# Preliminaries: Inspect and Set up environment

In [None]:
# Import all libraries required

# Data Processing and EDA
import datetime
import pandas as pd
import numpy as np

# For bioinformatics tasks
from Bio import SeqIO
# older alignment method
from Bio import pairwise2
from Bio.pairwise2 import format_alignment
# newer alignment method - not using this, 
# but just don't want to forget this option
from Bio import Align
from Bio.Align import PairwiseAligner
import multiprocessing

# For Machine Learning
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.base import BaseEstimator, TransformerMixin
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

# For Evaluation
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix
from seaborn import heatmap
from sklearn.inspection import partial_dependence
import matplotlib.pyplot as plt

In [None]:
# Environment
# Show all the output for every print not just the last
from IPython.core.interactiveshell import InteractiveShell
# Configuration and settings
InteractiveShell.ast_node_interactivity = "all"
# To check if in Google Colab
from IPython.core.getipython import get_ipython
# To display all the output in a nicer table
from IPython.display import display
# To time the execution of the code
import time
import os.path

In [None]:
print(datetime.datetime.now())

In [None]:
!which python

In [None]:
!python --version

In [None]:
!echo $PYTHONPATH

In [None]:
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

In [None]:
#if 'google.colab' in str(get_ipython()):
    # TODO: if running on Google Colab, install any packages you need to here. For example:
    #!pip install unidecode
    #!pip install category_encoders
    #!pip install scikeras

In [None]:
# Let's minimize randomness
# numpy
np.random.seed(42)

Start the timer

In [None]:
# Start timer
start_time = time.time()

In [None]:
import joblib  # For saving and loading models

In [None]:
def check_duplicate_headers(records):
    """Checks for duplicate FASTA headers in a list of SeqIO records.

    Args:
        records: A list of SeqRecord objects.

    Returns:
        list: A list of indices corresponding to duplicate records.
        list: A list of duplicate headers.
    """
    seen = set()
    duplicates = []
    duplicate_headers = []
    for i, record in enumerate(records):
        header = record.description  # Or use record.id
        if header in seen:
            duplicates.append(i)
            duplicate_headers.append(header)
        else:
            seen.add(header)
    return duplicates, duplicate_headers

In [None]:
def get_sequence_length(row, column):
    """Calculates the length of the sequence in the specified column."""
    return len(row[column])

In [None]:
# Define k-mer length
# kmer_length = 12
kmer_length = 6

In [None]:
# Function to extract kmers (can be reused)
def get_kmers(sequence, k):
  """
  Extracts all k-mers (subsequences of length k) from a DNA sequence.
  """
  kmers = []
  for i in range(len(sequence) - k + 1):
    kmer = sequence[i:i+k]
    kmers.append(kmer)
  return kmers

In [None]:
# Define chunk size (adjust as needed)
chunk_size = 100

In [None]:
def process_chunk(chunk_dict):
    """
    Processes a chunk of data from the kmer_dict and returns one-hot encoded features.
    """
    chunk_ha_features = []

    # Get unique k-mers across all sequences in the chunk
    # Only run this code if 'ohe' attribute does not exist. This is to ensure that the one-hot encoder is only fit once.
    # So, for the test data, 'ohe' already defined, this code will not run (will not do a fit on a new OHE) and go directly to the transformation step below
    if not hasattr(process_chunk, 'ohe'):
        all_kmers = set()
        for kmer_dict_row in chunk_dict.values():
            ha_kmers = kmer_dict_row["HA"]
            all_kmers.update(ha_kmers)

        # Create one-hot encoder (only fit on the first chunk for consistent categories)
        process_chunk.ohe = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
        process_chunk.ohe.fit([[kmer] for kmer in list(all_kmers)])  # Fit on unique kmers

    # Transform each sequence into a one-hot encoded vector
    for kmer_dict_row in chunk_dict.values():
        ha_kmers = kmer_dict_row["HA"]
        kmer_indices = process_chunk.ohe.transform([[kmer] for kmer in ha_kmers]).sum(axis=0)
        chunk_ha_features.append(kmer_indices)

    return chunk_ha_features

