<a href="https://colab.research.google.com/github/RoshanRane/DeepRepViz/blob/main/DeepRepViz_tutorial.ipynb"> <img align="left" src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab" title="Open in Google Colaboratory"></a><a href="https://github.com/RoshanRane/DeepRepViz"><img align="left" src="https://img.shields.io/badge/Github-Download-blue.svg" alt="Download" title="Download Notebook"></a><a href="https://link.springer.com/chapter/10.1007/978-3-031-72117-5_18"><img align="left" src="https://img.shields.io/badge/Read%20the%20Paper-8A2BE2" alt="Download" title="Download Notebook"></a>

# Setup

<span style="color:red">**Alert**</span>
  - Click on `Runtime` in the top menu.<br>
  - Select `Change runtime type` from the dropdown.<br>
  - In the "Hardware accelerator" dropdown, select `T4 GPU`.

In [None]:
# @title Install dependencies <a name="install"></a>

%matplotlib inline
%load_ext autoreload
%autoreload 2

import subprocess
import sys

# Function to install a package
def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet"] + package.split())

# List of packages to check and install
packages = {
    "numpy": "numpy",
    "pandas": "pandas",
    "matplotlib": "matplotlib",
    "statsmodels": "statsmodels",
    "tqdm": "tqdm",
    "torch": "torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118",
    "einops": "einops",
    "lightning": "lightning",
    "dcor": "dcor",
}

# Loop through packages and check if they need to be installed
for lib, install_name in packages.items():
    try:
        __import__(lib)
    except ImportError:
        print(f"{lib} not found, installing...")
        install(install_name)

!pip install gpustat
! gpustat

In [None]:
# @title Download source files from github
!git clone https://github.com/RoshanRane/DeepRepViz.git
!unzip -qq /content/DeepRepViz/data/toybrains_n5000_lblmidr-consite_cy025-cX025-yX050.zip -d /content/DeepRepViz/data
!rm /content/DeepRepViz/data/toybrains_n5000_lblmidr-consite_cy025-cX025-yX050.zip

***Roshan Prakash Rane***<br>
​-----------<br>
PhD Candidate<br>
Department of Psychiatry and Neurosciences, Charité - Universitätsmedizin Berlin, Berlin, Germany<br>
Department of Psychology, Humboldt-Universität zu Berlin, Berlin, Germany<br>
<br>
***JiHoon Kim***<br>
​-----------<br>
PhD Candidate<br>
Max Planck Institute for Human Cognitive and Brain Sciences, Leipzig, Germany<br>
INM-7, Reseasrch Centre Jülich, Jülich, Germany<br>

# DeepRepViz Tutorial Notebook

