In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from pathlib import Path
import numpy as np
from collections import defaultdict
from tqdm import tqdm
from torch.utils.data import random_split, DataLoader, Dataset

  import pynvml  # type: ignore[import]


In [3]:
DATA_ROOT = Path(r"C:\Users\Fatim_Sproj\Desktop\Fatim\Spring 2025\Datasets\pacs_data\pacs_data")
SOURCE_DOMAINS = ["art_painting", "cartoon", "photo"]
TARGET_DOMAIN = "sketch"

IMAGE_SIZE = 224
BATCH_SIZE = 64
LR = 3e-4
EPOCHS = 10
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

val_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

In [4]:
VAL_RATIO = 0.2

class PACSDataset(Dataset):
    def __init__(self, root, domain, transform=None):
        self.dataset = datasets.ImageFolder(root=root/domain, transform=transform)
        self.domain = domain
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return img, label, self.domain

source_train_loaders = {}
source_val_loaders = {}

for domain in SOURCE_DOMAINS:
    full_dataset = PACSDataset(DATA_ROOT, domain, transform=train_transform)
    n_total = len(full_dataset)
    n_val = int(VAL_RATIO * n_total)
    n_train = n_total - n_val
    
    train_set, val_set = random_split(full_dataset, [n_train, n_val])
    
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    source_train_loaders[domain] = train_loader
    source_val_loaders[domain] = val_loader

target_dataset = PACSDataset(DATA_ROOT, TARGET_DOMAIN, transform=val_transform)
target_loader = DataLoader(target_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)


In [7]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
num_classes = 7
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(DEVICE)

In [8]:
import torch
import torch.nn.functional as F
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
from pathlib import Path

SAVE_DIR = Path("dro_outputs")
SAVE_DIR.mkdir(exist_ok=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

def train_group_dro(model, source_loaders, target_loader, optimizer, epochs, device):
    domain_names = list(source_loaders.keys())
    best_target_acc = 0.0
    best_worst_source_acc = 0.0

    history = []

    for epoch in range(1, epochs + 1):
        model.train()
        domain_iterators = {d: iter(l) for d, l in source_loaders.items()}
        steps = min(len(l) for l in source_loaders.values())
        
        worst_losses = []
        domain_correct = defaultdict(int)
        domain_total = defaultdict(int)

        for step in range(steps):
            optimizer.zero_grad()
            losses = {}

            for domain in domain_names:
                try:
                    x, y, _ = next(domain_iterators[domain])
                except StopIteration:
                    domain_iterators[domain] = iter(source_loaders[domain])
                    x, y, _ = next(domain_iterators[domain])

                x, y = x.to(device), y.to(device)
                logits = model(x)
                loss = F.cross_entropy(logits, y)
                losses[domain] = loss

                preds = logits.argmax(dim=1)
                domain_correct[domain] += (preds == y).sum().item()
                domain_total[domain] += y.size(0)

            worst_domain = max(losses, key=losses.get)
            worst_loss = losses[worst_domain]
            worst_loss.backward()
            optimizer.step()

            worst_losses.append(worst_loss.item())

        print(f"\nEpoch {epoch}/{epochs}")
        avg_worst_loss = np.mean(worst_losses)
        print(f"Average worst-domain loss: {avg_worst_loss:.4f}")

        source_accs = {}
        for domain in domain_names:
            acc = domain_correct[domain] / domain_total[domain]
            source_accs[domain] = acc
            print(f"Source-domain '{domain}' accuracy: {acc*100:.2f}%")

        worst_source_acc = min(source_accs.values())
        print(f"Worst-source-domain accuracy: {worst_source_acc*100:.2f}%")

        target_acc = evaluate(model, target_loader, "Target Domain", device)

        if target_acc > best_target_acc:
            best_target_acc = target_acc
            torch.save(model.state_dict(), SAVE_DIR / "best_model_target.pth")
            print(f"Saved new best model (target acc = {target_acc*100:.2f}%)")

        if worst_source_acc > best_worst_source_acc:
            best_worst_source_acc = worst_source_acc
            torch.save(model.state_dict(), SAVE_DIR / "best_model_worst_source.pth")
            print(f"Saved new best model (worst-source acc = {worst_source_acc*100:.2f}%)")

        history.append({
            "epoch": epoch,
            "avg_worst_loss": avg_worst_loss,
            "source_accs": {d: float(a) for d, a in source_accs.items()},
            "worst_source_acc": float(worst_source_acc),
            "target_acc": float(target_acc)
        })

    epochs = [h["epoch"] for h in history]
    target_accs = [h["target_acc"] for h in history]
    worst_accs = [h["worst_source_acc"] for h in history]
    losses = [h["avg_worst_loss"] for h in history]

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, target_accs, label="Target Acc")
    plt.plot(epochs, worst_accs, label="Worst Source Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(SAVE_DIR / "accuracy_plot.png")
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, losses, label="Worst Loss", color="orange")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(SAVE_DIR / "loss_plot.png")
    plt.close()

    return best_target_acc, best_worst_source_acc


