In [None]:

import os
from pathlib import Path

import numpy as np
import pandas as pd
from Bio import SeqIO
import torch
import esm

# Reproducibility
random_seed = 42
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)
np.random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Hardware + file system configuration
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

PROJECT_ROOT = Path('/content')
print(f"Using device: {device}")
POS_FASTA = PROJECT_ROOT / 'reps_30_rep_seq_pos.fasta'
NEG_FASTA = PROJECT_ROOT / 'reps_30_rep_seq_neg.fasta'
OUTPUT_DIR = PROJECT_ROOT / 'esm_outputs'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
MAX_SEQUENCES_PER_BATCH = 8
MAX_TOKENS_PER_BATCH = 2000  # approximate per-batch token budget
MAX_RESIDUES_PER_CHUNK = 1000  # split very long sequences
MAX_TOKENS_PER_BATCH = 2000  # approximate per-batch token budget

ESM_MODEL = 'esm2_t33_650M_UR50D'
META_KEYS = ['source', 'gene', 'entry', 'file', 'idx']

for fasta_path in [POS_FASTA, NEG_FASTA]:
    if not fasta_path.exists():
        raise FileNotFoundError(f"Missing FASTA file: {fasta_path}")


In [None]:

VALID_RESIDUES = set("ACDEFGHIKLMNPQRSTVWYBJOUXZ")

def is_valid_sequence(seq: str) -> bool:
    return set(seq.upper()) <= VALID_RESIDUES


def parse_header_metadata(header: str) -> dict:
    parts = header.split('|')
    meta = {'source': parts[0]}
    for part in parts[1:]:
        if '=' in part:
            key, value = part.split('=', 1)
            meta[key.strip()] = value.strip()
    for key in META_KEYS:
        meta.setdefault(key, 'NA')
    return meta


def read_fasta_with_metadata(file_path: Path, label: int) -> pd.DataFrame:
    records = []
    for record in SeqIO.parse(str(file_path), 'fasta'):
        meta = parse_header_metadata(record.description)
        records.append({
            **meta,
            'header': record.description,
            'sequence': str(record.seq),
            'label': label
        })
    return pd.DataFrame(records)


positive_df = read_fasta_with_metadata(POS_FASTA, 1)
negative_df = read_fasta_with_metadata(NEG_FASTA, 0)
print(f"Loaded {len(positive_df)} positive and {len(negative_df)} negative representatives")

# Combine, shuffle (for downstream ML), and index sequences
combined_df = pd.concat([positive_df, negative_df], ignore_index=True)
valid_mask = combined_df['sequence'].str.upper().apply(lambda s: set(s) <= VALID_RESIDUES)
invalid = combined_df[~valid_mask]
if len(invalid):
    print(f'Removed {len(invalid)} sequences with non-standard residues')
combined_df = combined_df[valid_mask].reset_index(drop=True)
sequence_df = combined_df.sample(frac=1.0, random_state=random_seed).reset_index(drop=True)
sequence_df.insert(0, 'sequence_id', range(len(sequence_df)))
sequence_df['sequence_length'] = sequence_df['sequence'].str.len()

metadata_columns = ['sequence_id', 'label'] + META_KEYS + ['header', 'sequence_length']
metadata_df = sequence_df[metadata_columns].copy()
print(f"Unique genes represented: {metadata_df['gene'].nunique()}")
print(sequence_df.groupby('label')['sequence_length'].describe()[['mean', 'min', 'max']])


In [None]:

model_loader = getattr(esm.pretrained, ESM_MODEL)
model, alphabet = model_loader()
model = model.eval().to(device)
repr_layer = model.num_layers
batch_converter = alphabet.get_batch_converter()


def chunked_rows(df: pd.DataFrame):
    for row in df.itertuples():
        sequence = row.sequence
        if len(sequence) <= MAX_RESIDUES_PER_CHUNK:
            yield row, sequence
        else:
            for start in range(0, len(sequence), MAX_RESIDUES_PER_CHUNK):
                yield row, sequence[start:start + MAX_RESIDUES_PER_CHUNK]


def batch_iter(df: pd.DataFrame):
    batch_rows = []
    token_budget = 0
    for row, chunk_seq in chunked_rows(df):
        length = len(chunk_seq) + 2  # account for BOS/EOS tokens
        if batch_rows and (
            len(batch_rows) >= MAX_SEQUENCES_PER_BATCH or
            token_budget + length > MAX_TOKENS_PER_BATCH
        ):
            yield batch_rows
            batch_rows = []
            token_budget = 0
        batch_rows.append((row, chunk_seq))
        token_budget += length
    if batch_rows:
        yield batch_rows


