In [None]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("flwrlabs/pacs")

# Extract unique masterCategory values (assuming it's a string field)
labels = sorted(set(example["domain"] for example in dataset["train"]))

# Create id2label mapping
id2label = {str(i): label for i, label in enumerate(labels)}

# Print the mapping
print(id2label)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/191M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9991 [00:00<?, ? examples/s]

{'0': 'art_painting', '1': 'cartoon', '2': 'photo', '3': 'sketch'}


In [None]:
# Standard imports
import os
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import transforms, models

from datasets import load_dataset

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def accuracy(preds: torch.Tensor, labels: torch.Tensor) -> float:
    """
    Compute classification accuracy.
    preds: raw logits or probabilities (N, C)
    labels: (N,)
    """
    _, p = torch.max(preds, dim=1)
    return (p == labels).float().mean().item()

In [None]:
# Load the dataset
dataset_all = load_dataset("flwrlabs/pacs")  # has a single “train” split with domain, label, image fields
print(dataset_all)

# We know domain names from your earlier mapping:
id2label = {'0': 'art_painting', '1': 'cartoon', '2': 'photo', '3': 'sketch'}
domains = list(id2label.values())  # ['art_painting', 'cartoon', 'photo', 'sketch']

# Transform definitions
def get_transforms(train: bool):
    if train:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
    else:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

from torchvision.datasets import FakeData  # used for wrapping images
from PIL import Image

class PACSSubset(torch.utils.data.Dataset):
    """
    Dataset wrapper for PACS data filtered by domain and optionally split into train/test.
    Uses HuggingFace `datasets` rows.
    """
    def __init__(self, hf_dataset, domain_name, transform=None, indices=None):
        """
        hf_dataset: the HF dataset (e.g. dataset_all["train"])
        domain_name: e.g. "photo"
        transform: torchvision transforms
        indices: optional list of indices (within domain subset) to use
        """
        # Filter by domain
        self.examples = [ex for ex in hf_dataset if ex["domain"] == domain_name]
        if indices is not None:
            self.examples = [self.examples[i] for i in indices]
        self.transform = transform

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        img = ex["image"].convert("RGB")
        label = ex["label"]
        if self.transform:
            img = self.transform(img)
        return img, label

def make_dataloaders(hf_dataset, source_domain, target_domain,
                     batch_size=32, test_split_ratio=0.2, num_workers=4, seed=0):
    """
    Returns:
      src_train_loader, src_test_loader, tgt_test_loader
    """
    # First, build full subsets per domain
    src_full = PACSSubset(hf_dataset, source_domain, transform=get_transforms(train=True))
    tgt_full = PACSSubset(hf_dataset, target_domain, transform=get_transforms(train=False))
    # For source, we split indices into train/test
    n_src = len(src_full)
    # produce deterministic split
    indices = list(range(n_src))
    random.Random(seed).shuffle(indices)
    split = int(n_src * (1 - test_split_ratio))
    train_idx = indices[:split]
    test_idx = indices[split:]
    # Create subsets
    src_train = PACSSubset(hf_dataset, source_domain, transform=get_transforms(train=True), indices=train_idx)
    src_test  = PACSSubset(hf_dataset, source_domain, transform=get_transforms(train=False), indices=test_idx)
    # Target test uses full target domain (or you can similarly split, but here we test on all)
    tgt_test  = PACSSubset(hf_dataset, target_domain, transform=get_transforms(train=False), indices=None)

    src_train_loader = DataLoader(src_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    src_test_loader  = DataLoader(src_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    tgt_test_loader  = DataLoader(tgt_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return src_train_loader, src_test_loader, tgt_test_loader

DatasetDict({
    train: Dataset({
        features: ['image', 'domain', 'label'],
        num_rows: 9991
    })
})


In [None]:
class ResNet50Classifier(nn.Module):
    def __init__(self, num_classes=7, pretrained=True):
        super().__init__()
        self.backbone = models.resnet50(pretrained=pretrained)
        in_feats = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.classifier = nn.Linear(in_feats, num_classes)

    def forward(self, x):
        feats = self.backbone(x)
        logits = self.classifier(feats)
        return logits

In [None]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    total_acc = 0.0
    n = 0
    for x, y in tqdm(loader, desc="Train src"):
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * y.size(0)
        total_acc += accuracy(out, y) * y.size(0)
        n += y.size(0)
    return total_loss / n, total_acc / n

@torch.no_grad()
def evaluate(model, loader, criterion=None):
    model.eval()
    total_loss = 0.0
    total_acc = 0.0
    n = 0
    for x, y in tqdm(loader, desc="Eval"):
        x = x.to(device)
        y = y.to(device)
        out = model(x)
        if criterion is not None:
            total_loss += criterion(out, y).item() * y.size(0)
        total_acc += accuracy(out, y) * y.size(0)
        n += y.size(0)
    avg_loss = total_loss / n if criterion else None
    return avg_loss, total_acc / n

In [None]:
def run_source_only(hf_dataset, source_domain, target_domain,
                    num_classes=7, epochs=20, batch_size=32,
                    lr=1e-4, weight_decay=1e-5, seed=0):
    set_seed(seed)
    src_train_loader, src_test_loader, tgt_test_loader = make_dataloaders(
        hf_dataset, source_domain, target_domain, batch_size=batch_size, seed=seed)

    model = ResNet50Classifier(num_classes=num_classes, pretrained=True).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    best_tgt_acc = 0.0
    best_state = None

    for ep in range(epochs):
        print(f"Epoch {ep}/{epochs-1}")
        tr_loss, tr_acc = train_one_epoch(model, src_train_loader, optimizer, criterion)
        print(f"  Src Train — loss: {tr_loss:.4f}, acc: {tr_acc:.4f}")

        _, src_test_acc = evaluate(model, src_test_loader, criterion)
        print(f"  Src Test — acc: {src_test_acc:.4f}")

        _, tgt_test_acc = evaluate(model, tgt_test_loader, criterion)
        print(f"  Tgt Test — acc: {tgt_test_acc:.4f}")

        if tgt_test_acc > best_tgt_acc:
            best_tgt_acc = tgt_test_acc
            best_state = model.state_dict()

    # After training, load best model
    if best_state is not None:
        model.load_state_dict(best_state)
    _, final_src_acc = evaluate(model, src_test_loader, criterion)
    _, final_tgt_acc = evaluate(model, tgt_test_loader, criterion)

    return {
        "source_domain": source_domain,
        "target_domain": target_domain,
        "src_acc": final_src_acc,
        "tgt_acc": final_tgt_acc,
        "best_tgt_acc": best_tgt_acc
    }

In [None]:
def multiple_runs(hf_dataset, source_domain, target_domain,
                  seeds=[0, 1, 2], **kwargs):
    results = []
    for s in seeds:
        print("=== Seed", s, "===")
        res = run_source_only(hf_dataset, source_domain, target_domain, seed=s, **kwargs)
        print("Result:", res)
        results.append(res)
    src_accs = [r["src_acc"] for r in results]
    tgt_accs = [r["tgt_acc"] for r in results]
    print("---- Summary ----")
    print("Source Acc: mean = {:.4f}, std = {:.4f}".format(np.mean(src_accs), np.std(src_accs)))
    print("Target Acc: mean = {:.4f}, std = {:.4f}".format(np.mean(tgt_accs), np.std(tgt_accs)))
    return results

# Example: pick “photo” as source, “sketch” as target
results = multiple_runs(dataset_all["train"], source_domain="photo", target_domain="sketch",
                        num_classes=7, epochs=30, batch_size=32, lr=1e-4, weight_decay=1e-5, seeds=[0,1,2])
print(results)

=== Seed 0 ===




Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 184MB/s]


Epoch 0/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.58it/s]


  Src Train — loss: 0.4753, acc: 0.8765


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.54it/s]


  Src Test — acc: 0.9880


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.47it/s]


  Tgt Test — acc: 0.1710
Epoch 1/29


Train src: 100%|██████████| 41/41 [00:14<00:00,  2.77it/s]


  Src Train — loss: 0.2108, acc: 0.9299


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.56it/s]


  Src Test — acc: 0.9880


Eval: 100%|██████████| 123/123 [00:13<00:00,  9.20it/s]


  Tgt Test — acc: 0.2955
Epoch 2/29


Train src: 100%|██████████| 41/41 [00:14<00:00,  2.80it/s]


  Src Train — loss: 0.1691, acc: 0.9505


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.74it/s]


  Src Test — acc: 0.9820


Eval: 100%|██████████| 123/123 [00:13<00:00,  9.10it/s]


  Tgt Test — acc: 0.3306
Epoch 3/29


Train src: 100%|██████████| 41/41 [00:14<00:00,  2.74it/s]


  Src Train — loss: 0.1611, acc: 0.9451


Eval: 100%|██████████| 11/11 [00:02<00:00,  5.09it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:13<00:00,  9.01it/s]


  Tgt Test — acc: 0.1871
Epoch 4/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.69it/s]


  Src Train — loss: 0.1128, acc: 0.9688


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.44it/s]


  Src Test — acc: 0.9850


Eval: 100%|██████████| 123/123 [00:13<00:00,  9.04it/s]


  Tgt Test — acc: 0.1662
Epoch 5/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.73it/s]


  Src Train — loss: 0.1363, acc: 0.9543


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.62it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.89it/s]


  Tgt Test — acc: 0.1575
Epoch 6/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.65it/s]


  Src Train — loss: 0.0879, acc: 0.9710


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.51it/s]


  Src Test — acc: 0.9820


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.59it/s]


  Tgt Test — acc: 0.2245
Epoch 7/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.0954, acc: 0.9688


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.66it/s]


  Src Test — acc: 0.9701


Eval: 100%|██████████| 123/123 [00:13<00:00,  9.02it/s]


  Tgt Test — acc: 0.1822
Epoch 8/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.0941, acc: 0.9695


Eval: 100%|██████████| 11/11 [00:01<00:00,  5.82it/s]


  Src Test — acc: 0.9760


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.99it/s]


  Tgt Test — acc: 0.1553
Epoch 9/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.0946, acc: 0.9703


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.55it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.98it/s]


  Tgt Test — acc: 0.1611
