# FullySharedDataParallel (FSDP) Model Training

FSDP focus on **How we split/shard model parameters and optimizer state across GPUs**

### What this FSDP script does?
High-level design:
- Uses `mp.spawn` to lunch one process per GPU. <br>
- Each process:
    * Initilizes distributed with `init_process_group`
    * Wraps the model with FSDP
    * Uses a **DistributedSampler** for train/val sets. <br> <br>
- Runs for:
    * `gpu_count` in `[1, 2, 4]`
    * `batch_size` in `[64, 128, 512]`
    * `mode` in `["fsdp", "fsdp_amp"]` (plain FSDP vs FSDP+AMP)
    * `num_epochs = 10` <br><br>
- Per epoch, on each config, it:
    * Measures **epoch wall-clock time.**
    * Tracks **train loss, train accuracy, val loss, val accuracy.**
    * Measures peak **GPU Memory** on each rank with:
        * `torch.cuda.reset_peak_memory_stats(device)` at start of epoch
        * `torch.cuda.max_memory_allocated(device)` at end. <br><br>
- Aggregates metrics across all ranks (GPUs) using `all_reduce`:
    * Sums loss and counts to get **global** accuracy and loss.
    * Uses `all reduce(..., op=MAX)` to get **global peak memory in bytes.** <br><br>
- Rank 0:
    * Appends per-epoch metrics to `shared_metrics` list (multiprocessing Manager).
    * Prints nice logs: global acc, loss, epoch time, peak memory. <br><br>
- After all runs finish:
    * Convers `shared_metrics` to a DataFrame.
    * Save to `fsdp_benckmark_metrics.csv` and prints a summary DataFrame at the end.


In [1]:
%%writefile fsdp_benchmark_with_metrics_report.py
import os
import time
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributed as dist

from torch.nn import Module
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms
from sklearn.model_selection import train_test_split

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy

# AMP (new API if available, else fallback)
try:
    from torch.amp import autocast, GradScaler
    USE_NEW_AMP = True
except ImportError:
    from torch.cuda.amp import autocast, GradScaler
    USE_NEW_AMP = False


# ============================================================
# 1. DISTRIBUTED SETUP / CLEANUP
# ============================================================

