In [57]:
%pip install torchmetrics

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [58]:
from amr.dataset import HybridGenomeDataset
import torch

kleb = "Klebsiella_pneumoniae_aztreonam"
staphy = "Staphylococcus_aureus_cefoxitin"

config = {
    "k" : 6,
    "pathogen" : kleb,
    "genes" : ["acrR","gyrA","gyrB","ompK35","ompK36","ompK37","parC","rpsL"],
    "root_dir" : "../data/ds1",
    "batch_size" : 32,
    "learning_rate" : 0.0001,
    "epochs" : 200
}

def get_device():
    device_string = "cpu"
    if torch.mps.is_available:
        device_string = "mps"
    elif torch.cuda.is_available():
        device_string = "cuda"

    return torch.device(device_string)

In [59]:
import torch
from torch.nn.utils.rnn import pad_sequence

def hybrid_collate(batch):
    images = []
    sequences = []
    genes = []
    labels = []
    
    for (img, seq,gene), label in batch:
        images.append(img)
        sequences.append(seq)
        labels.append(label)
        genes.append(gene)
    
    return (torch.stack(images), pad_sequence(sequences, batch_first=True),torch.tensor(genes)), torch.tensor(labels)

In [60]:
from torch.utils.data import DataLoader, random_split

def get_dataloader(dataset,batch_size,shuffle=False):
    return DataLoader(dataset,batch_size=batch_size,shuffle=shuffle,collate_fn=hybrid_collate)

def get_train_val_dataloaders(dataset,batch_size=32,val_split=0.2):
    train_split_dataset, val_split_dataset = random_split(
        dataset, [1 - val_split, val_split]
    )
    train_loader = get_dataloader(train_split_dataset,batch_size,True)
    val_loader = get_dataloader(val_split_dataset,batch_size,False)
    return train_loader,val_loader

In [61]:
from torchmetrics.classification import ConfusionMatrix

def step(dataloader, device, model, criterion, optimizer,val=False):    
    model.eval() if eval else model.train()

    loss, correct, total = 0, 0, 0
    confmat = ConfusionMatrix(task="multiclass", num_classes=2).to(device)
    
    for inputs, labels in dataloader:
        images = inputs[0].to(device)
        sequences = inputs[1].to(device)
        labels = labels.to(device)

        if not val:
            optimizer.zero_grad()
        outputs = model((images, sequences))
        loss = criterion(outputs, labels)
        if not val:
            loss.backward()
            optimizer.step()

        loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        confmat.update(predicted, labels)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    epoch_loss = loss / total
    epoch_acc = correct / total
    matrix = confmat.compute()
    
    return epoch_loss,epoch_acc,matrix

In [62]:
from net.HybridGenomeNet import HybridGenomeNet
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

def train(model,device,criterion,optimizer):
    train_dataset = HybridGenomeDataset(
        root_dir=config["root_dir"],
        train_or_test="train",
        pathogen=config["pathogen"],
        genes=config["genes"],
        k=config["k"]
    )
    train_loader, val_loader = get_train_val_dataloaders(dataset=train_dataset,batch_size=config["batch_size"],val_split=0.2)
    scheduler = CosineAnnealingWarmRestarts(optimizer, 
                                      T_0=30,  # Initial cycle length
                                      T_mult=1,  # Cycle length multiplier
                                      eta_min=1e-7)  # Minimum LR
    
    num_epochs = config["epochs"]
    for epoch in range(num_epochs):
        epoch_train_loss, epoch_train_acc,train_matrix = step(
            train_loader, device, model, criterion, optimizer
        )
        with torch.no_grad():
            epoch_val_loss, epoch_val_acc,val_matrix = step(
                val_loader, device, model, criterion,optimizer,True
            )
            print(val_matrix)
        scheduler.step(epoch_val_loss)

        print(
            f"Epoch {epoch+1}/{num_epochs} | "
            f"Train Loss: {epoch_train_loss:.4f} | "
            f"Train Acc: {epoch_train_acc:.4f} | "
            f"Val Loss: {epoch_val_loss:.4f} | "
            f"Val Acc: {epoch_val_acc:.4f}"
        )

    torch.save(model.state_dict(), "model_weights.pth")
    
    return model


