In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, global_mean_pool, global_max_pool
import pickle
from torch_geometric.data import Data
from sklearn.metrics import roc_auc_score, average_precision_score
from torch.cuda.amp import autocast, GradScaler

Please run this after the scrapping code and downloading the data

In [None]:
BASE_DIR = "amr_dataset_500"
GRAPH_DATA_DIR = os.path.join(BASE_DIR, "graph_data")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Clear GPU cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()


Load Data

In [None]:
with open(os.path.join(GRAPH_DATA_DIR, "graphs.pkl"), "rb") as f:
    graphs_raw = pickle.load(f)

graphs = []
for g in graphs_raw:
    graphs.append(
        Data(
            x=torch.tensor(g["x"], dtype=torch.long),
            edge_index=torch.tensor(g["edge_index"], dtype=torch.long),
            y=torch.tensor(g["y"], dtype=torch.float32).unsqueeze(0),
            card=torch.tensor(g["card"], dtype=torch.float32).unsqueeze(0),
            genome_feat=torch.tensor(g.get("genome_feat", np.zeros(3)), dtype=torch.float32).unsqueeze(0),
        )
    )

with open(os.path.join(GRAPH_DATA_DIR, "genome_ids.json"), "r") as f:
    genome_ids = json.load(f)
with open(os.path.join(GRAPH_DATA_DIR, "antibiotics.json"), "r") as f:
    antibiotics = json.load(f)
with open(os.path.join(GRAPH_DATA_DIR, "kmer_vocab.json"), "r") as f:
    kmer_vocab = json.load(f)
with open(os.path.join(GRAPH_DATA_DIR, "card_genes.json"), "r") as f:
    card_genes = json.load(f)["genes"]

print(f"Loaded {len(graphs)} samples")

Class Weights

In [None]:
labels_all = torch.stack([g.y.squeeze(0) for g in graphs])
pos_weights = []
for i in range(labels_all.shape[1]):
    pos = labels_all[:, i].sum()
    neg = len(labels_all) - pos
    weight = max(neg / (pos + 1e-6), 0.1)
    pos_weights.append(min(weight, 10.0))
pos_weights = torch.tensor(pos_weights, dtype=torch.float32).to(device)


Split

In [None]:
train_size = int(0.7 * len(graphs))
val_size = int(0.15 * len(graphs))

train_dataset = graphs[:train_size]
val_dataset = graphs[train_size:train_size + val_size]
test_dataset = graphs[train_size + val_size:]

print(f"\nDataset split: Train={len(train_dataset)}, Val={len(val_dataset)}, Test={len(test_dataset)}")

#Smaller batch size
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)
test_loader = DataLoader(test_dataset, batch_size=4)


Model

In [None]:
class MemoryEfficientGAT(nn.Module):
    def __init__(self, num_kmers, kmer_emb_dim=64, gat_hidden=128, gat_heads=4,
                 card_feat_dim=len(card_genes), genome_feat_dim=3, num_classes=len(antibiotics)):
        super().__init__()

        self.kmer_emb = nn.Embedding(num_kmers, kmer_emb_dim)

        self.gat1 = GATConv(kmer_emb_dim, gat_hidden // gat_heads, heads=gat_heads, dropout=0.3)
        self.gat2 = GATConv(gat_hidden, gat_hidden, heads=1, concat=False, dropout=0.3)

        self.card_mlp = nn.Sequential(
            nn.Linear(card_feat_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU()
        )

        self.genome_mlp = nn.Sequential(
            nn.Linear(genome_feat_dim, 16),
            nn.ReLU()
        )

        combined_dim = gat_hidden * 2 + 32 + 16

        self.final_mlp = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x, edge_index, batch, card, genome_feat):
        x = self.kmer_emb(x).squeeze(1)

        x = F.elu(self.gat1(x, edge_index))
        x = F.elu(self.gat2(x, edge_index))

        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x_graph = torch.cat([x_mean, x_max], dim=1)

        card_feat = self.card_mlp(card)
        genome_embed = self.genome_mlp(genome_feat)

        combined = torch.cat([x_graph, card_feat, genome_embed], dim=1)
        out = self.final_mlp(combined)

        return out

model = MemoryEfficientGAT(num_kmers=len(kmer_vocab["kmer2idx"])).to(device)
print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)

Mixed precision scaler

In [None]:
scaler = GradScaler()

def compute_metrics(output, target):
    probs = torch.sigmoid(output).cpu().numpy()
    preds = (probs > 0.5).astype(float)
    target_np = target.cpu().numpy()
    acc = (preds == target_np).mean()
    try:
        auc = roc_auc_score(target_np, probs, average='macro')
    except:
        auc = 0.0
    return acc, auc

def train_epoch():
    model.train()
    total_loss = 0

    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # Mixed precision forward
        with autocast():
            out = model(batch.x, batch.edge_index, batch.batch, batch.card, batch.genome_feat)
            loss = criterion(out, batch.y)

        # Scaled backward
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * batch.num_graphs

        # Clear cache periodically
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    return total_loss / len(train_loader.dataset)

def eval_epoch(loader):
    model.eval()
    total_loss = 0
    all_outputs = []
    all_targets = []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)

            with autocast():
                out = model(batch.x, batch.edge_index, batch.batch, batch.card, batch.genome_feat)
                loss = criterion(out, batch.y)

            total_loss += loss.item() * batch.num_graphs
            all_outputs.append(out.cpu())
            all_targets.append(batch.y.cpu())

    all_outputs = torch.cat(all_outputs, dim=0)
    all_targets = torch.cat(all_targets, dim=0)

    acc, auc = compute_metrics(all_outputs, all_targets)

    return total_loss / len(loader.dataset), acc, auc

print("\n training\n")
EPOCHS = 30
best_val_auc = 0
for epoch in range(1, EPOCHS + 1):
    train_loss = train_epoch()
    val_loss, val_acc, val_auc = eval_epoch(val_loader)

    scheduler.step()

    print(f"Epoch {epoch:02d} | Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val AUC: {val_auc:.4f}")

    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save(model.state_dict(), f"{BASE_DIR}/best_model.pt")
        print(f"Best model saved (AUC: {val_auc:.4f})")

    # Clear cache after each epoch
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print("\n Training complete!")
model.load_state_dict(torch.load(f"{BASE_DIR}/best_model.pt"))
test_loss, test_acc, test_auc = eval_epoch(test_loader)

print(f"\nTest Results:")
print(f"   Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f"   ROC-AUC:  {test_auc:.4f}")

