In [0]:
CATALOG = dbutils.widgets.get("CATALOG")
SCHEMA = dbutils.widgets.get("SCHEMA")
VOLUME = dbutils.widgets.get("VOLUME")
TABLE = dbutils.widgets.get("TABLE")
EXPERIMENT = dbutils.widgets.get("EXPERIMENT")

table_path = f"{CATALOG}.{SCHEMA}.{TABLE}"

In [0]:
import torch

print("CUDA available:", torch.cuda.is_available())
print("GPU count:", torch.cuda.device_count())

for i in range(torch.cuda.device_count()):
    print(i, torch.cuda.get_device_name(i))

In [0]:
import mlflow

mlflow.set_experiment(EXPERIMENT)
mlflow.pytorch.autolog()

In [0]:
import os
import shutil
import random

BASE_DIR = f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME}/chest-xray-pneumonia/chest_xray_aug/"

train_dir = os.path.join(BASE_DIR, "train")
val_dir   = os.path.join(BASE_DIR, "val")

classes = ["NORMAL", "PNEUMONIA"]
n_to_move = 100

for cls in classes:
    src_dir = os.path.join(train_dir, cls)
    dst_dir = os.path.join(val_dir, cls)

    os.makedirs(dst_dir, exist_ok=True)

    # list files in source class directory
    files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]
    random.shuffle(files)

    # take up to n_to_move files (handles case where there are fewer than 100)
    to_move = files[:n_to_move]

    print(f"Moving {len(to_move)} files from {src_dir} to {dst_dir} ...")

    for fname in to_move:
        src = os.path.join(src_dir, fname)
        dst = os.path.join(dst_dir, fname)
        shutil.move(src, dst)

print("Done.")

In [0]:
import os
from pathlib import Path
from typing import Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

import torchvision.transforms as T
import torchvision.datasets as datasets
import torchvision.models as models

import mlflow
import mlflow.pytorch

###############################################################################
# Config
###############################################################################

BATCH_SIZE = 32          # per-GPU batch size
NUM_EPOCHS = 20
NUM_WORKERS = 4
LR = 1e-4
WEIGHT_DECAY = 1e-4
MODEL_NAME = "resnet18"

In [0]:
###############################################################################
# DDP setup helpers
###############################################################################

def setup_ddp() -> Tuple[torch.device, int, int]:
    """
    Initialize torch.distributed and return (device, rank, world_size).
    Assumes env variables: RANK, WORLD_SIZE, LOCAL_RANK or similar
    are set by the serverless orchestrator.
    """
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
    else:
        # Fallback to single-process multi-GPU or single GPU
        rank = 0
        world_size = 1

    if "LOCAL_RANK" in os.environ:
        local_rank = int(os.environ["LOCAL_RANK"])
    else:
        local_rank = 0

    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
    else:
        device = torch.device("cpu")

    if world_size > 1:
        torch.distributed.init_process_group(
            backend="nccl" if device.type == "cuda" else "gloo",
            rank=rank,
            world_size=world_size,
        )

    return device, rank, world_size


def cleanup_ddp():
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()

In [0]:
###############################################################################
# Model
###############################################################################

def build_model(num_classes: int) -> nn.Module:
    """
    Use an ImageNet-pretrained ResNet18 and replace the final layer [web:7][web:15].
    """
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

###############################################################################
# Training & evaluation loops
###############################################################################

def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device,
    epoch: int,
    rank: int,
):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for step, (images, labels) in enumerate(loader):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        if rank == 0 and step % 50 == 0:
            print(f"Epoch {epoch} Step {step} Loss {loss.item():.4f}")

    epoch_loss = running_loss / max(total, 1)
    epoch_acc = correct / max(total, 1)

    return epoch_loss, epoch_acc


