In [None]:
import sys, os, warnings
from argparse import Namespace
warnings.filterwarnings("ignore")

import torch as ch
from torchvision import transforms
import numpy as np
from PIL import Image
from tqdm import tqdm

from helpers import classifier_helpers
import helpers.data_helpers as dh
import helpers.context_helpers as coh
import helpers.rewrite_helpers as rh
import helpers.vis_helpers as vh

%matplotlib inline

In [None]:
DATASET_NAME = 'ImageNet'
LAYERNUM = 12
REWRITE_MODE = 'editing'
ARCH = 'vgg16'

## Load model

In [None]:
ret = classifier_helpers.get_default_paths(DATASET_NAME, arch=ARCH)
DATASET_PATH, MODEL_PATH, MODEL_CLASS, ARCH, CD = ret
ret = classifier_helpers.load_classifier(MODEL_PATH, MODEL_CLASS, ARCH,
                            DATASET_NAME, LAYERNUM) 
model, context_model, target_model = ret[:3]

## Load base dataset and vehicles-on-snow data

In [None]:
base_dataset, train_loader, val_loader = dh.get_dataset(DATASET_NAME, DATASET_PATH,
                                                        batch_size=32, workers=8)

In [None]:
train_data, test_data = dh.get_vehicles_on_snow_data(DATASET_NAME, CD)

In [None]:
print("Train exemplars")
vh.show_image_row([train_data['imgs'], train_data['masks'], train_data['modified_imgs']], 
                  ['Original', 'Mask', 'Modified'], fontsize=20)

In [None]:
print("Flickr-sourced test set")
for c, x in test_data.items():
    vh.show_image_row([x[:5]], title=f'{CD[c]} ({c})')

## Evaluate model performance on test set pre-rewriting

In [None]:
print("Original accuracy on test vehicles-on-snow data")

RESULTS = {k: {'preds': {}, 'acc': {}} for k in ['pre', 'post']}
for c, x in test_data.items():
    with ch.no_grad():
        pred = model(x.cuda()).argmax(axis=1)
    correct = [p for p in pred if p == c]
    acc = 100 * len(correct) / len(x)
    print(f'Class: {c}/{CD[c]} | Accuracy: {acc:.2f}',) 
    RESULTS['pre']['acc'][c] = acc
    RESULTS['pre']['preds'][c] = pred

## Perform re-write

In [None]:
train_args = {'ntrain': 1, # Number of exemplars
            'arch': ARCH, # Network architecture
            'mode_rewrite': REWRITE_MODE, # Rewriting method ['editing', 'finetune_local', 'finetune_global']
            'layernum': LAYERNUM, # Layer to modify
            'nsteps': 20000 if REWRITE_MODE == 'editing' else 400, # Number of rewriting steps  
            'lr': 1e-4, # Learning rate
            'restrict_rank': True, # Whether or not to perform low-rank update
            'nsteps_proj': 10, # Frequency of weight projection
            'rank': 1, # Rank of subspace to project weights
            'use_mask': True # Whether or not to use mask
             }
train_args = Namespace(**train_args)

In [None]:
context_model = rh.edit_classifier(train_args, 
                                   train_data, 
                                   context_model, 
                                   target_model=target_model, 
                                   val_loader=val_loader,
                                   caching_dir=f"./cache/covariances/{DATASET_NAME}_{ARCH}_layer{LAYERNUM}")

## Evaluate model performance on test set post-rewriting

In [None]:
print("Change in accuracy on test vehicles-on-snow data \n")

for c, x in test_data.items():
    with ch.no_grad():
        pred = model(x.cuda()).argmax(axis=1)
    correct = [p for p in pred if p == c]
    acc = 100 * len(correct) / len(x)
    print(f'Class: {c}/{CD[c]} \n Accuracy change: {RESULTS["pre"]["acc"][c]:.2f} -> {acc:.2f} \n',) 
    RESULTS['post']['acc'][c] = acc
    RESULTS['post']['preds'][c] = pred