In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score
import json

sys.path.insert(0, '../dlp')
sys.path.insert(1, 'ProtCLIP/proteinclip')
from proteinclip import model_utils


from batch import Batch

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 100_000
val_epoch = 100
num_val = 10
batch_size = 8
dataset_name = "new_corpus"
lr = 0.001
model_name = "Contrastive Learning"
max_seq_len = 500

from data_access import PQDataAccess
da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{dataset_name}", batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": model_name,
        "epochs": epochs,
        "batch_szie": batch_size,
        "max_seq_len": max_seq_len
    }
)

cuda:1
../checkpoints/Contrastive Learning_checkpoints


In [18]:
tax_ids_file = "../data/tax_ids.csv"

tax_ids = pd.read_csv(tax_ids_file)
# print(tax_ids)
num_classes = len(tax_ids) + 1
print(num_classes)
id_encoder = {name: idx + 1 for idx, name in enumerate(tax_ids['Taxonomic_lineage_IDs'].values)}

id_decoder = {idx + 1: name for idx, name in enumerate(tax_ids['Taxonomic_lineage_IDs'].values)}
id_decoder[0] = "NOT DEFINED"

# Character vocabulary for protein sequences (20 amino acids + 1 padding)
vocab = "ACDEFGHIKLMNPQRSTVWY"
char_to_idx = {char: idx + 1 for idx, char in enumerate(vocab)}  # Start index from 1 for padding
# Sequence encoder: Convert the protein sequence into integers
def encode_sequence(sequence):
    return [char_to_idx.get(char, 0) for char in sequence] + [0 for _ in range(max_seq_len - len(sequence))]  # 0 for unknown characters or padding 

def data_to_tensor_batch(b):
    inputs = torch.LongTensor([encode_sequence(e['sequence']) for e in b])

    tax_ids = torch.LongTensor([id_encoder.get(e['Taxonomic_lineage_IDs'], 0) for e in b])

    return Batch(inputs, tax_ids)

67486


In [57]:
import torch
import torch.nn as nn

class ONNXWithHead(nn.Module):
    def __init__(self, head_output_size):
        super().__init__()
        
        self.onnx_model = model_utils.load_proteinclip("esm", 33)  # For ESM2, 33-layer model
        onnx_output_size = 128  # Replace with your model's output size
        
        self.head = nn.Sequential(
            nn.Linear(onnx_output_size, head_output_size),
            nn.LogSoftmax(dim=1)
        )
    
    def forward(self, model_input):
#         # If input is numpy array, convert to torch
#         if isinstance(model_input, np.ndarray):
#             model_input = torch.from_numpy(model_input).float()
        
#         # Normalize
#         norm = torch.norm(model_input.cpu(), dim=-1, keepdim=True)
#         normalized_input = model_input / (norm + 1e-7)
        
        with torch.no_grad():
            base_output = self.onnx_model.predict(model_input.cpu().numpy())
            base_output = torch.tensor(base_output, dtype=torch.float32)
            
        return self.head(base_output)

In [58]:
model = ONNXWithHead(num_classes).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

criterion = nn.CrossEntropyLoss()
# Cosine annealing with warm restarts
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,  # Initial restart interval
    T_mult=2,  # Multiply interval by 2 after each restart
    eta_min=1e-6  # Minimum learning rate
)

In [59]:
val_dir = f"val_results/{model_name}"
if not os.path.exists(val_dir):
    os.makedirs(val_dir)
    
val_batches = [da.get_batch() for _ in range(num_val)]
# print(val_batches[0])

def evaluate(model):
    model.eval()  # Set model to evaluation mode
    
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for epoch in range(num_val):
        with torch.no_grad():  # Disable gradient computation during evaluation
            tensor_batch = data_to_tensor_batch(val_batches[epoch])
            tensor_batch.gpu(device)
        
            labels = tensor_batch.taxes
            outputs = model(tensor_batch.seq_ids)

            # Calculate the loss
            loss = criterion(outputs, labels)
    
            running_loss += loss.item()
                
            preds = torch.argmax(outputs, dim=1)
    
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    # Concatenate all batches into single tensors
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Compute evaluation metrics (example: accuracy, F1 score)
    accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
    f1_macro = f1_score(all_labels.numpy(), all_preds.numpy(), average='macro')  # F1-score for multi-label classification
    f1_micro = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')  # F1-score for multi-label classification
    # conf_matrix = confusion_matrix(all_labels.numpy(), all_preds.numpy())
    avg_loss = running_loss / num_val
    
    return avg_loss, accuracy, f1_micro, f1_macro

evaluate(model)

AssertionError: 

In [7]:
import glob
def load_checkpoint(model, optimizer=None, scheduler=None):
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))        
    # Extract epoch numbers and find latest
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    checkpoint = torch.load(latest_checkpoint)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    # Load optimizer state if provided (for training)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # Move optimizer state to GPU if necessary
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Get training metadata
    epoch = checkpoint['epoch']
    metrics = checkpoint['metrics']
    
    print(f"Successfully loaded checkpoint from epoch {epoch}")
    # print("Metrics at checkpoint:", metrics)
    
    return model, optimizer, scheduler, epoch, metrics
        

# model, optimizer, scheduler, latest_epoch, metrics = load_checkpoint(model, optimizer, scheduler)
latest_epoch = 0

In [8]:
running_loss = 0
current_lr = lr

