*Multi-task learning in Self eXplainable Deep Neural Networks*
==============================================================

***Master's thesis - code***

*Part 2 - Preparatory Code*

**Author:** *Adrian Domagała*

# Preparatory Code: Defining Models, Essential Classes, and Functions

This section contains the essential code used throughout the entire study. It includes definitions of model performance metrics, model definitions, functions used for model optimization, functions utilized in Single-Task Learning (STL) and Multi-Task Learning (MTL) approaches, and functions for visualizing the obtained results.

## Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
import os
import copy
import dill
import itertools
import openml

import torch 
import torch.nn as nn
from torch.nn import MSELoss, BCELoss
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam

import sklearn
import sklearn.preprocessing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from scipy.stats import logistic

import optuna
import dataclasses
import joblib
from functools import partial
from typing import Literal, get_args, Dict

import lime
import lime.lime_tabular
import lime.lime_base

## Defining Configurations


In order to standardize and simplify the definition of functions, I introduced the storage of variables in the form of dataclasses.

In [None]:
@dataclasses.dataclass
class Folders:
    base: str
    mtl: str = 'MTL/'
    dataset: str = 'Datasets/'
    study: str = 'Studies/'

@dataclasses.dataclass
class Files:
    data: str
    targets: str
    study: str

@dataclasses.dataclass
class Paths:
    data: str
    targets: str
    study: str

@dataclasses.dataclass
class Training:
    batch_size: int = 64
    num_epochs: int = 50
    patience: int = 10
    early_stopping = True
    lr: float = 0.001
    optimizer = torch.optim.Adam
    epoch_log: int = 5

@dataclasses.dataclass
class Tuning():
    name_layers: str = 'num_layers'
    max_layers: int = 5
    min_layers: int = 1
    name_neurons: str = 'num_neurons'
    max_neurons: int = None
    min_neurons: int = None
    sampler = optuna.samplers.TPESampler
    num_trials: int = 25

@dataclasses.dataclass    
class Models():
    def __init__(self, base_dir):
        self.mlp_opt: str = base_dir + "mlp_opt_"
        self.reg: str = base_dir + 'reg_'
        self.stl_reg: str = base_dir + 'stl_reg_'
        self.mtl: str = base_dir + 'mtl_alpha_'
        self.ext: str = '.pt'

@dataclasses.dataclass
class Result():
    metrics: float
    fid: float

@dataclasses.dataclass
class Results():
    reg: Dict[int, float] = dataclasses.field(default_factory=dict) 
    stl_mlp: Dict[int, float] = dataclasses.field(default_factory=dict) 
    stl_reg: Dict[int, float] = dataclasses.field(default_factory=dict) 
    mtl: Dict[int, dict] = dataclasses.field(default_factory=dict) 
    stl_gnf: Dict[int, float] = dataclasses.field(default_factory=dict)
    mtl_gnf: Dict[int, Dict[str, float]] = dataclasses.field(default_factory=dict) 

@dataclasses.dataclass
class Features():
    values_names = None
    names = None
    categorical = None
    numerical = None
    categorical_indices = None
    numerical_indices = None
    dummy = None

@dataclasses.dataclass
class Config:
    def __init__(self, folders: Folders, files: Files, training: Training, tuning: Tuning):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.training = training
        self.tuning = tuning
        self.folders = folders
        self.files = files
        self.models = Models(self.folders.base)
        self.paths = Paths(
            data=self.folders.dataset+self.files.data,
            targets=self.folders.dataset+self.files.targets,
            study=self.folders.study+self.files.study
        )
        self.input_size = None
        self.results = Results()
        self.alpha_list = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
        self.features = Features()
        self.best_parameters = None
        self.num = 5
        self.metrics_label = ''

DIRECTION = Literal['max', 'min']

## Metrics

### Global Fidelity

In [None]:
class GlobalFidelity(nn.Module):
    def __init__(self):
        super(GlobalFidelity, self).__init__()

    def forward(self, predictions, targets):
        loss = torch.mean((predictions - targets) ** 2)
        return loss

### Logarithm of the hyperbolic cosine 

In [None]:
class LogHCos(nn.Module):
    def __init__(self):
        super(LogHCos, self).__init__()

    def forward(self, predictions, targets):
        loss = torch.mean(torch.log(torch.cosh(predictions-targets)))
        return loss

### Accuracy score

In [None]:
class AccuracyScore(nn.Module):
    def __init__(self):
        super(AccuracyScore, self).__init__()

    def forward(self, predictions, targets):
        loss = (predictions == targets).count_nonzero() / predictions.shape[0]
        return loss

## Glass-box Models

### Linear Regression 

In [None]:
class LinearRegression(nn.Module):
    def __init__(self, input_size):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(input_size, 1)

    def forward(self, x):
        x = self.linear(x)
        return x

### Logistic Regression

In [None]:
class LogisticRegression(nn.Module):
    def __init__(self, input_size):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.linear(x)
        x = self.sigmoid(x)
        return x

## Black-box Models

The models were defined in a way that allows for their subsequent optimization with respect to parameters such as the number of hidden layers and the number of neurons in each layer.

### Multilayer Perceptron for Classification

In [None]:
class MLP_cls(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size=1):
        super(MLP_cls, self).__init__()
        self.hidden_layers = nn.ModuleList()
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        last_size = input_size
        for hidden in hidden_sizes:
            self.hidden_layers.append(nn.Linear(last_size, hidden))
            last_size = hidden
        self.out_layer = nn.Linear(last_size, output_size)

    def forward(self, x):
        for layer in self.hidden_layers:
            x = layer(x)
            x = self.relu(x)
        x = self.out_layer(x)
        x = self.sigmoid(x)
        return x

    @staticmethod
    def get_hidden_sizes(best_params, config: Config):    
        hidden_sizes = [best_params[f'{config.tuning.name_neurons}{i}'] for i in range(best_params[f'{config.tuning.name_layers}'])]
        return hidden_sizes

### Multilayer Perceptron for Regression

In [None]:
class MLP_reg(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size=1):
        super(MLP_reg, self).__init__()
        self.hidden_layers = nn.ModuleList()
        self.relu = nn.ReLU()

        last_size = input_size
        for hidden in hidden_sizes:
            self.hidden_layers.append(nn.Linear(last_size, hidden))
            last_size = hidden
        self.out_layer = nn.Linear(last_size, output_size)

    def forward(self, x):
        for layer in self.hidden_layers:
            x = layer(x)
            x = self.relu(x)
        x = self.out_layer(x)
        return x

    @staticmethod
    def get_hidden_sizes(best_params, config: Config):    
        hidden_sizes = [best_params[f'{config.tuning.name_neurons}{i}'] for i in range(best_params[f'{config.tuning.name_layers}'])]
        return hidden_sizes

