In [6]:
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import sys, os, math

sys.path.insert(0, '../dlp')
from data_access import PQDataAccess

pd.set_option('future.no_silent_downcasting', True)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

batch_size = 32
da = PQDataAccess("/home/aac/Alireza/datasets/taxseq/corpus_1000", batch_size)
epochs= 10_000
val_epoch = 50
num_val = 25

model_name = "tokenizer" # "FNN", "hierarchy", "T5"
checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

from data_process import *

from models.TokenizerClassifier import TokenizerClassifier


cuda:0
../checkpoints/tokenizer_checkpoints


In [73]:
from models.TokenizerClassifier import TokenizerClassifier

model = TokenizerClassifier(output_dim=len_tokenizer, max_tax_len=max_tax_len).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print("model:", sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

model: 8.021248 M parameters




In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from data_process import *

def train_step(model, optimizer, da, device):
    # Zero the gradients
    optimizer.zero_grad()
    
    # Get batch and convert to tensor
    tensor_batch = tokenizer_data_to_tensor_batch(da.get_batch())
    tensor_batch.gpu(device)
    
    prediction = model(tensor_batch.seq_ids)
    labels = tensor_batch.taxes
    
    loss = nn.BCEWithLogitsLoss()(prediction, labels.float())
        
    # Backward pass and optimization step
    loss.backward()
    optimizer.step()

    return loss.item()

In [69]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import f1_score, accuracy_score
from data_process import *
from sklearn.metrics import confusion_matrix


def evaluate(model, da, device, num_val_batches):
    model.eval()  # Set the model to evaluation mode
    
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():  # Disable gradient computation for evaluation
        for _ in range(num_val_batches):
            tensor_batch = tokenizer_data_to_tensor_batch(da.get_batch())
            tensor_batch.gpu(device)
            
            prediction = model(tensor_batch.seq_ids)
            label = tensor_batch.taxes
            
            loss = nn.BCEWithLogitsLoss()(prediction, label.float())
            total_loss += loss

            all_preds.append(prediction)
            all_labels.append(label)

    all_preds = torch.cat(all_preds).cpu()
    all_labels = torch.cat(all_labels).cpu()
    
    thresholds = np.arange(0.0, 1.0, 0.1)
    best_threshold = 0.0
    best_f1 = 0.0
    
    # Iterate over thresholds to find the one with the highest F1 score
    for threshold in thresholds:
        predicted_classes = (torch.sigmoid(all_preds) > threshold).int()
        f1 = f1_score(all_labels, predicted_classes, average='micro')
        
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
    
    # Calculate accuracy
    predicted_classes = (torch.sigmoid(all_preds) > best_threshold).int()
    # print(all_labels)
    # print(predicted_classes)
    # print(all_labels.shape)
    # print(predicted_classes.shape)
    f1 = f1_score(all_labels, predicted_classes, average='micro')
    # accuracy = accuracy_score(all_labels, predicted_classes)

    accuracy = (predicted_classes == all_labels).sum().item() / predicted_classes.numel() 

    # Average losses
    val_loss = total_loss / num_val_batches
    
    model.train()  # Set the model back to training mode
    return val_loss, accuracy, f1, best_threshold

In [70]:
evaluate(model, da, device, num_val)

(tensor(0.0024, device='cuda:0'),
 0.9993202941785074,
 0.5497072683135799,
 0.30000000000000004)

In [None]:
model.train()

train_losses = []

for epoch in range(epochs):
    train_loss = train_step(model, optimizer, da, device)
    train_losses.append(train_loss)
    
    if (epoch + 1) % val_epoch == 0:
        val_loss, acc, f1, tresh = evaluate(model, da, device, num_val)
        
        mean_train_loss = sum(train_losses[-val_epoch:]) / val_epoch
        print(f"Epoch {epoch+1}, Train Loss: {mean_train_loss:.4f}, Val Loss: {val_loss:.4f}, val acc: {acc:.4f}, val f1: {f1:.4f}")
    
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_step_{epoch + 1}.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'accuracy': acc,
            'f1_score': f1,
            'tresh': tresh
        }, checkpoint_path)

In [74]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
len_tokenizer = len(tokenizer.vocab)

def encode_lineage_tokenizer(tax_lineage):
    return tokenizer.encode(tax_lineage.split(", "), add_special_tokens=False, padding='max_length', max_length=max_tax_len, is_split_into_words= True)

