In [None]:
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import torch
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim

import numpy as np
import sklearn.metrics
import pandas as pd
import random

import dataset
import data_reader
import plots

In [None]:
# Choosing device for tensor processing

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Using GPU :", device)
else:
    device = torch.device("cpu")
    print("Using CPU ")

In [None]:
EXP_NAME = "tcia_25ep" #"25ep0"#

DATA_SET_NAME = "tcia_data_set_SPLIT" #f"data_set_X20_100%_SPLIT"#

N_PATCHES = 400 # Number of patches to take from each WSI

EPOCHS = 25

In [None]:
def init_weights(m): # XAVIER initialization for final layer weight initialization
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

def init():
    global model, loss_function, optimizer

    # Final dense layer classifier for WSI
    model = nn.Sequential(
        nn.Linear(512, 2),
        nn.Dropout(0.5),
        #nn.ReLU(inplace=True),
        #nn.Linear(256,2),
        nn.Softmax(dim=-1)
    ).to(device)

    for param in model.parameters():
        param.requires_grad = True

    model.apply(init_weights) # Xavier init
    
    #Hyperparameters:
    learning_rate = 1E-4 # 1E-4 # LR
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)#, betas=(0.9, 0.999), eps=1e-08)# Optimizer

    loss_function = nn.BCELoss()  # Loss

In [None]:
def patch_feature_extraction(wsi_data_set):
    global net
    """
    Input: 
    - wsi_data_set, patches from a given WSI, dataset object
    - y, patch labels
    Outputs:
    - output: features of the selected N_PATCHES patches
    """
    wsi_dataloader = DataLoader(wsi_data_set, batch_size=1)
    outputs = []

            
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    for batch_X, batch_y, patch_ids in iter(wsi_dataloader):

        batch_X, batch_y = batch_X.type(torch.FloatTensor).to(device).permute(0, 3, 2, 1), batch_y.type(torch.FloatTensor).to(device) 

        for i, x in enumerate(batch_X):
            batch_X[i] = normalize(batch_X[i]/255)
        
        output = list(net(batch_X)[0].cpu())
        outputs.append(output)

    return torch.tensor(outputs)

In [None]:
def fwd_pass(wsi_data_set, train=False):
    global feature_dataset

    features = []
    y = torch.tensor(wsi_data_set.labels[0], dtype=torch.float).to(device)
    
    if train: 
        model.zero_grad()
    with torch.no_grad():
        if wsi_data_set.case_ids[0] in feature_dataset.keys(): # Using saved computations
            features = feature_dataset[wsi_data_set.case_ids[0]]
            #print("Using saved computations")
        else:
            features = torch.mean(patch_feature_extraction(wsi_data_set), dim=0).to(device)
            feature_dataset[wsi_data_set.case_ids[0]] = (features)
    #print(features)
    output = model(features)
    #print(output[0], y
    
    y_pred = [torch.argmax(output)]
    y_true = [torch.argmax(y)]

    loss = loss_function(output[0], y[0])
    #print(output[0], y[0])
    #loss.requires_grad = True

    if train:
        loss.backward() # Calculate gradients using backprop
        optimizer.step() # Updates W and b using previously calculated gradients

    return [loss, y_pred, y_true]

In [None]:
n_splits = 10
feature_dataset = dict()