## Single-Task Learning Approach

### Training Function

In [None]:
def train_model(dataloader, model, optimizer, criterion, device, **kwargs):
    model = model.to(device)
    model.train()
    for data, target in dataloader:
        optimizer.zero_grad()
        data, target = data.to(device), target.view(-1, 1).to(device)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

### Testing Functions

In [None]:
def eval_reg(data, target, model, criterion, device, **kwargs):
    model = model.to(device)
    model.eval()
    data, target = data.to(device), target.to(device)
    with torch.no_grad():
        output = model(data)
        loss = criterion(output, target)
        return loss

The evaluation function for classification tasks includes rounding the output to ensure the values are 0 or 1. 

In [None]:
def eval_cls(data, target, model, criterion, device, **kwargs):
    model = model.to(device)
    model.eval()
    data, target = data.to(device), target.to(device)

    with torch.no_grad():
        output = model(data)
        output = (output >= 0.5).float()
        loss = criterion(output, target)
        return loss

### Training with early stopping criterion

Function to train the model with a specific early stopping criterion. The function evaluates the model's results on the validation set to halt training at the optimal moment, achieving the best possible results and simultaneously preventing overfitting.

In [None]:
def __get_worst_value(direction):
    if direction == 'max':
        return torch.tensor(0)
    elif direction == 'min':
        return torch.tensor(np.finfo(float).max)

def __get_compare_func(direction):
    if direction == 'max':
        return float.__gt__
    elif direction == 'min':
        return float.__lt__
    

def train_with_early_stopping(data_train, target_train, data_eval, target_eval, model, criterion_train, 
               criterion_eval, eval_func, direction: DIRECTION, config: Config, 
               num_epoch=None, patience=None, **kwargs):
    
    assert direction in get_args(DIRECTION), f"Direction should be one of: {get_args(DIRECTION)}, currently is {direction}"
    if num_epoch == None:
        num_epoch = config.training.num_epochs
    if patience == None:
        patience = config.training.patience

    train_dataset = TensorDataset(data_train, target_train)
    train_dataloader = DataLoader(train_dataset, batch_size=config.training.batch_size, shuffle=True)
    optimizer = config.training.optimizer(model.parameters(), lr=config.training.lr)

    best_result = __get_worst_value(direction)
    compare = __get_compare_func(direction)
    epochs_with_no_improvement = 0

    best_model = None

    for epoch in range(num_epoch):
        train_model(
            dataloader=train_dataloader, 
            model=model, 
            optimizer=optimizer, 
            criterion=criterion_train,
            device=config.device
        )

        result = eval_func(
            data=data_eval, 
            target=target_eval, 
            model=model, 
            criterion=criterion_eval, 
            device=config.device
        )
        
        if compare(result.item(), best_result.item()):
            best_result = result
            best_model = copy.deepcopy(model.state_dict())
            epochs_with_no_improvement = 0
        else:
            epochs_with_no_improvement += 1
        
        if epoch % config.training.epoch_log == 0:
            print(f"Epoch: {epoch}, result ({criterion_eval._get_name()}): {result}")

        if epochs_with_no_improvement >= patience:
            print(f"Early stopping on epoch: {epoch}, best results from epoch: {epoch-patience} result ({criterion_eval._get_name()}): {best_result}")
            model.load_state_dict(best_model)
            break
    else:
        print(f"Training ended on maximum epoch: {num_epoch}")
        
    return best_result

### Utility Functions

Function for loading, or training and saving models.

In [None]:
def load_or_train(func, model_path, model, overwrite=False, **kwargs):
    if overwrite:
        func(model=model, **kwargs)
        torch.save(model.state_dict(), model_path)
    else:
        try:
            model.load_state_dict(torch.load(model_path))
            print(f'Model {model_path} loaded successfully')
        except:
            print(f'Model {model_path} cannot be loaded. Training has started...')
            func(model=model, **kwargs)
            torch.save(model.state_dict(), model_path)

### Functions for Conducting the Experiments

Functions to perform the training and testing process a specified number of times (default is 5). The collected results are stored in the configuration. Specifically:

- For regression:

    - Multilayer Perceptron (MLP)
    - Linear Regression
    - Single Task Learning (MLP and Linear Regression)

- For classification:
    - MLP
    - Logistic Regression
    - Single Task Learning (MLP and Logistic Regression)

For Regression

In [None]:
def train_and_test_mlp_reg(X_train, y_train, X_eval, y_eval, X_test, y_test, config: Config, num=None):
    if num == None:
        num = config.num
        
    for i in range(num):
        model = MLP_reg(input_size=config.input_size, hidden_sizes=MLP_reg.get_hidden_sizes(config.best_parameters, config), output_size=1)
        load_or_train(
            func=train_with_early_stopping, 
            model_path=config.models.mlp_opt + f'{str(i)}' + config.models.ext,
            model=model,
            data_train=X_train,
            target_train=y_train,
            data_eval=X_eval,
            target_eval=y_eval,
            criterion_train=LogHCos(),
            criterion_eval=MSELoss(),
            eval_func=eval_reg,
            direction='min',
            config=config
        )

        config.results.stl_mlp[i]=eval_reg(
            X_test,
            y_test,
            model,
            MSELoss(), 
            config.device
        ).item()

In [None]:
def train_and_test_lin_reg(X_train, y_train, X_eval, y_eval, X_test, y_test, config: Config, num=None):
    if num == None:
        num = config.num
    for i in range(num):
        model = LinearRegression(config.input_size)

        load_or_train(
            func=train_with_early_stopping, 
            model_path=config.models.reg + f'{str(i)}' + config.models.ext,
            model=model,
            data_train=X_train,
            target_train=y_train,
            data_eval=X_eval,
            target_eval=y_eval,
            criterion_train=LogHCos(),
            criterion_eval=MSELoss(),
            eval_func=eval_reg,
            direction='min',
            config=config
        )

        config.results.reg[i] = eval_reg(
            data=X_test,
            target=y_test,
            model=model,
            criterion=MSELoss(), 
            device=config.device
        ).item()

