# Ananya Agrawal (ananyaa2)

In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torchmetrics
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
import pickle
from sklearn.manifold import TSNE
import umap
import numpy as np
import matplotlib.pyplot as plt

# Configuration
MODEL_NAME = "Rostlab/prot_bert"
BATCH_SIZE = 32  
EPOCHS = 50
DATA_DIR = "datafiles"
NUM_WORKERS = 8

# Device Handling
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

# Load Data
train_df = pd.read_csv(os.path.join(DATA_DIR, "train_data.csv"))
val_df = pd.read_csv(os.path.join(DATA_DIR, "val_data.csv"))
test_df = pd.read_csv(os.path.join(DATA_DIR, "test_data.csv"))
classes = pickle.load(open(os.path.join(DATA_DIR, "selected_families.pkl"), "rb"))
cls2idx = {cls: idx for idx, cls in enumerate(classes)}

def preprocess_sequence(seq):
    return seq.replace("X", "").replace("U", "").replace("B", "").replace("O", "").replace("Z", "")

# Dataset Class
class ProteinDataset(Dataset):
    def __init__(self, df, with_labels=True):
        self.sequences = df["sequence"].apply(preprocess_sequence).values
        self.labels = df["family_id"].map(cls2idx).values if with_labels else None
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)
        self.with_labels = with_labels

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        encoding = self.tokenizer(self.sequences[idx], padding="max_length", truncation=True, max_length=256, return_tensors="pt")
        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)
        if self.with_labels:
            label = torch.tensor(self.labels[idx])
            return input_ids, attention_mask, label
        return input_ids, attention_mask

# Data Loaders
train_loader = DataLoader(ProteinDataset(train_df), batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(ProteinDataset(val_df), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_loader = DataLoader(ProteinDataset(test_df, with_labels=False), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Model Definition
class ProteinClassifier(LightningModule):
    def __init__(self, n_classes=len(classes)):
        super().__init__()
        self.model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=n_classes)
        self.model.gradient_checkpointing_enable()
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(self.device), attention_mask.to(self.device), labels.to(self.device)
        logits = self(input_ids, attention_mask)
        loss = self.criterion(logits, labels)
        acc = self.accuracy(logits.softmax(dim=1), labels)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(self.device), attention_mask.to(self.device), labels.to(self.device)
        logits = self(input_ids, attention_mask)
        loss = self.criterion(logits, labels)
        acc = self.accuracy(logits.softmax(dim=1), labels)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=2e-5)

# Training
model = ProteinClassifier()
trainer = Trainer(max_epochs=EPOCHS, callbacks=[ModelCheckpoint(monitor="val_acc", mode="max")], accumulate_grad_batches=2, accelerator="gpu" if torch.cuda.is_available() else "auto")
trainer.fit(model, train_loader, val_loader)

# Inference
model.eval()
predictions = []
with torch.no_grad():
    for input_ids, attention_mask in test_loader:
        input_ids, attention_mask = input_ids.to(model.device), attention_mask.to(model.device)
        logits = model(input_ids, attention_mask)
        preds = logits.argmax(dim=1).cpu().numpy()
        predictions.extend(preds)

# Create submission file
submission = pd.read_csv(os.path.join(DATA_DIR, "sample_submission.csv"))
submission["family_id"] = [classes[p] for p in predictions]
submission.to_csv("submission.csv", index=False)

# UMAP Visualization
embeddings = []
labels = []
with torch.no_grad():
    for input_ids, attention_mask, label in train_loader:
        input_ids, attention_mask = input_ids.to(model.device), attention_mask.to(model.device)
        embs = model.model.base_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :].cpu().numpy()
        embeddings.extend(embs)
        labels.extend(label.cpu().numpy())

embeddings = np.array(embeddings)
labels = np.array(labels)

reducer = umap.UMAP(n_components=2)
umap_embeddings = reducer.fit_transform(embeddings)

plt.figure(figsize=(10, 6))
plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], c=labels, cmap='viridis', alpha=0.6)
plt.colorbar()
plt.title("Protein Embeddings - UMAP Visualization")
plt.savefig("umap_visualization.png")
plt.close()

# Accuracy plot
plt.figure(figsize=(10, 5))
plt.plot(label="Train Accuracy")
plt.plot(label="Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training and Validation Accuracy Over Epochs (Simulated)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("accuracy_plot_real_data_simulated.png")
plt.close()
