In [None]:
# import libraries
from pathlib import Path
from typing import List, Dict, Union
from typing import Iterator, Dict
from Bio import SeqIO

import matplotlib.pyplot as plt
import pandas as pd
import random
import numpy as np
import tarfile
import csv
import time
import os

## Part 1 - decompress files
.fasta files can be packaged into a single .tar.gz archive. This part of the code is intended to extract the archive and retrieve the original .fasta files.


In [None]:
# 📁 Path of the .tar.gz archive
archive_path = r"/home/squarna/Desktop/input/STV.tar.gz"  

# 📂 Destination folder where to extract the files
extraction_path = r"/home/squarna/Desktop/input"

# 🛠️ Create the folder if it does not exist
os.makedirs(extraction_path, exist_ok=True)

# 🔓 Extract the .fasta files (or all)
with tarfile.open(archive_path, "r:gz") as tar:
    for member in tar.getmembers():
        if member.name.endswith(".fasta"):  # or remove this if to extract everything
            tar.extract(member, path=extraction_path)
    print(f"✅ Extraction completed in: {extraction_path}")

## Part 2 - merge .fasta files and create a unique file
After decompression, you may obtain multiple .fasta files. This code merges them into a single file to facilitate downstream analysis.

In [None]:
# 📁 Path of the folder containing the .fasta files
input_folder = Path(r"/home/squarna/Desktop/input/TemPhD")

# 📤 Path of the unified file to create
output_file = Path(r"/home/squarna/Desktop/input/TEMPHD.fasta")

with open(output_file, 'w', encoding='utf-8') as outfile:
    merged_count = 0

    for fasta_file in input_folder.glob("*.fasta"):
        try:
            with open(fasta_file, 'r', encoding='utf-8') as infile:
                outfile.writelines(infile.readlines())
                merged_count += 1
                print(f"✅ Aggiunto: {fasta_file.name}")
        except Exception as e:
            print(f"⚠️ Errore con {fasta_file.name}: {e}")

print(f"\n📦 Combinati {merged_count} file in: {output_file.name}")

## Part 3 - create a class to merge data and metadata
Here, I created a class that merges each .fasta file with its corresponding metadata and removes duplicate Phage_ID entries.
If you want to use this class for different datasets, you might need to change column names.

In [None]:
class Merger:
    def __init__(self, input_data: str, input_meta: str):
        self.input_file = Path(input_data)
        self.input_metafile = Path(input_meta)

    def create_database(self) -> tuple[pd.DataFrame, pd.DataFrame]:
        """Extracts IDs and sequences from a FASTA file and checks for duplicates in metadata."""
        protein_ids = []
        sequences = []

        with open(self.input_file) as fasta_file:
            for seq_record in SeqIO.parse(fasta_file, "fasta"):
                protein_ids.append(seq_record.id)
                sequences.append(str(seq_record.seq))

        database = pd.DataFrame({
            "ID": protein_ids,
            "Sequence": sequences
        })

        # remove duplicates from the ID column
        database = database.drop_duplicates(subset="ID")
        
        # Load metadata
        meta_file = pd.read_csv(self.input_metafile, sep='\t')

        # Check for duplicates in Phage_ID
        duplicates = meta_file["Phage_ID"].duplicated().any()

        if duplicates:
            print("⚠️ There are duplicate Phage_ID entries in metadata!")
            meta_file = meta_file.drop_duplicates(subset="Phage_ID").reset_index(drop=True)
            print("✅ Duplicate Phage_ID entries removed from metadata!")
        else:
            print("✅ No duplicates in Phage_ID within metadata.")

        return database, meta_file

    def create_final_database(self, database: pd.DataFrame, meta_file: pd.DataFrame) -> pd.DataFrame:
        """Merges sequences and metadata into a clean final DataFrame."""

        meta_subset = meta_file[
            ['Phage_ID', 'Length', 'GC_content', 'Taxonomy', 'Completeness', 'Host',
             'Lifestyle', 'Cluster', 'Subcluster']
        ]

        data_completed = pd.merge(
            database,
            meta_subset,
            left_on="ID",
            right_on="Phage_ID",
            how="inner"
        )

        # Remove the duplicate 'Phage_ID' column
        data_completed.drop(columns='Phage_ID', inplace=True)

        # Rename the 'ID' column to 'Phage_ID'
        data_completed.rename(columns={'ID': 'Phage_ID'}, inplace=True)

        return data_completed