def embed_dataframe(df: pd.DataFrame) -> np.ndarray:
    embeddings = np.zeros((len(df), model.embed_dim), dtype=np.float32)
    chunk_counts = np.zeros(len(df), dtype=np.int32)
    processed = 0
    for batch_rows in batch_iter(df):
        batch_ids = [str(row.sequence_id) for row, _ in batch_rows]
        batch_seqs = [seq for _, seq in batch_rows]
        batch_data = list(zip(batch_ids, batch_seqs))
        _, _, batch_tokens = batch_converter(batch_data)
        batch_tokens = batch_tokens.to(device)
        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[repr_layer], return_contacts=False)
        token_representations = results['representations'][repr_layer]
        seq_representations = token_representations[:, 1:-1].mean(1).cpu().numpy()
        for (row, _), embedding in zip(batch_rows, seq_representations):
            embeddings[row.Index] += embedding
            chunk_counts[row.Index] += 1
        processed += len(batch_rows)
        if processed % 200 == 0:
            print(f"Processed {processed} chunks")
    if not np.all(chunk_counts):
        missing = np.where(chunk_counts == 0)[0]
        raise RuntimeError(f"Missing embeddings for rows: {missing[:10]} ...")
    embeddings /= chunk_counts[:, None]
    return embeddings


embeddings = embed_dataframe(sequence_df)
labels = sequence_df['label'].to_numpy(dtype=np.int8)
np.save(OUTPUT_DIR / 'esm_features.npy', embeddings)
np.save(OUTPUT_DIR / 'labels.npy', labels)
metadata_df.to_csv(OUTPUT_DIR / 'sequence_metadata.csv', index=False)
metadata_df.to_parquet(OUTPUT_DIR / 'sequence_metadata.parquet', index=False)

print('Saved embeddings to', OUTPUT_DIR / 'esm_features.npy')
print('Saved labels to', OUTPUT_DIR / 'labels.npy')
print('Saved metadata to', OUTPUT_DIR / 'sequence_metadata.(csv|parquet)')


In [None]:
import numpy as np
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    confusion_matrix,
    classification_report,
)
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import torch.nn.functional as F

# Reproducibility
random_seed = 42
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
np.random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Device selection
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f'Using device: {device}')

# Paths for embeddings generated by ESM2_features.ipynb
PROJECT_ROOT = Path('/content')
EMBED_DIR = PROJECT_ROOT / 'esm_outputs'
if not EMBED_DIR.exists():
    PROJECT_ROOT = Path('..').resolve()
    EMBED_DIR = PROJECT_ROOT / 'esm_outputs'
if not EMBED_DIR.exists():
    raise FileNotFoundError(f'Could not find esm_outputs directory at {EMBED_DIR}')

features_path = EMBED_DIR / 'esm_features.npy'
labels_path = EMBED_DIR / 'labels.npy'
metadata_path = EMBED_DIR / 'sequence_metadata.csv'

print('Loading embeddings from', features_path)
X = np.load(features_path)
y = np.load(labels_path)
metadata = pd.read_csv(metadata_path)
if len(metadata) != len(X):
    raise ValueError('Metadata rows do not match embeddings')

# Train/test split with aligned metadata indices
indices = np.arange(len(y))
(
    X_train,
    X_test,
    y_train,
    y_test,
    idx_train,
    idx_test,
) = train_test_split(
    X,
    y,
    indices,
    test_size=0.2,
    stratify=y,
    random_state=random_seed,
)
metadata_train = metadata.iloc[idx_train].reset_index(drop=True)
metadata_test = metadata.iloc[idx_test].reset_index(drop=True)

# Reshape for LSTM input
seq_len = 1
X_train = X_train.reshape(-1, seq_len, X_train.shape[1])
X_test = X_test.reshape(-1, seq_len, X_test.shape[1])

# Tensors and loaders
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


class ImprovedLSTMClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes, dropout_rate):
        super().__init__()
        self.lstm = nn.LSTM(
            input_dim,
            hidden_dim,
            num_layers,
            batch_first=True,
            dropout=dropout_rate,
            bidirectional=True,
        )
        self.bn = nn.BatchNorm1d(hidden_dim * 2)
        self.conv1 = nn.Conv2d(1, 76, kernel_size=(6, 1), padding=(1, 0))
        self.bn1 = nn.BatchNorm2d(76)
        self.pool = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
        self.conv2 = nn.Conv2d(76, 111, kernel_size=(4, 1), padding=(1, 0))
        self.bn2 = nn.BatchNorm2d(111)
        self.conv3 = nn.Conv2d(111, 487, kernel_size=(5, 1), padding=(1, 0))
        self.bn3 = nn.BatchNorm2d(487)
        self.dropout = nn.Dropout(0.5456158649892608)
        self.flatten_dim = self._get_flatten_dim(input_dim)
        self.fc = nn.Linear(self.flatten_dim, num_classes)

    def _get_flatten_dim(self, input_dim):
        h0 = torch.zeros(num_layers * 2, batch_size, hidden_dim)
        c0 = torch.zeros(num_layers * 2, batch_size, hidden_dim)
        x = torch.ones(batch_size, 1, input_dim)
        out, _ = self.lstm(x, (h0, c0))
        out = out[:, -1, :]
        out = out.unsqueeze(1).unsqueeze(-1)
        out = self.pool(F.relu(self.bn1(self.conv1(out))))
        out = self.pool(F.relu(self.bn2(self.conv2(out))))
        out = self.pool(F.relu(self.bn3(self.conv3(out))))
        out = out.view(out.size(0), -1)
        return out.size(1)

    def forward(self, x):
        h0 = torch.zeros(self.lstm.num_layers * 2, x.size(0), self.lstm.hidden_size).to(device)
        c0 = torch.zeros(self.lstm.num_layers * 2, x.size(0), self.lstm.hidden_size).to(device)
        out, _ = self.lstm(x, (h0, c0))
        out = out[:, -1, :]
        out = out.unsqueeze(1).unsqueeze(-1)
        out = self.pool(F.relu(self.bn1(self.conv1(out))))
        out = self.pool(F.relu(self.bn2(self.conv2(out))))
        out = self.pool(F.relu(self.bn3(self.conv3(out))))
        out = out.view(out.size(0), -1)
        out = self.dropout(out)
        out = self.fc(out)
        return F.log_softmax(out, dim=1)


input_dim = X_train.shape[2]
hidden_dim = 181
num_layers = 4
dropout_rate = 0.4397133138964481
learning_rate = 0.0003466440190079221
num_classes = 2

model = ImprovedLSTMClassifier(input_dim, hidden_dim, num_layers, num_classes, dropout_rate).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

n_epochs = 50
patience = 5
best_val_acc = 0.0
early_stop_counter = 0
model_save_path = PROJECT_ROOT / 'best_improved_lstmCNN_model.pth'

for epoch in range(n_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for data, target in tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{n_epochs}', leave=False):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0)
        correct += (output.argmax(1) == target).sum().item()
        total += target.size(0)

    train_loss /= total
    train_accuracy = correct / total

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc=f'Validating Epoch {epoch+1}/{n_epochs}', leave=False):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item() * data.size(0)
            correct += (output.argmax(1) == target).sum().item()
            total += target.size(0)

    val_loss /= total
    val_accuracy = correct / total

    print(f'Epoch {epoch+1}/{n_epochs}')
    print(f'Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')
    print(f'Training Accuracy: {train_accuracy:.4f}, Validation Accuracy: {val_accuracy:.4f}')

    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        early_stop_counter = 0
        torch.save(model.state_dict(), model_save_path)
        print(f'  Best model saved to {model_save_path}')
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print('Early stopping triggered')
            break

print('Loading the best model...')
model.load_state_dict(torch.load(model_save_path, map_location=device))

model.eval()
correct = 0
total = 0
y_pred_prob = []
y_true = []
print('Evaluating on the test set...')
with torch.no_grad():
    for data, target in tqdm(test_loader, desc='Testing', leave=False):
        data, target = data.to(device), target.to(device)
        output = model(data)
        probs = output.exp()
        correct += (probs.argmax(1) == target).sum().item()
        y_pred_prob.extend(probs[:, 1].cpu().numpy())
        y_true.extend(target.cpu().numpy())

test_accuracy = correct / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:.4f}')

roc_auc = roc_auc_score(y_true, y_pred_prob)
fpr, tpr, _ = roc_curve(y_true, y_pred_prob)
precision, recall, _ = precision_recall_curve(y_true, y_pred_prob)

with torch.no_grad():
    logits = model(X_test.to(device)).cpu().numpy()
y_pred = np.argmax(logits, axis=1)
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=['Negative', 'Positive'], yticklabels=['Negative', 'Positive'])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

print('Classification Report:', classification_report(y_test, y_pred))

plt.figure(figsize=(12, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.show()

plt.figure(figsize=(12, 6))
plt.plot(recall, precision, color='blue', lw=2)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.show()
