In [None]:
import os
import sys
# try to import peal and if not installed, add the parent directory to the path
try:
    import peal

except ImportError:
    # if peal not installed, but project downloaded locally
    module_path = os.path.abspath(os.path.join('..'))
    if module_path not in sys.path:
        sys.path.append(module_path)

# import basic libraries needed for sure and set the device depending on whether cuda is available or not
import torch
from peal.utils import request
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# set autoreload for more convinient development
%load_ext autoreload
%autoreload 2

# check and set that the right gpu is used
if device == 'cuda':
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
    !nvidia-smi
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    print('Currently used device: ' + str(os.environ["CUDA_VISIBLE_DEVICES"]))
    os.environ["CUDA_VISIBLE_DEVICES"]= request('cuda_visible_devices', default = "0") 
    torch.cuda.set_device(int(os.environ["CUDA_VISIBLE_DEVICES"]))
    import math
    import nvidia_smi
    nvidia_smi.nvmlInit()
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    gigabyte_vram = info.total / math.pow(10, 9)
    print("Total memory:", gigabyte_vram)

else:
    gigabyte_vram = None

confounder_type = request('confounder_type', default = 'copyrighttag')

In [None]:
# if the celeba dataset needs to be downloaded and / or poisened version needs to be created, otherwise this can be skipped
from peal.data.dataset_generators import ConfounderDatasetGenerator
from peal.utils import download_celeba

# The folder where the celeba dataset is stored or will be downloaded to
CELEBA_DIR = request('CELEBA_DIR', default = 'datasets/celeba_raw')
if not os.path.exists(CELEBA_DIR):
    download_celeba(CELEBA_DIR)

# In case you download the dataset yourself, make sure to have the following folder structure:
CELEBA_IMG_DIR = os.path.join(CELEBA_DIR, 'img_align_celeba')
CELEBA_ATTRIBUTE_DIR = os.path.join(CELEBA_DIR, 'list_attr_celeba.txt')

cdg = ConfounderDatasetGenerator(
    base_dataset_dir = CELEBA_IMG_DIR,
    dataset_name = 'celeba_' + confounder_type,
    label_dir = CELEBA_ATTRIBUTE_DIR,
    delimiter = ' ',
    confounder_type = confounder_type,
    overwrite = False
)
cdg.generate_dataset()

In [None]:
# create the datasets
import copy
from peal.data.datasets import get_datasets
from peal.utils import load_yaml_config
unpoisened_dataset_config = load_yaml_config('$PEAL/configs/data/isblond_confounder_celeba.yaml')
unpoisened_dataset_train, unpoisened_dataset_val, unpoisened_dataset_test = get_datasets(
    config = unpoisened_dataset_config,
    base_dir = 'datasets/celeba_' + confounder_type
)
poisened_datasets = {}
poisoning_degrees = request('poisoning_degrees', ['100', '95', '90', '85', '80'])
for poisoning_degree in poisoning_degrees:
    print(poisoning_degree)
    dataset_config = copy.deepcopy(unpoisened_dataset_config)
    dataset_config['num_samples'] = int(unpoisened_dataset_config['num_samples'] / 2)
    dataset_config['confounder_probability'] = int(poisoning_degree) / 100
    poisened_datasets[poisoning_degree] = get_datasets(
        config = dataset_config,
        base_dir = 'datasets/celeba_' + confounder_type
    )

In [None]:
# if you want to train your own initial student model
from peal.architectures.models import ImgEncoderDecoderModel
from peal.training.trainers import ModelTrainer

for poisoning_degree in poisoning_degrees:
    poisened_dataset_train, poisened_dataset_val, poisened_dataset_test = poisened_datasets[poisoning_degree]
    student_config = load_yaml_config('$PEAL/configs/models/celeba_isblond_classifier.yaml')
    student_config['data'] = poisened_dataset_train.config

    # create and traing student model
    student = ImgEncoderDecoderModel(student_config).to(device)
    student_trainer = ModelTrainer(
        config = student_config, 
        model = student, 
        datasource = (poisened_dataset_train, poisened_dataset_val),
        model_name = 'celeba_poisened' + poisoning_degree + '_isblond_' + confounder_type + '_classifier'
    )
    student_trainer.fit()

In [None]:
# choose whether to approximate the results by always using strongest poisened generator
use_predefined_generator = request('use_predefined_generator', True)
if use_predefined_generator:
    generator_path = request(
        'generator_path',
        'peal_runs/celeba_poisened100_' + confounder_type + '_generator/model.cpl'
    )
    generator = torch.load(generator_path).to(device)

else:
    generator = None

In [None]:
teacher_type = request('teacher_type', 'train')
if teacher_type == 'train':
    # if you want to train and use new model for knowledge distillation
    from peal.architectures.models import ImgEncoderDecoderModel
    from peal.training.trainers import ModelTrainer
    teacher_config = load_yaml_config('$PEAL/configs/models/celeba_isblond_classifier.yaml')
    teacher_config['data'] = unpoisened_dataset_train.config

    # create and train teacher model
    teacher = ImgEncoderDecoderModel(teacher_config).to(device)
    teacher_trainer = ModelTrainer(
        config = teacher_config,
        model = teacher, 
        datasource = (unpoisened_dataset_train, unpoisened_dataset_val),
        model_name = request('teacher_model_name', 'celeba_unpoisened_isblond_' + confounder_type + '_classifier'),
        gigabyte_vram = gigabyte_vram
    )
    teacher_trainer.fit()
    teacher_type = 'oracle'

