In [None]:
import sys, os, warnings, json
from argparse import Namespace
warnings.filterwarnings("ignore")

from pathlib import Path
import torch as ch
from torchvision import transforms
import numpy as np
from PIL import Image
from tqdm import tqdm

from helpers import classifier_helpers
import helpers.data_helpers as dh
import helpers.context_helpers as coh
import helpers.rewrite_helpers as rh
import helpers.vis_helpers as vh
import helpers.analysis_helpers as ah

random_seed = np.random.randint(0, 1000)

%matplotlib inline

In [None]:
with open('./helpers/config.json') as f:
    args = Namespace(**json.load(f))
    
print(args)

## Load model

In [None]:
ret = classifier_helpers.get_default_paths(args.dataset_name, arch=args.arch)
DATASET_PATH, MODEL_PATH, MODEL_CLASS, ARCH, CD = ret
CD = {k: v.split(',')[0] for k, v in CD.items()}

ret = classifier_helpers.load_classifier(MODEL_PATH, MODEL_CLASS, ARCH,
                            args.dataset_name, args.layernum) 
model, context_model, target_model = ret[:3]

## Load base dataset and synthetic data

In [None]:
base_dataset, train_loader, val_loader = dh.get_dataset(args.dataset_name, DATASET_PATH,
                                                        batch_size=32, workers=8)
preprocessing_transform = None
if args.arch.startswith('clip'):
    base_dataset.transform_test = ret[-1]
    preprocessing_transform = ret[-1]
    _, val_loader = base_dataset.make_loaders(workers=args.num_workers, 
                                         batch_size=args.batch_size, 
                                         shuffle_val=False)

In [None]:
concept_file = f'data/synthetic/segmentations/{args.concepts}/concept_{args.dataset_name}_{args.concepts}_{args.concept_name}.pt'
concept_info = ch.load(concept_file)
concept_info['imgs'] = concept_info['imgs'].to(ch.float32) / 255.
concept_info['masks'] = concept_info['masks'].to(ch.uint8)


In [None]:
data_dict, data_info_dict = dh.obtain_train_test_splits(args, concept_info, 
                                                          CD, 
                                                          args.style, 
                                                          preprocess=preprocessing_transform,
                                                          rng=np.random.RandomState(random_seed))
data_info_dict.update({'style_name': args.style, 'concept_name': args.concept_name})

In [None]:
sidx = np.random.choice(len(data_dict['test_data']['imgs']), 3, replace=False)
vh.show_image_row([data_dict['train_data']['imgs']], title='Train (original)')
vh.show_image_row([data_dict['train_data']['modified_imgs']], title='Train (modified)')
vh.show_image_row([data_dict['test_data']['imgs'][sidx]], title='Test (original)')
vh.show_image_row([data_dict['test_data']['modified_imgs_same'][sidx]], title='Test (modified w/ train style)')
vh.show_image_row([data_dict['test_data']['modified_imgs_diff'][sidx]], title='Test (modified w/ other styles)')

## Evaluate model performance on test set pre-rewriting

In [None]:
# Pre-edit model accuracy
cache_file = f'./cache/accuracy/{args.arch}_{args.dataset_name}.pt'
Path(f'./cache/accuracy/').mkdir(parents=True)
_, _, acc_pre = ah.eval_accuracy(model, val_loader, batch_size=args.batch_size, cache_file=cache_file)

In [None]:
print("Pre-rewrite eval on synthetic data")

log_keys = {'train': ['imgs', 'modified_imgs'],
                'test': ['imgs', 'modified_imgs_same', 'modified_imgs_diff']}
log_labels = {('train', 'imgs'): 'Original train images',
              ('train', 'modified_imgs'): 'Modified train images',
              ('test', 'imgs'): 'Original test images',
              ('test', 'modified_imgs_same'): 'Modified test images w/ train style',
              ('test', 'modified_imgs_diff'): 'Modified test images w/ other style',}

RESULTS = {}
for m in ['train', 'test']:
    for k2 in log_keys[m]: 
        preds = ah.get_preds(context_model, data_dict[f'{m}_data'][k2],
                                                   BS=args.batch_size).numpy()
        acc = 100 * np.mean(preds == data_dict[f'{m}_data']['labels'].numpy())
        print(f"Subset: {log_labels[(m, k2)]} | Accuracy: {acc:.2f}")
        RESULTS[f'{m}_pre_{k2}'] = {'preds': preds, 
                                    'acc': acc}

## Perform re-write

In [None]:
context_model = rh.edit_classifier(args, 
                                   data_dict['train_data'], 
                                   context_model, 
                                   target_model=target_model, 
                                   val_loader=val_loader,
                                   caching_dir=f"./cache/covariances/{args.dataset_name}_{ARCH}_layer{args.layernum}")

## Evaluate model performance on test set post-rewriting

In [None]:
print("Post-rewrite eval on synthetic data")

for m in ['train', 'test']:
    for k2 in log_keys[m]: 
        preds = ah.get_preds(context_model, data_dict[f'{m}_data'][k2],
                                                   BS=args.batch_size).numpy()
        acc = 100 * np.mean(preds == data_dict[f'{m}_data']['labels'].numpy())
        RESULTS[f'{m}_post_{k2}'] = {'preds': preds, 'acc': acc}

In [None]:
RESULTS = ah.evaluate_rewrite_effect(data_dict, RESULTS)

In [None]:
ah.plot_improvement_bar(RESULTS, args)

In [None]:
_, _, acc_post = ah.eval_accuracy(model, val_loader, batch_size=args.batch_size, cache_file=None)