# Domain Adaptation Via Activation Shaping

**Teacher assistant** Iurada Leonardo

**Students**

- Bar Giorgio
- Distefano Giuseppe
- Incaviglia Salvatore

## 0 - Reproduce the Baseline

### Libraries

In [28]:
!pip install torch torchvision tqdm torchmetrics

import torch

import torch.backends.mps
from argparse import ArgumentParser

import os
import torchvision.transforms as T

from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR10

import numpy as np
import random
from PIL import Image

import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights

import torch.nn.functional as F
from torchmetrics import Accuracy
from tqdm import tqdm

import logging
import warnings



### Setup PACS dataset and environment

In [29]:
# Download PACS Dataset Images and Labels
!git clone https://github.com/MachineLearning2020/Homework3-PACS/
!git clone https://github.com/silvia1993/DANN_Template/

# Setup data
!rm -rf data || true
!rm -rf record || true
!mkdir data
!mkdir data/kfold
!cp -r Homework3-PACS/PACS/ data/kfold
!cp DANN_Template/txt_lists/*.txt data
!rm -rf Homework3-PACS/
!rm -rf DANN_Template/

Cloning into 'Homework3-PACS'...
remote: Enumerating objects: 10032, done.[K
remote: Total 10032 (delta 0), reused 0 (delta 0), pack-reused 10032[K
Receiving objects: 100% (10032/10032), 174.13 MiB | 22.11 MiB/s, done.
Resolving deltas: 100% (1/1), done.
Updating files: 100% (9993/9993), done.
Cloning into 'DANN_Template'...
remote: Enumerating objects: 23, done.[K
remote: Total 23 (delta 0), reused 0 (delta 0), pack-reused 23[K
Receiving objects: 100% (23/23), 33.86 KiB | 642.00 KiB/s, done.
Resolving deltas: 100% (5/5), done.


### Globals

#### globals

In [30]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

CONFIG = dotdict({})

if torch.cuda.is_available():
    CONFIG.device = 'cuda'
elif torch.backends.mps.is_available() and \
    torch.backends.mps.is_built():
    CONFIG.device = 'mps'
else:
    CONFIG.device = 'cpu'

CONFIG.dtype = torch.float32

#### parse args

In [31]:
def _clear_args(parsed_args):
    parsed_args.experiment_args = eval(parsed_args.experiment_args)
    parsed_args.dataset_args = eval(parsed_args.dataset_args)
    return parsed_args


def parse_arguments():
    parser = ArgumentParser()

    parser.add_argument('--seed', type=int, default=0, help='Seed used for deterministic behavior')
    parser.add_argument('--test_only', action='store_true', help='Whether to skip training')
    parser.add_argument('--cpu', action='store_true', help='Whether to force the usage of CPU')

    parser.add_argument('--experiment', type=str, default='baseline')
    parser.add_argument('--experiment_name', type=str, default='baseline')
    parser.add_argument('--experiment_args', type=str, default='{}')
    parser.add_argument('--dataset_args', type=str, default='{}')

    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--num_workers', type=int, default=5)
    parser.add_argument('--grad_accum_steps', type=int, default=1)

    return _clear_args(parser.parse_args())


def parse_fake_arguments(conf):
    parser = ArgumentParser()

    parser.add_argument('--seed', type=int, default=0, help='Seed used for deterministic behavior')
    parser.add_argument('--test_only', action='store_true', help='Whether to skip training')
    parser.add_argument('--cpu', action='store_true', help='Whether to force the usage of CPU')

    parser.add_argument('--experiment', type=str, default='baseline')
    #parser.add_argument('--experiment_name', type=str, default='baseline')
    if conf=='cartoon': parser.add_argument('--experiment_name', type=str, default='baseline/cartoon')
    elif conf=='sketch': parser.add_argument('--experiment_name', type=str, default='baseline/sketch')
    elif conf=='photo': parser.add_argument('--experiment_name', type=str, default='baseline/photo')
    else: parser.add_argument('--experiment_name', type=str, default='baseline/cartoon')
    parser.add_argument('--experiment_args', type=str, default='{}')
    #parser.add_argument('--dataset_args', type=str, default="{'root': 'data/PACS', 'source_domain': 'art_painting', 'target_domain': '${target_domain}'}")
    if conf=='cartoon': parser.add_argument('--dataset_args', type=str, default="{'root': 'data/PACS', 'source_domain': 'art_painting', 'target_domain': 'cartoon'}")
    elif conf=='sketch': parser.add_argument('--dataset_args', type=str, default="{'root': 'data/PACS', 'source_domain': 'art_painting', 'target_domain': 'sketch'}")
    elif conf=='photo': parser.add_argument('--dataset_args', type=str, default="{'root': 'data/PACS', 'source_domain': 'art_painting', 'target_domain': 'photo'}")
    else: parser.add_argument('--dataset_args', type=str, default="{'root': 'data/PACS', 'source_domain': 'art_painting', 'target_domain': 'cartoon'}")

    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--num_workers', type=int, default=5)
    parser.add_argument('--grad_accum_steps', type=int, default=1)

    return _clear_args(parser.parse_args())

#### Global variables

In [32]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_CLASSES = 7
BATCH_SIZE = 256
LR = 1e-3               # The initial Learning Rate
MOMENTUM = 0.9          # Hyperparameter for SGD, keep this at 0.9 when using SGD
WEIGHT_DECAY = 5e-5     # Regularization, you can keep this at the default
NUM_EPOCHS = 30         # Total number of training epochs (iterations over dataset)
STEP_SIZE = 20          # How many epochs before decreasing learning rate (if using a step-down policy)
GAMMA = 0.1             # Multiplicative factor for learning rate step-down

LOG_FREQUENCY = 10

### Dataset

#### utils

In [33]:
class BaseDataset(Dataset):
    def __init__(self, examples, transform):
        self.examples = examples
        self.T = transform

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        x, y = self.examples[index]
        x = Image.open(x).convert('RGB')
        x = self.T(x).to(CONFIG.dtype)
        y = torch.tensor(y).long()
        return x, y

######################################################
# TODO: modify 'BaseDataset' for the Domain Adaptation setting.
# Hint: randomly sample 'target_examples' to obtain targ_x
#class DomainAdaptationDataset(Dataset):
#    def __init__(self, source_examples, target_examples, transform):
#        self.source_examples = source_examples
#        self.target_examples = target_examples
#        self.T = transform
#
#    def __len__(self):
#        return len(self.source_examples)
#
#    def __getitem__(self, index):
#        src_x, src_y = ...
#        targ_x = ...
#
#        src_x = self.T(src_x)
#        targ_x = self.T(targ_x)
#        return src_x, src_y, targ_x

# [OPTIONAL] TODO: modify 'BaseDataset' for the Domain Generalization setting.
# Hint: combine the examples from the 3 source domains into a single 'examples' list
#class DomainGeneralizationDataset(Dataset):
#    def __init__(self, examples, transform):
#        self.examples = examples
#        self.T = transform
#
#    def __len__(self):
#        return len(self.examples)
#
#    def __getitem__(self, index):
#        x1, x2, x3 = self.examples[index]
#        x1, x2, x3 = self.T(x1), self.T(x2), self.T(x3)
#        targ_x = self.T(targ_x)
#        return x1, x2, x3

######################################################

class SeededDataLoader(DataLoader):
    def __init__(self, dataset: Dataset, batch_size=1, shuffle=None,
                 sampler=None,
                 batch_sampler=None,
                 num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None,
                 generator=None, *, prefetch_factor=None, persistent_workers=False,
                 pin_memory_device=""):

        if not CONFIG.use_nondeterministic:
            def seed_worker(worker_id):
                worker_seed = torch.initial_seed() % 2**32
                np.random.seed(worker_seed)
                random.seed(worker_seed)

            generator = torch.Generator()
            generator.manual_seed(CONFIG.seed)

            worker_init_fn = seed_worker

        super().__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn,
                         pin_memory, drop_last, timeout, worker_init_fn, multiprocessing_context, generator,
                         prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
                         pin_memory_device=pin_memory_device)



#### PACS

In [34]:
def get_transform(size, mean, std, preprocess):
    transform = []
    if preprocess:
        transform.append(T.Resize(256))
        transform.append(T.RandomResizedCrop(size=size, scale=(0.7, 1.0)))
        transform.append(T.RandomHorizontalFlip())
    else:
        transform.append(T.Resize(size))
    transform.append(T.ToTensor())
    transform.append(T.Normalize(mean, std))
    return T.Compose(transform)


def load_data():
    CONFIG.num_classes = 7
    CONFIG.data_input_size = (3, 224, 224)

    # Create transforms
    mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) # ImageNet Pretrain statistics
    train_transform = get_transform(size=224, mean=mean, std=std, preprocess=True)
    test_transform = get_transform(size=224, mean=mean, std=std, preprocess=False)

    # Load examples & create Dataset
    if CONFIG.experiment in ['baseline', 'random']:
        source_examples, target_examples = [], []

        # Load source
        #with open(os.path.join(CONFIG.dataset_args['text_root'], f"{CONFIG.dataset_args['source_domain']}.txt"), 'r') as f:
            #lines = f.readlines()
        # domain/category/sample n
        f = open(os.path.join(CONFIG.dataset_args['text_root'], f"{CONFIG.dataset_args['source_domain']}.txt"), 'r')
        for line in f:
            line = line.strip().split()
            path, label = line[0].split('/')[0:], int(line[1])
            source_examples.append((os.path.join(CONFIG.dataset_args['images_root'], *path), label))
        f.close()

        # Load target
        #with open(os.path.join(CONFIG.dataset_args['text_root'], f"{CONFIG.dataset_args['target_domain']}.txt"), 'r') as f:
            #lines = f.readlines()
        f = open(os.path.join(CONFIG.dataset_args['text_root'], f"{CONFIG.dataset_args['target_domain']}.txt"), 'r')
        for line in f:
            line = line.strip().split()
            path, label = line[0].split('/')[0:], int(line[1])
            target_examples.append((os.path.join(CONFIG.dataset_args['images_root'], *path), label))
        f.close()

        train_dataset = BaseDataset(source_examples, transform=train_transform)
        test_dataset = BaseDataset(target_examples, transform=test_transform)

    ######################################################
    #elif... TODO: Add here how to create the Dataset object for the other experiments


    ######################################################

    # Dataloaders
    train_loader = SeededDataLoader(
        train_dataset,
        batch_size=CONFIG.batch_size,
        shuffle=True,
        num_workers=CONFIG.num_workers,
        pin_memory=True,
        persistent_workers=True
    )

    test_loader = SeededDataLoader(
        test_dataset,
        batch_size=CONFIG.batch_size,
        shuffle=False,
        num_workers=CONFIG.num_workers,
        pin_memory=True,
        persistent_workers=True
    )

    return {'train': train_loader, 'test': test_loader}

### Models

#### ResNet

In [35]:
class BaseResNet18(nn.Module):
    def __init__(self):
        super(BaseResNet18, self).__init__()
        self.resnet = resnet18(weights=ResNet18_Weights)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 7)

    def forward(self, x):
        return self.resnet(x)

######################################################
# TODO: either define the Activation Shaping Module as a nn.Module
#class ActivationShapingModule(nn.Module):
#...
#
# OR as a function that shall be hooked via 'register_forward_hook'
#def activation_shaping_hook(module, input, output):
#...
#
######################################################
# TODO: modify 'BaseResNet18' including the Activation Shaping Module
#class ASHResNet18(nn.Module):
#    def __init__(self):
#        super(ASHResNet18, self).__init__()
#        ...
#
#    def forward(self, x):
#        ...
#
######################################################


### Training functions

In [36]:
import argparse

@torch.no_grad()
def evaluate(model, data):
    model.eval()

    acc_meter = Accuracy(task='multiclass', num_classes=CONFIG.num_classes)
    acc_meter = acc_meter.to(CONFIG.device)

    loss = [0.0, 0]
    for x, y in tqdm(data):
        with torch.autocast(device_type=CONFIG.device, dtype=torch.float16, enabled=True):
            x, y = x.to(CONFIG.device), y.to(CONFIG.device)
            logits = model(x)
            acc_meter.update(logits, y)
            loss[0] += F.cross_entropy(logits, y).item()
            loss[1] += x.size(0)

    accuracy = acc_meter.compute()
    loss = loss[0] / loss[1]
    logging.info(f'Accuracy: {100 * accuracy:.2f} - Loss: {loss}')
    print(f'Accuracy: {100 * accuracy:.2f} - Loss: {loss} \n')


def train(model, data):

    # Create optimizers & schedulers
    optimizer = torch.optim.SGD(model.parameters(), weight_decay=0.0005, momentum=0.9, nesterov=True, lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(CONFIG.epochs * 0.8), gamma=0.1)
    scaler = torch.cuda.amp.GradScaler(enabled=True)

    # Load checkpoint (if it exists)
    cur_epoch = 0
    if os.path.exists(os.path.join('record', CONFIG.experiment_name, 'last.pth')):
        checkpoint = torch.load(os.path.join('record', CONFIG.experiment_name, 'last.pth'))
        cur_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        model.load_state_dict(checkpoint['model'])

    # Optimization loop
    for epoch in range(cur_epoch, CONFIG.epochs):
        model.train()

        for batch_idx, batch in enumerate(tqdm(data['train'])):

            # Compute loss
            with torch.autocast(device_type=CONFIG.device, dtype=torch.float16, enabled=True):

                if CONFIG.experiment in ['baseline', 'random']:
                    x, y = batch
                    x, y = x.to(CONFIG.device), y.to(CONFIG.device)
                    loss = F.cross_entropy(model(x), y)

                ######################################################
                #elif... TODO: Add here train logic for the other experiments

                ######################################################

            # Optimization step
            scaler.scale(loss / CONFIG.grad_accum_steps).backward()

            if ((batch_idx + 1) % CONFIG.grad_accum_steps == 0) or (batch_idx + 1 == len(data['train'])):
                scaler.step(optimizer)
                optimizer.zero_grad(set_to_none=True)
                scaler.update()

        scheduler.step()

        # Test current epoch
        logging.info(f'[TEST @ Epoch={epoch}]')
        evaluate(model, data['test'])

        # Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'model': model.state_dict()
        }
        torch.save(checkpoint, os.path.join('record', CONFIG.experiment_name, 'last.pth'))


### Run (baseline)

In [45]:
import argparse

def main():

    # Load dataset
    data = load_data()

    # Load model
    if CONFIG.experiment in ['baseline']:
        model = BaseResNet18()

    ######################################################
    #elif... TODO: Add here model loading for the other experiments (eg. DA and optionally DG)

    ######################################################

    model.to(CONFIG.device)

    if not CONFIG.test_only:
        train(model, data)
    else:
        evaluate(model, data['test'])


if __name__ == '__main__':
    warnings.filterwarnings('ignore', category=UserWarning)

    # Parse arguments (uncomment only one of the following)
    '''
    confs = [
        { # Cartoon
            'seed': 0,
            'test_only': False,
            'cpu': False,
            'experiment': 'baseline',
            'experiment_name': 'baseline/cartoon/',
            'experiment_args': '{}',
            'dataset_args': {'text_root': 'data', 'images_root': 'data/kfold/PACS', 'source_domain': 'art_painting', 'target_domain': 'cartoon'},
            'batch_size': 128,
            'epochs': 30,
            'num_workers': 5,
            'grad_accum_steps': 1
        }
    ]
    '''
    '''
    confs = [
        { # Sketch
            'seed': 0,
            'test_only': False,
            'cpu': False,
            'experiment': 'baseline',
            'experiment_name': 'baseline/sketch/',
            'experiment_args': '{}',
            'dataset_args': {'text_root': 'data', 'images_root': 'data/kfold/PACS', 'source_domain': 'art_painting', 'target_domain': 'sketch'},
            'batch_size': 128,
            'epochs': 30,
            'num_workers': 5,
            'grad_accum_steps': 1
        }
    ]
    '''
    confs = [
        { # Photo
            'seed': 0,
            'test_only': False,
            'cpu': False,
            'experiment': 'baseline',
            'experiment_name': 'baseline/photo/',
            'experiment_args': '{}',
            'dataset_args': {'text_root': 'data', 'images_root': 'data/kfold/PACS', 'source_domain': 'art_painting', 'target_domain': 'photo'},
            'batch_size': 128,
            'epochs': 30,
            'num_workers': 5,
            'grad_accum_steps': 1
        }
    ]

    #args = parse_arguments()
    for conf in confs:
        args = argparse.Namespace(**conf)
        CONFIG.update(vars(args))

        # Setup output directory
        CONFIG.save_dir = os.path.join('record', CONFIG.experiment_name)
        os.makedirs(CONFIG.save_dir, exist_ok=True)

        # Setup logging
        logging.basicConfig(
            filename=os.path.join(CONFIG.save_dir, 'log.txt'),
            format='%(message)s',
            level=logging.INFO,
            filemode='a'
        )

        # Set experiment's device & deterministic behavior
        if CONFIG.cpu:
            CONFIG.device = 'cpu'
        else:
            CONFIG.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        torch.manual_seed(CONFIG.seed)
        random.seed(CONFIG.seed)
        np.random.seed(CONFIG.seed)
        torch.backends.cudnn.benchmark = True
        torch.use_deterministic_algorithms(mode=True, warn_only=True)

    main()


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 152MB/s]
100%|██████████| 16/16 [00:11<00:00,  1.45it/s]
100%|██████████| 14/14 [00:05<00:00,  2.51it/s]


Accuracy: 60.78 - Loss: 0.011620737050107853 



100%|██████████| 16/16 [00:09<00:00,  1.75it/s]
100%|██████████| 14/14 [00:04<00:00,  2.95it/s]


Accuracy: 79.82 - Loss: 0.006599299101058594 



100%|██████████| 16/16 [00:11<00:00,  1.42it/s]
100%|██████████| 14/14 [00:04<00:00,  3.30it/s]


Accuracy: 87.78 - Loss: 0.0041108920545635105 



100%|██████████| 16/16 [00:09<00:00,  1.63it/s]
100%|██████████| 14/14 [00:04<00:00,  3.26it/s]


Accuracy: 89.40 - Loss: 0.00310309904064247 



100%|██████████| 16/16 [00:09<00:00,  1.61it/s]
100%|██████████| 14/14 [00:04<00:00,  3.28it/s]


Accuracy: 91.14 - Loss: 0.0024948941287166342 



100%|██████████| 16/16 [00:09<00:00,  1.63it/s]
100%|██████████| 14/14 [00:04<00:00,  3.21it/s]


Accuracy: 92.16 - Loss: 0.0021858058602153185 



100%|██████████| 16/16 [00:09<00:00,  1.64it/s]
100%|██████████| 14/14 [00:04<00:00,  3.19it/s]


Accuracy: 92.69 - Loss: 0.0019666722479337703 



100%|██████████| 16/16 [00:09<00:00,  1.63it/s]
100%|██████████| 14/14 [00:04<00:00,  3.24it/s]


Accuracy: 93.17 - Loss: 0.0018043650704586576 



100%|██████████| 16/16 [00:09<00:00,  1.66it/s]
100%|██████████| 14/14 [00:04<00:00,  3.26it/s]


Accuracy: 93.53 - Loss: 0.0017245613350839672 



100%|██████████| 16/16 [00:09<00:00,  1.61it/s]
100%|██████████| 14/14 [00:04<00:00,  3.24it/s]


Accuracy: 94.13 - Loss: 0.0015787166303503298 



100%|██████████| 16/16 [00:10<00:00,  1.59it/s]
100%|██████████| 14/14 [00:04<00:00,  3.25it/s]


Accuracy: 94.31 - Loss: 0.0015136572080636454 



100%|██████████| 16/16 [00:09<00:00,  1.61it/s]
100%|██████████| 14/14 [00:04<00:00,  3.29it/s]


Accuracy: 94.37 - Loss: 0.0014615919351756216 



100%|██████████| 16/16 [00:09<00:00,  1.62it/s]
100%|██████████| 14/14 [00:04<00:00,  3.23it/s]


Accuracy: 94.91 - Loss: 0.0013111949937340028 



100%|██████████| 16/16 [00:09<00:00,  1.67it/s]
100%|██████████| 14/14 [00:04<00:00,  3.27it/s]


Accuracy: 94.85 - Loss: 0.0013379836811901566 



100%|██████████| 16/16 [00:09<00:00,  1.73it/s]
100%|██████████| 14/14 [00:04<00:00,  3.01it/s]


Accuracy: 94.85 - Loss: 0.0013176740699274811 



100%|██████████| 16/16 [00:08<00:00,  1.83it/s]
100%|██████████| 14/14 [00:05<00:00,  2.72it/s]


Accuracy: 95.21 - Loss: 0.0012480593158188695 



100%|██████████| 16/16 [00:08<00:00,  1.90it/s]
100%|██████████| 14/14 [00:05<00:00,  2.58it/s]


Accuracy: 95.63 - Loss: 0.0011855909154711369 



100%|██████████| 16/16 [00:08<00:00,  1.99it/s]
100%|██████████| 14/14 [00:05<00:00,  2.36it/s]


Accuracy: 95.81 - Loss: 0.001148534507607807 



100%|██████████| 16/16 [00:08<00:00,  1.95it/s]
100%|██████████| 14/14 [00:06<00:00,  2.29it/s]


Accuracy: 95.39 - Loss: 0.0011825696860780258 



100%|██████████| 16/16 [00:08<00:00,  1.99it/s]
100%|██████████| 14/14 [00:06<00:00,  2.30it/s]


Accuracy: 95.51 - Loss: 0.0011261282310246706 



100%|██████████| 16/16 [00:07<00:00,  2.02it/s]
100%|██████████| 14/14 [00:06<00:00,  2.27it/s]


Accuracy: 95.57 - Loss: 0.00110271982625573 



100%|██████████| 16/16 [00:08<00:00,  2.00it/s]
100%|██████████| 14/14 [00:06<00:00,  2.23it/s]


Accuracy: 95.63 - Loss: 0.0010885661646366834 



100%|██████████| 16/16 [00:09<00:00,  1.65it/s]
100%|██████████| 14/14 [00:06<00:00,  2.33it/s]


Accuracy: 95.81 - Loss: 0.001059915064025425 



100%|██████████| 16/16 [00:08<00:00,  2.00it/s]
100%|██████████| 14/14 [00:06<00:00,  2.22it/s]


Accuracy: 95.69 - Loss: 0.001080644929859631 



100%|██████████| 16/16 [00:08<00:00,  1.99it/s]
100%|██████████| 14/14 [00:06<00:00,  2.30it/s]


Accuracy: 95.63 - Loss: 0.0010905734570083505 



100%|██████████| 16/16 [00:08<00:00,  2.00it/s]
100%|██████████| 14/14 [00:06<00:00,  2.28it/s]


Accuracy: 95.75 - Loss: 0.0010759724281772882 



100%|██████████| 16/16 [00:08<00:00,  1.97it/s]
100%|██████████| 14/14 [00:05<00:00,  2.33it/s]


Accuracy: 95.75 - Loss: 0.0010753345932298436 



100%|██████████| 16/16 [00:08<00:00,  1.98it/s]
100%|██████████| 14/14 [00:05<00:00,  2.43it/s]


Accuracy: 95.81 - Loss: 0.001072866622425482 



100%|██████████| 16/16 [00:08<00:00,  1.94it/s]
100%|██████████| 14/14 [00:05<00:00,  2.56it/s]


Accuracy: 95.69 - Loss: 0.001083402799074343 



100%|██████████| 16/16 [00:08<00:00,  1.89it/s]
100%|██████████| 14/14 [00:05<00:00,  2.77it/s]


Accuracy: 95.81 - Loss: 0.0010701698529907687 



## 1 - Activation Shaping module

In [46]:
def get_activation_shaping_hook(mask):
    # The hook captures mask variable from the parent scope, to update it,
    # remove() the hook and register a new one with the updated mask.

    def activation_shaping_hook(module, input, output):

        # binarize both activation map and mask using zero as threshold
        # print(output.shape, mask.shape)
        A_binary = torch.where(output <= 0, torch.tensor(0.0), torch.tensor(1.0))
        M_binary = torch.where(mask <= 0, torch.tensor(0.0), torch.tensor(1.0))

        # return the element-wise product of activation map and mask
        shaped_output = A_binary * M_binary
        # print(output.sum(), shaped_output.sum())
        return shaped_output
    return activation_shaping_hook

class ASHResNet18(nn.Module):
    def __init__(self):
        super(ASHResNet18, self).__init__()
        self.resnet = resnet18(pretrained=False)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 7)

    def register_activation_shaping_hook(self, layer_name = 'layer4.1.relu', mask_out_ratio = 0.0):
        self.layer_name = layer_name
        self.mask_out_ratio = mask_out_ratio

        # create a mask tensor with a given ratio of zeros
        print('self.mask_out_ratio: ', self.mask_out_ratio)
        rand_mat = torch.rand(512, 7, 7).to(CONFIG.device)
        mask = torch.where(rand_mat <= self.mask_out_ratio, 0.0, 1.0).to(CONFIG.device)

        hook = get_activation_shaping_hook(mask)
        # penultimate_layer = self.resnet.layer4[1].bn2
        # self.hook_handle = penultimate_layer.register_forward_hook(hook)

        for name, module in self.resnet.named_modules():
          if (isinstance(module, nn.ReLU) and name == self.layer_name):
            print('Insert activation shaping layer after ', name, module)
            self.hook_handle = module.register_forward_hook(hook)

    def remove_activation_shaping_hook(self):
        if self.hook_handle is not None:
            self.hook_handle.remove()

    def forward(self, x):
        return self.resnet(x)


## 2 - Random Activation Maps

In [49]:
def main_random():

    # Load dataset
    data = load_data()

    # Load model
    if CONFIG.experiment in ['random']:
        model = ASHResNet18()
        # model.register_activation_shaping_hook(layer_name = 'layer4.1.relu', mask_out_ratio = 0.55)
        model.register_activation_shaping_hook(layer_name = 'layer4.0.relu', mask_out_ratio = 0.55)

    model.to(CONFIG.device)

    if not CONFIG.test_only:
        train(model, data)
    else:
        evaluate(model, data['test'])

if __name__ == '__main__':
    warnings.filterwarnings('ignore', category=UserWarning)

    # Parse arguments (uncomment only one of the following)
    '''
    confs = [
        { # Cartoon
            'seed': 0,
            'test_only': False,
            'cpu': False,
            'experiment': 'random',
            'experiment_name': 'random/cartoon/',
            'experiment_args': '{}',
            'dataset_args': {'text_root': 'data', 'images_root': 'data/kfold/PACS', 'source_domain': 'art_painting', 'target_domain': 'cartoon'},
            'batch_size': 128,
            'epochs': 30,
            'num_workers': 5,
            'grad_accum_steps': 1
        }
    ]
    '''
    '''
    confs = [
        { # Sketch
            'seed': 0,
            'test_only': False,
            'cpu': False,
            'experiment': 'random',
            'experiment_name': 'random/sketch/',
            'experiment_args': '{}',
            'dataset_args': {'text_root': 'data', 'images_root': 'data/kfold/PACS', 'source_domain': 'art_painting', 'target_domain': 'sketch'},
            'batch_size': 128,
            'epochs': 30,
            'num_workers': 5,
            'grad_accum_steps': 1
        }
    ]
    '''
    confs = [
        { # Photo
            'seed': 0,
            'test_only': False,
            'cpu': False,
            'experiment': 'random',
            'experiment_name': 'random/photo/',
            'experiment_args': '{}',
            'dataset_args': {'text_root': 'data', 'images_root': 'data/kfold/PACS', 'source_domain': 'art_painting', 'target_domain': 'photo'},
            'batch_size': 128,
            'epochs': 30,
            'num_workers': 5,
            'grad_accum_steps': 1
        }
    ]

    #args = parse_arguments()
    for conf in confs:
        args = argparse.Namespace(**conf)
        print(args)
        CONFIG.update(vars(args))

        # Setup output directory
        CONFIG.save_dir = os.path.join('record', CONFIG.experiment_name)
        os.makedirs(CONFIG.save_dir, exist_ok=True)

        # Setup logging
        logging.basicConfig(
            filename=os.path.join(CONFIG.save_dir, 'log.txt'),
            format='%(message)s',
            level=logging.INFO,
            filemode='a'
        )

        # Set experiment's device & deterministic behavior
        if CONFIG.cpu:
            CONFIG.device = 'cpu'
        else:
            CONFIG.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        torch.manual_seed(CONFIG.seed)
        random.seed(CONFIG.seed)
        np.random.seed(CONFIG.seed)
        torch.backends.cudnn.benchmark = True
        torch.use_deterministic_algorithms(mode=True, warn_only=True)

    main_random()


Namespace(seed=0, test_only=False, cpu=False, experiment='random', experiment_name='random/photo/', experiment_args='{}', dataset_args={'text_root': 'data', 'images_root': 'data/kfold/PACS', 'source_domain': 'art_painting', 'target_domain': 'photo'}, batch_size=128, epochs=30, num_workers=5, grad_accum_steps=1)
self.mask_out_ratio:  0.55
Insert activation shaping layer after  layer4.0.relu ReLU(inplace=True)


100%|██████████| 16/16 [00:09<00:00,  1.64it/s]
100%|██████████| 14/14 [00:05<00:00,  2.74it/s]


Accuracy: 25.87 - Loss: 0.01576805493074977 



100%|██████████| 16/16 [00:10<00:00,  1.60it/s]
100%|██████████| 14/14 [00:04<00:00,  3.05it/s]


Accuracy: 26.05 - Loss: 0.015808126812209625 



100%|██████████| 16/16 [00:09<00:00,  1.65it/s]
100%|██████████| 14/14 [00:04<00:00,  3.04it/s]


Accuracy: 26.71 - Loss: 0.015702521372698026 



100%|██████████| 16/16 [00:09<00:00,  1.66it/s]
100%|██████████| 14/14 [00:04<00:00,  3.10it/s]


Accuracy: 27.07 - Loss: 0.015656027751054593 



100%|██████████| 16/16 [00:09<00:00,  1.69it/s]
100%|██████████| 14/14 [00:04<00:00,  3.20it/s]


Accuracy: 26.83 - Loss: 0.015639504986608814 



100%|██████████| 16/16 [00:09<00:00,  1.68it/s]
100%|██████████| 14/14 [00:04<00:00,  3.26it/s]


Accuracy: 26.89 - Loss: 0.015524127526197605 



100%|██████████| 16/16 [00:11<00:00,  1.39it/s]
100%|██████████| 14/14 [00:04<00:00,  3.22it/s]


Accuracy: 27.43 - Loss: 0.015527444256993825 



100%|██████████| 16/16 [00:09<00:00,  1.70it/s]
100%|██████████| 14/14 [00:04<00:00,  3.27it/s]


Accuracy: 27.31 - Loss: 0.015457651286781905 



100%|██████████| 16/16 [00:08<00:00,  1.83it/s]
100%|██████████| 14/14 [00:04<00:00,  2.85it/s]


Accuracy: 27.07 - Loss: 0.015445040728517635 



100%|██████████| 16/16 [00:08<00:00,  1.95it/s]
100%|██████████| 14/14 [00:05<00:00,  2.51it/s]


Accuracy: 27.01 - Loss: 0.015379162962565165 



100%|██████████| 16/16 [00:07<00:00,  2.12it/s]
100%|██████████| 14/14 [00:06<00:00,  2.28it/s]


Accuracy: 27.19 - Loss: 0.015333800758430344 



100%|██████████| 16/16 [00:07<00:00,  2.11it/s]
100%|██████████| 14/14 [00:06<00:00,  2.22it/s]


Accuracy: 27.01 - Loss: 0.015362573169662566 



100%|██████████| 16/16 [00:07<00:00,  2.10it/s]
100%|██████████| 14/14 [00:06<00:00,  2.33it/s]


Accuracy: 26.89 - Loss: 0.01526111442885713 



100%|██████████| 16/16 [00:07<00:00,  2.10it/s]
100%|██████████| 14/14 [00:05<00:00,  2.56it/s]


Accuracy: 26.83 - Loss: 0.01527670443414928 



100%|██████████| 16/16 [00:08<00:00,  1.95it/s]
100%|██████████| 14/14 [00:04<00:00,  2.85it/s]


Accuracy: 26.53 - Loss: 0.015268064615969173 



100%|██████████| 16/16 [00:08<00:00,  1.84it/s]
100%|██████████| 14/14 [00:04<00:00,  3.16it/s]


Accuracy: 25.87 - Loss: 0.01529640459014984 



100%|██████████| 16/16 [00:09<00:00,  1.72it/s]
100%|██████████| 14/14 [00:04<00:00,  3.21it/s]


Accuracy: 27.07 - Loss: 0.015198302126216317 



100%|██████████| 16/16 [00:09<00:00,  1.69it/s]
100%|██████████| 14/14 [00:04<00:00,  3.19it/s]


Accuracy: 26.95 - Loss: 0.01520213659651979 



100%|██████████| 16/16 [00:09<00:00,  1.67it/s]
100%|██████████| 14/14 [00:04<00:00,  3.06it/s]


Accuracy: 26.59 - Loss: 0.015191334378933477 



100%|██████████| 16/16 [00:09<00:00,  1.69it/s]
100%|██████████| 14/14 [00:04<00:00,  3.22it/s]


Accuracy: 25.69 - Loss: 0.01519078844321702 



100%|██████████| 16/16 [00:09<00:00,  1.68it/s]
100%|██████████| 14/14 [00:04<00:00,  3.13it/s]


Accuracy: 27.07 - Loss: 0.015112785165181417 



100%|██████████| 16/16 [00:09<00:00,  1.67it/s]
100%|██████████| 14/14 [00:04<00:00,  3.24it/s]


Accuracy: 26.53 - Loss: 0.0151487278367231 



100%|██████████| 16/16 [00:09<00:00,  1.67it/s]
100%|██████████| 14/14 [00:04<00:00,  3.22it/s]


Accuracy: 26.23 - Loss: 0.015130366251140297 



100%|██████████| 16/16 [00:09<00:00,  1.68it/s]
100%|██████████| 14/14 [00:04<00:00,  3.20it/s]


Accuracy: 27.01 - Loss: 0.015138322293401478 



100%|██████████| 16/16 [00:08<00:00,  1.81it/s]
100%|██████████| 14/14 [00:04<00:00,  2.81it/s]


Accuracy: 26.77 - Loss: 0.015120259944550291 



100%|██████████| 16/16 [00:08<00:00,  1.90it/s]
100%|██████████| 14/14 [00:05<00:00,  2.62it/s]


Accuracy: 26.29 - Loss: 0.015129004838223943 



100%|██████████| 16/16 [00:10<00:00,  1.57it/s]
100%|██████████| 14/14 [00:05<00:00,  2.50it/s]


Accuracy: 26.59 - Loss: 0.015102863133310558 



100%|██████████| 16/16 [00:07<00:00,  2.00it/s]
100%|██████████| 14/14 [00:06<00:00,  2.31it/s]


Accuracy: 25.87 - Loss: 0.015113375251164693 



100%|██████████| 16/16 [00:07<00:00,  2.08it/s]
100%|██████████| 14/14 [00:06<00:00,  2.19it/s]


Accuracy: 26.11 - Loss: 0.015128330199304455 



100%|██████████| 16/16 [00:07<00:00,  2.07it/s]
100%|██████████| 14/14 [00:06<00:00,  2.26it/s]


Accuracy: 26.35 - Loss: 0.015114920938800194 



## 3 - Adapting Activation Maps across Domains

## Ext. 2 - Binarization Ablation