# Late Fusion

In [1]:
from models.loader import load_model
from data.dataset import get_flair_loader, get_seg_loader, get_t2_loader, get_t1ce_loader, get_t1_loader

import pickle
import numpy as np
import torch
import torch.nn as nn   
import torchvision.transforms as T
from torchsummary import summary

from tqdm import tqdm

from sklearn.metrics import ndcg_score

  from .autonotebook import tqdm as notebook_tqdm


## Configuration

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

number_patient = 200

# Models class
class Autoencoder(nn.Module):
    def __init__(self, latent_dim=256):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.BatchNorm2d(32, momentum=0.9),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.BatchNorm2d(64, momentum=0.9),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.BatchNorm2d(128, momentum=0.9),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.BatchNorm2d(256, momentum=0.9),
            nn.Conv2d(256, 256, 3, stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.BatchNorm2d(256, momentum=0.9),
            nn.Flatten(),
            nn.Linear(256 * 3 * 4, self.latent_dim)
        )

        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, 256 * 3 * 4),
            nn.Unflatten(1, (256, 3, 4)),
            nn.ConvTranspose2d(self.latent_dim, 256, 3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.BatchNorm2d(256, momentum=0.9),
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.BatchNorm2d(128, momentum=0.9),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.BatchNorm2d(64, momentum=0.9),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.BatchNorm2d(32, momentum=0.9),
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Tanh()
        )

    def encode(self, x):
        return self.encoder(x)

    def forward(self, x):
        return self.decoder(self.encoder(x))

    def train_reconstruction(self, loader, epochs=10, lr=0.001):
        self.to(device)
        self.train()
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        criterion = nn.MSELoss()
        for epoch in range(epochs):
            for i, x in enumerate(loader):
                x = x.to(device)
                optimizer.zero_grad()
                x_reconstructed = self.forward(x)
                loss = criterion(x_reconstructed, x)
                loss.backward()
                optimizer.step()
                if i % 100 == 0:
                    print(f"Epoch {epoch}, batch {i}/{len(loader)}, loss {loss.item()}")


## Loading models

In [3]:
models = {
    "seg" : load_model("./models/seg.pth"),
    "t1" : load_model("./models/t1.pth"),
    "t1ce" : load_model("./models/t1ce.pth"),
    "flair" : load_model("./models/flair.pth"),
    "t2" : load_model("./models/t2.pth")
}
# All the dataset for each model
data = {
    "seg" : get_seg_loader(),
    "t1" : get_t1_loader(),
    "t1ce" : get_t1ce_loader(),
    "flair" : get_flair_loader(),
    "t2" : get_t2_loader()
}

matrix = {}
similarities = {}
candidates = {}

100%|██████████| 1050/1050 [00:02<00:00, 417.83it/s]
100%|██████████| 1050/1050 [00:03<00:00, 349.77it/s]
100%|██████████| 1050/1050 [00:02<00:00, 352.70it/s]
100%|██████████| 1050/1050 [00:03<00:00, 342.80it/s]
100%|██████████| 1050/1050 [00:03<00:00, 336.97it/s]
100%|██████████| 200/200 [00:00<00:00, 435.72it/s]
100%|██████████| 200/200 [00:01<00:00, 165.92it/s]
100%|██████████| 200/200 [00:00<00:00, 333.89it/s]
100%|██████████| 200/200 [00:00<00:00, 331.68it/s]
100%|██████████| 200/200 [00:00<00:00, 350.88it/s]


In [4]:
def get_features_matrix_for (model_key):
    """
    Return features matrix for each patient 
    """

    result = None
    arrays = []

    for index, batch in enumerate (data[model_key]) :
        arrays.append(models[model_key].encode(batch))

    result = torch.cat(tuple(arrays), dim=0)

    return result

def get_similarity_matrix_for (model_key) :
    """
    Return a similarity matrix for a specific model
    """

    distance_matrix = np.zeros((number_patient, number_patient))
    result = get_features_matrix_for(model_key)

    for patient_1 in range (0, number_patient) :
        for patient_2 in range (0, number_patient) :
            distance_matrix[patient_1, patient_2] = torch.linalg.norm(result[patient_1] - result[patient_2], ord=2)
        
    return distance_matrix

def get_most_similar_patient (key_model, similarity_matrix) :
    """
    Sort patient index in order of most relevent to less relevent
    """

    temp = np.zeros((number_patient, number_patient))

    for candidate in range (number_patient) :
        temp[candidate] = np.argsort(similarity_matrix[key_model][candidate])

    return temp

def get_sort_similar (key_model, similarity_matrix) :
    """
    Sort patient values
    """

    temp = np.zeros((number_patient, number_patient))

    for candidate in range (number_patient) :
        temp[candidate] = np.sort(similarity_matrix[key_model][candidate])

    return temp
 

## Extracting most similar patient

In [5]:
for key in data :
    matrix[key] = get_features_matrix_for(key)
    similarities[key]  = get_similarity_matrix_for (key)

