In [None]:
# Packages
import os
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib import gridspec
from mpl_toolkits import mplot3d
import seaborn as sns  # requires version 0.10.1/nur 0.11.1 funktioniert
plt.style.use('ggplot')

import scipy.io
import scipy.stats as st
import mat73
import time
import pickle5 as pickle

import torch
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torch.autograd import Variable
from sklearn.model_selection import train_test_split
from sklearn.mixture import GaussianMixture

# to disable seaborn warnings
import warnings
warnings.filterwarnings("ignore")

# Datensatz
anz_samples = 200000  

comp_samples = 5000 # anz_samples + 4*comp_samples must < 1.000.000

# Batchsize
batch_size_train = 1024 
batch_size_test = 4096

# Training
n_epochs = 5000
learning_rate = 0.01   
used_optimizer = "adagrad" # (adam=Adam, nadam=NAdam, radam=RAdam, adamax=Adamax, adagrad=Adagrad, sgd=SGD, adadelta=Adadelta, rms=RMSProp) 
used_loss = "mse" # (mse=mean squared error, mae=mean average error, kl=kullback leibler)

used_init = "he" # (he=Kaiming He, xavier=Xavier)
used_activation = "leaky_relu" # (leaky_relu=Leaky ReLU, relu=ReLU, elu=ELU, gelu=GELU, tanh=tanh)

use_dropout = [True, 0.6] # (True=dropout, 0.1=wahrscheinlichkeit ein Neuron zu löschen)
use_weight_decay = [False, 0.0] # pytorch weight decay
use_noise = [False, 0.0, 0.005] # (True=noise, mean, stddev)

# log
log_interval = 500
use_plot = True

# Neuronales Netz
I1 = 238           # Anzahl Input Parameter
L1 = 1024          # Neuronen Layer1
L2 = 1024          # Neuronen Layer2 
L3 = 1024          # Neuronen Layer3 
L4 = 2048          # Neuronen Layer4 
L5 = 2048          # Neuronen Layer5 
O1 = 3600          # Anzahl Output Parameter

In [None]:
# define functions

def load_data():
    infile = open(r"path/to/file.mat",'rb')
    data = pickle.load(infile)
    infile.close()
    
    samples_pro_verteilung = int(anz_samples/4)
    
    x = []
    x.extend(data["input"][     0: samples_pro_verteilung         ])
    x.extend(data["input"][250000: samples_pro_verteilung + 250000])
    x.extend(data["input"][500000: samples_pro_verteilung + 500000])
    x.extend(data["input"][750000: samples_pro_verteilung + 750000])

    y = []
    y.extend(data["ref"][     0: samples_pro_verteilung         ])
    y.extend(data["ref"][250000: samples_pro_verteilung + 250000])
    y.extend(data["ref"][500000: samples_pro_verteilung + 500000])
    y.extend(data["ref"][750000: samples_pro_verteilung + 750000])

    x = np.array(x)
    y = np.array(y)
    
    x_test_comp = []
    x_test_comp.append(data["input"][         samples_pro_verteilung :
                                                    samples_pro_verteilung + comp_samples])
    x_test_comp.append(data["input"][250000 + samples_pro_verteilung :
                                                    250000 + samples_pro_verteilung + comp_samples])
    x_test_comp.append(data["input"][500000 + samples_pro_verteilung :
                                                    500000 + samples_pro_verteilung + comp_samples])
    x_test_comp.append(data["input"][750000 + samples_pro_verteilung :
                                                    750000 + samples_pro_verteilung + comp_samples])
    
    y_test_comp = []
    y_test_comp.append(data["ref"][         samples_pro_verteilung :
                                            samples_pro_verteilung + comp_samples])
    y_test_comp.append(data["ref"][250000 + samples_pro_verteilung :
                                            250000 + samples_pro_verteilung + comp_samples])
    y_test_comp.append(data["ref"][500000 + samples_pro_verteilung :
                                            500000 + samples_pro_verteilung + comp_samples])
    y_test_comp.append(data["ref"][750000 + samples_pro_verteilung :
                                            750000 + samples_pro_verteilung + comp_samples])
    
    x_test_comp = np.array(x_test_comp)
    y_test_comp = np.array(y_test_comp)
        
    # Split into training and test
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, shuffle=True) 

    # transforming training and test sets to torch.tensor
    x_train = torch.from_numpy(x_train).type(torch.FloatTensor)
    y_train = torch.from_numpy(y_train).type(torch.FloatTensor)
    
    x_test = torch.from_numpy(x_test).type(torch.FloatTensor)
    y_test = torch.from_numpy(y_test).type(torch.FloatTensor)

    x_test_comp = torch.from_numpy(x_test_comp).type(torch.FloatTensor)
    y_test_comp = torch.from_numpy(y_test_comp).type(torch.FloatTensor)

    device =  "cuda" if torch.cuda.is_available() else "cpu"

    # Get Batches via DataLoader
    train = torch.utils.data.TensorDataset(x_train.to(device), y_train.to(device))
    test = torch.utils.data.TensorDataset(x_test.to(device), y_test.to(device))

    train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=batch_size_train, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test, batch_size=batch_size_test, shuffle=True)
    
    return train_loader, test_loader, x_train, y_train, x_test, y_test, x_test_comp, y_test_comp