In [None]:
def format_runtime(total_time):
    """
    Formats a given runtime (in seconds) into a human-readable string, 
    omitting zero-value components (days, hours, minutes).

    Args:
        total_time (float): The total runtime in seconds.

    Returns:
        str: A formatted string representing the runtime.
    """
    
    # Convert total time in seconds to a timedelta object
    td = datetime.timedelta(seconds=total_time)\
    
    # Extract days, hours, minutes, and seconds from the timedelta
    days = td.days
    hours, remainder = divmod(td.seconds, 3600)
    minutes, seconds = divmod(remainder, 60)

    # Create a list to store non-zero time components and their labels
    time_components = []
    if days > 0:
        time_components.append(f"{days} days")
    if hours > 0:
        time_components.append(f"{hours} hours")
    if minutes > 0:
        time_components.append(f"{minutes} minutes")
    time_components.append(f"{seconds:.2f} seconds")  # Always include seconds

    # Join the time components with commas and "and"
    formatted_time = ", ".join(time_components[:-1])  # Join all but the last
    if len(time_components) > 1:
        formatted_time += " and " + time_components[-1]  # Add "and" and the last
    else:
        formatted_time = time_components[0]  # If only one component, use it directly

    return formatted_time

# Load the Model and LabelEncoder

In [None]:
# Load the model and LabelEncoder
model_filename = "model_learn-flu-train_2024-08-19_01:50:52.joblib" # Update the filename of the trained model as required
model_data = joblib.load(model_filename)
model = model_data['model']
le = model_data['label_encoder']
process_chunk.ohe = model_data['one_hot_encoder'] 

# 5.0 Test

## 5.1 Load and Prepare Unseen Test Data

In [None]:
# Parse the FASTA file for the new dataset
test_records = list(SeqIO.parse("/home/ajvilleg/Netdrive/AI/GISAID/EpiFlu_Test/11-Aug-2024_Oceania/gisaid_epiflu_sequence_11-Aug-2024.fasta", "fasta"))

In [None]:
# Check for duplicates before extracting details
duplicate_indices, duplicate_headers = check_duplicate_headers(test_records)
if duplicate_indices:
    print("Warning: Found the following duplicate FASTA headers:")
    for header in duplicate_headers:
        print(header)
    print("One copy of each duplicate record will be kept.")

    # Create a set of unique indices to keep
    indices_to_keep = set(range(len(test_records))) - set(duplicate_indices)

    # Filter the records list
    test_records = [record for i, record in enumerate(test_records) if i in indices_to_keep]
    print("Duplicate records have been removed, keeping one copy of each.")

In [None]:
# Extract the details from the description of each record
test_data = []

# Iterate through records for every pair for NA and HA segments
for record1, record2 in zip(test_records[::2], test_records[1::2]):
    description1 = record1.description.split('|')
    description2 = record2.description.split('|')

    # Assume the isolate name is the same for both segments
    isolate_name1 = description1[0].strip()
    isolate_name2 = description2[0].strip()
    if isolate_name1 != isolate_name2:
        print(f"Isolate names do not match: {isolate_name1} vs {isolate_name2}")
        raise ValueError("Isolate names do not match")

    # Assume the isolate ID is the same for both segments
    isolate_id1 = description1[1].strip()
    isolate_id2 = description2[1].strip()
    if isolate_id1 != isolate_id2:
        print(f"Isolate IDs do not match: {isolate_id1} vs {isolate_id2}")
        raise ValueError("Isolate IDs do not match")

    # Assume the flu type is the same for both segments
    flu_type1 = description1[2].strip()
    flu_type2 = description2[2].strip()
    if flu_type1 != flu_type2:
        print(f"Flu types do not match: {flu_type1} vs {flu_type2}")
        raise ValueError("Flu types do not match")

    # Assume the lineage is the same for both segments
    lineage1 = description1[3].strip()
    lineage2 = description2[3].strip()
    if lineage1 != lineage2:
        print(f"Lineages do not match: {lineage1} vs {lineage2}")
        raise ValueError("Lineages do not match")

    # The segment labels are different for NA and HA segments
    segment1 = description1[4].strip()
    segment2 = description2[4].strip()  

    # Assume the collection date is the same for both segments
    collection_date1 = description1[5].strip()
    collection_date2 = description2[5].strip()
    if collection_date1 != collection_date2:
        print(f"Collection dates do not match: {collection_date1} vs {collection_date2}")
        raise ValueError("Collection dates do not match")

    # Assume the clade is the same for both segments. This is important as this will be our label for classification
    clade1 = description1[6].strip()
    clade2 = description2[6].strip()
    if clade1 != clade2:
        print(f"Clades do not match: {clade1} vs {clade2}")
        raise ValueError("Clades do not match")

    # The sequences will be different corresopnding to the NA and HA segments
    sequence1 = str(record1.seq)
    sequence2 = str(record2.seq)
    if segment1 == 'HA':
        sequence_ha = sequence1
        sequence_na = sequence2
    else: # segment2 == 'HA'
        sequence_ha = sequence2
        sequence_na = sequence1
    test_data.append([isolate_name1, isolate_id1, flu_type1, lineage1, sequence_ha, sequence_na, collection_date1, clade1])