for epoch in tqdm(range(latest_epoch + 1, latest_epoch + epochs + 1)):
    model.train()
    
    tensor_batch = data_to_tensor_batch(da.get_batch())
    tensor_batch.gpu(device)
    
    labels = tensor_batch.taxes
    outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])
    
    loss = criterion(outputs, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
    
    if (epoch + 1) % val_epoch == 0:
        train_loss = running_loss / val_epoch
        # Evaluate on validation set
        val_loss, val_accuracy, val_f1_micro, val_f1_macro = evaluate(model)
        
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"test Loss: {val_loss:.4f}, test Accuracy: {val_accuracy:.4f}")
        print(f"test F1 (micro): {val_f1_micro:.4f}, test F1 (macro): {val_f1_macro:.4f}")
        
        # Create metrics dictionary for saving
        metrics = {
            "train_loss": train_loss,
            "test_loss": val_loss,
            "test_accuracy": val_accuracy,
            "test_f1_micro": val_f1_micro,
            "test_f1_macro": val_f1_macro,
            "epoch": epoch + 1,
            "lr": current_lr
        }

        # Save periodic checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        
        # Log to wandb
        wandb.log(metrics)

        # Step the scheduler
        scheduler.step(epoch + loss.item())
        current_lr = scheduler.get_last_lr()[0]
        
        # Reset training metrics
        running_loss = 0

wandb.finish()

  0%|          | 0/100000 [00:00<?, ?it/s]


OutOfMemoryError: HIP out of memory. Tried to allocate 306.00 MiB. GPU 

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature: float = 0.07):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, protein_embeddings, lineage_embeddings):
        # Compute similarity matrix
        similarity = torch.matmul(protein_embeddings, lineage_embeddings.T)
        similarity = similarity / self.temperature
        
        # Labels are diagonal (matching pairs)
        labels = torch.arange(protein_embeddings.size(0), device=device)
        
        # Compute loss in both directions
        loss_protein = F.cross_entropy(similarity, labels)
        loss_lineage = F.cross_entropy(similarity.T, labels)
        
        return (loss_protein + loss_lineage) / 2

In [None]:
# Initialize models
protein_encoder = ProteinEncoder()
lineage_encoder = LineageEncoder(len(dataset.lineage_vocab))
criterion = ContrastiveLoss()

# Initialize optimizer
optimizer = torch.optim.Adam([
    {'params': protein_encoder.parameters()},
    {'params': lineage_encoder.parameters()}
], lr=1e-4)


total_loss = 0
protein_encoder.train()
lineage_encoder.train()

for batch in dataloader:
    # Move batch to device
    input_ids = batch['input_ids'].to(self.device)
    attention_mask = batch['attention_mask'].to(self.device)
    lineage = batch['lineage'].to(self.device)
    
    # Forward pass
    protein_embeddings = self.protein_encoder(input_ids, attention_mask)
    lineage_embeddings = self.lineage_encoder(lineage)
    
    # Compute loss
    loss = self.criterion(protein_embeddings, lineage_embeddings)
    
    # Backward pass
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()
    
    total_loss += loss.item()
    
return total_loss / len(dataloader)
print(f"Training loss: {loss:.4f}")

In [None]:
model, _, _, latest_epoch, metrics = load_checkpoint(model)

val_batches_ = [virus_da.get_batch() for _ in range(num_val // 2)] + [cellular_da.get_batch() for _ in range(num_val // 2)]

# input_sequences_ = [e['Sequence'] for b in val_batches_ for e in b]
# labels_ = [encode_lineage(e['Taxonomic_lineage__ALL_'])  for b in val_batches_ for e in b]

input_sequences_ = ["ACACAD"]
labels_ = [{0: 1}]

def evaluate_df(model):
    model.eval()  # Set model to evaluation mode
    
    df = {
        "sequence": [],
        "label": [],
        "pred": [],
        "loss": []
    }

    metrics = {
        "loss": 0,
        "accuracy": 0,
        "f1 macro": 0,
        "f1 micro": 0
    }
    
    # Process each sequence
    for sequence, label in zip(input_sequences_, labels_):
        inputs = tokenizer_(
            [sequence],
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=max_seq_len
        ).to(device)
    
        # Get model output
        with torch.no_grad():
            output = model(inputs['input_ids'], inputs['attention_mask'])

        pred = output.argmax(dim=-1).cpu().item()
        loss = criterion(output, torch.tensor([label[0]]).to(device))
        df["sequence"].append(sequence)
        df["label"].append(level_decoder[0][label[0]])
        df["pred"].append(level_decoder[0][pred])
        df["loss"].append(round(loss.cpu().item(), 4))

    # Convert to DataFrame
    new_df = pd.DataFrame(df)
    new_df['is_incorrect'] = new_df['label'] != new_df['pred']
    new_df = new_df.sort_values(['is_incorrect', 'loss'], ascending=[False, False])
    new_df.to_csv(f'classification_results__new_att.csv', index=False)

    metrics["loss"] = np.array(df["loss"]).mean()
    metrics["accuracy"] = accuracy_score(np.array(df["label"]), np.array(df["pred"]))
    metrics["f1 macro"] = f1_score(np.array(df["label"]), np.array(df["pred"]), average='macro')  # F1-score for multi-label classification
    metrics["f1 micro"] = f1_score(np.array(df["label"]), np.array(df["pred"]), average='micro') 
    print(metrics)

evaluate_df(model)