def evaluate(model, loader, name="Dataset", device="cuda"):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y, _ in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    acc = correct / total
    print(f"{name} Accuracy: {acc*100:.2f}%")
    return acc

In [9]:
best_target_acc, best_worst_source_acc = train_group_dro(
    model, source_train_loaders, target_loader, optimizer, EPOCHS, DEVICE
)


Epoch 1/10
Average worst-domain loss: 0.7635
Source-domain 'art_painting' accuracy: 77.23%
Source-domain 'cartoon' accuracy: 75.30%
Source-domain 'photo' accuracy: 85.70%
Worst-source-domain accuracy: 75.30%
Target Domain Accuracy: 55.59%
Saved new best model (target acc = 55.59%)
Saved new best model (worst-source acc = 75.30%)

Epoch 2/10
Average worst-domain loss: 0.2981
Source-domain 'art_painting' accuracy: 92.56%
Source-domain 'cartoon' accuracy: 93.45%
Source-domain 'photo' accuracy: 94.76%
Worst-source-domain accuracy: 92.56%
Target Domain Accuracy: 62.89%
Saved new best model (target acc = 62.89%)
Saved new best model (worst-source acc = 92.56%)

Epoch 3/10
Average worst-domain loss: 0.2040
Source-domain 'art_painting' accuracy: 94.94%
Source-domain 'cartoon' accuracy: 96.13%
Source-domain 'photo' accuracy: 96.26%
Worst-source-domain accuracy: 94.94%
Target Domain Accuracy: 64.88%
Saved new best model (target acc = 64.88%)
Saved new best model (worst-source acc = 94.94%)

Epo

In [11]:
best_model_path = SAVE_DIR / "best_model_target.pth"

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
num_classes = 7
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(DEVICE)

state_dict = torch.load(best_model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()

domain_results = {}
total_acc = 0.0

for d, loader in source_val_loaders.items():
    acc = evaluate(model, loader, f"Source {d}", DEVICE)
    domain_results[d] = acc
    total_acc += acc

t_acc = evaluate(model, target_loader, f"Target {TARGET_DOMAIN}", DEVICE)
domain_results[TARGET_DOMAIN] = t_acc
total_acc += t_acc

mean_source_acc = (total_acc - t_acc) / len(source_val_loaders)
mean_acc = total_acc / (len(source_val_loaders) + 1)

print(f"\nMean source accuracy: {mean_source_acc*100:.2f}%")
print(f"Mean domain accuracy (incl. target): {mean_acc*100:.2f}%")

torch.save(model, SAVE_DIR / "final_model_full.pth")
print("Saved final model to dro_outputs/")

domains = list(domain_results.keys())
accuracies = [domain_results[d] * 100 for d in domains]

plt.figure(figsize=(8, 5))
plt.bar(domains, accuracies, color='lightcoral', edgecolor='black', linewidth=1.2)
plt.ylabel("Accuracy (%)", fontsize=12)
plt.title("Per-Domain Accuracy (DRO Final Model)", fontsize=13, pad=10)
plt.xticks(rotation=0, ha='center', fontsize=11)
plt.yticks(fontsize=11)
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig(SAVE_DIR / "domain_accuracy_bar.png", dpi=300, bbox_inches='tight')
plt.close()


  state_dict = torch.load(best_model_path, map_location=DEVICE)


Source art_painting Accuracy: 85.26%
Source cartoon Accuracy: 90.38%
Source photo Accuracy: 95.21%
Target sketch Accuracy: 68.90%

Mean source accuracy: 90.28%
Mean domain accuracy (incl. target): 84.94%
Saved final model to dro_outputs/
