# Preliminaries: Inspect and Set up environment

In [None]:
# Import all libraries required

# Data Processing and EDA
import datetime
import pandas as pd
import numpy as np

# For bioinformatics tasks
from Bio import SeqIO
# older alignment method
from Bio import pairwise2
from Bio.pairwise2 import format_alignment
# newer alignment method - not using this, 
# but just don't want to forget this option
from Bio import Align
from Bio.Align import PairwiseAligner
import multiprocessing

# For Machine Learning
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.base import BaseEstimator, TransformerMixin
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

# For Evaluation
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix
from seaborn import heatmap
from sklearn.inspection import partial_dependence
import matplotlib.pyplot as plt
import shap

In [None]:
# Environment
# Show all the output for every print not just the last
from IPython.core.interactiveshell import InteractiveShell
# Configuration and settings
InteractiveShell.ast_node_interactivity = "all"
# To check if in Google Colab
from IPython.core.getipython import get_ipython
# To display all the output in a nicer table
from IPython.display import display
# To time the execution of the code
import time
import os.path

In [None]:
print(datetime.datetime.now())

In [None]:
!which python

In [None]:
!python --version

In [None]:
!echo $PYTHONPATH

In [None]:
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

In [None]:
#if 'google.colab' in str(get_ipython()):
    # TODO: if running on Google Colab, install any packages you need to here. For example:
    #!pip install unidecode
    #!pip install category_encoders
    #!pip install scikeras

In [None]:
# Let's minimize randomness
# numpy
np.random.seed(42)

Start the timer

In [None]:
# Start timer
start_time = time.time()

In [None]:
import joblib  # For saving and loading models

# 1.0 Data Understanding

## 1.1 Load data

In [None]:
# Parse the FASTA file
records = list(SeqIO.parse("/home/ajvilleg/Netdrive/AI/GISAID/EpiFlu_Training/28-Jul-2024/gisaid_epiflu_sequence_28-Jul-2024_USA.fasta", "fasta"))


In [None]:
def check_duplicate_headers(records):
    """Checks for duplicate FASTA headers in a list of SeqIO records.

    Args:
        records: A list of SeqRecord objects.

    Returns:
        list: A list of indices corresponding to duplicate records.
        list: A list of duplicate headers.
    """
    seen = set()
    duplicates = []
    duplicate_headers = []
    for i, record in enumerate(records):
        header = record.description  # Or use record.id
        if header in seen:
            duplicates.append(i)
            duplicate_headers.append(header)
        else:
            seen.add(header)
    return duplicates, duplicate_headers

In [None]:
# Check for duplicates before extracting details
duplicate_indices, duplicate_headers = check_duplicate_headers(records)
if duplicate_indices:
    print("Warning: Found the following duplicate FASTA headers:")
    for header in duplicate_headers:
        print(header)
    print("One copy of each duplicate record will be kept.")

    # Create a set of unique indices to keep
    indices_to_keep = set(range(len(records))) - set(duplicate_indices)

    # Filter the records list
    records = [record for i, record in enumerate(records) if i in indices_to_keep]
    print("Duplicate records have been removed, keeping one copy of each.")


In [None]:
# Extract the details from the description of each record
data = []

