In [None]:
# if peal not installed, but project downloaded locally
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

# import basic libraries needed for sure and set the device depending on whether cuda is available or not
import torch
from peal.utils import request
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# set autoreload for more convinient development
%load_ext autoreload
%autoreload 2

# check and set that the right gpu is used
if device == 'cuda':
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
    !nvidia-smi
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    print('Currently used device: ' + str(os.environ["CUDA_VISIBLE_DEVICES"]))
    os.environ["CUDA_VISIBLE_DEVICES"]= request('cuda_visible_devices', default = "0") 
    torch.cuda.set_device(int(os.environ["CUDA_VISIBLE_DEVICES"]))
    import math
    import nvidia_smi
    nvidia_smi.nvmlInit()
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    gigabyte_vram = info.total / math.pow(10, 9)
    print("Total memory:", gigabyte_vram)

else:
    gigabyte_vram = None

#from IPython.core.debugger import set_trace #set_trace()

confounder_type = request('confounder_type', default = 'copyrighttag')

In [None]:
# if the celeba dataset needs to be downloaded and poised version needs to be created
from peal.data.dataset_generators import ConfounderDatasetGenerator
if not os.path.exists('datasets'):
    os.makedirs('datasets')

# Download the celeba dataset and move the images to folder CELEBA_ROOT
CELEBA_IMG_DIR = request('CELEBA_IMG_DIR', default = '/home/space/datasets/celeba/img_align_celeba')

# move the attribute labels to 
CELEBA_ATTRIBUTE_DIR = request('CELEBA_ATTRIBUTE_DIR', default = '/home/space/datasets/celeba/list_attr_celeba.txt')

# TODO do this also for color and intensity
cdg = ConfounderDatasetGenerator(
    base_dataset_dir = CELEBA_IMG_DIR,
    dataset_name = 'celeba_' + confounder_type,
    label_dir = CELEBA_ATTRIBUTE_DIR,
    delimiter = ' ',
    confounder_type = confounder_type
)
cdg.generate_dataset()

In [None]:
# create the datasets
import copy
from peal.data.datasets import get_datasets
from peal.utils import load_yaml_config
unpoised_dataset_config = load_yaml_config('$PEAL/configs/data/isblond_confounder_celeba.yaml')
unpoised_dataset_train, unpoised_dataset_val, unpoised_dataset_test = get_datasets(
    config = unpoised_dataset_config,
    base_dir = 'datasets/celeba_' + confounder_type
)
poised_datasets = {}
poisoning_degrees = request('poisoning_degrees', ['100', '95', '90', '85', '80'])
for poisoning_degree in poisoning_degrees:
    print(poisoning_degree)
    dataset_config = copy.deepcopy(unpoised_dataset_config)
    dataset_config['num_samples'] = int(unpoised_dataset_config['num_samples'] / 2)
    dataset_config['confounder_probability'] = int(poisoning_degree) / 100
    poised_datasets[poisoning_degree] = get_datasets(
        config = dataset_config,
        base_dir = 'datasets/celeba_' + confounder_type
    )

In [None]:
# if you want to train your own initial student model
from peal.architectures.models import ImgEncoderDecoderModel
from peal.training.trainers import ModelTrainer

for poisoning_degree in poisoning_degrees:
    poised_dataset_train, poised_dataset_val, poised_dataset_test = poised_datasets[poisoning_degree]
    student_config = load_yaml_config('$PEAL/configs/models/celeba_isblond_classifier.yaml')
    student_config['data'] = poised_dataset_train.config

    # create and traing student model
    student = ImgEncoderDecoderModel(student_config).to(device)
    student_trainer = ModelTrainer(
        config = student_config, 
        model = student, 
        datasource = (poised_dataset_train, poised_dataset_val),
        model_name = 'celeba_poised' + poisoning_degree + '_isblond_' + confounder_type + '_classifier'
    )
    student_trainer.fit()

In [None]:
# choose whether to approximate the results by always using strongest poised generator
use_predefined_generator = request('use_predefined_generator', True)
if use_predefined_generator:
    generator_path = request(
        'generator_path',
        'peal_runs/celeba_poised100_' + confounder_type + '_generator/model.cpl'
    )
    generator = torch.load(generator_path).to(device)

