In [None]:
!pip install -q albumentations==1.3.0 timm captum optuna


In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
train_aug = A.Compose([
    A.Resize(256,256),
    A.RandomCrop(224,224),
    A.HorizontalFlip(p=0.5),
    A.GaussNoise(p=0.2),
    A.RandomBrightnessContrast(p=0.2),
    ToTensorV2()
])
val_aug = A.Compose([A.Resize(224,224), ToTensorV2()])

# Adapter dataset: apply albumentations on PIL images
class AlbChestDataset(ChestXrayDataset):
    def __init__(self, df, root_dir, alb_transform=None, label_col='label'):
        super().__init__(df, root_dir, transform=None, label_col=label_col)
        self.alb = alb_transform
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(self.root/f"{row['image_id']}.png").convert('RGB')
        img_np = np.array(img)
        if self.alb:
            out = self.alb(image=img_np)
            img = out['image'].float()/255.0
            # normalize same as earlier
            # Note: ToTensorV2 already returns CHW
        label = self.label2idx.get(row[self.label_col], 0) if self.label_col in self.df.columns else 0
        return img, torch.tensor(label, dtype=torch.long)


In [None]:
import timm, optuna
def create_model(lr, dropout):
    m = timm.create_model('tf_efficientnet_b0', pretrained=True, num_classes=num_classes)
    if hasattr(m, 'classifier'):
        m.classifier = nn.Sequential(nn.Dropout(dropout), nn.Linear(m.classifier.in_features, num_classes))
    return m.to(device)

def objective(trial):
    lr = trial.suggest_loguniform('lr', 1e-5, 1e-3)
    drop = trial.suggest_uniform('dropout', 0.0, 0.5)
    model = create_model(lr, drop)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    # small training loop on subset for quick search
    for epoch in range(1):
        model.train()
        for i,(xb,yb) in enumerate(train_loader):
            xb,yb = xb.to(device), yb.to(device)
            loss = nn.CrossEntropyLoss()(model(xb), yb)
            opt.zero_grad(); loss.backward(); opt.step()
            if i>20: break
    # return val loss (quick eval)
    model.eval(); vl=0; cnt=0
    with torch.no_grad():
        for xb,yb in val_loader:
            xb,yb = xb.to(device), yb.to(device)
            vl += nn.CrossEntropyLoss()(model(xb), yb).item()
            cnt += 1
            if cnt>5: break
    return vl/cnt

study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=8)
print("Best params:", study.best_params)


In [None]:
from captum.attr import LayerGradCam, visualize_image_attr
model.eval()
layer_gc = LayerGradCam(model, model.layer4[-1])  
img, label = val_ds[0]
inp = img.unsqueeze(0).to(device)
attr = layer_gc.attribute(inp, target=label.item())
# upsample and display
from captum.attr import LayerAttribution
cam = LayerAttribution.interpolate(attr, img.shape[1:])
import matplotlib.pyplot as plt
plt.imshow(img.permute(1,2,0).cpu()); plt.imshow(cam.squeeze().cpu(), alpha=0.5)
plt.title("Grad-CAM overlay")
