In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
from sklearn.metrics import f1_score, confusion_matrix

sys.path.insert(0, '../dlp')
from data_process import *

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(device)

epochs= 10_000
val_epoch = 50
num_val = 25

model_name = "esm_hierarchy"
checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

  from .autonotebook import tqdm as notebook_tqdm


Loaded dictionary.
{'superkingdom': 3, 'genus': 109455, 'order': 1915, 'family': 10413, 'subspecies': 29427, 'no rank': 242062, 'subfamily': 3261, 'strain': 46396, 'serogroup': 154, 'biotype': 17, 'tribe': 2393, 'phylum': 311, 'class': 534, 'species group': 359, 'forma': 699, 'clade': 956, 'suborder': 375, 'subclass': 169, 'varietas': 9991, 'kingdom': 13, 'subphylum': 31, 'forma specialis': 784, 'isolate': 1304, 'superfamily': 901, 'infraorder': 135, 'infraclass': 19, 'superorder': 57, 'subgenus': 1821, 'superclass': 6, 'parvorder': 26, 'begining root': 4, 'serotype': 1229, 'species subgroup': 134, 'subcohort': 3, 'cohort': 5, 'genotype': 22, 'subtribe': 587, 'section': 534, 'series': 9, 'morph': 11, 'subkingdom': 1, 'superphylum': 1, 'subsection': 41, 'pathogroup': 5}

Taxonomic ranks sorted by number of taxa:
no rank: 242062
genus: 109455
strain: 46396
subspecies: 29427
family: 10413
varietas: 9991
subfamily: 3261
tribe: 2393
order: 1915
subgenus: 1821
isolate: 1304
serotype: 1229
cl

In [21]:
class ESM_TaxonomyClassifier(nn.Module):
    def __init__(
        self,
        # Dictionary of taxonomy levels and their possible classes
        taxonomy_levels,
        input_dim=320,
        d_model=512,
        nhead=4,
        num_encoder_layers=1,
        dim_feedforward=256,
        dropout=0.1
    ):
        super().__init__()
        
        # Sequence embedding layers
        self.embedding = nn.Linear(input_dim, d_model)
        
        # Transformer encoder
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layers,
            num_encoder_layers
        )
        
        # Shared feature extractor
        self.feature_extractor = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Create classifier heads for each taxonomy level
        self.classifier_heads = nn.ModuleDict({
            level: nn.Linear(d_model, num_classes, bias=False)
            for level, num_classes in taxonomy_levels.items()
        })
        
        self.d_model = d_model
        
    def forward(self, src):
        # Embedding and positional encoding
        # src = self.embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        # src = self.pos_encoder(src)
        print(src.shape)
        src = self.embedding(src)
        
        print(src.shape)
        # Transform sequence
        encoder_output = self.transformer_encoder(
            src,
        )
        print(encoder_output.shape)
        
        # Global average pooling
        sequence_features = torch.mean(encoder_output, dim=1)
        print(sequence_features.shape)
        # Extract shared features
        shared_features = self.feature_extractor(sequence_features)
        print("shared_features:", shared_features.shape)
        # Get predictions for each taxonomy level
        predictions = {
            level: head(shared_features)
            for level, head in self.classifier_heads.items()
        }
        
        return predictions

In [22]:
model = ESM_TaxonomyClassifier(taxonomy_levels=tax_vocab_sizes).to(device)


tensor_batch = esm_hierarchy_data_to_tensor_batch('train', random.randint(0, 9999))
tensor_batch.gpu(device)

input_ids = tensor_batch.seq_ids

print(input_ids.shape)
predictions = model(input_ids)



torch.Size([16, 320])
torch.Size([16, 320])
torch.Size([16, 512])
torch.Size([16, 512])
torch.Size([16])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16 and 512x512)

In [8]:
total = sum(importance_dict.values())
level_weights = {key: value / total for key, value in importance_dict.items()}

model = ESM_TaxonomyClassifier(taxonomy_levels=tax_vocab_sizes).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