Epoch 10/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.0993, acc: 0.9649


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.27it/s]


  Src Test — acc: 0.9760


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.73it/s]


  Tgt Test — acc: 0.1644
Epoch 11/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.1142, acc: 0.9619


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.93it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:13<00:00,  9.00it/s]


  Tgt Test — acc: 0.2594
Epoch 12/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.73it/s]


  Src Train — loss: 0.0830, acc: 0.9756


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.58it/s]


  Src Test — acc: 0.9701


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.54it/s]


  Tgt Test — acc: 0.2367
Epoch 13/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.69it/s]


  Src Train — loss: 0.1161, acc: 0.9604


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.65it/s]


  Src Test — acc: 0.9880


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.93it/s]


  Tgt Test — acc: 0.2166
Epoch 14/29


Train src: 100%|██████████| 41/41 [00:14<00:00,  2.74it/s]


  Src Train — loss: 0.0704, acc: 0.9764


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.74it/s]


  Src Test — acc: 0.9760


Eval: 100%|██████████| 123/123 [00:13<00:00,  9.01it/s]


  Tgt Test — acc: 0.1639
Epoch 15/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.73it/s]


  Src Train — loss: 0.0920, acc: 0.9718


Eval: 100%|██████████| 11/11 [00:01<00:00,  5.67it/s]


  Src Test — acc: 0.9671


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.90it/s]


  Tgt Test — acc: 0.1637
Epoch 16/29


Train src: 100%|██████████| 41/41 [00:14<00:00,  2.75it/s]


  Src Train — loss: 0.1336, acc: 0.9497


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.61it/s]


  Src Test — acc: 0.9641


Eval: 100%|██████████| 123/123 [00:13<00:00,  9.01it/s]


  Tgt Test — acc: 0.1637
Epoch 17/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.0897, acc: 0.9718


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.77it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.99it/s]


  Tgt Test — acc: 0.2087
Epoch 18/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.65it/s]


  Src Train — loss: 0.0703, acc: 0.9741


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.59it/s]


  Src Test — acc: 0.9701


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.62it/s]


  Tgt Test — acc: 0.1891
Epoch 19/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.0727, acc: 0.9756


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.76it/s]


  Src Test — acc: 0.9701


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.95it/s]


  Tgt Test — acc: 0.1947
Epoch 20/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.73it/s]


  Src Train — loss: 0.0982, acc: 0.9665


Eval: 100%|██████████| 11/11 [00:02<00:00,  5.45it/s]


  Src Test — acc: 0.9581


Eval: 100%|██████████| 123/123 [00:13<00:00,  9.04it/s]


  Tgt Test — acc: 0.2362
Epoch 21/29


Train src: 100%|██████████| 41/41 [00:14<00:00,  2.75it/s]


  Src Train — loss: 0.0823, acc: 0.9741


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.91it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.97it/s]


  Tgt Test — acc: 0.1575
Epoch 22/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.0715, acc: 0.9771


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.65it/s]


  Src Test — acc: 0.9701


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.79it/s]


  Tgt Test — acc: 0.1588
Epoch 23/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.66it/s]


  Src Train — loss: 0.0733, acc: 0.9764


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.81it/s]


  Src Test — acc: 0.9671


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.97it/s]


  Tgt Test — acc: 0.2461
Epoch 24/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.1277, acc: 0.9642


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.68it/s]


  Src Test — acc: 0.9760


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.45it/s]


  Tgt Test — acc: 0.2609
Epoch 25/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.0907, acc: 0.9649


Eval: 100%|██████████| 11/11 [00:01<00:00,  5.67it/s]


  Src Test — acc: 0.9850


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.96it/s]


  Tgt Test — acc: 0.2336
Epoch 26/29


Train src: 100%|██████████| 41/41 [00:14<00:00,  2.73it/s]


  Src Train — loss: 0.0771, acc: 0.9748


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.53it/s]


  Src Test — acc: 0.9760


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.95it/s]


  Tgt Test — acc: 0.3296
Epoch 27/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.0738, acc: 0.9764


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.43it/s]


  Src Test — acc: 0.9760


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.79it/s]


  Tgt Test — acc: 0.2957
Epoch 28/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.67it/s]


  Src Train — loss: 0.0793, acc: 0.9756


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.70it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.93it/s]


  Tgt Test — acc: 0.2168
Epoch 29/29


Train src: 100%|██████████| 41/41 [00:14<00:00,  2.74it/s]


  Src Train — loss: 0.0971, acc: 0.9710


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.69it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.95it/s]


  Tgt Test — acc: 0.3215


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.34it/s]
Eval: 100%|██████████| 123/123 [00:14<00:00,  8.78it/s]


Result: {'source_domain': 'photo', 'target_domain': 'sketch', 'src_acc': 0.973053893643225, 'tgt_acc': 0.3214558411809621, 'best_tgt_acc': 0.3306184779842199}
=== Seed 1 ===
Epoch 0/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.5001, acc: 0.8537


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.34it/s]


  Src Test — acc: 0.9641


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.84it/s]


  Tgt Test — acc: 0.2459
Epoch 1/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.69it/s]


  Src Train — loss: 0.2073, acc: 0.9299


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.33it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.90it/s]


  Tgt Test — acc: 0.2945
Epoch 2/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.68it/s]


  Src Train — loss: 0.1557, acc: 0.9482


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.44it/s]


  Src Test — acc: 0.9701


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.85it/s]


  Tgt Test — acc: 0.2639
Epoch 3/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.70it/s]


  Src Train — loss: 0.1287, acc: 0.9627


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.46it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.83it/s]


  Tgt Test — acc: 0.2670
Epoch 4/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.67it/s]


  Src Train — loss: 0.1165, acc: 0.9642


Eval: 100%|██████████| 11/11 [00:02<00:00,  5.28it/s]


  Src Test — acc: 0.9671


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.47it/s]


  Tgt Test — acc: 0.1573
Epoch 5/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.1238, acc: 0.9588


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.45it/s]


  Src Test — acc: 0.9760


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.88it/s]


  Tgt Test — acc: 0.3553
Epoch 6/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.73it/s]


  Src Train — loss: 0.1119, acc: 0.9634


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.43it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.64it/s]


  Tgt Test — acc: 0.1560
Epoch 7/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.66it/s]


  Src Train — loss: 0.0984, acc: 0.9703


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.37it/s]


  Src Test — acc: 0.9820


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.85it/s]


  Tgt Test — acc: 0.1558
Epoch 8/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.70it/s]


  Src Train — loss: 0.0837, acc: 0.9695


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.18it/s]


  Src Test — acc: 0.9820


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.90it/s]


  Tgt Test — acc: 0.1583
Epoch 9/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.67it/s]


  Src Train — loss: 0.0981, acc: 0.9718


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.57it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.85it/s]


  Tgt Test — acc: 0.1591
Epoch 10/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.1037, acc: 0.9657


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.74it/s]


  Src Test — acc: 0.9611


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.48it/s]


  Tgt Test — acc: 0.2255
Epoch 11/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.68it/s]


  Src Train — loss: 0.1433, acc: 0.9489


Eval: 100%|██████████| 11/11 [00:02<00:00,  5.14it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.83it/s]


  Tgt Test — acc: 0.1659
Epoch 12/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.73it/s]


  Src Train — loss: 0.0986, acc: 0.9665


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.30it/s]


  Src Test — acc: 0.9701


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.86it/s]


  Tgt Test — acc: 0.1924
Epoch 13/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.70it/s]


  Src Train — loss: 0.0866, acc: 0.9741


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.11it/s]


  Src Test — acc: 0.9820


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.70it/s]


  Tgt Test — acc: 0.2713
Epoch 14/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.69it/s]


  Src Train — loss: 0.0783, acc: 0.9733


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.35it/s]


  Src Test — acc: 0.9491


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.88it/s]


  Tgt Test — acc: 0.2041
Epoch 15/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.1231, acc: 0.9634


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.46it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.79it/s]


  Tgt Test — acc: 0.1799
Epoch 16/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.62it/s]


  Src Train — loss: 0.1055, acc: 0.9695


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.53it/s]


  Src Test — acc: 0.9760


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.45it/s]


  Tgt Test — acc: 0.3011
Epoch 17/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.1189, acc: 0.9657


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.52it/s]


  Src Test — acc: 0.9401


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.82it/s]


  Tgt Test — acc: 0.3087
Epoch 18/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.65it/s]


  Src Train — loss: 0.0831, acc: 0.9733


Eval: 100%|██████████| 11/11 [00:01<00:00,  5.70it/s]


  Src Test — acc: 0.9760


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.89it/s]


  Tgt Test — acc: 0.2499
Epoch 19/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.0641, acc: 0.9787


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.41it/s]


  Src Test — acc: 0.9820


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.80it/s]


  Tgt Test — acc: 0.2685
Epoch 20/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.69it/s]


  Src Train — loss: 0.0934, acc: 0.9741


Eval: 100%|██████████| 11/11 [00:02<00:00,  5.12it/s]


  Src Test — acc: 0.9611


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.81it/s]


  Tgt Test — acc: 0.2820
Epoch 21/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.70it/s]


  Src Train — loss: 0.0656, acc: 0.9779


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.46it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.84it/s]


  Tgt Test — acc: 0.3166
Epoch 22/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.1093, acc: 0.9649


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.19it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:15<00:00,  8.16it/s]


  Tgt Test — acc: 0.1769
Epoch 23/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.68it/s]


  Src Train — loss: 0.0924, acc: 0.9733


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.40it/s]


  Src Test — acc: 0.9581


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.77it/s]


  Tgt Test — acc: 0.1782
Epoch 24/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.0856, acc: 0.9657


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.61it/s]


  Src Test — acc: 0.9970


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.65it/s]


  Tgt Test — acc: 0.1573
Epoch 25/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.64it/s]


  Src Train — loss: 0.0890, acc: 0.9688


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.43it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.71it/s]


  Tgt Test — acc: 0.2672
Epoch 26/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.0853, acc: 0.9733


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.36it/s]


  Src Test — acc: 0.9581


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.78it/s]


  Tgt Test — acc: 0.2863
