In [None]:
# The code to run the CNN-LSTM-Attention-Residual hybrid model

# Import libraries
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import classification_report, precision_recall_curve, roc_auc_score, f1_score, auc, matthews_corrcoef
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Parameters
SEQ_LENGTH = 300
BATCH_SIZE = 64
EPOCHS = 25
PATIENCE = 7  # Early stopping patience
DATA_FOLDER = '/kaggle/input/algae-dataset'  # Input data
CHECKPOINT_PATH_WORKING = '/kaggle/working/Algae_Final_checkpoint.pth'  # Path for checkpoint as input
CHECKPOINT_PATH_INPUT = '/kaggle/input/Algae_checkpoint_fold4_epoch24.pth'
K_FOLDS = 5  # Number of folds for cross-validation

# EarlyStopping class
class EarlyStopping:
    def __init__(self, patience=7, delta=0, verbose=False):
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping: No improvement in validation loss for {self.counter} epochs.")
            if self.counter >= self.patience:
                self.early_stop = True

# One-hot encoding function
def one_hot_encode(sequence, seq_length=SEQ_LENGTH):
    nucleotide_map = {'A': [1, 0, 0, 0, 0], 'T': [0, 1, 0, 0, 0],
                      'C': [0, 0, 1, 0, 0], 'G': [0, 0, 0, 1, 0],
                      'N': [0, 0, 0, 0, 1]}
    sequence = sequence.upper().ljust(seq_length, 'N')[:seq_length]
    return np.array([nucleotide_map.get(char, [0, 0, 0, 0, 1]) for char in sequence])

