In [None]:
# if peal not installed, but project downloaded locally
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

# import basic libraries needed for sure and set the device depending on whether cuda is available or not
import torch
from peal.utils import request
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# set autoreload for more convinient development
%load_ext autoreload
%autoreload 2

# check and set that the right gpu is used
if device == 'cuda':
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
    !nvidia-smi
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    print('Currently used device: ' + str(os.environ["CUDA_VISIBLE_DEVICES"]))
    os.environ["CUDA_VISIBLE_DEVICES"]= request('cuda_visible_devices', default = "0")
    torch.cuda.set_device(int(os.environ["CUDA_VISIBLE_DEVICES"]))
    import math
    import nvidia_smi
    nvidia_smi.nvmlInit()
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    gigabyte_vram = info.total / math.pow(10, 9)
    print("Total memory:", gigabyte_vram)

else:
    gigabyte_vram = None

In [None]:
# download the stain normed and the unnormed dataset from https://zenodo.org/record/1214456#.Y-pAToOYVhE

base_dir = '/home/sidney/workspace/explain_and_adapt_library/notebooks/datasets'
from PIL import Image

for dataset_type in ['normalized']: #, 'no_norm']:
    if dataset_type == 'no_norm':
        appendix = '-NONORM'
    
    else:
        appendix = ''
        
    raw_data_dir = base_dir + '/cancer_tissue_raw_' + dataset_type + '/NCT-CRC-HE-100K' + appendix
    output_dir = 'datasets/cancer_tissue_' + dataset_type
    os.makedirs(output_dir)
    # move the MUS and the STR classes to a new folder and convert them to .png images
    for folder_name in ['MUS', 'STR']:
        os.makedirs(os.path.join(output_dir, folder_name))
        for img_name in os.listdir(os.path.join(raw_data_dir, folder_name)):
            img = Image.open(os.path.join(raw_data_dir, folder_name, img_name))
            img.save(os.path.join(output_dir, folder_name, img_name[:-4] + '.png'))

In [None]:
# create the datasets
from peal.data.datasets import get_datasets
from peal.utils import load_yaml_config
import copy
unpoised_dataset_config = load_yaml_config('$PEAL/configs/data/cancer_tissue.yaml')
dataset_base_dir = request('dataset_base_dir', 'datasets/cancer_tissue_no_norm')
unpoised_dataset_train, unpoised_dataset_val, unpoised_dataset_test = get_datasets(
    config = unpoised_dataset_config,
    base_dir = dataset_base_dir
)

# create a copy of the dataset config that will be poised in the next steps
poised_dataset_config = copy.deepcopy(unpoised_dataset_config)

# throw away all samples that are blond and with copyright tag or neither blond nor with a copyright tag
poised_dataset_config['num_samples'] = int(unpoised_dataset_config['num_samples']  / 2)

confounder_probability = request('confounder_probability', '100')
poised_dataset_config['confounder_probability'] = float(confounder_probability) / 100

# create dataset based changed data config
poised_dataset_train, poised_dataset_val, poised_dataset_test = get_datasets(
    config = poised_dataset_config,
    base_dir = dataset_base_dir
)