elif teacher_type == 'load':
    # if you want to use existing model for knowledge distillation
    teacher_path = request('teacher_path', 'peal_runs/celeba_unpoisened_isblond_' + confounder_type + '_classifier/model.cpl')
    teacher = torch.load(teacher_path).to(device)
    teacher_type = 'oracle'

else:
    # if you want to teach the model yourself with the web interface
    teacher = teacher_type
    teacher_type = teacher.split('@')[0]

In [None]:
# create counterfactual mode
from peal.adaptors.counterfactual_knowledge_distillation import CounterfactualKnowledgeDistillation
for it, poisoning_degree in enumerate(poisoning_degrees):
    print(poisoning_degree)
    student_path = os.path.join(
        'peal_runs',
        'celeba_poisened' + poisoning_degree + '_isblond_' + confounder_type + '_classifier'
    )
    student = torch.load(
        os.path.join(student_path, 'model.cpl'),
        map_location = device
    )
    cfkd = CounterfactualKnowledgeDistillation(
        student = student,
        datasource = (
            poisened_datasets[poisoning_degree][0],
            poisened_datasets[poisoning_degree][1],
            unpoisened_dataset_test
        ),
        output_size = 2,
        teacher = teacher,
        generator = generator,
        base_dir = 'peal_runs/celeba_poisened' + poisoning_degree + '_isblond_' + confounder_type + '_classifier/cfkd_oracle',
        gigabyte_vram = gigabyte_vram
    )
    cfkd.run()

In [None]:
#%debug
# perform P-Clark and measure test accuracies on unpoisened dataset
from peal.adaptors.class_artifact_compensation import ClassArtifactCompensation
for it, poisoning_degree in enumerate(poisoning_degrees):
    print(poisoning_degree)
    student_path = os.path.join('peal_runs','celeba_poisened' + poisoning_degree + '_isblond_' + confounder_type + '_classifier')
    student = torch.load(
        os.path.join(student_path, 'model.cpl'),
        map_location = device
    )
    pclarc = ClassArtifactCompensation(
        student = student,
        datasource =(
            poisened_datasets[poisoning_degree][0],
            poisened_datasets[poisoning_degree][1],
            unpoisened_dataset_test
        ),
        output_size = 2,
        base_dir = os.path.join(student_path, 'pclarc_' + teacher_type),
        teacher = teacher,
        gigabyte_vram = gigabyte_vram,
        overwrite = False
    )
    pclarc.run()

In [None]:
from peal.training.trainers import calculate_test_accuracy
from peal.data.dataloaders import get_dataloader
student_config = load_yaml_config(
    '$PEAL/configs/models/celeba_isblond_classifier.yaml')
student_config['training']['test_batch_size'] = 10
unpoisened_dataloader_test = get_dataloader(
    dataset=unpoisened_dataset_test,
    training_config=student_config['training'],
    mode='test',
    task_config=student_config['task']
)
oracle_test_accuracy = calculate_test_accuracy(
    teacher, unpoisened_dataloader_test, device)
correction_types = request(
    'correction_types', ['uncorrected', 'cfkd', 'pclarc'])
accuracies = {correction_type: [] for correction_type in correction_types}
for correction_type in correction_types:
    for poisoning_degree in poisoning_degrees:
        student_path = os.path.join(
            'peal_runs',
            'celeba_poisened' + poisoning_degree +
            '_isblond_' + confounder_type + '_classifier'
        )
        if correction_type == 'uncorrected':
            student = torch.load(
                os.path.join(student_path, 'model.cpl'),
                map_location=device
            )
        else:
            student = torch.load(
                os.path.join(student_path, correction_type, 'model.cpl'),
                map_location=device
            )
        student_test_accuracy = calculate_test_accuracy(
            student, unpoisened_dataloader_test, device)
        print('poisoning degree: ' + poisoning_degree + ', ' +
              correction_type + ' accuracy: ' + str(student_test_accuracy))
        accuracies[correction_type].append(student_test_accuracy)


In [None]:
import matplotlib.pyplot as plt

# calculate correlations from poisoning degrees
x1 = [2 * float(poisoning_degree) - 1 for poisoning_degree in poisoning_degrees]
confounder_stronger = 2 * oracle_test_accuracy - 1

for correction_type in correction_types:
    y = accuracies[correction_type]
    plt.plot(x1, y, label=correction_type)

plt.axvline(x=confounder_stronger, linestyle='--', color='gray', label='Confounder is stronger feature')
plt.axhline(y=oracle_test_accuracy, linestyle='--', color='purple', label='Oracle Model')
plt.axhline(y=0.5, linestyle='--', color='black', label='Random')

plt.xlabel('Correlation Confounder & Class in Training')
plt.ylabel('Accuracy on Test Set without Correlation')
plt.legend(loc='lower left')
plt.show()