# Iterate through records for every pair for NA and HA segments
for record1, record2 in zip(records[::2], records[1::2]):
    description1 = record1.description.split('|')
    description2 = record2.description.split('|')

    # Assume the isolate name is the same for both segments
    isolate_name1 = description1[0].strip()
    isolate_name2 = description2[0].strip()
    if isolate_name1 != isolate_name2:
        print(f"Isolate names do not match: {isolate_name1} vs {isolate_name2}")
        raise ValueError("Isolate names do not match")

    # Assume the isolate ID is the same for both segments
    isolate_id1 = description1[1].strip()
    isolate_id2 = description2[1].strip()
    if isolate_id1 != isolate_id2:
        print(f"Isolate IDs do not match: {isolate_id1} vs {isolate_id2}")
        raise ValueError("Isolate IDs do not match")

    # Assume the flu type is the same for both segments
    flu_type1 = description1[2].strip()
    flu_type2 = description2[2].strip()
    if flu_type1 != flu_type2:
        print(f"Flu types do not match: {flu_type1} vs {flu_type2}")
        raise ValueError("Flu types do not match")

    # Assume the lineage is the same for both segments
    lineage1 = description1[3].strip()
    lineage2 = description2[3].strip()
    if lineage1 != lineage2:
        print(f"Lineages do not match: {lineage1} vs {lineage2}")
        raise ValueError("Lineages do not match")

    # The segment labels are different for NA and HA segments
    segment1 = description1[4].strip()
    segment2 = description2[4].strip()  

    # Assume the collection date is the same for both segments
    collection_date1 = description1[5].strip()
    collection_date2 = description2[5].strip()
    if collection_date1 != collection_date2:
        print(f"Collection dates do not match: {collection_date1} vs {collection_date2}")
        raise ValueError("Collection dates do not match")

    # Assume the clade is the same for both segments. This is important as this will be our label for classification
    clade1 = description1[6].strip()
    clade2 = description2[6].strip()
    if clade1 != clade2:
        print(f"Clades do not match: {clade1} vs {clade2}")
        raise ValueError("Clades do not match")

    # The sequences will be different corresopnding to the NA and HA segments
    sequence1 = str(record1.seq)
    sequence2 = str(record2.seq)
    if segment1 == 'HA':
        sequence_ha = sequence1
        sequence_na = sequence2
    else: # segment2 == 'HA'
        sequence_ha = sequence2
        sequence_na = sequence1
    data.append([isolate_name1, isolate_id1, flu_type1, lineage1, sequence_ha, sequence_na, collection_date1, clade1])

df = pd.DataFrame(data, columns=['Isolate_Name', 'Isolate_ID', 'Flu_Type', 'Lineage', 'HA', 'NA', 'Collection Date', 'Clade'])

In [None]:
# Take a look at the data  
display(df)

In [None]:
# Extract Type using regular expressions
df['Type'] = df['Flu_Type'].astype(str).str.extract(r'(A|B|C)').fillna('')

In [None]:
# Extract H_Subtype and N_Subtype with updated regex, allowing for one or more digits after H or N.
df['H_Subtype'] = df['Flu_Type'].astype(str).str.extract(r'(H\d+)').fillna('')
df['N_Subtype'] = df['Flu_Type'].astype(str).str.extract(r'(N\d+)').fillna('')
print(df['H_Subtype'].value_counts().to_markdown(numalign="left", stralign="left"))
print(df['N_Subtype'].value_counts().to_markdown(numalign="left", stralign="left"))

In [None]:
# Take a look at the data again
display(df)

## 1.2 EDA

### 1.2.1 Dataframe structure

In [None]:
df.info()

In [None]:
# Convert all columns to strings except Collection Date
df = df.astype(str)

# Convert "Collection Date" column to date
df["Collection Date"] = pd.to_datetime(df["Collection Date"])

df.info()

### 1.2.2 Describe

In [None]:
df.describe()

### 1.2.3 Duplicated rows

In [None]:
# Check for duplicated rows in training data
print(f'df has {df.duplicated().sum()} duplicate rows')
display(df[df.duplicated()])
# Drop duplicates and check again
df.drop_duplicates(inplace=True)
print(f'df has {df.duplicated().sum()} duplicate rows')

### 1.2.4 Missing values / NaN / Empty Strings

In [None]:
# Check for missing values and empty strings
print("NaN values in df:")
print(df.isnull().sum())  # Check for NaN values
print("\nEmpty string values in df:")
for col in df.select_dtypes(include=['object']):  # Iterate over columns with string datatype
    print(f"{col}: {(df[col] == '').sum()}")     # Count empty strings

In [None]:
# Drop rows with nulls or empty strings in Clade, ignore Lineage nulls/empty strings
df.replace('', pd.NA, inplace=True)  # Replace empty strings with NaN
df.dropna(subset=['Clade'], inplace=True)  # Drop rows where Clade is NaN

