In [None]:
!pip install -q torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
!pip install -q torchmetrics


In [None]:
import os, pathlib, time, random
import torch, torchvision
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image

In [None]:
class ChestXrayDataset(Dataset):
    def __init__(self, df, root_dir, transform=None, label_col='label'):
        self.df = df.reset_index(drop=True)
        self.root = pathlib.Path(root_dir)/"train"
        self.transform = transform
        self.label_col = label_col
        # Build label mapping if labels are strings
        if self.label_col in df.columns:
            unique = sorted(df[self.label_col].dropna().unique())
            self.label2idx = {v:i for i,v in enumerate(unique)}
        else:
            self.label2idx = {}
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = self.root/f"{row['image_id']}.png"
        img = Image.open(img_path).convert('RGB')
        if self.transform: img = self.transform(img)
        # label handling: convert to integer class if text label provided
        if self.label_col in self.df.columns:
            lab = self.label2idx.get(row[self.label_col], 0)
            return img, torch.tensor(lab, dtype=torch.long)
        else:
            return img, torch.tensor(0, dtype=torch.long)


In [None]:
train_df = pd.read_csv("data/vinbigdata_png/train_split.csv")
val_df   = pd.read_csv("data/vinbigdata_png/val_split.csv")

train_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_tf = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),
                             transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])

train_ds = ChestXrayDataset(train_df, "data/vinbigdata_png", transform=train_tf)
val_ds   = ChestXrayDataset(val_df, "data/vinbigdata_png", transform=val_tf)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(train_ds.label2idx) or 2
model = torchvision.models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [None]:
def train_one_epoch(model, loader, opt, device):
    model.train()
    running = 0.0
    for xb,yb in loader:
        xb,yb = xb.to(device), yb.to(device)
        preds = model(xb)
        loss = criterion(preds, yb)
        opt.zero_grad(); loss.backward(); opt.step()
        running += loss.item()*xb.size(0)
    return running/len(loader.dataset)

def evaluate(model, loader, device):
    model.eval()
    import torchmetrics
    acc = torchmetrics.Accuracy().to(device)
    with torch.no_grad():
        for xb,yb in loader:
            xb,yb = xb.to(device), yb.to(device)
            preds = model(xb).argmax(dim=1)
            acc.update(preds, yb)
    return acc.compute().item()

for epoch in range(3):  # baseline: 3 epochs
    t0 = time.time()
    tr_loss = train_one_epoch(model, train_loader, optimizer, device)
    val_acc = evaluate(model, val_loader, device)
    print(f"Epoch {epoch+1} Loss {tr_loss:.4f} ValAcc {val_acc:.4f} time {(time.time()-t0):.0f}s")
# Save baseline model
torch.save(model.state_dict(), "models/resnet18_baseline.pth")


In [None]:

import os, numpy as np
from torchvision.models.detection import fasterrcnn_resnet50_fpn

class DetectionDataset(Dataset):
    def __init__(self, df, root_dir, transforms=None):
        self.df = df
        self.root = pathlib.Path(root_dir)/"train"
        # group rows per image
        self.grouped = df.groupby('image_id')
        self.image_ids = list(self.grouped.groups.keys())
        self.transforms=transforms
    def __len__(self): return len(self.image_ids)
    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = self.root/f"{img_id}.png"
        img = Image.open(img_path).convert("RGB")
        w,h = img.size
        ann_df = self.grouped.get_group(img_id)
        boxes = []
        labels = []
        for _,r in ann_df.iterrows():
            x,y,wbox,hbox = r['x'], r['y'], r['w'], r['h']
            boxes.append([x,y,x+wbox,y+hbox])
            labels.append(int(r.get('class_id',1)))
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)
        target = {"boxes":boxes, "labels":labels, "image_id": torch.tensor([idx])}
        if self.transforms:
            img = self.transforms(img)
        return transforms.ToTensor()(img), target

ann_csv = pd.read_csv("data/vinbigdata_png/annotations.csv")  # adapt filename
sample_ann = ann_csv.groupby('image_id').head(1).image_id.unique()[:500]
sample_df = ann_csv[ann_csv['image_id'].isin(sample_ann)]
det_ds = DetectionDataset(sample_df, "data/vinbigdata_png")
det_loader = DataLoader(det_ds, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
det_model = fasterrcnn_resnet50_fpn(pretrained=True)
det_model.to(device)
params = [p for p in det_model.parameters() if p.requires_grad]
opt = torch.optim.SGD(params, lr=1e-3, momentum=0.9, weight_decay=0.0005)
# Training loop (very small)
det_model.train()
for epoch in range(2):
    for images, targets in det_loader:
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k,v in t.items()} for t in targets]
        loss_dict = det_model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        opt.zero_grad(); loss.backward(); opt.step()
    print("Detection epoch", epoch+1, "done")
torch.save(det_model.state_dict(), "models/fasterrcnn_baseline.pth")