Epoch 27/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.63it/s]


  Src Train — loss: 0.0666, acc: 0.9771


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.50it/s]


  Src Test — acc: 0.9910


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.79it/s]


  Tgt Test — acc: 0.2993
Epoch 28/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.0630, acc: 0.9802


Eval: 100%|██████████| 11/11 [00:01<00:00,  5.50it/s]


  Src Test — acc: 0.9850


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.57it/s]


  Tgt Test — acc: 0.2530
Epoch 29/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.68it/s]


  Src Train — loss: 0.0474, acc: 0.9878


Eval: 100%|██████████| 11/11 [00:01<00:00,  5.62it/s]


  Src Test — acc: 0.9701


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.74it/s]


  Tgt Test — acc: 0.2568


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.40it/s]
Eval: 100%|██████████| 123/123 [00:13<00:00,  8.81it/s]


Result: {'source_domain': 'photo', 'target_domain': 'sketch', 'src_acc': 0.9700598802395209, 'tgt_acc': 0.25680834818019854, 'best_tgt_acc': 0.3553066938152202}
=== Seed 2 ===
Epoch 0/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.70it/s]


  Src Train — loss: 0.4906, acc: 0.8765


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.45it/s]


  Src Test — acc: 0.9611


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.68it/s]


  Tgt Test — acc: 0.1598
Epoch 1/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.60it/s]


  Src Train — loss: 0.1826, acc: 0.9413


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.33it/s]


  Src Test — acc: 0.9671


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.80it/s]


  Tgt Test — acc: 0.2568
Epoch 2/29


Train src: 100%|██████████| 41/41 [00:14<00:00,  2.74it/s]


  Src Train — loss: 0.1556, acc: 0.9497


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.47it/s]


  Src Test — acc: 0.9491


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.41it/s]


  Tgt Test — acc: 0.2985
Epoch 3/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.68it/s]


  Src Train — loss: 0.1525, acc: 0.9466


Eval: 100%|██████████| 11/11 [00:01<00:00,  5.98it/s]


  Src Test — acc: 0.9551


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.70it/s]


  Tgt Test — acc: 0.3627
Epoch 4/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.1238, acc: 0.9619


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.45it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.77it/s]


  Tgt Test — acc: 0.1677
Epoch 5/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.1191, acc: 0.9634


Eval: 100%|██████████| 11/11 [00:02<00:00,  5.00it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.66it/s]


  Tgt Test — acc: 0.2522
Epoch 6/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.69it/s]


  Src Train — loss: 0.1193, acc: 0.9604


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.30it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.72it/s]


  Tgt Test — acc: 0.1624
Epoch 7/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.1276, acc: 0.9596


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.09it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.52it/s]


  Tgt Test — acc: 0.1952
Epoch 8/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.69it/s]


  Src Train — loss: 0.0892, acc: 0.9665


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.57it/s]


  Src Test — acc: 0.9850


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.37it/s]


  Tgt Test — acc: 0.2510
Epoch 9/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.70it/s]


  Src Train — loss: 0.0786, acc: 0.9779


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.60it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.74it/s]


  Tgt Test — acc: 0.3039
Epoch 10/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.62it/s]


  Src Train — loss: 0.0958, acc: 0.9703


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.23it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.71it/s]


  Tgt Test — acc: 0.1550
Epoch 11/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.1211, acc: 0.9627


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.57it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.72it/s]


  Tgt Test — acc: 0.2128
Epoch 12/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.67it/s]


  Src Train — loss: 0.0936, acc: 0.9718


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.50it/s]


  Src Test — acc: 0.9820


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.71it/s]


  Tgt Test — acc: 0.2433
Epoch 13/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.70it/s]


  Src Train — loss: 0.0870, acc: 0.9787


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.36it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.71it/s]


  Tgt Test — acc: 0.1550
Epoch 14/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.68it/s]


  Src Train — loss: 0.0960, acc: 0.9710


Eval: 100%|██████████| 11/11 [00:02<00:00,  5.35it/s]


  Src Test — acc: 0.9940


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.36it/s]


  Tgt Test — acc: 0.1547
Epoch 15/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.0811, acc: 0.9726


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.31it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.71it/s]


  Tgt Test — acc: 0.1904
Epoch 16/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.0755, acc: 0.9802


Eval: 100%|██████████| 11/11 [00:02<00:00,  5.42it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.59it/s]


  Tgt Test — acc: 0.1614
Epoch 17/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.70it/s]


  Src Train — loss: 0.0725, acc: 0.9779


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.48it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.67it/s]


  Tgt Test — acc: 0.1550
Epoch 18/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.70it/s]


  Src Train — loss: 0.1038, acc: 0.9688


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.03it/s]


  Src Test — acc: 0.9880


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.59it/s]


  Tgt Test — acc: 0.1568
Epoch 19/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.68it/s]


  Src Train — loss: 0.0874, acc: 0.9710


Eval: 100%|██████████| 11/11 [00:01<00:00,  5.89it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.76it/s]


  Tgt Test — acc: 0.2072
Epoch 20/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.72it/s]


  Src Train — loss: 0.0762, acc: 0.9748


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.08it/s]


  Src Test — acc: 0.9641


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.24it/s]


  Tgt Test — acc: 0.2059
Epoch 21/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.63it/s]


  Src Train — loss: 0.1132, acc: 0.9627


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.35it/s]


  Src Test — acc: 0.9760


Eval: 100%|██████████| 123/123 [00:13<00:00,  8.81it/s]


  Tgt Test — acc: 0.3105
Epoch 22/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.0894, acc: 0.9756


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.32it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.78it/s]


  Tgt Test — acc: 0.1715
Epoch 23/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.69it/s]


  Src Train — loss: 0.0499, acc: 0.9863


Eval: 100%|██████████| 11/11 [00:01<00:00,  5.67it/s]


  Src Test — acc: 0.9611


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.74it/s]


  Tgt Test — acc: 0.2008
Epoch 24/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.70it/s]


  Src Train — loss: 0.0872, acc: 0.9680


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.54it/s]


  Src Test — acc: 0.9760


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.74it/s]


  Tgt Test — acc: 0.3062
Epoch 25/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.68it/s]


  Src Train — loss: 0.0761, acc: 0.9748


Eval: 100%|██████████| 11/11 [00:02<00:00,  5.10it/s]


  Src Test — acc: 0.9701


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.75it/s]


  Tgt Test — acc: 0.2543
Epoch 26/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.70it/s]


  Src Train — loss: 0.0719, acc: 0.9764


Eval: 100%|██████████| 11/11 [00:01<00:00,  5.97it/s]


  Src Test — acc: 0.9731


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.40it/s]


  Tgt Test — acc: 0.1771
Epoch 27/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.70it/s]


  Src Train — loss: 0.0650, acc: 0.9764


Eval: 100%|██████████| 11/11 [00:02<00:00,  5.01it/s]


  Src Test — acc: 0.9611


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.75it/s]


  Tgt Test — acc: 0.2352
Epoch 28/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.69it/s]


  Src Train — loss: 0.0730, acc: 0.9748


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.38it/s]


  Src Test — acc: 0.9820


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.75it/s]


  Tgt Test — acc: 0.1827
Epoch 29/29


Train src: 100%|██████████| 41/41 [00:15<00:00,  2.71it/s]


  Src Train — loss: 0.0705, acc: 0.9756


Eval: 100%|██████████| 11/11 [00:01<00:00,  6.08it/s]


  Src Test — acc: 0.9790


Eval: 100%|██████████| 123/123 [00:14<00:00,  8.54it/s]


  Tgt Test — acc: 0.1825


Eval: 100%|██████████| 11/11 [00:01<00:00,  5.65it/s]
Eval: 100%|██████████| 123/123 [00:14<00:00,  8.73it/s]

Result: {'source_domain': 'photo', 'target_domain': 'sketch', 'src_acc': 0.9790419161676647, 'tgt_acc': 0.18248918299821837, 'best_tgt_acc': 0.3626877067956223}
---- Summary ----
Source Acc: mean = 0.9741, std = 0.0037
Target Acc: mean = 0.2536, std = 0.0568
[{'source_domain': 'photo', 'target_domain': 'sketch', 'src_acc': 0.973053893643225, 'tgt_acc': 0.3214558411809621, 'best_tgt_acc': 0.3306184779842199}, {'source_domain': 'photo', 'target_domain': 'sketch', 'src_acc': 0.9700598802395209, 'tgt_acc': 0.25680834818019854, 'best_tgt_acc': 0.3553066938152202}, {'source_domain': 'photo', 'target_domain': 'sketch', 'src_acc': 0.9790419161676647, 'tgt_acc': 0.18248918299821837, 'best_tgt_acc': 0.3626877067956223}]





Task 2

In [None]:
import os
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import transforms, models
from datasets import load_dataset

from sklearn.metrics import classification_report, confusion_matrix

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def accuracy(preds: torch.Tensor, labels: torch.Tensor) -> float:
    _, p = torch.max(preds, dim=1)
    return (p == labels).float().mean().item()

In [None]:
# Load the HF PACS dataset
dataset_all = load_dataset("flwrlabs/pacs")
print(dataset_all)

id2label = {'0': 'art_painting', '1': 'cartoon', '2': 'photo', '3': 'sketch'}
domains = list(id2label.values())

def get_transforms(train: bool):
    if train:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
    else:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

class PACSSubset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, domain_name, transform=None, indices=None):
        self.examples = [ex for ex in hf_dataset if ex["domain"] == domain_name]
        if indices is not None:
            self.examples = [self.examples[i] for i in indices]
        self.transform = transform

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        img = ex["image"].convert("RGB")
        label = ex["label"]
        if self.transform:
            img = self.transform(img)
        return img, label

