In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score

sys.path.insert(0, '../dlp')
from data_process import CNN_prepare_batch

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 100_000
val_epoch = 5000
num_val = 500
batch_size = 64
dataset_name = "corpus_200_500_random"
virus_dataset_name = "corpus_200_500_Viruses_random"
non_virus_dataset_name = "corpus_200_500_Non_Viruses_random"
lr = 0.001
model_name = "Pure CNN"
max_seq_len = 500

from data_access import PQDataAccess
virus_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq/{virus_dataset_name}", batch_size)
non_virus_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq/{non_virus_dataset_name}", batch_size)
da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq/{dataset_name}", batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": "Onehot_CNN",
        "dataset": dataset_name,
        "epochs": epochs,
        "batch_szie": batch_size,
        "max_seq_len": max_seq_len
    }
)

  from .autonotebook import tqdm as notebook_tqdm


Loaded dictionary.
cuda:1
 WORLD_SIZE=1 , LOCAL_WORLD_SIZE=1,RANK =0,LOCAL_RANK = 0 
../checkpoints/Pure CNN_checkpoints


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malirezanor[0m ([33malirezanor-310-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
class EnhancedProteinCNN(nn.Module):
    def __init__(self, 
                 num_classes=4,
                 vocab_size=25,
                 embedding_dim=128,
                 max_seq_length=max_seq_len,
                 num_filters=256,
                 kernel_sizes=[3, 5, 7, 9, 11],
                 dropout_rate=0.5):
        super(EnhancedProteinCNN, self).__init__()
        
        # Original embedding for amino acid indices
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        
        # Additional feature processing
        self.feature_dense = nn.Linear(3, embedding_dim)  # For hydrophobicity, volume, polarity
        
        # Process global sequence features
        self.global_feature_dense = nn.Linear(28, embedding_dim)
        
        # Convolution layers
        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels=embedding_dim,
                     out_channels=num_filters,
                     kernel_size=k,
                     padding='same')
            for k in kernel_sizes
        ])
        
        self.batch_norms = nn.ModuleList([
            nn.BatchNorm1d(num_filters)
            for _ in kernel_sizes
        ])
        
        # Calculate total features
        total_filters = num_filters * len(kernel_sizes) + embedding_dim  # Added embedding_dim for global features
        
        # Fully connected layers
        self.fc1 = nn.Linear(total_filters, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        
        self.dropout = nn.Dropout(dropout_rate)
        self.layer_norm = nn.LayerNorm(embedding_dim)
    
    def forward(self, x, global_features, attention_mask=None):
        # Process amino acid indices
        seq_embeddings = self.embedding(x[:, :, 0].long())  # Shape: (batch_size, seq_length, embedding_dim)
        
        # Process additional features (hydrophobicity, volume, polarity)
        feature_embeddings = self.feature_dense(x[:, :, 1:4])  # Shape: (batch_size, seq_length, embedding_dim)
        
        # Combine embeddings
        x = seq_embeddings + feature_embeddings
        x = self.layer_norm(x)
        
        # Process global sequence features
        global_embedding = self.global_feature_dense(global_features)
        
        if attention_mask is not None:
            x = x * attention_mask.unsqueeze(-1)
        
        # Transpose for convolution
        x = x.transpose(1, 2)  # Shape: (batch_size, embedding_dim, seq_length)
        
        # Apply convolutions
        conv_outputs = []
        for conv, bn in zip(self.convs, self.batch_norms):
            conv_out = conv(x)
            conv_out = bn(conv_out)
            conv_out = F.relu(conv_out)
            pooled = F.adaptive_max_pool1d(conv_out, 1).squeeze(-1)
            conv_outputs.append(pooled)
        
        # Concatenate all outputs including global features
        x = torch.cat(conv_outputs + [global_embedding], dim=1)
        
        # Final fully connected layers
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x

In [3]:
model = EnhancedProteinCNN().to(device)
print("model:", sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
# Cosine annealing with warm restarts
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,  # Initial restart interval
    T_mult=2,  # Multiply interval by 2 after each restart
    eta_min=1e-6  # Minimum learning rate
)

model: 2.012164 M parameters
EnhancedProteinCNN(
  (embedding): Embedding(25, 128, padding_idx=0)
  (feature_dense): Linear(in_features=3, out_features=128, bias=True)
  (global_feature_dense): Linear(in_features=28, out_features=128, bias=True)
  (convs): ModuleList(
    (0): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=same)
    (1): Conv1d(128, 256, kernel_size=(5,), stride=(1,), padding=same)
    (2): Conv1d(128, 256, kernel_size=(7,), stride=(1,), padding=same)
    (3): Conv1d(128, 256, kernel_size=(9,), stride=(1,), padding=same)
    (4): Conv1d(128, 256, kernel_size=(11,), stride=(1,), padding=same)
  )
  (batch_norms): ModuleList(
    (0-4): 5 x BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (fc1): Linear(in_features=1408, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=4, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (layer_norm)

In [4]:
val_batches = [da.get_batch() for _ in range(num_val)]

def evaluate(model):
    model.eval()  # Set model to evaluation mode
    
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for epoch in range(num_val):
        with torch.no_grad():  # Disable gradient computation during evaluation
            tensor_batch = CNN_prepare_batch(val_batches[epoch], max_seq_len_=max_seq_len)
            tensor_batch.gpu(device)
        
            labels = tensor_batch.taxes["begining root"]
            outputs = model(
                tensor_batch.seq_ids["batch_encoding"],
                tensor_batch.seq_ids["batch_global_features"],
                tensor_batch.seq_ids["batch_maks"],
            )

            # Calculate the loss
            loss = criterion(outputs, labels)
    
            running_loss += loss.item()
                
            preds = torch.argmax(outputs, dim=1)
    
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    # Concatenate all batches into single tensors
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Compute evaluation metrics (example: accuracy, F1 score)
    accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
    f1_macro = f1_score(all_labels.numpy(), all_preds.numpy(), average='macro')  # F1-score for multi-label classification
    f1_micro = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')  # F1-score for multi-label classification
    conf_matrix = confusion_matrix(all_labels.numpy(), all_preds.numpy(), labels= [0, 1, 2, 3])
    avg_loss = running_loss / num_val
    
    return avg_loss, accuracy, f1_micro, f1_macro, conf_matrix

In [5]:
import glob
def load_checkpoint(model, optimizer=None, scheduler=None):
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))        
    # Extract epoch numbers and find latest
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    checkpoint = torch.load(latest_checkpoint)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    # Load optimizer state if provided (for training)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # Move optimizer state to GPU if necessary
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)

    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Get training metadata
    epoch = checkpoint['epoch']
    metrics = checkpoint['metrics']
    
    print(f"Successfully loaded checkpoint from epoch {epoch}")
    print("Metrics at checkpoint:", metrics)
    
    return model, optimizer, scheduler, epoch, metrics
        