else:
    generator = None

In [None]:
teacher_type = request('teacher_type', 'train')
if teacher_type == 'train':
    # if you want to train and use new model for knowledge distillation
    from peal.architectures.models import ImgEncoderDecoderModel
    from peal.training.trainers import ModelTrainer
    teacher_config = load_yaml_config('$PEAL/configs/models/celeba_isblond_classifier.yaml')
    teacher_config['data'] = unpoised_dataset_train.config

    # create and train teacher model
    teacher = ImgEncoderDecoderModel(teacher_config).to(device)
    teacher_trainer = ModelTrainer(
        config = teacher_config,
        model = teacher, 
        datasource = (unpoised_dataset_train, unpoised_dataset_val),
        model_name = request('teacher_model_name', 'celeba_unpoised_isblond_' + confounder_type + '_classifier'),
        gigabyte_vram = gigabyte_vram
    )
    teacher_trainer.fit()
    teacher_type = 'oracle'

elif teacher_type == 'load':
    # if you want to use existing model for knowledge distillation
    teacher_path = request('teacher_path', 'peal_runs/celeba_unpoised_isblond_' + confounder_type + '_classifier/model.cpl')
    teacher = torch.load(teacher_path).to(device)
    teacher_type = 'oracle'

else:
    # if you want to teach the model yourself with the web interface
    teacher = teacher_type
    teacher_type = teacher.split('@')[0]

In [None]:
# create counterfactual mode
from peal.adaptors.counterfactual_knowledge_distillation import CounterfactualKnowledgeDistillation
for it, poisoning_degree in enumerate(poisoning_degrees):
    print(poisoning_degree)
    student_path = os.path.join(
        'peal_runs',
        'celeba_poised' + poisoning_degree + '_isblond_' + confounder_type + '_classifier'
    )
    student = torch.load(
        os.path.join(student_path, 'model.cpl'),
        map_location = device
    )
    cfkd = CounterfactualKnowledgeDistillation(
        student = student,
        datasource = (
            poised_datasets[poisoning_degree][0],
            poised_datasets[poisoning_degree][1],
            unpoised_dataset_test
        ),
        output_size = 2,
        teacher = teacher,
        generator = generator,
        base_dir = 'peal_runs/celeba_poised' + poisoning_degree + '_isblond_' + confounder_type + '_classifier/cfkd_oracle',
        gigabyte_vram = gigabyte_vram
    )
    cfkd.run()

In [None]:
#%debug
# perform P-Clark and measure test accuracies on unpoised dataset
from peal.adaptors.class_artifact_compensation import ClassArtifactCompensation
for it, poisoning_degree in enumerate(poisoning_degrees):
    print(poisoning_degree)
    student_path = os.path.join('peal_runs','celeba_poised' + poisoning_degree + '_isblond_' + confounder_type + '_classifier')
    student = torch.load(
        os.path.join(student_path, 'model.cpl'),
        map_location = device
    )
    pclarc = ClassArtifactCompensation(
        student = student,
        datasource =(
            poised_datasets[poisoning_degree][0],
            poised_datasets[poisoning_degree][1],
            unpoised_dataset_test
        ),
        output_size = 2,
        base_dir = os.path.join(student_path, 'pclarc_' + teacher_type),
        teacher = teacher,
        gigabyte_vram = gigabyte_vram,
        overwrite = False
    )
    pclarc.run()

In [None]:
%pdb off

