In [None]:
import os
import sys
import warnings
# try to import peal and if not installed, add the parent directory to the path
try:
    import peal

except ImportError:
    # if peal not installed, but project downloaded locally
    module_path = os.path.abspath(os.path.join('..'))
    if module_path not in sys.path:
        sys.path.append(module_path)

# import basic libraries needed for sure and set the device depending on whether cuda is available or not
import torch
from peal.utils import request
device = 'cuda' if torch.cuda.is_available() else 'cpu'

warnings.filterwarnings('ignore')

# set autoreload for more convinient development
%load_ext autoreload
%autoreload 2

# check and set that the right gpu is used
if device == 'cuda':
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
    !nvidia-smi
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    print('Currently used device: ' + str(os.environ["CUDA_VISIBLE_DEVICES"]))
    cuda_device = request('cuda_visible_devices', default = "0")
    if cuda_device == 'cpu':
        device = 'cpu'
    
    else:
        os.environ["CUDA_VISIBLE_DEVICES"]= cuda_device
        torch.cuda.set_device(int(os.environ["CUDA_VISIBLE_DEVICES"]))
        import math
        import nvidia_smi
        nvidia_smi.nvmlInit()
        handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
        info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
        gigabyte_vram = info.total / math.pow(10, 9)
        print("Total memory:", gigabyte_vram)

else:
    gigabyte_vram = None

In [None]:
# if the poisened version of the MNIST dataset needs to be created otherwise this can be skipped
import torchvision
from pathlib import Path
from peal.data.dataset_generators import MNISTConfounderDatasetGenerator

    
# the directory the MNIST dataset is assumed to be in
MNIST_DIR = request('MNIST_DIR', default = os.path.join('datasets', 'mnist_raw'))
for i in range(10):
    Path(os.path.join(MNIST_DIR, str(i))).mkdir(exist_ok=True, parents=True)

dataset_raw = torchvision.datasets.MNIST(MNIST_DIR, download=True)

for i in range(len(dataset_raw)):
    x, y = dataset_raw[i]
    x.save(os.path.join(MNIST_DIR, str(y), str(i) + '.png'))


# make sure to unpack mnist into MNIST_DIR where each subfolder contains all the samples of one digit
# where the folder with the zeros is called '0' and the one with the eights '8'
cdg = MNISTConfounderDatasetGenerator(
    mnist_dir = MNIST_DIR,
    dataset_name = 'mnist_0vs8'
)
cdg.generate_dataset()

In [None]:
# create the datasets
from peal.data.dataset_factory import get_datasets
from peal.utils import load_yaml_config
import copy
unpoisened_dataset_config = load_yaml_config('$PEAL/configs/data/mnist_0vs8.yaml')
dataset_base_dir = 'datasets/mnist_0vs8'
unpoisened_dataset_train, unpoisened_dataset_val, unpoisened_dataset_test = get_datasets(
    config = unpoisened_dataset_config,
    base_dir = dataset_base_dir
)

# create a copy of the dataset config that will be poisened in the next steps
poisened_dataset_config = copy.deepcopy(unpoisened_dataset_config)
poisened_dataset_config.num_samples = int(unpoisened_dataset_config.num_samples  / 2)

confounder_probability = request('confounder_probability', '100')
poisened_dataset_config.confounder_probability = float(confounder_probability) / 100

# create dataset based changed data config
poisened_dataset_train, poisened_dataset_val, poisened_dataset_test = get_datasets(
    config = poisened_dataset_config,
    base_dir = dataset_base_dir
)

In [None]:
is_train_generator = request('is_train_generator', True)
if is_train_generator:
    # if you want the generator getting trained from scratch
    from peal.training.trainers import ModelTrainer
    generator_config = load_yaml_config('$PEAL/configs/models/default_generator.yaml')
    generator_config.data = poisened_dataset_train.config
    if generator_config.type == 'Glow':
        from peal.generators.normalizing_flows import Glow
        generator = Glow(generator_config).to(device)
    
    elif generator_config.type == 'VAE':
        from peal.generators.variational_autoencoders import VAE
        generator = VAE(generator_config).to(device)

    generator_trainer = ModelTrainer(
        config = generator_config, 
        model = generator,
        datasource = (poisened_dataset_train, poisened_dataset_val),
        model_name = request(
            'generator_model_name',
            'mnist_0vs8_' + confounder_probability + '_generator'
        ),
        gigabyte_vram = gigabyte_vram
    )
    generator_trainer.fit()

else:
    # if you want to use loaded generator
    generator_path = request(
        'generator_path',
        'peal_runs/mnist_0vs8_' + confounder_probability + '_generator/model.cpl'
    )
    generator = torch.load(generator_path).to(device)

In [None]:
is_train_student = request('is_train_student', True)
if is_train_student:
    # if you want to train your own initial student model
    from peal.architectures.downstream_models import Img2VectorModel
    from peal.training.trainers import ModelTrainer
    student_config = load_yaml_config('$PEAL/configs/models/mnist_0vs8_classifier.yaml')
    student_config.data = poisened_dataset_train.config

    # create and traing student model
    student = Img2VectorModel(student_config).to(device)
    student_trainer = ModelTrainer(
        config = student_config, 
        model = student, 
        datasource = (poisened_dataset_train, poisened_dataset_val),
        model_name = request(
            'student_model_name',
            'mnist_0vs8_' + confounder_probability + '_classifier'
        ),
        gigabyte_vram = gigabyte_vram
    )
    student_trainer.fit()

else:
    # if you want to load your initial student model
    student_path = request(
        'student_path',
        'peal_runs/mnist_0vs8_' + confounder_probability + '_classifier/model.cpl'
    )
    student = torch.load(student_path).to(device)

In [None]:
teacher_type = request('teacher_type', 'train')
if teacher_type == 'train':
    # if you want to train and use new model for knowledge distillation
    from peal.architectures.downstream_models import Img2VectorModel
    from peal.training.trainers import ModelTrainer
    teacher_config = load_yaml_config('$PEAL/configs/models/mnist_0vs8_classifier.yaml')
    teacher_config.data = unpoisened_dataset_train.config

    # create and train teacher model
    teacher = Img2VectorModel(teacher_config).to(device)
    teacher_trainer = ModelTrainer(
        config = teacher_config, 
        model = teacher, 
        datasource = (unpoisened_dataset_train, unpoisened_dataset_val),
        model_name = request('teacher_model_name', 'mnist_0vs8_unpoisened_classifier'),
        gigabyte_vram = gigabyte_vram
    )
    teacher_trainer.fit()
    teacher_type = 'oracle'

elif teacher_type == 'load':
    # if you want to use existing model for knowledge distillation
    teacher_path = request('teacher_path', 'peal_runs/mnist_0vs8_unpoisened_classifier/model.cpl')
    teacher = torch.load(teacher_path).to(device)
    teacher_type = 'oracle'

else:
    teacher = teacher_type

In [None]:
# use counterfactual knowledge distillation to improve model
from peal.adaptors.counterfactual_knowledge_distillation import CounterfactualKnowledgeDistillation
cfkd = CounterfactualKnowledgeDistillation(
    student = student,
    datasource = (poisened_dataset_train, poisened_dataset_val, unpoisened_dataset_test),
    output_size = 2,
    generator = generator,
    teacher = teacher,
    base_dir = request(
        'cfkd_base_dir',
        'peal_runs/mnist_0vs8_' + confounder_probability + '_classifier/cfkd_' + teacher_type
    ),
    gigabyte_vram = gigabyte_vram,
    overwrite=True
)
cfkd.run()