def make_dataloaders(hf_dataset, source_domain, target_domain,
                     batch_size=32, test_split_ratio=0.2, seed=0, num_workers=4):
    # Build full subsets
    src_full = PACSSubset(hf_dataset, source_domain, transform=get_transforms(train=True))
    tgt_full = PACSSubset(hf_dataset, target_domain, transform=get_transforms(train=False))
    # Split source into train/test
    n_src = len(src_full)
    indices = list(range(n_src))
    random.Random(seed).shuffle(indices)
    split = int(n_src * (1 - test_split_ratio))
    train_idx = indices[:split]
    test_idx = indices[split:]
    src_train = PACSSubset(hf_dataset, source_domain,
                           transform=get_transforms(train=True),
                           indices=train_idx)
    src_test = PACSSubset(hf_dataset, source_domain,
                          transform=get_transforms(train=False),
                          indices=test_idx)
    # Target test uses full target set
    tgt_test = PACSSubset(hf_dataset, target_domain,
                          transform=get_transforms(train=False),
                          indices=None)

    src_train_loader = DataLoader(src_train, batch_size=batch_size,
                                  shuffle=True, num_workers=num_workers, drop_last=True)
    src_test_loader = DataLoader(src_test, batch_size=batch_size,
                                 shuffle=False, num_workers=num_workers)
    tgt_test_loader = DataLoader(tgt_test, batch_size=batch_size,
                                 shuffle=False, num_workers=num_workers)

    # Also produce a loader for target “unlabeled” used in training (same images, no labels)
    tgt_train_for_alignment = DataLoader(PACSSubset(hf_dataset, target_domain,
                                                    transform=get_transforms(train=False),
                                                    indices=None),
                                         batch_size=batch_size, shuffle=True,
                                         num_workers=num_workers, drop_last=True)

    return src_train_loader, src_test_loader, tgt_test_loader, tgt_train_for_alignment

DatasetDict({
    train: Dataset({
        features: ['image', 'domain', 'label'],
        num_rows: 9991
    })
})


In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.backbone = models.resnet50(pretrained=pretrained)
        in_feat = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.out_dim = in_feat

    def forward(self, x):
        return self.backbone(x)

class ClassifierHead(nn.Module):
    def __init__(self, in_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, feats):
        return self.fc(feats)


In [None]:
class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.alpha, None

def grad_reverse(x, alpha):
    return GradReverse.apply(x, alpha)

class DomainDiscriminator(nn.Module):
    def __init__(self, in_dim, hidden_dim=1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )

    def forward(self, x):
        return self.net(x)