- [Tutorial Objective](#objective)
- [A Guide to Implementing DeepRepViz in Your Predictive Modeling](#guide)
  - [1. Use the DeepRepViz Callback](#use)
  - [2. Train your model](#train)
  - [3. Compute Metrics with DeepRepVizBackend](#compute)
  - [4. Find the Generated File for Web-Based Visualization](#generate)
  - [Interim summary](#interim)

- [Usecase](#usecase)
  - [Import packages](#packages)
  - [Dataset](#dataset)
  - [Utils](#utils)
  - [Configuration](#config)
  - [Main](#main)
- [References](#ref)
- [Acknolwedgements and Funding](#ack&funding)

## Tutorial Objective <a name="objective"></a>

*Estimated timing of tutorial: 20 minutes*

This notebook demonstrates that how DeeRepViz can be integrated into a predictive deep learning application

In this tutorial, we will show how to use [DeepRepViz](https://link.springer.com/chapter/10.1007/978-3-031-72117-5_18) using [Toy Brains Dataset](https://github.com/RoshanRane/toybrains)
- Learn how to use DeepRepViz callback
- Compute con score and generate the files using DeepRepVizBackend for [online visualization tool](https://deep-rep-viz.vercel.app/)

---
## A Guide to Implementing DeepRepViz in Your Predictive Modeling <a name="guide"></a>

[Callback](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html) in PyTorch Lightning allows you to add custom logic at various training steps, enhancing metrics monitoring and modifying behavior. They enable you to hook useful information during the training process.

### 1. Use the DeepRepViz Callback <a name="use"></a>

To incorporate the DeepRepViz callback, instantiate it and pass it to the Trainer:

```
from deeprepviz.callback import DeepRepViz
import lightning as L

# Initialize DeepRepViz callback
drv = DeepRepViz(...)

# Train model
trainer = L.Trainer(callbacks=[drv])
```

### 2. Train your model <a name="train"></a>

Start the training process with `trainer.fit(model, datamodule=data_module)` to train with the callback enabled.

### 3. Compute Metrics with DeepRepVizBackend <a name="compute"></a>

The DeepRepVizBackend automatically computes metrics like the `Con-score`, and generates files for the web-based visualization tool. This requires the `raw_csv` and the Trainer. For instructions on obtaining the `raw_csv`, please refer to the [documentation](https://deep-rep-viz.vercel.app/docs.html).
```
from deeprepviz.backend import DeepRepvizBackend

# Initialize DeepRepViz Backend
drv_backend = DeepRepVizBackend(...)
log_dir = trainer.log_dir + '/deeprepvizlog/'
drv_backend.load_log(log_dir)

# Compute and generate the files
drv_backend.convert_log_to_v1_table(log_key=log_dir)
```

### 4. Find the Generated File for Web-Based Visualization <a name="generate"></a>

After computing metrics with the DeepRepVizBackend, you can find the generated files in the `log/your-dataset-name_your-model-name/trial_*/deeprepvizlog/` folder. Look for files named `DeepRepViz-*.csv`.

Once you have the file, you can upload it to the web-based visualization tool. For detailed instructions on using the tool, please refer to the [documentation](https://deep-rep-viz.vercel.app/docs.html).

### Interim summary <a name="interim"></a>
We use a `DeepRepViz` callback to capture useful information during training, while the `DeepRepVizBackend` computes the Con-score metric and generates files for [the web-based visualization tool](https://deep-rep-viz.vercel.app/).

---
## Usecase <a name="usecase"></a>

In [None]:
# @title Import packages <a name="packages"></a>

# import standard python packages
from dataclasses import dataclass, field
from datetime import datetime
from glob import glob
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
import os
import numpy as np
import pandas as pd
from PIL import Image
import random
from sklearn.model_selection import train_test_split, StratifiedKFold
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchmetrics
from torchvision import transforms
from typing import Dict, List, Type
import warnings

sys.path.append(os.path.abspath('/content/DeepRepViz'))
warnings.filterwarnings("ignore", category=DeprecationWarning)

# import DeepRepViz packages
from deeprepviz.utils import D2metric
from deeprepviz.callback import (
    DeepRepViz,
    get_all_model_layers,
    get_param_count
    )
from deeprepviz.backend import DeepRepVizBackend

In [None]:
# @title Dataset <a name="dataset"></a>

PATHS = "/content/DeepRepViz/data/toybrains_*"
DATASETS = sorted([os.path.abspath(path) for path in glob(PATHS)])
print(f"Fitting DL model on the following toybrains datasets:\n{DATASETS}")

In [None]:
# @title Utils <a name="utils"></a>

# DataLoader
class ToyBrainsDataloader(Dataset):
  def __init__(self, img_dir, img_names, labels, transform=None):
    self.img_dir = img_dir
    self.img_names = img_names
    self.labels = labels
    self.transform = transform

  def __getitem__(self, index):
    number = str(self.img_names[index]).zfill(5)
    img = Image.open(os.path.join(self.img_dir, number + ".jpg"))

    if self.transform is not None:
      img = self.transform(img)

    label = torch.as_tensor(self.labels[int(index)]).type(torch.LongTensor)
    return img, label

  def __len__(self):
    return self.labels.shape[0]

# Model
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(2,2),
    )

    self._init_weights()

  def _init_weights(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
          nn.init.zeros_(m.bias)
      elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.zeros_(m.bias)

  def forward(self, x):
    x = self.conv(x)
    return x

class SimpleCNN(nn.Module):
  def __init__(self, num_classes, final_act_size=64):
    super().__init__()
    self.final_act_size = final_act_size # weights + 1 bias
    # convolutional layers
    self.conv = nn.Sequential(
        ConvBlock(in_channels=3, out_channels=16),
        nn.Dropout(0.1),
        ConvBlock(in_channels=16, out_channels=32),
        nn.Dropout(0.1),
        ConvBlock(in_channels=32, out_channels=64),
    )

    # fully connected layers
    self.fc = nn.Sequential(
        nn.Flatten(),
        # TODO hardcoded input size
        nn.Linear(64 * 8 * 8, self.final_act_size, bias=True),
        nn.Linear(self.final_act_size, num_classes, bias=True),
    )

  def forward(self, x):
    x = self.conv(x)
    x = self.fc(x)
    return x

# Lightning Module
class LightningModel(L.LightningModule):
    def __init__(self, model, learning_rate,
                 task="binary", num_classes=1):
        '''LightningModule that receives a PyTorch model as input'''
        super().__init__()
        self.learning_rate = learning_rate
        self.model = model
        self.num_classes = num_classes
        # self.metric_acc = torchmetrics.Accuracy(task=task, num_classes=num_classes)
        self._metric_spec = torchmetrics.Specificity(task=task, num_classes=num_classes)
        self._metric_recall = torchmetrics.Recall(task=task, num_classes=num_classes)
        self.metric_D2 = D2metric()

    def forward(self, x):
        return self.model(x)

    def _shared_step(self, batch):
        features, true_labels = batch
        logits = self(features)
        # compute all metrics on the predictions
        if self.num_classes==1:
            logits = torch.squeeze(logits, dim=-1)
            true_labels = true_labels.to(torch.float32)
            # print(logits.shape, true_labels.shape)
            loss = F.binary_cross_entropy_with_logits(logits, true_labels)
            predicted_labels = torch.sigmoid(logits)>0.5

        else:
            loss = F.cross_entropy(logits, true_labels)
            predicted_labels = torch.argmax(logits, dim=1)
        # acc = self.metric_acc(predicted_labels, true_labels)
        # calculate balanced accuracy
        spec = self._metric_spec(predicted_labels, true_labels)
        recall = self._metric_recall(predicted_labels, true_labels)
        BAC = (spec+recall)/2
        D2 = self.metric_D2(logits, true_labels)
        metrics = {'loss':loss, 'BAC':BAC, 'D2':D2}
        return true_labels, logits, metrics

    def training_step(self, batch, batch_idx):
        labels, preds, metrics = self._shared_step(batch)
        # append 'train_' to every key
        log_metrics = {'train_'+k:v for k,v in metrics.items()}
        self.log_dict(log_metrics,
                      prog_bar=True,
                      on_epoch=True, on_step=False)
        return log_metrics['train_loss']

    def validation_step(self, batch, batch_idx):
        labels, preds, metrics = self._shared_step(batch)
        # append 'val_' to every key
        log_metrics = {'val_'+k:v for k,v in metrics.items()}
        self.log_dict(log_metrics,
                      prog_bar=True,
                      on_epoch=True, on_step=False)
        return labels, preds

    def test_step(self, batch, batch_idx):
        labels, preds, metrics = self._shared_step(batch)
        # append 'val_' to every key
        log_metrics = {'test_'+k:v for k,v in metrics.items()}
        self.log_dict(log_metrics,
                      prog_bar=True,
                      on_epoch=True, on_step=False)
        return labels, preds

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        # DeepRepViz returns ids and the normal batch outputs as tuples
        ids, batch = batch
        true_labels, logits, metrics = self._shared_step(batch)
        return ids, true_labels, logits, metrics

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min',
                                                         factor=0.75, patience=3)
        return {
            "optimizer": opt,
            "lr_scheduler": {
                "scheduler": sch,
                "monitor": "val_loss",
                "interval": "epoch", # default
                "frequency": 1, # default
            },
        }

In [None]:
# @title Configuration <a name="config"></a>
@dataclass
class Config:
  """Configuration options for the Ligtning ToyBrains example.

  Attributes:
      random_seed (int or None): Seed for random number generation.
      debug (bool): Flag for enabling debug mode.
      ID_COL (str): Column name for ID.
      trial (str): Trial identifier.
      LABEL_COL (str): Column name for labels.
      additional_drv_test_data (Dict): Additional test data for DRV.
      model_class (Type): Class of the model to use.
      model_kwargs (Dict): Additional keyword arguments for the model.
      learning_rate (float): Learning rate for the optimizer.
      unique_name (str): Unique name for the experiment.
      additional_loggers (List): List of additional loggers.
      additional_callbacks (List): List of additional callbacks.
      early_stop_patience (int): Patience for early stopping.
      batch_size (int): Batch size for training.
      num_workers (int): Number of workers for data loading.
      gen_v1_table (bool): Flag to generate version 1 table.
      k_fold (int): Number of folds for cross-validation.
      no_ood_val (bool): Flag to disable out-of-distribution validation.
      data_dir (str): Directory path for data.
      trainer_args (Dict): Additional arguments for the trainer.

  Examples:
      A new instance of this dataclass can be created as follows:

      >>> config = Config()

      The default values for each argument are shown in the document above. If desired, any of these values can be overriden when creating a new instance of the dataclass:

      >>> config = Config(batch_size=128, max_epochs=5)

  """
  random_seed: int = None
  debug: bool = False
  trial: str = 'trial_0'
  no_ood_val: bool = True
  id_col: str = 'subjectID'
  label_col: str = 'lbl_lesion'
  split_col: str = "datasplit"
  data_dir: str = '/content/DeepRepViz/data/toybrains_n5000_lblmidr-consite_cy025-cX025-yX050'
  additional_drv_test_data: Dict = field(default_factory=dict)
  model_class = SimpleCNN
  model_kwargs: Dict = field(default_factory=lambda: dict(num_classes=1, final_act_size=65))
  learning_rate: float = 0.03
  unique_name: str = ''
  additional_loggers: List = field(default_factory=list)
  additional_callbacks: List = field(default_factory=list)
  early_stop_patience: int = 8
  batch_size: int = 64
  num_workers: int = 2
  k_fold: int = 1
  trainer_args: Dict = field(default_factory=lambda: dict(max_epochs=10, accelerator='gpu', devices=[0]))
  gen_v1_table: bool = True

### Main

In [None]:
start_time = datetime.now()

# Set the configuration
config = Config()

# Forcefully set a random seed in debug mode
random_seed = 42

if config.random_seed is None and config.debug:
  random_seed = 42

if config.random_seed is not None:
  random_seed = config.random_seed
  torch.manual_seed(random_seed)
  np.random.seed(random_seed)
  random.seed(random_seed)
  L.seed_everything(random_seed)

unique_name = config.unique_name

if config.debug:
  os.system('rm -rf log/*debugmode*')
  unique_name = 'debugmode'+unique_name
  config.trainer_args['max_epochs'] = 2 if config.trainer_args['max_epochs'] > 100 else config.trainer_args['max_epochs']
  batch_size = 5
  k_fold = 2 if config.k_fold>2 else config.k_fold
  num_workers = 5
  no_ood_val = True

In [None]:
# Load the dataset
DATA_DIR = os.path.abspath(config.data_dir)
DATA_CSV = glob(DATA_DIR + '/toybrains*.csv')
assert len(DATA_CSV)==1, f"Toybrains dataset found = {DATA_CSV}.\
\nEnsure that that the dataset {config.data_dir} is generated using the `create_toybrains.py` script in the toybrains repo. \
Also ensure only one dataset exists for the given query '{DATA_DIR}'."
DATA_CSV = DATA_CSV[0]
N_SAMPLES = int(DATA_DIR.split('_n')[-1].split('_')[0])

# Collect the corresponding OOD test data
OOD_test_datasets = {}
if not config.no_ood_val:
  test_suffix = '_test'# Hardcoded
  test_nsamples = 1000 # Hardcoded: n samples of toybrains test datasets are 1000
  data_dir_test = re.sub(f'_n{N_SAMPLES}_', f'_n{test_nsamples}_', DATA_DIR) + test_suffix
  data_dir_test_noconf = re.sub('cX...', 'cX000', re.sub('cy...','cy000', data_dir_test))
  assert os.path.exists(data_dir_test_noconf), f"Could not find the equivalent 'no-conf' dataset {data_dir_test_noconf} for {DATA_DIR}"

  data_dir_test_notrue = re.sub('yX...','yX000', data_dir_test)
  assert os.path.exists(data_dir_test_notrue), f"Could not find the equivalent 'no-true' dataset {data_dir_test_notrue} for {dataset}"

  OOD_test_datasets = {'test-no-conf': data_dir_test_noconf, 'test-no-true': data_dir_test_notrue}

In [None]:
# Prepare the data splits as a dataframe mapping the subjectID to the split and trial
data = pd.read_csv(DATA_CSV)
ID_COL, LABEL_COL, SPLIT_COL = config.id_col, config.label_col, config.split_col
assert ID_COL in data.columns, f"ID_COL={ID_COL} is not present in the dataset's csv file. \
Available colnames = {data.columns.tolist()}"
assert LABEL_COL in data.columns, f"LABEL_COL={LABEL_COL} is not present in the dataset's csv file. \
Available colnames = {data.columns.tolist()}"

## SPLITS: Create the n-fold splits for the data
# Drop all columns except subjectID and label
datasplit_df = data.drop(columns=[c for c in data.columns if c not in [ID_COL, LABEL_COL]])
datasplit_df = datasplit_df.set_index(ID_COL)
# Create 'trial_x' columns: init as columns as args.k_fold
for trial in range(config.k_fold):
  datasplit_df[f'trial_{trial}'] = 'unknown'
# First, set aside 20% of the data as test and assign it commonly to all folds
train_idxs, test_idxs = train_test_split(datasplit_df.index, test_size=0.2,
                                         random_state=random_seed)
for trial in range(config.k_fold):
  datasplit_df.loc[test_idxs, f'trial_{trial}'] = 'test'

if config.k_fold <= 1: # if 1 fold then initialize all data to the first trial
  train_idxs, val_idxs = train_test_split(train_idxs, test_size=0.1,
                                          random_state=random_seed)
  datasplit_df.loc[train_idxs, 'trial_0'] = 'train'
  datasplit_df.loc[val_idxs, 'trial_0'] = 'val'
else: # if k-fold then split the data into k times and assign each to a sep trial
  splitter = StratifiedKFold(n_splits=k_fold,
                             shuffle=True,
                             random_state=random_seed)
  splits = splitter.split(train_idxs, y=datasplit_df.loc[train_idxs, LABEL_COL])
  for trial_idx, (train_idxs_i, val_idxs_i) in enumerate(splits):
      datasplit_df.loc[train_idxs[train_idxs_i], f'trial_{trial_idx}'] = 'train'
      datasplit_df.loc[train_idxs[val_idxs_i], f'trial_{trial_idx}'] = 'val'

datasplit_df = datasplit_df.sort_index()
(datasplit_df.filter(like='trial')!='unknown').all(), "some data points are not assigned to any split. {}".format(datasplit_df)

dataset_name = os.path.basename(DATA_DIR)

# Split the dataset as defined in the datasplit_df
if datasplit_df.index.name==ID_COL: datasplit_df = datasplit_df.reset_index()
# Select a specific trial given by 'trial' out of the k-folds
trial = f'trial_{trial}'

datasplit_df = datasplit_df.rename(columns={trial:SPLIT_COL})
datasplit_df = datasplit_df[[ID_COL, LABEL_COL, SPLIT_COL]]

df_train = datasplit_df[datasplit_df[SPLIT_COL]=='train']
df_val = datasplit_df[datasplit_df[SPLIT_COL]=='val']
df_test = datasplit_df[datasplit_df[SPLIT_COL]=='test']

print(f"Dataset: {dataset_name} \n  Training data split = {len(df_train)} \n\
Validation data split = {len(df_val)} \n  Test data split = {len(df_test)}")

In [None]:
# Create pytorch data loaders
train_dataset = ToyBrainsDataloader(
  img_names = df_train[ID_COL].values, # TODO change hardcoded
  labels = df_train[LABEL_COL].values,
  img_dir = DATA_DIR+'/images',
  transform = transforms.Compose([transforms.ToTensor()])
  )
train_loader = DataLoader(
  dataset=train_dataset,
  shuffle=True, batch_size=config.batch_size, drop_last=True,
  num_workers=config.num_workers,
  )

val_dataset = ToyBrainsDataloader(
  img_names = df_val[ID_COL].values,
  labels = df_val[LABEL_COL].values,
  img_dir = DATA_DIR+'/images',
  transform = transforms.Compose([transforms.ToTensor()])
  )
val_loader = DataLoader(
  dataset=val_dataset,
  shuffle=False, batch_size=config.batch_size, drop_last=True,
  num_workers=config.num_workers,
  )

# Create dataloaders for DeepRepViz() with no shuffle
drv_train_dataset = {
    'dataloader_kwargs': dict(img_dir=DATA_DIR+'/images',
                                img_names=df_train[ID_COL].values,
                                labels=df_train[LABEL_COL].values,
                                transform=transforms.ToTensor()),
    "expected_IDs":df_train[ID_COL].values,
    "expected_labels":df_train[LABEL_COL].values,
    }
drv_test_datasets = {
    'val': {
          'dataloader_kwargs': dict(img_dir=DATA_DIR+'/images',
                                    img_names=df_val[ID_COL].values,
                                    labels=df_val[LABEL_COL].values,
                                    transform=transforms.ToTensor()),
          "expected_IDs":df_val[ID_COL].values,
          "expected_labels":df_val[LABEL_COL].values
          },
    'test': {
          'dataloader_kwargs': dict(img_dir=DATA_DIR+'/images',
                                    img_names=df_test[ID_COL].values,
                                    labels=df_test[LABEL_COL].values,
                                    transform=transforms.ToTensor()),
          "expected_IDs":df_test[ID_COL].values,
          "expected_labels":df_test[LABEL_COL].values
          }
      }
# Append any additional test datasets provided too
for testdata_name, testdata_path in config.additional_drv_test_data.items():
  testdata_csv = glob(testdata_path + '/toybrains*.csv')
  assert len(testdata_csv)==1, f"Toybrains Test dataset found = {testdata_csv} in the path {testdata_path} .."
  testdata_df = pd.read_csv(testdata_csv[0])

  drv_test_datasets[testdata_name] = {
        'dataloader_kwargs': dict(img_dir=testdata_path+'/images',
                                  img_names=testdata_df[ID_COL].values,
                                  labels=testdata_df[LABEL_COL].values,
                                  transform=transforms.ToTensor()),
        "expected_IDs":testdata_df[ID_COL].values,
        "expected_labels":testdata_df[LABEL_COL].values
  }

In [None]:
# Load model
model = config.model_class(**config.model_kwargs)
lightning_model = LightningModel(model, learning_rate=config.learning_rate,
                                   num_classes=config.model_kwargs['num_classes'])

# Configure TensorBoardLogger as the main logger
# Create a unique name for the logs based on the dataset, model and user provided suffix
if unique_name != '': unique_name = '_' + unique_name
unique_name = f'{dataset_name}_{config.model_class.__name__}{unique_name}'
logger = TensorBoardLogger(save_dir='log', name=unique_name, version=trial)
if config.additional_loggers: # plus, any additional user provided loggers
  logger = [logger] + config.additional_loggers

print(f"pytorch_total_params = {get_param_count(model)}")
print(get_all_model_layers(model))

In [None]:
## DeepRepViz Callback
# Initalize DeepRepViz callback
drv = DeepRepViz(dataloader_class=ToyBrainsDataloader,
                 dataset_kwargs=drv_train_dataset,
                 datasets_kwargs_test=drv_test_datasets,
                 hook_layer=-1,
                 best_ckpt_by='loss_val', best_ckpt_metric_should_be='min',
                 verbose=int(config.debug))

callbacks = config.additional_callbacks + [drv]
# Add any other callbacks
if config.early_stop_patience:
  callbacks.append(EarlyStopping(monitor="val_loss", mode="min",
                                 patience=config.early_stop_patience))

In [None]:
# Train model
trainer = L.Trainer(callbacks=callbacks,
                      logger=logger,
                      overfit_batches = 5 if config.debug else 0,
                      log_every_n_steps= 2 if config.debug else 50,
                      **config.trainer_args) # deterministic=True
trainer.fit(
    model=lightning_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader)

In [None]:
## DeepRepViz Backend
# Create the DeepRepViz v1 table
if config.gen_v1_table:
  raw_csv_path = glob(f'{DATA_DIR}/*{dataset_name}.csv')[0]

df_data = pd.read_csv(raw_csv_path)
drv_backend = DeepRepVizBackend(
    conf_table=df_data,
    best_ckpt_by='loss_val',
    ID_col=ID_COL, label_col=LABEL_COL)

log_dir = trainer.log_dir + '/deeprepvizlog/'
drv_backend.load_log(log_dir)

# Downsample the activations to 3D if not already done
drv_backend.downsample_activations()

drv_backend.convert_log_to_v1_table(log_key=log_dir, unique_name=unique_name)
drv_backend.debug = False

metrics = ['dcor', 'mi', 'con', 'costeta', 'r2']
existing_metrics = drv_backend.get_metrics(log_dir, ckpt_idx='best')

if existing_metrics is not None:
  metrics = [m for m in metrics if m not in existing_metrics]
  print(f"Skipping {list(existing_metrics.keys())} for {log_dir}. As they have already been computed.")

# Compute and store the metrics in the metametadata.json file of the log_dir
if len(metrics) > 0:
  result = drv_backend.compute_metrics(log_key=log_dir,
                                       metrics=metrics,
                                       #   covariates=['lbl_lesion','cov_site', 'brain-int_fill','shape-midr_curv', 'shape-midr_vol-rad'],
                                       ckpt_idx='best')

total_time = datetime.now() - start_time
minutes, seconds = divmod(total_time.total_seconds(), 60)
print(f"Total runtime: {int(minutes)} minutes {int(seconds)} seconds")

In the `log/your-dataset-name_your-model-name/trial_*/deeprepvizlog/` folder, locate the `DeepRepViz-*.csv` file. After downloading it, upload the file to the web-based visualization tool. For detailed instructions, visit the [documentation page](https://deep-rep-viz.vercel.app/docs.html).

In [None]:
# Summarize the result
all_results = {}
logdirs = sorted([log for log in glob(f"log/toybrains_*/*/deeprepvizlog/") if 'debug' not in log])
print(logdirs)

for logdir in logdirs:
    print("loading:", logdir)
    drv_backend = DeepRepVizBackend()
    drv_backend.load_log(logdir)
    log = drv_backend.deeprepvizlogs[logdir]
    ckpt_idx = log['best_ckpt_idx']
    ckptname, log_ckpt = log['checkpoints'][ckpt_idx]
    logdirname = logdir.split('/')[-4].replace('toybrains-','')
    model_setting = logdirname.split('_')[-1]
    logdirname = logdirname.replace('_'+model_setting, '')
    # print('='*100,'\n', method_name, "at ckpt =", ckptname)
    # print(log.keys())
    # print("Model accuracy =", log_ckpt['metrics'])
    result = {("Model",k): v for k,v in log_ckpt['metrics'].items()}
    for metric_name, metric_scores in log_ckpt['act_metrics'].items():
        # print('-'*100,"\nMetric =", metric_name, '\n', '-'*100,)
        for key in ['lbl_lesion', 'cov_site', 'brain-int_fill', 'shape-midr_curv', 'shape-midr_vol-rad']:
            result.update({(key, metric_name): metric_scores[key]})
            # print("{} = {:.4f}".format(key, metric_scores[key]))

    all_results.update({(model_setting, logdirname): result})

df_results = pd.DataFrame.from_dict(all_results, orient='index')
# Sort the dataframe by the two levels of column headers
df_results = df_results.sort_index(axis=1, level=[0,1]).sort_index()
df_results.style.bar()

## References <a name="ref"></a>

> Rane, R.P., Kim, J., Umesha, A., Stark, D., Schulz, M.A., Ritter, K. (2024). **DeepRepViz: Identifying Potential Confounders in Deep Learning Model Predictions** *Medical Image Computing and Computer Assisted Intervention - MICCAI 2024*, pp 186-196.

> Have a look at toybrains dataset and tutorial: https://github.com/RoshanRane/toybrains

## Acknowledgements and Funding <a name="ack&funding"></a>

> This project was inspired by Google Brain's [projector.tensorflow.org](https://projector.tensorflow.org/), but is more catering towards the medical domain and medical imaging analysis. For implementation, we heavily rely on [3D-scatter-plot from plotly.js](https://plotly.com/javascript/3d-scatter-plots/).

> This project was funded by the DeSBi Research Unit (DFG; KI-FOR 5363; Profject ID 459422098), the consortium SFB/TRR 265 Losing and Regaining Control over Drug Intake (DFG; Project ID 402170461), FONDA (DFG; SFB 1404; Project ID: 414984028) and FOR 5187 (DFG; Project ID: 442075332).