In [63]:
def test(model,device,criterion,optimizer):
    
    for i in range(0,len(config["genes"])):
        gene = config["genes"][i]
        print("Gene: ",gene)
    
        test_dataset = HybridGenomeDataset(
            root_dir=config["root_dir"],
            train_or_test="test",
            pathogen=config["pathogen"],
            genes=[gene],
            k=config["k"]
        )
        test_loader = get_dataloader(test_dataset,config["batch_size"],False)
        with torch.no_grad():
            test_loss, test_acc,confusion_matrix = step(test_loader, device, model, criterion,optimizer,True)
            print(
                f"Test Results| " f"Test Loss: {test_loss:.4f} | " f"Test Acc: {test_acc:.4f}"
            )
            print(confusion_matrix)

In [None]:

device = get_device()
model = HybridGenomeNet().to(device)
class_weights = torch.tensor([10.0, 1.0]).to(device=device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

model = train(model,device,criterion,optimizer)
test(model,device,criterion,optimizer)

tensor([[  0,  13],
        [  0, 203]], device='mps:0')
Epoch 1/200 | Train Loss: 0.0029 | Train Acc: 0.9132 | Val Loss: 0.0038 | Val Acc: 0.9398
tensor([[  0,  13],
        [  0, 203]], device='mps:0')
Epoch 2/200 | Train Loss: 0.0035 | Train Acc: 0.9132 | Val Loss: 0.0040 | Val Acc: 0.9398
tensor([[  0,  13],
        [  0, 203]], device='mps:0')
Epoch 3/200 | Train Loss: 0.0004 | Train Acc: 0.9132 | Val Loss: 0.0036 | Val Acc: 0.9398
tensor([[  0,  13],
        [  0, 203]], device='mps:0')
Epoch 4/200 | Train Loss: 0.0036 | Train Acc: 0.9132 | Val Loss: 0.0036 | Val Acc: 0.9398
tensor([[  0,  13],
        [  0, 203]], device='mps:0')
Epoch 5/200 | Train Loss: 0.0010 | Train Acc: 0.9132 | Val Loss: 0.0037 | Val Acc: 0.9398
tensor([[  0,  13],
        [  0, 203]], device='mps:0')
Epoch 6/200 | Train Loss: 0.0022 | Train Acc: 0.9132 | Val Loss: 0.0036 | Val Acc: 0.9398
tensor([[  0,  13],
        [  0, 203]], device='mps:0')
Epoch 7/200 | Train Loss: 0.0009 | Train Acc: 0.9132 | Val Lo

In [None]:
device = get_device()
model = HybridGenomeNet().to(device)
model.load_state_dict(torch.load("model_weights.pth"))

test(model,device,criterion,optimizer)

Gene:  acrR
Test Results| Test Loss: 1.0792 | Test Acc: 0.7333
tensor([[ 0,  4],
        [ 0, 11]], device='mps:0')
Gene:  gyrA
Test Results| Test Loss: 1.4586 | Test Acc: 0.7333
tensor([[ 0,  4],
        [ 0, 11]], device='mps:0')
Gene:  gyrB
Test Results| Test Loss: 1.1402 | Test Acc: 0.7333
tensor([[ 1,  3],
        [ 1, 10]], device='mps:0')
Gene:  ompK35
Test Results| Test Loss: 0.9812 | Test Acc: 0.7333
tensor([[ 0,  4],
        [ 0, 11]], device='mps:0')
Gene:  ompK36
Test Results| Test Loss: 1.7126 | Test Acc: 0.7333
tensor([[ 0,  4],
        [ 0, 11]], device='mps:0')
Gene:  ompK37
Test Results| Test Loss: 1.1736 | Test Acc: 0.7333
tensor([[ 0,  4],
        [ 0, 11]], device='mps:0')
Gene:  parC
Test Results| Test Loss: 1.4806 | Test Acc: 0.7333
tensor([[ 0,  4],
        [ 0, 11]], device='mps:0')
Gene:  rpsL
Test Results| Test Loss: 1.1750 | Test Acc: 0.7333
tensor([[ 0,  4],
        [ 0, 11]], device='mps:0')
