In [1]:
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import sys, os, math

sys.path.insert(0, '../dlp')
from data_access import PQDataAccess
from data_process import *

pd.set_option('future.no_silent_downcasting', True)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

batch_size = 32
da = PQDataAccess("/home/aac/Alireza/datasets/taxseq/corpus_1000", batch_size)
epochs= 10_000
val_epoch = 50
num_val = 25

model_name = "new_hierarchy"
checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

 WORLD_SIZE=1 , LOCAL_WORLD_SIZE=1,RANK =0,LOCAL_RANK = 0 


  from .autonotebook import tqdm as notebook_tqdm


Loaded dictionary.
cuda:0
../checkpoints/new_hierarchy_checkpoints


In [2]:
from models.TaxonomyClassifier import TaxonomyClassifier

total = sum(importance_dict.values())
level_weights = {key: value / total for key, value in importance_dict.items()}

model = TaxonomyClassifier(taxonomy_levels=tax_vocab_sizes).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

print("model:", sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

  from .autonotebook import tqdm as notebook_tqdm


model: 120.04096 M parameters


In [None]:
from train import train_step
from evaluate import evaluate

model.train()

train_losses = []
val_losses = []
val_accuracies = []
val_f1s = []

for epoch in range(epochs):
    train_loss, level_losses = train_step(model, optimizer, da, device, level_weights)
    train_losses.append(train_loss)
    
    if (epoch + 1) % val_epoch == 0:
        val_loss, level_losses, acc, level_f1, cms = evaluate(model, da, device, tax_vocab_sizes, level_weights, num_val)
        print("cms", cms)
        val_losses.append(val_loss)
        val_accuracies.append(acc)
        val_f1s.append(level_f1)

        mean_train_loss = sum(train_losses[-val_epoch:]) / val_epoch
        
        print(f"Epoch {epoch+1}, Train Loss: {mean_train_loss:.4f}, Val Loss: {val_loss:.4f}, val acc: {acc:.4f}")
        print(sum(level_f1.values()) / len(level_f1))
    
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_step_{epoch + 1}.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'accuracy': acc,
            'f1_score': level_f1
        }, checkpoint_path)

In [4]:
def load_latest_checkpoint(checkpoint_dir, model, specific=None):
    # List all checkpoint files and sort them by step number
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_step_")]
    if not checkpoints:
        print("No checkpoints found in directory.")
        return None

    # Find the latest checkpoint based on step number
    checkpoints.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]), reverse=True)
    if specific is None:
        latest_checkpoint_path = os.path.join(checkpoint_dir, checkpoints[0])
    else:
        latest_checkpoint_path = os.path.join(checkpoint_dir, specific)

    # Load the checkpoint
    checkpoint = torch.load(latest_checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch = checkpoint['epoch']
    train_loss = checkpoint['train_loss']
    val_loss = checkpoint['val_loss']
    accuracy = checkpoint['accuracy']
    f1_score = checkpoint['f1_score']

    print(f"Loaded checkpoint from epoch {epoch+1}")
    
    return {
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "accuracy": accuracy,
        "f1_score": f1_score
    }

In [35]:
import sys
sys.path.insert(0, '../dlp')

from data_process import *

test_seq = "MKRLRPSDKFFELLGYKPHHVQLAIHRSTAKRRVACLGRQSGKSEAASVEAVFELFARPGSQGWIIAPTYDQAEIIFGRVVEKVERLSEVFPTTEVQLQRRRLRLLVHHYDRPVNAPGAKRVATSEFRGKSADRPDNLRGATLDFVILDEAAMIPFSVWSEAIEPTLSVRDGWALIISTPKGLNWFYEFFLMGWRGGLKEGIPNSGINQTHPDFESFHAASWDVWPERREWYMERRLYIPDLEFRQEYGAEFVSHSNSVFSGLDMLILLPYERRGTRLVVEDYRPDHIYCIGADFGKNQDYSVFSVLDLDTGAIACLERMNGATWSDQVARLKALSEDYGHAYVVADTWGVGDAIAEELDAQGINYTPLPVKSSSVKEQLISNLALLMEKGQVAVPNDKTILDELRNFRYYRTASGNQVMRAYGRGHDDIVMSLALAYSQYEGKDGYKFELAEERPSKLKHEESVMSLVEDDFTDLELANRAFSA"
tax_lineage = "cellular organisms, Bacteria, Pseudomonadota, Betaproteobacteria, unclassified Betaproteobacteria, Betaproteobacteria bacterium GR16-43"

model = TaxonomyClassifier(taxonomy_levels=tax_vocab_sizes).to(device)
latest_checkpoint = load_latest_checkpoint(checkpoint_dir, model)

input_tensor = torch.LongTensor([encode_sequence(test_seq)]).to(device)
output = model(input_tensor)

output_indexes = {k: v.argmax().item() for k,v in output.items()}

hierarchy = [
    "begining root", "no rank", "superkingdom", "kingdom", "subkingdom", "superphylum", "phylum",
    "subphylum", "superclass", "class", "subclass", "infraclass", "superorder", "order", "suborder",
    "infraorder", "parvorder", "superfamily", "family", "subfamily", "tribe", "subtribe", "genus",
    "subgenus", "species group", "species subgroup", "species", "subspecies", "varietas", "forma specialis",
    "forma", "biotype", "pathogroup", "serogroup", "serotype", "isolate", "strain", "genotype", "clade",
    "cohort", "subcohort", "section", "subsection", "series", "morph",
]

def pretty_print(dict_index):
    for k in hierarchy:
        if k in dict_index:
            v = dict_index[k]
            if v > 0:
                print(k, "\t", level_decoder[k][v])


def decode_input_lineage(tax_lineage):
    print(tax_lineage)
    test_input = encode_lineage(tax_lineage)
    for k in hierarchy:
        if k in test_input:
            v = test_input[k][0]
            if v > 0:
                print(k, "\t", level_decoder[k][v])


pretty_print(output_indexes)
print("--------")
decode_input_lineage(tax_lineage)

Loaded checkpoint from epoch 10000
begining root 	 cellular organisms
superkingdom 	 Bacteria
--------
cellular organisms, Bacteria, Pseudomonadota, Betaproteobacteria, unclassified Betaproteobacteria, Betaproteobacteria bacterium GR16-43
begining root 	 cellular organisms
no rank 	 unclassified Betaproteobacteria
superkingdom 	 Bacteria
phylum 	 Pseudomonadota
class 	 Betaproteobacteria
