In [None]:
import os
import sys
# try to import peal and if not installed, add the parent directory to the path
try:
    import peal

except ImportError:
    # if peal not installed, but project downloaded locally
    module_path = os.path.abspath(os.path.join('..'))
    if module_path not in sys.path:
        sys.path.append(module_path)

# import basic libraries needed for sure and set the device depending on whether cuda is available or not
import torch
from peal.utils import request
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# set autoreload for more convinient development
%load_ext autoreload
%autoreload 2

# check and set that the right gpu is used
if device == 'cuda':
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
    !nvidia-smi
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    print('Currently used device: ' + str(os.environ["CUDA_VISIBLE_DEVICES"]))
    os.environ["CUDA_VISIBLE_DEVICES"]= request('cuda_visible_devices', default = "0")
    torch.cuda.set_device(int(os.environ["CUDA_VISIBLE_DEVICES"]))
    import math
    import nvidia_smi
    nvidia_smi.nvmlInit()
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    gigabyte_vram = info.total / math.pow(10, 9)
    print("Total memory:", gigabyte_vram)

else:
    gigabyte_vram = None

is_asking = request('asking', default = True)
unrestricted_unpoisened = request('unrestricted_unpoisened', default = False, is_asking=is_asking)
attribute = request('attribute', default = 'Smiling', is_asking=is_asking)
confounding = request('confounding', default = 'Blond_Hair', is_asking=is_asking)

In [None]:
# if the celeba poisened version needs to be created, otherwise this can be skipped
from peal.data.dataset_generators import ConfounderDatasetGenerator

if is_asking:
    # the directory the CELEBA dataset is assumed to be in
    CELEBA_DIR = request('CELEBA_DIR', default = 'datasets/celeba_raw', is_asking=is_asking)

    # Make sure to have the following folder structure:
    # all images are e.g. in datasets/celeba_raw/img_align_celeba
    CELEBA_IMG_DIR = os.path.join(CELEBA_DIR, 'img_align_celeba')
    # the attribute labels are in datasets/celeba_raw/list_attr_celeba.txt
    CELEBA_ATTRIBUTE_DIR = os.path.join(CELEBA_DIR, 'list_attr_celeba.txt')

    cdg = ConfounderDatasetGenerator(
        base_dataset_dir = CELEBA_IMG_DIR,
        dataset_name = 'celeba_' + confounding,
        label_dir = CELEBA_ATTRIBUTE_DIR,
        delimiter = ' ',
        confounding = confounding
    )
    cdg.generate_dataset()

In [None]:
# create the datasets
from peal.data.dataset_factory import get_datasets
from peal.utils import load_yaml_config
import copy
data_config_path = request(
    'config_path',
    '$PEAL/configs/data/Smiling_Blond_Hair_celeba.yaml',
    is_asking=is_asking
)
unpoisened_dataset_config = load_yaml_config(data_config_path)
if not unrestricted_unpoisened:
    unpoisened_dataset_config['confounding_factors'] = [attribute, confounding]

if confounding in ['copyrighttag', 'intensity', 'color']:
    dataset_base_dir = request('dataset_base_dir', 'datasets/celeba_' + confounding, is_asking=is_asking)

else:
    dataset_base_dir = 'datasets/celeba'
    unpoisened_dataset_config['delimiter'] = ' '
    
unpoisened_dataset_train, unpoisened_dataset_val, unpoisened_dataset_test = get_datasets(
    config = unpoisened_dataset_config,
    base_dir = dataset_base_dir
)

# create a copy of the dataset config that will be poisened in the next steps
poisened_dataset_config = copy.deepcopy(unpoisened_dataset_config)
if unrestricted_unpoisened:
    poisened_dataset_config['confounding_factors'] = [attribute, confounding]

# throw away all samples that are blond and with copyright tag or neither blond nor with a copyright tag
poisened_dataset_config['num_samples'] = int(unpoisened_dataset_config['num_samples']  / 2)

confounder_probability = request('confounder_probability', '100', is_asking=is_asking)
poisened_dataset_config['confounder_probability'] = float(confounder_probability) / 100

# create dataset based changed data config
poisened_dataset_train, poisened_dataset_val, poisened_dataset_test = get_datasets(
    config = poisened_dataset_config,
    base_dir = dataset_base_dir
)

is_generator_dataset_poisened = request('is_generator_dataset_poisened', True, is_asking=is_asking)
if is_generator_dataset_poisened:
    gen_dataset_train, gen_dataset_val = poisened_dataset_train, poisened_dataset_val

else:
    gen_dataset_train, gen_dataset_val = unpoisened_dataset_train, unpoisened_dataset_val