In [None]:
def train_and_test_stl_lin_reg(X_train, X_eval, X_test, config: Config, num=None):
    if num == None:
        num = config.num
    for i in range(num):
        mlp = MLP_reg(input_size=config.input_size, hidden_sizes=MLP_reg.get_hidden_sizes(config.best_parameters, config)).to(config.device)
        mlp.load_state_dict(torch.load(config.models.mlp_opt + str(i) + config.models.ext))
        reg = LinearRegression(input_size=config.input_size).to(config.device)

        with torch.no_grad():
            y_pred_train = mlp(X_train.to(config.device))
            y_pred_eval = mlp(X_eval.to(config.device))
            y_pred_test = mlp(X_test.to(config.device))

        load_or_train(
            func=train_with_early_stopping, 
            model_path=config.models.stl_reg + str(i) + config.models.ext,
            model=reg,
            data_train=X_train,
            target_train=y_pred_train,
            data_eval=X_eval,
            target_eval=y_pred_eval,
            config=config,
            criterion_train=GlobalFidelity(),
            criterion_eval=GlobalFidelity(),
            eval_func=eval_reg,
            direction='min'
        )

        config.results.stl_reg[i] = eval_reg(
            data=X_test,
            target=y_pred_test,
            model=reg,
            criterion=GlobalFidelity(), 
            device=config.device
        ).item()

For Classification

In [None]:

def train_and_test_mlp_cls(X_train, y_train, X_eval, y_eval, X_test, y_test, config: Config, num=None):
    if num == None:
        num = config.num
    for i in range(num):
        model = MLP_cls(input_size=config.input_size, hidden_sizes=MLP_cls.get_hidden_sizes(config.best_parameters, config), output_size=1)
        load_or_train(
            func=train_with_early_stopping, 
            model_path=config.models.mlp_opt + f'{str(i)}' + config.models.ext,
            model=model,
            data_train=X_train,
            target_train=y_train,
            data_eval=X_eval,
            target_eval=y_eval,
            criterion_train=BCELoss(),
            criterion_eval=AccuracyScore(),
            eval_func=eval_cls,
            direction='max',
            config=config
        )

        config.results.stl_mlp[i]=eval_cls(
            X_test,
            y_test,
            model,
            AccuracyScore(), 
            config.device
        ).item()

In [None]:

def train_and_test_lin_cls(X_train, y_train, X_eval, y_eval, X_test, y_test, config: Config, num=None):
    if num == None:
        num = config.num
    for i in range(num):
        model = LogisticRegression(config.input_size)

        load_or_train(
            func=train_with_early_stopping, 
            model_path=config.models.reg + f'{str(i)}' + config.models.ext,
            model=model,
            data_train=X_train,
            target_train=y_train,
            data_eval=X_eval,
            target_eval=y_eval,
            criterion_train=BCELoss(),
            criterion_eval=AccuracyScore(),
            eval_func=eval_cls,
            direction='max',
            config=config
        )

        config.results.reg[i] = eval_cls(
            data=X_test,
            target=y_test,
            model=model,
            criterion=AccuracyScore(), 
            device=config.device
        ).item()

In [None]:
def train_and_test_stl_lin_cls(X_train, X_eval, X_test, config: Config, num=None):
    if num == None:
        num = config.num
    for i in range(num):
        mlp = MLP_cls(input_size=config.input_size, hidden_sizes=MLP_cls.get_hidden_sizes(config.best_parameters, config)).to(config.device)
        mlp.load_state_dict(torch.load(config.models.mlp_opt + str(i) + config.models.ext))
        reg = LogisticRegression(input_size=config.input_size).to(config.device)

        with torch.no_grad():
            y_pred_train = (mlp(X_train.to(config.device)) >= 0.5).float()
            y_pred_eval = (mlp(X_eval.to(config.device)) >= 0.5).float()
            y_pred_test = (mlp(X_test.to(config.device)) >= 0.5).float()

        load_or_train(
            func=train_with_early_stopping, 
            model_path=config.models.stl_reg + str(i) + config.models.ext,
            model=reg,
            data_train=X_train,
            target_train=y_pred_train,
            data_eval=X_eval,
            target_eval=y_pred_eval,
            config=config,
            criterion_train=GlobalFidelity(),
            criterion_eval=GlobalFidelity(),
            eval_func=eval_cls,
            direction='min'
        )

        config.results.stl_reg[i] = eval_cls(
            data=X_test,
            target=y_pred_test,
            model=reg,
            criterion=GlobalFidelity(), 
            device=config.device
        ).item()

## Multi-Task Learning approach

In multi-task learning, two tasks were specified: the first is a predictive task based on a black-box model (multilayer perceptron), and the second is an explainability task using a linear surrogate model. The soft parameter sharing approach was applied, characterized by a separate set of parameters (weights) for each task and the absence of shared layers.  Communication between tasks occurs through information flow mechanisms; in the discussed example, this is achieved via a shared loss function, which is a convex combination of the neural network's and the surrogate model's loss functions. (The dependencies between the loss functions are regularized by the $\alpha $ parameter.)

### Model definition

In [None]:
class MTL_mlp_linear(nn.Module):
    def __init__(self, mlp, linear):
        super(MTL_mlp_linear, self).__init__()
        self.mlp = mlp
        self.linear = linear

    def forward(self, x):
        mlp_output = self.mlp(x)
        linear_output = self.linear(x)
        return mlp_output, linear_output

### Training and Testing functions

To evaluate the surrogate model, the global fidelity metric was used, whereas the neural network model was assessed using either mean squared error loss or accuracy score. The former for regression tasks and the other for classifications. It is also worth mentioning that global fidelity for classification was calculated using rounded outputs (only values 0 or 1) for both the surrogate and the MLP models. (It might be worth further analysis to consider the fidelity results for probabilities returned directly by the sigmoid function in both models).

In [None]:
def train_mtl_model(dataloader, model, optimizer, criterion_mlp, criterion_reg, alpha, device, **kwargs):
    model = model.to(device)
    model.train()
    for data, target in dataloader:
        optimizer.zero_grad()
        data, target = data.to(device), target.view(-1, 1).to(device)
        mlp_output, reg_output = model(data)
        loss_mlp = criterion_mlp(mlp_output, target.view(-1, 1))
        loss_reg = criterion_reg(reg_output, mlp_output)
        loss = alpha * loss_mlp + (1-alpha) * loss_reg
        loss.backward()
        optimizer.step()