In [None]:
# determine the test accuracies on the unpoised dataset
import json
from tqdm import tqdm
from peal.data.dataloaders import get_dataloader
accuracies = []
accuracies_pclarc = []
accuracies_cfkd = []
student_config = load_yaml_config('$PEAL/configs/models/celeba_isblond_classifier.yaml')
student_config['training']['test_batch_size'] = 10
unpoised_dataloader_test = get_dataloader(
    dataset = unpoised_dataset_test,
    training_config = student_config['training'],
    mode = 'test',
    task_config = student_config['task']
)
for poisoning_degree in poisoning_degrees:
    student_name = 'celeba_poised' + poisoning_degree + '_isblond_' + confounder_type + '_classifier'
    student = torch.load(
        os.path.join(
            'peal_runs',
            student_name,
            'model.cpl'
        ),
        map_location = device
    )
    '''student_pclarc = torch.load(
        os.path.join(
            'peal_runs',
            'celeba_poised' + poisoning_degree + '_isblond_' + confounder_type + '_classifier',
            'pclarc_human', 
            'model.cpl'
        ),
        map_location = device
    )'''
    student_pclarc = student
    student_cfkd = torch.load(
        os.path.join(
            'peal_runs',
            'celeba_poised' + poisoning_degree + '_isblond_' + confounder_type + '_classifier',
            'cfkd_oracle',
            'model.cpl'
        ),
        map_location = device
    )
    correct = 0
    correct_pclarc = 0
    correct_cfkd = 0
    with tqdm( enumerate(unpoised_dataloader_test)) as pbar:
        for it, (X, y) in pbar:
            y_pred = student(X.to(device)).argmax(-1).to('cpu')
            correct += float(torch.sum(y_pred == y))
            y_pred_plarc = student_pclarc(X.to(device)).argmax(-1).to('cpu')
            correct_pclarc += float(torch.sum(y_pred_plarc == y))
            y_pred_cfkd = student_cfkd(X.to(device)).argmax(-1).to('cpu')
            correct_cfkd += float(torch.sum(y_pred_cfkd == y))
            pbar.set_description('poisoning_degree: ' + poisoning_degree + ', it: ' + str(it))
    
    accuracy = correct / unpoised_dataset_test.__len__()
    accuracies.append(accuracy)
    accuracy_pclarc = correct_pclarc / unpoised_dataset_test.__len__()
    accuracies_pclarc.append(accuracy_pclarc)
    accuracy_cfkd = correct_cfkd / unpoised_dataset_test.__len__()
    accuracies_cfkd.append(accuracy_cfkd)
    with open(
        os.path.join('peal_runs', student_name, 'results.json'),
        'w'
    ) as result_file:
        json.dump(
            {'accuracy' : accuracy, 'accuracy_pclarc' : accuracy_pclarc, 'accuracy_cfkd' : accuracy_cfkd},
            result_file,
            indent=4
        )
# 0.6, 0.7, 0.8, 0.9, 0.95, 0.99
# [0.91025, 0.901, 0.864, 0.819, 0.7635, 0.51375]
# 0.56
# > 0.8
# 1.0, 0.95, 0.9, 0.85, 0.8
# copyrighttag: 0.5, 0.7635, 0.819, 0.86875, 0.864
# intensity: 0.5045, 0.807, 0.84925, 0.86125, 0.887
# color: 0.502, 0.71075, 0.8295, 0.868, 0.89425
print(accuracies)
print(accuracies_pclarc)
print(accuracies_cfkd)

# TODO Visualize curves in nice matplotlib plot

In [None]:
#accuracies_cfkd = [0.65, 0.82, 0.84, 0.894, 0.897]
accuracies_cfkd = [0.84, 0.82, 0.84, 0.894, 0.897]

In [None]:
# copyright tag
oracle_test_accuracy = 0.916

In [None]:
# intensity
oracle_test_accuracy = 0.922

In [None]:
# color
oracle_test_accuracy = 0.9155

In [None]:
import matplotlib.pyplot as plt

#x1 = [0.85, 0.9, 0.95, 1.0]
#y1 = [0.86875, 0.819, 0.7635, 0.5]
#y2 = [0.88, 0.85, 0.82, 0.74]
#y3 = [0.86875, 0.83, 0.78, 0.56]
#x1 = [0.8, 0.85, 0.9, 0.95, 1.0]
x1 = [0.6, 0.7, 0.8, 0.9, 1.0]
y1 = accuracies[::-1]
y2 = accuracies_cfkd[::-1]
y3 = accuracies_pclarc[::-1]


confounder_stronger = 1 - 2 * (1 - oracle_test_accuracy)

plt.plot(x1, y1, label='uncorrected', color='red')
plt.plot(x1, y2, label='cfkd', color='green')
#plt.plot(x1, y3, label='pclarc', color='blue')

plt.axvline(x=confounder_stronger, linestyle='--', color='gray', label='Confounder is stronger feature')
plt.axhline(y=oracle_test_accuracy, linestyle='--', color='purple', label='Oracle Model')
plt.axhline(y=0.5, linestyle='--', color='black', label='Random')