## Part 4 - use the above class to create a dataset

In [None]:
# Define the paths for the .fasta and .tsv files
file_path = r"/home/squarna/Desktop/input/STV.fasta"
meta_path = r"/home/squarna/Desktop/input/stv_phage_meta_data.tsv"

# Initialize the object
merger = Merger(file_path, meta_path)

# Create the DataFrames
database, meta_file = merger.create_database()

# This is the final dataset of interest
final_data = merger.create_final_database(database, meta_file)

## Part 5 – Check the output dataset (final_data)

Metadata files (.tsv) and sequence files (.fasta) often do not have the same number of entries.  
For this reason, it is important to check their lengths:

1. **If len(file.fasta) > len(file.tsv)** → the .fasta file contains duplicated Phage_ID/sequences, or the .tsv file is missing some entries.  
2. **If len(file.fasta) < len(file.tsv)** → the .tsv file contains duplicated Phage_ID/entries, or it has more entries than the .fasta file.  
3. **If len(file.fasta) = len(file.tsv)** → this is a good sign; the files are most likely consistent.  


In [None]:
print(len(database))
print(len(meta_file))
print(len(final_data))
final_data.columns

## Part 6 - save the dataset in .csv format

In [None]:
# Define where to save the file
csv_output_path = r"/home/squarna/Desktop/input/STV.csv"

# Save the file and print result
final_data.to_csv(csv_output_path, index=False)

print(f"✅ File successfully saved in: {csv_output_path}")

## Part 7 - dataset analysis
We remove the missing values from all columns and create plots for a descriptive analysis of the dataset. In particular, we generate histograms to visualize the distribution of classes in the columns of interest.


In [None]:
# Read data
file_path = r"/media/ssd/Cleaned_datasets/000_dataset/000_cleaned_MIXED_dataset.csv"
file = pd.read_csv(file_path, sep = ',')

In [None]:
# Check the dataset after merging operation
file.shape, file.columns

In [None]:
# Check missing values in columns
file.isnull().sum()

In [None]:
# Remove missing values and define a new dataset
dataset = file.dropna(how = 'any')
dataset = dataset[dataset['Taxonomy'] != '-'].reset_index(drop=True)

This part of the code generates plots for the columns of interest. Column names may vary from one dataset to another.

In [None]:
# Define coulumn names of interest
name = ['Taxonomy', 'Completeness', 'Host', 'Lifestyle']
labels = {}

# Define a dictionary with unique instances in each columns
for i in name:
    labels[i] = dataset[i].unique().tolist()

In [None]:
# plot
# Choose how many categories to keep
top_n = 3

# Calculate frequencies
counts = dataset["Host"].value_counts()

# Select the top N
top_categories = counts[:top_n]

# Calculate the sum of the remaining ones
other_count = counts[top_n:].sum()

# Create a new Series with the top categories and "Other"
host_summary = top_categories.copy()
host_summary["Other"] = other_count

# Pie chart
host_summary.plot(kind='pie', autopct='%1.1f%%', figsize=(6, 6))
plt.title("Distribution of Host")
plt.ylabel('')
plt.tight_layout()
plt.show()

In [None]:
for column in [k for k in labels.keys() if k not in ['Host', 'Taxonomy']]:
    counts = dataset[column].value_counts()
    total = counts.sum()
    
    # Create labels with percentages
    labels_with_pct = [
        f"{name} ({count / total:.1%})" for name, count in zip(counts.index, counts)
    ]

    fig, ax = plt.subplots(figsize=(6, 6))
    wedges, texts = ax.pie(counts, startangle=90)  # no autopct, percentages will be shown in the legend
    ax.legend(wedges, labels_with_pct, title=column, loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))

    plt.title(f'Distribution of {column}')
    plt.tight_layout()
    plt.show()

In [None]:
for column in [k for k in labels.keys() if k not in ['Host', 'Lifestyle', 'Completeness']]:
    counts = dataset[column].value_counts()
    total = counts.sum()

    plt.figure(figsize=(8, 5))
    bars = plt.bar(counts.index.astype(str), counts.values)

    # Add percentages above the bars
    for bar, count in zip(bars, counts):
        percent = count / total * 100
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
                 f'{percent:.1f}%', ha='center', va='bottom', fontsize=9)

    plt.title(f'Distribution of {column}')
    plt.xlabel(column)
    plt.ylabel('Frequency')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.grid(axis='y', linestyle='--', alpha=0.5)
    plt.show()