In [None]:
def gaussian_kernel(x, y, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    """Compute multiple Gaussian RBF kernel matrix between x and y."""
    # Flatten
    n_samples = x.size(0) + y.size(0)
    total = torch.cat([x, y], dim=0)
    total0 = total.unsqueeze(0).expand(total.size(0), total.size(0), total.size(1))
    total1 = total.unsqueeze(1).expand(total.size(0), total.size(0), total.size(1))
    l2dist = ((total0 - total1) ** 2).sum(2)
    # compute bandwidth
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        # median heuristic
        bandwidth = torch.sum(l2dist.data) / (n_samples**2 - n_samples)
    bandwidth_list = [bandwidth * (kernel_mul ** (i - kernel_num // 2))
                      for i in range(kernel_num)]
    kernel_val = [torch.exp(-l2dist / b) for b in bandwidth_list]
    return sum(kernel_val)

def mmd_loss(source_feats, target_feats):
    """Estimate MMD between source and target."""
    K = gaussian_kernel(source_feats, target_feats)
    # split
    n = source_feats.size(0)
    m = target_feats.size(0)
    K_ss = K[:n, :n]
    K_tt = K[n:n+m, n:n+m]
    K_st = K[:n, n:n+m]
    K_ts = K[n:n+m, :n]
    loss = K_ss.mean() + K_tt.mean() - K_st.mean() - K_ts.mean()
    return loss

def cd_input(feats, class_probs):
    # simplest combined input: outer product (batch, C, d) → flatten to (batch, C*d)
    # or use multiply (C dims broadcast to d dims)
    # Here: broadcast and multiply
    return (feats.unsqueeze(1) * class_probs.unsqueeze(2)).view(feats.size(0), -1)


In [None]:
def train_epoch_source_only(feat_ext, classifier, optimizer, src_loader):
    feat_ext.train(); classifier.train()
    total_loss = 0.0; total_acc = 0.0; n = 0
    for x_s, y_s in tqdm(src_loader, desc="Train SourceOnly"):
        x_s = x_s.to(device); y_s = y_s.to(device)
        optimizer.zero_grad()
        f_s = feat_ext(x_s)
        logits_s = classifier(f_s)
        loss = nn.CrossEntropyLoss()(logits_s, y_s)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * y_s.size(0)
        total_acc += accuracy(logits_s, y_s) * y_s.size(0)
        n += y_s.size(0)
    return total_loss / n, total_acc / n

def train_epoch_dan(feat_ext, classifier, optimizer, src_loader, tgt_loader, lambda_mmd):
    feat_ext.train(); classifier.train()
    total_loss = 0.0; total_acc = 0.0; n = 0
    for (x_s, y_s), (x_t, _) in zip(src_loader, tgt_loader):
        x_s = x_s.to(device); y_s = y_s.to(device)
        x_t = x_t.to(device)
        optimizer.zero_grad()
        f_s = feat_ext(x_s)
        f_t = feat_ext(x_t)
        logits_s = classifier(f_s)
        loss_cls = nn.CrossEntropyLoss()(logits_s, y_s)
        loss_m = mmd_loss(f_s, f_t)
        loss = loss_cls + lambda_mmd * loss_m
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * y_s.size(0)
        total_acc += accuracy(logits_s, y_s) * y_s.size(0)
        n += y_s.size(0)
    return total_loss / n, total_acc / n

def train_epoch_dann(feat_ext, classifier, dom_disc, optimizer, src_loader, tgt_loader, alpha):
    feat_ext.train(); classifier.train(); dom_disc.train()
    total_loss = 0.0; total_acc = 0.0; n = 0
    for (x_s, y_s), (x_t, _) in zip(src_loader, tgt_loader):
        x_s = x_s.to(device); y_s = y_s.to(device)
        x_t = x_t.to(device)
        optimizer.zero_grad()
        f_s = feat_ext(x_s)
        f_t = feat_ext(x_t)
        logits_s = classifier(f_s)
        loss_cls = nn.CrossEntropyLoss()(logits_s, y_s)
        # domain discrimination
        f_cat = torch.cat([f_s, f_t], dim=0)
        f_rev = grad_reverse(f_cat, alpha)
        dom_logits = dom_disc(f_rev)
        domain_labels = torch.cat([
            torch.zeros(f_s.size(0), dtype=torch.long),
            torch.ones(f_t.size(0), dtype=torch.long)
        ], dim=0).to(device)
        loss_dom = nn.CrossEntropyLoss()(dom_logits, domain_labels)
        loss = loss_cls + loss_dom
        loss.backward()
        optimizer.step()
        total_loss += loss_cls.item() * y_s.size(0)
        total_acc += accuracy(logits_s, y_s) * y_s.size(0)
        n += y_s.size(0)
    return total_loss / n, total_acc / n

def train_epoch_cdan(feat_ext, classifier, dom_disc, optimizer, src_loader, tgt_loader, alpha):
    feat_ext.train(); classifier.train(); dom_disc.train()
    total_loss = 0.0; total_acc = 0.0; n = 0
    for (x_s, y_s), (x_t, _) in zip(src_loader, tgt_loader):
        x_s = x_s.to(device); y_s = y_s.to(device)
        x_t = x_t.to(device)
        optimizer.zero_grad()
        f_s = feat_ext(x_s)
        f_t = feat_ext(x_t)
        logits_s = classifier(f_s)
        prob_s = torch.softmax(logits_s, dim=1)
        loss_cls = nn.CrossEntropyLoss()(logits_s, y_s)
        # For x_t we just compute features & predictions
        logits_t = classifier(f_t)
        prob_t = torch.softmax(logits_t, dim=1)
        # Prepare combined features for domain discriminator
        comb_s = cd_input(f_s, prob_s)
        comb_t = cd_input(f_t, prob_t)
        comb = torch.cat([comb_s, comb_t], dim=0)
        comb_rev = grad_reverse(comb, alpha)
        dom_logits = dom_disc(comb_rev)
        domain_labels = torch.cat([
            torch.zeros(f_s.size(0), dtype=torch.long),
            torch.ones(f_t.size(0), dtype=torch.long)
        ], dim=0).to(device)
        loss_dom = nn.CrossEntropyLoss()(dom_logits, domain_labels)
        loss = loss_cls + loss_dom
        loss.backward()
        optimizer.step()
        total_loss += loss_cls.item() * y_s.size(0)
        total_acc += accuracy(logits_s, y_s) * y_s.size(0)
        n += y_s.size(0)
    return total_loss / n, total_acc / n


In [None]:
@torch.no_grad()
def extract_feats_logits(feat_ext, classifier, loader):
    feat_ext.eval(); classifier.eval()
    feats = []
    logits = []
    labels = []
    for x, y in loader:
        x = x.to(device)
        f = feat_ext(x)
        log = classifier(f)
        feats.append(f.cpu())
        logits.append(log.cpu())
        labels.append(y)
    feats = torch.cat(feats, dim=0).numpy()
    logits = torch.cat(logits, dim=0).numpy()
    labels = torch.cat(labels, dim=0).numpy()
    return feats, logits, labels

def evaluate_on_loader(feat_ext, classifier, loader):
    avg_loss, acc = evaluate_loop(feat_ext, classifier, loader)
    return acc

def evaluate_loop(feat_ext, classifier, loader):
    feat_ext.eval(); classifier.eval()
    total_loss = 0.0; total_acc = 0.0; n = 0
    for x, y in loader:
        x = x.to(device); y = y.to(device)
        logit = classifier(feat_ext(x))
        loss = nn.CrossEntropyLoss()(logit, y)
        total_loss += loss.item() * y.size(0)
        total_acc += accuracy(logit, y) * y.size(0)
        n += y.size(0)
    return total_loss / n, total_acc / n


In [None]:
def run_method(method_name, hf_dataset, source_domain, target_domain,
               num_classes, epochs=20, batch_size=32, lr=1e-4, weight_decay=1e-5,
               lambda_mmd=0.1, alpha=1.0, seed=0):
    set_seed(seed)
    src_train, src_test, tgt_align, tgt_test = make_dataloaders(
        hf_dataset, source_domain, target_domain, batch_size=batch_size, seed=seed)

    feat_ext = FeatureExtractor(pretrained=True).to(device)
    classifier = ClassifierHead(feat_ext.out_dim, num_classes).to(device)
    # Initialize DomainDiscriminator with the correct input dimension for CDAN
    if method_name == "CDAN":
        dom_disc = DomainDiscriminator(feat_ext.out_dim * num_classes).to(device)
    else:
        dom_disc = DomainDiscriminator(feat_ext.out_dim).to(device)

    # Combined optimizer (backbone, classifier, disc)
    optimizer = optim.Adam(list(feat_ext.parameters()) +
                           list(classifier.parameters()) +
                           list(dom_disc.parameters()),
                           lr=lr, weight_decay=weight_decay)

    best_tgt = 0.0
    best_state = None

    for ep in range(epochs):
        print(f"Epoch {ep}/{epochs-1} - Method {method_name}")
        if method_name == "source_only":
            train_epoch_source_only(feat_ext, classifier, optimizer, src_train)
        elif method_name == "DAN":
            train_epoch_dan(feat_ext, classifier, optimizer, src_train, tgt_align, lambda_mmd)
        elif method_name == "DANN":
            train_epoch_dann(feat_ext, classifier, dom_disc, optimizer, src_train, tgt_align, alpha)
        elif method_name == "CDAN":
            train_epoch_cdan(feat_ext, classifier, dom_disc, optimizer, src_train, tgt_align, alpha)
        else:
            raise ValueError("Unknown method")

        # Evaluate
        _, src_acc = evaluate_loop(feat_ext, classifier, src_test)
        _, tgt_acc = evaluate_loop(feat_ext, classifier, tgt_test)
        print(f"  Source Acc: {src_acc:.4f}, Target Acc: {tgt_acc:.4f}")
        if tgt_acc > best_tgt:
            best_tgt = tgt_acc
            best_state = {
                "feat": feat_ext.state_dict(),
                "clf": classifier.state_dict(),
                "disc": dom_disc.state_dict()
            }

    # Load best
    if best_state:
        feat_ext.load_state_dict(best_state["feat"])
        classifier.load_state_dict(best_state["clf"])
        dom_disc.load_state_dict(best_state["disc"])

    # Final evaluation
    _, src_acc = evaluate_loop(feat_ext, classifier, src_test)
    _, tgt_acc = evaluate_loop(feat_ext, classifier, tgt_test)

    # Extract features & predictions on target to get class-wise metrics
    f_t, log_t, y_t = extract_feats_logits(feat_ext, classifier, tgt_test)
    y_pred = np.argmax(log_t, axis=1)
    clf_report = classification_report(y_t, y_pred, digits=4, zero_division=0)
    conf_mat = confusion_matrix(y_t, y_pred)

    # Domain classifier proxy: train simple logistic reg on features
    # Using features from src_test and tgt_test
    f_s, _, y_s = extract_feats_logits(feat_ext, classifier, src_test)
    # Build domain labels: 0 for source, 1 for target
    from sklearn.linear_model import LogisticRegression
    X = np.vstack([f_s, f_t])
    y_dom = np.hstack([np.zeros(f_s.shape[0]), np.ones(f_t.shape[0])])
    dom_clf = LogisticRegression(max_iainedter=1000)
    dom_clf.fit(X, y_dom)
    dom_acc = dom_clf.score(X, y_dom)

    result = {
        "method": method_name,
        "src_acc": src_acc,
        "tgt_acc": tgt_acc,
        "class_report": clf_report,
        "confusion_matrix": conf_mat,
        "domain_proxy_acc": dom_acc
    }
    return result

In [None]:
methods = ["DAN", "DANN"]
all_res = []
for m in methods:
    print("Running method:", m)
    res = run_method(m, dataset_all["train"], source_domain="photo",
                     target_domain="sketch", num_classes=7,
                     epochs=25, batch_size=32, lr=1e-4, weight_decay=1e-5,
                     lambda_mmd=0.1, alpha=1.0, seed=0)
    print("Result:", res["src_acc"], res["tgt_acc"], "domain proxy acc:", res["domain_proxy_acc"])
    print("Classification report (target):\n", res["class_report"])
    print("Confusion matrix (target):\n", res["confusion_matrix"])
    all_res.append(res)


Running method: DAN




Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 190MB/s]


Epoch 0/24 - Method DAN
  Source Acc: 0.9731, Target Acc: 0.4682
Epoch 1/24 - Method DAN
  Source Acc: 0.9491, Target Acc: 0.4490
Epoch 2/24 - Method DAN
  Source Acc: 0.9611, Target Acc: 0.4196
Epoch 3/24 - Method DAN
  Source Acc: 0.9760, Target Acc: 0.4152
Epoch 4/24 - Method DAN
  Source Acc: 0.9760, Target Acc: 0.3901
Epoch 5/24 - Method DAN
  Source Acc: 0.9641, Target Acc: 0.4388
Epoch 6/24 - Method DAN
  Source Acc: 0.9311, Target Acc: 0.4070
Epoch 7/24 - Method DAN
  Source Acc: 0.9162, Target Acc: 0.3809
Epoch 8/24 - Method DAN
  Source Acc: 0.9162, Target Acc: 0.3950
Epoch 9/24 - Method DAN
  Source Acc: 0.9521, Target Acc: 0.3512
Epoch 10/24 - Method DAN
  Source Acc: 0.9671, Target Acc: 0.3799
Epoch 11/24 - Method DAN
  Source Acc: 0.9222, Target Acc: 0.4329
Epoch 12/24 - Method DAN
  Source Acc: 0.9581, Target Acc: 0.3737
Epoch 13/24 - Method DAN
  Source Acc: 0.9251, Target Acc: 0.3043
Epoch 14/24 - Method DAN
  Source Acc: 0.9581, Target Acc: 0.4214
Epoch 15/24 - Method



Result: 0.9251497034541147 0.3785860655737705 domain proxy acc: 1.0
Classification report (target):
               precision    recall  f1-score   support

           0     0.9091    0.0131    0.0258       765
           1     0.6809    0.0434    0.0816       737
           2     0.4014    0.6787    0.5045       750
           3     0.5913    0.5854    0.5883       603
           4     0.5110    0.6593    0.5757       810
           5     0.5439    0.3924    0.4559        79
           6     0.0114    0.0625    0.0192       160

    accuracy                         0.3788      3904
   macro avg     0.5213    0.3478    0.3216      3904
weighted avg     0.5926    0.3788    0.3377      3904

Confusion matrix (target):
 [[ 10   8 150  74 231   4 288]
 [  0  32  90  65 167   2 381]
 [  1   2 509  53  96   4  85]
 [  0   0 200 353  11   3  36]
 [  0   3 150  40 534   4  79]
 [  0   1  41   0   6  31   0]
 [  0   1 128  12   0   9  10]]
Running method: DANN




Epoch 0/24 - Method DANN
  Source Acc: 0.9611, Target Acc: 0.4439
Epoch 1/24 - Method DANN
  Source Acc: 0.9132, Target Acc: 0.3601
Epoch 2/24 - Method DANN
  Source Acc: 0.9192, Target Acc: 0.4203
Epoch 3/24 - Method DANN
  Source Acc: 0.8952, Target Acc: 0.4068
Epoch 4/24 - Method DANN
  Source Acc: 0.9251, Target Acc: 0.3371
Epoch 5/24 - Method DANN
  Source Acc: 0.8353, Target Acc: 0.3988
Epoch 6/24 - Method DANN
  Source Acc: 0.8683, Target Acc: 0.3107
Epoch 7/24 - Method DANN
  Source Acc: 0.9251, Target Acc: 0.3755
Epoch 8/24 - Method DANN
  Source Acc: 0.8743, Target Acc: 0.2861
Epoch 9/24 - Method DANN
  Source Acc: 0.8772, Target Acc: 0.3763
Epoch 10/24 - Method DANN
  Source Acc: 0.8653, Target Acc: 0.3573
Epoch 11/24 - Method DANN
  Source Acc: 0.8892, Target Acc: 0.3043
Epoch 12/24 - Method DANN
  Source Acc: 0.9341, Target Acc: 0.3668
Epoch 13/24 - Method DANN
  Source Acc: 0.9281, Target Acc: 0.3384
Epoch 14/24 - Method DANN
  Source Acc: 0.8713, Target Acc: 0.3340
Epoch



Result: 0.9041916196217794 0.2646004098360656 domain proxy acc: 1.0
Classification report (target):
               precision    recall  f1-score   support

           0     0.8649    0.0418    0.0798       765
           1     0.9062    0.0393    0.0754       737
           2     0.3262    0.2840    0.3036       750
           3     0.2274    0.8076    0.3548       603
           4     0.4177    0.2568    0.3180       810
           5     0.1344    0.8354    0.2316        79
           6     0.0000    0.0000    0.0000       160

    accuracy                         0.2651      3904
   macro avg     0.4110    0.3236    0.1948      3904
weighted avg     0.5277    0.2651    0.2137      3904

Confusion matrix (target):
 [[ 32   3  89 435 123  56  27]
 [  1  29  95 368 112 111  21]
 [  1   0 213 403  50  82   1]
 [  0   0  64 487   5  47   0]
 [  3   0  99 421 208  77   2]
 [  0   0  11   2   0  66   0]
 [  0   0  82  26   0  52   0]]


In [None]:
methods = ["CDAN"]
all_res = []
for m in methods:
    print("Running method:", m)
    res = run_method(m, dataset_all["train"], source_domain="photo",
                     target_domain="sketch", num_classes=7,
                     epochs=25, batch_size=32, lr=1e-4, weight_decay=1e-5,
                     lambda_mmd=0.1, alpha=1.0, seed=0)
    print("Result:", res["src_acc"], res["tgt_acc"], "domain proxy acc:", res["domain_proxy_acc"])
    print("Classification report (target):\n", res["class_report"])
    print("Confusion matrix (target):\n", res["confusion_matrix"])
    all_res.append(res)

Running method: CDAN




Epoch 0/24 - Method CDAN
  Source Acc: 0.9341, Target Acc: 0.3699
Epoch 1/24 - Method CDAN
  Source Acc: 0.8503, Target Acc: 0.3796
Epoch 2/24 - Method CDAN
  Source Acc: 0.8922, Target Acc: 0.4062
Epoch 3/24 - Method CDAN
  Source Acc: 0.9251, Target Acc: 0.4047
Epoch 4/24 - Method CDAN
  Source Acc: 0.9641, Target Acc: 0.4070
Epoch 5/24 - Method CDAN
  Source Acc: 0.9162, Target Acc: 0.3988
Epoch 6/24 - Method CDAN
  Source Acc: 0.9222, Target Acc: 0.2748
Epoch 7/24 - Method CDAN
  Source Acc: 0.9192, Target Acc: 0.3847
Epoch 8/24 - Method CDAN
  Source Acc: 0.8683, Target Acc: 0.3320
Epoch 9/24 - Method CDAN
  Source Acc: 0.9132, Target Acc: 0.3858
Epoch 10/24 - Method CDAN
  Source Acc: 0.9311, Target Acc: 0.3537
Epoch 11/24 - Method CDAN
  Source Acc: 0.9311, Target Acc: 0.3814
Epoch 12/24 - Method CDAN
  Source Acc: 0.9281, Target Acc: 0.2636
Epoch 13/24 - Method CDAN
  Source Acc: 0.9042, Target Acc: 0.3589
Epoch 14/24 - Method CDAN
  Source Acc: 0.9431, Target Acc: 0.3322
Epoch



Result: 0.9221556904073247 0.3230020491803279 domain proxy acc: 1.0
Classification report (target):
               precision    recall  f1-score   support

           0     0.4881    0.2669    0.3451       768
           1     0.7368    0.0190    0.0371       735
           2     0.3860    0.1404    0.2059       748
           3     0.2591    0.9835    0.4102       605
           4     0.4562    0.3469    0.3941       810
           5     0.4460    0.7949    0.5714        78
           6     0.0070    0.0063    0.0066       160

    accuracy                         0.3235      3904
   macro avg     0.3971    0.3654    0.2815      3904
weighted avg     0.4527    0.3235    0.2714      3904

Confusion matrix (target):
 [[205   2  22 372  82  18  67]
 [ 81  14  25 429 131   7  48]
 [ 60   0 105 432 118  18  15]
 [  0   1   7 595   2   0   0]
 [ 74   0  77 349 281  18  11]
 [  0   1   1  12   2  62   0]
 [  0   1  35 107   0  16   1]]


Self Traninng results

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import random
from tqdm.auto import tqdm

def set_seed(seed: int):
    """Sets the seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def accuracy(preds: torch.Tensor, labels: torch.Tensor) -> float:
    """Computes classification accuracy."""
    _, p = torch.max(preds, dim=1)
    return (p == labels).float().mean().item()

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Data Handling ---

def get_transforms(train: bool):
    """Gets the appropriate train/test transformations."""
    if train:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    else:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

class PACSSubset(Dataset):
    """Dataset wrapper for PACS data filtered by domain."""
    def __init__(self, hf_dataset, domain_name, transform=None, indices=None):
        self.examples = [ex for ex in hf_dataset if ex["domain"] == domain_name]
        if indices is not None:
            self.examples = [self.examples[i] for i in indices]
        self.transform = transform

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        img = ex["image"].convert("RGB")
        label = ex["label"]
        if self.transform:
            img = self.transform(img)
        return img, label

class PseudoLabelDataset(Dataset):
    """A dataset for the target domain images with their pseudo-labels."""
    def __init__(self, target_domain_subset, pseudo_labels, transform=None):
        self.examples = target_domain_subset.examples
        self.pseudo_labels = pseudo_labels
        self.transform = transform
        assert len(self.examples) == len(self.pseudo_labels), "Mismatch in data and pseudo-labels length"

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        img = ex["image"].convert("RGB")
        pseudo_label = self.pseudo_labels[idx]
        if self.transform:
            img = self.transform(img)
        # Return the image with its confident pseudo-label
        return img, pseudo_label

# --- Model Definitions ---

class FeatureExtractor(nn.Module):
    """ResNet50 feature extractor."""
    def __init__(self, pretrained=True):
        super().__init__()
        self.backbone = models.resnet50(pretrained=pretrained)
        self.out_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

    def forward(self, x):
        return self.backbone(x)

class ClassifierHead(nn.Module):
    """Simple linear classifier head."""
    def __init__(self, in_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, feats):
        return self.fc(feats)

# --- Training and Evaluation Loops ---

def train_one_epoch(feat_ext, classifier, loader, optimizer):
    """Standard training loop for one epoch."""
    feat_ext.train()
    classifier.train()
    total_loss = 0.0
    total_acc = 0.0
    n = 0
    criterion = nn.CrossEntropyLoss()
    for x, y in tqdm(loader, desc="Training"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        features = feat_ext(x)
        logits = classifier(features)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * y.size(0)
        total_acc += accuracy(logits, y) * y.size(0)
        n += y.size(0)
    return total_loss / n, total_acc / n

@torch.no_grad()
def evaluate(feat_ext, classifier, loader):
    """Standard evaluation loop."""
    feat_ext.eval()
    classifier.eval()
    total_acc = 0.0
    n = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = classifier(feat_ext(x))
        total_acc += accuracy(logits, y) * y.size(0)
        n += y.size(0)
    return total_acc / n

# --- Main Self-Training Logic ---

def run_self_training(hf_dataset, source_domain, target_domain, num_classes=7,
                      epochs_teacher=15, epochs_student=10, batch_size=32,
                      lr=1e-4, weight_decay=1e-5, confidence_threshold=0.95, seed=0):
    """
    Performs the full self-training pipeline.
    1. Trains a "teacher" model on the source domain.
    2. Generates pseudo-labels for the target domain.
    3. Filters pseudo-labels by confidence.
    4. Fine-tunes a "student" model on the confident pseudo-labels.
    5. Evaluates the final student model.
    """
    set_seed(seed)

    # --- 1. Train Teacher Model on Source Data ---
    print("--- Step 1: Training Teacher Model ---")
    src_train_set = PACSSubset(hf_dataset, source_domain, transform=get_transforms(train=True))
    src_test_set = PACSSubset(hf_dataset, source_domain, transform=get_transforms(train=False))
    tgt_test_set = PACSSubset(hf_dataset, target_domain, transform=get_transforms(train=False))

    src_train_loader = DataLoader(src_train_set, batch_size=batch_size, shuffle=True, num_workers=4)
    src_test_loader = DataLoader(src_test_set, batch_size=batch_size, shuffle=False, num_workers=4)
    tgt_test_loader = DataLoader(tgt_test_set, batch_size=batch_size, shuffle=False, num_workers=4)

    teacher_feat_ext = FeatureExtractor().to(device)
    teacher_classifier = ClassifierHead(teacher_feat_ext.out_dim, num_classes).to(device)
    optimizer_teacher = optim.Adam(
        list(teacher_feat_ext.parameters()) + list(teacher_classifier.parameters()),
        lr=lr, weight_decay=weight_decay
    )

    for ep in range(epochs_teacher):
        train_loss, train_acc = train_one_epoch(teacher_feat_ext, teacher_classifier, src_train_loader, optimizer_teacher)
        print(f"Teacher Epoch {ep}/{epochs_teacher-1} - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

    teacher_src_acc = evaluate(teacher_feat_ext, teacher_classifier, src_test_loader)
    teacher_tgt_acc = evaluate(teacher_feat_ext, teacher_classifier, tgt_test_loader)
    print(f"Teacher Model Performance -> Source Acc: {teacher_src_acc:.4f}, Target Acc: {teacher_tgt_acc:.4f}\n")

    # --- 2. Generate Pseudo-Labels on Target Data ---
    print("--- Step 2: Generating Pseudo-Labels ---")
    target_full_loader = DataLoader(tgt_test_set, batch_size=batch_size, shuffle=False, num_workers=4)

    teacher_feat_ext.eval()
    teacher_classifier.eval()

    all_pseudo_labels = []
    all_confidences = []

    with torch.no_grad():
        for x, _ in tqdm(target_full_loader, desc="Predicting on Target"):
            x = x.to(device)
            features = teacher_feat_ext(x)
            logits = teacher_classifier(features)
            probs = torch.softmax(logits, dim=1)
            confidences, pseudo_labels = torch.max(probs, dim=1)
            all_pseudo_labels.append(pseudo_labels.cpu())
            all_confidences.append(confidences.cpu())

    all_pseudo_labels = torch.cat(all_pseudo_labels)
    all_confidences = torch.cat(all_confidences)

    # --- 3. Filter by Confidence and Create New Dataset ---
    print(f"--- Step 3: Filtering with confidence threshold > {confidence_threshold} ---")
    confident_indices = torch.where(all_confidences >= confidence_threshold)[0]

    if len(confident_indices) == 0:
        print("Warning: No samples passed the confidence threshold. Self-training cannot proceed.")
        return None

    confident_pseudo_labels = all_pseudo_labels[confident_indices]

    # Create a new PACSSubset containing only the confident target examples
    confident_target_subset = PACSSubset(hf_dataset, target_domain, indices=confident_indices.tolist())

    pseudo_labeled_dataset = PseudoLabelDataset(
        confident_target_subset,
        confident_pseudo_labels,
        transform=get_transforms(train=True) # Use training transforms for fine-tuning
    )

    pseudo_labeled_loader = DataLoader(pseudo_labeled_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    print(f"Found {len(confident_indices)} confident samples out of {len(tgt_test_set)} ({len(confident_indices)/len(tgt_test_set):.2%}).")

    # --- 4. Fine-tune Student Model ---
    print("\n--- Step 4: Fine-tuning Student Model on Pseudo-Labels ---")
    # Initialize student model with teacher's weights
    student_feat_ext = FeatureExtractor().to(device)
    student_feat_ext.load_state_dict(teacher_feat_ext.state_dict())
    student_classifier = ClassifierHead(student_feat_ext.out_dim, num_classes).to(device)
    student_classifier.load_state_dict(teacher_classifier.state_dict())

    optimizer_student = optim.Adam(
        list(student_feat_ext.parameters()) + list(student_classifier.parameters()),
        lr=lr/10, # Use a smaller learning rate for fine-tuning
        weight_decay=weight_decay
    )

    for ep in range(epochs_student):
        train_loss, train_acc = train_one_epoch(student_feat_ext, student_classifier, pseudo_labeled_loader, optimizer_student)
        print(f"Student Epoch {ep}/{epochs_student-1} - Pseudo-Train Loss: {train_loss:.4f}, Pseudo-Train Acc: {train_acc:.4f}")

    # --- 5. Final Evaluation ---
    print("\n--- Step 5: Final Evaluation of Student Model ---")
    final_src_acc = evaluate(student_feat_ext, student_classifier, src_test_loader)
    final_tgt_acc = evaluate(student_feat_ext, student_classifier, tgt_test_loader)

    print(f"\nFinal Student Performance -> Source Acc: {final_src_acc:.4f}, Target Acc: {final_tgt_acc:.4f}")

    return {
        "method": "Self-Training",
        "teacher_tgt_acc": teacher_tgt_acc,
        "student_src_acc": final_src_acc,
        "student_tgt_acc": final_tgt_acc,
        "confident_samples": len(confident_indices),
        "total_target_samples": len(tgt_test_set)
    }

# Load dataset once
print("Loading PACS dataset...")
dataset_all = load_dataset("flwrlabs/pacs")
print("Dataset loaded.")

# Run the Self-Training experiment
self_training_results = run_self_training(
    hf_dataset=dataset_all["train"],
    source_domain="photo",
    target_domain="sketch",
    num_classes=7,
    epochs_teacher=20,
    epochs_student=10,
    batch_size=32,
    lr=1e-4,
    confidence_threshold=0.95,
    seed=0
)

print("\n--- Self-Training Experiment Summary ---")
if self_training_results:
    print(f"Teacher (Source-Only) Target Accuracy: {self_training_results['teacher_tgt_acc']:.4f}")
    print(f"Student (Fine-tuned) Target Accuracy: {self_training_results['student_tgt_acc']:.4f}")
    improvement = self_training_results['student_tgt_acc'] - self_training_results['teacher_tgt_acc']
    print(f"Improvement over baseline: {improvement:.4f} ({improvement/self_training_results['teacher_tgt_acc']:.2%})")
else:
    print("Experiment failed to run.")


Loading PACS dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/191M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9991 [00:00<?, ? examples/s]

Dataset loaded.
--- Step 1: Training Teacher Model ---




Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 190MB/s]


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 0/19 - Train Loss: 0.4500, Train Acc: 0.8814


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 1/19 - Train Loss: 0.1979, Train Acc: 0.9443


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 2/19 - Train Loss: 0.1706, Train Acc: 0.9485


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 3/19 - Train Loss: 0.1755, Train Acc: 0.9455


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 4/19 - Train Loss: 0.1276, Train Acc: 0.9593


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 5/19 - Train Loss: 0.1175, Train Acc: 0.9599


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 6/19 - Train Loss: 0.1110, Train Acc: 0.9641


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 7/19 - Train Loss: 0.1191, Train Acc: 0.9605


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 8/19 - Train Loss: 0.1044, Train Acc: 0.9665


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b6080a4cfe0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b6080a4cfe0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Teacher Epoch 9/19 - Train Loss: 0.1021, Train Acc: 0.9641


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 10/19 - Train Loss: 0.1069, Train Acc: 0.9617


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 11/19 - Train Loss: 0.0909, Train Acc: 0.9701


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 12/19 - Train Loss: 0.0925, Train Acc: 0.9665


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 13/19 - Train Loss: 0.0960, Train Acc: 0.9683


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 14/19 - Train Loss: 0.0856, Train Acc: 0.9707


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 15/19 - Train Loss: 0.0965, Train Acc: 0.9701


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 16/19 - Train Loss: 0.1375, Train Acc: 0.9599


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 17/19 - Train Loss: 0.1100, Train Acc: 0.9629


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 18/19 - Train Loss: 0.1004, Train Acc: 0.9707


Training:   0%|          | 0/53 [00:00<?, ?it/s]

Teacher Epoch 19/19 - Train Loss: 0.0797, Train Acc: 0.9701
Teacher Model Performance -> Source Acc: 0.9988, Target Acc: 0.2074

--- Step 2: Generating Pseudo-Labels ---


Predicting on Target:   0%|          | 0/123 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b6080a4cfe0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b6080a4cfe0>    
self._shutdown_workers()Traceback (most recent call last):

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        self._shutdown_workers()if w.is_alive():
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

       if w.is_alive(): 
      ^ ^ ^^  ^^^^^^^^^^^^^^^^^^^^

  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
        assert self.

--- Step 3: Filtering with confidence threshold > 0.95 ---
Found 563 confident samples out of 3929 (14.33%).

--- Step 4: Fine-tuning Student Model on Pseudo-Labels ---


Training:   0%|          | 0/18 [00:00<?, ?it/s]

Student Epoch 0/9 - Pseudo-Train Loss: 4.8351, Pseudo-Train Acc: 0.1670


Training:   0%|          | 0/18 [00:00<?, ?it/s]

Student Epoch 1/9 - Pseudo-Train Loss: 3.5108, Pseudo-Train Acc: 0.1883


Training:   0%|          | 0/18 [00:00<?, ?it/s]

Student Epoch 2/9 - Pseudo-Train Loss: 2.5970, Pseudo-Train Acc: 0.2238


Training:   0%|          | 0/18 [00:00<?, ?it/s]

Student Epoch 3/9 - Pseudo-Train Loss: 1.9571, Pseudo-Train Acc: 0.3517


Training:   0%|          | 0/18 [00:00<?, ?it/s]

Student Epoch 4/9 - Pseudo-Train Loss: 1.4100, Pseudo-Train Acc: 0.5329


Training:   0%|          | 0/18 [00:00<?, ?it/s]

Student Epoch 5/9 - Pseudo-Train Loss: 1.0671, Pseudo-Train Acc: 0.6927


Training:   0%|          | 0/18 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b6080a4cfe0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x7b6080a4cfe0>if w.is_alive():

Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
          self._shutdown_workers()^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    ^if w.is_alive():
^^  ^^ ^^ ^ ^^^ 
 ^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
^  ^ ^^^ ^^^^ 
   File "/usr/lib/

Student Epoch 6/9 - Pseudo-Train Loss: 0.7661, Pseudo-Train Acc: 0.8259


Training:   0%|          | 0/18 [00:00<?, ?it/s]

Student Epoch 7/9 - Pseudo-Train Loss: 0.6568, Pseudo-Train Acc: 0.8472


Training:   0%|          | 0/18 [00:00<?, ?it/s]

Student Epoch 8/9 - Pseudo-Train Loss: 0.4656, Pseudo-Train Acc: 0.9201


Training:   0%|          | 0/18 [00:00<?, ?it/s]

Student Epoch 9/9 - Pseudo-Train Loss: 0.4076, Pseudo-Train Acc: 0.9361

--- Step 5: Final Evaluation of Student Model ---

Final Student Performance -> Source Acc: 0.1144, Target Acc: 0.2545

--- Self-Training Experiment Summary ---
Teacher (Source-Only) Target Accuracy: 0.2074
Student (Fine-tuned) Target Accuracy: 0.2545
Improvement over baseline: 0.0471 (22.70%)


concept shift

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from sklearn.metrics import classification_report
import numpy as np
import random
from tqdm.auto import tqdm

# --- Boilerplate and Utilities (from previous code) ---

def set_seed(seed: int):
    """Sets the seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def accuracy(preds: torch.Tensor, labels: torch.Tensor) -> float:
    """Computes classification accuracy."""
    _, p = torch.max(preds, dim=1)
    return (p == labels).float().mean().item()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PACS_CLASS_NAMES = ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']

# --- Data Handling ---

def get_transforms(train: bool):
    """Gets the appropriate train/test transformations."""
    if train:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    else:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

class CustomPACSDataset(Dataset):
    """A custom Dataset wrapper that takes a list of pre-filtered examples."""
    def __init__(self, examples, transform=None):
        self.examples = examples
        self.transform = transform

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        img = ex["image"].convert("RGB")
        label = ex["label"]
        if self.transform:
            img = self.transform(img)
        return img, label

def make_dataloaders_label_shift(hf_dataset, source_domain, target_domain,
                                 batch_size=32, seed=0,
                                 rare_class_id=None, rare_class_ratio=1.0):
    """
    Creates DataLoaders for a label shift experiment.

    The key difference is that the target_alignment_loader has the 'rare_class_id'
    downsampled by 'rare_class_ratio'.

    Returns:
        src_train_loader, src_test_loader,
        target_alignment_loader (modified), target_test_loader (unmodified)
    """
    set_seed(seed)

    # 1. Create Source DataLoaders (as before)
    source_full_examples = [ex for ex in hf_dataset if ex["domain"] == source_domain]
    n_src = len(source_full_examples)
    indices = list(range(n_src))
    random.shuffle(indices)
    split = int(n_src * 0.8)
    train_indices = indices[:split]
    test_indices = indices[split:]

    src_train_examples = [source_full_examples[i] for i in train_indices]
    src_test_examples = [source_full_examples[i] for i in test_indices]

    src_train_set = CustomPACSDataset(src_train_examples, transform=get_transforms(train=True))
    src_test_set = CustomPACSDataset(src_test_examples, transform=get_transforms(train=False))

    src_train_loader = DataLoader(src_train_set, batch_size=batch_size, shuffle=True, drop_last=True)
    src_test_loader = DataLoader(src_test_set, batch_size=batch_size, shuffle=False)

    # 2. Create FULL Target Test Loader for final evaluation (unmodified)
    target_full_examples = [ex for ex in hf_dataset if ex["domain"] == target_domain]
    target_test_set = CustomPACSDataset(target_full_examples, transform=get_transforms(train=False))
    target_test_loader = DataLoader(target_test_set, batch_size=batch_size, shuffle=False)

    # 3. Create MODIFIED Target Alignment Loader
    if rare_class_id is not None:
        print(f"Modifying target alignment set: Removing class '{PACS_CLASS_NAMES[rare_class_id]}' (ID: {rare_class_id})...")
        modified_target_examples = []
        for ex in target_full_examples:
            if ex['label'] == rare_class_id:
                if random.random() < rare_class_ratio: # Downsample
                    modified_target_examples.append(ex)
            else: # Keep all other classes
                modified_target_examples.append(ex)

        print(f"Original target size: {len(target_full_examples)}. New alignment size: {len(modified_target_examples)}.")
        target_alignment_set = CustomPACSDataset(modified_target_examples, transform=get_transforms(train=True))
    else:
        # If no rare class specified, use the full target set for alignment
        target_alignment_set = CustomPACSDataset(target_full_examples, transform=get_transforms(train=True))

    target_alignment_loader = DataLoader(target_alignment_set, batch_size=batch_size, shuffle=True, drop_last=True)

    return src_train_loader, src_test_loader, target_alignment_loader, target_test_loader

# --- Model Definitions & Training Loops (Copied from previous tasks for self-containment) ---

class FeatureExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)
        self.out_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
    def forward(self, x): return self.backbone(x)

class ClassifierHead(nn.Module):
    def __init__(self, in_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)
    def forward(self, feats): return self.fc(feats)

class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.alpha, None

class DomainDiscriminator(nn.Module):
    def __init__(self, in_dim, hidden_dim=1024):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(in_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 2))
    def forward(self, x): return self.net(x)