plt.xlabel('Correlation Confounder & Class in Training')
plt.ylabel('Accuracy on Test Set without Correlation')
plt.legend(loc='lower left')
plt.show()

In [None]:
import matplotlib.pyplot as plt

#x1 = [0.85, 0.9, 0.95, 1.0]
#y1 = [0.86875, 0.819, 0.7635, 0.5]
#y2 = [0.88, 0.85, 0.82, 0.74]
#y3 = [0.86875, 0.83, 0.78, 0.56]
#x1 = [0.8, 0.85, 0.9, 0.95, 1.0]
x1 = [0.6, 0.7, 0.8, 0.9, 1.0]
y1 = accuracies[::-1]
y2 = accuracies_cfkd[::-1]
y3 = accuracies_pclarc[::-1]

oracle_test_accuracy = 0.91
confounder_stronger = 1 - 2 * (1 - oracle_test_accuracy)

plt.plot(x1, y1, label='uncorrected', color='red')
plt.plot(x1, y2, label='cfkd', color='green')
plt.plot(x1, y3, label='pclarc', color='blue')

plt.axvline(x=confounder_stronger, linestyle='--', color='gray', label='Confounder is stronger feature')
plt.axhline(y=oracle_test_accuracy, linestyle='--', color='purple', label='Oracle Model')
plt.axhline(y=0.5, linestyle='--', color='black', label='Random')

plt.xlabel('Correlation Confounder & Class in Training')
plt.ylabel('Accuracy on Test Set without Correlation')
plt.legend(loc='lower left')
plt.show()

In [None]:
label = 'Validation Accuracy'
y1 = [0.97, 0.97, 0.95, 0.93, 0.92, 0.96]
y2 = [0.97, 0.96, 0.94, 0.92, 0.91, 0.91]
y3 = [0.97, 0.97, 0.96, 0.95, 0.94, 0.92]

In [None]:
label = 'Unpoised Test Accuracy'
y1 = [0.54, 0.58, 0.70, 0.78, 0.78, 0.78]
y2 = [0.54, 0.59, 0.73, 0.79, 0.81, 0.82]
y3 = [0.54, 0.58, 0.71, 0.78, 0.79, 0.80]

In [None]:
label = 'Feedback Accuracy'
y1 = [0.29, 0.44, 0.49, 0.4, 0.29, 0.33]
y2 = [0.05, 0.13, 0.21, 0.30, 0.35, 0.41]
y3 = [0.34, 0.45, 0.49, 0.54, 0.55, 0.56]

In [None]:
label = 'Validation Accuracy'
y1 = [0.97, 0.97, 0.95]
y2 = [0.97, 0.95, 0.94]
y3 = [0.97, 0.96, 0.95]

In [None]:
label = 'Unpoised Test Accuracy'
y1 = [0.54, 0.58, 0.70]
y2 = [0.54, 0.60, 0.60]
y3 = [0.54, 0.57, 0.63]

In [None]:
label = 'Feedback Accuracy'
y1 = [0.29, 0.44, 0.49]
y2 = [0.04, 0.12, 0.23]
y3 = [0.05, 0.18, 0.34]

In [None]:
import matplotlib.pyplot as plt

x1 = range(3)

plt.plot(x1, y1, label='oracle', color='red')
plt.plot(x1, y2, label='SpRAy', color='green')
plt.plot(x1, y3, label='human', color='blue')

plt.xlabel('Num Iterations')
plt.ylabel(label)
plt.legend(loc='upper left')
plt.show()

In [None]:
# creates plot for cancer tissue classifier
import matplotlib.pyplot as plt

oracle_test_accuracy = 0.94
poised_accuracy = 0.65
y1 = [0.31, 0.57, 0.44, 0.27, 0.71, 0.51]
y2 = [0.99, 0.94, 0.94, 0.58, 0.98, 0.97]
y3 = [0.65, 0.69, 0.71, 0.52, 0.83, 0.76]
x1 = range(6)

plt.plot(x1, y1, label='Feedback Accuracy', color='red')
plt.plot(x1, y2, label='Validation Accuracy', color='green')
plt.plot(x1, y3, label='Unpoised Test Accuracy', color='blue')