@torch.no_grad()
def evaluate(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        outputs = model(images)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    loss_val = running_loss / max(total, 1)
    acc_val = correct / max(total, 1)
    return loss_val, acc_val

In [0]:
###############################################################################
# Data
###############################################################################

def get_transforms(image_size: int = 224):
    # Typical ImageNet-like transforms for chest X-ray classification [web:6][web:9]
    train_tf = T.Compose([
        T.Resize((image_size, image_size)),
        T.RandomHorizontalFlip(),
        T.RandomRotation(10),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
    ])

    eval_tf = T.Compose([
        T.Resize((image_size, image_size)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
    ])

    return train_tf, eval_tf


def get_dataloaders(rank: int, world_size: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Expects this folder structure (Kaggle-style) [web:1][web:11][web:14]:
        chest_xray_aug/
          train/
            NORMAL/
            PNEUMONIA/
          val/
            NORMAL/
            PNEUMONIA/
          test/
            NORMAL/
            PNEUMONIA/
    """
    train_tf, eval_tf = get_transforms()

    train_dir = BASE_DIR + "train"
    val_dir   = BASE_DIR + "val"
    test_dir  = BASE_DIR + "test"

    train_dataset = datasets.ImageFolder(train_dir, transform=train_tf)
    val_dataset   = datasets.ImageFolder(val_dir, transform=eval_tf)
    test_dataset  = datasets.ImageFolder(test_dir, transform=eval_tf)

    if world_size > 1:
        train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
        val_sampler   = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
        test_sampler  = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank, shuffle=False)
        shuffle_flag = False
    else:
        train_sampler = None
        val_sampler = None
        test_sampler = None
        shuffle_flag = True

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        sampler=train_sampler,
        shuffle=shuffle_flag if train_sampler is None else False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        sampler=val_sampler,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        sampler=test_sampler,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )

    return train_loader, val_loader, test_loader

In [0]:
###############################################################################
# MLflow-enabled main
###############################################################################

device, rank, world_size = setup_ddp()

with mlflow.start_run() as run:
  train_loader, val_loader, test_loader = get_dataloaders(rank, world_size)
  num_classes = len(train_loader.dataset.classes)

  model = build_model(num_classes)
  model.to(device)

  if world_size > 1:
      # Wrap with DDP for multi-GPU training [web:7]
      model = DDP(model, device_ids=[device.index], output_device=device.index)

  criterion = nn.CrossEntropyLoss().to(device)
  optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

  best_val_acc = 0.0

  for epoch in range(1, NUM_EPOCHS + 1):
      if world_size > 1:
          # Ensure each epoch shuffles differently in DistributedSampler
          train_loader.sampler.set_epoch(epoch)

      train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, rank)
      val_loss, val_acc = evaluate(model, val_loader, criterion, device)

      # Aggregate metrics across processes if needed
      if world_size > 1:
          # Convert to tensors and reduce
          metrics = torch.tensor(
              [train_loss, train_acc, val_loss, val_acc],
              device=device,
          )
          torch.distributed.all_reduce(metrics, op=torch.distributed.ReduceOp.SUM)
          metrics = metrics / world_size
          train_loss, train_acc, val_loss, val_acc = metrics.tolist()

      if rank == 0:
          print(
              f"Epoch {epoch}/{NUM_EPOCHS} "
              f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}"
          )
          mlflow.log_metrics(
              {
                  "train_loss": train_loss,
                  "train_acc": train_acc,
                  "val_loss": val_loss,
                  "val_acc": val_acc,
              },
              step=epoch,
          )

          # Save best model
          if val_acc > best_val_acc:
              best_val_acc = val_acc
              # For DDP, model.module is the underlying model
              model_to_log = model.module if isinstance(model, DDP) else model
              mlflow.pytorch.log_model(model_to_log, artifact_path="model_best")

  # Final test evaluation on best model (reload from MLflow or just use in-memory)
  test_loss, test_acc = evaluate(model, test_loader, criterion, device)

  if world_size > 1:
      metrics = torch.tensor([test_loss, test_acc], device=device)
      torch.distributed.all_reduce(metrics, op=torch.distributed.ReduceOp.SUM)
      metrics = metrics / world_size
      test_loss, test_acc = metrics.tolist()

  if rank == 0:
      print(f"Test Loss: {test_loss:.4f} Acc: {test_acc:.4f}")
      mlflow.log_metrics({"test_loss": test_loss, "test_acc": test_acc})
      # Log final model as well
      model_to_log = model.module if isinstance(model, DDP) else model
      mlflow.pytorch.log_model(model_to_log, artifact_path="model_final")

  cleanup_ddp()