In [None]:
def eval_mtl_cls(data, targets, model, accuracy, fidelity, device):
    model = model.to(device)
    data, targets = data.to(device), targets.to(device)
    model.eval()
    with torch.no_grad():
        pred_mlp, pred_reg = model(data)
        pred_mlp = (pred_mlp >= 0.5).float()
        pred_reg = (pred_reg >= 0.5).float()
    acc_score = accuracy(pred_mlp, targets)
    fid_score = fidelity(pred_reg, pred_mlp)
    return acc_score.item(), fid_score.item()

In [None]:
def eval_mtl_reg(data, targets, model, mse, fidelity, device):
    model = model.to(device)
    data, targets = data.to(device), targets.to(device)
    model.eval()
    with torch.no_grad():
        pred_mlp, pred_reg = model(data)
    mse_score = mse(pred_mlp, targets)
    fid_score = fidelity(pred_reg, pred_mlp)
    return mse_score.item(), fid_score.item()

### Training with early stopping criterion

Compared to the STL approach, evaluating whether to apply the early stopping criterion in MTL is somewhat more complex. This complexity arises because the model evaluation function simultaneously returns two values: the result of the neural network and the result of the surrogate model. Therefore, I applied a convex combination similar to the one in training function, based on the same parameter alpha. Ipso facto I obtained a model that performs optimally on both tasks, while maintaining the proportion of importance established during training.

In [None]:
def train_mtl_early_stopping_cls(data_train, target_train, data_eval, target_eval, model, alpha, config: Config, **kwargs):
    
    train_dataset = TensorDataset(data_train, target_train)
    train_dataloader = DataLoader(train_dataset, batch_size=config.training.batch_size, shuffle=True)
    optimizer = config.training.optimizer(model.parameters(), lr=config.training.lr)

    best_acc = 0
    best_fid = np.finfo(float).max
    best_result = np.finfo(float).min
    epochs_with_no_improvement = 0

    bce_loss = BCELoss()
    global_fidelity = GlobalFidelity()
    accuracy = AccuracyScore()

    best_model = None

    for epoch in range(config.training.num_epochs):
        train_mtl_model( 
            dataloader=train_dataloader, 
            model=model,
            optimizer=optimizer, 
            criterion_mlp=bce_loss,
            criterion_reg=global_fidelity,
            alpha=alpha,
            device=config.device
        )

        acc, fid = eval_mtl_cls(
            data=data_eval, 
            targets=target_eval, 
            model=model,
            accuracy=accuracy,
            fidelity=global_fidelity, 
            device=config.device
        )

        current_result = alpha * acc - (1-alpha) * fid
        if current_result > best_result:
            best_result = current_result
            best_model = copy.deepcopy(model.state_dict())
            epochs_with_no_improvement = 0
            best_acc, best_fid = acc, fid
        else:
            epochs_with_no_improvement += 1
        
        if epoch % config.training.epoch_log == 0:
            print(f"Epoch: {epoch}, acc: {acc}, fid: {fid}")

        if epochs_with_no_improvement >= config.training.patience:
            print(f"Early stopping on epoch: {epoch}, best results from epoch: {epoch-config.training.patience}")
            print(f"Accuracy: {best_acc}, Fidelity: {best_fid}")
            model.load_state_dict(best_model)
            break
        
    return best_acc, best_fid

In [None]:
def train_mtl_early_stopping_reg(data_train, target_train, data_eval, target_eval, model, alpha, config: Config, **kwargs):
    
    train_dataset = TensorDataset(data_train, target_train)
    train_dataloader = DataLoader(train_dataset, batch_size=config.training.batch_size, shuffle=True)
    optimizer = config.training.optimizer(model.parameters(), lr=config.training.lr)

    best_mse = np.finfo(float).max
    best_fid = np.finfo(float).max
    best_result = np.finfo(float).max
    epochs_with_no_improvement = 0

    loghcos = LogHCos()
    fidelity = GlobalFidelity()
    mse_loss = MSELoss()

    best_model = None

    for epoch in range(config.training.num_epochs):
        train_mtl_model( 
            dataloader=train_dataloader, 
            model=model,
            optimizer=optimizer, 
            criterion_mlp=loghcos,
            criterion_reg=fidelity,
            alpha=alpha,
            device=config.device
        )

        mse, fid = eval_mtl_reg(
            data=data_eval, 
            targets=target_eval, 
            model=model,
            mse=mse_loss,
            fidelity=fidelity,
            device=config.device
        )

        current_result = alpha * mse + (1-alpha) * fid
        if current_result < best_result:
            best_result = current_result
            best_model = copy.deepcopy(model.state_dict())
            epochs_with_no_improvement = 0
            best_mse, best_fid = mse, fid
        else:
            epochs_with_no_improvement += 1
        
        if epoch % config.training.epoch_log == 0:
            print(f"Epoch: {epoch}, mse: {mse}, fid: {fid}")

        if epochs_with_no_improvement >= config.training.patience:
            print(f"Early stopping on epoch: {epoch}, best results from epoch: {epoch-config.training.patience}")
            print(f"MSE: {best_mse}, Fidelity: {best_fid}")
            model.load_state_dict(best_model)
            break
        
    return best_mse, best_fid

### Functions for Conducting the Experiments

Functions for training and testing for given alpha parameter

In [None]:
def train_and_test_mtl_cls(data_train, data_eval, data_test, target_train, target_eval, target_test, alpha, config: Config, path=None, save_load=False, model_params=None):
    if not model_params: model_params = config.best_parameters
    
    mlp = MLP_cls(input_size=config.input_size, hidden_sizes=MLP_cls.get_hidden_sizes(model_params, config))
    reg = LogisticRegression(input_size=config.input_size)
    model = MTL_mlp_linear(mlp=mlp, linear=reg)

    if save_load:
        load_or_train(
            func=train_mtl_early_stopping_cls, 
            model=model,
            model_path=path,
            data_train=data_train,
            data_eval=data_eval,
            target_train=target_train,
            target_eval=target_eval,
            alpha=alpha,
            config=config
        )
    else:
        train_mtl_early_stopping_cls(
            data_train=data_train, 
            target_train=target_train, 
            data_eval=data_eval, 
            target_eval=target_eval, 
            model=model, 
            alpha=alpha, 
            config=config
        )
        
    acc, fid = eval_mtl_cls(
        data=data_test, 
        targets=target_test, 
        model=model, 
        device=config.device,
        accuracy=AccuracyScore(),
        fidelity=GlobalFidelity()
    )
    return acc, fid


