In [None]:
import numpy as np
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    confusion_matrix,
    classification_report,
)
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import torch.nn.functional as F

# Reproducibility
random_seed = 42
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
np.random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Device selection
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f'Using device: {device}')

# Paths for embeddings generated by ESM2_features.ipynb
PROJECT_ROOT = Path('/content')
EMBED_DIR = PROJECT_ROOT / 'esm_outputs'
if not EMBED_DIR.exists():
    PROJECT_ROOT = Path('..').resolve()
    EMBED_DIR = PROJECT_ROOT / 'esm_outputs'
if not EMBED_DIR.exists():
    raise FileNotFoundError(f'Could not find esm_outputs directory at {EMBED_DIR}')

features_path = EMBED_DIR / 'esm_features.npy'
labels_path = EMBED_DIR / 'labels.npy'
CHECKPOINT_DIR = PROJECT_ROOT / 'model_checkpoints'
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
metadata_path = EMBED_DIR / 'sequence_metadata.csv'

print('Loading embeddings from', features_path)
X = np.load(features_path)
y = np.load(labels_path)
metadata = pd.read_csv(metadata_path)
if len(metadata) != len(X):
    raise ValueError('Metadata rows do not match embeddings')

# Train/test split with aligned metadata indices
indices = np.arange(len(y))
(
    X_train,
    X_test,
    y_train,
    y_test,
    idx_train,
    idx_test,
) = train_test_split(
    X,
    y,
    indices,
    test_size=0.2,
    stratify=y,
    random_state=random_seed,
)
metadata_train = metadata.iloc[idx_train].reset_index(drop=True)
metadata_test = metadata.iloc[idx_test].reset_index(drop=True)

# Reshape for LSTM input
seq_len = 1
X_train = X_train.reshape(-1, seq_len, X_train.shape[1])
X_test = X_test.reshape(-1, seq_len, X_test.shape[1])

# Tensors and loaders
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


class ImprovedLSTMClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes, dropout_rate):
        super().__init__()
        self.lstm = nn.LSTM(
            input_dim,
            hidden_dim,
            num_layers,
            batch_first=True,
            dropout=dropout_rate,
            bidirectional=True,
        )
        self.bn = nn.BatchNorm1d(hidden_dim * 2)
        self.conv1 = nn.Conv2d(1, 76, kernel_size=(6, 1), padding=(1, 0))
        self.bn1 = nn.BatchNorm2d(76)
        self.pool = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
        self.conv2 = nn.Conv2d(76, 111, kernel_size=(4, 1), padding=(1, 0))
        self.bn2 = nn.BatchNorm2d(111)
        self.conv3 = nn.Conv2d(111, 487, kernel_size=(5, 1), padding=(1, 0))
        self.bn3 = nn.BatchNorm2d(487)
        self.dropout = nn.Dropout(0.5456158649892608)
        self.flatten_dim = self._get_flatten_dim(input_dim)
        self.fc = nn.Linear(self.flatten_dim, num_classes)

    def _get_flatten_dim(self, input_dim):
        h0 = torch.zeros(num_layers * 2, batch_size, hidden_dim)
        c0 = torch.zeros(num_layers * 2, batch_size, hidden_dim)
        x = torch.ones(batch_size, 1, input_dim)
        out, _ = self.lstm(x, (h0, c0))
        out = out[:, -1, :]
        out = out.unsqueeze(1).unsqueeze(-1)
        out = self.pool(F.relu(self.bn1(self.conv1(out))))
        out = self.pool(F.relu(self.bn2(self.conv2(out))))
        out = self.pool(F.relu(self.bn3(self.conv3(out))))
        out = out.view(out.size(0), -1)
        return out.size(1)

    def forward(self, x):
        h0 = torch.zeros(self.lstm.num_layers * 2, x.size(0), self.lstm.hidden_size).to(device)
        c0 = torch.zeros(self.lstm.num_layers * 2, x.size(0), self.lstm.hidden_size).to(device)
        out, _ = self.lstm(x, (h0, c0))
        out = out[:, -1, :]
        out = out.unsqueeze(1).unsqueeze(-1)
        out = self.pool(F.relu(self.bn1(self.conv1(out))))
        out = self.pool(F.relu(self.bn2(self.conv2(out))))
        out = self.pool(F.relu(self.bn3(self.conv3(out))))
        out = out.view(out.size(0), -1)
        out = self.dropout(out)
        out = self.fc(out)
        return F.log_softmax(out, dim=1)


