In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score

sys.path.insert(0, '../dlp')
from data_process import *

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(device)

epochs= 10_000
val_epoch = 500
num_val = 1000

model_name = "esm_hierarchy"
checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project="test classification",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.02,
    "architecture": "CNN",
    "dataset": "CIFAR-100",
    "epochs": 10,
    "config": "onehot" 
    }
)

config = "embedding"

  from .autonotebook import tqdm as notebook_tqdm


Loaded dictionary.
30522
cuda:1
../checkpoints/esm_hierarchy_checkpoints


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malirezanor[0m ([33malirezanor-310-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
class ESMFNNClassifier(nn.Module):
    def __init__(self, model, embedding_dim, num_classes, max_seq_len, hidden_dim):
        super(FNNClassifier, self).__init__()
        self.model = model
        self.fc1 = nn.Linear(embedding_dim * max_seq_len, hidden_dim)  # Output size depends on conv and pooling layers
        self.fc2 = nn.Linear(hidden_dim, num_classes, bias=False)
        self.fc3 = nn.Linear(embedding_dim, hidden_dim)

    def forward(self, x):
        if self.model == "Flat":
            x = x.view(x.size(0), -1)  # Flatten the output for the fully connected layer
            x = torch.relu(self.fc1(x))

        elif self.model == "Mean":
            x = x.mean(dim=1)
            x = torch.relu(self.fc3(x))

        return self.fc2(x)

In [23]:
max_seq_len = 1000
max_tax_len = 150

# Character vocabulary for protein sequences (20 amino acids + 1 padding)
vocab = "ACDEFGHIKLMNPQRSTVWY"
char_to_idx = {char: idx + 1 for idx, char in enumerate(vocab)}  # Start index from 1 for padding
# Sequence encoder: Convert the protein sequence into integers
def encode_sequence(sequence):
    return [char_to_idx.get(char, 0) for char in sequence] + [0 for _ in range(max_seq_len - len(sequence))]  # 0 for unknown characters or padding 

def encode_sequence_batch(sequences):
    return torch.Tensor([encode_sequence(s) for s in sequences])

In [24]:
# Hyperparameters
vocab_dim = 21
embedding_dim = 320
hidden_dim = 512
num_taxonomy_ids = 4  # Example: Assuming 14,680 possible taxonomy classes

# Initialize the model, optimizer, and loss function
ESM_model = FNNClassifier(
    "Mean",
    embedding_dim,
    num_taxonomy_ids,
    1000,
    hidden_dim
).to(device)

onehot_model = FNNClassifier(
    "Flat",
    21,
    num_taxonomy_ids,
    1000,
    hidden_dim
).to(device)


model = ESM_model if config == "embedding" else onehot_model
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
print("model:", sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

model: 164.006912 M parameters


In [25]:
def evaluate(split='val'):
    model.eval()  # Set model to evaluation mode
    
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for epoch in range(num_val):
        with torch.no_grad():  # Disable gradient computation during evaluation
            tensor_batch, sequences = esm_hierarchy_data_to_tensor_batch(split, epoch)
            tensor_batch.gpu(device)
        
            labels = tensor_batch.taxes["clades"]
            onehot_tensors = encode_sequence_batch(sequences).to(device)
            
            if config == "embedding":
                outputs = model(tensor_batch.seq_ids)
            else:
                outputs = model(onehot_tensors)
            
            # Calculate the loss
            loss = criterion(outputs, labels)
    
            running_loss += loss.item()
                
            preds = torch.argmax(outputs, dim=1)
    
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    # Concatenate all batches into single tensors
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Compute evaluation metrics (example: accuracy, F1 score)
    accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
    f1_macro = f1_score(all_labels.numpy(), all_preds.numpy(), average='macro')  # F1-score for multi-label classification
    f1_micro = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')  # F1-score for multi-label classification
    conf_matrix = confusion_matrix(all_labels.numpy(), all_preds.numpy())
    avg_loss = running_loss / num_val
    
    return avg_loss, accuracy, f1_micro, f1_macro, conf_matrix

In [1]:
running_loss = 0

for epoch in range(epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    
    tensor_batch, sequences = esm_hierarchy_data_to_tensor_batch('train', epoch)
    tensor_batch.gpu(device)

    labels = tensor_batch.taxes["clade"]
    onehot_tensors = encode_sequence_batch(sequences).to(device)
    
    if config == "embedding":
        outputs = model(tensor_batch.seq_ids)
    else:
        outputs = model(onehot_tensors)
    
    # Calculate the loss
    loss = criterion(outputs, labels)

    # Backpropagation: Zero the gradients, compute the backward pass, and update weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Track the loss
    running_loss += loss.item()

    if (epoch + 1) % val_epoch == 0:
        # Print loss for this epoch
        epoch_loss = running_loss / (epoch + 1)
        print(f"Epoch [{epoch + 1}/{epochs}], Train Loss: {epoch_loss:.4f}")
        
        # Evaluate the model on the test set
        val_loss, val_accuracy, val_f1_micro, val_f1_macro, cm = evaluate()
        print(cm)
        print(f"val Loss: {val_loss:.4f}, val Accuracy: {val_accuracy:.4f}, val F1 Score (micro): {val_f1_micro:.4f}, , val F1 Score (macro): {val_f1_macro:.4f}")

        wandb.log({"train loss": epoch_loss, "val acc": val_accuracy, "val loss": val_loss})

wandb.finish()

NameError: name 'epochs' is not defined

In [11]:
evaluate('test')

(0.14237267754599453,
 0.9738125,
 0.9738125,
 0.32891084301742607,
 array([[    0,   237,     0],
        [    0, 15581,     0],
        [    0,   182,     0]]))