In [None]:
import os
import sys
try:
    import peal

except ImportError:
    # if peal not installed, but project downloaded locally
    module_path = os.path.abspath(os.path.join('..'))
    if module_path not in sys.path:
        sys.path.append(module_path)

# import basic libraries needed for sure and set the device depending on whether cuda is available or not
import torch
from peal.utils import request
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# set autoreload for more convinient development
%load_ext autoreload
%autoreload 2

# check and set that the right gpu is used
if device == 'cuda':
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
    !nvidia-smi
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    print('Currently used device: ' + str(os.environ["CUDA_VISIBLE_DEVICES"]))
    os.environ["CUDA_VISIBLE_DEVICES"]= request('cuda_visible_devices', default = "0")
    torch.cuda.set_device(int(os.environ["CUDA_VISIBLE_DEVICES"]))
    import math
    import nvidia_smi
    nvidia_smi.nvmlInit()
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    gigabyte_vram = info.total / math.pow(10, 9)
    print("Total memory:", gigabyte_vram)

else:
    gigabyte_vram = None

In [None]:
# if the colorectal cancer dataset needs to be downloaded and / or poisened version needs to be created
# otherwise this can be skipped
from peal.data.dataset_generators import StainingConfounderGenerator
from pathlib import Path

# The folder where the celeba dataset is stored or will be downloaded to
RAW_DATA_DIR = 'datasets/colorectal_cancer_raw'
if not os.path.exists(RAW_DATA_DIR):
    Path(RAW_DATA_DIR).mkdir(parents=True, exist_ok=True)
    # downloads the stain the unnormed colorectal cancer dataset from
    # https://zenodo.org/record/1214456#.Y-pAToOYVhE
    !wget -O datasets/colorectal_cancer_raw/download.zip https://zenodo.org/record/1214456/files/NCT-CRC-HE-100K-NONORM.zip?download=1
    !unzip -d datasets/colorectal_cancer_raw datasets/colorectal_cancer_raw/download.zip
    RAW_DATA_DIR = 'datasets/colorectal_cancer_raw/NCT-CRC-HE-100K-NONORM'

# In case you download the dataset yourself, make sure to have the images of the same class always
# in the same subfolder of RAW_DATA_DIR
scg = StainingConfounderGenerator(
    base_dataset_dir=RAW_DATA_DIR,
    dataset_name='cancer_tissue_no_norm',
    delimiter=' ',
    overwrite = False
)
scg.generate_dataset()

In [None]:
# create the datasets
from peal.data.datasets import get_datasets
from peal.utils import load_yaml_config
import copy
unpoised_dataset_config = load_yaml_config('$PEAL/configs/data/cancer_tissue.yaml')
dataset_base_dir = request('dataset_base_dir', 'datasets/cancer_tissue_no_norm')
unpoised_dataset_train, unpoised_dataset_val, unpoised_dataset_test = get_datasets(
    config = unpoised_dataset_config,
    base_dir = dataset_base_dir
)

# create a copy of the dataset config that will be poised in the next steps
poised_dataset_config = copy.deepcopy(unpoised_dataset_config)

# throw away all samples that are blond and with copyright tag or neither blond nor with a copyright tag
poised_dataset_config['num_samples'] = int(unpoised_dataset_config['num_samples']  / 2)

confounder_probability = request('confounder_probability', '100')
poised_dataset_config['confounder_probability'] = float(confounder_probability) / 100

# create dataset based changed data config
poised_dataset_train, poised_dataset_val, poised_dataset_test = get_datasets(
    config = poised_dataset_config,
    base_dir = dataset_base_dir
)

In [None]:
is_train_student = request('is_train_student', True)
if is_train_student:
    from peal.architectures.models import ImgEncoderDecoderModel
    from peal.training.trainers import ModelTrainer
    student_config = load_yaml_config('$PEAL/configs/models/cancer_tissue_classifier.yaml')
    student_config['data'] = poised_dataset_train.config
    student = ImgEncoderDecoderModel(student_config).to(device)
    student_trainer = ModelTrainer(
        config = student_config, 
        model = student, 
        datasource = (poised_dataset_train, poised_dataset_val),
        model_name = request('student_name', 'cancer_tissue_classifier')
    )
    student_trainer.fit()

else:
    # or if you want to load your initial student model
    student_path = request('student_path', 'peal_runs/cancer_tissue_classifier/model.cpl')
    student = torch.load(student_path).to(device)

In [None]:
teacher_type = request('teacher_type', 'train')
if teacher_type == 'train':
    from peal.architectures.models import ImgEncoderDecoderModel
    from peal.training.trainers import ModelTrainer
    teacher_config = load_yaml_config('$PEAL/configs/models/cancer_tissue_classifier.yaml')
    teacher_config['data'] = poised_dataset_train.config
    teacher = ImgEncoderDecoderModel(teacher_config).to(device)
    teacher_trainer = ModelTrainer(
        config = student_config, 
        model = student, 
        datasource = (unpoised_dataset_train, unpoised_dataset_val),
        model_name = request('teacher_name', 'cancer_tissue_classifier_unpoised')
    )
    teacher_trainer.fit()
    teacher_type = 'oracle'

elif teacher_type == 'load':
    # if you want to use existing model for knowledge distillation
    teacher_path = request('teacher_path', 'peal_runs/cancer_tissue_classifier_unpoised/model.cpl')
    teacher = torch.load(teacher_path).to(device)
    teacher_type = 'oracle'

else:
    teacher = teacher_type

In [None]:
is_train_generator = request('is_train_generator', True)
if is_train_generator:
    # if you want the generator getting trained from scratch
    from peal.architectures.generators import Glow
    from peal.training.trainers import ModelTrainer
    generator_config = load_yaml_config('$PEAL/configs/models/default_generator.yaml')
    generator_config['data'] = poised_dataset_train.config
    generator = Glow(generator_config).to(device)

    generator_trainer = ModelTrainer(
        config = generator_config, 
        model = generator,
        datasource = (poised_dataset_train, poised_dataset_val),
        model_name = request('generator_model_name', 'cancer_tissue_generator'),
        gigabyte_vram = gigabyte_vram / 2
    )
    generator_trainer.fit()

else:
    # if you want to use loaded generator
    generator_path = request('generator_path', 'peal_runs/cancer_tissue_generator/model.cpl')
    generator = torch.load(generator_path).to(device)

In [None]:
# use counterfactual know
from peal.adaptors.counterfactual_knowledge_distillation import CounterfactualKnowledgeDistillation
cfkd = CounterfactualKnowledgeDistillation(
    student = student,
    datasource = (poised_dataset_train, poised_dataset_val, unpoised_dataset_test),
    output_size = 2,
    generator = generator,
    teacher = teacher,
    base_dir = request(
        'cfkd_base_dir',
        'peal_runs/cancer_tissue_classifier/cfkd_' + teacher_type
    ),
    gigabyte_vram = gigabyte_vram / 2,
    use_visualization = False
)
cfkd.run()