If bar plot is too messy do to numerous classes try this part of the code to make a better graph

In [None]:
for column in [k for k in labels.keys() if k not in ['Host', 'Lifestyle', 'Completeness']]:
    counts = dataset[column].value_counts()

    # Select the top 8 classes
    top_classes = counts.head(8)
    others_sum = counts[8:].sum()

    # Add the "Other" class if necessary
    if others_sum > 0:
        counts_reduced = top_classes.copy()
        counts_reduced['Other'] = others_sum
    else:
        counts_reduced = top_classes

    total = counts_reduced.sum()

    plt.figure(figsize=(8, 5))
    bars = plt.bar(counts_reduced.index.astype(str), counts_reduced.values)

    # Add percentages above the bars
    for bar, count in zip(bars, counts_reduced):
        percent = count / total * 100
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
                 f'{percent:.1f}%', ha='center', va='bottom', fontsize=9)

    plt.title(f'Distribution of {column}')
    plt.xlabel(column)
    plt.ylabel('Frequency')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.grid(axis='y', linestyle='--', alpha=0.5)
    plt.show()

## Part 8 - check distribution
Do it if necessary --> you can check the distribution of classes

In [None]:
# plot
import seaborn as sns

# Categories of interest
categories = ['caudovirales', 'inoviridae', 'microviridae', 'riboviria']

# --- First plot: global distribution ---
plt.figure(figsize=(8, 5))
sns.histplot(dataset['GC_content'].dropna(), bins=30, kde=True, color='steelblue')
plt.title('Global distribution of GC_content')
plt.xlabel('GC_content (%)')
plt.ylabel('Frequency')
plt.grid(True)
plt.tight_layout()
plt.show()

# --- Separate plots for each category ---
for cat in categories:
    subset = dataset[dataset['Taxonomy'].str.lower().str.contains(cat, na=False)]
    gc_values = subset['GC_content'].dropna()

    plt.figure(figsize=(8, 5))
    sns.histplot(gc_values, bins=30, kde=True)
    plt.title(f'Distribution of GC_content for {cat.capitalize()}')
    plt.xlabel('GC_content (%)')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

## Part 9 - final dataset 
Use this part of code to remove Lo-quality and Not-determined sequences (Completeness).
Moreover, remove illegal protein sequences

In [None]:
dataset

In [None]:
# Clean the dataset removing Low-qualited and Not-determined sequences
final_dataset = dataset.drop(
    dataset[dataset['Completeness'].isin(['Low-quality', 'Not-determined'])].index
).set_index('Phage_ID')

final_dataset

## Part 9.1 - illegal sequences
Use this part of code to remove illegal sequences (sequences with illegal character)

In [None]:
import re

# Function to identify valid sequences
def is_valid_sequence(seq, allowed="ACDEFGHIKLMNPQRSTVWY"):
    return re.fullmatch(f"[{allowed}]+", seq) is not None

def clean_invalid_sequences(input_path, output_path, invalid_output_path):
    df = pd.read_csv(input_path)

    # Validity mask
    valid_mask = df["Sequence"].apply(is_valid_sequence)

    # Separation
    valid_df = df[valid_mask].reset_index(drop=True)
    invalid_df = df[~valid_mask].reset_index(drop=True)

    # Saving
    valid_df.to_csv(output_path, index=False)
    invalid_df.to_csv(invalid_output_path, index=False)

    print(f"✅ Valid sequences: {len(valid_df)} saved in {output_path}")
    print(f"❌ Invalid sequences: {len(invalid_df)} saved in {invalid_output_path}")

In [None]:
clean_invalid_sequences(
    input_path = file_path,
    output_path = r'/home/squarna/Desktop/csssleaned_MIXED_dataset.csv',
    invalid_output_path = r'/home/squarna/Desktop/cmerdaaaaaaleaned_MIXED_dataset.csv'
)

In [None]:
# Save the cleaned dataset (no illegal sequences and low-quality/not-determined sequences)
final_dataset.to_csv('/home/squarna/Desktop/cleaned_MIXED_dataset.csv')

In [None]:
# Final check 🤪
counter = 0

for i in range(len(final_dataset)):
    if final_dataset['Completeness'].iloc[i] in (['Low-quality', 'Not-determined']):
        counter += 1

if counter == 0:
    print("✅ Ok")
else:
    print(f"⚠️ KO {counter}")


                                       