In [None]:
# Check for missing values and empty strings
# NOTE: the previous empty strings that were not removed will now show up as NaN as we replaced empty strings with NaN
print("NaN values in df:")
print(df.isnull().sum())  # Check for NaN values
print("\nEmpty string values in df:")
for col in df.select_dtypes(include=['object']):  # Iterate over columns with string datatype
    print(f"{col}: {(df[col] == '').sum()}")     # Count empty strings

### 1.2.5 Class Imbalance

#### Clade Imbalance

In [None]:
df['Clade'].value_counts()

In [None]:
# Drop rows with "unassigned" in 'Clade' from the training data
df = df[df['Clade'] != 'unassigned']  # Filter out rows with label "unassigned"

In [None]:
df['Clade'].value_counts()

#### H_Subtype Imbalance

In [None]:
df['H_Subtype'].value_counts()

#### N_Subtype Imbalance

In [None]:
df['N_Subtype'].value_counts()

### 1.2.6 Shape

In [None]:
df.shape

### 1.2.7 Look at sequence length stats

In [None]:
def get_sequence_length(row, column):
    """Calculates the length of the sequence in the specified column."""
    return len(row[column])

In [None]:
ha_sequence_lengths = df.apply(get_sequence_length, axis=1, column="HA")
na_sequence_lengths = df.apply(get_sequence_length, axis=1, column="NA")
print("Sequence lengths in HA columns:")
ha_sequence_lengths.describe()
print("Sequence lengths in NA columns:")
na_sequence_lengths.describe()


# 2.0 Data Preparation

### 2.1 K-mers and k-mer encoding

In [None]:
# Define k-mer length
# kmer_length = 12
kmer_length = 6

In [None]:
# Function to extract kmers (can be reused)
def get_kmers(sequence, k):
  """
  Extracts all k-mers (subsequences of length k) from a DNA sequence.
  """
  kmers = []
  for i in range(len(sequence) - k + 1):
    kmer = sequence[i:i+k]
    kmers.append(kmer)
  return kmers

In [None]:
# Create an empty dictionary to store kmers for each sequence (identified by row index)
kmer_dict = {}

In [None]:
# Extract k-mers with length kmer_length from each sequence and store them in the dictionary
for i, row in df.iterrows():
  # Extract kmers from HA sequence (assuming it exists)
  ha_kmers = []
  if "HA" in row:  # Check if "HA" column exists
    sequence = str(row["HA"])
    ha_kmers = get_kmers(sequence, kmer_length)

  # Store kmers separately in the dictionary
  kmer_dict[i] = {
    "HA": ha_kmers
  }

In [None]:
df.loc[22, "HA"]


In [None]:
row_index = 22  # Replace with the desired row index
kmer_breakdown = kmer_dict[row_index]
print(kmer_breakdown)

#### 2.1.1 One-Hot Encoding using chunking to optimize memory usage

In [None]:
# Define chunk size (adjust as needed)
chunk_size = 100

In [None]:
def process_chunk(chunk_dict):
    """
    Processes a chunk of data from the kmer_dict and returns one-hot encoded features.
    """
    chunk_ha_features = []

    # Get unique k-mers across all sequences in the chunk
    # Only run this code if 'ohe' attribute does not exist. This is to ensure that the one-hot encoder is only fit once.
    # So, for the test data, 'ohe' already defined, this code will not run (will not do a fit on a new OHE) and go directly to the transformation step below
    if not hasattr(process_chunk, 'ohe'):
        all_kmers = set()
        for kmer_dict_row in chunk_dict.values():
            ha_kmers = kmer_dict_row["HA"]
            all_kmers.update(ha_kmers)

        # Create one-hot encoder (only fit on the first chunk for consistent categories)
        process_chunk.ohe = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
        process_chunk.ohe.fit([[kmer] for kmer in list(all_kmers)])  # Fit on unique kmers

    # Transform each sequence into a one-hot encoded vector
    for kmer_dict_row in chunk_dict.values():
        ha_kmers = kmer_dict_row["HA"]
        kmer_indices = process_chunk.ohe.transform([[kmer] for kmer in ha_kmers]).sum(axis=0)
        chunk_ha_features.append(kmer_indices)

    return chunk_ha_features

