In [16]:
%pip install torchmetrics

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [17]:
from amr.dataset import HybridGenomeDataset

kleb = "Klebsiella_pneumoniae_aztreonam"
staphy = "Staphylococcus_aureus_cefoxitin"

k=6
pathogen = staphy
genes=["pbp4"]

train_dataset = HybridGenomeDataset(
    root_dir="../data/ds1",
    train_or_test="train",
    pathogen=pathogen,
    genes=genes,
    k=k
)

test_dataset = HybridGenomeDataset(
    root_dir="../data/ds1",
    train_or_test="test",
    k=k,
    pathogen=pathogen,
    genes=genes
)

print(train_dataset[0])

((tensor([[[ 0.2235,  0.5451, -0.2314,  ..., -1.0000, -1.0000, -1.0000],
         [ 0.2235,  0.7961,  0.2235,  ..., -0.2314, -1.0000, -1.0000],
         [-1.0000,  0.2235,  0.2235,  ..., -1.0000, -1.0000, -0.2314],
         ...,
         [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000]]]), tensor([[0., 0., 1.,  ..., 1., 1., 1.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.]]), [1]), 0)


In [18]:
import torch
from torch.nn.utils.rnn import pad_sequence

def hybrid_collate(batch):
    images = []
    sequences = []
    genes = []
    labels = []
    
    for (img, seq,gene), label in batch:
        images.append(img)
        sequences.append(seq)
        labels.append(label)
        genes.append(gene)
    
    return (torch.stack(images), pad_sequence(sequences, batch_first=True),torch.tensor(genes)), torch.tensor(labels)

In [19]:
from torch.utils.data import Dataset, DataLoader, random_split

def get_train_val_dataloaders(val_split=0.2):
    train_split_dataset, val_split_dataset = random_split(
        train_dataset, [1 - val_split, val_split]
    )
    train_loader = DataLoader(train_split_dataset, batch_size=32, shuffle=True, collate_fn=hybrid_collate)
    val_loader = DataLoader(val_split_dataset, batch_size=32, shuffle=False, collate_fn=hybrid_collate)
    return train_loader,val_loader

def get_test_dataloader():
    test_loader = DataLoader(test_dataset,batch_size=32,shuffle=False,collate_fn=hybrid_collate)
    return test_loader

In [None]:
from net.HybridGenomeNet import HybridGenomeNet
from torch import nn
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torchmetrics.classification import ConfusionMatrix
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau


def train_step(dataloader, device, model, criterion, optimizer):
    model.train()
    train_loss, train_correct, train_total = 0, 0, 0
    for inputs, labels in dataloader:
        images = inputs[0].to(device)
        sequences = inputs[1].to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model((images, sequences))
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
    epoch_train_loss = train_loss / train_total
    epoch_train_acc = train_correct / train_total
    return epoch_train_loss, epoch_train_acc


def evaluate_step(dataloader, device, model, criterion):
    val_loss, val_correct, val_total = 0, 0, 0
    model.eval()

    with torch.no_grad():
        for inputs, labels in dataloader:
            images = inputs[0].to(device)
            sequences = inputs[1].to(device)
            labels = labels.to(device)

            outputs = model((images, sequences))
            loss = criterion(outputs, labels)

            # Statistics
            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

            # Update confusion matrix
            # confmat.update(predicted, labels)
    epoch_val_loss = val_loss / val_total
    epoch_val_acc = val_correct / val_total

    return epoch_val_loss, epoch_val_acc


def train_hybrid():
    train_loader, val_loader = get_train_val_dataloaders()

    device_string = "cpu"
    if torch.mps.is_available:
        device_string = "mps"
    elif torch.cuda.is_available():
        device_string = "cuda"

    device = torch.device(device_string)
    model = HybridGenomeNet().to(device)
    class_weights = torch.tensor([2.0, 1.0]).to(device=device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = CosineAnnealingWarmRestarts(optimizer, 
                                      T_0=30,  # Initial cycle length
                                      T_mult=1,  # Cycle length multiplier
                                      eta_min=1e-7)  # Minimum LR
    

    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

    num_epochs = 200

    for epoch in range(num_epochs):
        epoch_train_loss, epoch_train_acc = train_step(
            train_loader, device, model, criterion, optimizer
        )
        epoch_val_loss, epoch_val_acc = evaluate_step(
            val_loader, device, model, criterion
        )

        scheduler.step(epoch_val_loss)

        history["train_loss"].append(epoch_train_loss)
        history["val_loss"].append(epoch_val_loss)
        history["train_acc"].append(epoch_train_acc)
        history["val_acc"].append(epoch_val_acc)

        print(
            f"Epoch {epoch+1}/{num_epochs} | "
            f"Train Loss: {epoch_train_loss:.4f} | "
            f"Train Acc: {epoch_train_acc:.4f} | "
            f"Val Loss: {epoch_val_loss:.4f} | "
            f"Val Acc: {epoch_val_acc:.4f}"
        )

    test_loader = get_test_dataloader()
    test_loss, test_acc = evaluate_step(test_loader, device, model, criterion)
    print(
        f"Test Results| " f"Test Loss: {test_loss:.4f} | " f"Test Acc: {test_acc:.4f}"
    )

    return model


model = train_hybrid()

Epoch 1/200 | Train Loss: 5080.6347 | Train Acc: 0.5093 | Val Loss: 4253.8164 | Val Acc: 0.7778
Epoch 2/200 | Train Loss: 994.9496 | Train Acc: 0.7130 | Val Loss: 958.4403 | Val Acc: 0.7778
Epoch 3/200 | Train Loss: 495.5942 | Train Acc: 0.6574 | Val Loss: 646.2391 | Val Acc: 0.7778
Epoch 4/200 | Train Loss: 314.0625 | Train Acc: 0.5000 | Val Loss: 59.2643 | Val Acc: 0.2222
Epoch 5/200 | Train Loss: 66.6621 | Train Acc: 0.4722 | Val Loss: 46.6154 | Val Acc: 0.2222
Epoch 6/200 | Train Loss: 26.7841 | Train Acc: 0.5463 | Val Loss: 0.6738 | Val Acc: 0.7778
Epoch 7/200 | Train Loss: 0.8118 | Train Acc: 0.6944 | Val Loss: 0.6461 | Val Acc: 0.7778
Epoch 8/200 | Train Loss: 8.0559 | Train Acc: 0.6667 | Val Loss: 0.6555 | Val Acc: 0.7778
Epoch 9/200 | Train Loss: 0.7285 | Train Acc: 0.7130 | Val Loss: 0.6555 | Val Acc: 0.7778
Epoch 10/200 | Train Loss: 5.1860 | Train Acc: 0.6296 | Val Loss: 0.6557 | Val Acc: 0.7778
Epoch 11/200 | Train Loss: 0.7214 | Train Acc: 0.7037 | Val Loss: 0.6571 | Val 