input_dim = X_train.shape[2]
hidden_dim = 181
num_layers = 4
dropout_rate = 0.4397133138964481
learning_rate = 0.0003466440190079221
num_classes = 2

model = ImprovedLSTMClassifier(input_dim, hidden_dim, num_layers, num_classes, dropout_rate).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

n_epochs = 50
patience = 5
best_val_acc = 0.0
early_stop_counter = 0
model_save_path = PROJECT_ROOT / 'best_improved_lstmCNN_model.pth'

for epoch in range(n_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for data, target in tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{n_epochs}', leave=False):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0)
        correct += (output.argmax(1) == target).sum().item()
        total += target.size(0)

    train_loss /= total
    train_accuracy = correct / total

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc=f'Validating Epoch {epoch+1}/{n_epochs}', leave=False):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item() * data.size(0)
            correct += (output.argmax(1) == target).sum().item()
            total += target.size(0)

    val_loss /= total
    val_accuracy = correct / total

    print(f'Epoch {epoch+1}/{n_epochs}')
    print(f'Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')
    print(f'Training Accuracy: {train_accuracy:.4f}, Validation Accuracy: {val_accuracy:.4f}')

    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        early_stop_counter = 0
        torch.save(model.state_dict(), model_save_path)
        print(f'  Best model saved to {model_save_path}')
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print('Early stopping triggered')
            break

print('Loading the best model...')
model.load_state_dict(torch.load(model_save_path, map_location=device))

model.eval()
correct = 0
total = 0
y_pred_prob = []
y_true = []
print('Evaluating on the test set...')
with torch.no_grad():
    for data, target in tqdm(test_loader, desc='Testing', leave=False):
        data, target = data.to(device), target.to(device)
        output = model(data)
        probs = output.exp()
        correct += (probs.argmax(1) == target).sum().item()
        y_pred_prob.extend(probs[:, 1].cpu().numpy())
        y_true.extend(target.cpu().numpy())

test_accuracy = correct / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:.4f}')

roc_auc = roc_auc_score(y_true, y_pred_prob)
fpr, tpr, _ = roc_curve(y_true, y_pred_prob)
precision, recall, _ = precision_recall_curve(y_true, y_pred_prob)

with torch.no_grad():
    logits = model(X_test.to(device)).cpu().numpy()
y_pred = np.argmax(logits, axis=1)
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=['Negative', 'Positive'], yticklabels=['Negative', 'Positive'])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

print('Classification Report:\n', classification_report(y_test, y_pred))

plt.figure(figsize=(12, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.show()

plt.figure(figsize=(12, 6))
plt.plot(recall, precision, color='blue', lw=2)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.show()


In [None]:
import optuna
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    confusion_matrix,
    classification_report,
)
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import torch.nn.functional as F

random_seed = 42
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
np.random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f'Using device: {device}')

PROJECT_ROOT = Path('/content')
EMBED_DIR = PROJECT_ROOT / 'esm_outputs'
if not EMBED_DIR.exists():
    PROJECT_ROOT = Path('..').resolve()
    EMBED_DIR = PROJECT_ROOT / 'esm_outputs'
if not EMBED_DIR.exists():
    raise FileNotFoundError(f'Could not find esm_outputs directory at {EMBED_DIR}')

features_path = EMBED_DIR / 'esm_features.npy'
labels_path = EMBED_DIR / 'labels.npy'
CHECKPOINT_DIR = PROJECT_ROOT / 'model_checkpoints'
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
print('Loading embeddings from', features_path)
X = np.load(features_path)
y = np.load(labels_path)

X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.2,
    stratify=y,
    random_state=random_seed,
)

seq_len = 1
X_train = X_train.reshape(-1, seq_len, X_train.shape[1])
X_test = X_test.reshape(-1, seq_len, X_test.shape[1])

X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)