def train_epoch_source_only(feat_ext, classifier, optimizer, src_loader):
    feat_ext.train(); classifier.train()
    for x_s, y_s in tqdm(src_loader, desc="Source-Only Train"):
        x_s, y_s = x_s.to(device), y_s.to(device)
        optimizer.zero_grad()
        logits_s = classifier(feat_ext(x_s))
        loss = nn.CrossEntropyLoss()(logits_s, y_s)
        loss.backward()
        optimizer.step()

def mmd_loss(source_feats, target_feats):
    # A simple MMD implementation for demonstration
    source_mean = source_feats.mean(dim=0)
    target_mean = target_feats.mean(dim=0)
    return (source_mean - target_mean).pow(2).sum()

def train_epoch_dan(feat_ext, classifier, optimizer, src_loader, tgt_loader, lambda_mmd):
    feat_ext.train(); classifier.train()
    for (x_s, y_s), (x_t, _) in zip(src_loader, tgt_loader):
        x_s, y_s, x_t = x_s.to(device), y_s.to(device), x_t.to(device)
        optimizer.zero_grad()
        f_s = feat_ext(x_s)
        f_t = feat_ext(x_t)
        logits_s = classifier(f_s)
        loss_cls = nn.CrossEntropyLoss()(logits_s, y_s)
        loss_m = mmd_loss(f_s, f_t)
        loss = loss_cls + lambda_mmd * loss_m
        loss.backward()
        optimizer.step()

