In [None]:
import matplotlib.pyplot as plt
import pandas as pd

## Part 1 - dataset descriptive analysis
We remove the missing values from all columns and create plots for a descriptive analysis of the dataset. In particular, we generate histograms to visualize the distribution of classes in the columns of interest.


In [None]:
# Load the merged dataset in RAM
file_path = r"/media/ssd/Cleaned_datasets/000_dataset/dataset/000_cleaned_MIXED_dataset.csv"
file = pd.read_csv(file_path, sep = ',')

In [None]:
# Check the dataset shape
file.shape, file.columns

In [None]:
# Check for missing values in columns
file.isnull().sum()

In [None]:
# Remove missing values and define a new dataset
dataset = file.dropna(how = 'any')
dataset = dataset[dataset['Taxonomy'] != '-'].reset_index(drop = True)

Often in time, missing value are reportes as "-". 
This can undermine the analysis as isnull() function doesn't count them as missing values.
For this reason, we carry out an additional check.

In [None]:
# Additional check in the "Host" column only
undetermined_host = (dataset['Host'] != '-').sum()
print(undetermined_host)

## Part 2 - plots
This part of the code generates plots for the columns of interest. Columns' names may vary from one dataset to another.

In [None]:
# Define the columns of interest from the dataset
name = ['Taxonomy', 'Completeness', 'Host', 'Lifestyle']
labels = {}

# Define a dictionary that contains for each name the unique instances present in the dataset
for i in name:
    labels[i] = dataset[i].unique().tolist()

In [None]:
# ----------------------------------------------------------------------------------------------------------
# PIE CHART 1: HOST DISTRIBUTION
# ----------------------------------------------------------------------------------------------------------

# Choose how many categories to keep
top_n = 3

# Calculate frequencies for each host in the Host column
counts = dataset["Host"].value_counts()

# Select first N numerous host
top_categories = counts[:top_n]

# Calculate the sum of the remaining ones
other_count = counts[top_n:].sum()

# Create a new Series with the top categories and "Other"
host_summary = top_categories.copy()
host_summary["Other"] = other_count

# Plot
host_summary.plot(
    kind = 'pie',
    autopct = '%1.1f%%', 
    figsize=(6, 6)
                   )

plt.title("Distribution of Host")
plt.ylabel('')
plt.tight_layout()
plt.show()

In [None]:
# ----------------------------------------------------------------------------------------------------------
# PIE CHART 2: COMPLETENESS AND LIFESTYLE
# ----------------------------------------------------------------------------------------------------------

for column in [k for k in labels.keys() if k not in ['Host', 'Taxonomy']]:
    counts = dataset[column].value_counts()
    total = counts.sum()
    
    # Create labels with percentages
    labels_with_pct = [
        f"{name} ({count / total:.1%})" for name, count in zip(counts.index, counts)
    ]

    fig, ax = plt.subplots(figsize=(6, 6))
    wedges, texts = ax.pie(counts, startangle=90)  # no autopct, percentages will be shown in the legend
    ax.legend(wedges, labels_with_pct, title=column, loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))

    plt.title(f'Distribution of {column}')
    plt.tight_layout()
    plt.show()

In [None]:
# ----------------------------------------------------------------------------------------------------------
# INSTOGRAM: DISTRIBUTION OF TAXONOMY
# ----------------------------------------------------------------------------------------------------------

for column in [k for k in labels.keys() if k not in ['Host', 'Lifestyle', 'Completeness']]:
    counts = dataset[column].value_counts()
    total = counts.sum()

    plt.figure(figsize=(8, 5))
    bars = plt.bar(counts.index.astype(str), counts.values)

    # Add percentages above the bars
    for bar, count in zip(bars, counts):
        percent = count / total * 100
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
                 f'{percent:.1f}%', ha='center', va='bottom', fontsize=9)

    plt.title(f'Distribution of {column}')
    plt.xlabel(column)
    plt.ylabel('Frequency')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.grid(axis='y', linestyle='--', alpha=0.5)
    plt.show()

If bar plot is too messy do to numerous classes try this part of the code to make a better graph

In [None]:
# ----------------------------------------------------------------------------------------------------------
# INSTOGRAM: DISTRIBUTION OF TAXONOMY 2
# ----------------------------------------------------------------------------------------------------------

for column in [k for k in labels.keys() if k not in ['Host', 'Lifestyle', 'Completeness']]:
    counts = dataset[column].value_counts()

    # Select the top 8 classes
    top_classes = counts.head(8)
    others_sum = counts[8:].sum()

    # Add the "Other" class if necessary
    if others_sum > 0:
        counts_reduced = top_classes.copy()
        counts_reduced['Other'] = others_sum
    else:
        counts_reduced = top_classes

    total = counts_reduced.sum()

    plt.figure(figsize=(8, 5))
    bars = plt.bar(counts_reduced.index.astype(str), counts_reduced.values)

    # Add percentages above the bars
    for bar, count in zip(bars, counts_reduced):
        percent = count / total * 100
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
                 f'{percent:.1f}%', ha='center', va='bottom', fontsize=9)

    plt.title(f'Distribution of {column}')
    plt.xlabel(column)
    plt.ylabel('Frequency')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.grid(axis='y', linestyle='--', alpha=0.5)
    plt.show()

## Part 3 - final dataset
Use this part of code to remove Lo-quality and Not-determined sequences (Completeness).
Moreover, remove illegal protein sequences before feeding data into the ESM-2 model

In [None]:
# Clean the dataset removing Low-qualited and Not-determined sequences
final_dataset = dataset.drop(dataset[dataset['Completeness'].isin(['Low-quality', 'Not-determined'])].index).set_index('Phage_ID')
final_dataset

Use this part of code to remove illegal sequences (sequences with characters not beloging to the standard amino acid alphabet). In case of illegal characters in the protein sequence, ESM-2 model returns error during computations.

In [None]:
import re

# Function to identify valid sequences 
def is_valid_sequence(seq, allowed = "ACDEFGHIKLMNPQRSTVWY"):
    return re.fullmatch(f"[{allowed}]+", seq) is not None

def clean_invalid_sequences(input_path, output_path, invalid_output_path):
    df = pd.read_csv(input_path)

    # Validity mask
    valid_mask = df["Sequence"].apply(is_valid_sequence)

    # Separation
    valid_df = df[valid_mask].reset_index(drop=True)
    invalid_df = df[~valid_mask].reset_index(drop=True)

    # Saving
    valid_df.to_csv(output_path, index=False)
    invalid_df.to_csv(invalid_output_path, index=False)

    print(f"✅ Valid sequences: {len(valid_df)} saved in {output_path}")
    print(f"❌ Invalid sequences: {len(invalid_df)} saved in {invalid_output_path}")

In [None]:
clean_invalid_sequences(
    input_path = file_path,
    output_path = r'/home/squarna/Desktop/csssleaned_MIXED_dataset.csv',
    invalid_output_path = r'/home/squarna/Desktop/cmerdaaaaaaleaned_MIXED_dataset.csv'
)

In [None]:
# Save the cleaned dataset (no illegal sequences and low-quality/not-determined sequences)
final_dataset.to_csv('/home/squarna/Desktop/cleaned_MIXED_dataset.csv')