In [None]:
# ======================================
# Imports
# ======================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import csv
import random
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score,
    roc_auc_score, average_precision_score
)
from sklearn.calibration import calibration_curve
from scipy.stats import spearmanr

# ======================================
# Reproducibility
# ======================================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# ======================================
# Data Utilities
# ======================================
def one_hot_encode(seq, max_len=23):
    mapping = {"A":0, "C":1, "G":2, "T":3, "N":4}
    encoding = torch.zeros((max_len, len(mapping)))

    for i, base in enumerate(seq[:max_len]):
        idx = mapping.get(base.upper(), 4)
        encoding[i, idx] = 1.0

    return encoding


class gRNADataset(Dataset):
    def __init__(self, pairs, labels, max_len=23):
        self.pairs = pairs
        self.labels = labels
        self.max_len = max_len

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        seq1, seq2 = self.pairs[idx]
        x1 = one_hot_encode(seq1, self.max_len)
        x2 = one_hot_encode(seq2, self.max_len)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return x1, x2, label


# ======================================
# LSTM Encoder
# ======================================
class LSTMEncoder(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=128, dropout=0.3):
        super().__init__()

        self.lstm = nn.LSTM(
            input_dim,
            hidden_dim,
            batch_first=True,
            bidirectional=True
        )

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(self, x):
        outputs, _ = self.lstm(x)
        pooled, _ = torch.max(outputs, dim=1)
        pooled = self.dropout(pooled)   # üî• MC Dropout
        return self.fc(pooled)


# ======================================
# Siamese Network
# ======================================
class SiameseNetwork(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=128):
        super().__init__()

        self.encoder = LSTMEncoder(input_dim, hidden_dim)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, 64),
            nn.ReLU(),
            nn.Dropout(0.3),          # üî• MC Dropout
            nn.Linear(64, 1)
        )

    def forward(self, x1, x2):
        e1 = self.encoder(x1)
        e2 = self.encoder(x2)

        diff = torch.abs(e1 - e2)
        prod = e1 * e2

        combined = torch.cat([diff, prod], dim=1)
        logits = self.classifier(combined)

        return logits.squeeze(-1)


# ======================================
# Training
# ======================================
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0

    for x1, x2, y in dataloader:
        x1, x2, y = x1.to(device), x2.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x1, x2)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


# ======================================
# Probability Evaluation
# ======================================
def evaluate_probs(model, dataloader, device):
    model.eval()
    probs, labels = [], []

    with torch.no_grad():
        for x1, x2, y in dataloader:
            x1, x2 = x1.to(device), x2.to(device)
            logits = model(x1, x2)
            p = torch.sigmoid(logits)

            probs.extend(p.cpu().numpy())
            labels.extend(y.numpy())

    return np.array(probs), np.array(labels)


# ======================================
# MC Dropout Prediction
# ======================================
def mc_dropout_predict(model, dataloader, device, n_samples=100):
    model.train()  # üî• keep dropout ON
    all_preds = []

    for _ in range(n_samples):
        probs, _ = evaluate_probs(model, dataloader, device)
        all_preds.append(probs)

    return np.stack(all_preds, axis=0)


# ======================================
# Metrics
# ======================================
def classification_metrics(probs, labels, threshold=0.5):
    preds = (probs > threshold).astype(int)
    return {
        "Accuracy": accuracy_score(labels, preds),
        "Precision": precision_score(labels, preds),
        "Recall": recall_score(labels, preds),
        "F1": f1_score(labels, preds)
    }


def auc_metrics(probs, labels):
    return {
        "ROC-AUC": roc_auc_score(labels, probs),
        "PR-AUC": average_precision_score(labels, probs)
    }


def spearman_metric(probs, labels):
    return spearmanr(probs, labels).correlation


def plot_calibration(probs, labels):
    frac_pos, mean_pred = calibration_curve(labels, probs, n_bins=10)

    plt.figure()
    plt.plot(mean_pred, frac_pos, marker="o")
    plt.plot([0,1],[0,1],"--")
    plt.xlabel("Predicted probability")
    plt.ylabel("Observed frequency")
    plt.title("Calibration Plot")
    plt.show()


# ======================================
# CSV Loader
# ======================================
def load_pairs_and_labels(csv_file):
    pairs, labels = [], []

    with open(csv_file, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            on_seq = row["on_seq"].replace("-", "").replace("_", "")
            off_seq = row["off_seq"].replace("-", "").replace("_", "")
            pairs.append((on_seq, off_seq))
            labels.append(float(row["label"]))

    return pairs, labels


# ======================================
# Main
# ======================================
def main():
    csv_file = "/content/drive/MyDrive/Deep Learning/circle_seq_train_ratio_preserved.csv"

    pairs, labels = load_pairs_and_labels(csv_file)

    train_p, test_p, train_l, test_l = train_test_split(
        pairs, labels, test_size=0.2, stratify=labels, random_state=42
    )

    train_ds = gRNADataset(train_p, train_l)
    test_ds = gRNADataset(test_p, test_l)

    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = SiameseNetwork().to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # -------- Training --------
    for epoch in range(5):
        loss = train_epoch(model, train_loader, criterion, optimizer, device)
        probs, labels = evaluate_probs(model, test_loader, device)
        f1 = f1_score(labels, (probs > 0.5).astype(int))

        print(f"Epoch [{epoch+1}/5] | Loss: {loss:.4f} | F1: {f1:.4f}")

    # -------- Final Evaluation --------
    probs, labels = evaluate_probs(model, test_loader, device)
    mc_preds = mc_dropout_predict(model, test_loader, device)

    mean_probs = mc_preds.mean(axis=0)
    var_probs = mc_preds.var(axis=0)

    print("\nüìä Classification Metrics")
    print(classification_metrics(mean_probs, labels))

    print("\nüìà AUC Metrics")
    print(auc_metrics(mean_probs, labels))

    print("\nüìê Spearman Correlation")
    print("Spearman:", spearman_metric(mean_probs, labels))

    print("\nüì¶ Average Predictive Variance:", var_probs.mean())

    plot_calibration(mean_probs, labels)

    # -------- Save model --------
    save_path = "/content/drive/MyDrive/Deep Learning/siamese_lstm_mc_dropout.pth"
    torch.save(model.state_dict(), save_path)
    print("‚úÖ Model saved at:", save_path)


if __name__ == "__main__":
    main()


Epoch [1/5] | Loss: 0.0637 | F1: 0.0000
Epoch [2/5] | Loss: 0.0576 | F1: 0.0000
Epoch [3/5] | Loss: 0.0577 | F1: 0.0000
Epoch [4/5] | Loss: 0.0561 | F1: 0.0000
Epoch [5/5] | Loss: 0.0534 | F1: 0.0000


KeyboardInterrupt: 