print("model:", sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')



model: 240.627968 M parameters


In [12]:
import random


def train_step(level_weights=None):
    # Zero the gradients
    optimizer.zero_grad()
    
    # Get batch and convert to tensor
    tensor_batch = esm_hierarchy_data_to_tensor_batch('train', random.randint(0, 9999))
    tensor_batch.gpu(device)

    input_ids = tensor_batch.seq_ids

    print(input_ids.shape)
    predictions = model(input_ids)
    labels = tensor_batch.taxes
    
    # Initialize total loss
    total_loss = 0
    level_losses = {}

    # Calculate loss for each level
    for level, pred in predictions.items():
        level_loss = F.BCEWithLogitsLoss(pred, labels[level])
        
        # Apply level weights if provided
        if level_weights and level in level_weights:
            level_loss *= level_weights[level]
            
        level_losses[level] = level_loss.item()  # Store loss value for logging
        total_loss += level_loss
    
    # Backward pass and optimization step
    total_loss.backward()
    optimizer.step()

    return total_loss.item(), level_losses

In [13]:
def evaluate(tax_vocab_sizes, level_weights=None, num_val_batches=1):
    model.eval()  # Set the model to evaluation mode
    
    level_losses = {} 
    total_loss = 0

    # Initialize accumulators for each level
    level_correct = {level: 0 for level in tax_vocab_sizes.keys()}
    level_total = {level: 0 for level in tax_vocab_sizes.keys()}
    level_preds = {level: [] for level in tax_vocab_sizes.keys()}
    level_labels = {level: [] for level in tax_vocab_sizes.keys()}
    
    # Only track confusion matrix for "begining root"
    root_preds = []
    root_labels = []
    
    with torch.no_grad():  # Disable gradient computation for evaluation
        for _ in range(num_val_batches):
            tensor_batch = esm_hierarchy_data_to_tensor_batch('val', _)
            tensor_batch.gpu(device)
            
            predictions = model(tensor_batch.seq_ids)
            labels = tensor_batch.taxes
            
            batch_loss = 0
            batch_level_losses = {}

            # Calculate loss for each level
            for level, pred in predictions.items():
                level_loss = F.cross_entropy(pred, labels[level])
                # Apply level weights if provided
                if level_weights and level in level_weights:
                    level_loss *= level_weights[level]
                
                batch_level_losses[level] = level_loss.item()
                batch_loss += level_loss
            
                # Store predictions and labels for F1 calculation
                predicted_classes = torch.argmax(pred, dim=1)
                level_preds[level].extend(predicted_classes.cpu().numpy())
                level_labels[level].extend(labels[level].cpu().numpy())
                
                # Store predictions and labels for "begining root" confusion matrix
                if level == "begining root":
                    root_preds.extend(predicted_classes.cpu().numpy())
                    root_labels.extend(labels[level].cpu().numpy())
                
                # Accumulate correct predictions and total samples for each level
                level_correct[level] += (predicted_classes == labels[level]).sum().item()
                level_total[level] += labels[level].size(0)
            
            # Update total loss and level-specific losses
            total_loss += batch_loss.item()
            for level, level_loss_value in batch_level_losses.items():
                if level in level_losses:
                    level_losses[level] += level_loss_value
                else:
                    level_losses[level] = level_loss_value
    
    # Calculate average losses
    val_loss = total_loss / num_val_batches
    level_losses = {level: loss / num_val_batches for level, loss in level_losses.items()}
    
    # Calculate accuracy for each level
    level_acc = {level: correct / total if total > 0 else 0 
                 for level, (correct, total) in 
                 zip(level_correct.keys(), zip(level_correct.values(), level_total.values()))}
    
    # Calculate F1 scores for each level
    level_f1 = {level: f1_score(np.array(level_labels[level]), 
                               np.array(level_preds[level]), 
                               average='micro') 
                for level in tax_vocab_sizes.keys()}
    
    # Calculate confusion matrix only for "begining root"
    total_cms = {"begining root": confusion_matrix(np.array(root_labels),
                                                 np.array(root_preds),
                                                 labels=[_ for _ in range(tax_vocab_sizes["begining root"])])}
    
    # Calculate overall accuracy across all levels
    total_correct = sum(level_correct.values())
    total_samples = sum(level_total.values())
    overall_accuracy = total_correct / total_samples if total_samples > 0 else 0
    
    model.train()  # Set the model back to training mode
    return val_loss, level_losses, overall_accuracy, level_acc, level_f1, total_cms

In [14]:
model.train()

train_losses = []
val_losses = []
val_accuracies = []
val_f1s = []

for epoch in range(epochs):
    train_loss, level_losses = train_step(level_weights)
    train_losses.append(train_loss)
    
    if (epoch + 1) % val_epoch == 0:
        val_loss, level_losses, acc, level_acc, level_f1, cms = evaluate(tax_vocab_sizes, level_weights, num_val)
        print("cms", cms)
        val_losses.append(val_loss)
        val_accuracies.append(acc)
        val_f1s.append(level_f1)

        mean_train_loss = sum(train_losses[-val_epoch:]) / val_epoch
        
        print(f"Epoch {epoch+1}, Train Loss: {mean_train_loss:.4f}, Val Loss: {val_loss:.4f}, val acc: {acc:.4f}")
        print(sum(level_f1.values()) / len(level_f1))
    
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_step_{epoch + 1}.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'accuracy': acc,
            'f1_score': level_f1
        }, checkpoint_path)