In [None]:
def train_and_test_mtl_reg(data_train, data_eval, data_test, target_train, target_eval, target_test, alpha, config: Config, path=None, save_load=False, model_params=None):
    if not model_params: model_params = config.best_parameters
    
    mlp = MLP_reg(input_size=config.input_size, hidden_sizes=MLP_reg.get_hidden_sizes(model_params, config))
    lin = LinearRegression(input_size=config.input_size)
    model = MTL_mlp_linear(mlp=mlp, linear=lin)

    if save_load:
        load_or_train(
            func=train_mtl_early_stopping_reg, 
            model=model,
            model_path=path,
            data_train=data_train,
            data_eval=data_eval,
            target_train=target_train,
            target_eval=target_eval,
            alpha=alpha,
            config=config
        )
    else:
        train_mtl_early_stopping_reg(
            data_train=data_train, 
            target_train=target_train, 
            data_eval=data_eval, 
            target_eval=target_eval, 
            model=model, 
            alpha=alpha, 
            config=config
        )
        
    mse, fid = eval_mtl_reg(
        data=data_test, 
        targets=target_test, 
        model=model, 
        device=config.device,
        mse=MSELoss(),
        fidelity=GlobalFidelity()
    )
    return mse, fid

Functions for training and testing models for list of alpha (default values from a closed interval [0, 1] with a step of 0.1). Acquired results are stored in config.  

In [None]:
def train_and_test_mtl_cls_for_alpha_list(data_train, data_eval, data_test, 
                           target_train, target_eval, target_test, 
                           num, config: Config, alpha_list=None, model_params=None):
    
    if not alpha_list: alpha_list = config.alpha_list
    config.results.mtl[num] = {}

    for alpha in alpha_list:
        print(f'Model for alpha: {alpha}')
        path = config.models.mtl + str(alpha) + '_' + str(num) + config.models.ext
        acc, fid = train_and_test_mtl_cls(
            data_train=data_train,
            data_eval=data_eval,
            data_test=data_test,
            target_train=target_train,
            target_eval=target_eval,
            target_test=target_test,
            alpha=alpha,
            path=path,
            save_load=True,
            config=config
        )
        print(f'Result on test dataset: {path}, Accuracy: {acc}, Fidelity: {fid}')
        config.results.mtl[num][str(alpha)] = Result(acc, fid)

In [None]:
def train_and_test_mtl_reg_for_alpha_list(data_train, data_eval, data_test, 
                           target_train, target_eval, target_test, 
                           num, config: Config, alpha_list=None, model_params=None):
    
    if not alpha_list: alpha_list = config.alpha_list
    config.results.mtl[num] = {}

    for alpha in alpha_list:
        print(f'Model for alpha: {alpha}')
        path = config.models.mtl + str(alpha) + '_' + str(num) + config.models.ext
        mse, fid = train_and_test_mtl_reg(
            data_train=data_train,
            data_eval=data_eval,
            data_test=data_test,
            target_train=target_train,
            target_eval=target_eval,
            target_test=target_test,
            alpha=alpha,
            path=path,
            save_load=True,
            config=config
        )
        print(f'Result on test dataset: {path}, MSE: {mse}, Fidelity: {fid}')
        config.results.mtl[num][str(alpha)] = Result(mse, fid)

Functions to perform the training and testing process a specified number of times (default is 5). The collected results are stored in the configuration.

In [None]:
def train_and_test_mtl_reg_n_times(data_train, target_train, data_eval, target_eval, data_test, target_test, config: Config, num=None):
    if num == None:
        num = config.num
    for i in range(num):
        train_and_test_mtl_reg_for_alpha_list(
            data_train=data_train,
            data_eval=data_eval,
            data_test=data_test,
            target_train=target_train,
            target_eval=target_eval,
            target_test=target_test,
            config=config,
            num=i
        )

In [None]:
def train_and_test_mtl_cls_n_times(data_train, target_train, data_eval, target_eval, data_test, target_test, config: Config, num=None):
    if num == None:
        num = config.num
    for i in range(num):
        train_and_test_mtl_cls_for_alpha_list(
            data_train=data_train,
            data_eval=data_eval,
            data_test=data_test,
            target_train=target_train,
            target_eval=target_eval,
            target_test=target_test,
            config=config,
            num=i
        )

## Optimization 

The optimization is conducted using the Optuna library. Parameters that were optimized:
- number of hidden layers (ranging from 1 to 5),
- number of neurons per layer (ranging from one-quater to four times the number of input features).

Other parameters such as optimizer or learning rate were taken from the discussed reasearch. 

### Loading function 

Other optimization functions are defined individually for each dataset.

In [None]:
def try_load_study(func, config: Config, overwrite = False):
    if overwrite:
        return func(config)
    try:
        study = joblib.load(config.paths.study)
    except:
        study = func(config)
    return study

## Plotting functions

Function presenting six plots illustrating the following relationships:

- The relationship between the black-box model evaluation metric and the alpha parameter value.
- The relationship between the surrogate model evaluation metric (global fidelity), and the alpha parameter value.
- The comparison of results for the STL (Single-Task Learning) and MTL (Multi-Task Learning) approaches.