test_df = pd.DataFrame(test_data, columns=['Isolate_Name', 'Isolate_ID', 'Flu_Type', 'Lineage', 'HA', 'NA', 'Collection Date', 'Clade'])


In [None]:
# Take a look at the data
display(test_df)

In [None]:
# Extract Type using regular expressions
test_df['Type'] = test_df['Flu_Type'].astype(str).str.extract(r'(A|B|C)').fillna('')

In [None]:
# Extract H_Subtype and N_Subtype with updated regex, allowing for one or more digits after H or N.
test_df['H_Subtype'] = test_df['Flu_Type'].astype(str).str.extract(r'(H\d+)').fillna('')
test_df['N_Subtype'] = test_df['Flu_Type'].astype(str).str.extract(r'(N\d+)').fillna('')
print(test_df['H_Subtype'].value_counts().to_markdown(numalign="left", stralign="left"))
print(test_df['N_Subtype'].value_counts().to_markdown(numalign="left", stralign="left"))

In [None]:
# Take a look at the data again
display(test_df)

## 5.2 EDA

### 5.2.1 Dataframe structure

In [None]:
test_df.info()

In [None]:
# Convert all columns to strings except Collection Date
test_df = test_df.astype(str)

# Convert "Collection Date" column to date
test_df["Collection Date"] = pd.to_datetime(test_df["Collection Date"])

test_df.info()

### 5.2.2 Describe

In [None]:
test_df.describe()

### 5.2.3 Duplicated rows

In [None]:
# Check for duplicated rows in test data
print(f'test_df has {test_df.duplicated().sum()} duplicate rows')
display(test_df[test_df.duplicated()])
# Drop duplicates and check again
test_df.drop_duplicates(inplace=True)
print(f'test_df has {test_df.duplicated().sum()} duplicate rows')

### 5.2.4 Missing values / NaN / Empty Strings

In [None]:
# Check for missing values and empty strings
print("NaN values in test_df:")
print(test_df.isnull().sum())  # Check for NaN values
print("\nEmpty string values in test_df:")
for col in test_df.select_dtypes(include=['object']):  # Iterate over columns with string datatype
    print(f"{col}: {(test_df[col] == '').sum()}")     # Count empty strings

In [None]:
# Drop rows with nulls or empty strings in Clade, ignore Lineage nulls/empty strings
test_df.replace('', pd.NA, inplace=True)  # Replace empty strings with NaN
test_df.dropna(subset=['Clade'], inplace=True)  # Drop rows where Clade is NaN

In [None]:
# Check for missing values and empty strings
print("NaN values in test_df:")
print(test_df.isnull().sum())  # Check for NaN values
print("\nEmpty string values in test_df:")
for col in test_df.select_dtypes(include=['object']):  # Iterate over columns with string datatype
    print(f"{col}: {(test_df[col] == '').sum()}")     # Count empty strings

### 5.2.5 Class Imbalance

#### Clade Imbalance

In [None]:
test_df['Clade'].value_counts()

In [None]:
# Drop rows with "unassigned" in 'Clade' from the training data
test_df = test_df[test_df['Clade'] != 'unassigned']  # Filter out rows with label "unassigned"

#### HACK: Drop rows in 'Clade' from the test data, as it is not present in the training data 

In [None]:
# Temporary HACK: Drop rows with "2.3.2.1c" in 'Clade' from the test data (TODO: which test data?), 
# as it is not present in the training data 
test_df = test_df[test_df['Clade'] != '2.3.2.1c']  # Filter out rows with label "2.3.2.1c"


In [None]:
# Temporary HACK: Drop rows with "Am_nonGsGD" in 'Clade' from the test data (TODO: which test data?), 
# as it is not present in the training data 
test_df = test_df[test_df['Clade'] != 'Am_nonGsGD']  # Filter out rows with label "Am_nonGsGD"