def build_model(input_dim, hidden_dim, num_layers, num_classes, dropout_rate):
    class ImprovedLSTMClassifier(nn.Module):
        def __init__(self, input_dim, hidden_dim, num_layers, num_classes, dropout_rate):
            super().__init__()
            self.hidden_dim = hidden_dim
            self.num_layers = num_layers
            self.lstm = nn.LSTM(
                input_dim,
                hidden_dim,
                num_layers,
                batch_first=True,
                dropout=dropout_rate,
                bidirectional=True,
            )
            self.conv1 = nn.Conv2d(1, 64, kernel_size=(3, 1), padding=(1, 0))
            self.bn1 = nn.BatchNorm2d(64)
            self.pool = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
            self.conv2 = nn.Conv2d(64, 128, kernel_size=(3, 1), padding=(1, 0))
            self.bn2 = nn.BatchNorm2d(128)
            self.conv3 = nn.Conv2d(128, 256, kernel_size=(3, 1), padding=(1, 0))
            self.bn3 = nn.BatchNorm2d(256)
            self.dropout = nn.Dropout(dropout_rate)
            self.fc = nn.Linear(self._get_flatten_dim(input_dim), num_classes)

        def _get_flatten_dim(self, input_dim):
            with torch.no_grad():
                x = torch.ones(1, 1, input_dim)
                h0 = torch.zeros(self.num_layers * 2, 1, self.hidden_dim)
                c0 = torch.zeros(self.num_layers * 2, 1, self.hidden_dim)
                out, _ = self.lstm(x, (h0, c0))
                out = out[:, -1, :]
                out = out.unsqueeze(1).unsqueeze(-1)
                out = self.pool(F.relu(self.bn1(self.conv1(out))))
                out = self.pool(F.relu(self.bn2(self.conv2(out))))
                out = self.pool(F.relu(self.bn3(self.conv3(out))))
                return out.view(out.size(0), -1).size(1)

        def forward(self, x):
            h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_dim).to(device)
            c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_dim).to(device)
            out, _ = self.lstm(x, (h0, c0))
            out = out[:, -1, :]
            out = out.unsqueeze(1).unsqueeze(-1)
            out = self.pool(F.relu(self.bn1(self.conv1(out))))
            out = self.pool(F.relu(self.bn2(self.conv2(out))))
            out = self.pool(F.relu(self.bn3(self.conv3(out))))
            out = out.view(out.size(0), -1)
            out = self.dropout(out)
            out = self.fc(out)
            return F.log_softmax(out, dim=1)

    return ImprovedLSTMClassifier(input_dim, hidden_dim, num_layers, num_classes, dropout_rate)


model_save_path = PROJECT_ROOT / 'best_improved_lstmCNN_model_optuna.pth'


def objective(trial):
    hidden_dim = trial.suggest_int('hidden_dim', 64, 256)
    num_layers = trial.suggest_int('num_layers', 1, 4)
    dropout_rate = trial.suggest_float('dropout_rate', 0.2, 0.5)
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True)
    batch_size = trial.suggest_categorical('batch_size', [32, 64, 128])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    model = build_model(X_train.shape[2], hidden_dim, num_layers, 2, dropout_rate).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    n_epochs = 20
    best_val_acc = 0.0
    patience = 3
    early_stop = 0

    trial_model_path = model_save_path.with_name(f"{model_save_path.stem}_trial{trial.number}.pth")
    if trial_model_path.exists():
        trial_model_path.unlink()

    for epoch in range(n_epochs):
        model.train()
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                correct += (output.argmax(1) == target).sum().item()
                total += target.size(0)
        val_acc = correct / total

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            early_stop = 0
            torch.save(model.state_dict(), trial_model_path)
        else:
            early_stop += 1
            if early_stop >= patience:
                break

    return best_val_acc


study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=25)
print('Best hyperparameters:', study.best_params)
print('Best validation accuracy:', study.best_value)

best_params = study.best_params
batch_size = best_params['batch_size']
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

best_model_path = model_save_path.with_name(f"{model_save_path.stem}_trial{study.best_trial.number}.pth")
model = build_model(
    X_train.shape[2],
    best_params['hidden_dim'],
    best_params['num_layers'],
    2,
    best_params['dropout_rate'],
).to(device)
model.load_state_dict(torch.load(best_model_path, map_location=device))
model.eval()

correct = 0
total = 0
y_pred_prob = []
y_true = []
with torch.no_grad():
    for data, target in tqdm(test_loader, desc='Testing', leave=False):
        data, target = data.to(device), target.to(device)
        output = model(data)
        probs = output.exp()
        correct += (probs.argmax(1) == target).sum().item()
        y_pred_prob.extend(probs[:, 1].cpu().numpy())
        y_true.extend(target.cpu().numpy())

test_accuracy = correct / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:.4f}')