def train_epoch_dann(feat_ext, classifier, dom_disc, optimizer, src_loader, tgt_loader, alpha):
    feat_ext.train(); classifier.train(); dom_disc.train()
    for (x_s, y_s), (x_t, _) in zip(src_loader, tgt_loader):
        x_s, y_s, x_t = x_s.to(device), y_s.to(device), x_t.to(device)
        optimizer.zero_grad()
        # Classification loss
        f_s = feat_ext(x_s)
        logits_s = classifier(f_s)
        loss_cls = nn.CrossEntropyLoss()(logits_s, y_s)
        # Domain loss
        f_t = feat_ext(x_t)
        f_rev = GradReverse.apply(torch.cat([f_s, f_t], dim=0), alpha)
        dom_logits = dom_disc(f_rev)
        dom_labels = torch.cat([torch.zeros(f_s.size(0)), torch.ones(f_t.size(0))]).long().to(device)
        loss_dom = nn.CrossEntropyLoss()(dom_logits, dom_labels)
        loss = loss_cls + loss_dom
        loss.backward()
        optimizer.step()

@torch.no_grad()
def evaluate_and_report(feat_ext, classifier, loader, title="Evaluation"):
    feat_ext.eval(); classifier.eval()
    all_preds, all_labels = [], []
    for x, y in tqdm(loader, desc=title):
        x = x.to(device)
        logits = classifier(feat_ext(x))
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds.cpu())
        all_labels.append(y)

    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()

    report = classification_report(all_labels, all_preds, target_names=PACS_CLASS_NAMES, zero_division=0)
    return report