plt.axhline(y=oracle_test_accuracy, linestyle='--', color='purple', label='Oracle Model')
plt.axhline(y=poised_accuracy, linestyle='--', color='black', label='Poised accuracy')

plt.xlabel('Iteration')
plt.ylabel('Accuracy')
plt.legend(loc='upper left')
plt.show()

In [None]:
module_path = '/home/sidney/workspace/pytorch_global_explanations_library'
if module_path not in sys.path:
    sys.path.append(module_path)

import models
    
student_cfkd = torch.load(
    '/home/sidney/workspace/pytorch_global_explanations_library/runs/celeba_isblond_hardconfounder_classifier/explanations/counterfactual_explanation_machine/4/finetuned_model/final_model.cpl'
).to(device)

In [None]:
from peal.architectures.interfaces import InvertibleGenerator
from peal.architectures.coupling_layers import gaussian_log_p

class GlowGenerator(InvertibleGenerator):
    def __init__(self, glow_model):
        super().__init__()
        self.glow_model = glow_model
    
    def encode(self, x):
        log_p_sum, logdet, z_outs = self.glow_model.forward(x)
        return z_outs
    
    def decode(self, z):
        return self.glow_model.reverse(z, reconstruct = True)
    
    def log_prob_z(self, z):
        log_probs = []

        for it, block in enumerate(self.glow_model.blocks):
            zero = torch.zeros_like(z[it])
            mean, log_sd = block.prior(zero).chunk(2, 1)
            log_p = gaussian_log_p(z[it], mean, log_sd)
            log_probs.append(torch.flatten(log_p, start_dim = 1).sum(1))
        
        log_p_sum = torch.sum(torch.stack(log_probs, dim = 0), dim = 0)

        n_pixel = 64 * 64 * 3
        log_p = log_p_sum - 5 * n_pixel
        return - torch.tensor(log_p / (math.log(2) * n_pixel))

module_path = '/home/sidney/workspace/pytorch_global_explanations_library'
if module_path not in sys.path:
    sys.path.append(module_path)

import models
    
glow = torch.load(
    '/home/sidney/workspace/pytorch_global_explanations_library/runs/celeba_with_copyright_tag_generator/final_model.pt'
).module.to(device)
generator = GlowGenerator(glow)

In [None]:
model_name = 'celeba_poised' + poisoning_degrees[0] + '_isblond_' + confounder_type + '_classifier'
student = torch.load(
    os.path.join('peal_runs', model_name, 'model.cpl'),
    map_location = device
)
'''student_cfkd = torch.load(
    os.path.join('peal_runs',model_name, 'cfkd_new3', 'model.cpl'),
    map_location = device
)'''
student_pclarc = torch.load(
    os.path.join('peal_runs',model_name, 'pclarc', 'model.cpl'),
    map_location = device
)

In [None]:
from peal.visualization.model_comparison import create_comparison

unpoised_dataset_test.task_config = {'selection' : [], 'criterions' : []}

for poisoning_degree in poisoning_degrees:
    img = create_comparison(
        dataset = unpoised_dataset_test,
        criterions = {
            'blond' : lambda X, y: int(y[unpoised_dataset_test.attributes.index('Blond_Hair')]),
            #'confounder' : lambda X, y: int(y[unpoised_dataset_test.attributes.index('Confounder')]),
            'uncorrected' : lambda X, y: int(
                student(X.unsqueeze(0).to(device)).squeeze(0).cpu().argmax()
            ),
            'cfkd' : lambda X, y: int(
                student_cfkd(X.unsqueeze(0).to(device)).squeeze(0).cpu().argmax()
            ),
            'pclarc' : lambda X, y: int(student_pclarc(X.unsqueeze(0).to(device)).squeeze(0).cpu().argmax())
        },
        columns = {            
            'Counterfactual\nExplanation' : ['cf', student, 'uncorrected'],
            'CFKD\ncorrected' : ['cf', student_cfkd, 'cfkd'],
            'LRP\nExplanation' : ['lrp', student, 'uncorrected'],
            'PClarC\ncorrected' : ['lrp', student_pclarc, 'pclarc'],
        },
        score_reference_idx = 1,
        generator = generator,
        device = device,
        max_samples = 50
    )
    #img.show()
    img.save('qualitative_results.png')