for SPLIT in range(10):
    ep_loss, val_ep_loss = [], []
    ep_acc, val_ep_acc = [], []
    
    X, y, case_ids = [], [], []
        
    MODEL_NAME = EXP_NAME + f"{SPLIT}"
    
    TRAIN_SPLITS = list(range(n_splits))
    # We take out the SPLIT and SPLIT+1 sets for val and testing
    if SPLIT == n_splits-1: # For the final split for validation we take the first one for test
        TRAIN_SPLITS.remove(0) 
    else:
        TRAIN_SPLITS.remove(SPLIT+1)
    TRAIN_SPLITS.remove(SPLIT)
    
    if SPLIT == n_splits-1: # For the final split for validation we take the firs one for test
        SPLIT_NAME = DATA_SET_NAME + f"{0}" # Test set 
    else:
        SPLIT_NAME = DATA_SET_NAME + f"{SPLIT+1}" # Test set
    
    MODEL_NAME = EXP_NAME + f"{SPLIT}"

    print(SPLIT_NAME)

    init() # Initializing patch merging model

    net = torch.load(f"C:\\Users\\Alejandro\\Desktop\\heterogeneous-data\\results\\WSI\\models\\{MODEL_NAME}.pth") # Model loading
    net.eval()
    net.fc = nn.Identity()#nn.Sequential(nn.AvgPool1d(kernel_size=512))

    # Loading validation splits:
    SPLIT_NAME = DATA_SET_NAME + f"{SPLIT}"

    print(f"Loading validation split: {SPLIT}")

    val_X, val_y, _ , val_patch_ids = data_reader.read_lmdb(f"D:/data/WSI/patches/{SPLIT_NAME}")

    # Loading training splits:
    for i in TRAIN_SPLITS:
        print(f"Loading training split: {i}")
        SPLIT_NAME = DATA_SET_NAME + f"{i}"
        
        X_, y_, _ , patch_ids_ = data_reader.read_lmdb(f"D:/data/WSI/patches/{SPLIT_NAME}")

        X.extend(X_)
        y.extend(y_)
        case_ids.extend(patch_ids_)

    sample_ids = [case_id.split("_")[1] for case_id in case_ids] # Taking only the sample_id, not patch_id
    val_sample_ids = [case_id.split("_")[1] for case_id in val_patch_ids]

    unique_sample_ids = np.unique(sample_ids)
    val_unique_sample_ids = np.unique(val_sample_ids)
    random.shuffle(unique_sample_ids)
    random.shuffle(val_unique_sample_ids)
    sample_ids = np.array(sample_ids)
    val_sample_ids = np.array(val_sample_ids)

    for EPOCH in range(EPOCHS):
        print("EPOCH: ", EPOCH+1)
        losses, accs, outputs, labels = [], [], [], []
        val_losses, val_accs, val_outputs, val_labels = [], [], [], []
        
        for unique_sample_id in tqdm(unique_sample_ids):

            ii = np.where(sample_ids == unique_sample_id)[0]

            wsi_data_set = dataset.PatchDataset([], [], [])
            #feature_data_set = dataset.PatchDataset([], [], [])

            wsi_data_set.inputs.extend(X[ii[0]:ii[-1]][:N_PATCHES])# Taking the patches from a given wsi
            wsi_data_set.labels.extend(y[ii[0]:ii[-1]][:N_PATCHES])
            wsi_data_set.case_ids.extend(sample_ids[ii[0]:ii[-1]][:N_PATCHES])
            prob = np.random.uniform()

            if len(wsi_data_set) < N_PATCHES:
                a = 1
                # print("No patches")
            elif prob<0.5 and wsi_data_set.labels[0][0]==0 or wsi_data_set.labels[0][0]==1: # Undersampling
                model.train()
                output = fwd_pass(wsi_data_set, train=True)
                losses.append(float(output[0]))
                outputs.append(output[1][0].cpu())
                labels.append(output[2][0].cpu())

            del wsi_data_set
        
        for unique_sample_id in tqdm(val_unique_sample_ids):

            ii = np.where(val_sample_ids == unique_sample_id)[0]
            
            wsi_data_set = dataset.PatchDataset([], [], [])
            #feature_data_set = dataset.PatchDataset([], [], [])

            wsi_data_set.inputs.extend(val_X[ii[0]:ii[-1]][:N_PATCHES])# Taking the patches from a given wsi
            wsi_data_set.labels.extend(val_y[ii[0]:ii[-1]][:N_PATCHES])
            wsi_data_set.case_ids.extend(val_patch_ids[ii[0]:ii[-1]][:N_PATCHES])

            if len(wsi_data_set) < N_PATCHES:
                a=1
                #print("No patches")
            else:
                model.eval()
                output = fwd_pass(wsi_data_set, train=False)
                val_losses.append(float(output[0]))
                val_outputs.append(output[1][0].cpu())
                val_labels.append(output[2][0].cpu())

            del wsi_data_set
                
        train_loss = np.mean(losses)
        train_acc = sklearn.metrics.accuracy_score(labels, outputs)
        train_bacc = sklearn.metrics.balanced_accuracy_score(labels, outputs)
        train_f1 = sklearn.metrics.f1_score(labels, outputs, average="macro")   
        train_conf_m = sklearn.metrics.confusion_matrix(labels, outputs, labels=[0, 1])

        val_loss = np.mean(val_losses)
        val_acc = sklearn.metrics.accuracy_score(val_labels, val_outputs)
        val_bacc = sklearn.metrics.balanced_accuracy_score(val_labels, val_outputs)
        val_f1 = sklearn.metrics.f1_score(val_labels, val_outputs, average="macro")
        val_conf_m = sklearn.metrics.confusion_matrix(val_labels, val_outputs, labels=[0, 1])

        print("Train Loss: ", train_loss, "Train ACC:", train_acc, "Train F1:", train_f1, "\nTrain CONF M:\n", train_conf_m)
        print("Val Loss: ", val_loss, "Val ACC:", val_acc, "Val F1:", val_f1, "\nVal CONF M:\n", val_conf_m)

        ep_loss.append(train_loss)
        ep_acc.append(train_acc)

        val_ep_loss.append(val_loss)
        val_ep_acc.append(val_acc)
        
    MODEL_NAME = f"feat_avg_tcia_{SPLIT}"
    torch.save(model, f"C:\\Users\\Alejandro\\Desktop\\heterogeneous-data\\results\\WSI\\models\\{MODEL_NAME}.pth")

