In [3]:
# Setup & Mount Drive (Colab)

!pip install --quiet torch torchvision h5py tqdm

from google.colab import drive
drive.mount('/content/drive')

import os, torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on", DEVICE)

# Path to your HDF5 file in Drive
H5_PATH = "/content/drive/MyDrive/new_volumes.h5"

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m123.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m95.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m62.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
# Cell 2: HDF5 Volume Dataset with slice‐cap + DataLoaders

import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# 1) Read all patient IDs and labels from the HDF5
with h5py.File(H5_PATH, "r") as hf:
    all_pids = list(hf["volumes"].keys())
    labels   = [int(hf["volumes"][pid].attrs["label"]) for pid in all_pids]

# 2) Train/Val split (stratified)
train_pids, val_pids = train_test_split(
    all_pids, test_size=0.2, stratify=labels, random_state=42
)

# 3) Dataset class (caps volumes at max_slices)
class H5VolumeDataset(Dataset):
    def __init__(self, h5_path, pids, max_slices=64):
        self.h5_path    = h5_path
        self.pids       = pids
        self.max_slices = max_slices

    def __len__(self):
        return len(self.pids)

    def __getitem__(self, idx):
        pid = self.pids[idx]
        with h5py.File(self.h5_path, "r") as hf:
            grp    = hf["volumes"][pid]
            volume = grp["images"][:]            # numpy [N,2,H,W]
            label  = int(grp.attrs["label"])

        # if more than max_slices, uniformly sample down
        N, C, H, W = volume.shape
        if N > self.max_slices:
            indices = np.linspace(0, N-1, self.max_slices, dtype=int)
            volume  = volume[indices]

        # to tensor & normalize to [0,1]
        vol_t = torch.from_numpy(volume).float().div(255.0)  # [n,2,H,W]
        return vol_t, label

# 4) Collate function: pad to the largest volume in the batch
def collate_fn(batch):
    vols, labs = zip(*batch)
    max_n = max(v.shape[0] for v in vols)
    padded = []
    for v in vols:
        n, c, h, w = v.shape
        if n < max_n:
            pad = torch.zeros((max_n-n, c, h, w), dtype=v.dtype)
            v   = torch.cat([v, pad], dim=0)
        padded.append(v)
    return torch.stack(padded, dim=0), torch.tensor(labs, dtype=torch.long)

# 5) Instantiate datasets & loaders
train_ds = H5VolumeDataset(H5_PATH, train_pids, max_slices=64)
val_ds   = H5VolumeDataset(H5_PATH, val_pids,   max_slices=64)

train_loader = DataLoader(
    train_ds,
    batch_size=1,        # start with 1 volume per batch
    shuffle=True,
    num_workers=0,       # keep 0 for safety in Colab
    pin_memory=False,
    collate_fn=collate_fn
)
val_loader = DataLoader(
    val_ds,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    pin_memory=False,
    collate_fn=collate_fn
)

print(f"Train vols: {len(train_ds)}, Val vols: {len(val_ds)}")


Train vols: 452, Val vols: 114


Train vols: 452, Val vols: 114


In [5]:
# Model Definition (MIL ResNet-18)
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18, ResNet18_Weights

class MILResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        backbone.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.feat = nn.Sequential(
            backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool,
            backbone.layer1, backbone.layer2, backbone.layer3, backbone.layer4,
            backbone.avgpool  # outputs [B,512,1,1]
        )
        self.cls  = nn.Linear(512, 2)

    def forward(self, x):
        # x: [B, N, 2, H, W]
        B,N,_,H,W = x.shape
        x = x.view(B*N, 2, H, W)                  # [B*N,2,H,W]
        f = self.feat(x).view(B, N, 512)          # [B,N,512]
        bag, _ = f.max(dim=1)                     # [B,512]
        return self.cls(bag)

model = MILResNet18().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()


In [5]:
# Training Loop
from tqdm.notebook import tqdm
import time
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score

EPOCHS = 5
CKPT_DIR = "/content/drive/MyDrive/PDAC_models"
os.makedirs(CKPT_DIR, exist_ok=True)
best_auc = 0.0

