# Domain Generalization (ERM) on PACS

In [9]:
import os
import random
from pathlib import Path
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import DataLoader, ConcatDataset, Subset
from torchvision import transforms, datasets

 
DATA_ROOT = Path(r"C:\Users\Fatim_Sproj\Desktop\Fatim\Spring 2025\Datasets\pacs_data\pacs_data") 
SOURCE_DOMAINS = ["art_painting", "cartoon", "photo"]
TARGET_DOMAIN = "sketch"

IMAGE_SIZE = 224
BATCH_SIZE = 64
NUM_EPOCHS = 10

LR = 1e-4
SEED = 42

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

VIT_MODEL_NAME = "WinKawaks/vit-tiny-patch16-224"

random.seed(SEED)
torch.manual_seed(SEED)


<torch._C.Generator at 0x18ee2301030>

In [10]:
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

val_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])


def load_domain(domain_name, transform):
    p = DATA_ROOT / domain_name
    if not p.exists():
        raise FileNotFoundError(f"Domain path not found: {p.resolve()}")
    ds = datasets.ImageFolder(str(p), transform=transform)
    return ds

source_datasets = {d: load_domain(d, train_transform) for d in SOURCE_DOMAINS}
target_dataset = load_domain(TARGET_DOMAIN, val_transform)

class_lists = [tuple(ds.classes) for ds in source_datasets.values()] + [tuple(target_dataset.classes)]
if len(set(class_lists)) != 1:
    print("WARNING: Class lists differ between domains. Ensure class folders match and are ordered the same.")

NUM_CLASSES = len(next(iter(source_datasets.values())).classes)

train_dataset = ConcatDataset(list(source_datasets.values()))
val_fraction = 0.2
num_val = int(len(train_dataset) * val_fraction)
num_train = len(train_dataset) - num_val
train_subset = Subset(train_dataset, list(range(0, num_train)))
val_subset = Subset(train_dataset, list(range(num_train, num_train + num_val)))

train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

target_loader = DataLoader(target_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

source_loaders = {d: DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
                  for d, ds in source_datasets.items()}

print(f"Train size: {len(train_subset)}, Val size: {len(val_subset)}, Target size: {len(target_dataset)}")
print(f"Num classes: {NUM_CLASSES}")


Train size: 4842, Val size: 1210, Target size: 3929
Num classes: 7


In [11]:
import torchvision

class DomainBedModel(nn.Module):
    def __init__(self, featurizer, classifier, dropout_rate=0.0):
        super().__init__()
        self.featurizer = featurizer
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = classifier

    def forward(self, x):
        features = self.featurizer(x)
        features_dropped_out = self.dropout(features)
        return self.fc(features_dropped_out)

resnet_dropout = 0.0
freeze_bn_flag = True

base_model = torchvision.models.resnet18(pretrained=True)
n_outputs = base_model.fc.in_features
base_model.fc = nn.Identity()
featurizer = base_model

if freeze_bn_flag:
    for module in featurizer.modules():
        if isinstance(module, nn.BatchNorm2d):
            module.eval()

classifier = nn.Linear(n_outputs, NUM_CLASSES)

model = DomainBedModel(
    featurizer,
    classifier,
    dropout_rate=resnet_dropout
).to(DEVICE)

criterion = nn.CrossEntropyLoss()



In [12]:
import torch
from torch.optim import Optimizer

class SAM(Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        params = list(params)
        params = [p for p in params if p is not None and p.requires_grad]
        if len(params) == 0:
            raise ValueError("SAM received no parameters to optimize. "
                             "Did you pass model.parameters() after they've been consumed or did you freeze all params?")

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super().__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.adaptive = adaptive
        self.rho = rho

    @torch.no_grad()
    def first_step(self, zero_grad=True):
        grad_norm = self._grad_norm()
        scale = self.rho / (grad_norm + 1e-12)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                if self.adaptive:
                    e_w = (torch.abs(p) * p.grad) * scale
                else:
                    e_w = p.grad * scale
                p.add_(e_w)                 
                self.state[p]['e_w'] = e_w  

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=True):
        for group in self.param_groups:
            for p in group['params']:
                e_w = self.state[p].get('e_w', None)
                if e_w is None:
                    continue
                p.sub_(e_w)  
        self.base_optimizer.step()
        if zero_grad:
            self.zero_grad()

    def zero_grad(self):
        self.base_optimizer.zero_grad()

    def _grad_norm(self):
        device = self.param_groups[0]['params'][0].device
        norms = []
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                if self.adaptive:
                    norms.append(((torch.abs(p) * p.grad).norm(p=2)).to(device))
                else:
                    norms.append((p.grad.norm(p=2)).to(device))
        if not norms:
            return torch.tensor(0.0, device=device)
        stacked = torch.stack(norms)
        return torch.norm(stacked, p=2)


In [13]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from tqdm import tqdm
from pathlib import Path
import json
import matplotlib.pyplot as plt

criterion = nn.CrossEntropyLoss()

def train_one_epoch_sam(model, loader, optimizer, device, criterion):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc="Train (SAM)", leave=False)
    for imgs, labels in pbar:
        imgs = imgs.to(device)
        labels = labels.to(device)

        def closure():
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            return loss

        closure_loss = closure()
        optimizer.first_step(zero_grad=True)
        closure_loss = closure()
        optimizer.second_step(zero_grad=True)

        outputs = model(imgs)
        loss = criterion(outputs, labels)
        total_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
        pbar.set_postfix({'loss': total_loss/total, 'acc': 100*correct/total})

    return total_loss/total, 100*correct/total


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    total_loss = 0.0
    for imgs, labels in loader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        total_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    return total_loss/total, 100*correct/total