In [None]:
def draw_plots_MTL_vs_STL(metrics_label, metrics_title, config: Config, org_met=None, org_fid=None, title=None):
    COLORS = ['lightcoral', 'red', 'peru', 'moccasin', 'gold', 'greenyellow', 'lime', 'turquoise', 'skyblue', 'royalblue', 'violet', 'black']
    fig, axes = plt.subplots(3, 2, figsize=(12, 18))

    if title:
        fig.suptitle(title, fontsize=18)

    metrics_score_list = []
    fid_score_list = []

    for alpha in config.alpha_list:
        temp_metric = []
        temp_fid = []
        for i in range(config.num):
            temp_metric.append(config.results.mtl[i][str(alpha)].metrics)
            temp_fid.append(config.results.mtl[i][str(alpha)].fid)
        metrics_score_list.append(np.mean(temp_metric))
        fid_score_list.append(np.mean(temp_fid))

    alpha_list = config.alpha_list
    stl_metrics_score = np.mean(list(config.results.stl_mlp.values()))
    stl_fid_score = np.mean(list(config.results.stl_reg.values()))

    # MSE/alpha
    axes[0,0].plot(metrics_score_list, alpha_list, marker='o')
    axes[0,0].set_ylabel('alpha')
    axes[0,0].set_yticks(ticks=alpha_list)
    axes[0,0].set_xlabel(metrics_label)
    axes[0,0].set_title(label=f'{metrics_label}/alpha[0, 1]')
    axes[0,0].text(-0.1, 1.05, '(1)', ha='left', va='top', fontsize=12, transform=axes[0,0].transAxes)

    axes[0,1].plot(metrics_score_list[1:-1], alpha_list[1:-1], marker='o')
    axes[0,1].set_ylabel('alpha')
    axes[0,1].set_yticks(ticks=alpha_list[1:-1])
    axes[0,1].set_xlabel(metrics_label)
    axes[0,1].set_title(label=f'{metrics_label}/alpha(0, 1)')
    axes[0,1].xaxis.set_major_locator(plt.MaxNLocator(7))
    axes[0,1].text(-0.1, 1.05, '(2)', ha='left', va='top', fontsize=12, transform=axes[0,1].transAxes)


    if org_met:
        axes[0,1].plot(org_met, alpha_list[1:-1], marker='o', color='darkviolet')
        axes[0,1].legend(labels=('own', 'orig.'))


    # Fidelity/alpha
    axes[1,0].plot(fid_score_list, alpha_list, marker='o')
    axes[1,0].set_ylabel('alpha')
    axes[1,0].set_yticks(ticks=alpha_list)
    axes[1,0].set_xlabel('fidelity')
    axes[1,0].set_title(label='Fidelity/alpha[0, 1]')
    axes[1,0].text(-0.1, 1.05, '(3)', ha='left', va='top', fontsize=12, transform=axes[1,0].transAxes)


    axes[1,1].plot(fid_score_list[1:-1], alpha_list[1:-1], marker='o')
    axes[1,1].set_ylabel('alpha')
    axes[1,1].set_yticks(ticks=alpha_list[1:-1])
    axes[1,1].set_xlabel('fidelity')
    axes[1,1].set_title(label='Fidelity/alpha(0, 1)')
    axes[1,1].xaxis.set_major_locator(plt.MaxNLocator(7))
    axes[1,1].text(-0.1, 1.05, '(4)', ha='left', va='top', fontsize=12, transform=axes[1,1].transAxes)


    if org_fid:
        axes[1,1].plot(org_fid, alpha_list[1:-1], marker='o', color='darkviolet')
        axes[1,1].legend(labels=('own', 'orig.'))

    # MSE/Fidelity alpha [0, 1]
    metrics_score_list_ex = metrics_score_list + [stl_metrics_score]
    fid_score_list_ex = fid_score_list + [stl_fid_score]
    alpha = [f'{chr(945)}={alpha}' for alpha in alpha_list]+['STL    ']
    alpha_legend = alpha
    legend_handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=15) for color in COLORS]
    legend = [f"{alpha} ({'{:.4f}'.format(round(float(metrics), 4))} / {'{:.4f}'.format(round(fid, 4))})" for alpha, metrics, fid in zip(alpha_legend, metrics_score_list_ex, fid_score_list_ex)]

    linear_met = np.mean(list(config.results.reg.values()))

    axes[2,0].scatter(metrics_score_list_ex, fid_score_list_ex, marker='o', s=100, edgecolors='black', c=COLORS)
    axes[2,0].set_title(f'{metrics_title}/Fidelity (with STL)')
    axes[2,0].legend(legend_handles, legend)
    axes[2,0].set_ylabel('fidelity')
    axes[2,0].set_xlabel(metrics_label)
    axes[2,0].axvline(x = linear_met, color = 'seagreen', label = 'linear')
    axes[2,0].text(linear_met, max(fid_score_list_ex), 'linear' + f" {metrics_label}: {linear_met:.4f}", color='seagreen', ha='right', va='top', rotation=90, style='italic')
    axes[2,0].text(-0.1, 1.05, '(5)', ha='left', va='top', fontsize=12, transform=axes[2,0].transAxes)


    # MSE/Fidelity alpha (0, 1)
    metrics_score_list_ex = [round(float(x), 4) for x in metrics_score_list[1:-1] + [stl_metrics_score]]
    fid_score_list_ex = [float(fid) for fid in fid_score_list[1:-1] + [stl_fid_score]]
    alpha = [f'{chr(945)}={alpha}' for alpha in alpha_list[1:-1]]+['STL    ']
    colors = COLORS[1:-2] + ['black']
    plot_data = [[alpha, metrics, fid, col] for alpha, metrics, fid, col in sorted(zip(alpha, metrics_score_list_ex, fid_score_list_ex, colors), key=lambda l: l[2], reverse=reversed)]

    legend_handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=15) for color in [pd[3] for pd in plot_data]]
    legend = [f"{alpha} ({'{:.4f}'.format(round(float(metrics), 4))} / {'{:.4f}'.format(round(fid, 4))})" for alpha, metrics, fid, _ in plot_data]

    axes[2,1].scatter(metrics_score_list_ex, fid_score_list_ex, marker='o', s=100, edgecolors='black', c=colors)
    axes[2,1].set_title(f'{metrics_title}/Fidelity (with STL)')
    axes[2,1].legend(handles=legend_handles, labels=legend, title=f'alpha ({metrics_label}/fid)', loc='center left', bbox_to_anchor=(1, 0.5))
    axes[2,1].set_ylabel('fidelity')
    axes[2,1].set_xlabel(metrics_label)
    axes[2,1].axvline(x = linear_met, color = 'seagreen', label = 'linear')
    axes[2,1].text(linear_met, max(fid_score_list_ex), 'linear' + f" {metrics_label}: {linear_met:.4f}", color='seagreen', ha='right', va='top', rotation=90, style='italic')
    axes[2,1].xaxis.set_major_locator(plt.MaxNLocator(7))
    axes[2,1].text(-0.1, 1.05, '(6)', ha='left', va='top', fontsize=12, transform=axes[2,1].transAxes)


    plt.tight_layout()
    fig.subplots_adjust(top=0.93, left=0.15)
    plt.show()

Functions to show MTL results in form of table. 

