In [1]:
# Download PACS Dataset Images
!git clone https://github.com/MachineLearning2020/Homework3-PACS/
!mv Homework3-PACS/PACS/ .
!rm -r Homework3-PACS/

# Download PACS Dataset Labels
!git clone https://github.com/silvia1993/DANN_Template/
!mv DANN_Template/txt_lists/art_painting.txt PACS/
!mv DANN_Template/txt_lists/cartoon.txt PACS/
!mv DANN_Template/txt_lists/photo.txt PACS/
!mv DANN_Template/txt_lists/sketch.txt PACS/
!rm -r DANN_Template/

# Install additional libraries
!pip install torchmetrics

In [2]:
import torch

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_CLASSES = 7
BATCH_SIZE = 256
LR = 1e-3            # The initial Learning Rate
MOMENTUM = 0.9       # Hyperparameter for SGD, keep this at 0.9 when using SGD
WEIGHT_DECAY = 5e-5  # Regularization, you can keep this at the default
NUM_EPOCHS = 30      # Total number of training epochs (iterations over dataset)
STEP_SIZE = 20       # How many epochs before decreasing learning rate (if using a step-down policy)
GAMMA = 0.1          # Multiplicative factor for learning rate step-down

In [3]:
from torch.utils.data import Dataset
from PIL import Image

# Define the Dataset class
class PACSDataset(Dataset):
    def __init__(self, domain, transform):
        assert domain in ['photo', 'art_painting', 'cartoon', 'sketch']
        self.examples = [] # (img_path, class_label)
        self.T = transform

        with open(f'PACS/{domain}.txt', 'r') as f:
            lines = f.readlines()

        for line in lines:
            line = line.strip().split()
            img_path = 'PACS/'+line[0]
            class_label = int(line[1])
            self.examples.append((img_path, class_label))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        img_path, class_label = self.examples[index]
        img = Image.open(img_path).convert('RGB')
        img = self.T(img)
        return img, class_label

In [4]:
import torch.nn as nn

# Define AlexNet architecture class
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, num_domains=2):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        # Category classifier
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )
        # Domain classifier
        self.domain_classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_domains)
        )

    def forward(self, x):
        features = self.features(x)
        features = features.view(features.size(0), -1)
        class_outputs = self.classifier(features)
        domain_outputs = self.domain_classifier(features)
        return class_outputs, domain_outputs

In [None]:
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.models import AlexNet_Weights
import torch.nn.functional as F
from torchmetrics import Accuracy
from tqdm import tqdm

#### DATA SETUP
# Define the transforms to use on images
dataset_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define the Dataset object for training & testing
train_dataset = PACSDataset(domain='cartoon', transform=dataset_transform)
test_dataset = PACSDataset(domain='sketch', transform=dataset_transform)

# Define the DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=4)


#### ARCHITECTURE SETUP
# Create the Network Architecture object
model = AlexNet()
# Load pre-trained weights
model.load_state_dict(AlexNet_Weights.IMAGENET1K_V1.get_state_dict(progress=True), strict=False)
# Overwrite the final classifier layer as we only have 7 classes in PACS
model.classifier[-1] = nn.Linear(4096, NUM_CLASSES)


#### TRAINING SETUP
# Move model to device before passing it to the optimizer
model = model.to(DEVICE)

# Create Optimizer & Scheduler objects
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)


#### TRAINING LOOP
model.train()
if False:
    # Baseline
    for epoch in range(NUM_EPOCHS):
        epoch_loss = [0.0, 0]
        for x, y in tqdm(train_loader):
            x, y = x.to(DEVICE), y.to(DEVICE)
            # x --> [B x C x H x W]
    
            # Category Loss
            cls_o, _ = model(x)
            loss = F.cross_entropy(cls_o, y)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            epoch_loss[0] += loss.item()
            epoch_loss[1] += x.size(0)
            
        scheduler.step()
        print(f'[EPOCH {epoch+1}] Avg. Loss: {epoch_loss[0] / epoch_loss[1]}')
else:
    # DANN
    LAMBDA = 1e-4
    for epoch in range(NUM_EPOCHS):
        epoch_loss = [0.0, 0]
        for batch_idx, ((src_x, src_y), (trg_x, _)) in tqdm(enumerate(zip(train_loader, test_loader))):
            src_x, src_y = src_x.to(DEVICE), src_y.to(DEVICE)
            trg_x = trg_x.to(DEVICE)

            src_cls_o, src_dom_o = model(src_x)
            _, trg_dom_o = model(trg_x)

            if batch_idx % 2 == 0:
                # Classification Loss
                loss = F.cross_entropy(src_cls_o, src_y)
                
            else:
                # Classification Loss
                cls_loss = F.cross_entropy(src_cls_o, src_y)
    
                # Source Domain Adversarial Loss --> src_dom_label = 0
                src_dom_label = torch.zeros(src_dom_o.size(0)).long().to(DEVICE)
                src_dom_loss = F.cross_entropy(src_dom_o, src_dom_label)
    
                # Target Domain Adversarial Loss --> trg_dom_label = 1
                trg_dom_label = torch.ones(trg_dom_o.size(0)).long().to(DEVICE)
                trg_dom_loss = F.cross_entropy(trg_dom_o, trg_dom_label)
    
                # Final Loss
                loss = cls_loss - LAMBDA * (src_dom_loss + trg_dom_loss)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            epoch_loss[0] += loss.item()
            epoch_loss[1] += src_x.size(0)
            
        scheduler.step()
        print(f'[EPOCH {epoch+1}] Avg. Loss: {epoch_loss[0] / epoch_loss[1]}')
# pip install wandb

#### TEST LOOP
model.eval()

meter = Accuracy(task='multiclass', num_classes=NUM_CLASSES).to(DEVICE)

with torch.no_grad():
    for x, y in tqdm(test_loader):
        x, y = x.to(DEVICE), y.to(DEVICE)
        cls_o, _ = model(x)
        meter.update(cls_o, y)
accuracy = meter.compute()

print(f'\nAccuracy on the target domain: {100 * accuracy:.2f}%')