roc_auc = roc_auc_score(y_true, y_pred_prob)
fpr, tpr, _ = roc_curve(y_true, y_pred_prob)
precision, recall, _ = precision_recall_curve(y_true, y_pred_prob)

logits = model(X_test.to(device)).cpu().numpy()
y_pred = np.argmax(logits, axis=1)
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=['Negative', 'Positive'], yticklabels=['Negative', 'Positive'])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix (Optuna model)')
plt.show()

print('Classification Report:
', classification_report(y_test, y_pred))

plt.figure(figsize=(6, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve (Optuna model)')
plt.legend(loc='lower right')
plt.show()

plt.figure(figsize=(6, 6))
plt.plot(recall, precision, color='blue', lw=2)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve (Optuna model)')
plt.show()


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, roc_auc_score, roc_curve, precision_recall_curve, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import torch.nn.functional as F
from pathlib import Path

# Mount Google Drive

# Set random seeds for reproducibility
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
np.random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Check if GPU is available and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Load features and labels from Google Drive
PROJECT_ROOT = Path('/content')
EMBED_DIR = PROJECT_ROOT / 'esm_outputs'
if not EMBED_DIR.exists():
    PROJECT_ROOT = Path('..').resolve()
    EMBED_DIR = PROJECT_ROOT / 'esm_outputs'
if not EMBED_DIR.exists():
    raise FileNotFoundError(f'Could not find esm_outputs directory at {EMBED_DIR}')

features_path = EMBED_DIR / 'esm_features.npy'
labels_path = EMBED_DIR / 'labels.npy'
CHECKPOINT_DIR = PROJECT_ROOT / 'model_checkpoints'
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
print('Loading embeddings from', features_path)
X = np.load(features_path)
y = np.load(labels_path)


# Reshape the features for LSTM input (e.g., (num_samples, seq_len, feature_dim))
seq_len = 1  # This should match your sequence length if it's different
X = X.reshape(-1, seq_len, X.shape[1])

# Convert to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)

# Create TensorDataset
dataset = TensorDataset(X_tensor, y_tensor)

# Initialize parameters
input_dim = X.shape[2]
hidden_dim = 181
num_layers = 4
dropout_rate = 0.4397133138964481
learning_rate = 0.0003466440190079221
num_classes = 2
batch_size = 32
n_epochs = 50
patience = 5

# Initialize KFold
kf = KFold(n_splits=10, shuffle=True, random_state=random_seed)

# Define model class
class ImprovedLSTMClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes, dropout_rate):
        super(ImprovedLSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_rate, bidirectional=True)
        self.bn = nn.BatchNorm1d(hidden_dim * 2)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=76, kernel_size=(6, 1), padding=(1, 0))
        self.bn1 = nn.BatchNorm2d(76)
        self.pool = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
        self.conv2 = nn.Conv2d(in_channels=76, out_channels=111, kernel_size=(4, 1), padding=(1, 0))
        self.bn2 = nn.BatchNorm2d(111)
        self.conv3 = nn.Conv2d(in_channels=111, out_channels=487, kernel_size=(5, 1), padding=(1, 0))
        self.bn3 = nn.BatchNorm2d(487)

        self.flatten_dim = self._get_flatten_dim(input_dim)
        self.fc = nn.Linear(self.flatten_dim, num_classes)
        self.dropout = nn.Dropout(0.5456158649892608)

    def _get_flatten_dim(self, input_dim):
        h0 = torch.zeros(num_layers * 2, batch_size, hidden_dim)
        c0 = torch.zeros(num_layers * 2, batch_size, hidden_dim)
        x = torch.ones(batch_size, 1, input_dim)
        out, _ = self.lstm(x, (h0, c0))
        out = out[:, -1, :]
        out = out.unsqueeze(1).unsqueeze(-1)
        out = self.conv1(out)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.pool(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)
        out = self.pool(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out = F.relu(out)
        out = self.pool(out)
        out = out.view(out.size(0), -1)
        return out.size(1)

    def forward(self, x):
        h0 = torch.zeros(self.lstm.num_layers * 2, x.size(0), self.lstm.hidden_size).to(device)
        c0 = torch.zeros(self.lstm.num_layers * 2, x.size(0), self.lstm.hidden_size).to(device)
        out, _ = self.lstm(x, (h0, c0))
        out = out[:, -1, :]
        out = out.unsqueeze(1).unsqueeze(-1)
        out = self.conv1(out)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.pool(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)
        out = self.pool(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out = F.relu(out)
        out = self.pool(out)
        out = out.view(out.size(0), -1)
        out = self.dropout(out)
        out = self.fc(out)
        return F.log_softmax(out, dim=1)

# Cross-validation
fold = 1
accuracies = []
for train_idx, val_idx in kf.split(dataset):
    print(f"Training fold {fold}/{kf.n_splits}")

    # Create data loaders for this fold
    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    # Initialize the model, loss function, and optimizer
    model = ImprovedLSTMClassifier(input_dim, hidden_dim, num_layers, num_classes, dropout_rate).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_val_acc = 0.0
    early_stop_counter = 0

    for epoch in range(n_epochs):
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0
        for data, target in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{n_epochs}", leave=False):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * data.size(0)
            _, predicted = torch.max(output, 1)
            correct += (predicted == target).sum().item()
            total += target.size(0)

        train_loss /= total
        train_accuracy = correct / total

        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in tqdm(val_loader, desc=f"Validating Epoch {epoch+1}/{n_epochs}", leave=False):
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                val_loss += loss.item() * data.size(0)
                _, predicted = torch.max(output, 1)
                correct += (predicted == target).sum().item()
                total += target.size(0)

        val_loss /= total
        val_accuracy = correct / total

        print(f'Epoch {epoch+1}/{n_epochs}')
        print(f'Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')
        print(f'Training Accuracy: {train_accuracy:.4f}, Validation Accuracy: {val_accuracy:.4f}')

        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            early_stop_counter = 0
            model_save_path = CHECKPOINT_DIR / f'best_lstmCNN_model_fold_{fold}.pth'
            torch.save(model.state_dict(), model_save_path)
            print(f"  Best model for fold {fold} saved to Google Drive!")
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print("Early stopping triggered")
                break

    # Load the best model for this fold
    print(f"Loading the best model for fold {fold} from Google Drive...")
    model.load_state_dict(torch.load(CHECKPOINT_DIR / f'best_lstmCNN_model_fold_{fold}.pth'))

    # Evaluate on the validation set
    model.eval()
    correct = 0
    total = 0
    y_pred_prob = []
    y_true = []

    print(f"Evaluating fold {fold} on the validation set...")
    with torch.no_grad():
        for data, target in tqdm(val_loader, desc="Testing", leave=False):
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output, 1)
            correct += (predicted == target).sum().item()
            y_pred_prob.extend(output[:, 1].cpu().numpy())
            y_true.extend(target.cpu().numpy())

    val_accuracy = correct / len(val_loader.dataset)
    accuracies.append(val_accuracy)
    print(f'Validation Accuracy for fold {fold}: {val_accuracy:.4f}')
    fold += 1

# Print final cross-validated accuracy
mean_accuracy = np.mean(accuracies)
std_accuracy = np.std(accuracies)
print(f'Final Cross-Validated Accuracy: {mean_accuracy:.4f} Â± {std_accuracy:.4f}')


In [None]:
import numpy as np

def summarize_gene_performance(predictions, metadata: pd.DataFrame, threshold: float = 0.5) -> pd.DataFrame:
    """Link downstream model predictions back to genes for diagnostics."""
    scores = np.asarray(predictions)
    if len(scores) != len(metadata):
        raise ValueError('Predictions and metadata must have the same length')

    if scores.ndim == 2:
        if scores.shape[1] == 2:
            scores = scores[:, 1]
        elif scores.shape[1] == 1:
            scores = scores[:, 0]
        else:
            raise ValueError('Unsupported prediction shape; expected (n,), (n,1) or (n,2)')

    df = metadata.copy().reset_index(drop=True)
    df['prediction_score'] = scores
    df['pred_label'] = (df['prediction_score'] >= threshold).astype(int)
    df['correct'] = (df['pred_label'] == df['label'])

    summary = (
        df.groupby(['gene', 'label'])
          .agg(n_sequences=('sequence_id', 'count'), accuracy=('correct', 'mean'))
          .reset_index()
          .sort_values('accuracy', ascending=False)
    )
    return summary

# Run this cell after training/evaluating to inspect per-gene metrics on the test split
# Example:
# gene_perf = summarize_gene_performance(np.array(y_pred_prob), metadata_test)
# display(gene_perf.head())