In [None]:
def get_results_non_lin_vs_lin(config: Config, metrics_label=None):
        if metrics_label == None:
            metrics_label = config.metrics_label
        return pd.DataFrame({
            'Metrics': [metrics_label],
            'Dataset': [config.folders.base.split('/')[0]],
            'Linear': [np.mean(list(config.results.reg.values()))],
            'Non-linear (MLP)': [np.mean(list(config.results.stl_mlp.values()))]
        })

def get_results_mtl_vs_stl(config: Config, metrics_label=None):
    if metrics_label == None:
            metrics_label = config.metrics_label
    df = pd.DataFrame({
        'Metrics': [metrics_label, 'Global Fidelity'],
        'Dataset': [config.folders.base.split('/')[0], config.folders.base.split('/')[0]],
        'Linear': [np.mean(list(config.results.reg.values())), '-'],
        'STL': [np.mean(list(config.results.stl_mlp.values())), np.mean(list(config.results.stl_reg.values()))]
    })

    for alpha in config.alpha_list:
        temp_metric = []
        temp_fid = []
        for i in range(config.num):
            temp_metric.append(config.results.mtl[i][str(alpha)].metrics)
            temp_fid.append(config.results.mtl[i][str(alpha)].fid)
        df[f'MTL a= {str(alpha)}'] = [np.mean(temp_metric), np.mean(temp_fid)]
    return df
    

def show_tables(config: Config, metrics_label=None):
    if metrics_label == None:
        metrics_label = config.metrics_label
    try:
        df1 = pd.read_pickle(config.folders.base + 'results_lin_vs_non_lin.pkl')
    except:
        df1 = get_results_non_lin_vs_lin(config, metrics_label)
        df1.to_pickle(config.folders.base + 'results_lin_vs_non_lin.pkl')

    display(df1.style.format(precision=4, decimal=",").hide())

    try:
        df2 = pd.read_pickle(config.folders.base + 'results_stl_vs_mtl.pkl')
    except:
        df2 = get_results_mtl_vs_stl(config, metrics_label)
        df2.to_pickle(config.folders.base + 'results_stl_vs_mtl.pkl')

    display(df2.style.format(precision=4, decimal=",").hide())

Functions to show GNF results in form of table. 

In [None]:
def get_gnf_df(config: Config):
    df = pd.DataFrame({
        'Metric': ['GNF'],
        'Dataset': [config.folders.base.split('/')[0]],
        'Linear': ['-'],
        'STL': [np.mean(config.results.stl_gnf)]

    })
    for alpha in config.alpha_list:
        temp_list = []
        for i in range(config.num):
            temp_list.append(config.results.mtl_gnf[i][str(alpha)])
        df[f'MTL a= {str(alpha)}'] = [np.mean(temp_list)]
    return df

def get_gnf_summary_config(config_list):
    df_list = []
    for config in config_list:
        df_list.append(get_gnf_df(config))
    return pd.concat(df_list)

def get_prediction_results(config: Config, metrics_label, n=0):
    df = pd.DataFrame({
        'Metric': [metrics_label],
        'Dataset': [config.folders.base.split('/')[0]],
        'Linear': [config.results.reg[n]],
        'STL': [config.results.stl_mlp[n]]
    })
    for alpha in config.alpha_list:
        df[f'MTL a= {str(alpha)}'] = config.results.mtl[n][str(alpha)].metrics
    return df

def get_gnf_and_metrics(config: Config, metrics_label=None, num=0):
    if metrics_label == None:
            metrics_label = config.metrics_label
    try:
        df = pd.read_pickle(config.folders.base + 'results_gnf.pkl')
    except:
        df_gnf = get_gnf_df(config)
        df_res = get_prediction_results(config, metrics_label, num)
        df = pd.concat([df_gnf, df_res])
        df.to_pickle(config.folders.base + 'results_gnf.pkl')
    return df


def show_tables_lime(config: Config, metrics_label=None, num=0):
    display(get_gnf_and_metrics(config, metrics_label, num).style.format(precision=3, decimal=",").hide())

## Local Explainability with Lime

### Prediction Functions

Definitions of the prediction functions for classification - LIME requires a minimum of two classes; therefore, the prediction function should return two probability values. Also prediction function should include appropriate data transformations, as the models were trained on encoded data.

In [None]:
def _X_to_tensor(X, config:Config):
    X = pd.DataFrame(X, columns=config.features.names)
    X = pd.get_dummies(X, columns=config.features.categorical, dtype=float)
    X = X.reindex(columns=config.features.dummy, fill_value=0.0).values
    return torch.tensor(X, dtype=torch.float32).to(config.device)

In [None]:
def predict_cls_for_lime(X, model, config: Config):
    X = _X_to_tensor(X=X, config=config)
    model.eval()
    with torch.no_grad():
        pred = model(X).cpu().numpy()
    return np.hstack((1 - pred, pred))

In [None]:
def predict_mtl_cls_for_lime(X, model, config: Config):
    X = _X_to_tensor(X=X, config=config)
    model.eval()
    with torch.no_grad():
        pred = model(X)[0].cpu().numpy()
    return np.hstack((1 - pred, pred))

Definitions of the prediction functions for regression 

In [None]:
def predict_reg_for_lime(X, model, config: Config):
    X = _X_to_tensor(X=X, config=config)
    model.eval()
    with torch.no_grad():
        pred = model(X).cpu().numpy()  
    return pred

In [None]:
def predict_mtl_reg_for_lime(X, model, config: Config):
    X = _X_to_tensor(X=X, config=config)
    model.eval() 
    with torch.no_grad():
        pred = model(X)[0].cpu().numpy()
    return pred

### Neighbors Generation Functions

Function for generating neighbors ($|N_x| = 10$)
- For continuous features: using a normal distribution $\mathcal{N}(𝑥,𝜇,𝜎^2)$
where $µ=0, σ^2=0.1$
- For categorical features: lack of information in the discussed study 

In [None]:
STD_DEV = np.sqrt(0.1)
def generate_neighbors_continuous_only(instance, config: Config, mean=0, std_deviation=STD_DEV, num_neighbors=10, features_num_idx=None):
    if features_num_idx==None:
        features_num_idx = config.features.numerical_indices
    perturbations = np.random.normal(loc=0, scale=std_deviation, size=(num_neighbors, len(features_num_idx)))
    neighbors = np.repeat(instance.reshape(1, -1), num_neighbors, axis=0)
    neighbors[:, features_num_idx] += perturbations
    return neighbors

### Functions for getting LIME predictions

In [None]:
def get_lime_prediction_reg(instance, exp):
    pred = sum([x[1]*instance[x[0]] for x in exp.local_exp[1]]) + exp.intercept[1]
    return pred