torch.Size([16, 320])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16 and 512x512)

In [4]:
def load_latest_checkpoint(checkpoint_dir, model, specific=None):
    # List all checkpoint files and sort them by step number
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_step_")]
    if not checkpoints:
        print("No checkpoints found in directory.")
        return None

    # Find the latest checkpoint based on step number
    checkpoints.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]), reverse=True)
    if specific is None:
        latest_checkpoint_path = os.path.join(checkpoint_dir, checkpoints[0])
    else:
        latest_checkpoint_path = os.path.join(checkpoint_dir, specific)

    # Load the checkpoint
    checkpoint = torch.load(latest_checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch = checkpoint['epoch']
    train_loss = checkpoint['train_loss']
    val_loss = checkpoint['val_loss']
    accuracy = checkpoint['accuracy']
    f1_score = checkpoint['f1_score']

    print(f"Loaded checkpoint from epoch {epoch+1}")
    
    return {
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "accuracy": accuracy,
        "f1_score": f1_score
    }

In [35]:
import sys
sys.path.insert(0, '../dlp')

from data_process import *

test_seq = "MKRLRPSDKFFELLGYKPHHVQLAIHRSTAKRRVACLGRQSGKSEAASVEAVFELFARPGSQGWIIAPTYDQAEIIFGRVVEKVERLSEVFPTTEVQLQRRRLRLLVHHYDRPVNAPGAKRVATSEFRGKSADRPDNLRGATLDFVILDEAAMIPFSVWSEAIEPTLSVRDGWALIISTPKGLNWFYEFFLMGWRGGLKEGIPNSGINQTHPDFESFHAASWDVWPERREWYMERRLYIPDLEFRQEYGAEFVSHSNSVFSGLDMLILLPYERRGTRLVVEDYRPDHIYCIGADFGKNQDYSVFSVLDLDTGAIACLERMNGATWSDQVARLKALSEDYGHAYVVADTWGVGDAIAEELDAQGINYTPLPVKSSSVKEQLISNLALLMEKGQVAVPNDKTILDELRNFRYYRTASGNQVMRAYGRGHDDIVMSLALAYSQYEGKDGYKFELAEERPSKLKHEESVMSLVEDDFTDLELANRAFSA"
tax_lineage = "cellular organisms, Bacteria, Pseudomonadota, Betaproteobacteria, unclassified Betaproteobacteria, Betaproteobacteria bacterium GR16-43"

model = TaxonomyClassifier(taxonomy_levels=tax_vocab_sizes).to(device)
latest_checkpoint = load_latest_checkpoint(checkpoint_dir, model)

input_tensor = torch.LongTensor([encode_sequence(test_seq)]).to(device)
output = model(input_tensor)

output_indexes = {k: v.argmax().item() for k,v in output.items()}

hierarchy = [
    "begining root", "no rank", "superkingdom", "kingdom", "subkingdom", "superphylum", "phylum",
    "subphylum", "superclass", "class", "subclass", "infraclass", "superorder", "order", "suborder",
    "infraorder", "parvorder", "superfamily", "family", "subfamily", "tribe", "subtribe", "genus",
    "subgenus", "species group", "species subgroup", "species", "subspecies", "varietas", "forma specialis",
    "forma", "biotype", "pathogroup", "serogroup", "serotype", "isolate", "strain", "genotype", "clade",
    "cohort", "subcohort", "section", "subsection", "series", "morph",
]

def pretty_print(dict_index):
    for k in hierarchy:
        if k in dict_index:
            v = dict_index[k]
            if v > 0:
                print(k, "\t", level_decoder[k][v])


def decode_input_lineage(tax_lineage):
    print(tax_lineage)
    test_input = encode_lineage(tax_lineage)
    for k in hierarchy:
        if k in test_input:
            v = test_input[k][0]
            if v > 0:
                print(k, "\t", level_decoder[k][v])


pretty_print(output_indexes)
print("--------")
decode_input_lineage(tax_lineage)

Loaded checkpoint from epoch 10000
begining root 	 cellular organisms
superkingdom 	 Bacteria
--------
cellular organisms, Bacteria, Pseudomonadota, Betaproteobacteria, unclassified Betaproteobacteria, Betaproteobacteria bacterium GR16-43
begining root 	 cellular organisms
no rank 	 unclassified Betaproteobacteria
superkingdom 	 Bacteria
phylum 	 Pseudomonadota
class 	 Betaproteobacteria
