In [None]:
from FastSurfer.FastSurferCNN.data_loader.load_neuroimaging_data import load_and_conform_image
import nibabel as nib
from nibabel.processing import conform
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split

from PIL import ImageFilter
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.nn import CrossEntropyLoss
from torchviz import make_dot

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchsummary import summary

from tqdm import tqdm
import os
import random
from pathlib import Path
import joblib

from models.densenet import densenet121

from lime.lime_tabular import LimeTabularExplainer

# Data Processing and Loading

In [None]:
df_info = pd.read_csv('data/info.csv')

ORDER_MAP = {'Artiodactyla':0,
 'Carnivora':1,
 'Chiroptera':2,
 'Eulipotyphla':3,
 'Hyracoidea':4,
 'Lagomorpha':5,
 'Marsupialia':6,
 'Perissodactyla':7,
 'Primates':8,
 'Rodentia':9,
 'Scandentia':10,
 'Xenarthra':11}

FAMILY_MAP = {'Bovidae': 0, # Artiodactyla
'Cervidae': 1, # Artiodactyla
'Canidae': 2, # Carnivora
'Felidae': 3, # Carnivora
'Mustelidae': 4, # Carnivora
'Giraffidae': 5, # Artiodactyla
'Pteropodidae': 6,
'Procaviidae': 7,
'Cercopithecidae': 8,
'Delphinidae': 9,
'Hyaenidae': 10,
'Ursidae': 11,
'Muridae': 12,
'Hominidae': 13}

REVERSED_FAMILY_MAP = {0: 'Bovidae',
                       1: 'Cervidae', 
                       2: 'Canidae', 
                       3: 'Felidae', 
                       4: 'Mustelidae'}

CLASSES = ['Bovidae', 'Cervidae', 'Canidae', 'Felidae', 'Mustelidae']

def normalize_image(img):
    img += abs(img.min())
    img /= img.max()
    return img

from augmentations import Crop
        
def process_raw_image(datadir, f, order, family, outputdir):
    if order.empty:
        order = df_info[df_info.Filename.str.fullmatch(f.stem[:f.stem.find('_')])].Order.drop_duplicates()
        family = df_info[df_info.Filename.str.fullmatch(f.stem[:f.stem.find('_')])].Family.drop_duplicates()
    if order.size != 1:
        print(f)
        print(order)
        return ''
    
    img = nib.load(f)
    img = img.get_fdata()
    img = np.nan_to_num(img)
    img = normalize_image(img)
    
    crop = Crop([80,80,32],"random")
    
    annotations = ""
    for i in range(20):
        joblib.dump([crop(img),order], outputdir + f'Processed/{f.stem}_{i}.joblib')
        annotations += f'{f.stem}_{i}.joblib,{ORDER_MAP[order.values[0]]},{FAMILY_MAP[family.values[0]]}\n'
    
    return annotations

def process_raw_images(inputdir='MRI', datadir='data/', outputdir='data/', order=pd.Series(dtype=str), family=pd.Series(dtype=str)):
    annotations = 'filename,order,family\n'
    for f in tqdm((Path(datadir)/inputdir).iterdir()):
        annotations += process_raw_image(datadir, f, order, family, outputdir)
    with open(Path(datadir)/'labels.csv', 'w') as f:
        f.write(annotations)
        
def load_processed(datadir='data/'):
    train_loader = []
    for f in (Path(datadir)/'Processed').iterdir():
        train_loader.append(joblib.load(f))
    return train_loader

def load_raw(datadir='data/'):
    train_loader = []
    for f in (Path(datadir)/'MRI').iterdir():
        order = df_info[df_info.Filename.str.contains(f.stem[:f.stem.find('_')-1])].Order.drop_duplicates()
        if order.size > 1:
            print(order)
            continue
        img = nib.load(f).get_fdata()
        train_loader.append([img, order])
    return train_loader

def trim(arr, mask):
    bounding_box = tuple(
        slice(np.min(indexes), np.max(indexes) + 1)
        for indexes in np.where(mask))
    return arr[bounding_box]
    

In [None]:
process_raw_images('Resampled')

In [None]:
process_raw_images('Resampled', datadir='f:/Data/OpenNeuro/ds004114-download/', outputdir='f:/Data/OpenNeuro/', order=pd.Series(['Rodentia']), family=pd.Series(['Muridae']))

In [None]:
process_raw_images('Resampled', datadir='f:/Data/OpenNeuro/ds004215-download/', outputdir='f:/Data/OpenNeuro/', order=pd.Series(['Primates']), family=pd.Series(['Hominidae']))

### Get Data Distribution

In [None]:
files = []
for f in Path('data/Resampled/').iterdir():
    if f.stem == '.nii':
        continue
    files.append(f.stem.split('_')[0])

In [None]:
df_info_l = pd.read_csv('data/info.csv')

In [None]:
df_info_l = df_info_l[df_info_l.Filename.isin(files)]
df_info_l = df_info_l.drop('Id', axis=1)

In [None]:
df_info_l.Family.value_counts().plot(kind='bar')
plt.ylabel('Number of Samples')
plt.savefig('data/data_histogram.jpg', bbox_inches = 'tight', dpi=300)

In [None]:
df_info_l.to_csv("data/info_short.csv", index=False)

### Train Test split