In [None]:
def get_lime_prediction_cat(instance, exp, categorical_indices):
    result = 0
    for x in exp.local_exp[1]:
        if x[0] in categorical_indices:
            result += x[1]
        else:
            result += x[1]*instance[x[0]]
    return result + exp.intercept[1]

In [None]:
def get_lime_prediction(local_neighbors, exp, categorical_indices):
    if categorical_indices:
        lime_predictions = torch.tensor([get_lime_prediction_cat(x, exp, categorical_indices) for x in local_neighbors])
    else:
        lime_predictions = torch.tensor([get_lime_prediction_reg(x, exp) for x in local_neighbors])
    return lime_predictions

### Global Neighborhood Fidelity

Global Neighborhood Fidelity is defined as the averaged value of Neighborhood Fidelity for all data points. Neighborhood Fidelity, in turn, is defined as the value of Global Fidelity in the local neighborhood of a given point.

Global Neighborhood Fidelity for regression

In [None]:
def global_neighborhood_fidelity_reg(model, neighbors_dataset, explainer, predict_func_lime, predict_func_model, config: Config, num_neighbors=10, n=1):
    result = []
    fidelity = GlobalFidelity()
    for i in range(n):
        fidelity_score = []
        for instance in neighbors_dataset:
            exp = explainer.explain_instance(instance, predict_func_lime)
            local_neighbors = generate_neighbors_continuous_only(instance, config, num_neighbors)
            lime_predictions = get_lime_prediction(local_neighbors=local_neighbors, exp=exp, categorical_indices=config.features.categorical_indices)
            model_predictions = torch.tensor(np.array([pred for pred in predict_func_model(local_neighbors, model, config)]))
            fidelity_score.append(fidelity(lime_predictions, model_predictions))
            
        result.append(torch.mean(torch.tensor(fidelity_score)))
    return result

Global Neighborhood Fidelity for classification

In [None]:
def global_neighborhood_fidelity_cls(model, neighbors_dataset, explainer, predict_func_lime, predict_func_model, config: Config, num_neighbors=10, n=1):
    result = []
    fidelity = GlobalFidelity()
    for i in range(n):
        fidelity_score = []
        for instance in neighbors_dataset:
            exp = explainer.explain_instance(instance, predict_func_lime)
            local_neighbors = generate_neighbors_continuous_only(instance, config, num_neighbors) 
            lime_predictions = (get_lime_prediction(local_neighbors=local_neighbors, exp=exp, categorical_indices=config.features.categorical_indices) >= 0.5).float()
            model_predictions = (torch.tensor(np.array([pred2 for pred1, pred2 in predict_func_model(local_neighbors, model, config)])) >= 0.5).float()
            fidelity_score.append(fidelity(lime_predictions, model_predictions))
        result.append(torch.mean(torch.tensor(fidelity_score)))
    return result

### Functions for Conducting the Experiments

Global Neighborhood Fidelity for MTL regression

In [None]:
def gnf_for_mtl_reg(neighbors_dataset, explainer, model_params, path, config: Config, num_neighbors=10):
    mlp = MLP_reg(input_size=config.input_size, hidden_sizes=MLP_reg.get_hidden_sizes(model_params, config))
    lin = LinearRegression(input_size=config.input_size)
    model = MTL_mlp_linear(mlp=mlp, linear=lin).to(config.device)
    
    predict_func_lime = partial(predict_mtl_reg_for_lime, model=model, config=config)
    model.load_state_dict(torch.load(path))

    return global_neighborhood_fidelity_reg(
        model=model,
        neighbors_dataset=neighbors_dataset,
        explainer=explainer, 
        predict_func_lime=predict_func_lime,
        predict_func_model=predict_mtl_reg_for_lime,
        config=config, 
        num_neighbors=num_neighbors
    )


In [None]:
def gnf_for_mtl_reg_for_alpha_list(neighbors_dataset, explainer, model_params, config: Config, num_neighbors=10, alpha_list=None, n=5):
    if not n: n = config.num
    if not alpha_list: alpha_list=config.alpha_list
    for i in range(n):
        config.results.mtl_gnf[i] = {}
        for alpha in alpha_list:
            path = config.models.mtl + str(alpha) + '_0' + config.models.ext
            config.results.mtl_gnf[i][str(alpha)] = gnf_for_mtl_reg(
                neighbors_dataset=neighbors_dataset,
                explainer=explainer, 
                model_params=config.best_parameters,
                path=path,
                config=config, 
                num_neighbors=num_neighbors
            )
            print(f'Model: {path}, GNF: {config.results.mtl_gnf[i][str(alpha)]}')

Global Neighborhood Fidelity for MTL classification

In [None]:
def gnf_for_mtl_cls(neighbors_dataset, explainer, model_params, path, config: Config, num_neighbors=10):
    mlp = MLP_cls(input_size=config.input_size, hidden_sizes=MLP_cls.get_hidden_sizes(model_params, config))
    lin = LogisticRegression(input_size=config.input_size)
    model = MTL_mlp_linear(mlp=mlp, linear=lin).to(config.device)

    predict_func_lime = partial(predict_mtl_cls_for_lime, model=model, config=config)
    model.load_state_dict(torch.load(path))
    return global_neighborhood_fidelity_cls(
        model=model,
        neighbors_dataset=neighbors_dataset,
        explainer=explainer, 
        predict_func_lime=predict_func_lime,
        predict_func_model=predict_mtl_cls_for_lime,
        config=config, 
        num_neighbors=num_neighbors
    )

In [None]:
def gnf_for_mtl_cls_for_alpha_list(neighbors_dataset, explainer, model_params, config: Config, num_neighbors=10, alpha_list=None, n=5):
    if not n: n = config.num
    if not alpha_list: alpha_list=config.alpha_list
    for i in range(n):
        config.results.mtl_gnf[i] = {}
        for alpha in alpha_list:
            path = config.models.mtl + str(alpha) + '_0' + config.models.ext
            config.results.mtl_gnf[i][str(alpha)] = gnf_for_mtl_cls(
                neighbors_dataset=neighbors_dataset,
                explainer=explainer, 
                model_params=config.best_parameters,
                path=path,
                config=config, 
                num_neighbors=num_neighbors
            )
            print(f'Model: {path}, GNF: {config.results.mtl_gnf[i][str(alpha)]}')