# --- Main Experiment Runner ---

def run_experiment(method, hf_dataset, source, target, rare_class_id, rare_class_ratio):
    print(f"\n===== Running Experiment: {method} =====")
    print(f"Source: {source}, Target: {target}")
    print(f"Simulating Label Shift: Class '{PACS_CLASS_NAMES[rare_class_id]}' (ID: {rare_class_id}) removed from target alignment set.")

    set_seed(42)

    # Get data loaders with the label shift applied
    src_train, _, tgt_align, tgt_test = make_dataloaders_label_shift(
        hf_dataset, source, target,
        rare_class_id=rare_class_id, rare_class_ratio=rare_class_ratio
    )

    # Initialize models
    feat_ext = FeatureExtractor().to(device)
    classifier = ClassifierHead(feat_ext.out_dim, 7).to(device)
    dom_disc = DomainDiscriminator(feat_ext.out_dim).to(device)

    params = list(feat_ext.parameters()) + list(classifier.parameters()) + list(dom_disc.parameters())
    optimizer = optim.Adam(params, lr=1e-4, weight_decay=1e-5)

    # Training loop
    epochs = 15 # Reduced epochs for quicker demonstration
    for ep in range(epochs):
        print(f"Epoch {ep+1}/{epochs}")
        if method == "Source-Only":
            train_epoch_source_only(feat_ext, classifier, optimizer, src_train)
        elif method == "DAN":
            train_epoch_dan(feat_ext, classifier, optimizer, src_train, tgt_align, lambda_mmd=0.1)
        elif method == "DANN":
            train_epoch_dann(feat_ext, classifier, dom_disc, optimizer, src_train, tgt_align, alpha=1.0)

    # Final evaluation on the COMPLETE target set
    final_report = evaluate_and_report(feat_ext, classifier, tgt_test, title=f"Final Eval ({method})")
    print(f"\n--- Classification Report for {method} on Full Target Set ---")
    print(final_report)
    print("=" * 40)


print("Loading PACS dataset...")
dataset_all = load_dataset("flwrlabs/pacs")["train"]
print("Dataset loaded.")

SOURCE_DOMAIN = "photo"
TARGET_DOMAIN = "sketch"
# Class 3 in PACS is 'guitar'. Let's remove it from the target during alignment.
RARE_CLASS_ID = 3

# Scenario 1: Complete removal of the class during training/alignment
run_experiment("Source-Only", dataset_all, SOURCE_DOMAIN, TARGET_DOMAIN, RARE_CLASS_ID, rare_class_ratio=0.0)
run_experiment("DAN", dataset_all, SOURCE_DOMAIN, TARGET_DOMAIN, RARE_CLASS_ID, rare_class_ratio=0.0)
run_experiment("DANN", dataset_all, SOURCE_DOMAIN, TARGET_DOMAIN, RARE_CLASS_ID, rare_class_ratio=0.0)

Loading PACS dataset...
Dataset loaded.

===== Running Experiment: Source-Only =====
Source: photo, Target: sketch
Simulating Label Shift: Class 'guitar' (ID: 3) removed from target alignment set.
Modifying target alignment set: Removing class 'guitar' (ID: 3)...
Original target size: 3929. New alignment size: 3321.
Epoch 1/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 2/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 3/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 4/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 5/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 6/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 7/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 8/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 9/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 10/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 11/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 12/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 13/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 14/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Epoch 15/15


Source-Only Train:   0%|          | 0/41 [00:00<?, ?it/s]

Final Eval (Source-Only):   0%|          | 0/123 [00:00<?, ?it/s]


--- Classification Report for Source-Only on Full Target Set ---
              precision    recall  f1-score   support

         dog       0.00      0.00      0.00       772
    elephant       0.00      0.00      0.00       740
     giraffe       0.00      0.00      0.00       753
      guitar       0.73      0.03      0.05       608
       horse       0.21      1.00      0.35       816
       house       0.00      0.00      0.00        80
      person       0.00      0.00      0.00       160

    accuracy                           0.21      3929
   macro avg       0.13      0.15      0.06      3929
weighted avg       0.16      0.21      0.08      3929


===== Running Experiment: DAN =====
Source: photo, Target: sketch
Simulating Label Shift: Class 'guitar' (ID: 3) removed from target alignment set.
Modifying target alignment set: Removing class 'guitar' (ID: 3)...
Original target size: 3929. New alignment size: 3321.
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
E

Final Eval (DAN):   0%|          | 0/123 [00:00<?, ?it/s]


--- Classification Report for DAN on Full Target Set ---
              precision    recall  f1-score   support

         dog       0.73      0.15      0.25       772
    elephant       0.68      0.19      0.30       740
     giraffe       0.37      0.58      0.45       753
      guitar       0.27      0.87      0.41       608
       horse       0.69      0.35      0.46       816
       house       0.84      0.47      0.61        80
      person       0.00      0.00      0.00       160

    accuracy                           0.39      3929
   macro avg       0.51      0.37      0.35      3929
weighted avg       0.55      0.39      0.36      3929


===== Running Experiment: DANN =====
Source: photo, Target: sketch
Simulating Label Shift: Class 'guitar' (ID: 3) removed from target alignment set.
Modifying target alignment set: Removing class 'guitar' (ID: 3)...
Original target size: 3929. New alignment size: 3321.
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/

Final Eval (DANN):   0%|          | 0/123 [00:00<?, ?it/s]


--- Classification Report for DANN on Full Target Set ---
              precision    recall  f1-score   support

         dog       0.43      0.14      0.21       772
    elephant       0.52      0.54      0.53       740
     giraffe       0.36      0.63      0.46       753
      guitar       0.22      0.24      0.23       608
       horse       0.61      0.32      0.42       816
       house       0.00      0.01      0.00        80
      person       0.00      0.00      0.00       160

    accuracy                           0.35      3929
   macro avg       0.31      0.27      0.26      3929
weighted avg       0.41      0.35      0.35      3929

