In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch import optim
from torchmetrics import F1Score
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
import optuna
from optuna.integration import PyTorchLightningPruningCallback
import pandas as pd
import os
import numpy as np

torch.set_float32_matmul_precision('high')

# Preprocessing

In [None]:
# --- Load data ---
mlst_train_df = pd.read_csv('../assets/training_set_mlst.csv')
mlst_test_df = pd.read_csv('../assets/test_set_mlst.csv')
mlst_val_df = pd.read_csv('../assets/validation_set_mlst.csv')

serotype_train_df = pd.read_csv('../assets/training_set_serotype.csv')
serotype_test_df = pd.read_csv('../assets/test_set_serotype.csv')
serotype_val_df = pd.read_csv('../assets/validation_set_serotype.csv')

subspecies_train_df = pd.read_csv('../assets/training_set_subspecies.csv')
subspecies_test_df = pd.read_csv('../assets/test_set_subspecies.csv')
subspecies_val_df = pd.read_csv('../assets/validation_set_subspecies.csv')

kmc5_arrays = os.path.expanduser('~/PROJECTS/GaTech/FCGR_classifier/salmonella_kmc5_arrays/')
kmc7_arrays = os.path.expanduser('~/PROJECTS/GaTech/FCGR_classifier/salmonella_kmc7_arrays/')

def load_kmer_arrays(df, array_dir, suffix):
    arrays = []
    labels = []
    for idx, row in df.iterrows():
        sample_id = row[0]
        label = row[1]
        array_path = os.path.join(array_dir, f"{sample_id}{suffix}.npy")
        if os.path.exists(array_path):
            array = np.load(array_path).flatten()
            arrays.append(array)
            labels.append(label)
        else:
            print(f"Warning: Array file {array_path} not found.")
    return np.array(arrays), np.array(labels)

# MLST
X_train_mlst_5, y_train_mlst_5 = load_kmer_arrays(mlst_train_df, kmc5_arrays, '_k5_k5')
X_val_mlst_5, y_val_mlst_5 = load_kmer_arrays(mlst_val_df, kmc5_arrays, '_k5_k5')
X_test_mlst_5, y_test_mlst_5 = load_kmer_arrays(mlst_test_df, kmc5_arrays, '_k5_k5')

X_train_mlst_7, y_train_mlst_7 = load_kmer_arrays(mlst_train_df, kmc7_arrays, '_k7_k7')
X_val_mlst_7, y_val_mlst_7 = load_kmer_arrays(mlst_val_df, kmc7_arrays, '_k7_k7')
X_test_mlst_7, y_test_mlst_7 = load_kmer_arrays(mlst_test_df, kmc7_arrays, '_k7_k7')

# Serotype
X_train_sero_5, y_train_sero_5 = load_kmer_arrays(serotype_train_df, kmc5_arrays, '_k5_k5')
X_val_sero_5, y_val_sero_5 = load_kmer_arrays(serotype_val_df, kmc5_arrays, '_k5_k5')
X_test_sero_5, y_test_sero_5 = load_kmer_arrays(serotype_test_df, kmc5_arrays, '_k5_k5')

X_train_sero_7, y_train_sero_7 = load_kmer_arrays(serotype_train_df, kmc7_arrays, '_k7_k7')
X_val_sero_7, y_val_sero_7 = load_kmer_arrays(serotype_val_df, kmc7_arrays, '_k7_k7')
X_test_sero_7, y_test_sero_7 = load_kmer_arrays(serotype_test_df, kmc7_arrays, '_k7_k7')

# Subspecies
X_train_sub_5, y_train_sub_5 = load_kmer_arrays(subspecies_train_df, kmc5_arrays, '_k5_k5')
X_val_sub_5, y_val_sub_5 = load_kmer_arrays(subspecies_val_df, kmc5_arrays, '_k5_k5')
X_test_sub_5, y_test_sub_5 = load_kmer_arrays(subspecies_test_df, kmc5_arrays, '_k5_k5')

X_train_sub_7, y_train_sub_7 = load_kmer_arrays(subspecies_train_df, kmc7_arrays, '_k7_k7')
X_val_sub_7, y_val_sub_7 = load_kmer_arrays(subspecies_val_df, kmc7_arrays, '_k7_k7')
X_test_sub_7, y_test_sub_7 = load_kmer_arrays(subspecies_test_df, kmc7_arrays, '_k7_k7')

  sample_id = row[0]
  label = row[1]


# Performance metric functions

In [None]:
def accuracy(preds, y):
    return (preds == y).float().mean()

def F1(out_dim,preds,y):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Initialize F1Score for binary classification
    f1 = F1Score(task="multiclass", num_classes=out_dim, average='macro').to(device)

    # Calculate F1 score
    return f1(preds, y)

# Panspace CNN model (encoder portion of their model)