In [None]:
# Temporary HACK: Drop rows with "3C.2a1b.2a.1a.1" in 'Clade' from the test data (EpiFlu_Test/11-Aug-2024_Oceania/gisaid_epiflu_sequence_11-Aug-2024.fasta), 
# as it is not present in the training data (EpiFlu_Training/01-Aug-2024/gisaid_epiflu_sequence_01-Aug-2024_All_Hosts_Type_A_USA.fasta) 
test_df = test_df[test_df['Clade'] != '3C.2a1b.2a.1a.1']  # Filter out rows with label "3C.2a1b.2a.1a.1"

In [None]:
# Temporary HACK: Drop rows with "2.3.2.1a" in 'Clade' from the test data (EpiFlu_Test/11-Aug-2024_Oceania/gisaid_epiflu_sequence_11-Aug-2024.fasta), 
# as it is not present in the training data (EpiFlu_Training/01-Aug-2024/gisaid_epiflu_sequence_01-Aug-2024_All_Hosts_Type_A_USA.fasta) 
test_df = test_df[test_df['Clade'] != '2.3.2.1a']  # Filter out rows with label "2.3.2.1a"

In [None]:
test_df['Clade'].value_counts()

#### H_Subtype Imbalance

In [None]:
test_df['H_Subtype'].value_counts()

#### N_Subtype Imbalance

In [None]:
test_df['N_Subtype'].value_counts()

### 5.2.6 Shape

In [None]:
test_df.shape

### 5.2.7 Look at sequence length stats

In [None]:
new_ha_sequence_lengths = test_df.apply(get_sequence_length, axis=1, column="HA")
new_na_sequence_lengths = test_df.apply(get_sequence_length, axis=1, column="NA")
print("Sequence lengths in HA columns:")
new_ha_sequence_lengths.describe()
print("Sequence lengths in NA columns:")
new_na_sequence_lengths.describe()

## 5.3 K-mers and k-mer encoding

In [None]:
# Create an empty dictionary to store kmers for each sequence in the new dataset (identified by row index)
test_kmer_dict = {}

In [None]:
# Extract k-mers from each sequence and store them in the dictionary
for i, row in test_df.iterrows():
    # Extract kmers from HA sequence (assuming it exists)
    ha_kmers = []
    if "HA" in row:  # Check if "HA" column exists
        sequence = str(row["HA"])
        ha_kmers = get_kmers(sequence, kmer_length)

    # Store kmers separately in the dictionary
    test_kmer_dict[i] = {
        "HA": ha_kmers
    }

### 5.3.1 One-Hot Encoding using chunking to optimize memory usage

In [None]:
# One-Hot Encoding on new data using the same kmer chunking logic as before to optimize memory usage
# Apply the same k-mer extraction and chunking as with training data
test_ha_features = []
for i in range(0, len(test_kmer_dict), chunk_size):
    # Get a chunk of data
    chunk_dict = dict(list(test_kmer_dict.items())[i:i + chunk_size])
    
    # Process features for the chunk
    chunk_ha_features = process_chunk(chunk_dict)
    
    # Append features from the chunk
    test_ha_features.extend(chunk_ha_features)

## 5.4 Prediction

In [None]:
# OHE
# Convert list of lists to numpy array
X_new = np.array(test_ha_features)  

In [None]:
# Predict using the trained model
y_pred_new = model.predict(X_new)

# Decode predictions to get the original Clade labels
predicted_clades = le.inverse_transform(y_pred_new)

# Add predicted clades back to new_df
test_df['Predicted_Clade'] = predicted_clades

# Display the data with predictions and true clades
display(test_df[['Isolate_Name', 'Isolate_ID', 'Collection Date', 'Clade', 'Predicted_Clade']])  # Display true and predicted clades

## 5.5 Evaluation

In [None]:

# Calculate and display metrics 
y_true_new = le.transform(test_df['Clade']) # Encode true labels

In [None]:
# Evaluate Predictions
print("\n### Model Evaluation on Test Dataset ###")

# Accuracy: Proportion of correctly predicted samples
accuracy = accuracy_score(y_true_new, y_pred_new)
print("Accuracy:", accuracy)

# Precision: Ratio of true positives to all predicted positives
precision = precision_score(y_true_new, y_pred_new, average='weighted') # Weighted average for multi-class
print("Precision:", precision)

