In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    PROMOTER_CSV_PATH = "/content/drive/MyDrive/Colab Notebooks/promoter_dataset.csv"
except ImportError:
    PROMOTER_CSV_PATH = "promoter_dataset.csv"

print("Using promoter CSV at:", PROMOTER_CSV_PATH)

# ----------------- Imports -----------------
import csv
import os
import random
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.optim as optim

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

NUC2IDX = {"A": 0, "C": 1, "G": 2, "T": 3}


def clean_seq(seq: str) -> str:
    """Uppercase and replace non-ACGT characters with 'A'."""
    seq = seq.upper().strip()
    return "".join([ch if ch in NUC2IDX else "A" for ch in seq])


def load_sequence_csv(path: str) -> List[Tuple[str, int]]:
    """Load CSV with columns: sequence,label."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"CSV not found at: {path}")

    data = []
    with open(path, "r", newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            seq = clean_seq(row["sequence"])
            label = int(row["label"])
            data.append((seq, label))
    return data


def split_dataset(data: List[Tuple[str, int]], val_ratio: float = 0.2):
    random.shuffle(data)
    n_val = int(len(data) * val_ratio)
    return data[n_val:], data[:n_val]


def batchify(data: List[Tuple[str, int]], batch_size: int):
    random.shuffle(data)
    for i in range(0, len(data), batch_size):
        batch = data[i:i + batch_size]
        seqs = [x[0] for x in batch]
        labels = torch.tensor([x[1] for x in batch], dtype=torch.long)
        yield seqs, labels


# ----------------- Improved OQFA model -----------------
class ImprovedOQFA(nn.Module):
    """
    Stable quantum-inspired classifier:
    - orthogonal transitions
    - normalized hidden states
    - dropout before measurement
    """

    def __init__(self, num_states: int = 32, num_classes: int = 2, dropout_p: float = 0.3):
        super().__init__()
        self.num_states = num_states

        self.transitions = nn.Parameter(torch.empty(4, num_states, num_states))
        for i in range(4):
            nn.init.orthogonal_(self.transitions.data[i])

        self.initial_state = nn.Parameter(torch.zeros(num_states))
        with torch.no_grad():
            self.initial_state[0] = 1.0

        self.dropout = nn.Dropout(dropout_p)
        self.classifier = nn.Linear(num_states, num_classes)

    def step(self, h: torch.Tensor, ch: str) -> torch.Tensor:
        idx = NUC2IDX.get(ch, 0)
        h = torch.matmul(h, self.transitions[idx])
        return h / (h.norm() + 1e-8)

    def forward_on_sequence(self, seq: str) -> torch.Tensor:
        h = self.initial_state
        for ch in seq:
            h = self.step(h, ch)
        h = self.dropout(h)
        return self.classifier(h)

    def forward(self, seqs: List[str]) -> torch.Tensor:
        return torch.stack([self.forward_on_sequence(s) for s in seqs], dim=0)


# ----------------- Training & evaluation -----------------
def evaluate_model(model: nn.Module, data: List[Tuple[str, int]], batch_size: int = 16, silent=False):
    device = next(model.parameters()).device
    model.eval()
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for seqs, labels in batchify(data, batch_size):
            labels = labels.to(device)
            preds = model(seqs).argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)

    acc = total_correct / max(1, total_samples)
    if not silent:
        print(f"Accuracy: {acc:.3f}")
    return acc


def train_oqfa(
    train_data: List[Tuple[str, int]],
    val_data: List[Tuple[str, int]],
    num_states: int = 32,
    num_epochs: int = 80,
    batch_size: int = 16,
    lr: float = 3e-3,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    model = ImprovedOQFA(num_states=num_states, num_classes=2, dropout_p=0.3).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0

        for seqs, labels in batchify(train_data, batch_size):
            labels = labels.to(device)
            logits = model(seqs)
            loss = criterion(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)
            preds = logits.argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)

        train_loss = total_loss / total_samples
        train_acc = total_correct / total_samples
        val_acc = evaluate_model(model, val_data, batch_size, silent=True)

        if epoch % 5 == 0 or epoch == 1:
            print(
                f"Epoch {epoch:02d} | "
                f"Train Loss: {train_loss:.4f} | "
                f"Train Acc: {train_acc:.3f} | "
                f"Val Acc: {val_acc:.3f}"
            )

    return model


# ----------------- Main execution -----------------
print("\nLoading PROMOTER dataset from:", PROMOTER_CSV_PATH)
promoter_data = load_sequence_csv(PROMOTER_CSV_PATH)
print("Total promoter samples:", len(promoter_data))
print("First 3 samples:", promoter_data[:3])

train_data, val_data = split_dataset(promoter_data, val_ratio=0.2)
print(f"Train size: {len(train_data)}, Val size: {len(val_data)}")

model = train_oqfa(
    train_data=train_data,
    val_data=val_data,
    num_states=32,
    num_epochs=80,
    batch_size=16,
    lr=3e-3,
)

print("\nFinal evaluation on validation set:")
final_acc = evaluate_model(model, val_data, batch_size=16)
print("Final validation accuracy:", final_acc)