def tokenizer_data_to_tensor_batch(b):
    # if model_name in ["new_hierarchy", "hierarchy"]:
    sequences = [encode_sequence(e['sequence']) for e in b]
    tax_ids = [encode_lineage_tokenizer(e['tax_lineage']) for e in b]
    encoded_list = [[1 if _ in tax_id else 0 for _ in range(len_tokenizer)] for tax_id in tax_ids]

    return Batch(torch.LongTensor(sequences), torch.LongTensor(encoded_list))



tensor(42)

In [2]:
def load_latest_checkpoint(checkpoint_dir, model, specific=None):
    # List all checkpoint files and sort them by step number
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_step_")]
    if not checkpoints:
        print("No checkpoints found in directory.")
        return None

    # Find the latest checkpoint based on step number
    checkpoints.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]), reverse=True)
    if specific is None:
        latest_checkpoint_path = os.path.join(checkpoint_dir, checkpoints[0])
    else:
        latest_checkpoint_path = os.path.join(checkpoint_dir, specific)

    # Load the checkpoint
    checkpoint = torch.load(latest_checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch = checkpoint['epoch']
    train_loss = checkpoint['train_loss']
    val_loss = checkpoint['val_loss']
    accuracy = checkpoint['accuracy']
    f1_score = checkpoint['f1_score']
    tresh = checkpoint['tresh']

    print(f"Loaded checkpoint from epoch {epoch+1}")
    
    return {
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "accuracy": accuracy,
        "f1_score": f1_score,
        "tresh": tresh
    }

In [8]:

test_seq = "MKRLRPSDKFFELLGYKPHHVQLAIHRSTAKRRVACLGRQSGKSEAASVEAVFELFARPGSQGWIIAPTYDQAEIIFGRVVEKVERLSEVFPTTEVQLQRRRLRLLVHHYDRPVNAPGAKRVATSEFRGKSADRPDNLRGATLDFVILDEAAMIPFSVWSEAIEPTLSVRDGWALIISTPKGLNWFYEFFLMGWRGGLKEGIPNSGINQTHPDFESFHAASWDVWPERREWYMERRLYIPDLEFRQEYGAEFVSHSNSVFSGLDMLILLPYERRGTRLVVEDYRPDHIYCIGADFGKNQDYSVFSVLDLDTGAIACLERMNGATWSDQVARLKALSEDYGHAYVVADTWGVGDAIAEELDAQGINYTPLPVKSSSVKEQLISNLALLMEKGQVAVPNDKTILDELRNFRYYRTASGNQVMRAYGRGHDDIVMSLALAYSQYEGKDGYKFELAEERPSKLKHEESVMSLVEDDFTDLELANRAFSA"
tax_lineage = "cellular organisms, Bacteria, Pseudomonadota, Betaproteobacteria, unclassified Betaproteobacteria, Betaproteobacteria bacterium GR16-43"

model = TokenizerClassifier(output_dim=len_tokenizer, max_tax_len=max_tax_len).to(device)
latest_checkpoint = load_latest_checkpoint(checkpoint_dir, model)
latest_checkpoint

Loaded checkpoint from epoch 5150


{'epoch': 5149,
 'train_loss': 0.005232478026300669,
 'val_loss': tensor(0.0050, device='cuda:0'),
 'accuracy': 0.9980904693750862,
 'f1_score': 0.36409836771609455,
 'tresh': 0.2}

In [15]:
input_tensor = torch.LongTensor([encode_sequence(test_seq)]).to(device)
output = model(input_tensor)
predicted_labels = (torch.sigmoid(output) > 0.2).int()
indexes = [i for i, _ in enumerate(predicted_labels[0]) if _]
indexes

[0,
 113,
 114,
 118,
 119,
 139,
 140,
 142,
 144,
 153,
 155,
 156,
 159,
 188,
 1116,
 1161,
 1162,
 1182,
 1183,
 1279,
 1361,
 1372,
 1465,
 1566,
 1643,
 1776,
 1777,
 1810,
 1874,
 2083,
 2093,
 2173,
 2822,
 2897,
 3052,
 4527,
 4559,
 5096,
 5114,
 6112,
 6140,
 6766,
 8209,
 8362,
 8814,
 9126,
 10548,
 10961,
 11179,
 12023,
 12658,
 12809,
 12985,
 14391,
 14521,
 15447,
 15540,
 16386,
 18484,
 18757,
 18882,
 19415,
 19810,
 19890,
 19891,
 21020,
 23916,
 25857,
 26503,
 26918]

In [16]:
tokenizer.decode(indexes)

'[PAD] ( ) -. B C E G P R S V ssaeiyesus groupiatepisttadareterce Actbalestesmyino Prohozoutebraeo unceae Op bacteriaazact organismsukaoboa cellular Neoumeosteriarda Baomi Met Terrakon Gramroteclassifiedcterryoleo'