# Recall: Ratio of true positives to all actual positives
recall = recall_score(y_true_new, y_pred_new, average='weighted') # Weighted average for multi-class
print("Recall:", recall)

# F1-score: Harmonic mean of precision and recall
f1 = f1_score(y_true_new, y_pred_new, average='weighted') # Weighted average for multi-class
print("F1-score:", f1)

#### Classification Report

In [None]:

# Classification Report
print("\nClassification Report:")
print(classification_report(y_true_new, y_pred_new))

#### Confusion Matrix

In [None]:
# Confusion Matrix with Seaborn
print("\nConfusion Matrix:")
cm = confusion_matrix(y_true_new, y_pred_new)

# Create a new figure for the confusion matrix
plt.figure(figsize=(8, 6))

# Create heatmap using seaborn
heatmap(cm, annot=True, fmt="d", cmap="Blues") # Customize heatmap with annotations, format, and colormap

# Add labels and title
plt.xlabel("Predicted Clade")
plt.ylabel("True Clade")
plt.title("Confusion Matrix")

# Show the confusion matrix
plt.show()

End the timer

In [None]:
# End timing
end_time = time.time()

In [None]:
# Calculate and print total testing runtime
testing_time = end_time - start_time
print(f"\nTotal testing runtime: {format_runtime(testing_time)}")

#### What are the predicted clades for each subtype? 

In [None]:
# Create a new column 'Subtype' by concatenating 'H_Subtype' and 'N_Subtype'
test_df['Subtype'] = test_df['H_Subtype'] + test_df['N_Subtype']

# Group by 'Subtype' and collect unique 'Predicted_Clade' values into a list
grouped_data = test_df.groupby('Subtype')['Predicted_Clade'].unique().reset_index()

# Rename the 'Predicted_Clade' column to 'Predicted Clades'
grouped_data = grouped_data.rename(columns={'Predicted_Clade': 'Predicted Clades'})

# Convert NumPy arrays in 'Predicted Clades' to strings
grouped_data['Predicted Clades'] = grouped_data['Predicted Clades'].astype(str)
grouped_data

# Display the grouped_data directly in markdown format
print(grouped_data.to_markdown(index=False, numalign="left", stralign="left"))


In [None]:
# # Create a new column 'Subtype' by concatenating 'H_Subtype' and 'N_Subtype'
# test_df['Subtype'] = test_df['H_Subtype'] + test_df['N_Subtype']

# Group by 'Subtype' and collect 'Predicted_Clade' values
grouped_data = test_df.groupby(['Subtype', 'Predicted_Clade']).size().reset_index(name='Count')

# Calculate total count for the entire test dataset
total_count = len(test_df) 

# Calculate percentage using the total count of the test dataset
grouped_data['Percentage'] = ((grouped_data['Count'] / total_count) * 100).round(1)

# Rename the 'Predicted_Clade' column to 'Predicted Clades'
grouped_data = grouped_data.rename(columns={'Predicted_Clade': 'Predicted Clades'})
grouped_data

# Display the grouped_data directly in markdown format
print(grouped_data.to_markdown(index=False, numalign="left", stralign="left"))

In [None]:
# # Create a new column 'Subtype' by concatenating 'H_Subtype' and 'N_Subtype'
# test_df['Subtype'] = test_df['H_Subtype'] + test_df['N_Subtype']

# Group by 'Subtype' and collect 'Predicted_Clade' values
grouped_data = test_df.groupby(['Subtype', 'Predicted_Clade']).size().reset_index(name='Count')

# Calculate total count for the entire test dataset
total_count = len(test_df) 

# Calculate percentage using the total count of the test dataset
grouped_data['Percentage'] = ((grouped_data['Count'] / total_count) * 100).round(1)

# Rename the 'Predicted_Clade' column to 'Predicted Clades'
grouped_data = grouped_data.rename(columns={'Predicted_Clade': 'Predicted Clades'})

# Aggregate Predicted Clades, Count and Percentage for each Subtype
grouped_data = grouped_data.groupby('Subtype').agg({
    'Predicted Clades': lambda x: ', '.join(x), 
    'Count': lambda x: ', '.join(x.astype(str)), 
    'Percentage': lambda x: ', '.join(x.astype(str))
}).reset_index()

# Display the grouped_data directly in markdown format
print(grouped_data.to_markdown(index=False, numalign="left", stralign="left"))