In [None]:
class PanspaceCNN(pl.LightningModule):
    def __init__(self, input_size, out_dim, learning_rate):
        """
        Initializes the PyTorch model for Panspace.

        Args:
            input_size (int): The size of the square CGR image (e.g., 64 for k=6).
            out_dim (int): The size of the output vector.
        """
        super(PanspaceCNN, self).__init__()
        
        # Convolutional layers
        # The number of filters and kernel sizes are example values and can be tuned.
        self.save_hyperparameters()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        # Max pooling layers
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Fully connected layers
        # Calculate the size of the flattened tensor after convolutions and pooling.
        # This assumes a 1xinput_size x input_size input image.
        final_conv_size = input_size // (2**3)  # After 3 pooling layers
        self.fc1 = nn.Linear(128 * final_conv_size * final_conv_size, 512)
        self.fc2 = nn.Linear(512, out_dim)

        # Batch normalization and activation functions
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        """
        Defines the forward pass of the model.

        Args:
            x (torch.Tensor): A tensor representing the CGR image of shape (batch_size, 1, H, W).

        Returns:
            torch.Tensor: A tensor with predicted probabilities for each class of shape (batch_size, num_classes)
        """
        # Apply convolutions, batch normalization, and pooling
        x = self.pool(self.relu(self.bn1(self.conv1(x))))
        x = self.pool(self.relu(self.bn2(self.conv2(x))))
        x = self.pool(self.relu(self.bn3(self.conv3(x))))
        
        # Flatten the tensor
        x = torch.flatten(x, 1)

        # Apply fully connected layers
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, torch.argmax(y,dim=1))
        f1 = F1(self.hparams.out_dim,preds,torch.argmax(y,dim=1))
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        self.log('val_f1', f1, prog_bar=True)
        return {'val_loss': loss, 'val_acc': acc, 'val_f1': f1}
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, torch.argmax(y,dim=1))
        f1 = F1(self.hparams.out_dim,preds,torch.argmax(y,dim=1))
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        self.log('test_f1', f1, prog_bar=True)
        return {'test_loss': loss, 'test_acc': acc, 'test_f1': f1}
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

In [None]:
# ----------------------------------
# scale classes for one-hot encoding
# ----------------------------------

def scale(y,classes):
    y_scaled = []
    for label in y:
        index = np.where(classes == label)
        y_scaled.append(index[0][0])
    y_scale = np.array(y_scaled)
    return y_scale

def prepare_data(X_train,X_val,X_test,y_train,y_val,y_test,size=32,out_dim=254):
    y_train = y_train.astype(int)
    y_val = y_val.astype(int)
    y_test = y_test.astype(int)
    classes = np.sort(np.unique(y_train))

    y_train_scale = scale(y_train,classes)
    y_val_scale = scale(y_val,classes)
    y_test_scale = scale(y_test,classes)

    # -----------------------------
    # Convert data to torch tensors
    # -----------------------------
    X_train_tensor = torch.tensor(X_train.reshape(-1,size,size)).unsqueeze(1).float()  # (N, 1, size, size)
    X_val_tensor = torch.tensor(X_val.reshape(-1,size,size)).unsqueeze(1).float()
    X_test_tensor = torch.tensor(X_test.reshape(-1,size,size)).unsqueeze(1).float()
    y_train_tensor = torch.tensor(y_train_scale)
    y_val_tensor = torch.tensor(y_val_scale)
    y_test_tensor = torch.tensor(y_test_scale)

    # -----------------------------
    # Encode labels to categorical
    # -----------------------------
    y_train_enc = F.one_hot(y_train_tensor, num_classes = out_dim).float()
    y_val_enc = F.one_hot(y_val_tensor, num_classes = out_dim).float()
    y_test_enc = F.one_hot(y_test_tensor, num_classes= out_dim).float()

    # -----------------------------
    # Create PyTorch Datasets and Loaders
    # -----------------------------
    train_dataset = TensorDataset(X_train_tensor, y_train_enc)
    val_dataset = TensorDataset(X_val_tensor, y_val_enc)
    test_dataset = TensorDataset(X_test_tensor, y_test_enc)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64)
    test_loader = DataLoader(test_dataset, batch_size=64)
    return train_loader,val_loader,test_loader

# Define the objective function for Optuna
def objective(trial):
    # Set the hyperparameters to optimize
    learning_rate = trial.suggest_float('learning_rate', 1e-7, 1e-2, log=True)
    
    # Create the model with trial hyperparameters
    model = PanspaceCNN(
        input_size=128,
        out_dim=254,
        learning_rate=learning_rate
    )
    
    # Early stopping callback
    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        patience=5,
        verbose=False,
        mode='min'
    )
    
    # Optuna pruning callback
    pruning_callback = PyTorchLightningPruningCallback(trial, monitor='val_loss')
    
    # Logger
    logger = TensorBoardLogger(save_dir=os.getcwd(), name=f"optuna_logs/trial_{trial.number}")
    
    # Create trainer
    trainer = pl.Trainer(
        max_epochs=50,
        callbacks=[early_stop_callback, pruning_callback],
        logger=logger,
        enable_progress_bar=False,  
        enable_model_summary=False  
    )
    
    # Preparing the data
    train_loader, val_loader, test_loader = prepare_data(X_train_7,X_val_7,X_test_7,y_train_7,y_val_7,y_test_7,128,254)
    
    # Training the model
    trainer.fit(model, train_loader, val_loader)
    
    # Final validation loss
    return trainer.callback_metrics['val_loss'].item()

def run_optimization(n_trials=50):
    pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10)
    study = optuna.create_study(direction='minimize', pruner=pruner)
    study.optimize(objective, n_trials=n_trials)
    
    print("Best trial:")
    trial = study.best_trial
    print(f"  Value: {trial.value}")
    print("  Params: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")
        
    return study

def test_best_model(study,X_train,X_val,X_test,y_train,y_val,y_test,size,num_classes):
    # Getting the best hyperparameters
    best_params = study.best_trial.params
    
    # Creating the model with the best hyperparameters
    model = PanspaceCNN(
        input_size=size,
        out_dim=num_classes,
        learning_rate=best_params['learning_rate']
    )
    
    # Creating trainer instance
    trainer = pl.Trainer(max_epochs=150)
    
    # Preparing the data
    train_loader, val_loader, test_loader = prepare_data(X_train,X_val,X_test,y_train,y_val,y_test,size,num_classes)
    
    # Training the model with the best hyperparameters
    trainer.fit(model, train_loader, val_loader)
    
    # Testing the model with the test data
    results = trainer.test(model, dataloaders=test_loader)
    return results