#Definition of the net structure
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        if used_init == 'he':
            self.fc1 = nn.Linear(I1, L1) # 238 eingangsparameter
            torch.nn.init.kaiming_normal_(self.fc1.weight, mode='fan_in') 
            self.fc2 = nn.Linear(L1, L2)
            torch.nn.init.kaiming_normal_(self.fc2.weight, mode='fan_in')  
            self.fc3 = nn.Linear(L2, L3)
            torch.nn.init.kaiming_normal_(self.fc3.weight, mode='fan_in') 
            self.fc4 = nn.Linear(L3, L4)
            torch.nn.init.kaiming_normal_(self.fc4.weight, mode='fan_in') 
            self.fc5 = nn.Linear(L4, L5)
            torch.nn.init.kaiming_normal_(self.fc5.weight, mode='fan_in') 
            self.fc6 = nn.Linear(L5, O1) # 3600 ausgaben zu schätzen
        
        elif used_init == 'xavier':
            self.fc1 = nn.Linear(I1, L1) # 238 eingangsparameter
            torch.nn.init.xavier_uniform_(self.fc1.weight) 
            self.fc2 = nn.Linear(L1, L2)
            torch.nn.init.xavier_uniform_(self.fc2.weight) 
            self.fc3 = nn.Linear(L2, L3)
            torch.nn.init.xavier_uniform_(self.fc3.weight) 
            self.fc4 = nn.Linear(L3, L4)
            torch.nn.init.xavier_uniform_(self.fc4.weight) 
            self.fc5 = nn.Linear(L4, L5)
            torch.nn.init.xavier_uniform_(self.fc5.weight) 
            self.fc6 = nn.Linear(L5, O1) # 3600 ausgaben zu schätzen
    
    def forward(self, x):     
        if use_dropout[0] == True:
            if used_activation == "leaky_relu":
                x = F.leaky_relu(self.fc1(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.leaky_relu(self.fc2(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.leaky_relu(self.fc3(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.leaky_relu(self.fc4(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.leaky_relu(self.fc5(x))
                # x = F.dropout(x, p = use_dropout[1]) # deactivate dropout in last layer 
                x = self.fc6(x)
                return F.softmax(x, dim=-1) 
            
            elif used_activation == "relu":
                x = F.relu(self.fc1(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.relu(self.fc2(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.relu(self.fc3(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.relu(self.fc4(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.relu(self.fc5(x))
                x = F.dropout(x, p = use_dropout[1])
                x = self.fc6(x)
                return F.softmax(x, dim=-1)
            
            elif used_activation == "elu":
                x = F.elu(self.fc1(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.elu(self.fc2(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.elu(self.fc3(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.elu(self.fc4(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.elu(self.fc5(x))
                x = F.dropout(x, p = use_dropout[1])
                x = self.fc6(x)
                return F.softmax(x, dim=-1) 
            
            elif used_activation == "gelu":
                x = F.gelu(self.fc1(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.gelu(self.fc2(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.gelu(self.fc3(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.gelu(self.fc4(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.gelu(self.fc5(x))
                x = F.dropout(x, p = use_dropout[1])
                x = self.fc6(x)
                return F.softmax(x, dim=-1) 
            
            elif used_activation == "tanh":
                x = F.tanh(self.fc1(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.tanh(self.fc2(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.tanh(self.fc3(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.tanh(self.fc4(x))
                x = F.dropout(x, p = use_dropout[1])
                x = F.tanh(self.fc5(x))
                x = F.dropout(x, p = use_dropout[1])
                x = self.fc6(x)
                return F.softmax(x, dim=-1) 

        
        elif use_dropout[0] == False:
            if used_activation == "leaky_relu":
                x = F.leaky_relu(self.fc1(x))
                x = F.leaky_relu(self.fc2(x))
                x = F.leaky_relu(self.fc3(x))
                x = F.leaky_relu(self.fc4(x))
                x = F.leaky_relu(self.fc5(x))
                x = self.fc6(x)
                return F.softmax(x, dim=-1) 
            
            elif used_activation == "relu":
                x = F.relu(self.fc1(x))
                x = F.relu(self.fc2(x))
                x = F.relu(self.fc3(x))
                x = F.relu(self.fc4(x))
                x = F.relu(self.fc5(x))
                x = self.fc6(x)
                return F.softmax(x, dim=-1) 
            
            elif used_activation == "elu":
                x = F.elu(self.fc1(x))
                x = F.elu(self.fc2(x))
                x = F.elu(self.fc3(x))
                x = F.elu(self.fc4(x))
                x = F.elu(self.fc5(x))
                x = self.fc6(x)
                return F.softmax(x, dim=-1)
            
            elif used_activation == "gelu":
                x = F.gelu(self.fc1(x))
                x = F.gelu(self.fc2(x))
                x = F.gelu(self.fc3(x))
                x = F.gelu(self.fc4(x))
                x = F.gelu(self.fc5(x))
                x = self.fc6(x)
                return F.softmax(x, dim=-1) 
            
            elif used_activation == "tanh":
                x = F.tanh(self.fc1(x))
                x = F.tanh(self.fc2(x))
                x = F.tanh(self.fc3(x))
                x = F.tanh(self.fc4(x))
                x = F.tanh(self.fc5(x))
                x = self.fc6(x)
                return F.softmax(x, dim=-1)

def preprocess():
    # Get cpu or gpu device for training.
    device =  "cuda" if torch.cuda.is_available() else "cpu"
    # print(f"Using {device} device")
    
    network = Net().to(device) # load network to gpu
    
    if use_weight_decay == True:
        if used_optimizer == "nadam":
            optimizer = optim.NAdam(network.parameters(), lr=learning_rate, weight_decay=use_weight_decay[1])
        elif used_optimizer == "adam":
            optimizer = optim.Adam(network.parameters(), lr=learning_rate, weight_decay=use_weight_decay[1])
        elif used_optimizer == "sgd":
            optimizer = optim.SGD(network.parameters(), lr=learning_rate, weight_decay=use_weight_decay[1])
        elif used_optimizer == "radam":
            optimizer = optim.RAdam(network.parameters(), lr=learning_rate, weight_decay=use_weight_decay[1])
        elif used_optimizer == "adamax":
            optimizer = optim.Adamax(network.parameters(), lr=learning_rate, weight_decay=use_weight_decay[1])
        elif used_optimizer == "adadelta":
            optimizer = optim.Adadelta(network.parameters(), lr=learning_rate, weight_decay=use_weight_decay[1])
        elif used_optimizer == "rms":
            optimizer = optim.RMSprop(network.parameters(), lr=learning_rate, weight_decay=use_weight_decay[1])
        elif used_optimizer == "adagrad":
            optimizer = optim.Adagrad(network.parameters(), lr=learning_rate, weight_decay=use_weight_decay[1])

    elif use_weight_decay == False:
        if used_optimizer == "nadam":
            optimizer = optim.NAdam(network.parameters(), lr=learning_rate)
        elif used_optimizer == "adam":
            optimizer = optim.Adam(network.parameters(), lr=learning_rate)
        elif used_optimizer == "sgd":
            optimizer = optim.SGD(network.parameters(), lr=learning_rate)
        elif used_optimizer == "radam":
            optimizer = optim.RAdam(network.parameters(), lr=learning_rate)
        elif used_optimizer == "adamax":
            optimizer = optim.Adamax(network.parameters(), lr=learning_rate)
        elif used_optimizer == "adadelta":
            optimizer = optim.Adadelta(network.parameters(), lr=learning_rate)
        elif used_optimizer == "rms":
            optimizer = optim.RMSprop(network.parameters(), lr=learning_rate)
        elif used_optimizer == "adagrad":
            optimizer = optim.Adagrad(network.parameters(), lr=learning_rate)

    return network, optimizer

def gaussian(ins, is_training=True, mean=0.0, stddev=0.005):
    if is_training:
        noise = Variable(ins.data.new(ins.size()).normal_(mean, stddev))
        return ins + noise
    return ins
    
def predict(network, data):
    y_pred = network(data)
    return y_pred

# Training function for each epoch
def train(epoch, network, optimizer):
    
    network.train()
    
    # Loop over the batches
    for batch_idx, (data, target) in enumerate(train_loader):

        target = target.to(torch.float32)
        data = data.to(torch.float32)

        if use_noise[0] == True:
            data = gaussian(data, is_training=True, mean=use_noise[1], stddev=use_noise[2])

        # --- Steps of the training of the net ---
        optimizer.zero_grad()
        output = network(data)

        if used_loss == "mse":
            loss = F.mse_loss(output, target) 
        elif used_loss == "mae":
            loss = nn.L1Loss(output, target)
        elif used_loss == "kl":
            kl_loss = nn.KLDivLoss(reduction = 'batchmean')
            loss = kl_loss(output, target)

        loss.backward()
        optimizer.step()

        # --- Save Evaluation metrics ---
        train_loss.append(loss.item())

    train_losses.append(sum(train_loss) / len(train_loader))
    train_loss.clear()

# Test function that applies the test set to the trained net
def test(network):
    
    network.eval()
    
    # Gradient calculation is disabled (as not needed)
    with torch.no_grad():

        # Loop over the batches
        for data, target in test_loader:

            target = target.to(torch.float32)
            data = data.to(torch.float32)

            # --- Prediction and calculation of evaluation metrics ---
            output = network(data) 

            if used_loss == "mse":
                loss = F.mse_loss(output, target) 
            elif used_loss == "mae":
                loss = nn.L1Loss(output, target)
            elif used_loss == "kl":
                kl_loss = nn.KLDivLoss(reduction = 'batchmean')
                loss = kl_loss(output, target)

            test_loss.append(loss.item())

        test_losses.append(sum(test_loss) / len(test_loader))
        test_loss.clear()

def train_test(network, optimizer):
    start_time = time.time() # Start timer

    # Initial Evaluation metrics
    test(network)
    print('Test set: Avg. loss: {:.10f},\n'.format(test_losses[-1], len(test_loader.dataset)))
        
    # run training
    for epoch in range(1, n_epochs + 1):
        train(epoch, network, optimizer)
        test(network)

        # print evaluation metrics
        if epoch % log_interval == 0:
           print("Epoch ", epoch)
           print('Train set: Avg. loss: {:.10f}'.format(train_losses[-1], len(train_loader.dataset)))
           print('Test set: Avg. loss: {:.10f}'.format(test_losses[-1], len(test_loader.dataset)))
        
           passed_sec_epoch = time.time() - start_time
           print('geschätzte verbleibende Dauer nach', (passed_sec_epoch/3600), ':', ((((passed_sec_epoch/epoch)*n_epochs)/3600)-(passed_sec_epoch/3600)), '\n') # auf aktuellem tempo geschätze gesamtdauer-bisherige dauer 

    passed_sec = time.time() - start_time
    passed_min = passed_sec/60
    passed_hrs = passed_min/60
    print("--- %s seconds ---" % (passed_sec))
    print("--- %s minutes ---" % (passed_min))
    print("--- %s hours ---" % (passed_hrs))
    
    return network

# Test function that applies the test set to the trained net
def test_comp(network, x_data_comp, y_data_comp):
    test_loss_comp = []
    test_losses_comp = []
    
    device =  "cuda" if torch.cuda.is_available() else "cpu"
    
    test_data_comp = torch.utils.data.TensorDataset(x_data_comp.to(device), y_data_comp.to(device))
    test_loader = torch.utils.data.DataLoader(dataset=test_data_comp, batch_size=comp_samples, shuffle=False) # shuffle?
    
    network.eval()
    # Gradient calculation is disabled (as not needed)
    with torch.no_grad():

        # Loop over the batches
        for data, target in test_loader:

            target = target.to(torch.float32)
            data = data.to(torch.float32)

            # --- Prediction and calculation of evaluation metrics ---
            output = network(data)

            if used_loss == "mse":
                loss = F.mse_loss(output, target) 
            elif used_loss == "mae":
                loss = nn.L1Loss(output, target)
            elif used_loss == "kl":
                kl_loss = nn.KLDivLoss(reduction = 'batchmean')
                loss = kl_loss(output, target)

            test_loss_comp.append(loss.item())

        test_losses_comp.append(sum(test_loss_comp) / len(test_loader))
        test_loss_comp.clear()
        
        return sum(test_losses_comp)/len(test_losses_comp)
        
def plot_train_test_loss(epochen, train_losses, test_losses):  
    fig = plt.figure()
    plt.plot(epochen, train_losses,  color='blue',  label='Train Loss')
    plt.plot(epochen, test_losses[1:],  color='red', label='Test Loss')
    plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
    plt.xlabel('number of training examples seen')
    plt.ylabel('loss')
    plt.show()
  
def plot_3d_spect_vergleich(gt_data, nn_data, winkel1=90, winkel2=0):
    # gt_data: y_train daten, Ground Truth
    # nn_data: vorhersage des NN
    # winkel1: Winkel zur Grundfläche
    # winkel2: Rotation der x-y achsen
    
    fig = plt.figure(figsize=(12, 6))
    x = np.linspace(0, 60, 60)
    y = np.linspace(0, 60, 60)
    X, Y = np.meshgrid(x, y)
    Z = gt_data.reshape(60, 60) # Ground Truth
    
    axes0 = fig.add_subplot(121, projection='3d')
    axes0.view_init(winkel1, winkel2) # erste Zahl Winkel zur Grundfläche, zweite Zahl rotation der x-y achsen
    axes0.plot_surface(X, Y, Z, rstride=1, cstride=1,cmap='viridis', edgecolor='none')
    axes0.set_title("realer Datenpunkt");
    
    Z = nn_data.reshape(60, 60) 
    axes1 = fig.add_subplot(122, projection='3d')
    axes1.contour3D(X, Y, Z, 50, cmap='binary')
    axes1.view_init(winkel1, winkel2) # erste Zahl Winkel zur Grundfläche, zweite Zahl rotation der x-y achsen
    axes1.plot_surface(X, Y, Z, rstride=1, cstride=1,cmap='viridis', edgecolor='none')
    axes1.set_title("geschätzte Verteilung");
    
def dice_loss(inputs, targets, smooth=1):
    # flatten label and prediction tensors
    inputs = inputs.reshape(-1) 
    targets = targets.reshape(-1) 

    intersection = (inputs * targets).sum()                            
    dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  

    return dice
    
def gmm_eval(network, seg_mask, n_comp=4):
    # seg_mask = 'brain_tissue', 'wm', 'gm', 'csf'
    
    # load invivo daten
    invivo_data = scipy.io.loadmat(r"path/to/in-vivo.mat")

    # numpy array
    invivo_tensor = torch.from_numpy(invivo_data["m"]).to(torch.float32)

    # get prediction
    network.to("cpu")
    pred_invivo = network(invivo_tensor)

    # tensor back to numpy
    pred_invivo_numpy = pred_invivo.detach().numpy()

    # reshape into correct form
    pred_invivo = np.reshape(pred_invivo_numpy, (128, 128, 60, 60), order = "f") # order f benötigt
    
    # load segmentation maps
    seg_data = scipy.io.loadmat(r"path/to/seg_data.mat")
    if seg_mask == "brain_tissue":
        mask_data = scipy.io.loadmat(r"path/to/mask.mat")
        mask = mask_data["mask"]
    elif seg_mask == "wm":
        mask = seg_data["WM"]
    elif seg_mask == "gm":
        mask = seg_data["GM"]
    elif seg_mask == "csf":
        mask = seg_data["CSF"]
        
    t1_data = np.linspace(50.0, 3500.0, num=60).reshape(1,-1)
    t2_data = np.linspace(5.0, 400.0, num=60).reshape(1,-1)

    colormap_cycle = ['Blues', 'Reds', 'Greens', 'Oranges', 'Purples', 'BuGn', 'YlOrBr', 'PuBu', 'PuRd', 'YlGNBu', 'Greys']

    # apply mask to spectra
    F_gmm_ = pred_invivo
    nx = np.shape(F_gmm_)[0]
    ny = np.shape(F_gmm_)[1]
    q = np.shape(F_gmm_)[2]
    mask = mask
    F_gmm_ = np.multiply(F_gmm_, np.repeat(np.repeat(mask.reshape(np.shape(mask)[0], np.shape(mask)[1], 1, 1), q, axis=2), q, axis=3))

    # threshold F_gmm_
    F_gmm__thresh = F_gmm_.copy()
    F_gmm__thresh[F_gmm__thresh < 1e-2] = 0

    # convert spectra into counts
    C_thresh = F_gmm__thresh * 1. / F_gmm__thresh[F_gmm__thresh > 0].min()
    C_thresh = np.round(C_thresh)
    C_thresh = C_thresh.astype(int)

    C_thresh_mask = C_thresh[mask > 0]

    C_thresh_mask_sum = np.sum(C_thresh_mask, axis=0)
    C_thresh_mask_sum[C_thresh_mask_sum < 1000] = 0

    plt.figure(figsize=(7, 5))
    plt.contour(C_thresh_mask_sum, levels=60)
    plt.colorbar()
    plt.show()

    # Put counts into T1-T2-list
    t1_thresh = []
    t2_thresh = []
    for r in range(q):
        for c in range(q):
            t1_thresh = t1_thresh + [r] * C_thresh_mask_sum[r, c]  # [data['T1'][0][r]]*C_thresh_mask_sum[r, c]
            t2_thresh = t2_thresh + [c] * C_thresh_mask_sum[r, c]  # [data['T2'][0][c]]*C_thresh_mask_sum[r, c]

    # F_gmm_ormat for scatter
    X = np.zeros([len(t1_thresh), 2])
    X[:, 0] = t1_thresh
    X[:, 1] = t2_thresh

    # run GMM
    gm = GaussianMixture(n_components=n_comp, covariance_type='full', verbose=True).fit(X)
    labels = gm.predict(X)
    labels = labels + 1  # to set only background to 0

    # Go back to physical ranges
    X_phys = X.copy()
    for t1_ind, t1 in enumerate(t1_data[0]):
        X_phys[X[:, 1] == t1_ind, 0] = t1
    for t2_ind, t2 in enumerate(t2_data[0]):
        X_phys[X[:, 0] == t2_ind, 1] = t2

    # plot of all gaussians for average spectrum
    examplary_voxels = 1000  # increasing this will make the plotting take longer
    plt.figure(figsize=(4, 4))
    for i in range(1, n_comp + 1):
        ind = np.where(labels == i)
        ind = np.random.permutation(ind[0])
        df = pd.DataFrame(data={'T1 (ms)': X_phys[ind, 0], 'T2 (ms)': X_phys[ind, 1]})
        sns.kdeplot(df['T1 (ms)'], df['T2 (ms)'], shade=True, shade_lowest=False, alpha=0.6, antialiased=True, bw=1,
                    cmap=colormap_cycle[i - 1])
        plt.xlim(0, np.max(t1_data))
        plt.ylim(0, np.max(t2_data))
        plt.xlabel('T1 (ms)', size=14)
        plt.ylabel('T2 (ms)', size=14)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)

    plt.show()

    # voxel wise labeling
    label_array = np.zeros([nx, ny, n_comp])
    label_array_mask = np.zeros([C_thresh_mask.shape[0], n_comp])

    count_skip = 0

    for i in range(C_thresh_mask.shape[0]):
        t1_thresh_test = []
        t2_thresh_test = []
        try:
            for r in range(q):
                for c in range(q):
                    t1_thresh_test = t1_thresh_test + [r] * C_thresh_mask[i, r, c]  # [data['T1'][0][r]]*C_thresh_mask[i,r, c]
                    t2_thresh_test = t2_thresh_test + [c] * C_thresh_mask[i, r, c]  # [data['T2'][0][c]]*C_thresh_mask[i,r, c]

            X_test = np.zeros([len(t1_thresh_test), 2])
            X_test[:, 0] = t1_thresh_test
            X_test[:, 1] = t2_thresh_test

            labels_out = gm.predict(X_test)
            counts_total = labels_out.shape[0]
            for n in range(n_comp):
                count = np.where(labels_out == n)[0]

                try:
                    label_array_mask[i, n] = count.shape[0] / counts_total
                except ZeroDivisionError:
                    label_array_mask[i, n] = 0.
        except:
            count_skip = count_skip + 1

    label_array[mask > 0, :] = label_array_mask

    # plot all maps besides
    seg_pred = [] # liste mit segmentierungen der vorhersagen
    x = n_comp  # len
    y = 4       # width
    for n in range(0, n_comp):   

        fig = plt.figure(figsize=(16, 16)) 

        # compartmental volume fraction map
        plt.subplot(x, y, 1).set_title(f'Compartment Nr.{n + 1}')
        plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[]) # deactivate xy ticks
        plt.imshow(label_array[:, :, n], vmin=0, vmax=1)  
        seg_pred.append(label_array[:, :, n])

        # gaussian average spectrum
        plt.subplot(x, y, 2).set_title(f'Compartment Nr.{n + 1}') 
        ind = np.where(labels == n + 1)
        ind = np.random.permutation(ind[0])
        df = pd.DataFrame(data={'T1 (ms)': X_phys[ind, 0], 'T2 (ms)': X_phys[ind, 1]})
        sns.kdeplot(df['T1 (ms)'], df['T2 (ms)'], shade=True, shade_lowest=False, alpha=0.8, antialiased=True, bw=1,
                    cmap=colormap_cycle[n - 1])
        plt.xlim(0, np.max(t1_data))
        plt.ylim(0, np.max(t2_data))
        plt.xlabel('T1 (ms)', size=14)
        plt.ylabel('T2 (ms)', size=14)

        # histogram data over the threshhold
        plt.subplot(x, y, 3).set_title(f'Compartment Nr.{n + 1}')
        hist_list = list(filter(lambda x: x >= 0.1, label_array[:, :, n].reshape(-1))) # alle elemente unter einem threshhold aus der liste entfernen
        plt.hist(hist_list, density=True, bins = 100) 
        mn, mx = plt.xlim()
        plt.xlim(mn, mx)
        kde_xs = np.linspace(mn, mx, 300)
        kde = st.gaussian_kde(hist_list)
        plt.plot(kde_xs, kde.pdf(kde_xs), label="PDF_gmm_")
        plt.ylabel('Vorkommen', size=14)
        plt.xlabel('Wahrscheinlichkeit', size=14)
        
        # find out which compartment it is by getting the dice loss of each
        seg_wert_GM = dice_loss(seg_pred[n], seg_data['GM'], smooth=1)
        seg_wert_WM = dice_loss(seg_pred[n], seg_data['WM'], smooth=1)
        seg_wert_CSF = dice_loss(seg_pred[n], seg_data['CSF'], smooth=1)
        
        print('der dice wert von gm, wm, csf: ', seg_wert_GM, seg_wert_WM, seg_wert_CSF)
        
        plt.subplot(x, y, 4).set_title(f'Compartment Nr.{n + 1}')
        #plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[]) # deactivate xy ticks
        
        if seg_wert_GM > seg_wert_WM and seg_wert_GM > seg_wert_CSF:
            plt.imshow(seg_data['GM'] - seg_pred[n], vmin=0, vmax=1)
        elif seg_wert_WM > seg_wert_GM and seg_wert_WM > seg_wert_CSF:
            plt.imshow(seg_data['WM'] - seg_pred[n], vmin=0, vmax=1)
        elif seg_wert_CSF > seg_wert_WM and seg_wert_CSF > seg_wert_GM:
            plt.imshow(seg_data['CSF'] - seg_pred[n], vmin=0, vmax=1)
        
        plt.show()
        
    print('---------------------------------------------------------------------------')

In [None]:
# get data
train_loader, test_loader, x_train, y_train, x_test, y_test, x_test_comp, y_test_comp = load_data()

params = ['params_to_test'] 

# safe all trained networks for gmm analysis
networks = []
# safe all losses to plot them together
train_losses_all = []
test_losses_all = []

for j in params:
    
    param_to_test = j
    
    print("Parameter: ", str(j))
    
    # training
    network, optimizer = preprocess()

    # Define variable that should save the loss for each iteration
    train_loss = []
    train_losses = []
    test_loss = []
    test_losses = []

    # do the actual training
    network = train_test(network, optimizer)

    train_losses_all.append(train_losses)
    test_losses_all.append(test_losses)

    # loss Eval
    epochen = [j  for j in range(len(train_losses))]
    plot_train_test_loss(epochen, train_losses, test_losses) # plot both
    
    # comp losses
    for k in range(0, 4):
        comp_loss = test_comp(network, x_test_comp[k], y_test_comp[k])
        print('loss bei ', str(k + 1), 'compartments: ', comp_loss)
    
    # plot spectra prediction
    if use_plot == True:
        network = network.cpu()
        plt.rcParams["figure.figsize"] = (10,6)
        print("train")
        for m in range(0, 2):
            plot_3d_spect_vergleich(y_train[m].cpu(), predict(network, x_train[m]).detach().numpy())
        print("test")
        for i in range(0, 2):
            plot_3d_spect_vergleich(y_test[i].cpu(), predict(network, x_test[i]).detach().numpy())
            
    networks.append(network)

    # save modell 
    Path = "model_trained_with_" + str(j)
    torch.save(network, Path)     
    
    # clear cache for next cycle 
    torch.cuda.empty_cache()

In [None]:
# plot train loss of all tests
fig = plt.figure()
for i in range(0, len(train_losses_all)):
    plt.plot([j for j in range(len(train_losses_all[i]))], train_losses_all[i], label='parameter: '+str(params[i]))
    
plt.legend(['parameter '+str(params[h]) for h in range(0, len(train_losses_all))], loc='upper right')

# plot test loss of all tests
fig = plt.figure()
for i in range(0, len(test_losses_all)):
    plt.plot([j for j in range(len(test_losses_all[i]))], test_losses_all[i], label='parameter: '+str(params[i]))
    
plt.legend(['parameter '+str(params[h]) for h in range(0, len(test_losses_all))], loc='upper right')

In [None]:
# gmm eval
for n in range(0, len(params)):
    print(params[n])
    gmm_eval(networks[n], 'brain_tissue', n_comp=4)
    
    print('einzelne GMM Analysen')
    
    print('White Matter Analysis for Parameter', params[n])
    gmm_eval(networks[n], 'wm', n_comp=3)
    print('Gray Matter Analysis for Parameter', params[n])
    gmm_eval(networks[n], 'gm', n_comp=3)
    print('CSF Analysis for Parameter', params[n])
    gmm_eval(networks[n], 'csf', n_comp=3)