In [17]:
from amr.dataset import HybridGenomeDataset

kleb = "Klebsiella_pneumoniae_aztreonam"
staphy = "Staphylococcus_aureus_cefoxitin"

k=6
pathogen = staphy
genes=["pbp4","gyrA","fusA","dfrB","rpoB"]

train_dataset = HybridGenomeDataset(
    root_dir="../data/ds1",
    train_or_test="train",
    pathogen=pathogen,
    genes=genes,
    k=k
)

test_dataset = HybridGenomeDataset(
    root_dir="../data/ds1",
    train_or_test="test",
    k=k,
    pathogen=pathogen,
    genes=["pbp4"]
)

print(train_dataset[0])

((tensor([[[ 0.2235,  0.5451, -0.2314,  ..., -1.0000, -1.0000, -1.0000],
         [ 0.2235,  0.7961,  0.2235,  ..., -0.2314, -1.0000, -1.0000],
         [-1.0000,  0.2235,  0.2235,  ..., -1.0000, -1.0000, -0.2314],
         ...,
         [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000]]]), tensor([[0., 0., 1.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.]]), [1]), 0)


In [18]:
import torch
from torch.nn.utils.rnn import pad_sequence

def hybrid_collate(batch):
    images = []
    sequences = []
    genes = []
    labels = []
    
    for (img, seq,gene), label in batch:
        images.append(img)
        sequences.append(seq)
        labels.append(label)
        genes.append(gene)
    
    return (torch.stack(images), pad_sequence(sequences, batch_first=True),torch.tensor(genes)), torch.tensor(labels)

In [19]:
from torch.utils.data import Dataset, DataLoader, random_split

def get_train_val_dataloaders(val_split=0.2):
    train_split_dataset, val_split_dataset = random_split(
        train_dataset, [1 - val_split, val_split]
    )
    train_loader = DataLoader(train_split_dataset, batch_size=32, shuffle=True, collate_fn=hybrid_collate)
    val_loader = DataLoader(val_split_dataset, batch_size=32, shuffle=True, collate_fn=hybrid_collate)
    return train_loader,val_loader

In [23]:
from net.HybridGenomeNet import HybridGenomeNet
from torch import nn
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts


def train_hybrid():
    train_loader,val_loader = get_train_val_dataloaders()
    
    device_string = "cpu"
    if(torch.mps.is_available):
        device_string = "mps"
    elif torch.cuda.is_available():
        device_string = "cuda"
    
    device = torch.device(device_string)
    model = HybridGenomeNet().to(device)
    class_weights = torch.tensor([2.0, 1.0]).to(device=device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    scheduler = CosineAnnealingWarmRestarts(optimizer, 
                                      T_0=10,  # Initial cycle length
                                      T_mult=2,  # Cycle length multiplier
                                      eta_min=1e-6)  # Minimum LR
    
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': []
    }
    
    num_epochs = 500
    
    for epoch in range(num_epochs):
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        for inputs, labels in train_loader:
            images = inputs[0].to(device)
            sequences = inputs[1].to(device)
            genes = inputs[2].to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model((images, sequences,genes))
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                images = inputs[0].to(device)
                sequences = inputs[1].to(device)
                genes = inputs[2].to(device)
                labels = labels.to(device)
                
                outputs = model((images,sequences,genes))
                loss = criterion(outputs, labels)
                
                # Statistics
                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        epoch_train_loss = train_loss / train_total
        epoch_train_acc = train_correct / train_total
        epoch_val_loss = val_loss / val_total
        epoch_val_acc = val_correct / val_total
        
        scheduler.step(epoch_val_loss)
        
        history['train_loss'].append(epoch_train_loss)
        history['val_loss'].append(epoch_val_loss)
        history['train_acc'].append(epoch_train_acc)
        history['val_acc'].append(epoch_val_acc)
        
        print(f'Epoch {epoch+1}/{num_epochs} | '
              f'Train Loss: {epoch_train_loss:.4f} | '
              f'Train Acc: {epoch_train_acc:.4f} | '
              f'Val Loss: {epoch_val_loss:.4f} | '
              f'Val Acc: {epoch_val_acc:.4f}')
        
        
        
        
train_hybrid()

Epoch 1/500 | Train Loss: 0.7544 | Train Acc: 0.5537 | Val Loss: 0.6947 | Val Acc: 0.6963
Epoch 2/500 | Train Loss: 0.7097 | Train Acc: 0.5722 | Val Loss: 0.6919 | Val Acc: 0.6963
Epoch 3/500 | Train Loss: 0.7070 | Train Acc: 0.6333 | Val Loss: 0.6947 | Val Acc: 0.6963
Epoch 4/500 | Train Loss: 0.7066 | Train Acc: 0.6111 | Val Loss: 0.6901 | Val Acc: 0.6963
Epoch 5/500 | Train Loss: 0.7190 | Train Acc: 0.5389 | Val Loss: 0.6904 | Val Acc: 0.6815
Epoch 6/500 | Train Loss: 0.6843 | Train Acc: 0.6352 | Val Loss: 0.6854 | Val Acc: 0.6963
Epoch 7/500 | Train Loss: 0.6923 | Train Acc: 0.6352 | Val Loss: 0.6896 | Val Acc: 0.6963
Epoch 8/500 | Train Loss: 0.6812 | Train Acc: 0.6556 | Val Loss: 0.6878 | Val Acc: 0.6963
Epoch 9/500 | Train Loss: 0.6892 | Train Acc: 0.6463 | Val Loss: 0.6939 | Val Acc: 0.6963
Epoch 10/500 | Train Loss: 0.6779 | Train Acc: 0.6833 | Val Loss: 0.6956 | Val Acc: 0.6963
Epoch 11/500 | Train Loss: 0.6920 | Train Acc: 0.6574 | Val Loss: 0.6871 | Val Acc: 0.7037
Epoch 12