In [None]:
# Iterate through kmer_dict in chunks
ha_features = []
for i in range(0, len(kmer_dict), chunk_size):
  # Get a chunk of data
  chunk_dict = dict(list(kmer_dict.items())[i:i + chunk_size])

  # Process features for the chunk
  chunk_ha_features = process_chunk(chunk_dict)

  # Append features from the chunk
  ha_features.extend(chunk_ha_features)

### 2.3 Define X and y and Train Test Split

In [None]:
# OHE
# Convert list of lists to numpy array for X (training data)
X = np.array(ha_features)

In [None]:
y = []
le = LabelEncoder()

# Fit the LabelEncoder to all unique classes (call only once)
le.fit(df['Clade'])

for index in df.index:
  clade_label = le.transform(np.array([df.loc[index, "Clade"]]))[0]
  y.append(clade_label)

In [None]:
# Train-test split (80% training, 20% testing)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)  # Set random_state for reproducibility

# 3.0 Modelling

In [None]:
# # Train a Random Forest model
# model = RandomForestClassifier(n_estimators=100, random_state=42)  # You can adjust hyperparameters
# model.fit(X_train, y_train) 

In [None]:
# Train a Random Forest model with balanced class weights to handle class imbalance
model = RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced') 
model.fit(X_train, y_train) 

# 4.0 Evaluation

In [None]:
# Make predictions on the testing set
y_pred = model.predict(X_test)

In [None]:
# Evaluate Predictions
print("\n### Model Evaluation on Train/Validation Dataset ###")

# Accuracy: Proportion of correctly predicted samples
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

# Precision: Ratio of true positives to all predicted positives
precision = precision_score(y_test, y_pred, average='weighted') # Weighted average for multi-class
print("Precision:", precision)

# Recall: Ratio of true positives to all actual positives
recall = recall_score(y_test, y_pred, average='weighted')  # Weighted average for multi-class
print("Recall:", recall)

# F1-score: Harmonic mean of precision and recall
f1 = f1_score(y_test, y_pred, average='weighted')  # Weighted average for multi-class
print("F1-score:", f1)

### Classification Report

In [None]:
# Classification Report
print("\nClassification Report:")
print(classification_report(y_test, y_pred))


### Confusion Matrix

In [None]:
# Confusion Matrix with Seaborn
print("\nConfusion Matrix:")
cm = confusion_matrix(y_test, y_pred)

# Create a new figure for the confusion matrix
plt.figure(figsize=(8, 6))

# Create heatmap using seaborn
heatmap(cm, annot=True, fmt="d", cmap="Blues") # Customize heatmap with annotations, format, and colormap

# Add labels and title
plt.xlabel("Predicted Clade")
plt.ylabel("True Clade")
plt.title("Confusion Matrix")

# Show the confusion matrix
plt.show()

End the timer

In [None]:
# End timing
end_time = time.time()

In [None]:
def format_runtime(total_time):
    """
    Formats a given runtime (in seconds) into a human-readable string, 
    omitting zero-value components (days, hours, minutes).

    Args:
        total_time (float): The total runtime in seconds.

    Returns:
        str: A formatted string representing the runtime.
    """
    
    # Convert total time in seconds to a timedelta object
    td = datetime.timedelta(seconds=total_time)\
    
    # Extract days, hours, minutes, and seconds from the timedelta
    days = td.days
    hours, remainder = divmod(td.seconds, 3600)
    minutes, seconds = divmod(remainder, 60)

    # Create a list to store non-zero time components and their labels
    time_components = []
    if days > 0:
        time_components.append(f"{days} days")
    if hours > 0:
        time_components.append(f"{hours} hours")
    if minutes > 0:
        time_components.append(f"{minutes} minutes")
    time_components.append(f"{seconds:.2f} seconds")  # Always include seconds

    # Join the time components with commas and "and"
    formatted_time = ", ".join(time_components[:-1])  # Join all but the last
    if len(time_components) > 1:
        formatted_time += " and " + time_components[-1]  # Add "and" and the last
    else:
        formatted_time = time_components[0]  # If only one component, use it directly

    return formatted_time