In [None]:
# find staining of images
import numpy as np
import os
from PIL import Image
base_dir = '/home/sidney/workspace/explain_and_adapt_library/notebooks/datasets/cancer_tissue_no_norm'
sample_list = []
class_names = ['MUS', 'STR']
for y in range(2):
    class_name = class_names[y]
    for idx, file_name in enumerate(os.listdir(os.path.join(base_dir, class_name))):
        if idx % 100 == 0:
            print(str(idx) + ' / ' + str(len(os.listdir(os.path.join(base_dir, class_name)))))
        # TODO by class
        #idx = int(np.random.randint(0, len(poised_dataset_train)))
        #X, y = poised_dataset_train[idx]
        X = np.array(Image.open(os.path.join(base_dir, class_name, file_name)), dtype=np.float32) / 255
        img = np.expand_dims(X, 0)
        patches = img
        def RGB2OD(image:np.ndarray) -> np.ndarray:
            mask = (image == 0)
            image[mask] = 1
            return np.maximum(-1 * np.log(image), 1e-5)

        OD_raw = RGB2OD(np.stack(patches).reshape(-1,3))
        OD = (OD_raw[(OD_raw > 0.15).any(axis=1), :])

        _, eigenVectors = np.linalg.eigh(np.cov(OD, rowvar=False))
        eigenVectors = eigenVectors[:, [2, 1]] # strip off residual stain component

        if eigenVectors[0, 0] < 0: eigenVectors[:, 0] *= -1
        if eigenVectors[0, 1] < 0: eigenVectors[:, 1] *= -1
        T_hat = np.dot(OD, eigenVectors)

        phi = np.arctan2(T_hat[:, 1], T_hat[:, 0])
        min_Phi = np.percentile(phi, 1)
        max_Phi = np.percentile(phi, 99)

        v1 = np.dot(eigenVectors, np.array([np.cos(min_Phi), np.sin(min_Phi)]))
        v2 = np.dot(eigenVectors, np.array([np.cos(max_Phi), np.sin(max_Phi)]))
        if v1[0] > v2[0]:
            stainVectors = np.array([v1, v2])
        else:
            stainVectors = np.array([v2, v1])

        sample_list.append([os.path.join(class_name, file_name), X, y, stainVectors, OD_raw])

In [None]:
import numpy as np
hematoxylin_intensities_by_class = [[], []]
def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a, axis = -1) * np.linalg.norm(b))

sample_list_new = []
for sample in sample_list:
    path, X, y, stainVectors, OD_raw = sample
    similarities_0 = cosine_similarity(OD_raw, stainVectors[0])
    similarities_1 = cosine_similarity(OD_raw, stainVectors[1])
    hematoxylin_greater_mask = similarities_0 > similarities_1
    X_intensities = np.linalg.norm(X, axis = -1).flatten()
    X_masked_intensities = X_intensities * hematoxylin_greater_mask
    stable_maximum = np.percentile(X_masked_intensities, 99)
    hematoxylin_intensities_by_class[y].append(stable_maximum)
    sample_list_new.append([path, X, y, stainVectors, OD_raw, stable_maximum])

In [None]:
base_dir = '/home/sidney/workspace/explain_and_adapt_library/notebooks/datasets/cancer_tissue_no_norm'
intensity_median = np.percentile(
    np.concatenate([hematoxylin_intensities_by_class[0],hematoxylin_intensities_by_class[1]]),
    50
)
def check(sample, has_attribute, has_confounder):
    return sample[2] == has_attribute and int((sample[-1] > intensity_median)) == has_confounder

lines_out = ['ImgPath,Cancer,Confounder,ConfounderStrength']
idxs = np.zeros([2, 2], dtype=np.int32)
for sample_idx in range(18000):
    if sample_idx % 100 == 0:
        print(sample_idx)
        open(os.path.join(base_dir, 'data.csv'), 'w').write('\n'.join(lines_out))

    has_attribute = int(sample_idx  % 4 == 0 or sample_idx  % 4 == 1)
    has_confounder = int(sample_idx % 2 == 0)

    while not check(sample_list_new[int(idxs[has_attribute][has_confounder])], has_attribute, has_confounder):
        idxs[has_attribute][has_confounder] += 1
    
    sample = sample_list_new[idxs[has_attribute][has_confounder]]
    lines_out.append(sample[0] + ',' + str(has_attribute) + ',' + str(has_confounder) + ',' + str(sample[-1]))
    print(str(has_attribute) + ' ' + str(has_confounder) + ' ' + str(idxs[has_attribute][has_confounder]))
    idxs[has_attribute][has_confounder] += 1

open(os.path.join(base_dir, 'data.csv'), 'w').write('\n'.join(lines_out))