In [14]:
def borda_count (modals_keys) :
    """
    Utilize borda_count to determine the candidates for 
    """

    temp_sim_indice = {}
    temp_sim_sort = {}

    end_most_relevant_candidate = np.zeros((number_patient, number_patient))
    end_most_relevant_candidate_values = np.zeros((number_patient, number_patient))

    for key in modals_keys :

        temp_sim_indice [key] = get_most_similar_patient (key, similarities)
        temp_sim_sort [key] = get_sort_similar (key, similarities)


    # Getting the N ith patient in all models
    for patient in range (number_patient) :

        temp = {}
        candidates_temp = {}

        for key in modals_keys :
            temp [key] = temp_sim_indice [key][patient]

        for key in modals_keys :

            # Counting ponderation

            for i, candidate in enumerate( temp[key]) :
                if candidate not in candidates_temp :
                    candidates_temp[candidate] = number_patient - i
                else :
                    candidates_temp[candidate] += number_patient - i 

        sorted_list_patient = sorted (candidates_temp, key=candidates_temp.get)
        sorted_list_patient.reverse()

        value_sorted = sorted (candidates_temp.values())
        value_sorted.reverse()

        end_most_relevant_candidate[patient] = np.array(sorted_list_patient)
        end_most_relevant_candidate_values[patient] = np.array(value_sorted)
        # Borda count vote for the five models



    return {
        "candidates" : end_most_relevant_candidate,
        "values_candidates" : end_most_relevant_candidate_values
    }

In [15]:
candidates_matrix = borda_count (data.keys())

candidates_matrix['candidates'] = candidates_matrix['candidates'][:,1:]
candidates_matrix['values_candidates']= candidates_matrix['values_candidates'][:,1:]

## Calculate NDCG

In [10]:
# Ground truth
with open ('./data/IoU.pickle', "rb") as file :
    djakkar =  pickle.load(file)

djakkar_index = djakkar

# Patient croissant matrix
most_similar_patient_matrix = np.zeros(djakkar_index.shape)

for x in range (most_similar_patient_matrix.shape[0]) :
    most_similar_patient_matrix[x] = np.argsort(djakkar_index[x])


most_similar_patient_matrix

sorted_djakkar_index = np.zeros(djakkar_index.shape)

for x in range (sorted_djakkar_index.shape[0]) :
    sorted_djakkar_index[x] = np.sort(djakkar_index[x])


sorted_djakkar_index

testing_size = djakkar_index.shape[0]

# Getting the NDCG score
ndcg_matrix = np.zeros((1, testing_size))

for x in range (testing_size) :
    ndcg_matrix [0, x] = ndcg_score(np.asarray([most_similar_patient_matrix[:,1:][x]]), np.asarray([sorted_djakkar_index[:,1:][x]]))

ndcg_matrix

array([[0.82245911, 0.8577357 , 0.86001779, 0.84384805, 0.8239485 ,
        0.85025024, 0.84933212, 0.84582292, 0.82992151, 0.83515181,
        0.84552729, 0.82532444, 0.82505279, 0.8401196 , 0.83646798,
        0.8505333 , 0.86071333, 0.83575711, 0.83674904, 0.83906377,
        0.84711791, 0.84435709, 0.82772311, 0.826835  , 0.86342462,
        0.84718996, 0.84024042, 0.82002108, 0.82232115, 0.84775996,
        0.82704775, 0.82707302, 0.85849851, 0.86810165, 0.84364039,
        0.86066422, 0.85222457, 0.83879503, 0.84431929, 0.84270972,
        0.83772172, 0.85594114, 0.83910083, 0.83538075, 0.85595874,
        0.83633646, 0.84436767, 0.8666647 , 0.82471694, 0.8540424 ,
        0.84227408, 0.8439962 , 0.86007306, 0.86854879, 0.83515426,
        0.8585878 , 0.82614278, 0.83060612, 0.864261  , 0.86138798,
        0.84850773, 0.84259706, 0.83970414, 0.87613908, 0.86706485,
        0.85716744, 0.83735291, 0.84925946, 0.83133931, 0.85560747,
        0.85127549, 0.8539983 , 0.85831997, 0.86

## Calculate models ndcg

In [12]:
ndcg_model_scores = np.zeros( (1, number_patient) )

for patient in tqdm (range (number_patient)) :
    ndcg_model_scores [0, patient] = ndcg_score([candidates_matrix["candidates"][patient]], [candidates_matrix["values_candidates"][patient]])

print (ndcg_model_scores)

100%|██████████| 200/200 [00:00<00:00, 1652.95it/s]

[[0.86890092 0.86077733 0.84490902 0.86514623 0.89320719 0.86014605
  0.88793426 0.86158562 0.86527133 0.8343647  0.84706034 0.83822318
  0.84880401 0.87901092 0.83864675 0.88387067 0.8602223  0.85832263
  0.84667824 0.85879981 0.85675056 0.86091919 0.83034227 0.87970573
  0.85613518 0.86556851 0.87039555 0.86008233 0.85723729 0.85246656
  0.85202106 0.85675549 0.85334103 0.87967986 0.86045141 0.84151582
  0.88475922 0.85939146 0.88097416 0.86961371 0.88414596 0.85337048
  0.83970182 0.8377998  0.82883323 0.90470003 0.8383206  0.8464024
  0.89513772 0.84446946 0.87141245 0.8781591  0.86777451 0.8428604
  0.86948239 0.87044505 0.85731634 0.84926237 0.85921801 0.84333048
  0.86830662 0.84163462 0.88417557 0.88307683 0.84150225 0.8450156
  0.88040025 0.84613302 0.86064525 0.84911595 0.8531183  0.85279538
  0.86061261 0.87583728 0.84010454 0.8525269  0.88384196 0.85012938
  0.87670923 0.83459462 0.86644855 0.85435139 0.83589938 0.83550671
  0.84743478 0.84835012 0.85957804 0.88257221 0.882




In [13]:
with open ("./data/late_fusion_ndcg.pickle", "wb") as file:
    pickle.dump(ndcg_model_scores, file)