In [None]:
# Calculate and print total training and validation runtime
training_and_validation_time = end_time - start_time
print(f"\nTotal training and validation runtime: {format_runtime(training_and_validation_time)}")

## K-mer Importance, how many clades contain this k-mer? 

In [None]:
# Feature Importance Analysis

importances = model.feature_importances_

# Map feature indices back to k-mers
kmer_names = process_chunk.ohe.get_feature_names_out()  # Get k-mer names from the encoder

feature_importances = pd.DataFrame({'kmer': kmer_names, 'importance': importances})
feature_importances.sort_values(by='importance', ascending=False, inplace=True)

# Create an empty list to store clades associated with each k-mer
clades_with_kmer = []

# Iterate through each k-mer in feature_importances
for _, row in feature_importances.iterrows():
    kmer = row['kmer']
    kmer_index = np.where(kmer_names == kmer)[0][0]  # Get the index of the k-mer

    # Find rows in X_train where this k-mer is present (non-zero value)
    rows_with_kmer = np.where(X_train[:, kmer_index] > 0)[0]

    # Get the corresponding clade labels from y_train
    clade_labels = [y_train[i] for i in rows_with_kmer]

    # Decode the clade labels and get unique clades
    unique_clades = list(set(le.inverse_transform(clade_labels)))

    # Append the list of clades to clades_with_kmer
    clades_with_kmer.append(unique_clades)

# Add a new column 'Clades' to feature_importances
feature_importances['Clades'] = clades_with_kmer

# Get the total number of unique clades
total_clades = len(le.classes_) 

# Add a new column 'Clade Proportion' to calculate the proportion
feature_importances['Clade Proportion'] = feature_importances['Clades'].apply(lambda x: len(x) / total_clades)

# # Increase the maximum number of rows and columns to display
# pd.set_option('display.max_rows', None)  # Display all rows
# pd.set_option('display.max_columns', None)  # Display all columns   

# # You can also specifically increase the width of the 'Clades' column
# pd.set_option('display.max_colwidth', None)  # No limit on column width

total_kmers = len(kmer_names) # Assuming kmer_names contains all unique k-mers

# Display the top N most important k-mers
N = 20  # You can adjust this to show more or fewer k-mers
print(f"\nTop {N} (of {total_kmers}) most important k-mers:")
display(feature_importances.head(N))

# Print feature_importances to markdown
print(feature_importances.head(N).to_markdown(numalign="left", stralign="left"))

# # Output everyting to file
# feature_importances.to_csv('feature_importances.csv', index=False)

### SHAP on k-mer importance

In [None]:
# # Initialize the SHAP explainer. For tree-based models like Random Forest, use TreeExplainer
explainer = shap.TreeExplainer(model)

# Calculate SHAP values for the test set, disabling the additivity check
shap_values = explainer.shap_values(X_test, check_additivity=False) # You might want to use a smaller subset of X_test for faster computation

# Print shapes for debugging
print(shap_values.shape)
print(X_test.shape)
print(len(process_chunk.ohe.get_feature_names_out()))

# Visualize the first prediction's explanation for the first class (index 0)
shap.initjs()
shap.force_plot(explainer.expected_value[0], shap_values[0,:,0], feature_names=process_chunk.ohe.get_feature_names_out())

# Summarize the effects of all the features for all classes
shap.summary_plot(shap_values, X_test, feature_names=process_chunk.ohe.get_feature_names_out())

# Export the trained model

In [None]:
# Define a function to generate a timestamp
def get_timestamp():
  now = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
  return now

In [None]:
# Save the trained model
timestamp = get_timestamp()
model_data = {'model': model, 'label_encoder': le, 'one_hot_encoder': process_chunk.ohe} # Save the model, labelencoder, and OHE together
model_filename = f"model_learn-flu-train_{timestamp}.joblib"
joblib.dump(model_data, model_filename)
print(f"Trained Model, LabelEncoder, and OneHotEncoder and saved to: {model_filename}")