In [None]:
print(np.min(hematoxylin_intensities_by_class[0]))
print(np.percentile(hematoxylin_intensities_by_class[0], 25))
print(np.percentile(hematoxylin_intensities_by_class[0], 50))
print(np.percentile(hematoxylin_intensities_by_class[0], 75))
print(np.max(hematoxylin_intensities_by_class[0]))
print(np.min(hematoxylin_intensities_by_class[1]))
print(np.percentile(hematoxylin_intensities_by_class[1], 25))
print(np.percentile(hematoxylin_intensities_by_class[1], 50))
print(np.percentile(hematoxylin_intensities_by_class[1], 75))
print(np.max(hematoxylin_intensities_by_class[1]))

In [None]:
is_train_student = request('is_train_student', True)
if is_train_student:
    from peal.architectures.models import ImgEncoderDecoderModel
    from peal.training.trainers import ModelTrainer
    student_config = load_yaml_config('$PEAL/configs/models/cancer_tissue_classifier.yaml')
    student_config['data'] = poised_dataset_train.config
    student = ImgEncoderDecoderModel(student_config).to(device)
    student_trainer = ModelTrainer(
        config = student_config, 
        model = student, 
        datasource = (poised_dataset_train, poised_dataset_val),
        model_name = request('student_name', 'cancer_tissue_classifier')
    )
    student_trainer.fit()

else:
    # or if you want to load your initial student model
    student_path = request('student_path', 'peal_runs/cancer_tissue_classifier/model.cpl')
    student = torch.load(student_path).to(device)

In [None]:
teacher_type = request('teacher_type', 'train')
if teacher_type == 'train':
    from peal.architectures.models import ImgEncoderDecoderModel
    from peal.training.trainers import ModelTrainer
    teacher_config = load_yaml_config('$PEAL/configs/models/cancer_tissue_classifier.yaml')
    teacher_config['data'] = poised_dataset_train.config
    teacher = ImgEncoderDecoderModel(teacher_config).to(device)
    teacher_trainer = ModelTrainer(
        config = student_config, 
        model = student, 
        datasource = (unpoised_dataset_train, unpoised_dataset_val),
        model_name = request('teacher_name', 'cancer_tissue_classifier_unpoised')
    )
    teacher_trainer.fit()
    teacher_type = 'oracle'

elif teacher_type == 'load':
    # if you want to use existing model for knowledge distillation
    teacher_path = request('teacher_path', 'peal_runs/cancer_tissue_classifier_unpoised/model.cpl')
    teacher = torch.load(teacher_path).to(device)
    teacher_type = 'oracle'

else:
    teacher = teacher_type

In [None]:
is_train_generator = request('is_train_generator', True)
if is_train_generator:
    # if you want the generator getting trained from scratch
    from peal.architectures.generators import Glow
    from peal.training.trainers import ModelTrainer
    generator_config = load_yaml_config('$PEAL/configs/models/default_generator.yaml')
    generator_config['data'] = poised_dataset_train.config
    generator = Glow(generator_config).to(device)

    generator_trainer = ModelTrainer(
        config = generator_config, 
        model = generator,
        datasource = (poised_dataset_train, poised_dataset_val),
        model_name = request('generator_model_name', 'cancer_tissue_generator'),
        gigabyte_vram = gigabyte_vram / 2
    )
    generator_trainer.fit()

else:
    # if you want to use loaded generator
    generator_path = request('generator_path', 'peal_runs/cancer_tissue_generator/model.cpl')
    generator = torch.load(generator_path).to(device)

In [None]:
# use counterfactual know
from peal.adaptors.counterfactual_knowledge_distillation import CounterfactualKnowledgeDistillation
cal = CounterfactualKnowledgeDistillation(
    student = student,
    datasource = (poised_dataset_train, poised_dataset_val, unpoised_dataset_test),
    output_size = 2,
    generator = generator,
    teacher = teacher,
    base_dir = request(
        'cfkd_base_dir',
        'peal_runs/cancer_tissue_classifier/cfkd_' + teacher_type
    ),
    gigabyte_vram = gigabyte_vram / 2,
    use_visualization = False
)
cal.run()