In [None]:
df_info = pd.read_csv('data/labels.csv')
df_info = df_info.loc[df_info['family'].isin([0,1,2,3,4])]

In [None]:
train, test = train_test_split(df_info, test_size=0.33, random_state=42)

In [None]:
train.to_csv("data/train.csv", index=False)
test.to_csv("data/test.csv", index=False)

In [None]:
df_info = pd.read_csv('f:/Data/OpenNeuro/labels.csv')

In [None]:
train, test = train_test_split(df_info, test_size=0.33, random_state=42)

In [None]:
train.to_csv("f:/Data/OpenNeuro/train.csv", index=False)
test.to_csv("f:/Data/OpenNeuro/test.csv", index=False)

In [None]:
df_info = pd.read_csv('data/labels.csv')
df_info = df_info.loc[df_info['family'].isin([5])]

## Print All Data

In [None]:
import imageio

for f in Path('data/Processed/').iterdir():
    img_arr, _ = joblib.load(f)
    imageio.imwrite(f'data/Processed-demo/{f.stem}.jpg', img_arr[:,:,0])

# Training

## Train And Validation loss

In [None]:
model = torch.load('checkpoint/SupCon_epoch_100.pth')
losses = model['losses']

In [None]:
plt.figure(figsize=(3, 2), dpi=300)
plt.plot(range(len(losses['train'])), losses['train'], label="train")
plt.plot(range(len(losses['validation'])), losses['validation'], label="validation")
plt.legend()
plt.ylabel('Cross Entropy Loss')
plt.xlabel('Epoch')
plt.savefig('losses_SupCon_all', bbox_inches = 'tight', dpi=300)

In [None]:
min(losses['validation'])

In [None]:
losses['validation'].index(min(losses['validation']))+1

# Inference

## Get predictions

In [None]:
from dataset import CustomImageDataset
from config import Config, FINE_TUNING

config = Config(FINE_TUNING)

dataset_test = CustomImageDataset(config, 'data/test.csv', 'data/Processed/', FINE_TUNING)
loader_test = DataLoader(dataset_test,
                          batch_size=config.batch_size,
                          pin_memory=config.pin_mem,
                          num_workers=config.num_cpu_workers
                          )

dataset_unknown = CustomImageDataset(config, 'data/unknown.csv', 'data/Processed/', FINE_TUNING)
loader_unknown = DataLoader(dataset_unknown,
                          batch_size=config.batch_size,
                          pin_memory=config.pin_mem,
                          num_workers=config.num_cpu_workers
                          )

def get_predictions(net, is_encoder=False):
    y_pred = []
    y_true = []
    for inputs, labels, paths in loader_test:
        if is_encoder:
            output = net(inputs).data.cpu().numpy()
        else:
            output = torch.max(net(inputs), 1)[1].data.cpu().numpy()
        y_pred.extend(output) # Save Prediction

        labels = labels.data.cpu().numpy()
        y_true.extend(labels) # Save Truth
        
    return y_pred, y_true

def get_embeddings(net, unknown=False):
    embed = []
    y_true = []
    if unknown:
        print("taking unknown data")
        loader = loader_unknown
    else:
        print("taking test data")
        loader = loader_test
            
    for inputs, labels, paths in loader:
        output = net(inputs, return_hidden=True).data.cpu().numpy()
        embed.extend(output) # Save Prediction

        labels = labels.data.cpu().numpy()
        y_true.extend(labels) # Save Truth
        
    return embed, y_true

## Load the classifier and plot a confusion matrix

In [None]:
model = torch.load('checkpoint/fine_tune_epoch_199_all.pth')
net = densenet121(mode="classifier", drop_rate=0.0, num_classes=12)
net = torch.nn.DataParallel(net).to('cuda')
net.load_state_dict(model['model'])
y_pred, y_true = get_predictions(net)

In [None]:
CLASSES_ALL = ['Bov', 'Cer', 'Can', 'Fel', 'Mus', 'Gir', 'Pte', 'Pro', 'Cer', 'Del', 'Hya', 'Urs']
ConfusionMatrixDisplay.from_predictions(y_true, y_pred, display_labels=CLASSES_ALL, cmap='Blues', values_format='.0%', normalize='true')
plt.savefig('confusion_matrix_SupCon_all_98.png', bbox_inches = 'tight', dpi=300)

# Latent Space Visualization

In [None]:
model = torch.load('checkpoint/fine_tune_epoch_46_5_classes_supcon.pth')
net = densenet121(mode="classifier", drop_rate=0.0, num_classes=5)
net = torch.nn.DataParallel(net).to('cuda')
net.load_state_dict(model['model'])
embed, y_true = get_embeddings(net)

In [None]:
embed, y_true = get_embeddings(net)

In [None]:
tsne = TSNE(n_components=2, verbose=1, random_state=123, n_iter=10000)
z = tsne.fit_transform(embed) 

In [None]:
cmap = plt.cm.get_cmap('Set1').copy()
cmap2 = plt.cm.get_cmap('Dark2').copy()
cmap = matplotlib.colors.ListedColormap(cmap.colors[:5] + cmap2.colors[:7])
scatter = plt.scatter(x=z[:,0], y=z[:,1], c=y_true, cmap=cmap)
plt.legend(handles=scatter.legend_elements()[0], labels=CLASSES_ALL)
plt.savefig('latent_space.png', bbox_inches = 'tight', dpi=300)