# model, optimizer, scheduler, latest_epoch, metrics = load_checkpoint(model, optimizer, scheduler)
latest_epoch = 0

In [6]:
def get_partition_ratio(epoch, decay_epochs=50000):
    """
    Calculate partition ratio that decreases from 8/16 to 1/16 in steps
    """
    # Calculate how many epochs before each step down
    epochs_per_step = decay_epochs // 7  # 7 steps from 8/16 down to 1/16
    
    # Calculate current step based on epoch
    step = min(epoch // epochs_per_step, 7)  # Max 7 steps down from 8
    
    # Map step to fraction
    fraction = (8 - step) / 16
    
    return fraction

In [None]:
running_loss = 0
train_preds = []
train_labels = []
current_lr = lr

for epoch in tqdm(range(latest_epoch, latest_epoch + epochs)):
    model.train()

    current_partition = get_partition_ratio(epoch)

    tensor_batch = CNN_prepare_batch(
        virus_da.get_batch(),
        non_virus_da.get_batch(),
        max_seq_len,
        partition=current_partition
    )
    tensor_batch.gpu(device)
        
    labels = tensor_batch.taxes["begining root"]
    outputs = model(
        tensor_batch.seq_ids["batch_encoding"],
        tensor_batch.seq_ids["batch_global_features"],
        tensor_batch.seq_ids["batch_maks"],
    )
    
    loss = criterion(outputs, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
    
    preds = torch.argmax(outputs, dim=1)
    train_preds.append(preds.cpu())
    train_labels.append(labels.cpu())
    
    if (epoch + 1) % val_epoch == 0:
        # Calculate training metrics
        all_train_preds = torch.cat(train_preds)
        all_train_labels = torch.cat(train_labels)
        
        train_accuracy = accuracy_score(all_train_labels.numpy(), all_train_preds.numpy())
        train_f1_micro = f1_score(all_train_labels.numpy(), all_train_preds.numpy(), average='micro')
        train_f1_macro = f1_score(all_train_labels.numpy(), all_train_preds.numpy(), average='macro')
        train_cm = confusion_matrix(all_train_labels.numpy(), all_train_preds.numpy(), labels=[0, 1, 2, 3])
        train_loss = running_loss / val_epoch
        
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"Train F1 (micro): {train_f1_micro:.4f}, Train F1 (macro): {train_f1_macro:.4f}")
        print("Train Confusion Matrix:")
        print(train_cm)
        
        # Evaluate on validation set
        val_loss, val_accuracy, val_f1_micro, val_f1_macro, val_cm = evaluate(model)
        
        print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
        print(f"Val F1 (micro): {val_f1_micro:.4f}, Val F1 (macro): {val_f1_macro:.4f}")
        print("Val Confusion Matrix:")
        print(val_cm)

        # Step the scheduler
        scheduler.step(epoch + loss.item())
        current_lr = scheduler.get_last_lr()[0]
        
        # Create metrics dictionary for saving
        metrics = {
            "train_loss": train_loss,
            "train_accuracy": train_accuracy,
            "train_f1_micro": train_f1_micro,
            "train_f1_macro": train_f1_macro,
            "train_confusion_matrix": train_cm,
            "val_loss": val_loss,
            "val_accuracy": val_accuracy,
            "val_f1_micro": val_f1_micro,
            "val_f1_macro": val_f1_macro,
            "val_confusion_matrix": val_cm,
            "epoch": epoch + 1,
            "current_portion": current_partition,
            "lr": current_lr
        }

        # Save periodic checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        
        # Log to wandb
        wandb.log(metrics)

        # Step the scheduler
        scheduler.step(epoch + loss.item())
        current_lr = scheduler.get_last_lr()[0]
        
        # Reset training metrics
        running_loss = 0
        train_preds = []
        train_labels = []

wandb.finish()

  5%|▍         | 4999/100000 [13:13<3:34:43,  7.37it/s]

Epoch [5000/100000]
Train Loss: 0.7249, Train Accuracy: 0.5000
Train F1 (micro): 0.5000, Train F1 (macro): 0.1670
Train Confusion Matrix:
[[     0      0      0      0]
 [     7 159858    132      3]
 [     2 158436    130      6]
 [     0   1423      3      0]]


  5%|▌         | 5000/100000 [14:32<628:56:58, 23.83s/it]

Val Loss: 0.7211, Val Accuracy: 0.0080
Val F1 (micro): 0.0080, Val F1 (macro): 0.0053
Val Confusion Matrix:
[[    0     0     0     0]
 [    0   256     0     0]
 [    0 31660     0     0]
 [    0    84     0     0]]


  8%|▊         | 7806/100000 [21:59<4:09:06,  6.17it/s]  