In [14]:
from torch.optim import AdamW
rho = 0.5
sam_optimizer = SAM(
    [p for p in model.parameters() if p.requires_grad],
    AdamW,
    lr=LR,
    weight_decay=1e-4,
    rho=rho
)

best_target_acc = 0.0
best_epoch = 0
metrics = {"epochs": []}

OUTPUT_DIR = Path("./sam_outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc = train_one_epoch_sam(model, train_loader, sam_optimizer, DEVICE, criterion)
    target_loss, target_acc = evaluate(model, target_loader, DEVICE)
    print(f"Epoch {epoch}/{NUM_EPOCHS} | Train loss: {train_loss:.4f} acc: {train_acc:.2f}% | Target loss: {target_loss:.4f} acc: {target_acc:.2f}%")

    metrics["epochs"].append({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "target_loss": target_loss,
        "target_acc": target_acc
    })

    if target_acc > best_target_acc:
        best_target_acc = target_acc
        best_epoch = epoch
        torch.save(model.state_dict(), OUTPUT_DIR / "best_model.pth")

print(f"\nBest target-domain accuracy (SAM): {best_target_acc:.2f}% at epoch {best_epoch}")

best_model_path = OUTPUT_DIR / "best_model.pth"
model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
model.to(DEVICE)
model.eval()

domain_results = {}
total_acc = 0.0
for d, loader in source_loaders.items():
    _, acc = evaluate(model, loader, DEVICE)
    domain_results[d] = acc
    total_acc += acc

t_loss, t_acc = evaluate(model, target_loader, DEVICE)
domain_results[TARGET_DOMAIN] = t_acc
total_acc += t_acc
mean_acc = total_acc / (len(source_loaders) + 1)
mean_source_acc = total_acc - t_acc
mean_source_acc /= len(source_loaders)

metrics.update({
    "best_target_acc": best_target_acc,
    "best_epoch": best_epoch,
    "mean_acc": mean_acc,
    "mean_source_acc": mean_source_acc,
    "domain_results": domain_results
})

with open(OUTPUT_DIR / "metrics.json", "w") as f:
    json.dump(metrics, f, indent=2)

epochs = [m["epoch"] for m in metrics["epochs"]]
train_accs = [m["train_acc"] for m in metrics["epochs"]]
target_accs = [m["target_acc"] for m in metrics["epochs"]]
train_losses = [m["train_loss"] for m in metrics["epochs"]]
target_losses = [m["target_loss"] for m in metrics["epochs"]]

plt.figure(figsize=(8, 5))
plt.plot(epochs, train_accs, label="Train Acc")
plt.plot(epochs, target_accs, label="Target Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.title("SAM Training and Target Accuracy")
plt.tight_layout()
plt.savefig(OUTPUT_DIR / "acc_curves.png")
plt.close()

plt.figure(figsize=(8, 5))
plt.plot(epochs, train_losses, label="Train Loss")
plt.plot(epochs, target_losses, label="Target Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("SAM Training and Target Loss")
plt.tight_layout()
plt.savefig(OUTPUT_DIR / "loss_curves.png")
plt.close()

domains = list(domain_results.keys())
accuracies = [domain_results[d] for d in domains]
plt.figure(figsize=(8, 5))
plt.bar(domains, accuracies)
plt.ylabel("Accuracy (%)")
plt.title("Per-Domain Accuracy (SAM Final Model)")
plt.xticks(rotation=30, ha="right")
plt.tight_layout()
plt.savefig(OUTPUT_DIR / "domain_accuracy_bar.png")
plt.close()

                                                                                  

Epoch 1/10 | Train loss: 0.7529 acc: 79.35% | Target loss: 1.5186 acc: 44.03%


                                                                                  

Epoch 2/10 | Train loss: 0.4569 acc: 88.93% | Target loss: 1.1649 acc: 66.51%


                                                                                  

Epoch 3/10 | Train loss: 0.3225 acc: 93.10% | Target loss: 1.0704 acc: 71.77%


                                                                                  

Epoch 4/10 | Train loss: 0.2361 acc: 95.17% | Target loss: 0.9734 acc: 69.25%


                                                                                  

Epoch 5/10 | Train loss: 0.1955 acc: 96.03% | Target loss: 0.9127 acc: 68.44%


                                                                                  

Epoch 6/10 | Train loss: 0.1530 acc: 97.01% | Target loss: 0.8684 acc: 73.15%


                                                                                  

Epoch 7/10 | Train loss: 0.1247 acc: 97.77% | Target loss: 0.9399 acc: 72.92%


                                                                                   

Epoch 8/10 | Train loss: 0.0993 acc: 98.31% | Target loss: 0.9209 acc: 70.45%


                                                                                   

Epoch 9/10 | Train loss: 0.0888 acc: 98.62% | Target loss: 0.9392 acc: 71.82%


                                                                                   

Epoch 10/10 | Train loss: 0.0793 acc: 98.66% | Target loss: 0.8104 acc: 74.85%

Best target-domain accuracy (SAM): 74.85% at epoch 10


  model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))


In [17]:
best_model_path = OUTPUT_DIR / 'best_model.pth'

model = DomainBedModel(
    featurizer,
    classifier,
    dropout_rate=resnet_dropout
).to(DEVICE)

state_dict = torch.load(best_model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()

domain_results = {}
total_acc = 0.0

for d, loader in source_loaders.items():
    loss, acc = evaluate(model, loader, DEVICE)
    domain_results[d] = {"loss": loss, "acc": acc}
    total_acc += acc
    print(f"Source {d}: loss {loss:.4f}, acc {acc:.2f}%")

t_loss, t_acc = evaluate(model, target_loader, DEVICE)
domain_results[TARGET_DOMAIN] = {"loss": t_loss, "acc": t_acc}
total_acc += t_acc

mean_source_acc = total_acc - t_acc
mean_source_acc /= len(source_loaders)
mean_acc = total_acc / (len(source_loaders) + 1)

print(f"\nMean source accuracy: {mean_source_acc:.2f}%")
print(f"Mean domain accuracy (incl. target): {mean_acc:.2f}%")

final_results = {
    "best_target_acc": best_target_acc,
    "mean_source_acc": mean_source_acc,
    "domain_results": domain_results,
    "mean_acc": mean_acc
}

torch.save(model, OUTPUT_DIR / 'final_model_full.pth')
print("Saved final model to erm_outputs/")

domains = list(domain_results.keys())
accuracies = [domain_results[d]["acc"] for d in domains]

plt.figure(figsize=(8, 5))
bars = plt.bar(domains, accuracies, color='skyblue', edgecolor='black', linewidth=1.2)
plt.ylabel("Accuracy (%)", fontsize=12)
plt.title("Per-Domain Accuracy (ERM Final Model)", fontsize=13, pad=10)

plt.xticks(range(len(domains)), domains, rotation=0, ha='center', fontsize=11)
plt.yticks(fontsize=11)
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig(OUTPUT_DIR / "domain_accuracy_bar.png", dpi=300, bbox_inches='tight')
plt.close()


  state_dict = torch.load(best_model_path, map_location=DEVICE)


Source art_painting: loss 0.1415, acc 98.09%
Source cartoon: loss 0.0337, acc 99.45%
Source photo: loss 0.2736, acc 92.81%

Mean source accuracy: 96.78%
Mean domain accuracy (incl. target): 91.30%
Saved final model to erm_outputs/
