In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import random_split
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim


import random
from pathlib import Path
import torchvision.transforms.functional as TF

import os
import tarfile
import urllib.request
from PIL import Image

## Data Preparation

Loading STL-10 Dataset

In [None]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])
stl10_train = torchvision.datasets.STL10(root='./stl_data', split='train', download=True, transform=transforms)
stl10_test = torchvision.datasets.STL10(root='./stl_data', split='test', download=True, transform=transforms)

In [None]:
train_size = int(0.9 * len(stl10_train))
val_size = len(stl10_train) - train_size
stl10_train_split, stl10_val_split = random_split(stl10_train, [train_size, val_size])

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [None]:
BATCH_SIZE = 64

train_loader = DataLoader(stl10_train_split, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(stl10_val_split, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(stl10_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

## Fine-tuning

Loading ResNet-50

In [None]:
resnet = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
resnet

Loading ViT-S/16

In [None]:
vit = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
vit

Fine-tuning ResNet

In [None]:
# REPLACING FINAL LAYER
resnet.fc = torch.nn.Linear(in_features=resnet.fc.in_features, out_features=stl10_train.classes, bias=True) 

# FREEZING BACKBONE
for param in resnet.parameters():
    param.requires_grad = False
for param in resnet.fc.parameters():
    param.requires_grad = True

resnet = resnet.to(DEVICE)
EPOCHS = 10
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.fc.parameters(), lr=0.005)

# FULL FINETUNING
# for param in resnet.parameters():
#     param.requires_grad = True
# optimizer = optim.Adam(resnet.parameters(), lr=0.0001)

In [None]:
resnet_train_accuracies = []
resnet_train_losses = []
resnet_val_accuracies = []
resnet_val_losses = []

for epoch in range(EPOCHS):
    resnet.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = resnet(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    train_acc = 100. * correct / total
    avg_train_loss = running_loss / len(train_loader)
    resnet_train_accuracies.append(train_acc)
    resnet_train_losses.append(avg_train_loss)
    print(f"Epoch {epoch+1}")
    print(f"\tTraining Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%")

    # Validation
    resnet.eval()
    val_running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = resnet(images)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    val_acc = 100. * correct / total
    avg_val_loss = val_running_loss / len(val_loader)
    resnet_val_accuracies.append(val_acc)
    resnet_val_losses.append(avg_val_loss)
    print(f"\tValidation Loss: {avg_val_loss:.4f}, Validation Acc: {val_acc:.2f}%")

In [None]:
torch.save(resnet.state_dict(), 'resnet_finetuned.pth')

In [None]:
resnet.load_state_dict(torch.load('resnet_finetuned.pth', map_location=DEVICE))
print("Loaded fine-tuned ResNet-50 weights.")

Fine-tuning ViT

In [None]:
vit.heads.head = torch.nn.Linear(vit.heads.head.in_features, 10)

# Freeze backbone, only train classification head
for param in vit.parameters():
    param.requires_grad = False
for param in vit.heads.head.parameters():
    param.requires_grad = True
    

vit = vit.to(DEVICE)
EPOCHS = 10
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(vit.heads.head.parameters(), lr=0.005)

# FULL FINETUNING
# for param in vit.parameters():
#     param.requires_grad = True
# optimizer = optim.Adam(vit.parameters(), lr=0.0001)

In [None]:
vit_train_accuracies = []
vit_train_losses = []
vit_val_accuracies = []
vit_val_losses = []

for epoch in range(EPOCHS):
    vit.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = vit(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    train_acc = 100. * correct / total
    avg_train_loss = running_loss / len(train_loader)
    vit_train_accuracies.append(train_acc)
    vit_train_losses.append(avg_train_loss)
    print(f"Epoch {epoch+1}")
    print(f"\tTraining Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%")

    # Validation
    vit.eval()
    val_running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = vit(images)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    val_acc = 100. * correct / total
    avg_val_loss = val_running_loss / len(val_loader)
    vit_val_accuracies.append(val_acc)
    vit_val_losses.append(avg_val_loss)
    print(f"\tValidation Loss: {avg_val_loss:.4f}, Validation Acc: {val_acc:.2f}%")

In [None]:
torch.save(vit.state_dict(), 'vit_finetuned.pth')

In [None]:
vit.load_state_dict(torch.load('vit_finetuned.pth', map_location=DEVICE))
print("Loaded fine-tuned ViT weights.") 

## In-Distribution Performance

Evaluating Fine-Tuned ResNet

In [None]:
resnet.eval()
resnet_test_loss = 0.0
resnet_correct = 0
resnet_total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = resnet(images)
        loss = criterion(outputs, labels)
        resnet_test_loss += loss.item()
        _, predicted = outputs.max(1)
        resnet_total += labels.size(0)
        resnet_correct += predicted.eq(labels).sum().item()

resnet_test_acc = 100. * resnet_correct / resnet_total
resnet_avg_test_loss = resnet_test_loss / len(test_loader)
print(f"ResNet-50 Test Loss: {resnet_avg_test_loss:.4f}, Test Acc: {resnet_test_acc:.2f}%")

Evaluating Fine-Tuned ViT

In [None]:
vit.eval()
vit_test_loss = 0.0
vit_correct = 0
vit_total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = vit(images)
        loss = criterion(outputs, labels)
        vit_test_loss += loss.item()
        _, predicted = outputs.max(1)
        vit_total += labels.size(0)
        vit_correct += predicted.eq(labels).sum().item()

vit_test_acc = 100. * vit_correct / vit_total
vit_avg_test_loss = vit_test_loss / len(test_loader)
print(f"ViT-S/16 Test Loss: {vit_avg_test_loss:.4f}, Test Acc: {vit_test_acc:.2f}%")

## Color-bias Test

Grayscale Dataset

In [None]:
transforms_grayscale = torchvision.transforms.Compose([
    torchvision.transforms.Grayscale(num_output_channels=3),
    torchvision.transforms.ToTensor(),
])
stl10_test_grayscale = torchvision.datasets.STL10(root='./stl_data', split='test', download=True, transform=transforms_grayscale)
test_loader = DataLoader(stl10_test_grayscale, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

Evaluating Fine-Tuned ResNet

In [None]:
resnet.eval()
resnet_gray_correct = 0
resnet_gray_total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = resnet(images)
        _, predicted = outputs.max(1)
        resnet_gray_total += labels.size(0)
        resnet_gray_correct += predicted.eq(labels).sum().item()

resnet_gray_acc = 100. * resnet_gray_correct / resnet_gray_total
print(f"ResNet-50 Grayscale Test Acc: {resnet_gray_acc:.2f}%")
print(f"ResNet-50 Accuracy Drop: {resnet_test_acc - resnet_gray_acc:.2f}%")


Evaluating Fine-Tuned ViT

In [None]:
vit.eval()
vit_gray_correct = 0
vit_gray_total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = vit(images)
        _, predicted = outputs.max(1)
        vit_gray_total += labels.size(0)
        vit_gray_correct += predicted.eq(labels).sum().item()

vit_gray_acc = 100. * vit_gray_correct / vit_gray_total
print(f"ViT-S/16 Grayscale Test Acc: {vit_gray_acc:.2f}%")
print(f"ViT-S/16 Accuracy Drop: {vit_test_acc - vit_gray_acc:.2f}%")

## Shape vs. Texture Bias – Stylized Images

Use this repo: https://github.com/rgeirhos/Stylized-ImageNet

## Translation Invariance Test

Translated Dataset

In [None]:
# def get_shifted_test_loader(shift_x=0, shift_y=0):
#     shift_transform = torchvision.transforms.Compose([
#         torchvision.transforms.ToPILImage(),
#         torchvision.transforms.functional.Lambda(lambda img: torchvision.transforms.functional.affine(img, angle=0, translate=(shift_x, shift_y), scale=1.0, shear=0)),
#         torchvision.transforms.ToTensor(),
#     ])
#     shifted_dataset = torchvision.datasets.STL10(
#         root='./stl_data',
#         split='test',
#         download=True,
#         transform=lambda img: shift_transform(img)
#     )
#     return DataLoader(shifted_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# shifted_loader = get_shifted_test_loader(shift_x=5, shift_y=5)

In [None]:
SHIFT_PIXELS = 5

DIRECTIONS = {
    "up": (0, -SHIFT_PIXELS),
    "down": (0, SHIFT_PIXELS),
    "left": (-SHIFT_PIXELS, 0),
    "right": (SHIFT_PIXELS, 0),
    "up_left": (-SHIFT_PIXELS, -SHIFT_PIXELS),
    "up_right": (SHIFT_PIXELS, -SHIFT_PIXELS),
    "down_left": (-SHIFT_PIXELS, SHIFT_PIXELS),
    "down_right": (SHIFT_PIXELS, SHIFT_PIXELS),
}

def make_shift_transform(dx, dy, fill=0):
    return torchvision.transforms.Compose([
        torchvision.transforms.Lambda(lambda img: TF.affine(
            img,
            angle=0,
            translate=(dx, dy),
            scale=1.0,
            shear=0.0,
            interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
            fill=fill
        )),
        torchvision.transforms.ToTensor()
    ])

# class RollShiftToTensor:
#     """Tensor-space circular shift (wrap-around); no padding borders."""
#     def __init__(self, dx, dy):
#         self.dx, self.dy = dx, dy
#     def __call__(self, img_pil):
#         x = torchvision.transforms.ToTensor()(img_pil)          # [C,H,W]
#         x = torch.roll(x, shifts=(self.dy, self.dx), dims=(1, 2))  # (H,W)
#         return x
    
class RandomShiftSTL10(torchvision.datasets.STL10):
    def __init__(self, root, split, download, directions, fill=0):
        super().__init__(root=root, split=split, download=download, transform=None)
        self.directions = list(directions.items())
        self.fill = fill

    def __getitem__(self, index):
        img, target = super().__getitem__(index) 
        name, (dx, dy) = random.choice(self.directions)
        img = TF.affine(
            img, angle=0, translate=(dx, dy), scale=1.0, shear=0.0,
            interpolation=torchvision.transforms.InterpolationMode.BILINEAR, fill=self.fill
        )
        return torchvision.transforms.ToTensor()(img), target, name
        # return RollShiftToTensor(self.dx, self.dy)(img), target, self.direction_name

def get_mixed_shift_loader(split="test", batch_size=BATCH_SIZE, fill=0, seed=None):
    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)
        
    ds = RandomShiftSTL10(root="./stl_data", split=split, download=True,
                          directions=DIRECTIONS, fill=fill)
    return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=2)


shifted_loader = get_mixed_shift_loader(split="test", seed=1337)


Evaluating ResNet

In [None]:
resnet.eval()
resnet_shift_correct = 0
resnet_shift_total = 0

with torch.no_grad():
    for images, labels in shifted_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = resnet(images)
        _, predicted = outputs.max(1)
        resnet_shift_total += labels.size(0)
        resnet_shift_correct += predicted.eq(labels).sum().item()

resnet_shift_acc = 100. * resnet_shift_correct / resnet_shift_total
print(f"ResNet-50 Shifted Test Acc: {resnet_shift_acc:.2f}%")


Evaluating ViT

In [None]:
vit.eval()
vit_shift_correct = 0
vit_shift_total = 0

with torch.no_grad():
    for images, labels in shifted_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = vit(images)
        _, predicted = outputs.max(1)
        vit_shift_total += labels.size(0)
        vit_shift_correct += predicted.eq(labels).sum().item()

vit_shift_acc = 100. * vit_shift_correct / vit_shift_total
print(f"ViT-S/16 Shifted Test Acc: {vit_shift_acc:.2f}%")

## Permutation / Occlusion Test

## Feature Representation Analysis

## Domain Generalization Test on PACS

In [None]:
from datasets import load_dataset
from PIL import Image

# Load PACS dataset from Hugging Face
ds = load_dataset("flwrlabs/pacs")
print(f"Dataset structure: {ds}")
print(f"Available domains: {ds['train'].features['domain'].names}")
print(f"Available classes: {ds['train'].features['label'].names}")

# Create a PyTorch dataset wrapper for the Hugging Face dataset
class PACSDatasetWrapper(Dataset):
    def __init__(self, hf_dataset, domains, transform=None):
        self.transform = transform
        # Filter for selected domains
        self.data = hf_dataset.filter(lambda example: example['domain'] in domains)
        self.classes = hf_dataset.features['label'].names
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        image = item['image'].convert('RGB')
        label = item['label']
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define train and test domains
train_domains = ['photo', 'art_painting', 'cartoon']
test_domain = ['sketch']

# Create datasets
pacs_train = PACSDatasetWrapper(ds['train'], domains=train_domains, transform=transform)
pacs_test = PACSDatasetWrapper(ds['train'], domains=test_domain, transform=transform)

# Split training data into train and validation
train_size = int(0.9 * len(pacs_train))
val_size = len(pacs_train) - train_size
pacs_train_split, pacs_val_split = random_split(pacs_train, [train_size, val_size])

# Create data loaders
BATCH_SIZE = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps")

pacs_train_loader = DataLoader(pacs_train_split, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
pacs_val_loader = DataLoader(pacs_val_split, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
pacs_test_loader = DataLoader(pacs_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training on {len(pacs_train_split)} images from domains: {', '.join(train_domains)}")
print(f"Validating on {len(pacs_val_split)} images from domains: {', '.join(train_domains)}")
print(f"Testing on {len(pacs_test)} images from domain: {', '.join(test_domain)}")

In [None]:
# Download PACS dataset
def download_pacs():
    url = 'https://drive.google.com/uc?export=download&id=1JFr8f805nMUelQWWmfnJR3y75PlRmGCJ'
    if not os.path.exists('./pacs_data'):
        os.makedirs('./pacs_data')
    if not os.path.exists('./pacs_data/PACS.tar.gz'):
        print("Downloading PACS dataset...")
        urllib.request.urlretrieve(url, './pacs_data/PACS.tar.gz')
        
    if not os.path.exists('./pacs_data/PACS'):
        print("Extracting PACS dataset...")
        tar = tarfile.open('./pacs_data/PACS.tar.gz')
        tar.extractall('./pacs_data')
        tar.close()
    print("PACS dataset ready")

# Custom dataset class for PACS
class PACSDataset(Dataset):
    def __init__(self, root_dir, domains, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        
        # PACS dataset has 7 classes
        self.classes = ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        # Load images from the specified domains
        for domain in domains:
            domain_dir = os.path.join(root_dir, domain)
            for class_name in os.listdir(domain_dir):
                if class_name in self.classes:
                    class_dir = os.path.join(domain_dir, class_name)
                    for img_name in os.listdir(class_dir):
                        if img_name.endswith('.jpg') or img_name.endswith('.png'):
                            self.samples.append((
                                os.path.join(class_dir, img_name),
                                self.class_to_idx[class_name]
                            ))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Download the dataset
download_pacs()

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet and ViT expect 224x224 images
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Prepare datasets
train_domains = ['photo', 'art_painting', 'cartoon']
test_domain = ['sketch']

pacs_train = PACSDataset(
    root_dir='./pacs_data/PACS/kfold',
    domains=train_domains,
    transform=transform
)

pacs_test = PACSDataset(
    root_dir='./pacs_data/PACS/kfold',
    domains=test_domain,
    transform=transform
)

# Split training data into train and validation
train_size = int(0.9 * len(pacs_train))
val_size = len(pacs_train) - train_size
pacs_train_split, pacs_val_split = random_split(pacs_train, [train_size, val_size])

# Create data loaders
pacs_train_loader = DataLoader(pacs_train_split, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
pacs_val_loader = DataLoader(pacs_val_split, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
pacs_test_loader = DataLoader(pacs_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training on {len(pacs_train_split)} images from domains: {', '.join(train_domains)}")
print(f"Validating on {len(pacs_val_split)} images from domains: {', '.join(train_domains)}")
print(f"Testing on {len(pacs_test)} images from domain: {', '.join(test_domain)}")


In [None]:
# Training function
def train_model(model, dataloader, criterion, optimizer, epochs=5):
    train_accuracies = []
    train_losses = []
    val_accuracies = []
    val_losses = []
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in dataloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        train_acc = 100. * correct / total
        avg_train_loss = running_loss / len(dataloader)
        train_accuracies.append(train_acc)
        train_losses.append(avg_train_loss)
        
        print(f"Epoch {epoch+1}")
        print(f"\tTraining Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        
        # Validation phase
        model.eval()
        val_running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in pacs_val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        val_acc = 100. * correct / total
        avg_val_loss = val_running_loss / len(pacs_val_loader)
        val_accuracies.append(val_acc)
        val_losses.append(avg_val_loss)
        
        print(f"\tValidation Loss: {avg_val_loss:.4f}, Validation Acc: {val_acc:.2f}%")
    
    return train_accuracies, train_losses, val_accuracies, val_losses


In [None]:
resnet_pacs = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
resnet_pacs.fc = torch.nn.Linear(resnet_pacs.fc.in_features, len(pacs_train.classes))

# Freeze backbone, only train the last layer
for param in resnet_pacs.parameters():
    param.requires_grad = False
for param in resnet_pacs.fc.parameters():
    param.requires_grad = True

resnet_pacs = resnet_pacs.to(DEVICE)
resnet_pacs_criterion = torch.nn.CrossEntropyLoss()
resnet_pacs_optimizer = optim.Adam(resnet_pacs.fc.parameters(), lr=0.005)

In [None]:
print("Fine-tuning ResNet-50 on PACS dataset...")
resnet_pacs_train_acc, resnet_pacs_train_loss, resnet_pacs_val_acc, resnet_pacs_val_loss = train_model(
    model=resnet_pacs,
    dataloader=pacs_train_loader,
    criterion=resnet_pacs_criterion,
    optimizer=resnet_pacs_optimizer,
    epochs=5
)

In [None]:
# Fine-tune ViT for PACS
vit_pacs = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
vit_pacs.heads.head = torch.nn.Linear(vit_pacs.heads.head.in_features, len(pacs_train.classes))

# Freeze backbone, only train classification head
for param in vit_pacs.parameters():
    param.requires_grad = False
for param in vit_pacs.heads.head.parameters():
    param.requires_grad = True
    
vit_pacs = vit_pacs.to(DEVICE)
vit_pacs_criterion = torch.nn.CrossEntropyLoss()
vit_pacs_optimizer = optim.Adam(vit_pacs.heads.head.parameters(), lr=0.005)

In [None]:
# Train ViT on PACS
print("Fine-tuning ViT-B/16 on PACS dataset...")
vit_pacs_train_acc, vit_pacs_train_loss, vit_pacs_val_acc, vit_pacs_val_loss = train_model(
    model=vit_pacs,
    dataloader=pacs_train_loader,
    criterion=vit_pacs_criterion,
    optimizer=vit_pacs_optimizer,
    epochs=5
)

In [None]:
# Evaluate ResNet on Sketch domain
resnet_pacs.eval()
resnet_sketch_correct = 0
resnet_sketch_total = 0

with torch.no_grad():
    for images, labels in pacs_test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = resnet_pacs(images)
        _, predicted = outputs.max(1)
        resnet_sketch_total += labels.size(0)
        resnet_sketch_correct += predicted.eq(labels).sum().item()

resnet_sketch_acc = 100. * resnet_sketch_correct / resnet_sketch_total
print(f"ResNet-50 Sketch Domain Acc: {resnet_sketch_acc:.2f}%")

In [None]:

# Evaluate ViT on Sketch domain
vit_pacs.eval()
vit_sketch_correct = 0
vit_sketch_total = 0

with torch.no_grad():
    for images, labels in pacs_test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = vit_pacs(images)
        _, predicted = outputs.max(1)
        vit_sketch_total += labels.size(0)
        vit_sketch_correct += predicted.eq(labels).sum().item()

vit_sketch_acc = 100. * vit_sketch_correct / vit_sketch_total
print(f"ViT-B/16 Sketch Domain Acc: {vit_sketch_acc:.2f}%")


In [None]:
plt.figure(figsize=(12, 8))

plt.subplot(2, 2, 1)
plt.plot(resnet_pacs_train_acc, label='ResNet Train')
plt.plot(vit_pacs_train_acc, label='ViT Train')
plt.title('Training Accuracy')
plt.legend()

plt.subplot(2, 2, 2)
plt.plot(resnet_pacs_val_acc, label='ResNet Val')
plt.plot(vit_pacs_val_acc, label='ViT Val')
plt.title('Validation Accuracy')
plt.legend()

plt.subplot(2, 2, 3)
plt.plot(resnet_pacs_train_loss, label='ResNet Train')
plt.plot(vit_pacs_train_loss, label='ViT Train')
plt.title('Training Loss')
plt.legend()

plt.subplot(2, 2, 4)
plt.plot(resnet_pacs_val_loss, label='ResNet Val')
plt.plot(vit_pacs_val_loss, label='ViT Val')
plt.title('Validation Loss')
plt.legend()

plt.tight_layout()
plt.show()

# Print final results
print("\nFinal Domain Generalization Results:")
print(f"ResNet-50: Sketch Domain Accuracy: {resnet_sketch_acc:.2f}%")
print(f"ViT-B/16: Sketch Domain Accuracy: {vit_sketch_acc:.2f}%")