In [None]:
default_is_train_student = not os.path.exists(
    'peal_runs/celeba_poisened' + confounder_probability + '_'
    + attribute + '_' + confounding + '_classifier/model.cpl'
)
is_train_student = request('is_train_student', default_is_train_student, is_asking=is_asking)
if is_train_student:
    # if you want to train your own initial student model
    from peal.architectures.downstream_models import Img2VectorModel
    from peal.training.trainers import ModelTrainer
    student_config = load_yaml_config('$PEAL/configs/models/celeba_Smiling_classifier.yaml')
    student_config['task']['selection'] = [attribute]
    student_config['data'] = poisened_dataset_train.config

    # create and traing student model
    student = Img2VectorModel(student_config).to(device)
    student_trainer = ModelTrainer(
        config = student_config, 
        model = student, 
        datasource = (poisened_dataset_train, poisened_dataset_val),
        model_name = request(
            'student_model_name',
            'celeba_poisened' + confounder_probability + '_' + attribute + '_' + confounding + '_classifier',
            is_asking=is_asking
        ),
        gigabyte_vram = gigabyte_vram
    )
    student_trainer.fit()

else:
    # if you want to load your initial student model
    student_path = request(
        'student_path',
        'peal_runs/celeba_poisened' + confounder_probability + '_' + attribute + '_' + confounding + '_classifier/model.cpl',
        is_asking=is_asking
    )
    student = torch.load(student_path).to(device)

In [None]:
default_teacher_type = 'train' if not os.path.exists(
    'peal_runs/celeba_unpoisened_' + attribute + '_' + confounding + '_classifier/model.cpl'
) else 'load'
teacher_type = request('teacher_type', default_teacher_type, is_asking=is_asking)
if teacher_type == 'train':
    # if you want to train and use new model for knowledge distillation
    from peal.architectures.downstream_models import Img2VectorModel
    from peal.training.trainers import ModelTrainer
    teacher_config = load_yaml_config('$PEAL/configs/models/celeba_Smiling_classifier.yaml')
    teacher_config['task']['selection'] = [attribute]
    teacher_config['data'] = unpoisened_dataset_train.config

    # create and train teacher model
    teacher = Img2VectorModel(teacher_config).to(device)
    teacher_trainer = ModelTrainer(
        config = teacher_config, 
        model = teacher, 
        datasource = (unpoisened_dataset_train, unpoisened_dataset_val),
        model_name = request(
            'teacher_model_name',
            'celeba_unpoisened_' + attribute + '_' + confounding + '_classifier',
            is_asking=is_asking,
        ),
        gigabyte_vram = gigabyte_vram
    )
    teacher_trainer.fit()
    teacher_type = 'oracle'
elif teacher_type == 'load':
    # if you want to use existing model for knowledge distillation
    teacher_path = request(
        'teacher_path',
        'peal_runs/celeba_unpoisened_' + attribute + '_' + confounding + '_classifier/model.cpl',
        is_asking=is_asking
    )
    teacher = torch.load(teacher_path).to(device)
    teacher_type = 'oracle'

else:
    teacher = teacher_type
    teacher_type = teacher.split('@')[0]

In [None]:
default_is_train_generator = not os.path.exists(
    'peal_runs/celeba_poisened' + confounder_probability + '_' + confounding + '_generator/model.cpl'
)
is_train_generator = request('is_train_generator', default_is_train_generator, is_asking=is_asking)
if is_train_generator:
    # if you want the generator getting trained from scratch
    from peal.generators.normalizing_flows import Glow
    from peal.training.trainers import ModelTrainer
    generator_config = load_yaml_config('$PEAL/configs/models/default_generator.yaml')
    generator_config['data'] = poisened_dataset_train.config
    generator = Glow(generator_config).to(device)

    generator_trainer = ModelTrainer(
        config = generator_config, 
        model = generator,
        datasource = (gen_dataset_train, gen_dataset_val),
        model_name = request(
            'generator_model_name',
            'celeba_poisened' + confounder_probability + '_' + confounding + '_generator',
            is_asking=is_asking
        ),
        gigabyte_vram = gigabyte_vram
    )
    generator_trainer.fit()

else:
    # if you want to use loaded generator
    generator_path = request(
        'generator_path',
        'peal_runs/celeba_poisened' + confounder_probability + '_' + confounding + '_generator/model.cpl',
        is_asking=is_asking
    )
    generator = torch.load(generator_path).to(device)

In [None]:
# use counterfactual knowledge distillation to improve model
from peal.adaptors.counterfactual_knowledge_distillation import CounterfactualKnowledgeDistillation
cfkd = CounterfactualKnowledgeDistillation(
    student = student,
    datasource = (poisened_dataset_train, poisened_dataset_val, unpoisened_dataset_test),
    output_size = 2,
    generator = generator,
    teacher = teacher,
    base_dir = request(
        'cfkd_base_dir',
        'peal_runs/celeba_poisened' + confounder_probability + '_' + attribute + '_' +
            confounding + '_classifier/cfkd_' + teacher_type,
        is_asking=is_asking
    ),
    gigabyte_vram = gigabyte_vram,
    adaptor_config = '$PEAL/configs/adaptors/counterfactual_knowledge_distillation_celeba.yaml',
    overwrite=False,
)
cfkd.adaptor_config['max_train_samples'] = 10
cfkd.adaptor_config['max_validation_samples'] = 2
cfkd.adaptor_config['finetune_iterations'] = 2
cfkd.adaptor_config['learning_rate'] = 1.0
cfkd.adaptor_config['explainer']['gradient_steps'] = 30
cfkd.run()