# Dataset class
class SequenceDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = torch.tensor(self.sequences[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return sequence, label

# Load data function
def load_data(data_folder):
    sequences = []
    labels = []
    class_names = []
    for idx, file_name in enumerate(os.listdir(data_folder)):
        if file_name.endswith('.csv'):
            file_path = os.path.join(data_folder, file_name)
            data = pd.read_csv(file_path, header=None)
            sequences.extend(data[0].tolist())
            labels.extend([idx] * len(data))
            class_names.append(os.path.splitext(file_name)[0])  # Strip the '.csv' extension
    one_hot_sequences = np.array([one_hot_encode(seq) for seq in sequences])
    return one_hot_sequences, np.array(labels), class_names

# Self-Attention Mechanism
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(SelfAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        x = x.permute(1, 0, 2)  # MultiheadAttention expects (seq_len, batch, embed_dim)
        attn_output, _ = self.attention(x, x, x)
        return attn_output.permute(1, 0, 2)  # Return to (batch, seq_len, embed_dim)

# Multi-kernel CNN with Residual Connections
class MultiKernelCNN(nn.Module):
    def __init__(self, input_channels, output_channels, use_multi_kernel=True):
        super(MultiKernelCNN, self).__init__()
        self.use_multi_kernel = use_multi_kernel
        if use_multi_kernel:
            # Multiple kernel sizes for the second layer
            self.conv3 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
            self.conv5 = nn.Conv1d(input_channels, output_channels, kernel_size=5, padding=2)
            self.conv7 = nn.Conv1d(input_channels, output_channels, kernel_size=7, padding=3)
        else:
            # Single kernel size for the first layer
            self.conv3 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.residual = nn.Conv1d(input_channels, output_channels * (3 if use_multi_kernel else 1), kernel_size=1)
        self.dropout = nn.Dropout(0.5)  # Increased dropout

    def forward(self, x):
        residual = self.residual(x)
        if self.use_multi_kernel:
            # Apply multiple kernels for the second layer
            x1 = self.relu(self.conv3(x))
            x2 = self.relu(self.conv5(x))
            x3 = self.relu(self.conv7(x))
            x = torch.cat((x1, x2, x3), dim=1)
        else:
            # Apply single kernel for the first layer
            x = self.relu(self.conv3(x))
        x = self.pool(x)
        x = self.dropout(x)
        if x.shape != residual.shape:
            residual = residual[:, :, :x.shape[2]]
        x = x + residual
        return x

# CNN Attention Mechanism
class CNN_Attention(nn.Module):
    def __init__(self, channel_dim):
        super(CNN_Attention, self).__init__()
        self.channel_dim = channel_dim
        self.query = nn.Conv1d(channel_dim, channel_dim // 8, kernel_size=1)
        self.key = nn.Conv1d(channel_dim, channel_dim // 8, kernel_size=1)
        self.value = nn.Conv1d(channel_dim, channel_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))  # Learnable scaling factor

    def forward(self, x):
        batch_size, channels, seq_len = x.size()
        # Compute query, key, and value
        query = self.query(x).view(batch_size, -1, seq_len)  # (batch, channels//8, seq_len)
        key = self.key(x).view(batch_size, -1, seq_len)  # (batch, channels//8, seq_len)
        value = self.value(x).view(batch_size, -1, seq_len)  # (batch, channels, seq_len)
        # Compute attention scores
        attention_scores = torch.bmm(query.permute(0, 2, 1), key)  # (batch, seq_len, seq_len)
        attention_scores = F.softmax(attention_scores, dim=-1)
        # Apply attention to value
        out = torch.bmm(value, attention_scores.permute(0, 2, 1))  # (batch, channels, seq_len)
        out = self.gamma * out + x  # Residual connection
        return out

# CNN-LSTM Hybrid Model with Self-Attention
class CNN_LSTM_Model(nn.Module):
    def __init__(self, num_classes):
        super(CNN_LSTM_Model, self).__init__()
        # First CNN layer: single kernel size of 3
        self.multi_kernel_cnn1 = MultiKernelCNN(input_channels=5, output_channels=128, use_multi_kernel=False)
        self.bn1 = nn.BatchNorm1d(128)  # BatchNorm after first CNN layer
        # Second CNN layer: multiple kernel sizes (3, 5, 7)
        self.multi_kernel_cnn2 = MultiKernelCNN(input_channels=128, output_channels=256, use_multi_kernel=True)
        self.cnn_attention2 = CNN_Attention(channel_dim=256 * 3)  # CNN Attention after second CNN layer
        self.bn2 = nn.BatchNorm1d(256 * 3)  # BatchNorm after second CNN layer
        self.residual = nn.Conv1d(in_channels=5, out_channels=256 * 3, kernel_size=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        # LSTM layer with BatchNorm
        self.lstm = nn.LSTM(input_size=256 * 3, hidden_size=256, batch_first=True, bidirectional=True)
        self.bn_lstm = nn.BatchNorm1d(256 * 2)  # BatchNorm after LSTM (output is 256 * 2 due to bidirectional)
        self.self_attention = SelfAttention(embed_dim=256 * 2, num_heads=4)
        self.dropout_lstm = nn.Dropout(0.3)  # Reduced dropout rate
        # Fully connected layers with BatchNorm
        self.fc1 = nn.Linear(256 * 2, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)  # BatchNorm after first FC layer
        self.fc2 = nn.Linear(256, 512)
        self.bn_fc2 = nn.BatchNorm1d(512)  # BatchNorm after second FC layer
        self.fc3 = nn.Linear(512, num_classes)
        self.relu = nn.ReLU()
        self.dropout_fc = nn.Dropout(0.3)  # Reduced dropout rate

    def forward(self, x):
        x = x.permute(0, 2, 1)  # (batch, channels, seq_len)
        residual = self.residual(x)
        # First CNN layer: single kernel size of 3
        x = self.multi_kernel_cnn1(x)
        x = self.bn1(x)  # BatchNorm after first CNN layer
        # Second CNN layer: multiple kernel sizes (3, 5, 7)
        x = self.multi_kernel_cnn2(x)
        x = self.cnn_attention2(x)  # Apply CNN Attention
        x = self.bn2(x)  # BatchNorm after second CNN layer
        x = self.pool(x)
        x = self.dropout_fc(x)
        if x.shape != residual.shape:
            residual = residual[:, :, :x.shape[2]]
        x = x + residual
        x = x.permute(0, 2, 1)  # (batch, seq_len, channels)
        # LSTM layer
        lstm_out, (hn, _) = self.lstm(x)
        lstm_out = lstm_out.permute(0, 2, 1)  # (batch, hidden_dim, seq_len) for BatchNorm
        lstm_out = self.bn_lstm(lstm_out)  # BatchNorm after LSTM
        lstm_out = lstm_out.permute(0, 2, 1)  # (batch, seq_len, hidden_dim)
        lstm_out = self.dropout_lstm(lstm_out)
        # Self-Attention
        attn_out = self.self_attention(lstm_out)
        context_vector = torch.mean(attn_out, dim=1)  # (batch, hidden_dim)
        # Fully connected layers
        x = self.fc1(context_vector)
        x = self.bn_fc1(x)  # BatchNorm after first FC layer
        x = self.relu(x)
        x = self.dropout_fc(x)
        x = self.fc2(x)
        x = self.bn_fc2(x)  # BatchNorm after second FC layer
        x = self.relu(x)
        x = self.dropout_fc(x)
        x = self.fc3(x)
        return x


# Load data
sequences, labels, class_names = load_data(DATA_FOLDER)

# After loading data
print(f"Sequences shape: {sequences.shape}, Labels shape: {labels.shape}")
print(f"Class distribution: {np.bincount(labels)}")

# Split data into train and test sets (80% train, 20% test)
X_train_val, X_test, y_train_val, y_test = train_test_split(
    sequences, labels, test_size=0.2, random_state=42, stratify=labels
)

# Debug: Check shapes after train-test split
print(f"X_train_val shape: {X_train_val.shape}, y_train_val shape: {y_train_val.shape}")
print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")
print(f"Class distribution in training set: {np.bincount(y_train_val)}")
print(f"Class distribution in test set: {np.bincount(y_test)}")

# Initialize StratifiedKFold for cross-validation on the train set
skf = StratifiedKFold(n_splits=K_FOLDS, shuffle=True, random_state=42)

# Define device (CPU or GPU) at the beginning of the script
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check for existing checkpoints
if os.path.exists(CHECKPOINT_PATH_INPUT):
    checkpoint = torch.load(CHECKPOINT_PATH_INPUT, map_location=device, weights_only=True)
    print(f"Resuming training from checkpoint: {CHECKPOINT_PATH_INPUT}")
    checkpoint_fold = int(CHECKPOINT_PATH_INPUT.split('fold')[-1].split('_')[0])  # Extract fold number from checkpoint path
    checkpoint_epoch = checkpoint['epoch']  # Extract epoch number from checkpoint
    print(f"Checkpoint corresponds to Fold {checkpoint_fold}, Epoch {checkpoint_epoch}")
else:
    checkpoint = None
    checkpoint_fold = -1  # No checkpoint
    checkpoint_epoch = -1
    print("No checkpoint found. Starting training from scratch.")


# Cross-validation loop
for fold, (train_idx, val_idx) in enumerate(skf.split(X_train_val, y_train_val)):
    fold_number = fold + 1  # Fold numbers start from 1

    # Skip folds that were already completed
    if checkpoint_fold != -1 and fold_number < checkpoint_fold:
        print(f"Skipping Fold {fold_number} (already completed).", flush=True)
        continue

    print(f"Fold {fold_number}/{K_FOLDS}", flush=True)
    X_train, X_val = X_train_val[train_idx], X_train_val[val_idx]
    y_train, y_val = y_train_val[train_idx], y_train_val[val_idx]

    # Debug: Check shapes for this fold
    print(f"X_train shape (fold {fold_number}): {X_train.shape}, y_train shape: {y_train.shape}", flush=True)
    print(f"X_val shape (fold {fold_number}): {X_val.shape}, y_val shape: {y_val.shape}", flush=True)

    # Create datasets and dataloaders
    train_dataset = SequenceDataset(X_train, y_train)
    val_dataset = SequenceDataset(X_val, y_val)

    # Initialize DataLoader for training and validation
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    # Initialize model, loss, optimizer, and scheduler
    model = CNN_LSTM_Model(num_classes=len(class_names)).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = Adam(model.parameters(), lr=0.001, weight_decay=1e-5)  # Increased initial learning rate
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)  # Cosine annealing with warm restarts

    # Load checkpoint if this is the fold we need to resume
    if checkpoint_fold != -1 and fold_number == checkpoint_fold:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch_start = checkpoint_epoch + 1  # Resume from the next epoch
        print(f"Resuming training for Fold {fold_number} from epoch {epoch_start}", flush=True)
    else:
        epoch_start = 0  # Start from scratch for other folds
        print(f"Starting training for Fold {fold_number} from scratch.", flush=True)

    # Early stopping with increased patience
    early_stopping = EarlyStopping(patience=7, verbose=True)

    # Mixed precision training (only if CUDA is available)
    scaler = torch.amp.GradScaler(enabled=device.type == 'cuda')  # Corrected initialization

    # Training loop
    for epoch in range(epoch_start, EPOCHS):
        model.train()
        total_loss = 0
        for sequences, labels in train_loader:
            sequences, labels = sequences.to(device), labels.to(device)
            optimizer.zero_grad()

            # Mixed precision training
            with torch.amp.autocast(device_type=device.type, enabled=device.type == 'cuda'):  # Corrected device_type
                outputs = model(sequences)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for sequences, labels in val_loader:
                sequences, labels = sequences.to(device), labels.to(device)
                outputs = model(sequences)
                loss = criterion(outputs, labels)  # Calculate loss
                val_loss += loss.item()  # Accumulate validation loss

        val_loss /= len(val_loader)  # Average validation loss
        scheduler.step(val_loss)  # Update learning rate based on validation loss

        # Print learning rate
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {total_loss/len(train_loader):.4f}, Val Loss: {val_loss:.4f}, LR: {current_lr:.6f}", flush=True)

        # Save model checkpoint every 2 epochs
        if (epoch + 1) % 2 == 0:  # Checkpoint every 2 epochs
            checkpoint_path = f"/kaggle/working/Algae_checkpoint_fold{fold_number}_epoch{epoch+1}.pth"
            torch.save({
                'epoch': epoch + 1,  # Save the current epoch
                'model_state_dict': model.state_dict(),  # Save model weights
                'optimizer_state_dict': optimizer.state_dict(),  # Save optimizer state
                'val_loss': val_loss,  # Save validation loss
            }, checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}", flush=True)

        # Early stopping check
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered. Stopping training.", flush=True)
            break



    # Evaluate on the validation set for this fold
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    with torch.no_grad():
        for sequences, labels in val_loader:
            sequences, labels = sequences.to(device), labels.to(device)
            outputs = model(sequences)
            probs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    # Calculate advanced metrics
    roc_auc = roc_auc_score(all_labels, np.array(all_probs)[:, 1])  # Use probabilities of the positive class
    f1 = f1_score(all_labels, all_preds, average='weighted')
    mcc = matthews_corrcoef(all_labels, all_preds)  # Matthews Correlation Coefficient

    # Precision-Recall curve (for binary classification, adjust if needed)
    precision, recall, _ = precision_recall_curve(all_labels, np.array(all_probs)[:, 1], pos_label=1)
    pr_auc = auc(recall, precision)

    print(f"Validation Set Results for Fold {fold_number}:")
    print(classification_report(all_labels, all_preds, target_names=class_names, digits=4))
    print(f"ROC-AUC: {roc_auc:.4f}, F1-Score: {f1:.4f}, MCC: {mcc:.4f}")
    print(f"Precision-Recall AUC: {pr_auc:.4f}")



# Create test dataset and dataloader
test_dataset = SequenceDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Evaluate on the test set
model.eval()
all_preds = []
all_labels = []
all_probs = []
with torch.no_grad():
    for sequences, labels in test_loader:
        sequences, labels = sequences.to(device), labels.to(device)
        outputs = model(sequences)
        probs = F.softmax(outputs, dim=1)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

# Calculate advanced metrics for the test set
roc_auc = roc_auc_score(all_labels, np.array(all_probs)[:, 1])  # Use probabilities of the positive class
f1 = f1_score(all_labels, all_preds, average='weighted')
mcc = matthews_corrcoef(all_labels, all_preds)  # Matthews Correlation Coefficient

# Precision-Recall curve (for binary classification, adjust if needed)
precision, recall, _ = precision_recall_curve(all_labels, np.array(all_probs)[:, 1], pos_label=1)
pr_auc = auc(recall, precision)

print("Test Set Results:")
print(classification_report(all_labels, all_preds, target_names=class_names, digits=4))
print(f"ROC-AUC: {roc_auc:.4f}, F1-Score: {f1:.4f}, MCC: {mcc:.4f}")
print(f"Precision-Recall AUC: {pr_auc:.4f}")

# Save the final model
final_model_path = "/kaggle/working/Final_Algae_model.pth"
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved at {final_model_path}")

In [None]:
# The code to have the Confusion matrix

# Import libraries
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, precision_recall_curve, roc_auc_score, f1_score, auc, matthews_corrcoef
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

# Parameters
SEQ_LENGTH = 270
BATCH_SIZE = 64
DATA_FOLDER = '//kaggle/input/plants algae dataset (270nt)'
FINAL_MODEL_PATH = '/kaggle/input/Plant-Algae(270nt)_Final checkpoint.pth'

# One-hot encoding function
def one_hot_encode(sequence, seq_length=SEQ_LENGTH):
    nucleotide_map = {'A': [1, 0, 0, 0, 0], 'T': [0, 1, 0, 0, 0],
                      'C': [0, 0, 1, 0, 0], 'G': [0, 0, 0, 1, 0],
                      'N': [0, 0, 0, 0, 1]}
    sequence = sequence.upper().ljust(seq_length, 'N')[:seq_length]
    return np.array([nucleotide_map.get(char, [0, 0, 0, 0, 1]) for char in sequence])

# Dataset class
class SequenceDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = torch.tensor(self.sequences[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return sequence, label

# Load data function
def load_data(data_folder):
    sequences = []
    labels = []
    class_names = []
    for idx, file_name in enumerate(os.listdir(data_folder)):
        if file_name.endswith('.csv'):
            file_path = os.path.join(data_folder, file_name)
            data = pd.read_csv(file_path, header=None)
            sequences.extend(data[0].tolist())
            labels.extend([idx] * len(data))
            class_names.append(os.path.splitext(file_name)[0])
    one_hot_sequences = np.array([one_hot_encode(seq) for seq in sequences])
    return one_hot_sequences, np.array(labels), class_names

# Self-Attention Mechanism
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(SelfAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        x = x.permute(1, 0, 2)
        attn_output, _ = self.attention(x, x, x)
        return attn_output.permute(1, 0, 2)

# Multi-kernel CNN with Residual Connections
class MultiKernelCNN(nn.Module):
    def __init__(self, input_channels, output_channels, use_multi_kernel=True):
        super(MultiKernelCNN, self).__init__()
        self.use_multi_kernel = use_multi_kernel
        if use_multi_kernel:
            self.conv3 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
            self.conv5 = nn.Conv1d(input_channels, output_channels, kernel_size=5, padding=2)
            self.conv7 = nn.Conv1d(input_channels, output_channels, kernel_size=7, padding=3)
        else:
            self.conv3 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.residual = nn.Conv1d(input_channels, output_channels * (3 if use_multi_kernel else 1), kernel_size=1)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        residual = self.residual(x)
        if self.use_multi_kernel:
            x1 = self.relu(self.conv3(x))
            x2 = self.relu(self.conv5(x))
            x3 = self.relu(self.conv7(x))
            x = torch.cat((x1, x2, x3), dim=1)
        else:
            x = self.relu(self.conv3(x))
        x = self.pool(x)
        x = self.dropout(x)
        if x.shape != residual.shape:
            residual = residual[:, :, :x.shape[2]]
        x = x + residual
        return x

# CNN Attention Mechanism
class CNN_Attention(nn.Module):
    def __init__(self, channel_dim):
        super(CNN_Attention, self).__init__()
        self.channel_dim = channel_dim
        self.query = nn.Conv1d(channel_dim, channel_dim // 8, kernel_size=1)
        self.key = nn.Conv1d(channel_dim, channel_dim // 8, kernel_size=1)
        self.value = nn.Conv1d(channel_dim, channel_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, seq_len = x.size()
        query = self.query(x).view(batch_size, -1, seq_len)
        key = self.key(x).view(batch_size, -1, seq_len)
        value = self.value(x).view(batch_size, -1, seq_len)
        attention_scores = torch.bmm(query.permute(0, 2, 1), key)
        attention_scores = F.softmax(attention_scores, dim=-1)
        out = torch.bmm(value, attention_scores.permute(0, 2, 1))
        out = self.gamma * out + x
        return out

# CNN-LSTM Hybrid Model with Self-Attention
class CNN_LSTM_Model(nn.Module):
    def __init__(self, num_classes):
        super(CNN_LSTM_Model, self).__init__()
        self.multi_kernel_cnn1 = MultiKernelCNN(input_channels=5, output_channels=128, use_multi_kernel=False)
        self.bn1 = nn.BatchNorm1d(128)
        self.multi_kernel_cnn2 = MultiKernelCNN(input_channels=128, output_channels=256, use_multi_kernel=True)
        self.cnn_attention2 = CNN_Attention(channel_dim=256 * 3)
        self.bn2 = nn.BatchNorm1d(256 * 3)
        self.residual = nn.Conv1d(in_channels=5, out_channels=256 * 3, kernel_size=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.lstm = nn.LSTM(input_size=256 * 3, hidden_size=256, batch_first=True, bidirectional=True)
        self.bn_lstm = nn.BatchNorm1d(256 * 2)
        self.self_attention = SelfAttention(embed_dim=256 * 2, num_heads=4)
        self.dropout_lstm = nn.Dropout(0.3)
        self.fc1 = nn.Linear(256 * 2, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 512)
        self.bn_fc2 = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, num_classes)
        self.relu = nn.ReLU()
        self.dropout_fc = nn.Dropout(0.3)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        residual = self.residual(x)
        x = self.multi_kernel_cnn1(x)
        x = self.bn1(x)
        x = self.multi_kernel_cnn2(x)
        x = self.cnn_attention2(x)
        x = self.bn2(x)
        x = self.pool(x)
        x = self.dropout_fc(x)
        if x.shape != residual.shape:
            residual = residual[:, :, :x.shape[2]]
        x = x + residual
        x = x.permute(0, 2, 1)
        lstm_out, (hn, _) = self.lstm(x)
        lstm_out = lstm_out.permute(0, 2, 1)
        lstm_out = self.bn_lstm(lstm_out)
        lstm_out = lstm_out.permute(0, 2, 1)
        lstm_out = self.dropout_lstm(lstm_out)
        attn_out = self.self_attention(lstm_out)
        context_vector = torch.mean(attn_out, dim=1)
        x = self.fc1(context_vector)
        x = self.bn_fc1(x)
        x = self.relu(x)
        x = self.dropout_fc(x)
        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = self.relu(x)
        x = self.dropout_fc(x)
        x = self.fc3(x)
        return x


import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
import numpy as np
import torch
from torch.utils.data import DataLoader

# 1. Load data and class names
one_hot_sequences, labels, class_names = load_data(DATA_FOLDER)

# 2. Split data into train and test (e.g., 80/20 split)
X_train, X_test, y_train, y_test = train_test_split(
    one_hot_sequences, labels, test_size=0.2, random_state=42, stratify=labels)

# 3. Create test dataset and dataloader
test_dataset = SequenceDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 4. Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 5. Number of classes
num_classes = len(class_names)

# 6. Instantiate the model and load weights
model = CNN_LSTM_Model(num_classes=num_classes).to(device)
model.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=device))

# 7. Set model to evaluation mode
model.eval()

# 8. Run inference on test data and collect predictions and true labels
all_preds = []
all_labels = []

with torch.no_grad():
    for sequences, labels_batch in test_loader:
        sequences = sequences.to(device)
        labels_batch = labels_batch.to(device)
        outputs = model(sequences)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels_batch.cpu().numpy())

# 9. Compute confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# 10. Normalize confusion matrix to percentages (row-wise)
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100

report = classification_report(all_labels, all_preds, target_names=class_names, digits=4)
print("Classification Report:\n", report)


# 11. Visualize confusion matrix with seaborn heatmap



plt.figure(figsize=(5, 4))
sns.heatmap(cm_percent, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                annot_kws={'size': 14})

plt.xlabel('Predicted Labels', fontsize=16, labelpad=16)
plt.ylabel('True Labels', fontsize=16, labelpad=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

output_path = '/kaggle/working/Final_confusion_matrix_Plants-Algae.png'
plt.savefig(output_path, bbox_inches='tight', dpi=300)
plt.show()

print(f"Confusion matrix plot saved to {output_path}")



In [None]:
# The code to have the ROC and PR curves

# Import required libraries
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (classification_report, roc_curve, auc,
                            precision_recall_curve, average_precision_score,
                            f1_score, matthews_corrcoef)
import matplotlib.pyplot as plt

# Parameters
SEQ_LENGTH = 300
BATCH_SIZE = 64
DATA_FOLDER = '/kaggle/input/plant-algae-dataset'
FINAL_MODEL_PATH = '/kaggle/input/final-models/Final_Plant-Algae_model.pth'
N_SPLITS = 5  # Number of folds for cross-validation
RANDOM_STATE = 42

# One-hot encoding function
def one_hot_encode(sequence, seq_length=SEQ_LENGTH):
    nucleotide_map = {'A': [1, 0, 0, 0, 0], 'T': [0, 1, 0, 0, 0],
                      'C': [0, 0, 1, 0, 0], 'G': [0, 0, 0, 1, 0],
                      'N': [0, 0, 0, 0, 1]}
    sequence = sequence.upper().ljust(seq_length, 'N')[:seq_length]
    return np.array([nucleotide_map.get(char, [0, 0, 0, 0, 1]) for char in sequence])

# Dataset class
class SequenceDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = torch.tensor(self.sequences[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return sequence, label

# Load data function
def load_data(data_folder):
    sequences = []
    labels = []
    class_names = []
    for idx, file_name in enumerate(os.listdir(data_folder)):
        if file_name.endswith('.csv'):
            file_path = os.path.join(data_folder, file_name)
            data = pd.read_csv(file_path, header=None)
            sequences.extend(data[0].tolist())
            labels.extend([idx] * len(data))
            class_names.append(os.path.splitext(file_name)[0])
    one_hot_sequences = np.array([one_hot_encode(seq) for seq in sequences])
    return one_hot_sequences, np.array(labels), class_names

# Self-Attention Mechanism
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(SelfAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        x = x.permute(1, 0, 2)
        attn_output, _ = self.attention(x, x, x)
        return attn_output.permute(1, 0, 2)

# Multi-kernel CNN with Residual Connections
class MultiKernelCNN(nn.Module):
    def __init__(self, input_channels, output_channels, use_multi_kernel=True):
        super(MultiKernelCNN, self).__init__()
        self.use_multi_kernel = use_multi_kernel
        if use_multi_kernel:
            self.conv3 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
            self.conv5 = nn.Conv1d(input_channels, output_channels, kernel_size=5, padding=2)
            self.conv7 = nn.Conv1d(input_channels, output_channels, kernel_size=7, padding=3)
        else:
            self.conv3 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.residual = nn.Conv1d(input_channels, output_channels * (3 if use_multi_kernel else 1), kernel_size=1)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        residual = self.residual(x)
        if self.use_multi_kernel:
            x1 = self.relu(self.conv3(x))
            x2 = self.relu(self.conv5(x))
            x3 = self.relu(self.conv7(x))
            x = torch.cat((x1, x2, x3), dim=1)
        else:
            x = self.relu(self.conv3(x))
        x = self.pool(x)
        x = self.dropout(x)
        if x.shape != residual.shape:
            residual = residual[:, :, :x.shape[2]]
        x = x + residual
        return x

# CNN Attention Mechanism
class CNN_Attention(nn.Module):
    def __init__(self, channel_dim):
        super(CNN_Attention, self).__init__()
        self.channel_dim = channel_dim
        self.query = nn.Conv1d(channel_dim, channel_dim // 8, kernel_size=1)
        self.key = nn.Conv1d(channel_dim, channel_dim // 8, kernel_size=1)
        self.value = nn.Conv1d(channel_dim, channel_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, seq_len = x.size()
        query = self.query(x).view(batch_size, -1, seq_len)
        key = self.key(x).view(batch_size, -1, seq_len)
        value = self.value(x).view(batch_size, -1, seq_len)
        attention_scores = torch.bmm(query.permute(0, 2, 1), key)
        attention_scores = F.softmax(attention_scores, dim=-1)
        out = torch.bmm(value, attention_scores.permute(0, 2, 1))
        out = self.gamma * out + x
        return out

# CNN-LSTM Hybrid Model with Self-Attention
class CNN_LSTM_Model(nn.Module):
    def __init__(self, num_classes):
        super(CNN_LSTM_Model, self).__init__()
        self.multi_kernel_cnn1 = MultiKernelCNN(input_channels=5, output_channels=128, use_multi_kernel=False)
        self.bn1 = nn.BatchNorm1d(128)
        self.multi_kernel_cnn2 = MultiKernelCNN(input_channels=128, output_channels=256, use_multi_kernel=True)
        self.cnn_attention2 = CNN_Attention(channel_dim=256 * 3)
        self.bn2 = nn.BatchNorm1d(256 * 3)
        self.residual = nn.Conv1d(in_channels=5, out_channels=256 * 3, kernel_size=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.lstm = nn.LSTM(input_size=256 * 3, hidden_size=256, batch_first=True, bidirectional=True)
        self.bn_lstm = nn.BatchNorm1d(256 * 2)
        self.self_attention = SelfAttention(embed_dim=256 * 2, num_heads=4)
        self.dropout_lstm = nn.Dropout(0.3)
        self.fc1 = nn.Linear(256 * 2, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 512)
        self.bn_fc2 = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, num_classes)
        self.relu = nn.ReLU()
        self.dropout_fc = nn.Dropout(0.3)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        residual = self.residual(x)
        x = self.multi_kernel_cnn1(x)
        x = self.bn1(x)
        x = self.multi_kernel_cnn2(x)
        x = self.cnn_attention2(x)
        x = self.bn2(x)
        x = self.pool(x)
        x = self.dropout_fc(x)
        if x.shape != residual.shape:
            residual = residual[:, :, :x.shape[2]]
        x = x + residual
        x = x.permute(0, 2, 1)
        lstm_out, (hn, _) = self.lstm(x)
        lstm_out = lstm_out.permute(0, 2, 1)
        lstm_out = self.bn_lstm(lstm_out)
        lstm_out = lstm_out.permute(0, 2, 1)
        lstm_out = self.dropout_lstm(lstm_out)
        attn_out = self.self_attention(lstm_out)
        context_vector = torch.mean(attn_out, dim=1)
        x = self.fc1(context_vector)
        x = self.bn_fc1(x)
        x = self.relu(x)
        x = self.dropout_fc(x)
        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = self.relu(x)
        x = self.dropout_fc(x)
        x = self.fc3(x)
        return x

# Main execution
# Main execution
if __name__ == "__main__":
    # Load and prepare data
    sequences, labels, class_names = load_data(DATA_FOLDER)
    print(f"Sequences shape: {sequences.shape}, Labels shape: {labels.shape}")
    print(f"Class distribution: {np.bincount(labels)}")
    print(f"Class names: {class_names}")

    # Initialize variables to store results across folds
    all_folds_probs = []
    all_folds_labels = []
    all_folds_preds = []

    # Initialize Stratified K-Fold cross-validator
    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)

    # Create full dataset
    full_dataset = SequenceDataset(sequences, labels)

    # Cross-validation loop
    for fold, (train_idx, val_idx) in enumerate(skf.split(sequences, labels)):
        print(f"\n=== Fold {fold + 1}/{N_SPLITS} ===")

        # Create train and validation datasets
        train_dataset = Subset(full_dataset, train_idx)
        val_dataset = Subset(full_dataset, val_idx)

        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

        # Initialize device and model
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = CNN_LSTM_Model(num_classes=len(class_names)).to(device)

        # Load final model if available
        if os.path.exists(FINAL_MODEL_PATH):
            try:
                model.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=device, weights_only=True))
                print(f"Loaded model from: {FINAL_MODEL_PATH}")
            except Exception as e:
                print(f"Error loading model: {e}")

        # Validation phase
        model.eval()
        fold_probs = []
        fold_labels = []
        fold_preds = []

        with torch.no_grad():
            for sequences, labels in val_loader:
                sequences, labels = sequences.to(device), labels.to(device)
                outputs = model(sequences)
                probs = F.softmax(outputs, dim=1)

                fold_probs.append(probs.cpu().numpy())
                fold_labels.append(labels.cpu().numpy())
                fold_preds.append(outputs.argmax(dim=1).cpu().numpy())

        # Convert lists to numpy arrays
        fold_probs = np.concatenate(fold_probs)
        fold_labels = np.concatenate(fold_labels)
        fold_preds = np.concatenate(fold_preds)

        # Store this fold's results
        all_folds_probs.append(fold_probs)
        all_folds_labels.append(fold_labels)
        all_folds_preds.append(fold_preds)

        # Print fold results
        print(f"Fold {fold + 1} validation results:")
        print(classification_report(fold_labels, fold_preds, target_names=class_names))

    # After all folds, compute and plot average metrics
    all_labels = np.concatenate(all_folds_labels)
    all_preds = np.concatenate(all_folds_preds)

    # Enhanced plotting function with professional settings
    def plot_avg_roc_pr_curves(all_folds_probs, all_folds_labels, class_names):
        # Set up professional plotting style
        plt.style.use('default')
        plt.rcParams.update({
            'font.family': 'serif',
            'font.serif': ['Times New Roman'],
            'font.size': 12,
            'axes.labelsize': 18,
            'axes.titlesize': 16,
            'legend.fontsize': 14,
            'xtick.labelsize': 14,
            'ytick.labelsize': 14,
            'figure.dpi': 600,
            'savefig.dpi': 600,
            'savefig.bbox': 'tight',
            'lines.linewidth': 2,
            'grid.alpha': 0.3
        })

        # Concatenate all fold results
        y_probs = np.concatenate(all_folds_probs)
        y_true = np.concatenate(all_folds_labels)

        # For binary classification
        y_probs = y_probs[:, 1]  # Positive class probabilities

        # Compute metrics
        fpr, tpr, _ = roc_curve(y_true, y_probs)
        roc_auc = auc(fpr, tpr)
        precision, recall, _ = precision_recall_curve(y_true, y_probs)
        average_precision = average_precision_score(y_true, y_probs)

        # Compute standard deviations across folds
        roc_aucs = []
        ap_scores = []

        for fold_probs, fold_labels in zip(all_folds_probs, all_folds_labels):
            fold_probs = fold_probs[:, 1]
            fpr_fold, tpr_fold, _ = roc_curve(fold_labels, fold_probs)
            roc_aucs.append(auc(fpr_fold, tpr_fold))
            ap_scores.append(average_precision_score(fold_labels, fold_probs))


        roc_auc_sem = np.std(roc_aucs, ddof=1) / np.sqrt(len(roc_aucs))
        ap_sem = np.std(ap_scores, ddof=1) / np.sqrt(len(ap_scores))


        # Create figure with two subplots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

        # ROC Curve with professional styling
        ax1.plot(fpr, tpr, color='#1f77b4', label=f'ROC (AUC = {roc_auc:.3f} ± {roc_auc_sem:.3f})')
        ax1.plot([0, 1], [0, 1], 'k--', alpha=0.3)
        ax1.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05],
               xlabel='False Positive Rate',
               ylabel='True Positive Rate',
               title='ROC Curve')
        ax1.grid(True, linestyle='--', alpha=0.3)
        ax1.legend(loc='lower right')

        # PR Curve with professional styling
        ax2.plot(recall, precision, color='#ff7f0e', label=f'PR (AP = {average_precision:.3f} ± {ap_sem:.3f})')
        ax2.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05],
               xlabel='Recall',
               ylabel='Precision',
               title='Precision-Recall Curve')
        ax2.grid(True, linestyle='--', alpha=0.3)
        ax2.legend(loc='lower left')

        # Adjust layout
        plt.tight_layout(pad=3.0)

        # Save in multiple formats
        output_base = '/kaggle/working/classification_curves_PG_300'
        for fmt in ['pdf', 'png', 'svg']:
            plt.savefig(f'{output_base}.{fmt}', dpi=600, bbox_inches='tight')
        plt.show()

        print(f"\nPublication-quality curves saved to:")
        print(f"- {output_base}.png")
        print(f"\nPerformance Metrics:")
        print(f"ROC AUC: {roc_auc:.4f} ± {roc_auc_sem:.4f}")
        print(f"Average Precision: {average_precision:.4f} ± {ap_sem:.4f}")

    # Generate the professional curves
    plot_avg_roc_pr_curves(all_folds_probs, all_folds_labels, class_names)

    # Print final classification report
    print("\n=== Final Classification Report ===")
    print(classification_report(all_labels, all_preds, target_names=class_names))

    # Print additional metrics
    print("\n=== Additional Metrics ===")
    print(f"Macro F1-score: {f1_score(all_labels, all_preds, average='macro'):.3f}")
    print(f"Weighted F1-score: {f1_score(all_labels, all_preds, average='weighted'):.3f}")
    print(f"Matthews Correlation Coefficient: {matthews_corrcoef(all_labels, all_preds):.3f}")

In [None]:
# The code to have the analysis of attention heatmaps

# Import libraries
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, precision_recall_curve, roc_auc_score, f1_score, auc, matthews_corrcoef
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

# Parameters
SEQ_LENGTH = 300
BATCH_SIZE = 64
DATA_FOLDER = '/kaggle/input/algae-dataset'
FINAL_MODEL_PATH = '/kaggle/input/final-models/Final_Algae_model.pth'

# One-hot encoding function
def one_hot_encode(sequence, seq_length=SEQ_LENGTH):
    nucleotide_map = {'A': [1, 0, 0, 0, 0], 'T': [0, 1, 0, 0, 0],
                      'C': [0, 0, 1, 0, 0], 'G': [0, 0, 0, 1, 0],
                      'N': [0, 0, 0, 0, 1]}
    sequence = sequence.upper().ljust(seq_length, 'N')[:seq_length]
    return np.array([nucleotide_map.get(char, [0, 0, 0, 0, 1]) for char in sequence])

# Dataset class
class SequenceDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = torch.tensor(self.sequences[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return sequence, label

# Load data function
def load_data(data_folder):
    sequences = []
    labels = []
    class_names = []
    for idx, file_name in enumerate(os.listdir(data_folder)):
        if file_name.endswith('.csv'):
            file_path = os.path.join(data_folder, file_name)
            data = pd.read_csv(file_path, header=None)
            sequences.extend(data[0].tolist())
            labels.extend([idx] * len(data))
            class_names.append(os.path.splitext(file_name)[0])
    one_hot_sequences = np.array([one_hot_encode(seq) for seq in sequences])
    return one_hot_sequences, np.array(labels), class_names

# Self-Attention Mechanism
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(SelfAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        x = x.permute(1, 0, 2)
        attn_output, _ = self.attention(x, x, x)
        return attn_output.permute(1, 0, 2)

# Multi-kernel CNN with Residual Connections
class MultiKernelCNN(nn.Module):
    def __init__(self, input_channels, output_channels, use_multi_kernel=True):
        super(MultiKernelCNN, self).__init__()
        self.use_multi_kernel = use_multi_kernel
        if use_multi_kernel:
            self.conv3 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
            self.conv5 = nn.Conv1d(input_channels, output_channels, kernel_size=5, padding=2)
            self.conv7 = nn.Conv1d(input_channels, output_channels, kernel_size=7, padding=3)
        else:
            self.conv3 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.residual = nn.Conv1d(input_channels, output_channels * (3 if use_multi_kernel else 1), kernel_size=1)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        residual = self.residual(x)
        if self.use_multi_kernel:
            x1 = self.relu(self.conv3(x))
            x2 = self.relu(self.conv5(x))
            x3 = self.relu(self.conv7(x))
            x = torch.cat((x1, x2, x3), dim=1)
        else:
            x = self.relu(self.conv3(x))
        x = self.pool(x)
        x = self.dropout(x)
        if x.shape != residual.shape:
            residual = residual[:, :, :x.shape[2]]
        x = x + residual
        return x

# CNN Attention Mechanism
class CNN_Attention(nn.Module):
    def __init__(self, channel_dim):
        super(CNN_Attention, self).__init__()
        self.channel_dim = channel_dim
        self.query = nn.Conv1d(channel_dim, channel_dim // 8, kernel_size=1)
        self.key = nn.Conv1d(channel_dim, channel_dim // 8, kernel_size=1)
        self.value = nn.Conv1d(channel_dim, channel_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, seq_len = x.size()
        query = self.query(x).view(batch_size, -1, seq_len)
        key = self.key(x).view(batch_size, -1, seq_len)
        value = self.value(x).view(batch_size, -1, seq_len)
        attention_scores = torch.bmm(query.permute(0, 2, 1), key)
        attention_scores = F.softmax(attention_scores, dim=-1)
        out = torch.bmm(value, attention_scores.permute(0, 2, 1))
        out = self.gamma * out + x
        return out

# CNN-LSTM Hybrid Model with Self-Attention
class CNN_LSTM_Model(nn.Module):
    def __init__(self, num_classes):
        super(CNN_LSTM_Model, self).__init__()
        self.multi_kernel_cnn1 = MultiKernelCNN(input_channels=5, output_channels=128, use_multi_kernel=False)
        self.bn1 = nn.BatchNorm1d(128)
        self.multi_kernel_cnn2 = MultiKernelCNN(input_channels=128, output_channels=256, use_multi_kernel=True)
        self.cnn_attention2 = CNN_Attention(channel_dim=256 * 3)
        self.bn2 = nn.BatchNorm1d(256 * 3)
        self.residual = nn.Conv1d(in_channels=5, out_channels=256 * 3, kernel_size=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.lstm = nn.LSTM(input_size=256 * 3, hidden_size=256, batch_first=True, bidirectional=True)
        self.bn_lstm = nn.BatchNorm1d(256 * 2)
        self.self_attention = SelfAttention(embed_dim=256 * 2, num_heads=4)
        self.dropout_lstm = nn.Dropout(0.3)
        self.fc1 = nn.Linear(256 * 2, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 512)
        self.bn_fc2 = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, num_classes)
        self.relu = nn.ReLU()
        self.dropout_fc = nn.Dropout(0.3)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        residual = self.residual(x)
        x = self.multi_kernel_cnn1(x)
        x = self.bn1(x)
        x = self.multi_kernel_cnn2(x)
        x = self.cnn_attention2(x)
        x = self.bn2(x)
        x = self.pool(x)
        x = self.dropout_fc(x)
        if x.shape != residual.shape:
            residual = residual[:, :, :x.shape[2]]
        x = x + residual
        x = x.permute(0, 2, 1)
        lstm_out, (hn, _) = self.lstm(x)
        lstm_out = lstm_out.permute(0, 2, 1)
        lstm_out = self.bn_lstm(lstm_out)
        lstm_out = lstm_out.permute(0, 2, 1)
        lstm_out = self.dropout_lstm(lstm_out)
        attn_out = self.self_attention(lstm_out)
        context_vector = torch.mean(attn_out, dim=1)
        x = self.fc1(context_vector)
        x = self.bn_fc1(x)
        x = self.relu(x)
        x = self.dropout_fc(x)
        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = self.relu(x)
        x = self.dropout_fc(x)
        x = self.fc3(x)
        return x

# Main execution
if __name__ == "__main__":
    # Load and prepare data
    sequences, labels, class_names = load_data(DATA_FOLDER)
    print(f"Sequences shape: {sequences.shape}, Labels shape: {labels.shape}")
    print(f"Class distribution: {np.bincount(labels)}")

    # Split data
    _, X_test, _, y_test = train_test_split(
        sequences, labels, test_size=0.2, random_state=42, stratify=labels
    )

    # Create test dataset and loader
    test_dataset = SequenceDataset(X_test, y_test)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    # Initialize device and model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CNN_LSTM_Model(num_classes=len(class_names)).to(device)

    # Load final model
    try:
        model.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=device, weights_only=True))
        print(f"Successfully loaded final model from: {FINAL_MODEL_PATH}")
    except Exception as e:
        print(f"Error loading model with weights_only=True: {e}")
        try:
            model.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=device, weights_only=False))
            print("Successfully loaded with weights_only=False")
        except Exception as e2:
            raise RuntimeError(f"Failed to load model: {e2}")


import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader

def compute_attention_heatmaps(model, dataloader, device):
    """
    Compute CNN and LSTM attention heatmaps for the given dataloader using the trained model.

    Args:
        model: Trained model.
        dataloader: DataLoader for the test set.
        device: Device (e.g., 'cuda' or 'cpu').

    Returns:
        cnn_attention_weights: List of CNN attention weights for each sequence.
        lstm_attention_weights: List of LSTM attention weights for each sequence.
        sequences: List of sequences.
        labels: List of labels.
    """
    model.eval()
    cnn_attention_weights = []
    lstm_attention_weights = []
    sequences = []
    labels = []

    for batch_sequences, batch_labels in dataloader:
        batch_sequences, batch_labels = batch_sequences.to(device), batch_labels.to(device)

        # Forward pass
        with torch.no_grad():
            # Pass through the first CNN layer
            x = batch_sequences.permute(0, 2, 1)
            residual = model.residual(x)
            x = model.multi_kernel_cnn1(x)

            # Pass through the second CNN layer (multi-kernel)
            x = model.multi_kernel_cnn2(x)
            x = model.pool(x)
            x = model.dropout_fc(x)
            if x.shape != residual.shape:
                residual = residual[:, :, :x.shape[2]]
            x = x + residual

            # Extract CNN attention weights (feature maps from the second CNN layer)
            cnn_attn = x.cpu().numpy()  # Shape: (batch_size, channels, sequence_length)
            cnn_attn = cnn_attn.mean(axis=1)  # Average across channels

            # Pass through the LSTM layer
            x = x.permute(0, 2, 1)
            lstm_out, (hn, _) = model.lstm(x)
            lstm_out = model.dropout_lstm(lstm_out)

            # Pass through the Self-Attention layer
            attn_output, attn_weights = model.self_attention.attention(
                lstm_out.permute(1, 0, 2),  # (seq_len, batch, embed_dim)
                lstm_out.permute(1, 0, 2),
                lstm_out.permute(1, 0, 2)
            )
            attn_output = attn_output.permute(1, 0, 2)  # (batch, seq_len, embed_dim)

            # Extract LSTM attention weights
            lstm_attn = attn_weights.cpu().numpy()  # Shape: (batch_size, seq_len, seq_len)
            lstm_attn = lstm_attn.mean(axis=1)  # Average across sequence positions

        # Append results
        cnn_attention_weights.extend(cnn_attn)
        lstm_attention_weights.extend(lstm_attn)
        sequences.extend(batch_sequences.cpu().numpy())
        labels.extend(batch_labels.cpu().numpy())

    return cnn_attention_weights, lstm_attention_weights, sequences, labels


def compute_group_attention_heatmaps(attention_weights, labels, class_names, sequence_length=300):
    """
    Compute group-level attention heatmaps by averaging attention weights for each group.

    Args:
        attention_weights: List of attention weights.
        labels: List of labels.
        class_names: List of class names.
        sequence_length: Length of the sequences (default: 300 bp).

    Returns:
        group_attention_heatmaps: Dictionary of average attention heatmaps for each group.
    """
    group_attention_heatmaps = {group: np.zeros(sequence_length) for group in range(len(class_names))}
    group_counts = {group: 0 for group in range(len(class_names))}

    # Group attention weights by class and sum them
    for attn_weights, label in zip(attention_weights, labels):
        # Resize attention weights to the original sequence length (300)
        attn_weights_resized = F.interpolate(
            torch.tensor(attn_weights).unsqueeze(0).unsqueeze(0),  # Add batch and channel dimensions
            size=sequence_length,
            mode='linear',
            align_corners=False
        ).squeeze(0).squeeze(0).numpy()  # Remove batch and channel dimensions

        # Sum resized attention weights
        group_attention_heatmaps[label] += attn_weights_resized
        group_counts[label] += 1

    # Compute average attention heatmap for each group
    for group in group_attention_heatmaps:
        group_attention_heatmaps[group] /= group_counts[group]  # Average across sequences

    return group_attention_heatmaps



def plot_attention_heatmaps(group_attention_heatmaps, class_names, heatmap_type="CNN", save_dir="attention_heatmaps"):
    """
    Plot group-level attention heatmaps and save them with high resolution.
    Uses consistent color scale ranges for all CNN maps and all LSTM maps.

    Args:
        group_attention_heatmaps: Dictionary of average attention heatmaps for each group.
        class_names: List of class names.
        heatmap_type: Type of attention heatmap ("CNN" or "LSTM").
        save_dir: Directory to save the heatmaps.
    """
    import os

    # Create the save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Calculate global min/max for consistent color scaling across all maps of this type
    all_values = np.concatenate([heatmap for heatmap in group_attention_heatmaps.values()])
    vmin = np.min(all_values)
    vmax = np.max(all_values)

    for group, group_name in enumerate(class_names):
        # Plot Group-Level Attention Heatmap
        plt.figure(figsize=(10, 2))
        sns.heatmap(
            group_attention_heatmaps[group].reshape(1, -1),  # Reshape to 2D for heatmap
            cmap="viridis",
            cbar=True,
            yticklabels=False,  # Disable default y-axis tick labels
            vmin=vmin,  # Set consistent min
            vmax=vmax   # Set consistent max
        )
        plt.title(f"Group-Level {heatmap_type} Attention Heatmap for {group_name} (Test Set)", fontsize=12)
        plt.xlabel("Sequence Position (bp)", fontsize=14, labelpad=13)  # Add fontsize and labelpad
        plt.ylabel("Average Attention", fontsize=12, labelpad=13, rotation=90)  # Add y-axis label with rotation

        # Set x-axis ticks to match the sequence length (0 to 299 bp)
        x_ticks = np.arange(0, 300, step=20)  # Ticks at 0, 50, 100, 150, 200, 250
        x_ticks = np.append(x_ticks, 299)  # Add the last position (299)
        x_tick_labels = [f"{-300 + x}" for x in x_ticks]  # Convert to negative values: -300, -250, ..., -50, -1
        x_tick_labels[-1] = "-1"  # Replace 0 with -1 for the last tick
        plt.xticks(x_ticks, labels=x_tick_labels, fontsize=12, rotation=60, ha='right')  # Rotate x-axis ticks by 30 degrees

        # Set y-axis ticks (no labels, since we use plt.ylabel)
        plt.yticks(ticks=[0], labels=[""], fontsize=12)  # Empty y-axis tick labels

        # Save the heatmap with high resolution
        save_path = os.path.join(save_dir, f"{heatmap_type}_attention_heatmap_{group_name}.png")
        plt.savefig(save_path, dpi=300, bbox_inches="tight")  # Save with high resolution (300 DPI)
        print(f"Heatmap saved to {save_path}")

        plt.show()


# Create a Dataset for the test set
test_dataset = SequenceDataset(X_test, y_test)

# Initialize DataLoader for the test set
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Compute attention weights for the test set
cnn_attention_weights, lstm_attention_weights, test_sequences, test_labels = compute_attention_heatmaps(model, test_loader, device)

# Compute group-level attention heatmaps
cnn_group_attention_heatmaps = compute_group_attention_heatmaps(cnn_attention_weights, test_labels, class_names, sequence_length=300)
lstm_group_attention_heatmaps = compute_group_attention_heatmaps(lstm_attention_weights, test_labels, class_names, sequence_length=300)

# Plot CNN attention heatmaps
plot_attention_heatmaps(cnn_group_attention_heatmaps, class_names, heatmap_type="CNN")

# Plot LSTM attention heatmaps
plot_attention_heatmaps(lstm_group_attention_heatmaps, class_names, heatmap_type="LSTM")


In [None]:
 The code to have the Perturbation test analysis

# Import libraries
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, precision_recall_curve, roc_auc_score, f1_score, auc, matthews_corrcoef
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

# Parameters
SEQ_LENGTH = 300
BATCH_SIZE = 64
DATA_FOLDER = '/kaggle/input/plant-dataset'
FINAL_MODEL_PATH = '/kaggle/input/final-models/Final_Plant_model.pth'

# One-hot encoding function
def one_hot_encode(sequence, seq_length=SEQ_LENGTH):
    nucleotide_map = {'A': [1, 0, 0, 0, 0], 'T': [0, 1, 0, 0, 0],
                      'C': [0, 0, 1, 0, 0], 'G': [0, 0, 0, 1, 0],
                      'N': [0, 0, 0, 0, 1]}
    sequence = sequence.upper().ljust(seq_length, 'N')[:seq_length]
    return np.array([nucleotide_map.get(char, [0, 0, 0, 0, 1]) for char in sequence])

# Dataset class
class SequenceDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = torch.tensor(self.sequences[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return sequence, label

# Load data function
def load_data(data_folder):
    sequences = []
    labels = []
    class_names = []
    for idx, file_name in enumerate(os.listdir(data_folder)):
        if file_name.endswith('.csv'):
            file_path = os.path.join(data_folder, file_name)
            data = pd.read_csv(file_path, header=None)
            sequences.extend(data[0].tolist())
            labels.extend([idx] * len(data))
            class_names.append(os.path.splitext(file_name)[0])
    one_hot_sequences = np.array([one_hot_encode(seq) for seq in sequences])
    return one_hot_sequences, np.array(labels), class_names

# Self-Attention Mechanism
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(SelfAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        x = x.permute(1, 0, 2)
        attn_output, _ = self.attention(x, x, x)
        return attn_output.permute(1, 0, 2)

# Multi-kernel CNN with Residual Connections
class MultiKernelCNN(nn.Module):
    def __init__(self, input_channels, output_channels, use_multi_kernel=True):
        super(MultiKernelCNN, self).__init__()
        self.use_multi_kernel = use_multi_kernel
        if use_multi_kernel:
            self.conv3 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
            self.conv5 = nn.Conv1d(input_channels, output_channels, kernel_size=5, padding=2)
            self.conv7 = nn.Conv1d(input_channels, output_channels, kernel_size=7, padding=3)
        else:
            self.conv3 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.residual = nn.Conv1d(input_channels, output_channels * (3 if use_multi_kernel else 1), kernel_size=1)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        residual = self.residual(x)
        if self.use_multi_kernel:
            x1 = self.relu(self.conv3(x))
            x2 = self.relu(self.conv5(x))
            x3 = self.relu(self.conv7(x))
            x = torch.cat((x1, x2, x3), dim=1)
        else:
            x = self.relu(self.conv3(x))
        x = self.pool(x)
        x = self.dropout(x)
        if x.shape != residual.shape:
            residual = residual[:, :, :x.shape[2]]
        x = x + residual
        return x

# CNN Attention Mechanism
class CNN_Attention(nn.Module):
    def __init__(self, channel_dim):
        super(CNN_Attention, self).__init__()
        self.channel_dim = channel_dim
        self.query = nn.Conv1d(channel_dim, channel_dim // 8, kernel_size=1)
        self.key = nn.Conv1d(channel_dim, channel_dim // 8, kernel_size=1)
        self.value = nn.Conv1d(channel_dim, channel_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, seq_len = x.size()
        query = self.query(x).view(batch_size, -1, seq_len)
        key = self.key(x).view(batch_size, -1, seq_len)
        value = self.value(x).view(batch_size, -1, seq_len)
        attention_scores = torch.bmm(query.permute(0, 2, 1), key)
        attention_scores = F.softmax(attention_scores, dim=-1)
        out = torch.bmm(value, attention_scores.permute(0, 2, 1))
        out = self.gamma * out + x
        return out

# CNN-LSTM Hybrid Model with Self-Attention
class CNN_LSTM_Model(nn.Module):
    def __init__(self, num_classes):
        super(CNN_LSTM_Model, self).__init__()
        self.multi_kernel_cnn1 = MultiKernelCNN(input_channels=5, output_channels=128, use_multi_kernel=False)
        self.bn1 = nn.BatchNorm1d(128)
        self.multi_kernel_cnn2 = MultiKernelCNN(input_channels=128, output_channels=256, use_multi_kernel=True)
        self.cnn_attention2 = CNN_Attention(channel_dim=256 * 3)
        self.bn2 = nn.BatchNorm1d(256 * 3)
        self.residual = nn.Conv1d(in_channels=5, out_channels=256 * 3, kernel_size=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.lstm = nn.LSTM(input_size=256 * 3, hidden_size=256, batch_first=True, bidirectional=True)
        self.bn_lstm = nn.BatchNorm1d(256 * 2)
        self.self_attention = SelfAttention(embed_dim=256 * 2, num_heads=4)
        self.dropout_lstm = nn.Dropout(0.3)
        self.fc1 = nn.Linear(256 * 2, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 512)
        self.bn_fc2 = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, num_classes)
        self.relu = nn.ReLU()
        self.dropout_fc = nn.Dropout(0.3)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        residual = self.residual(x)
        x = self.multi_kernel_cnn1(x)
        x = self.bn1(x)
        x = self.multi_kernel_cnn2(x)
        x = self.cnn_attention2(x)
        x = self.bn2(x)
        x = self.pool(x)
        x = self.dropout_fc(x)
        if x.shape != residual.shape:
            residual = residual[:, :, :x.shape[2]]
        x = x + residual
        x = x.permute(0, 2, 1)
        lstm_out, (hn, _) = self.lstm(x)
        lstm_out = lstm_out.permute(0, 2, 1)
        lstm_out = self.bn_lstm(lstm_out)
        lstm_out = lstm_out.permute(0, 2, 1)
        lstm_out = self.dropout_lstm(lstm_out)
        attn_out = self.self_attention(lstm_out)
        context_vector = torch.mean(attn_out, dim=1)
        x = self.fc1(context_vector)
        x = self.bn_fc1(x)
        x = self.relu(x)
        x = self.dropout_fc(x)
        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = self.relu(x)
        x = self.dropout_fc(x)
        x = self.fc3(x)
        return x

# Main execution
if __name__ == "__main__":
    # Load and prepare data
    sequences, labels, class_names = load_data(DATA_FOLDER)
    print(f"Sequences shape: {sequences.shape}, Labels shape: {labels.shape}")
    print(f"Class distribution: {np.bincount(labels)}")

    # Split data
    _, X_test, _, y_test = train_test_split(
        sequences, labels, test_size=0.2, random_state=42, stratify=labels
    )

    # Create test dataset and loader
    test_dataset = SequenceDataset(X_test, y_test)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    # Initialize device and model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CNN_LSTM_Model(num_classes=len(class_names)).to(device)

    # Load final model
    try:
        model.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=device, weights_only=True))
        print(f"Successfully loaded final model from: {FINAL_MODEL_PATH}")
    except Exception as e:
        print(f"Error loading model with weights_only=True: {e}")
        try:
            model.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=device, weights_only=False))
            print("Successfully loaded with weights_only=False")
        except Exception as e2:
            raise RuntimeError(f"Failed to load model: {e2}")

import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import os
from torch.profiler import profile, record_function, ProfilerActivity
import time
import zipfile

class PerturbationTester:
    def __init__(self, model, dataloader, class_names, sequence_length=300):
        self.model = model
        self.dataloader = dataloader
        self.class_names = class_names
        self.sequence_length = sequence_length
        self.num_classes = len(class_names)
        self.original_predictions_cache = {}

        # Define paths for saving and loading checkpoints
        self.checkpoint_dir = "/kaggle/working/perturbation_checkpoints"  # For saving
        self.input_checkpoint_dir = "/kaggle/input/plant-analysis"  # For loading

        # Create save directory if it doesn't exist
        os.makedirs(self.checkpoint_dir, exist_ok=True)

        # Check if input directory exists (for resuming)
        self.can_resume = os.path.exists(self.input_checkpoint_dir)
        if self.can_resume:
            print(f"Found checkpoint directory at {self.input_checkpoint_dir} - can resume")
        else:
            print("No existing checkpoints found - starting fresh")

    def get_checkpoint_path(self, group, size, pass_num, input_dir=False):
        """Get path for either saving or loading checkpoints"""
        base_dir = self.input_checkpoint_dir if input_dir else self.checkpoint_dir
        return os.path.join(base_dir, f"group_{group}_size_{size}_pass_{pass_num}.pt")

    def mutate_sequences_vectorized(self, sequences, start_pos, end_pos):
        """Fully vectorized sequence mutation"""
        mutated = sequences.clone()
        batch_size, seq_len, num_channels = sequences.shape

        # Pre-compute all possible mutations on same device as input
        nucleotides = torch.tensor([
            [1, 0, 0, 0, 0], [0, 1, 0, 0, 0],
            [0, 0, 1, 0, 0], [0, 0, 0, 1, 0]
        ], dtype=torch.float32, device=sequences.device)

        # Generate random indices for entire batch
        rand_indices = torch.randint(0, 4, (batch_size,), device=sequences.device)

        # Create mutation block
        mutation_length = end_pos - start_pos
        random_nucleotides = nucleotides[rand_indices].unsqueeze(1)  # [batch, 1, 5]
        random_nucleotides = random_nucleotides.expand(-1, mutation_length, -1)  # [batch, L, 5]

        # Apply mutations
        mutated[:, start_pos:end_pos, :] = random_nucleotides
        return mutated

    def cache_original_predictions(self):
        """Cache original predictions for all batches"""
        self.model.eval()
        self.original_predictions_cache = {}

        with torch.no_grad():
            for batch_idx, (sequences, labels) in enumerate(self.dataloader):
                sequences = sequences.float()
                outputs = self.model(sequences)
                _, preds = torch.max(outputs, 1)
                self.original_predictions_cache[batch_idx] = {
                    'sequences': sequences,
                    'labels': labels,
                    'correct': (preds == labels).float()
                }

    def save_checkpoint(self, results, current_batch, total_batches, mutation_sizes, pass_num):
        """Save progress every 10% of batches processed"""
        progress = current_batch / total_batches
        if progress > 0 and progress % 0.1 < 0.01:  # Every ~10% progress
            for group in range(self.num_classes):
                for size in mutation_sizes:
                    checkpoint_path = self.get_checkpoint_path(group, size, pass_num)
                    try:
                        torch.save({
                            'results': results[group][size],
                            'current_batch': current_batch,
                            'group_counts': sum(1 for _, labels in self.dataloader
                                              if (labels == group).any()),
                            'pass_num': pass_num,
                            'timestamp': time.time()
                        }, checkpoint_path)
                    except Exception as e:
                        print(f"Error saving checkpoint {checkpoint_path}: {e}")
            print(f"Saved checkpoint at {progress*100:.0f}% completion (Pass {pass_num})")

    def load_checkpoint(self, mutation_sizes, pass_num):
        """Attempt to load existing checkpoints from either input or working directory"""
        results = {group: {size: torch.zeros(self.sequence_length)
                         for size in mutation_sizes}
                  for group in range(self.num_classes)}
        group_counts = torch.zeros(self.num_classes)
        last_batch = 0
        loaded_from = None

        # Try loading from input directory first (previous session)
        if self.can_resume:
            try:
                for group in range(self.num_classes):
                    for size in mutation_sizes:
                        checkpoint_path = self.get_checkpoint_path(group, size, pass_num, input_dir=True)
                        if os.path.exists(checkpoint_path):
                            try:
                                checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
                                if not isinstance(checkpoint, dict) or 'results' not in checkpoint:
                                    print(f"Invalid checkpoint format in {checkpoint_path}")
                                    continue

                                results[group][size] = checkpoint['results']
                                group_counts[group] = checkpoint.get('group_counts', 0)
                                last_batch = max(last_batch, checkpoint.get('current_batch', 0))
                                loaded_from = self.input_checkpoint_dir
                            except Exception as e:
                                print(f"Error loading checkpoint {checkpoint_path}: {e}")
                                if os.path.exists(checkpoint_path):
                                    os.remove(checkpoint_path)

                if last_batch > 0:
                    print(f"Resuming from batch {last_batch} (Pass {pass_num}) from input directory")
            except Exception as e:
                print(f"Error loading from input directory: {e}")

        # If nothing found in input directory, try working directory (current session)
        if loaded_from is None:
            try:
                for group in range(self.num_classes):
                    for size in mutation_sizes:
                        checkpoint_path = self.get_checkpoint_path(group, size, pass_num)
                        if os.path.exists(checkpoint_path):
                            try:
                                checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
                                if not isinstance(checkpoint, dict) or 'results' not in checkpoint:
                                    print(f"Invalid checkpoint format in {checkpoint_path}")
                                    continue

                                results[group][size] = checkpoint['results']
                                group_counts[group] = checkpoint.get('group_counts', 0)
                                last_batch = max(last_batch, checkpoint.get('current_batch', 0))
                                loaded_from = self.checkpoint_dir
                            except Exception as e:
                                print(f"Error loading checkpoint {checkpoint_path}: {e}")
                                if os.path.exists(checkpoint_path):
                                    os.remove(checkpoint_path)

                if last_batch > 0:
                    print(f"Resuming from batch {last_batch} (Pass {pass_num}) from working directory")
            except Exception as e:
                print(f"Error loading from working directory: {e}")

        # Prepare for resuming if we found any checkpoints
        if last_batch > 0:
            # Create iterator from cache
            cache_items = list(self.original_predictions_cache.items())
            dataloader_iter = iter(cache_items[last_batch:])
            return results, group_counts, dataloader_iter, last_batch

        return results, group_counts, None, 0

    def transfer_checkpoints_to_input(self):
        """Transfer checkpoints from working to input directory (for next session)"""
        if not os.path.exists(self.checkpoint_dir):
            return False

        try:
            # Create a zip archive of the checkpoints
            zip_path = os.path.join("/kaggle/working", "perturbation_checkpoints.zip")
            with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
                for root, _, files in os.walk(self.checkpoint_dir):
                    for file in files:
                        zipf.write(os.path.join(root, file),
                                  os.path.relpath(os.path.join(root, file),
                                  os.path.join(self.checkpoint_dir)))

            print(f"Checkpoints archived to {zip_path}. Please:")
            print("1. Download this zip file")
            print("2. Create a new Kaggle dataset with it")
            print("3. Add it as input to your next session")
            return True
        except Exception as e:
            print(f"Error transferring checkpoints: {e}")
            return False

    def clear_checkpoints(self, pass_num):
        """Remove checkpoint files after successful completion"""
        for group in range(self.num_classes):
            for size in [1, 4, 8, 12]:  # All possible mutation sizes
                checkpoint_path = self.get_checkpoint_path(group, size, pass_num)
                if os.path.exists(checkpoint_path):
                    try:
                        os.remove(checkpoint_path)
                    except Exception as e:
                        print(f"Error removing checkpoint {checkpoint_path}: {e}")

    def profile_execution(self):
        """Profile the execution to identify bottlenecks"""
        with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
            test_results = self.run_perturbation_test(
                mutation_sizes=[4],  # Use single size for profiling
                window_step=5,
                num_workers=2,
                refine_iterations=0
            )

        print(prof.key_averages().table(sort_by="cpu_time_total"))
        return test_results

    def run_perturbation_test(self, mutation_sizes=[1, 4, 8, 12],
                            window_step=3, num_workers=2, refine_iterations=1):
        """Main perturbation test with all optimizations integrated"""
        start_time = time.time()

        # First pass: Coarse screening with window sampling
        try:
            coarse_results = self._run_pass(
                mutation_sizes=mutation_sizes,
                window_step=window_step,
                num_workers=num_workers,
                pass_num=0
            )
        except Exception as e:
            print(f"Error during first pass: {e}")
            print("Attempting to recover...")
            # Try to load any partial results
            coarse_results, _, _, _ = self.load_checkpoint(mutation_sizes, 0)
            if not any(torch.any(r) for group in coarse_results.values() for r in group.values()):
                raise RuntimeError("Could not recover from error - no valid checkpoints found")

        # Refinement passes for important regions
        for iteration in range(refine_iterations):
            try:
                important_regions = self._identify_important_regions(coarse_results)
                refined_results = self._run_pass(
                    mutation_sizes=mutation_sizes,
                    window_step=1,  # Full resolution for important regions
                    num_workers=num_workers,
                    focus_regions=important_regions,
                    pass_num=iteration+1
                )
                coarse_results = self._merge_results(coarse_results, refined_results)
            except Exception as e:
                print(f"Error during refinement pass {iteration+1}: {e}")
                print("Continuing with existing results...")
                break

        # Generate final plots
        self._plot_results(coarse_results)

        # Clear checkpoints after successful completion
        for pass_num in range(refine_iterations+1):
            self.clear_checkpoints(pass_num)

        print(f"Total execution time: {time.time()-start_time:.2f} seconds")
        return coarse_results

    def _run_pass(self, mutation_sizes, window_step, num_workers, focus_regions=None, pass_num=0):
        """Run a single pass of the perturbation test with checkpointing"""
        if not self.original_predictions_cache:
            self.cache_original_predictions()

        # Initialize results with error handling
        try:
            results, group_counts, dataloader_iter, start_batch = self.load_checkpoint(mutation_sizes, pass_num)
            if dataloader_iter is None:
                dataloader_iter = enumerate(self.original_predictions_cache.items())
                start_batch = 0
        except Exception as e:
            print(f"Error initializing results: {e}")
            print("Starting fresh pass...")
            results = {group: {size: torch.zeros(self.sequence_length)
                             for size in mutation_sizes}
                      for group in range(self.num_classes)}
            group_counts = torch.zeros(self.num_classes)
            dataloader_iter = enumerate(self.original_predictions_cache.items())
            start_batch = 0

        total_batches = len(self.original_predictions_cache)

        with torch.no_grad():
            for batch_idx, batch_data in dataloader_iter:
                if isinstance(batch_data, tuple) and len(batch_data) == 2:
                    # Handle case where we get (key, value) from items()
                    _, batch_data = batch_data

                if batch_idx < start_batch:
                    continue

                sequences = batch_data['sequences']
                labels = batch_data['labels']
                original_correct = batch_data['correct']

                # Get all positions to test
                positions = self._get_positions_to_test(window_step, focus_regions)

                # Process all mutations for all sizes in parallel
                with ThreadPoolExecutor(max_workers=num_workers) as executor:
                    # Create all mutation tasks
                    args_list = [(sequences, labels, original_correct, size, pos)
                               for size in mutation_sizes
                               for pos in positions]

                    # Process in batches for better memory efficiency
                    batch_size = 32  # Adjust based on available memory
                    for i in range(0, len(args_list), batch_size):
                        batch_args = args_list[i:i+batch_size]
                        batch_results = list(executor.map(self._process_position, batch_args))

                        # Update results
                        for (pos, importance_scores), (_, _, _, size, _) in zip(batch_results, batch_args):
                            for group in range(self.num_classes):
                                group_mask = (labels == group)
                                if group_mask.any():
                                    results[group][size][pos] += importance_scores[group_mask].sum()

                # Update group counts
                for group in range(self.num_classes):
                    group_counts[group] += (labels == group).sum()

                # Save checkpoint periodically
                self.save_checkpoint(results, batch_idx+1, total_batches, mutation_sizes, pass_num)

        # Normalize results
        for group in range(self.num_classes):
            if group_counts[group] > 0:
                for size in mutation_sizes:
                    results[group][size] /= group_counts[group]

        return results

    def _get_positions_to_test(self, window_step, focus_regions=None):
        """Determine which positions to test based on focus regions"""
        if focus_regions:
            positions = set()
            for start, end in focus_regions:
                positions.update(range(start, end))
            return sorted(positions)
        return range(0, self.sequence_length, window_step)

    def _process_position(self, args):
        """Process a single mutation position"""
        sequences, labels, original_correct, size, pos = args
        mutated_sequences = self.mutate_sequences_vectorized(
            sequences, pos, min(pos + size, self.sequence_length))
        outputs = self.model(mutated_sequences)
        _, preds = torch.max(outputs, 1)
        mutated_correct = (preds == labels).float()
        return pos, original_correct - mutated_correct

    def _identify_important_regions(self, results, threshold=0.1):
        """Vectorized important region identification"""
        all_scores = []

        # Collect all scores
        for group in range(self.num_classes):
            for size in results[group]:
                all_scores.append(results[group][size].numpy())

        if not all_scores:
            return [(0, self.sequence_length)]

        # Combine scores and find important regions
        combined_scores = np.max(np.stack(all_scores), axis=0)
        above_threshold = np.concatenate(([False], combined_scores > threshold, [False]))

        # Find region boundaries
        diff = np.diff(above_threshold.astype(int))
        starts = np.where(diff > 0)[0]
        ends = np.where(diff < 0)[0]

        # Merge overlapping regions
        if len(starts) == 0:
            return [(0, self.sequence_length)]

        important_regions = list(zip(starts, ends))
        important_regions.sort()
        merged = [important_regions[0]]

        for current in important_regions[1:]:
            last = merged[-1]
            if current[0] <= last[1]:
                merged[-1] = (last[0], max(last[1], current[1]))
            else:
                merged.append(current)

        return merged

    def _merge_results(self, coarse, refined):
        """Merge coarse and refined results"""
        merged = {group: {} for group in range(self.num_classes)}

        for group in range(self.num_classes):
            for size in coarse[group]:
                merged[group][size] = torch.where(
                    refined[group][size] != 0,
                    refined[group][size],
                    coarse[group][size]
                )

        return merged

    def _plot_results(self, results):
        """Generate plots of the results and save high-resolution versions"""
        # Create directory for saved plots if it doesn't exist
        plot_dir = "/kaggle/working/perturbation_plots"
        os.makedirs(plot_dir, exist_ok=True)

        for group, group_name in enumerate(self.class_names):
            plt.figure(figsize=(12, 4), dpi=300)  # High DPI for better quality
            ax = plt.gca()  # Get current axis

            for size in results[group]:
                y = results[group][size].numpy()
                # Apply smoothing
                window_size = min(11, self.sequence_length // 10)
                if window_size > 1:
                    y = np.convolve(y, np.ones(window_size)/window_size, mode='same')
                plt.plot(range(self.sequence_length), y, label=f'Mutation Size: {size} nt', linewidth=2)

            plt.title(f"Importance Score Over Sequence for {group_name}", fontsize=16, pad=20)
            plt.xlabel("Sequence Position (bp)", fontsize=18, labelpad=15)
            plt.ylabel("Importance Score", fontsize=18, labelpad=13)

            # Set y-axis limits and tick fontsize
            plt.ylim(-0.09, 2.00)
            plt.yticks(fontsize=14)

            # # Modified x-ticks to show negative values
            x_ticks = np.arange(0, self.sequence_length + 1, step=20)  # Ticks at 0, 20, 40, ..., 300
            x_tick_labels = [f"{-self.sequence_length + x}" for x in x_ticks]  # Convert to negative values
            x_tick_labels[-1] = "-1"  # Replace 0 with -1 for the last tick

            plt.xticks(x_ticks, labels=x_tick_labels, fontsize=16, rotation=30, ha='right')

            # Remove top and right spines
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)

            # Adjust legend
            plt.legend(fontsize=14, framealpha=0.9, loc='upper left', bbox_to_anchor=(0.02, 0.98))
            # Set x-axis limits to ensure all ticks are visible
            plt.xlim(0, self.sequence_length)

            # Tight layout to prevent label cutoff
            plt.tight_layout()

            # Save high-resolution version
            plot_path = os.path.join(plot_dir, f"{group_name}_importance_plot.png")
            plt.savefig(plot_path, dpi=300, bbox_inches='tight', format='png')
            print(f"Saved high-resolution plot to: {plot_path}")

            # Show plot
            plt.show()



# Usage Example:
tester = PerturbationTester(model, test_loader, class_names)

try:
    # Option 1: Run with profiling first to identify bottlenecks
    # profiled_results = tester.profile_execution()

    # Option 2: Run full optimized test
    final_results = tester.run_perturbation_test(
        mutation_sizes=[1, 4, 8, 12],
        window_step=3,          # Start with coarse sampling
        num_workers=4,         # Adjust based on the CPU
        refine_iterations=1    # Add refinement passes
    )

    # Optionally transfer checkpoints for next session
    tester.transfer_checkpoints_to_input()
except Exception as e:
    print(f"Run interrupted: {e}")
    print("Progress has been saved. When you restart, it will resume from last checkpoint.")

In [None]:
 The code to have the Group saliency map analysis

# Import libraries
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, precision_recall_curve, roc_auc_score, f1_score, auc, matthews_corrcoef
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

# Parameters
SEQ_LENGTH = 300
BATCH_SIZE = 64
DATA_FOLDER = '/kaggle/input/algae-dataset'
FINAL_MODEL_PATH = '/kaggle/input/final-models/Final_Algae_model.pth'

# One-hot encoding function
def one_hot_encode(sequence, seq_length=SEQ_LENGTH):
    nucleotide_map = {'A': [1, 0, 0, 0, 0], 'T': [0, 1, 0, 0, 0],
                      'C': [0, 0, 1, 0, 0], 'G': [0, 0, 0, 1, 0],
                      'N': [0, 0, 0, 0, 1]}
    sequence = sequence.upper().ljust(seq_length, 'N')[:seq_length]
    return np.array([nucleotide_map.get(char, [0, 0, 0, 0, 1]) for char in sequence])

# Dataset class
class SequenceDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = torch.tensor(self.sequences[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return sequence, label

# Load data function
def load_data(data_folder):
    sequences = []
    labels = []
    class_names = []
    for idx, file_name in enumerate(os.listdir(data_folder)):
        if file_name.endswith('.csv'):
            file_path = os.path.join(data_folder, file_name)
            data = pd.read_csv(file_path, header=None)
            sequences.extend(data[0].tolist())
            labels.extend([idx] * len(data))
            class_names.append(os.path.splitext(file_name)[0])
    one_hot_sequences = np.array([one_hot_encode(seq) for seq in sequences])
    return one_hot_sequences, np.array(labels), class_names

# Self-Attention Mechanism
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(SelfAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        x = x.permute(1, 0, 2)
        attn_output, _ = self.attention(x, x, x)
        return attn_output.permute(1, 0, 2)

# Multi-kernel CNN with Residual Connections
class MultiKernelCNN(nn.Module):
    def __init__(self, input_channels, output_channels, use_multi_kernel=True):
        super(MultiKernelCNN, self).__init__()
        self.use_multi_kernel = use_multi_kernel
        if use_multi_kernel:
            self.conv3 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
            self.conv5 = nn.Conv1d(input_channels, output_channels, kernel_size=5, padding=2)
            self.conv7 = nn.Conv1d(input_channels, output_channels, kernel_size=7, padding=3)
        else:
            self.conv3 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.residual = nn.Conv1d(input_channels, output_channels * (3 if use_multi_kernel else 1), kernel_size=1)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        residual = self.residual(x)
        if self.use_multi_kernel:
            x1 = self.relu(self.conv3(x))
            x2 = self.relu(self.conv5(x))
            x3 = self.relu(self.conv7(x))
            x = torch.cat((x1, x2, x3), dim=1)
        else:
            x = self.relu(self.conv3(x))
        x = self.pool(x)
        x = self.dropout(x)
        if x.shape != residual.shape:
            residual = residual[:, :, :x.shape[2]]
        x = x + residual
        return x

# CNN Attention Mechanism
class CNN_Attention(nn.Module):
    def __init__(self, channel_dim):
        super(CNN_Attention, self).__init__()
        self.channel_dim = channel_dim
        self.query = nn.Conv1d(channel_dim, channel_dim // 8, kernel_size=1)
        self.key = nn.Conv1d(channel_dim, channel_dim // 8, kernel_size=1)
        self.value = nn.Conv1d(channel_dim, channel_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, seq_len = x.size()
        query = self.query(x).view(batch_size, -1, seq_len)
        key = self.key(x).view(batch_size, -1, seq_len)
        value = self.value(x).view(batch_size, -1, seq_len)
        attention_scores = torch.bmm(query.permute(0, 2, 1), key)
        attention_scores = F.softmax(attention_scores, dim=-1)
        out = torch.bmm(value, attention_scores.permute(0, 2, 1))
        out = self.gamma * out + x
        return out

# CNN-LSTM Hybrid Model with Self-Attention
class CNN_LSTM_Model(nn.Module):
    def __init__(self, num_classes):
        super(CNN_LSTM_Model, self).__init__()
        self.multi_kernel_cnn1 = MultiKernelCNN(input_channels=5, output_channels=128, use_multi_kernel=False)
        self.bn1 = nn.BatchNorm1d(128)
        self.multi_kernel_cnn2 = MultiKernelCNN(input_channels=128, output_channels=256, use_multi_kernel=True)
        self.cnn_attention2 = CNN_Attention(channel_dim=256 * 3)
        self.bn2 = nn.BatchNorm1d(256 * 3)
        self.residual = nn.Conv1d(in_channels=5, out_channels=256 * 3, kernel_size=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.lstm = nn.LSTM(input_size=256 * 3, hidden_size=256, batch_first=True, bidirectional=True)
        self.bn_lstm = nn.BatchNorm1d(256 * 2)
        self.self_attention = SelfAttention(embed_dim=256 * 2, num_heads=4)
        self.dropout_lstm = nn.Dropout(0.3)
        self.fc1 = nn.Linear(256 * 2, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 512)
        self.bn_fc2 = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, num_classes)
        self.relu = nn.ReLU()
        self.dropout_fc = nn.Dropout(0.3)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        residual = self.residual(x)
        x = self.multi_kernel_cnn1(x)
        x = self.bn1(x)
        x = self.multi_kernel_cnn2(x)
        x = self.cnn_attention2(x)
        x = self.bn2(x)
        x = self.pool(x)
        x = self.dropout_fc(x)
        if x.shape != residual.shape:
            residual = residual[:, :, :x.shape[2]]
        x = x + residual
        x = x.permute(0, 2, 1)
        lstm_out, (hn, _) = self.lstm(x)
        lstm_out = lstm_out.permute(0, 2, 1)
        lstm_out = self.bn_lstm(lstm_out)
        lstm_out = lstm_out.permute(0, 2, 1)
        lstm_out = self.dropout_lstm(lstm_out)
        attn_out = self.self_attention(lstm_out)
        context_vector = torch.mean(attn_out, dim=1)
        x = self.fc1(context_vector)
        x = self.bn_fc1(x)
        x = self.relu(x)
        x = self.dropout_fc(x)
        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = self.relu(x)
        x = self.dropout_fc(x)
        x = self.fc3(x)
        return x

# Main execution
if __name__ == "__main__":
    # Load and prepare data
    sequences, labels, class_names = load_data(DATA_FOLDER)
    print(f"Sequences shape: {sequences.shape}, Labels shape: {labels.shape}")
    print(f"Class distribution: {np.bincount(labels)}")

    # Split data
    _, X_test, _, y_test = train_test_split(
        sequences, labels, test_size=0.2, random_state=42, stratify=labels
    )

    # Create test dataset and loader
    test_dataset = SequenceDataset(X_test, y_test)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    # Initialize device and model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CNN_LSTM_Model(num_classes=len(class_names)).to(device)

    # Load final model
    try:
        model.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=device, weights_only=True))
        print(f"Successfully loaded final model from: {FINAL_MODEL_PATH}")
    except Exception as e:
        print(f"Error loading model with weights_only=True: {e}")
        try:
            model.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=device, weights_only=False))
            print("Successfully loaded with weights_only=False")
        except Exception as e2:
            raise RuntimeError(f"Failed to load model: {e2}")

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Function to compute saliency maps
def compute_saliency_maps(model, dataloader, device):
    """
    Compute saliency maps for the given dataloader using the trained model.

    Args:
        model: Trained model.
        dataloader: DataLoader for the test set.
        device: Device (e.g., 'cuda' or 'cpu').

    Returns:
        saliency_maps: List of saliency maps for each sequence.
        sequences: List of sequences.
        labels: List of labels.
    """
    model.eval()
    saliency_maps = []
    sequences = []
    labels = []

    for batch_sequences, batch_labels in dataloader:
        batch_sequences, batch_labels = batch_sequences.to(device), batch_labels.to(device)
        batch_sequences.requires_grad_()  # Enable gradient computation for inputs

        # Forward pass
        outputs = model(batch_sequences)  # Outputs should be of shape (batch_size, num_classes)

        # Use the predicted class for saliency
        _, preds = torch.max(outputs, dim=1)  # Get predicted class indices

        # Select the output corresponding to the predicted class
        selected_outputs = outputs[torch.arange(outputs.size(0)), preds]  # Shape: (batch_size,)

        # Compute gradients of the selected outputs with respect to the input
        selected_outputs.sum().backward()  # Sum to get a scalar for backward pass
        gradients = batch_sequences.grad.data.abs().cpu().numpy()  # Take absolute value of gradients

        # Append results
        saliency_maps.extend(gradients)
        sequences.extend(batch_sequences.detach().cpu().numpy())
        labels.extend(batch_labels.cpu().numpy())

    return saliency_maps, sequences, labels
# Function to compute group-level saliency maps
def compute_group_saliency_maps(saliency_maps, labels, class_names, sequence_length=300):
    """
    Compute group-level saliency maps by averaging saliency maps for each group.

    Args:
        saliency_maps: List of saliency maps.
        labels: List of labels.
        class_names: List of class names.
        sequence_length: Length of the sequences (default: 300 bp).

    Returns:
        group_saliency_maps: Dictionary of average saliency maps for each group.
    """
    group_saliency_maps = {group: np.zeros(sequence_length) for group in range(len(class_names))}
    group_counts = {group: 0 for group in range(len(class_names))}

    # Group saliency maps by class and sum them
    for saliency_map, label in zip(saliency_maps, labels):
        # Sum saliency values for each position (0 to 299 bp)
        group_saliency_maps[label] += saliency_map.sum(axis=1)  # Sum across channels (e.g., one-hot encoding)
        group_counts[label] += 1

    # Compute average saliency map for each group
    for group in group_saliency_maps:
        group_saliency_maps[group] /= group_counts[group]  # Average across sequences

    return group_saliency_maps


# Function to plot group-level saliency maps
def plot_group_saliency_maps(group_saliency_maps, class_names, sequence_length=300):
    """
    Plot group-level saliency maps with enhanced formatting and consistent y-axis (0.0-0.8).
    Saves each map separately with high resolution (600 DPI).

    Args:
        group_saliency_maps: Dictionary of average saliency maps for each group
        class_names: List of class names
        sequence_length: Length of the sequences (default: 300 bp)
    """
    # Set consistent y-axis parameters
    y_min, y_max = 0.0, 0.8
    y_ticks = np.arange(y_min, y_max + 0.01, 0.2)  # 0.0, 0.2, 0.4, 0.6, 0.8

    for group, group_name in enumerate(class_names):
        # Create figure with white background
        fig, ax = plt.subplots(figsize=(12, 4), facecolor='white')

        # Plot Group-Level Saliency Map
        ax.plot(np.arange(sequence_length), group_saliency_maps[group],
                color='red', linewidth=2.0)

        # Title and labels with improved formatting
        ax.set_title(f"Group-Level Saliency Map for {group_name}",
                    fontsize=14, pad=15, fontweight='bold')
        ax.set_xlabel("Sequence Position (bp)",
                     fontsize=18, labelpad=15)
        ax.set_ylabel("Average Saliency",
                     fontsize=16, labelpad=15)

        # Set y-axis parameters
        ax.set_ylim(y_min, y_max)
        ax.set_yticks(y_ticks)
        ax.tick_params(axis='y', labelsize=14)

        # Customize x-axis ticks
        x_ticks = np.arange(0, sequence_length + 1, step=50)
        x_tick_labels = [f"{-sequence_length + x}" for x in x_ticks]
        x_tick_labels[-1] = "-1"
        ax.set_xticks(x_ticks)
        ax.set_xticklabels(x_tick_labels, fontsize=14, rotation=45, ha='right')



        # Remove top and right spines
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        # Adjust layout
        plt.tight_layout()

        # Save each plot separately with high resolution
        output_filename = f"Saliency_Map_{group_name.replace(' ', '_')}.png"
        plt.savefig(output_filename, dpi=600, bbox_inches='tight', facecolor=fig.get_facecolor())
        print(f"Saved high-resolution saliency map for {group_name} as {output_filename}")

        plt.show()



# Create a Dataset for the test set
test_dataset = SequenceDataset(X_test, y_test)

# Initialize DataLoader for the test set
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Compute saliency maps for the test set
saliency_maps, test_sequences, test_labels = compute_saliency_maps(model, test_loader, device)

# Compute group-level saliency maps
group_saliency_maps = compute_group_saliency_maps(saliency_maps, test_labels, class_names, sequence_length=300)

# Plot group-level saliency maps
plot_group_saliency_maps(group_saliency_maps, class_names, sequence_length=300)


from Bio import motifs
from Bio.Seq import Seq
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os

def identify_important_regions(group_saliency_maps, class_names, window_size=20, top_n=5):
    """
    Identify important regions by finding peaks in saliency maps.

    Args:
        group_saliency_maps: Dictionary of saliency maps per group
        class_names: List of class names
        window_size: Size of region to consider around each peak
        top_n: Number of top regions to return

    Returns:
        Dictionary of important regions for each group
    """
    important_regions = {}

    for group, group_name in enumerate(class_names):
        saliency = group_saliency_maps[group]

        # Find peaks (local maxima)
        peaks = []
        for i in range(1, len(saliency)-1):
            if saliency[i] > saliency[i-1] and saliency[i] > saliency[i+1]:
                peaks.append(i)

        # Sort peaks by saliency score
        peaks.sort(key=lambda x: saliency[x], reverse=True)

        # Get top N regions
        regions = []
        for peak in peaks[:top_n]:
            start = max(0, peak - window_size//2)
            end = min(len(saliency), peak + window_size//2)
            regions.append((start, end, saliency[peak]))

        important_regions[group_name] = regions

    return important_regions

def get_nucleotide_color(letter, alpha=1.0):
    """Return colors for nucleotides."""
    colors = {
        'A': (0.0, 0.6, 0.0, alpha),    # Green for A
        'T': (0.8, 0.0, 0.0, alpha),    # Red for T
        'C': (0.0, 0.0, 0.8, alpha),    # Blue for C
        'G': (0.9, 0.6, 0.0, alpha),    # Orange for G
        'N': (0.5, 0.5, 0.5, alpha)     # Gray for N
    }
    return colors.get(letter.upper(), (0.0, 0.0, 0.0, alpha))

def create_sequence_logos(test_sequences, test_labels, class_names, important_regions):
    """
    Create sequence logos for important regions in each group using Bio.motifs.

    Args:
        test_sequences: List of all test sequences (one-hot encoded)
        test_labels: List of corresponding labels
        class_names: List of class names
        important_regions: Dictionary of important regions per group
    """
    # Convert one-hot encoded sequences back to nucleotide sequences
    nucleotide_map = {0: 'A', 1: 'T', 2: 'C', 3: 'G', 4: 'N'}
    seq_strings = []
    for seq in test_sequences:
        # Convert one-hot to nucleotides
        seq_str = ''.join([nucleotide_map[np.argmax(pos)] for pos in seq])
        seq_strings.append(seq_str)

    # For each group and its important regions
    for group_name, regions in important_regions.items():
        group_idx = class_names.index(group_name)
        group_seqs = [seq for seq, label in zip(seq_strings, test_labels) if label == group_idx]

        for i, (start, end, score) in enumerate(regions):
            # Extract region sequences for this group
            region_seqs = [seq[start:end] for seq in group_seqs if len(seq) >= end]

            if not region_seqs:
                print(f"No sequences long enough for {group_name} region {i+1} ({start}-{end})")
                continue

            # Create a motif from these sequences
            try:
                instances = [Seq(seq) for seq in region_seqs]
                m = motifs.create(instances)

                # Create a figure
                plt.figure(figsize=(10, 3))

                # Create a position weight matrix
                pwm = m.counts.normalize(pseudocounts=0.5)

                # Calculate information content
                ic = []
                for position in range(len(m)):
                    freq = pwm[position]
                    entropy = -sum(freq[base] * np.log2(freq[base]) if freq[base] > 0 else 0 for base in 'ATCG')
                    ic.append(2 - entropy)

                # Plot the sequence logo
                for position in range(len(ic)):
                    # Sort nucleotides by frequency
                    freq = pwm[position]
                    sorted_bases = sorted([(base, freq[base]) for base in 'ATCG'],
                                         key=lambda x: x[1], reverse=True)

                    yshift = 0
                    for base, f in sorted_bases:
                        if f > 0.1:  # Only show nucleotides with >10% frequency
                            height = f * ic[position]
                            plt.text(position + 0.5, yshift + height/2, base,
                                    ha='center', va='center', fontsize=12,
                                    color=get_nucleotide_color(base))
                            plt.fill_between([position, position+1],
                                           [yshift, yshift],
                                           [yshift + height, yshift + height],
                                           color=get_nucleotide_color(base, alpha=0.3))
                            yshift += height

                # Style the plot
                plt.xlim(0, end-start)
                plt.ylim(0, 2)  # Max information content is 2 bits
                plt.xticks(np.arange(0.5, end-start+0.5, 1),
                          labels=np.arange(start, end, 1), rotation=45)
                plt.xlabel('Position (bp)')
                plt.ylabel('Information (bits)')
                plt.title(f'Sequence Logo for {group_name}\nRegion {i+1} (bp {start}-{end}), Saliency score: {score:.3f}')
                plt.grid(False)

                # Save the plot
                filename = f"Sequence_Logo_{group_name.replace(' ', '_')}_Region_{i+1}.png"
                plt.savefig(filename, dpi=300, bbox_inches='tight')
                plt.close()
                print(f"Saved sequence logo for {group_name} region {i+1} as {filename}")

            except Exception as e:
                print(f"Error creating logo for {group_name} region {i+1}: {str(e)}")
                plt.close()

# Identify important regions from saliency maps
important_regions = identify_important_regions(group_saliency_maps, class_names)

# Create sequence logos for important regions
create_sequence_logos(test_sequences, test_labels, class_names, important_regions)