for epoch in range(1, EPOCHS+1):
    t0 = time.time()

    # — Train —
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0

    train_bar = tqdm(train_loader,
                     desc=f"Epoch {epoch} Train:",
                     unit="it",
                     leave=False)

    for x, y in train_bar:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        # accumulate
        train_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += x.size(0)

        # optional: show running avg in bar
        train_bar.set_postfix(loss=f"{train_loss/total:.4f}", acc=f"{correct/total:.4f}")

    train_loss /= total
    train_acc  = correct/total
    print(f" → Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")

    # — Validate —
    model.eval()
    val_loss = 0.0
    val_corr = 0
    val_tot  = 0
    y_true, y_score = [], []

    val_bar = tqdm(val_loader,
                   desc=f"Epoch {epoch} Val:  ",
                   unit="it",
                   leave=False)

    with torch.no_grad():
        for x, y in val_bar:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            loss = criterion(logits, y)

            # accumulate
            val_loss += loss.item() * x.size(0)
            preds = logits.argmax(dim=1)
            val_corr += (preds == y).sum().item()
            val_tot  += x.size(0)

            probs = torch.softmax(logits, dim=1)[:,1].cpu().numpy()
            y_true.extend(y.cpu().numpy())
            y_score.extend(probs)

            val_bar.set_postfix(loss=f"{val_loss/val_tot:.4f}", acc=f"{val_corr/val_tot:.4f}")

    val_loss /= val_tot
    val_acc   = val_corr/val_tot
    auc = roc_auc_score(y_true, y_score)
    print(f" → Val   Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, AUC: {auc:.3f}")

    # — Checkpoint —
    ckpt_path = os.path.join(CKPT_DIR, f"resnet18_epoch{epoch}.pth")
    torch.save(model.state_dict(), ckpt_path)
    print(f" ✔️ Saved {os.path.basename(ckpt_path)}")

    # track best
    if auc > best_auc:
        best_auc = auc
        best_path = os.path.join(CKPT_DIR, "resnet18_best.pth")
        torch.save(model.state_dict(), best_path)
        print(f"    → New best model saved to {os.path.basename(best_path)}")

print("Training complete.")


Epoch 1 Train::   0%|          | 0/452 [00:00<?, ?it/s]

 → Train Loss: 0.6319, Acc: 0.6881


Epoch 1 Val:  :   0%|          | 0/114 [00:00<?, ?it/s]

 → Val   Loss: 0.6404, Acc: 0.5877, AUC: 0.582
 ✔️ Saved resnet18_epoch1.pth
    → New best model saved to resnet18_best.pth


Epoch 2 Train::   0%|          | 0/452 [00:00<?, ?it/s]

 → Train Loss: 0.4189, Acc: 0.8186


Epoch 2 Val:  :   0%|          | 0/114 [00:00<?, ?it/s]

 → Val   Loss: 1.0218, Acc: 0.5877, AUC: 0.500
 ✔️ Saved resnet18_epoch2.pth


Epoch 3 Train::   0%|          | 0/452 [00:00<?, ?it/s]

 → Train Loss: 0.2320, Acc: 0.9159


Epoch 3 Val:  :   0%|          | 0/114 [00:00<?, ?it/s]

 → Val   Loss: 0.9304, Acc: 0.6491, AUC: 0.636
 ✔️ Saved resnet18_epoch3.pth
    → New best model saved to resnet18_best.pth


Epoch 4 Train::   0%|          | 0/452 [00:00<?, ?it/s]

 → Train Loss: 0.0593, Acc: 0.9845


Epoch 4 Val:  :   0%|          | 0/114 [00:00<?, ?it/s]

 → Val   Loss: 0.9489, Acc: 0.5789, AUC: 0.655
 ✔️ Saved resnet18_epoch4.pth
    → New best model saved to resnet18_best.pth


Epoch 5 Train::   0%|          | 0/452 [00:00<?, ?it/s]

 → Train Loss: 0.0107, Acc: 1.0000


Epoch 5 Val:  :   0%|          | 0/114 [00:00<?, ?it/s]

 → Val   Loss: 1.1374, Acc: 0.6053, AUC: 0.621
 ✔️ Saved resnet18_epoch5.pth
Training complete.