In [None]:
import matplotlib.pyplot as plt

""""
plt.plot(ep_loss)
plt.plot(val_ep_loss)
"""

plt.plot(ep_acc)
plt.plot(val_ep_acc)

In [None]:
plt.plot(ep_loss)
plt.plot(val_ep_loss)

## TEST

In [None]:
feature_dataset = dict()

init()

test_acc, test_loss, test_f1 = [], [], []
tot_conf = np.zeros((2,2))

#DATA_SET_NAME = 

for SPLIT in range(10):
    
    X, y, case_ids = [], [], []

    MODEL_NAME = f"feat_avg_tcia_{SPLIT}"
    print("AVG model: ", MODEL_NAME)
    model = torch.load(f"C:\\Users\\Alejandro\\Desktop\\heterogeneous-data\\results\\WSI\\models\\{MODEL_NAME}.pth")
        
    MODEL_NAME = EXP_NAME + f"{SPLIT}"
    print("Feat extr model: ", MODEL_NAME)

    net = torch.load(f"C:\\Users\\Alejandro\\Desktop\\heterogeneous-data\\results\\WSI\\models\\{MODEL_NAME}.pth") # Model loading
    net.eval()
    net.fc = nn.Identity()#nn.Sequential(nn.AvgPool1d(kernel_size=512))
    
    if SPLIT==9:
        TEST_SPLIT = 0
    else:
        TEST_SPLIT = SPLIT+1

    # Loading test split:
    SPLIT_NAME = DATA_SET_NAME + f"{TEST_SPLIT}"

    print(f"Loading test split: {TEST_SPLIT}")

    val_X, val_y, _ , val_patch_ids = data_reader.read_lmdb(f"D:/data/WSI/patches/{SPLIT_NAME}")

    val_sample_ids = [case_id.split("_")[1] for case_id in val_patch_ids]

    val_unique_sample_ids = np.unique(val_sample_ids)
    random.shuffle(val_unique_sample_ids)
    val_sample_ids = np.array(val_sample_ids)

    val_losses, val_accs, val_outputs, val_labels = [], [], [], []
    
    for unique_sample_id in tqdm(val_unique_sample_ids):

        ii = np.where(val_sample_ids == unique_sample_id)[0]
        
        wsi_data_set = dataset.PatchDataset([], [], [])
        #feature_data_set = dataset.PatchDataset([], [], [])

        wsi_data_set.inputs.extend(val_X[ii[0]:ii[-1]][:N_PATCHES])# Taking the patches from a given wsi
        wsi_data_set.labels.extend(val_y[ii[0]:ii[-1]][:N_PATCHES])
        wsi_data_set.case_ids.extend(val_patch_ids[ii[0]:ii[-1]][:N_PATCHES])

        if len(wsi_data_set) < N_PATCHES:
            a=1
            #print("No patches")
        else:
            model.eval()
            output = fwd_pass(wsi_data_set, train=False)
            val_losses.append(float(output[0]))
            val_outputs.append(output[1][0].cpu())
            val_labels.append(output[2][0].cpu())

    del wsi_data_set

    val_loss = np.mean(val_losses)
    val_acc = sklearn.metrics.accuracy_score(val_labels, val_outputs)
    val_bacc = sklearn.metrics.balanced_accuracy_score(val_labels, val_outputs)
    val_f1 = sklearn.metrics.f1_score(val_labels, val_outputs, average="macro")
    val_conf_m = sklearn.metrics.confusion_matrix(val_labels, val_outputs, labels=[0, 1])

    print("Test Loss: ", val_loss, "Test ACC:", val_acc, "Test F1:", val_f1)
    print("CONf: ", val_conf_m)

    test_acc.append(val_acc)
    test_loss.append(val_loss)
    test_f1.append(val_f1)
    tot_conf+=val_conf_m

In [None]:
print(tot_conf)
print(np.mean(test_acc))