In [None]:
import os
import torch
import torchmetrics
from torch import nn
from torch.optim import Adam
from transformers import BertTokenizer, BertGenerationEncoder
from lightning import LightningModule, LightningDataModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
import re
import pickle
import pandas as pd

class ProteinClassifier(LightningModule):
    def __init__(self, n_classes=25):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
        self.embedder = BertGenerationEncoder.from_pretrained("Rostlab/prot_bert")
        dmodel = 1024
        self.model = nn.Linear(dmodel, n_classes)  # Classification head
        
        self.criterion = nn.CrossEntropyLoss()
        self.val_accuracy = torchmetrics.classification.Accuracy(task="multiclass",
                                                                 num_classes=n_classes)
        self.train_accuracy = torchmetrics.classification.Accuracy(task="multiclass",
                                                                   num_classes=n_classes)
        self.val_f1 = torchmetrics.classification.F1Score(task="multiclass", num_classes=n_classes)
    
    def forward(self, x):
        lengths = torch.tensor([len(i) for i in x]).to(self.device)
        ids = self.tokenizer(x, add_special_tokens=True, padding="longest")
        input_ids = torch.tensor(ids['input_ids']).to(self.device)
        attention_mask = torch.tensor(ids['attention_mask']).to(self.device).to(self.dtype)
        with torch.no_grad():
            embeddings = self.embedder(input_ids=input_ids,
                                       attention_mask=attention_mask).last_hidden_state
        embeddings = embeddings.sum(dim=1) / lengths.view(-1, 1)
        logits = self.model(embeddings)
        return logits
    
    def training_step(self, batch, batch_idx):
        x, y = batch  # Directly unpack the collated batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.train_accuracy(preds, y)
        self.log("train_loss", loss)
        self.log("train_acc", self.train_accuracy, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch  # Directly unpack the collated batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy(preds, y)
        self.val_f1(preds, y)
        self.log("val_loss", loss)
        self.log("val_acc", self.val_accuracy, prog_bar=True)
        self.log("val_f1", self.val_f1, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=5e-5)


class PAFDatamodule(LightningDataModule):
    def __init__(self, root_path, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.root = root_path
        self.classes = pickle.load(open(f"{root_path}/selected_families.pkl", "rb"))
    
    def encode_classes(self, y):
        cls2idx = dict(zip(self.classes, range(len(self.classes))))
        return [cls2idx[i] for i in y]
    
    def get_dataset(self, part, with_target=True):
        file_path = f"{self.root}/{part}_data.csv"
        df = pd.read_csv(file_path)
        
        # Replace rare amino acids (X, U, B, O, Z) with 'X'
        sequences = df.loc[:, "sequence"].values
        sequences = [re.sub(r'[UBOZ]', 'X', seq) for seq in sequences]
        
        x = sequences
        
        if with_target:
            y = df.loc[:, "family_id"].values
            y = torch.tensor(self.encode_classes(y))
            return list(zip(x, y))
        else:
            # For test data without labels
            sequence_names = df.loc[:, "sequence_name"].values
            return list(zip(x, sequence_names))
    
    def train_dataloader(self):
        data = self.get_dataset("train")
        return DataLoader(data, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        data = self.get_dataset("val")
        return DataLoader(data, batch_size=self.batch_size, shuffle=False)
    
    def test_dataloader(self):
        data = self.get_dataset("test", with_target=False)
        return DataLoader(data, batch_size=self.batch_size, shuffle=False)
    
    def predict_dataloader(self):
        data = self.get_dataset("test", with_target=False)
        return DataLoader(data, batch_size=self.batch_size, shuffle=False)


if __name__ == "__main__":
    # Set parameters
    n_classes = 25
    data_root = "datafiles"
    
    # Create directories for saving models
    os.makedirs("models/baseline", exist_ok=True)
    os.makedirs("models/finetuned", exist_ok=True)
    
    # Initialize data module
    datamodule = PAFDatamodule(data_root, batch_size=16)
    
    # PART 1: Baseline model (frozen ProtBERT)
    print("Setting up baseline model...")
    baseline_model = ProteinClassifier(n_classes=n_classes)
    
    # Freeze the embedder for baseline model
    for param in baseline_model.embedder.parameters():
        param.requires_grad = False
    
    # Define callbacks for baseline model
    baseline_checkpoint = ModelCheckpoint(
        dirpath="models/baseline/",
        filename="baseline-{epoch:02d}-{val_acc:.4f}",
        monitor="val_acc",
        mode="max",
        save_top_k=1
    )
    
    early_stop = EarlyStopping(
        monitor="val_loss",
        min_delta=0.00,
        patience=3,
        verbose=True,
        mode="min"
    )
    
    # Set up the TensorBoard logger
    logger = TensorBoardLogger("tb_logs", name="protein_classifier")
    
    # Train baseline model
    print("Training baseline model (frozen ProtBERT)...")
    baseline_trainer = Trainer(
        max_epochs=5,
        callbacks=[baseline_checkpoint, early_stop],
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        gradient_clip_val=1.0,
        check_val_every_n_epoch=1,
        logger=logger  
    )
    
    baseline_trainer.fit(model=baseline_model, datamodule=datamodule)
    
    # Save the final baseline model
    baseline_trainer.save_checkpoint("models/baseline/final_baseline_model.ckpt")
    
    # Save best model path for reference
    with open("models/baseline/best_model_path.txt", "w") as f:
        f.write(baseline_checkpoint.best_model_path)
    
    print(f"Baseline model saved at: {baseline_checkpoint.best_model_path}")


Setting up baseline model...
Training baseline model (frozen ProtBERT)...


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\Anorm\anaconda3\envs\crystal\lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:654: Checkpoint directory C:\Users\Anorm\Downloads\models\baseline exists and is not empty.

  | Name           | Type                  | Params | Mode 
-----------------------------------------------------------------
0 | embedder       | BertGenerationEncoder | 418 M  | eval 
1 | model          | Linear                | 25.6 K | train
2 | criterion      | CrossEntropyLoss      | 0      | train
3 | val_accuracy   | MulticlassAccuracy    | 0      | train
4 | train_accuracy | MulticlassAccuracy    | 0      | train
5 | val_f1         | MulticlassF1Score     | 0      | train
-----------------------------------------------------------------
25.6 K    Trainable params
418 M     Non-trainable params
418 M     Total params
1,675.620 Total estimated model params size (MB)
5         

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\Anorm\anaconda3\envs\crystal\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 2082/2082 [35:45<00:00,  0.97it/s, v_num=17]     


c:\Users\Anorm\anaconda3\envs\crystal\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 2082/2082 [14:52<00:00,  2.33it/s, v_num=0, train_acc=0.333, val_acc=0.162, val_f1=0.162]

Metric val_loss improved. New best score: 3.122


Epoch 1: 100%|██████████| 2082/2082 [14:43<00:00,  2.36it/s, v_num=0, train_acc=0.000, val_acc=0.162, val_f1=0.162] 

Metric val_loss improved by 0.055 >= min_delta = 0.0. New best score: 3.067


Epoch 2: 100%|██████████| 2082/2082 [18:04<00:00,  1.92it/s, v_num=0, train_acc=0.000, val_acc=0.0887, val_f1=0.0887]

Metric val_loss improved by 0.034 >= min_delta = 0.0. New best score: 3.033


Epoch 3: 100%|██████████| 2082/2082 [37:30<00:00,  0.93it/s, v_num=0, train_acc=0.333, val_acc=0.102, val_f1=0.102]   

Metric val_loss improved by 0.023 >= min_delta = 0.0. New best score: 3.010


Epoch 4: 100%|██████████| 2082/2082 [24:23<00:00,  1.42it/s, v_num=0, train_acc=0.000, val_acc=0.153, val_f1=0.153] 

Metric val_loss improved by 0.017 >= min_delta = 0.0. New best score: 2.993
`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 2082/2082 [24:23<00:00,  1.42it/s, v_num=0, train_acc=0.000, val_acc=0.153, val_f1=0.153]
Baseline model saved at: C:\Users\Anorm\Downloads\models\baseline\baseline-epoch=01-val_acc=0.1622.ckpt


In [None]:
# PART 2: Fine-tuned model (unfreeze ProtBERT and fine-tune)
print("Setting up fine-tuned model...")
finetuned_model = ProteinClassifier(n_classes=n_classes)
    
    
    
    # Define callbacks for finetuned model
finetuned_checkpoint = ModelCheckpoint(
    dirpath="models/finetuned/",
    filename="finetuned-{epoch:02d}-{val_f1:.4f}",
    monitor="val_f1",
    mode="max",
    save_top_k=1
    )
    
    # Train finetuned model
print("Training fine-tuned model (unfrozen ProtBERT)...")
finetuned_trainer = Trainer(
    max_epochs=2,
    callbacks=[finetuned_checkpoint, early_stop],
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    gradient_clip_val=1.0
    )
    
finetuned_trainer.fit(model=finetuned_model, datamodule=datamodule)
    
    # Save the final finetuned model
finetuned_trainer.save_checkpoint("models/finetuned/final_finetuned_model.ckpt")
    
    # Save best model path
with open("models/finetuned/best_model_path.txt", "w") as f:
    f.write(finetuned_checkpoint.best_model_path)
    
print(f"Fine-tuned model saved at: {finetuned_checkpoint.best_model_path}")

Setting up fine-tuned model...
Training fine-tuned model (unfrozen ProtBERT)...


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type                  | Params | Mode 
-----------------------------------------------------------------
0 | embedder       | BertGenerationEncoder | 418 M  | eval 
1 | model          | Linear                | 25.6 K | train
2 | criterion      | CrossEntropyLoss      | 0      | train
3 | val_accuracy   | MulticlassAccuracy    | 0      | train
4 | train_accuracy | MulticlassAccuracy    | 0      | train
5 | val_f1         | MulticlassF1Score     | 0      | train
-----------------------------------------------------------------
418 M     Trainable params
0         Non-trainable params
418 M     Total params
1,675.620 Total estimated model params size (MB)
5         Modules in train mode
548       Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\Anorm\anaconda3\envs\crystal\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

c:\Users\Anorm\anaconda3\envs\crystal\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 1: 100%|██████████| 2082/2082 [15:16<00:00,  2.27it/s, v_num=19, train_acc=0.000, val_acc=0.0718, val_f1=0.0718]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 2082/2082 [15:16<00:00,  2.27it/s, v_num=19, train_acc=0.000, val_acc=0.0718, val_f1=0.0718]
Fine-tuned model saved at: C:\Users\Anorm\Downloads\models\finetuned\finetuned-epoch=00-val_f1=0.1622.ckpt


In [43]:
# Function to generate predictions from a saved model
def generate_predictions(model_path, output_filename):
    print(f"Generating predictions using model: {model_path}")
        
        # Load the saved model
    model = ProteinClassifier.load_from_checkpoint(model_path)
    model.eval()
    model.to("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load test data
    test_df = pd.read_csv(f"{data_root}/test_data.csv")
    sequences = test_df.loc[:, "sequence"].values
        # Replace rare amino acids
    sequences = [re.sub(r'[UBOZ]', 'X', seq) for seq in sequences]
    sequence_names = test_df.loc[:, "sequence_name"].values
        
        # Process sequences in batches
    batch_size = 32
    predictions = []
        
    with torch.no_grad():
        for i in range(0, len(sequences), batch_size):
            print(f"Processing batch {i//batch_size + 1}/{(len(sequences) + batch_size - 1)//batch_size}")
            batch_sequences = sequences[i:i+batch_size]
            logits = model(batch_sequences)
            preds = torch.argmax(logits, dim=1).cpu().numpy()
                # Convert indices back to family_ids
            pred_classes = [datamodule.classes[idx] for idx in preds]
            predictions.extend(pred_classes)
        
        # Create submission file
    submission_df = pd.DataFrame({
        "sequence_name": sequence_names,
        "family_id": predictions
        })
        
    submission_df.to_csv(output_filename, index=False)
    print(f"Submission file created at {output_filename} with {len(submission_df)} predictions")
    
    # Generate predictions for both models
generate_predictions(baseline_checkpoint.best_model_path, "baseline_submission.csv")
generate_predictions(finetuned_checkpoint.best_model_path, "finetuned_submission.csv")

Generating predictions using model: C:\Users\Anorm\Downloads\models\baseline\baseline-epoch=01-val_acc=0.1622.ckpt
Processing batch 1/131
Processing batch 2/131
Processing batch 3/131
Processing batch 4/131
Processing batch 5/131
Processing batch 6/131
Processing batch 7/131
Processing batch 8/131
Processing batch 9/131
Processing batch 10/131
Processing batch 11/131
Processing batch 12/131
Processing batch 13/131
Processing batch 14/131
Processing batch 15/131
Processing batch 16/131
Processing batch 17/131
Processing batch 18/131
Processing batch 19/131
Processing batch 20/131
Processing batch 21/131
Processing batch 22/131
Processing batch 23/131
Processing batch 24/131
Processing batch 25/131
Processing batch 26/131
Processing batch 27/131
Processing batch 28/131
Processing batch 29/131
Processing batch 30/131
Processing batch 31/131
Processing batch 32/131
Processing batch 33/131
Processing batch 34/131
Processing batch 35/131
Processing batch 36/131
Processing batch 37/131
Proces

In [None]:
import torch
from lightning import LightningModule
from torch import nn
from torch.optim import AdamW
import torchmetrics
from transformers import AutoTokenizer, AutoModel
import os
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
import pandas as pd
import re
from datamodule import PAFDatamodule

class ESM2Classifier(LightningModule):
    def __init__(self, n_classes=25):
        super().__init__()
        self.model_name = "facebook/esm2_t12_35M_UR50D"  # Smaller ESM-2 model
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.embedder = AutoModel.from_pretrained(self.model_name)
        
        # Get embedding dimension
        embedding_dim = self.embedder.config.hidden_size
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, n_classes)
        )
        
        self.criterion = nn.CrossEntropyLoss()
        self.val_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=n_classes)
        self.train_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=n_classes)
        self.val_f1 = torchmetrics.classification.F1Score(task="multiclass", num_classes=n_classes)
    
    def forward(self, x):
        # Tokenize sequences
        encoding = self.tokenizer(x, return_tensors="pt", padding=True, truncation=True, max_length=1024)
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        
        # Get sequence embeddings
        outputs = self.embedder(input_ids=input_ids, attention_mask=attention_mask)
        
        # Use CLS token representation
        embeddings = outputs.last_hidden_state[:, 0, :]
        
        # Pass through classification head
        logits = self.classifier(embeddings)
        return logits
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.train_accuracy(preds, y)
        self.log("train_loss", loss)
        self.log("train_acc", self.train_accuracy, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy(preds, y)
        self.val_f1(preds, y)
        self.log("val_loss", loss)
        self.log("val_acc", self.val_accuracy, prog_bar=True)
        self.log("val_f1", self.val_f1, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        # Use different learning rates for pretrained model and classifier
        optimizer = AdamW([
            {"params": self.embedder.parameters(), "lr": 1e-5},
            {"params": self.classifier.parameters(), "lr": 5e-5}
        ])
        
        # Add a learning rate scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=10, eta_min=1e-6
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch"
            }
        }

if __name__ == "__main__":
    # Set parameters
    n_classes = 25
    data_root = "datafiles"
    
    # Create directory for saving models
    os.makedirs("models/esm2", exist_ok=True)
    
    # Initialize data module
    datamodule = PAFDatamodule(data_root, batch_size=8)  # Smaller batch for larger model
    
    # Initialize ESM2 model
    print("Setting up ESM2 model...")
    esm2_model = ESM2Classifier(n_classes=n_classes)
    
    # Define callbacks
    checkpoint = ModelCheckpoint(
        dirpath="models/esm2/",
        filename="esm2-{epoch:02d}-{val_f1:.4f}",
        monitor="val_f1",
        mode="max",
        save_top_k=1
    )
    
    early_stop = EarlyStopping(
        monitor="val_f1",
        patience=3,
        verbose=True,
        mode="max"
    )
    
    # Setup TensorBoard logger to log training metrics
    logger = TensorBoardLogger("tb_logs", name="esm2_classifier")
    
    # Train model
    print("Training ESM2 model...")
    trainer = Trainer(
        max_epochs=1,
        callbacks=[checkpoint, early_stop],
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        gradient_clip_val=1.0,
        logger=logger  
    )
    
    trainer.fit(model=esm2_model, datamodule=datamodule)
    
    # Save the final model
    trainer.save_checkpoint("models/esm2/final_esm2_model.ckpt")
    
    # Save best model path
    with open("models/esm2/best_model_path.txt", "w") as f:
        f.write(checkpoint.best_model_path)
    
    print(f"ESM2 model saved at: {checkpoint.best_model_path}")
    


Setting up ESM2 model...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Training ESM2 model...


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | embedder       | EsmModel           | 34.0 M | eval 
1 | classifier     | Sequential         | 259 K  | train
2 | criterion      | CrossEntropyLoss   | 0      | train
3 | val_accuracy   | MulticlassAccuracy | 0      | train
4 | train_accuracy | MulticlassAccuracy | 0      | train
5 | val_f1         | MulticlassF1Score  | 0      | train
--------------------------------------------------------------
34.3 M    Trainable params
0         Non-trainable params
34.3 M    Total params
137.008   Total estimated model params size (MB)
9         Modules in train mode
230       Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\Anorm\anaconda3\envs\crystal\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

c:\Users\Anorm\anaconda3\envs\crystal\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0:   0%|          | 0/4163 [27:37<?, ?it/s]                  =3, train_acc=1.000]
Epoch 0: 100%|██████████| 4163/4163 [3:28:31<00:00,  0.33it/s, v_num=3, train_acc=1.000, val_acc=1.000, val_f1=1.000]

Metric val_f1 improved. New best score: 1.000
`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 4163/4163 [3:28:34<00:00,  0.33it/s, v_num=3, train_acc=1.000, val_acc=1.000, val_f1=1.000]
ESM2 model saved at: C:\Users\Anorm\Downloads\models\esm2\esm2-epoch=00-val_f1=0.9995.ckpt


In [52]:
# Function to generate predictions from the ESM2 model
def generate_esm2_predictions(model_path, output_filename):
    print(f"Generating predictions using ESM2 model: {model_path}")
        
        # Load the saved model
    model = ESM2Classifier.load_from_checkpoint(model_path)
    model.eval()
    model.to("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load test data
    test_df = pd.read_csv(f"{data_root}/test_data.csv")
    sequences = test_df.loc[:, "sequence"].values
        # Replace rare amino acids
    sequences = [re.sub(r'[UBOZ]', 'X', seq) for seq in sequences]
    sequence_names = test_df.loc[:, "sequence_name"].values
        
        # Process sequences in batches
    batch_size = 16  # Adjust batch size as needed
    predictions = []
        
    with torch.no_grad():
        for i in range(0, len(sequences), batch_size):
            print(f"Processing batch {i//batch_size + 1}/{(len(sequences) + batch_size - 1)//batch_size}")
            batch_sequences = sequences[i:i+batch_size]
            logits = model(batch_sequences)
            preds = torch.argmax(logits, dim=1).cpu().numpy()
                # Convert indices back to family_ids
            pred_classes = [datamodule.classes[idx] for idx in preds]
            predictions.extend(pred_classes)
        
        # Create submission file
    submission_df = pd.DataFrame({
        "sequence_name": sequence_names,
        "family_id": predictions
        })
        
    submission_df.to_csv(output_filename, index=False)
    print(f"ESM2 submission file created at {output_filename} with {len(submission_df)} predictions")
    
    # Generate predictions
generate_esm2_predictions(checkpoint.best_model_path, "esm2_submission.csv")

Generating predictions using ESM2 model: C:\Users\Anorm\Downloads\models\esm2\esm2-epoch=00-val_f1=0.9995.ckpt


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Processing batch 1/261
Processing batch 2/261
Processing batch 3/261
Processing batch 4/261
Processing batch 5/261
Processing batch 6/261
Processing batch 7/261
Processing batch 8/261
Processing batch 9/261
Processing batch 10/261
Processing batch 11/261
Processing batch 12/261
Processing batch 13/261
Processing batch 14/261
Processing batch 15/261
Processing batch 16/261
Processing batch 17/261
Processing batch 18/261
Processing batch 19/261
Processing batch 20/261
Processing batch 21/261
Processing batch 22/261
Processing batch 23/261
Processing batch 24/261
Processing batch 25/261
Processing batch 26/261
Processing batch 27/261
Processing batch 28/261
Processing batch 29/261
Processing batch 30/261
Processing batch 31/261
Processing batch 32/261
Processing batch 33/261
Processing batch 34/261
Processing batch 35/261
Processing batch 36/261
Processing batch 37/261
Processing batch 38/261
Processing batch 39/261
Processing batch 40/261
Processing batch 41/261
Processing batch 42/261
P

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import pandas as pd
import umap.umap_ as umap
from sklearn.manifold import TSNE
import pickle
import re
import sys
sys.argv = [sys.argv[0]]  # Clear any arguments passed to IPython

def extract_embeddings(model, sequences, batch_size=8):
    """Extract embeddings for visualization"""
    model.eval()
    embeddings = []
    
    with torch.no_grad():
        for i in range(0, len(sequences), batch_size):
            batch_sequences = sequences[i:i+batch_size]
            # Get embeddings depending on model type
            if hasattr(model, 'embedder') and model.embedder.__class__.__name__ == 'BertGenerationEncoder':
                # ProtBERT model
                lengths = torch.tensor([len(seq) for seq in batch_sequences]).to(model.device)
                ids = model.tokenizer(batch_sequences, add_special_tokens=True, padding="longest")
                input_ids = torch.tensor(ids['input_ids']).to(model.device)
                attention_mask = torch.tensor(ids['attention_mask']).to(model.device)
                
                outputs = model.embedder(input_ids=input_ids, attention_mask=attention_mask)
                batch_embeddings = outputs.last_hidden_state.sum(dim=1)/lengths.view(-1, 1)
            else:
                # ESM2 model
                encoding = model.tokenizer(batch_sequences, return_tensors="pt", padding=True, truncation=True)
                input_ids = encoding['input_ids'].to(model.device)
                attention_mask = encoding['attention_mask'].to(model.device)
                
                outputs = model.embedder(input_ids=input_ids, attention_mask=attention_mask)
                batch_embeddings = outputs.last_hidden_state[:, 0, :]  # CLS token
            
            embeddings.append(batch_embeddings.cpu().numpy())
    
    return np.vstack(embeddings)

def visualize_embeddings(model_path, output_prefix, data_root="datafiles"):
    """Create UMAP and t-SNE visualizations of protein embeddings"""
    print(f"Generating visualizations for model: {model_path}")
    
    # Determine model type and load it
    if "esm2" in model_path:
        from esm2_model import ESM2Classifier
        model = ESM2Classifier.load_from_checkpoint(model_path)
    else:
        from prot_bert import ProteinClassifier
        model = ProteinClassifier.load_from_checkpoint(model_path)
    
    model.eval()
    model.to("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load validation data (with known labels) for visualization
    val_df = pd.read_csv(f"{data_root}/val_data.csv")
    sequences = val_df.loc[:, "sequence"].values
    # Replace rare amino acids
    sequences = [re.sub(r'[UBOZ]', 'X', seq) for seq in sequences]
    
    # Get family_ids and map to indices
    family_ids = val_df.loc[:, "family_id"].values
    classes = pickle.load(open(f"{data_root}/selected_families.pkl", "rb"))
    cls2idx = dict(zip(classes, range(len(classes))))
    labels = np.array([cls2idx[fam] for fam in family_ids])
    
    # Limit to 1000 sequences for faster visualization
    if len(sequences) > 1000:
        indices = np.random.choice(len(sequences), 1000, replace=False)
        sequences = [sequences[i] for i in indices]
        labels = labels[indices]
    
    # Extract embeddings
    print("Extracting embeddings...")
    embeddings = extract_embeddings(model, sequences)
    
    # UMAP visualization
    print("Generating UMAP visualization...")
    reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
    umap_embeddings = reducer.fit_transform(embeddings)
    
    # Plot UMAP
    plt.figure(figsize=(12, 10))
    unique_labels = np.unique(labels)
    colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
    
    for i, label in enumerate(unique_labels):
        idx = labels == label
        plt.scatter(umap_embeddings[idx, 0], umap_embeddings[idx, 1], 
                   c=[colors[i]], label=classes[label], alpha=0.7, s=10)
    
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.title('UMAP projection of protein embeddings')
    plt.tight_layout()
    plt.savefig(f'{output_prefix}_umap.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # t-SNE visualization
    print("Generating t-SNE visualization...")
    tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42)
    tsne_embeddings = tsne.fit_transform(embeddings)
    
    plt.figure(figsize=(12, 10))
    for i, label in enumerate(unique_labels):
        idx = labels == label
        plt.scatter(tsne_embeddings[idx, 0], tsne_embeddings[idx, 1], 
                   c=[colors[i]], label=classes[label], alpha=0.7, s=10)
    
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.title('t-SNE projection of protein embeddings')
    plt.tight_layout()
    plt.savefig(f'{output_prefix}_tsne.png', dpi=300, bbox_inches='tight')
    
    print(f"Visualizations saved as {output_prefix}_umap.png and {output_prefix}_tsne.png")

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Visualize protein embeddings")
    parser.add_argument("--model", type=str, help="Path to the saved model checkpoint")
    parser.add_argument("--output", type=str, default="embeddings", help="Output filename prefix")
    parser.add_argument("--data", type=str, default="datafiles", help="Path to data directory")
    
    args = parser.parse_args()
    
    # If no model path is provided, try to read from the best model path files
    if args.model is None:
        model_paths = []
        for model_type in ["baseline", "finetuned", "esm2"]:
            try:
                with open(f"models/{model_type}/best_model_path.txt", "r") as f:
                    model_paths.append((f.read().strip(), f"{args.output}_{model_type}"))
            except FileNotFoundError:
                print(f"Could not find best_model_path.txt for {model_type} model.")
        
        if not model_paths:
            print("No model paths found. Please specify a model path.")
            exit(1)
        
        # Visualize all available models
        for model_path, output_prefix in model_paths:
            visualize_embeddings(model_path, output_prefix, args.data)
    else:
        # Visualize the specified model
        visualize_embeddings(args.model, args.output, args.data)