def setup_dist(rank, world_size):
    """
    Initialize the distributed process group for FSDP.
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup_dist():
    dist.destroy_process_group()


# ============================================================
# 2. DATASET & MODEL DEFINITION
# ============================================================

class OralCancerDataset(Dataset):
    """
    Dataset:
    - dataframe has columns: 'id', 'label'
    - path_map maps 'id' -> full image path
    """
    def __init__(self, dataframe, path_map, transform=None):
        self.data = dataframe.reset_index(drop=True)
        self.path_map = path_map
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_id = self.data.iloc[idx]["id"]
        label = int(self.data.iloc[idx]["label"])

        img_path = self.path_map.get(img_id)
        if img_path is None:
            image = Image.new("RGB", (96, 96))
        else:
            image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label


class SimpleCNN(Module):
    """
    Simple CNN:
    - 3 conv blocks (Conv -> BN -> ReLU -> MaxPool)
    - Flatten -> FC(256) -> Dropout -> FC(1)
    - Paired with BCEWithLogitsLoss for binary classification (0/1).
    """
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 12 * 12, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = x.view(-1, 128 * 12 * 12)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


# ============================================================
# 3. FSDP TRAINING FUNCTION (PER PROCESS / PER GPU)
# ============================================================

def fsdp_train_process(
    rank, world_size, df, path_map, batch_size, num_epochs, mode, csv_path
):
    """
    rank      : process rank (0 .. world_size-1)
    world_size: number of GPUs/processes
    df        : full dataframe with columns ['id', 'label']
    path_map  : dict mapping id -> image path
    batch_size: batch size per GPU
    num_epochs: number of training epochs
    mode      : 'fsdp' or 'fsdp_amp'
    csv_path  : path to the CSV file to append metrics
    """
    setup_dist(rank, world_size)
    device = torch.device(f"cuda:{rank}")

    # ---------- Data transforms ----------
    IMG_SIZE = 96
    train_transforms = transforms.Compose(
        [
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5] * 3, [0.5] * 3),
        ]
    )

    val_transforms = transforms.Compose(
        [
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5] * 3, [0.5] * 3),
        ]
    )

    # ---------- Train/Val split ----------
    train_df, val_df = train_test_split(
        df,
        test_size=0.2,
        random_state=42,
        stratify=df["label"],
    )

    train_dataset = OralCancerDataset(train_df, path_map, transform=train_transforms)
    val_dataset = OralCancerDataset(val_df, path_map, transform=val_transforms)

    train_sampler = DistributedSampler(
        train_dataset, num_replicas=world_size, rank=rank, shuffle=True
    )
    val_sampler = DistributedSampler(
        val_dataset, num_replicas=world_size, rank=rank, shuffle=False
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,  # sampler does the shuffling
        sampler=train_sampler,
        num_workers=0,
        pin_memory=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        sampler=val_sampler,
        num_workers=0,
        pin_memory=True,
    )

    # ---------- Model, FSDP wrapper, optimizer, loss ----------
    torch.cuda.set_device(device)
    model = SimpleCNN().to(device)

    # Choose sharding strategy explicitly to avoid warnings
    if world_size == 1:
        sharding_strategy = ShardingStrategy.NO_SHARD
    else:
        sharding_strategy = ShardingStrategy.FULL_SHARD

    model = FSDP(model, device_id=device, sharding_strategy=sharding_strategy)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    scaler = GradScaler() if mode == "fsdp_amp" else None

    # ---------- Checkpoint setup ----------
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)
    ckpt_name = f"fsdp_{mode}_g{world_size}_b{batch_size}.pt"
    ckpt_path = os.path.join(ckpt_dir, ckpt_name)

    start_epoch = 0
    if os.path.exists(ckpt_path):
        # Load checkpoint on every rank (simpler for FSDP)
        checkpoint = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        if scaler is not None and checkpoint.get("scaler_state", None) is not None:
            scaler.load_state_dict(checkpoint["scaler_state"])
        start_epoch = checkpoint.get("epoch", 0)
        if rank == 0:
            print(
                f"[RESUME] Mode={mode}, GPUs={world_size}, Batch={batch_size} "
                f"from epoch {start_epoch}"
            )
    else:
        if rank == 0:
            print(
                f"[START] Mode={mode}, GPUs={world_size}, Batch={batch_size}, "
                f"{num_epochs} epochs"
            )

    run_start_time = time.time()

    for epoch in range(start_epoch, num_epochs):
        train_sampler.set_epoch(epoch)

        # Reset peak memory stats at the start of each epoch
        torch.cuda.reset_peak_memory_stats(device)

        # --------------- TRAINING PHASE ---------------
        model.train()
        epoch_start = time.time()

        train_loss_sum = 0.0
        train_correct_sum = 0
        train_total = 0

        for images, labels in train_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.float().unsqueeze(1).to(device, non_blocking=True)

            optimizer.zero_grad()

            if mode == "fsdp_amp" and scaler is not None:
                # AMP-enabled forward + backward
                if USE_NEW_AMP:
                    amp_ctx = autocast(device_type="cuda")
                else:
                    amp_ctx = autocast()

                with amp_ctx:
                    outputs = model(images)
                    loss = criterion(outputs, labels)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            train_loss_sum += loss.item() * images.size(0)
            probs = torch.sigmoid(outputs)
            preds = (probs >= 0.5).float()
            train_correct_sum += (preds == labels).sum().item()
            train_total += labels.size(0)

        epoch_time = time.time() - epoch_start

        # --------------- VALIDATION PHASE ---------------
        model.eval()
        val_loss_sum = 0.0
        val_correct_sum = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device, non_blocking=True)
                labels = labels.float().unsqueeze(1).to(device, non_blocking=True)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss_sum += loss.item() * images.size(0)
                probs = torch.sigmoid(outputs)
                preds = (probs >= 0.5).float()
                val_correct_sum += (preds == labels).sum().item()
                val_total += labels.size(0)

        # --------------- METRIC AGGREGATION ---------------
        metrics_tensor = torch.tensor(
            [
                train_loss_sum,
                train_correct_sum,
                train_total,
                val_loss_sum,
                val_correct_sum,
                val_total,
            ],
            device=device,
        )
        dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)

        # Peak memory (MAX across ranks)
        peak_mem_bytes_local = torch.cuda.max_memory_allocated(device)
        mem_tensor = torch.tensor(
            peak_mem_bytes_local, device=device, dtype=torch.float64
        )
        dist.all_reduce(mem_tensor, op=dist.ReduceOp.MAX)
        peak_mem_bytes_global = mem_tensor.item()
        peak_mem_gb_global = peak_mem_bytes_global / (1024 ** 3)

        # Unpack global metrics safely
        global_train_loss = metrics_tensor[0].item() / max(
            metrics_tensor[2].item(), 1
        )
        global_train_acc = metrics_tensor[1].item() / max(
            metrics_tensor[2].item(), 1
        )
        global_val_loss = metrics_tensor[3].item() / max(
            metrics_tensor[5].item(), 1
        )
        global_val_acc = metrics_tensor[4].item() / max(
            metrics_tensor[5].item(), 1
        )

        print(
            f"[Rank {rank}] Epoch {epoch+1}/{num_epochs} "
            f"LocalTrainAcc={(train_correct_sum / max(train_total,1)):.4f}"
        )

        # --------------- SAVE CHECKPOINT & LOG (RANK 0) ---------------
        if rank == 0:
            # Save checkpoint so we can resume this config later
            torch.save(
                {
                    "epoch": epoch + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scaler_state": scaler.state_dict() if scaler is not None else None,
                },
                ckpt_path,
            )

            print(
                f"   >>> [Mode={mode.upper()}] Epoch {epoch+1}/{num_epochs} "
                f"| TrainAcc={global_train_acc:.4f} ValAcc={global_val_acc:.4f} "
                f"| EpochTime={epoch_time:.2f}s | PeakMem={peak_mem_gb_global:.2f} GB"
            )

            # Append metrics row to CSV (survives crashes)
            row = {
                "mode": mode,
                "gpu_count": world_size,
                "batch_size": batch_size,
                "epoch": epoch + 1,
                "train_loss": global_train_loss,
                "train_acc": global_train_acc,
                "val_loss": global_val_loss,
                "val_acc": global_val_acc,
                "epoch_time": epoch_time,
                "peak_mem_bytes": peak_mem_bytes_global,
                "peak_mem_gb": peak_mem_gb_global,
            }
            file_exists = os.path.exists(csv_path)
            pd.DataFrame([row]).to_csv(
                csv_path, mode="a", header=not file_exists, index=False
            )

    total_run_time = time.time() - run_start_time
    if rank == 0:
        print(
            f"\n[SUMMARY][Mode={mode.upper()}] [GPUs={world_size}] [Batch={batch_size}] "
            f"Total time for {num_epochs} epochs (including resume): {total_run_time:.2f}s\n"
        )

    cleanup_dist()


# ============================================================
# 4. MAIN EXECUTION LOOP
# ============================================================

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

    print("Loading dataset for FSDP benchmark...")

    # ------------- Load CSV (must have 'id', 'label' columns) -------------
    df = pd.read_csv("oral_cancer_balanced.csv")

    # ------------- Build id -> path map (similar to your DDP script) -------------
    path_map = {}
    for root, dirs, files in os.walk("Data"):
        if "val" in dirs:
            dirs.remove("val")
        for file in files:
            if file.lower().endswith((".jpg", ".jpeg", ".png")) and "val" not in root:
                path_map[file] = os.path.join(root, file)

    max_gpus_available = torch.cuda.device_count()
    print(f"Available GPUs: {max_gpus_available}")

    # You can tweak which GPU counts & batch sizes to test
    target_gpu_counts = [1, 2, 4]
    gpu_counts = [g for g in target_gpu_counts if g <= max_gpus_available]

    batch_sizes = [64, 128, 512]
    num_epochs = 10

    modes = ["fsdp", "fsdp_amp"]  # plain FSDP vs FSDP + AMP
    csv_path = "fsdp_benchmark_metrics.csv"

    print("Starting FSDP benchmark (with checkpointing + AMP comparison)...")

    for mode in modes:
        for n_gpus in gpu_counts:
            for b_size in batch_sizes:
                print(
                    f"\n=== Running Mode={mode.upper()}, GPUs={n_gpus}, BatchSize={b_size} ==="
                )
                run_start = time.time()
                try:
                    mp.spawn(
                        fsdp_train_process,
                        args=(n_gpus, df, path_map, b_size, num_epochs, mode, csv_path),
                        nprocs=n_gpus,
                        join=True,
                    )
                except Exception as e:
                    print(f"[ERROR] Mode={mode}, GPUs={n_gpus}, Batch={b_size} :: {e}")
                run_total = time.time() - run_start
                print(
                    f"*** Completed Mode={mode.upper()}, GPUs={n_gpus}, Batch={b_size} "
                    f"in {run_total:.2f}s ***"
                )

    print("\nFSDP benchmark completed. Metrics saved in 'fsdp_benchmark_metrics.csv'.")

Overwriting fsdp_benchmark_with_metrics_report.py


In [2]:
!python fsdp_benchmark_with_metrics_report.py

Loading dataset for FSDP benchmark...
Available GPUs: 4
Starting FSDP benchmark (with checkpointing + AMP comparison)...

=== Running Mode=FSDP, GPUs=1, BatchSize=64 ===
  checkpoint = torch.load(ckpt_path, map_location=device)
[RESUME] Mode=fsdp, GPUs=1, Batch=64 from epoch 10

[SUMMARY][Mode=FSDP] [GPUs=1] [Batch=64] Total time for 10 epochs (including resume): 0.00s

*** Completed Mode=FSDP, GPUs=1, Batch=64 in 5.86s ***

=== Running Mode=FSDP, GPUs=1, BatchSize=128 ===
  checkpoint = torch.load(ckpt_path, map_location=device)
[RESUME] Mode=fsdp, GPUs=1, Batch=128 from epoch 10

[SUMMARY][Mode=FSDP] [GPUs=1] [Batch=128] Total time for 10 epochs (including resume): 0.00s

*** Completed Mode=FSDP, GPUs=1, Batch=128 in 5.28s ***

=== Running Mode=FSDP, GPUs=1, BatchSize=512 ===
  checkpoint = torch.load(ckpt_path, map_location=device)
[RESUME] Mode=fsdp, GPUs=1, Batch=512 from epoch 10

[SUMMARY][Mode=FSDP] [GPUs=1] [Batch=512] Total time for 10 epochs (including resume): 0.00s

*** Co

## Re-run

In [19]:
%%writefile clean_fsdp_amp_2_4.py
import os
import pandas as pd

CSV_PATH = "fsdp_benchmark_metrics.csv"
CKPT_DIR = "checkpoints"

# --- 1. Clean the CSV -------------------------------------------------
if os.path.exists(CSV_PATH):
    df = pd.read_csv(CSV_PATH)
    print(f"Loaded {len(df)} rows from {CSV_PATH}")

    # Rows we want to drop: fsdp_amp with gpu_count 2 or 4
    drop_mask = (df["mode"] == "fsdp_amp") & (df["gpu_count"].isin([1, 2, 4]))
    n_drop = drop_mask.sum()

    df_clean = df[~drop_mask].reset_index(drop=True)
    df_clean.to_csv(CSV_PATH, index=False)

    print(f"Removed {n_drop} rows where mode='fsdp_amp' and gpu_count in [2, 4].")
    print(f"New CSV row count: {len(df_clean)}")
else:
    print(f"{CSV_PATH} not found. Skipping CSV cleanup.")

# --- 2. Delete old checkpoints for fsdp_amp, GPUs 2 and 4 -------------
# In the FSDP script we named checkpoints:
#   ckpt_name = f\"fsdp_{mode}_g{world_size}_b{batch_size}.pt\"
# So for mode='fsdp_amp' and world_size=2, batch=64:
#   'fsdp_fsdp_amp_g2_b64.pt'

if os.path.isdir(CKPT_DIR):
    modes = ["fsdp_amp"]
    gpu_counts = [1, 2, 4]
    batch_sizes = [64, 128, 512]

    removed = 0
    for mode in modes:
        for g in gpu_counts:
            for b in batch_sizes:
                ckpt_name = f"fsdp_{mode}_g{g}_b{b}.pt"
                ckpt_path = os.path.join(CKPT_DIR, ckpt_name)
                if os.path.exists(ckpt_path):
                    os.remove(ckpt_path)
                    removed += 1
                    print(f"Removed checkpoint: {ckpt_path}")

    if removed == 0:
        print("No matching fsdp_amp checkpoints found to remove.")
    else:
        print(f"Total checkpoints removed: {removed}")
else:
    print(f"No '{CKPT_DIR}' directory found. Skipping checkpoint cleanup.")

Overwriting clean_fsdp_amp_2_4.py


In [20]:
!python clean_fsdp_amp_2_4.py

Loaded 120 rows from fsdp_benchmark_metrics.csv
Removed 30 rows where mode='fsdp_amp' and gpu_count in [2, 4].
New CSV row count: 90
No matching fsdp_amp checkpoints found to remove.


In [13]:
%%writefile fsdp_fsdp_amp_rerun_g2_g4_metrics2.py
import os
import time
import math
from PIL import Image

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributed as dist

from torch.nn import Module
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms
from sklearn.model_selection import train_test_split

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy

# AMP check
try:
    from torch.amp import autocast, GradScaler
    USE_NEW_AMP = True
except ImportError:
    from torch.cuda.amp import autocast, GradScaler
    USE_NEW_AMP = False

# ============================================================
# 1. DISTRIBUTED SETUP / CLEANUP
# ============================================================
def setup_dist(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12387"  
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup_dist():
    dist.destroy_process_group()

# ============================================================
# 2. DATASET & MODEL
# ============================================================
class OralCancerDataset(Dataset):
    def __init__(self, dataframe, path_map, transform=None):
        self.data = dataframe.reset_index(drop=True)
        self.path_map = path_map
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_id = row["id"]
        label = int(row["label"])
        img_path = self.path_map.get(img_id)
        if img_path is None:
            image = Image.new("RGB", (96, 96))
        else:
            image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

class SimpleCNN(Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 12 * 12, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = x.view(-1, 128 * 12 * 12)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# ============================================================
# 3. WORKER: FSDP + AMP TRAINING FOR ONE CONFIG
# ============================================================
def fsdp_amp_worker(rank, world_size, df, path_map, batch_size, num_epochs, csv_path):
    setup_dist(rank, world_size)
    device = torch.device(f"cuda:{rank}")

    IMG_SIZE = 96
    train_tf = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3),
    ])
    val_tf = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3),
    ])

    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df["label"])

    train_dataset = OralCancerDataset(train_df, path_map, transform=train_tf)
    val_dataset   = OralCancerDataset(val_df,   path_map, transform=val_tf)

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, shuffle=False, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler, shuffle=False, num_workers=0, pin_memory=True)

    model = SimpleCNN().to(device)
    sharding_strategy = ShardingStrategy.NO_SHARD if world_size == 1 else ShardingStrategy.FULL_SHARD
    model = FSDP(model, device_id=device, sharding_strategy=sharding_strategy)

    criterion = nn.BCEWithLogitsLoss()
    
    # FIX 1: LOWER LEARNING RATE
    # 1e-3 is often too unstable for Multi-GPU AMP. 1e-4 is safer.
    optimizer = optim.Adam(model.parameters(), lr=1e-4) 
    
    scaler = GradScaler()

    if rank == 0:
        print(f"[FSDP_AMP] world_size={world_size}, batch_size={batch_size}")

    for epoch in range(num_epochs):
        train_sampler.set_epoch(epoch)
        torch.cuda.reset_peak_memory_stats(device)
        epoch_start = time.time()

        # ----- TRAIN -----
        model.train()
        train_loss_sum = 0.0
        train_correct_sum = 0
        train_total = 0

        for images, labels in train_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.float().unsqueeze(1).to(device, non_blocking=True)

            optimizer.zero_grad()

            if USE_NEW_AMP:
                amp_ctx = autocast(device_type="cuda")
            else:
                amp_ctx = autocast()

            with amp_ctx:
                outputs = model(images)
                loss = criterion(outputs, labels)

            # FIX 2: GRADIENT CLIPPING WITH AMP
            scaler.scale(loss).backward()
            
            # Unscale the optimizer first
            scaler.unscale_(optimizer)
            # Now clip gradients to prevent them from hitting Infinity
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            scaler.step(optimizer)
            scaler.update()

            # FIX 3: NAN CHECK BEFORE SUMMING
            loss_val = loss.item()
            if not math.isnan(loss_val) and not math.isinf(loss_val):
                train_loss_sum += loss_val * images.size(0)
            
            probs = torch.sigmoid(outputs)
            preds = (probs >= 0.5).float()
            train_correct_sum += (preds == labels).sum().item()
            train_total += labels.size(0)

        epoch_time = time.time() - epoch_start

        # ----- VAL -----
        model.eval()
        val_loss_sum = 0.0
        val_correct_sum = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device, non_blocking=True)
                labels = labels.float().unsqueeze(1).to(device, non_blocking=True)

                if USE_NEW_AMP:
                    amp_ctx = autocast(device_type="cuda")
                else:
                    amp_ctx = autocast()

                with amp_ctx:
                    outputs = model(images)
                    loss = criterion(outputs, labels)

                # FIX 3 (Applied to Val): NAN CHECK
                loss_val = loss.item()
                if not math.isnan(loss_val) and not math.isinf(loss_val):
                    val_loss_sum += loss_val * images.size(0)

                probs = torch.sigmoid(outputs)
                preds = (probs >= 0.5).float()
                val_correct_sum += (preds == labels).sum().item()
                val_total += labels.size(0)

        # ----- METRICS -----
        metrics = torch.tensor([
            train_loss_sum, train_correct_sum, train_total,
            val_loss_sum, val_correct_sum, val_total,
        ], device=device)
        dist.all_reduce(metrics, op=dist.ReduceOp.SUM)

        peak_mem_local = torch.cuda.max_memory_allocated(device)
        mem_tensor = torch.tensor(peak_mem_local, device=device, dtype=torch.float64)
        dist.all_reduce(mem_tensor, op=dist.ReduceOp.MAX)
        peak_mem_bytes = mem_tensor.item()
        peak_mem_gb = peak_mem_bytes / (1024**3)

        g_train_loss = metrics[0].item() / max(metrics[2].item(), 1)
        g_train_acc  = metrics[1].item() / max(metrics[2].item(), 1)
        g_val_loss   = metrics[3].item() / max(metrics[5].item(), 1)
        g_val_acc    = metrics[4].item() / max(metrics[5].item(), 1)

        if rank == 0:
            print(f" >>> [FSDP_AMP] GPUs={world_size}, epoch={epoch+1} | TrainLoss={g_train_loss:.4f} ValAcc={g_val_acc:.4f}")
            row = {
                "mode": "fsdp_amp",
                "gpu_count": world_size,
                "batch_size": batch_size,
                "epoch": epoch + 1,
                "train_loss": g_train_loss,
                "train_acc": g_train_acc,
                "val_loss": g_val_loss,
                "val_acc": g_val_acc,
                "epoch_time": epoch_time,
                "peak_mem_bytes": peak_mem_bytes,
                "peak_mem_gb": peak_mem_gb,
            }
            file_exists = os.path.exists(csv_path)
            pd.DataFrame([row]).to_csv(csv_path, mode="a", header=not file_exists, index=False)

    cleanup_dist()

# ============================================================
# 4. MAIN
# ============================================================
if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    
    # We save to the SAME file you used before, so it just appends the missing data
    CSV_PATH = "fsdp_benchmark_metrics2.csv"
    
    print("Loading dataset...")
    df = pd.read_csv("oral_cancer_balanced.csv")
    path_map = {}
    for root, dirs, files in os.walk("Data"):
        if "val" in dirs: dirs.remove("val")
        for file in files:
            if file.lower().endswith((".jpg", ".jpeg", ".png")) and "val" not in root:
                path_map[file] = os.path.join(root, file)

    max_gpus = torch.cuda.device_count()
    
    # ONLY RERUN 2 and 4, since 1 worked
    target_gpu_counts = [2, 4]
    gpu_counts = [g for g in target_gpu_counts if g <= max_gpus]

    batch_sizes = [64, 128, 512] 
    num_epochs = 10 

    for n_gpus in gpu_counts:
        for b_size in batch_sizes:
            print(f"\n=== RUN: GPUs={n_gpus}, batch_size={b_size} ===")
            try:
                mp.spawn(fsdp_amp_worker, args=(n_gpus, df, path_map, b_size, num_epochs, CSV_PATH), nprocs=n_gpus, join=True)
            except Exception as e:
                print(f"[ERROR] GPUs={n_gpus}, batch={b_size} :: {e}")

Overwriting fsdp_fsdp_amp_rerun_g2_g4_metrics2.py


In [14]:
!python fsdp_fsdp_amp_rerun_g2_g4_metrics2.py

Loading dataset...

=== RUN: GPUs=2, batch_size=64 ===
[FSDP_AMP] world_size=2, batch_size=64
 >>> [FSDP_AMP] GPUs=2, epoch=1 | TrainLoss=0.1090 ValAcc=0.5000
 >>> [FSDP_AMP] GPUs=2, epoch=2 | TrainLoss=0.0000 ValAcc=0.5000
 >>> [FSDP_AMP] GPUs=2, epoch=4 | TrainLoss=0.0000 ValAcc=0.5000
 >>> [FSDP_AMP] GPUs=2, epoch=6 | TrainLoss=0.0000 ValAcc=0.5000
 >>> [FSDP_AMP] GPUs=2, epoch=7 | TrainLoss=0.0000 ValAcc=0.5000
 >>> [FSDP_AMP] GPUs=2, epoch=9 | TrainLoss=0.0000 ValAcc=0.5000
 >>> [FSDP_AMP] GPUs=2, epoch=10 | TrainLoss=0.0000 ValAcc=0.5000

=== RUN: GPUs=2, batch_size=128 ===
[FSDP_AMP] world_size=2, batch_size=128
 >>> [FSDP_AMP] GPUs=2, epoch=1 | TrainLoss=0.2041 ValAcc=0.5000
 >>> [FSDP_AMP] GPUs=2, epoch=3 | TrainLoss=0.0000 ValAcc=0.5000
 >>> [FSDP_AMP] GPUs=2, epoch=4 | TrainLoss=0.0000 ValAcc=0.5000
 >>> [FSDP_AMP] GPUs=2, epoch=5 | TrainLoss=0.0000 ValAcc=0.5000
 >>> [FSDP_AMP] GPUs=2, epoch=6 | TrainLoss=0.0000 ValAcc=0.5000
 >>> [FSDP_AMP] GPUs=2, epoch=7 | TrainLoss=0.00