# Params

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [None]:
# device = 'cpu'
# map_vpc = {'1': 0, '3': 1, '4': 2, '5': 3, '0': 4, '6': 5} # map labels: 1 = benign (0), we don't have 2, i = Gleason i (i-2) for i=[3:5]
num_classes = 6
batch_size = 8
magnifications= [10, 20]
patch_size = 512
# stains = ['HnE']
folds = ['fold1']

path_model = 'model_VPC_Zurich/' #+ str(magnification) + '/' + fold + '/aug_model'
model_name = '256_aug_model'

path_embeddings = '256_VPC_Zurich_embeddings_overlap/'
path_VPC = '../data/VPC-TMA/overlap_patches/'


# Import

In [None]:
import os
from skimage import io
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from IPython import display
import numpy as np
import pickle
import cv2 as cv

In [None]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Model

In [None]:
class NN(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()
#         self.num_classes = num_classes
        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential(nn.Linear(in_features=512, out_features=num_classes, bias=True, ))#,
                        #  nn.ReLU(),
                        #  nn.Linear(in_features=1000, out_features=num_classes, bias=True))
#         self.model = torchvision.models.resnet50(pretrained=True)
#         self.model.fc.out_features = num_classes
#         self.model.fc = nn.Sequential(nn.Linear(in_features=2048, out_features=num_classes, bias=True, ))#,
                        #  nn.ReLU(),
                        #  nn.Linear(in_features=1000, out_features=num_classes, bias=True))
#         print(self.model)

    def forward(self, dictionary):
        return {'label': self.model(dictionary['img'])}

    def prediction(self, dictionary):
        return {'label': torch.argmax(self.forward(dictionary)['label'], dim=1)}

model = NN(num_classes=num_classes).cuda()
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)
# print(model)

# Utils

In [None]:
# a function to move tensors from the CPU to the GPU
def dict_to_device(orig, device):
    new = {}
    for k,v in orig.items():
        new[k] = v.to(device)
    return new

def plotImage(img, ax=plt):
    img_pil = torchvision.transforms.ToPILImage()(img)
    img_size = torch.FloatTensor(img_pil.size)
    ax.imshow(img_pil)

def directory_maker(path):
    if not os.path.exists(path):
        os.mkdir(path)
        
def listdir_fullpath(d):
    return [os.path.join(d, f) for f in os.listdir(d)]
        
def path_constructor(root_dir, embd_dir, magnifications, sizes):
    dict_imgs_path = {mag: {core : [] for core in os.listdir(root_dir)} for mag in magnifications}
    for core in os.listdir(root_dir):
        embd_core_path = embd_dir + core + '/'
        directory_maker(embd_core_path)
        for size in sizes:
            embd_size_path = embd_core_path + str(size) + '/'
            directory_maker(embd_size_path)
            for mag in magnifications:
                embd_mag_path = embd_size_path + str(mag) + '/'
                directory_maker(embd_mag_path)
                dict_imgs_path[mag][core].extend(listdir_fullpath(root_dir + core + '/' + str(size) + '/' + str(mag) + '/'))
    return dict_imgs_path
    
# dict_imgs = path_constructor(path_VPC, path_embeddings + fold + '/', [10, 20, 40], [patch_size])

## get embeddings
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output
    return hook

def load_model(model_dir, model_name, magnification, fold, embedding_layer):
    # 'model/' + str(magnification) + '/' + fold  + '/aug_model'
    model_path = model_dir + str(magnification) + '/' + fold + '/'
    model_path = [model_path + f for f in os.listdir(model_path) if f.startswith(model_name)][0]
    print(model_path)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.model.avgpool.register_forward_hook(get_activation(embedding_layer)) # model.model
    model.eval()
    
    return model

# load_model(path_model, 'aug_model', magnification, fold)

def save_embeddings(model, imgs_path, embd_dir, embedding_layer, model_name):
    target_path = embd_dir + '/'.join(imgs_path[0].split('/')[-4:-1]) + '/' + model_name + '_' + embedding_layer + '.pkl'
    if os.path.exists(target_path): 
        return
    transform = transforms.ToTensor()
    dict_embd = {img_path.split('/')[-1][:-4] : None for img_path in imgs_path}
    for img_path in imgs_path:
        if not img_path.endswith('.png'):
            print(img_path)
            continue
        img = io.imread(img_path)
        if img.shape[2] == 4: img = img[:,:,:3]
        img = cv.resize(img, (256, 256), interpolation=cv.INTER_CUBIC) # input is 256x256!
        img = transform(img)
        dict_gpu = dict_to_device({'img': torch.unsqueeze(img, 0)}, 'cuda')
        model(dict_gpu)
        dict_embd[img_path.split('/')[-1][:-4]] = torch.squeeze(activation[embedding_layer]).cpu().detach().numpy()
    with open(target_path, 'wb') as f:
        pickle.dump(dict_embd, f)

# Embeddings saver

In [None]:
def get_embeddings_all(root_dir, embd_dir, model_dir, model_name, magnifications, folds, embedding_layer='avgpool', sizes=[512]):
    for fold in folds:
        dict_imgs_path = path_constructor(root_dir, embd_dir + fold + '/', magnifications, sizes)
        print('{}: images loaded'.format(fold))
        for magnification in magnifications:
            model = load_model(path_model, model_name, magnification, fold, embedding_layer)
            print('{}, magnification {}: model loaded'.format(fold, magnification))
            model.eval()
            cnt = 0
            for core in dict_imgs_path[magnification]:
                save_embeddings(model, dict_imgs_path[magnification][core], embd_dir + fold + '/', embedding_layer, model_name)
                cnt += 1
                if cnt % 100 == 0:
                    print('{}, magnification {}: {}/1105 embeddings are saved'.format(fold, magnification, cnt))
                
get_embeddings_all(path_VPC, path_embeddings, path_model, model_name, magnifications, folds, sizes=[patch_size])