# 𝔻𝕖𝕔𝕖𝕟𝕥ℕ𝕖𝕥: 𝕕𝕚𝕤𝕖𝕟𝕥𝕒𝕟𝕘𝕝𝕖𝕕 𝕟𝕖𝕥

Goal: create a sparse and modular ConvNet

Todos: 
* [ ] delete node (filter) if either no input or no output edges
* [ ] AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_workers_status'
* [ ] cuda error, if one of the decent1x1 has no kernels left - we need at least one input for each 1x1 filter
* [ ] can we keep training if filter gets removed (e.g. at reloading model)
* [ ] need some working filter removing in general - only at reload rn
* [ ] currently commented:             # img, msk = flt.execute() # flattened -> in data/octa500.py
* [ ] make sure code runs without the fda (fourier) library


Notes:
* additionally needed: position, activated channels, connection between channels
* within this layer, a whole filter can be deactivated
* within a filter, single channels can be deactivated
* within this layer, filters can be swapped
* the 'value' in the csv file is random if the CI metric is 'random'
     
* pruning actually doesn't work: https://discuss.pytorch.org/t/pruning-doesnt-affect-speed-nor-memory-for-resnet-101/75814   
* fine tune a pruned model: https://stackoverflow.com/questions/73103144/how-to-fine-tune-the-pruned-model-in-pytorch
* an actual pruning mechanism: https://arxiv.org/pdf/2002.08258.pdf

pip install:
* pytorch_lightning

preprocessing possible:
* flatten layers
* denoise
* crop background


warnings:
C:\Users\Christina\anaconda3\envs\chrisy\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\result.py:211: You called `self.log('unpruned', ...)` in your `on_train_epoch_end` but the value needs to be floating point. Converting it to torch.float32.
C:\Users\Christina\anaconda3\envs\chrisy\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\result.py:211: You called `self.log('unpruned_state', ...)` in your `on_train_epoch_end` but the value needs to be floating point. Converting it to torch.float32.

![uml of code](examples/example_vis/uml.png)

# conventions

id may be image id if available, else batch id

* entry image and mask: entry_id5_0_0_0_mo3_gt2.png
* mat: mat_id10004_size26_0_0_0_mo2_gt2.mat (size - 
* hidden layer: hid_id5_3_8_2.png 
* last layer: pool_2_3_4_cl2.png (global pooling - connected to class n, cl=class)
* activated image: cam_id5_mo3_gt2.png
* activated image gray: camgray_id5_mo3_gt2.png


* circle in: in_2_3_4_ep65.png
* circle out: out_2_3_4_ep65.png

* filter: filter_2_3_4.csv and filter_2_3_4.png

# imports

In [1]:
# =============================================================================
# future imports first
# =============================================================================
from __future__ import print_function
# =============================================================================
# sys
# =============================================================================
import sys 
sys.path.insert(0, "helper")
# =============================================================================
# alphabetic order misc
# =============================================================================
import glob
import math
import matplotlib.pyplot as plt
plt.ioff()
import seaborn as sns
import numpy as np
import os
import pandas as pd
from PIL import Image
import random
import scipy.io
# from sklearn.model_selection import train_test_split
import warnings
# =============================================================================
# torch
# =============================================================================
import torch
import torch.nn as nn
import torch.optim as optim
# import torchvision
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
# from pytorch_lightning.callbacks.model_checkpoint import *
# =============================================================================
# datasceyence
# =============================================================================
from helper.model.decentnet import DecentNet
from helper.visualisation import filter_activation
from helper.visualisation.colour import *
from helper.data.mnist import DataLoaderMNIST
from helper.data.retinamnist import DataLoaderRetinaMNIST
from helper.data.octmnist import DataLoaderOCTMNIST
from helper.data.octa500 import DataLoaderOCTA500
from helper.data.organmnist3D import DataLoaderOrganMNIST3D
from data.transform.octa500_resize import *



In [2]:
seed = 1997 # was 19 before

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

print("cuda available:", torch.cuda.is_available())

debug_mode = False # todo - this is to do some print stuff # changed this from debug_mode"l" to debug mode

print('torch 2.0.0 ==', torch.__version__=='2.0.0')
print('tl 2.1.0 ==', pl.__version__=='2.1.0')

cuda available: True
torch 2.0.0 == False
tl 2.1.0 == False


# Settings

In [3]:
model_kwargs = {
    'in_channels' : 1, # not in use yet
    'n_classes': None, # filled in the dataset
    'out_dim' :  [1, 8, 16, 32], # [1, 8, 16, 32], #[1, 16, 24, 32] # entry, decent1, decent2, decent3
    'grid_size' : 18*18,
    'criterion' : torch.nn.CrossEntropyLoss(), # torch.nn.BCEWithLogitsLoss(),
    'new_cc_mode' : True, # this is for using the new connection cost loss term
    'reset_optimiser_at_update' : True, # needs to be resetted when pruning # not needed anymore
    'optimizer': "sgd", # sgd adamw
    'base_lr': 0.1, #0.001,
    'min_lr' : 0.001, #0.00001,
    'momentum' : 0.9,
    'lr_update' : 100,
    # decentnet
    'cc_weight': 5, # high weight as the cc doesn't change a lot
    'cc_metric' : 'l2', # connection cost metric (for loss) - distance metric # no idea how the torch works oops
    'ci_metric' : 'l2', # todo: should be l2 # channel importance metric (for pruning)
    'cm_metric' : 'not implemented yet', # 'count', # crossing minimisation 
    'update_every_nth_epoch' : 3, # 5 # todo - remove from code
    'pretrain_epochs' : 3, # 20 # todo - remove from code
    'prune_keep' : 0.95, # 0.97, # in each epoch
    'prune_keep_total' : 0.4, # this number is not exact, depends on the prune_keep value
}

train_kwargs = {
    'input_data_csv': ["data_prep/data_octa500.csv"],
    'result_path': "examples/example_results", # "example_results/lightning_logs", # not in use??
    'exp_name': "counter_seems_to_work", # must include dataset name, otherwise mnist is used
    'load_ckpt_file' : "version_3/checkpoints/mf_stage=3.0_counter=13.0_val_f1_macro=0.45_unpruned=3732.ckpt", # "version_0/checkpoints/mu_epoch=14-val_f1_macro=0.41-unpruned=5676.ckpt" # "version_0/checkpoints/epoch=94-unpruned=1600-val_f1=0.67.ckpt", # 'version_94/checkpoints/epoch=26-step=1080.ckpt', # change this for loading a file and using "test", if you want training, keep None
    'load_mode' : True, # True, False
    'dataset' : 'octa500',
    'epochs': None,
    'training_stages': [1] + [5]*2 + [2], # [epochs] pretrain + [epochs between pruning] * pruning amount + [epochs] after-training
    'img_size' : 28, #168, # keep mnist at original size, training didn't work when i increased the size ... # MNIST/MedMNIST 28 × 28 Pixel
    'p_augment' : 0.2, # probabiliby of torchvision transforms of training data (doesn't apply to all transforms) # 0.1 low, 0.5 half, 1 always
    'batch_size': 8, # laptop: 2, pc: 128, # the higher the batch_size the faster the training - every iteration adds A LOT OF comp cost
    'log_every_n_steps' : 50, # lightning default: 50 # needs to be bigger than the amount of steps in an epoch (based on trainset size and batchsize)
    'device': "cuda",
    'num_workers' : 0, # 18, # 18 for seri computer, 0 or 8 for my laptop # make sure smaller than activate dataset sizes
    'train_size' : 5000, # total, none = 0, all = -1  (batch size * forward passes per epoch) # set 0 to skip training and just do testing
    'val_size' : 100, # total, none = 0, all = -1 (batch size * forward passes per epoch) 
    'test_size' : 100, # total, none = 0, all = -1 (batch size * forward passes per epoch)
    'octa500_id' : 200-1, # not in use - we use preselected data from a csv
    'xai_done' : False, # DO NOT CHANGE, WILL BE CHANGED IN CODE
}

print("train kwargs", train_kwargs)
print("model kwargs", model_kwargs)

kwargs = {'train_kwargs':train_kwargs, 'model_kwargs':model_kwargs}

train kwargs {'input_data_csv': ['data_prep/data_octa500.csv'], 'result_path': 'examples/example_results', 'exp_name': 'counter_seems_to_work', 'load_ckpt_file': 'version_3/checkpoints/mf_stage=3.0_counter=13.0_val_f1_macro=0.45_unpruned=3732.ckpt', 'load_mode': True, 'dataset': 'octa500', 'epochs': None, 'training_stages': [1, 5, 5, 2], 'img_size': 28, 'p_augment': 0.2, 'batch_size': 8, 'log_every_n_steps': 50, 'device': 'cuda', 'num_workers': 0, 'train_size': 5000, 'val_size': 100, 'test_size': 100, 'octa500_id': 199, 'xai_done': False}
model kwargs {'in_channels': 1, 'n_classes': None, 'out_dim': [1, 8, 16, 32], 'grid_size': 324, 'criterion': CrossEntropyLoss(), 'new_cc_mode': True, 'reset_optimiser_at_update': True, 'optimizer': 'sgd', 'base_lr': 0.1, 'min_lr': 0.001, 'momentum': 0.9, 'lr_update': 100, 'cc_weight': 5, 'cc_metric': 'l2', 'ci_metric': 'l2', 'cm_metric': 'not implemented yet', 'update_every_nth_epoch': 3, 'pretrain_epochs': 3, 'prune_keep': 0.95, 'prune_keep_total': 0

## check the values

In [4]:
# i have to check where the 6000 comes from, should be calculated

breaking = 6000*model_kwargs['prune_keep_total']
weights = 6000 # this value is an estimate for a model [1, 8, 16, 32]
# 'unpruned' is the logger variable for the value

pairs = []
print("weights that stay after epoch")
for i in range(len(train_kwargs['training_stages'])):
    
    if (weights < breaking): # weights*model_kwargs['prune_keep']
        print("stop:", breaking)
        print('you need at least this many epochs:', i)
        print('you currently have this many epochs:', len(train_kwargs['training_stages']))
        print("recommended to add 2*update_every_nth_epoch")
        break
    
    # not sure whether -1 is correct, have to check
    #if i >= model_kwargs['pretrain_epochs'] and ((i-1)%model_kwargs['update_every_nth_epoch'] == 0):
    if i >= train_kwargs['training_stages'][0] and i< train_kwargs['training_stages'][-1] and weights > breaking: #  and ((i-1)%model_kwargs['update_every_nth_epoch'] == 0):
        weights = int(weights*model_kwargs['prune_keep'])
    else:
        pass
    
    print(i, weights)
    pairs.append((i, weights))


print(f"First 5 pairs: {pairs[:5]}")
print(f"Last 5 pairs: {pairs[-5:]}")
print(f"Total pairs: {len(pairs)}")

# print(f"Min i: {min([i for i, a in your_data])}, Max i: {max([i for i, a in your_data])}")
    

weights that stay after epoch
0 6000
1 5700
2 5700
3 5700
First 5 pairs: [(0, 6000), (1, 5700), (2, 5700), (3, 5700)]
Last 5 pairs: [(0, 6000), (1, 5700), (2, 5700), (3, 5700)]
Total pairs: 4


# Data

## Dataset
* the dataset name needs to be part of the experiment name

In [11]:
if 'octmnist' in train_kwargs['dataset']:
    # OCTMINST + weights for loss due to heavy imbalanced data
    dataloader = DataLoaderOCTMNIST(train_kwargs, model_kwargs)  
    print("")
    all_labels = []
    # Extract all labels from the DataLoader
    for inputs, labels in dataloader.train_dataloader:
        all_labels.append(labels.flatten())
    # Concatenate all labels into a single tensor
    
    all_labels = torch.cat(all_labels)
    sorted_labels, sorted_indices = torch.sort(all_labels)
    # Count the occurrences of each class
    class_counts = torch.bincount(sorted_labels)
    # Calculate weights (inverse of class frequency)
    class_weights = 1.0 / class_counts.float()
    # Normalize weights (optional, but recommended for stability)
    class_weights = class_weights / class_weights.sum()
    print("class_counts", class_counts, "class_weights:", class_weights) 
    if torch.isnan(class_weights).any(): 
        print("DECENT INFO: dataset too small, no weighting used") 
    else:
        model_kwargs["criterion"] = torch.nn.CrossEntropyLoss(weight=class_weights)
    
    # class_mapper = ['cnv', 'dr', 'amd', 'healthy']
elif 'retinamnist' in train_kwargs['dataset']:
    # RetinaMNIST
    dataloader = DataLoaderRetinaMNIST(train_kwargs, model_kwargs)
    
elif 'octa500' in train_kwargs['dataset']:
    # OCTA-500
    dataloader = DataLoaderOCTA500(train_kwargs, model_kwargs)
elif '3d' in train_kwargs['dataset']:
    dataloader = DataLoaderOrganMNIST3D(train_kwargs, model_kwargs)
else:
    print("select a valid dataset")
    
class_mapper = dataloader.info["label"]

train
Empty DataFrame
Columns: [img_id, img_path, msk_path, mode, lbl_disease, sex, os-od, age]
Index: []
0
val
Empty DataFrame
Columns: [img_id, img_path, msk_path, mode, lbl_disease, sex, os-od, age]
Index: []
0
test
     img_id                                           img_path  \
0     10001  C://Users/Prinzessin/projects/decentnet/datasc...   
1     10004  C://Users/Prinzessin/projects/decentnet/datasc...   
2     10005  C://Users/Prinzessin/projects/decentnet/datasc...   
3     10007  C://Users/Prinzessin/projects/decentnet/datasc...   
4     10008  C://Users/Prinzessin/projects/decentnet/datasc...   
..      ...                                                ...   
175   10288  C://Users/Prinzessin/projects/decentnet/datasc...   
176   10290  C://Users/Prinzessin/projects/decentnet/datasc...   
177   10291  C://Users/Prinzessin/projects/decentnet/datasc...   
178   10297  C://Users/Prinzessin/projects/decentnet/datasc...   
179   10300  C://Users/Prinzessin/projects/decentnet/da

In [12]:
assert model_kwargs['n_classes'] != None, "DECENT ERROR: make sure you set the n_classes with the dataset"  
print("n_classes:", model_kwargs['n_classes'])

print("train_size", train_kwargs["train_size"])
print("val_size", train_kwargs["val_size"])
print("test_size", train_kwargs["test_size"])

n_classes: 4
train_size 0
val_size 0
test_size 100


## X (Datatype)

In [13]:
class X:
    # =============================================================================
    #
    # an object with image representations and their positions
    # amout of channels need to have same length as m and n lists
    #
    # =============================================================================
    
    def __init__(self, data, ms_x, ns_x):
        self.data = data # list of tensors (image representations)
        self.ms_x = ms_x # list of integers (m position of each image representation)
        self.ns_x = ns_x # list of integers (n position of each image representation)
                
    def setter(self, data, ms_x, ns_x):
        self.data = data
        self.ms_x = ms_x
        self.ns_x = ns_x
        
    def getter(self):
        return self.data, self.m, self.n
    
    def __str__(self):
        return 'X(data: ' + str(self.data.shape) +' at positions: ms_x= ' + ', '.join(str(m.item()) for m in self.ms_x) + ', ns_x= ' + ', '.join(str(n.item()) for n in self.ns_x) + ')'
    __repr__ = __str__

  

# Lightning

## Callbacks

In [14]:
# in use
class SaveLastModelCheckpoint(ModelCheckpoint):
    
    def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates) -> None:
        if not self.save_last:
            return
        
        # self.CHECKPOINT_NAME_LAST = 

        filepath = self.format_checkpoint_name(monitor_candidates)
        
        print("last filepath", filepath)
        print("+ last", filepath.replace("mf", "last"))
        
        # examples/example_results\lightning_logs\trying_counter\version_6\checkpoints\mf_stage=0.0_counter=1.0_val_f1_macro=0.13_unpruned=5980.ckpt
        
        replace_part = filepath.split("_counter=")[1].replace(".cktp", "")
        print("replace_part", replace_part)
        
        # remove all from current state
        tmp = filepath.replace(replace_part, "*").replace("mf", "last")
        print("tmp", tmp)
        for previous in glob.glob(tmp):
            print("prev", previous)
            self._remove_checkpoint(trainer, previous)
        
        self._save_checkpoint(trainer, filepath.replace("mf", "last"))
        
        # if there is one with last and same stage - delete

        #if self._enable_version_counter:
            #version_cnt = self.STARTING_VERSION
            #while self.file_exists(filepath, trainer) and filepath != self.last_model_path:
                #filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST, ver=version_cnt)
                #version_cnt += 1

        # set the last model path before saving because it will be part of the state.
        #previous, self.last_model_path = self.last_model_path, filepath
        #if self.save_last == "link" and self._last_checkpoint_saved and self.save_top_k != 0:
        #    self._link_checkpoint(trainer, self._last_checkpoint_saved, filepath)
        #else:
            
        #if previous and self._should_remove_checkpoint(trainer, previous, filepath):
            #self._remove_checkpoint(trainer, previous)

In [15]:
# not in use 
class EndModelCheckpoint(ModelCheckpoint):
    
    def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        # =============================================================================
        # costum model checkpoint 
        # if unpruned state != -1
        # Save a checkpoint at the end of a defined training epoch.
        # parameters:
        #    trainer
        #    module
        # saves:
        #    the checkpoint model
        # sources:
        #    https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/callbacks/model_checkpoint.py
        # =============================================================================
        
        monitor_candidates = self._monitor_candidates(trainer)
        monitor_candidates["epoch"] = monitor_candidates["epoch"]

        self._save_last_checkpoint(trainer, monitor_candidates)
        pl_module.model.get_everything(counter=trainer.current_epoch)

In [16]:
# not working

from pytorch_lightning.callbacks import Callback
class OptimiserUpdateCheckpoint(Callback):
    
    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        if pl_module.reset_optimiser_at_update: # todo - something missing here!!
            print("update optimiser")
            
            # print(pl_module.model.parameters())
            
            new_optimizers = optim.SGD(pl_module.model.parameters(), lr=0.01, momentum=0.9)
            trainer.optimizers = [new_optimizers]
            
            
            # trainer.lr_schedulers = trainer.configure_schedulers([new_schedulers])
            # self.model.parameters
            
            #trainer.optimizers[0] = new_optimizers
            
            #print("* params begin "*10)
            #for param in pl_module.model.parameters():
            #    print(param)
            #    break
            #print("-"*10)    

In [17]:
# not in use anymore

class DecentModelCheckpoint(ModelCheckpoint):
    
    def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        # =============================================================================
        # costum model checkpoint 
        # if unpruned state != -1
        # Save a checkpoint at the end of a defined training epoch.
        # parameters:
        #    trainer
        #    module
        # saves:
        #    the checkpoint model
        # sources:
        #    https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/callbacks/model_checkpoint.py
        # =============================================================================
        
        # when pruning, then save model!
        if (
            not self._should_skip_saving_checkpoint(trainer) 
            and self._should_save_on_train_epoch_end(trainer)
        ):
            monitor_candidates = self._monitor_candidates(trainer)
            monitor_candidates["epoch"] = monitor_candidates["epoch"]
            print("DECENT NOTE: callback on_train_epoch_end", monitor_candidates["epoch"].item())
            if monitor_candidates["epoch"] > 0:
                if monitor_candidates["unpruned_state"] != -1:
                    print("DECENT NOTE: save model", monitor_candidates["epoch"].item())
                    if self._every_n_epochs >= 1 and ((trainer.current_epoch + 1) % self._every_n_epochs) == 0:
                        self._save_topk_checkpoint(trainer, monitor_candidates)
                        
                    self._save_last_checkpoint(trainer, monitor_candidates)
                    
                    pl_module.model.get_everything(counter=trainer.current_epoch)


## LightningModule

In [18]:
class DecentLightning(pl.LightningModule):
    # =============================================================================
    #
    # Lightning Module consists of functions that define the training routine
    # train, val, test: before epoch, step, after epoch, ...
    # https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/core/module.py
    # order for the instance methods:
    # https://pytorch-lightning.readthedocs.io/en/1.7.2/common/lightning_module.html#hooks
    # 
    # =============================================================================

    def __init__(self, kwargs, log_dir, ckpt_path=None):
        super().__init__()
        
        # print("the kwargs: ", kwargs)
        
        # keep kwargs for saving hyperparameters
        model_kwargs = kwargs['model_kwargs']
        
        self.log_dir = log_dir
        
        print("ckpt_path:", ckpt_path)
        
        if train_kwargs["load_mode"]: # True, False
            # ckpt_path = os.path.join(log_dir, train_kwargs["load_ckpt_file"]).replace("_xAI", "")
            if os.path.isfile(ckpt_path):
                print(f"Found pretrained model at {ckpt_path}, loading...")
                self.model = DecentNet(model_kwargs=model_kwargs, log_dir=log_dir, ckpt_path=ckpt_path).to("cuda")
            else:
                print(f"DECENT NOTE: Did not find {ckpt_path}, create new model")
                # n_classes=self.n_classes, grid_size=self.grid_size, out_dim=self.out_dim, prune_keep=self.prune_keep, prune_keep_total=self.prune_keep_total, cc_metric=self.cc_metric
                # self.model = DecentNet(model_kwargs=model_kwargs, log_dir=log_dir).to("cuda")
        else:
            print(f"Create new model")
            # n_classes=self.n_classes, grid_size=self.grid_size, out_dim=self.out_dim, prune_keep=self.prune_keep, prune_keep_total=self.prune_keep_total, cc_metric=self.cc_metric
            self.model = DecentNet(model_kwargs=model_kwargs, log_dir=log_dir).to("cuda")
            
        # print(self.model)
        
        self.n_classes = model_kwargs["n_classes"]
        self.cc_weight = model_kwargs["cc_weight"]
        self.criterion = model_kwargs["criterion"]
        self.optim = model_kwargs["optimizer"]
        self.base_lr = model_kwargs["base_lr"]
        self.min_lr = model_kwargs["min_lr"]
        self.lr_update = model_kwargs["lr_update"]
        self.momentum = model_kwargs["momentum"]
        self.update_every_nth_epoch = model_kwargs["update_every_nth_epoch"]
        self.pretrain_epochs = model_kwargs["pretrain_epochs"]
        self.image_size = kwargs['train_kwargs']["img_size"]
        self.new_cc_mode = kwargs["model_kwargs"]["new_cc_mode"]
        self.reset_optimiser_at_update = kwargs["model_kwargs"]["reset_optimiser_at_update"]
        
        self.cc_ci = torch.tensor([0]).to(kwargs["train_kwargs"]["device"])
    
        self.counter = 0
        
        # needed for hparams.yaml file
        self.save_hyperparameters()
        
        
        self.train_metrics = torchmetrics.MetricCollection(
            {
            "acc": torchmetrics.classification.MulticlassAccuracy(num_classes=self.n_classes),
            "f1_macro": torchmetrics.classification.MulticlassF1Score(num_classes=self.n_classes),
            "f1_micro" : torchmetrics.classification.MulticlassF1Score(num_classes=self.n_classes, average='micro'),
            "prec": torchmetrics.classification.MulticlassPrecision(num_classes=self.n_classes),
            "rec": torchmetrics.classification.MulticlassRecall(num_classes=self.n_classes),
            # "cm": torchmetrics.classification.MulticlassConfusionMatrix(num_classes=self.n_classes)
            }, prefix="train_",)
        self.val_metrics = self.train_metrics.clone(prefix="val_")
        self.test_metrics = self.train_metrics.clone(prefix="test_")

        self.cm = torchmetrics.classification.MulticlassConfusionMatrix(num_classes=self.n_classes)
        self.roc_auc = torchmetrics.classification.MulticlassROC(num_classes=self.n_classes)
        self.pr_curve = torchmetrics.classification.MulticlassPrecisionRecallCurve(num_classes=self.n_classes)
        
        print("DECENT INFO: init done")        
                
    def forward(self, x, mode="grad"):
        # =============================================================================
        # we make it possible to use model_output = self(image)
        # =============================================================================
        return self.model(x, mode)
    
    def set_stage(self, i_prune_stage): #  i_stage, max_epochs):
        self.this_prune_stage = i_prune_stage
        #self.max_epochs = max_epochs
        #self.counter = prev_counter # i_stage * max_epochs # self.prune_stage * self.max_epochs + self.current_epoch

    # from here the fit starts    
    
    def prune(self):
        # does not work with on_fit_start
        
        if self.this_prune_stage < 1:
            return
        
        print("DECENT INFO: pruning now at the beginning")
        # pruning and save model

        # counter

        # update model
        # don't update unless pretrain epochs is reached
        #if (self.current_epoch % self.update_every_nth_epoch) == 0 and self.current_epoch >= self.pretrain_epochs:
        #    print("DECENT NOTE: update model", self.current_epoch)      

        if debug_mode:
            print("DECENT NOTE: before update")
            print("DECENT NOTE: print model ...")
            print(self.model)
        
        
        self.model.update(current_epoch = self.counter)

        #if self.reset_optimiser_at_update:
            # self.configure_optimizers() # reset optimisers do to big change in loss term + cause of pruned parameters
            # self.trainer.accelerator_backend.setup_optimizers(self) xxxxxx
            #new_optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
            #self.optimizer.load_state_dict(new_optimizer.state_dict())  

            # self.trainer.accelerator.setup_optimizers(self)
            #self.trainer.strategy.setup_optimizers(self)
            
        if debug_mode:
            print("DECENT NOTE: model updated")  
    
    def configure_optimizers(self):
        # =============================================================================
        # returns:
        #    optimiser and lr scheduler
        # =============================================================================  
        print("DECENT NOTE: configure_optimizers")
        
        if self.optim == "adamw":
            adam = optim.AdamW(self.model.parameters(), lr=self.base_lr)
            multistep = optim.lr_scheduler.MultiStepLR(optimiser, milestones=[50,100], gamma=0.1)
            
             # first a list of optimisers, then a list of learing rate schedulers
            return [adam], [multistep]
        else:
            sgd = optim.SGD(self.model.parameters(), lr=self.base_lr, momentum=self.momentum)
            cosine = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimiser,
                                                                    T_0 = self.lr_update, # number of iterations for the first restart.
                                                                    eta_min = self.min_lr
                                                                   )
            
        for param in self.model.parameters(): # not sure whether this is actually needed
            if param.grad is not None:
                param.grad.zero_()

        # first a list of optimisers, then a list of learing rate schedulers
        return [sgd], [cosine]
    
        
    def on_train_start(self):
        # =============================================================================
        # logging from first and then iteratively "previously pruned model" 
        # == our current new model
        # plot of circular layer - todo - make sure we get this into some other dir
        # todo - i have no idea whether the circular plots break the code of node is gone
        # =============================================================================

        # numel: returns the total number of elements in the input tensor
        unpruned = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print("unpruned", unpruned) # todo: log
        self.log('unpruned', float(unpruned), on_step=False, on_epoch=True) # neither ??
        
        self.log('stage', float(self.this_prune_stage), on_step=False, on_epoch=True)

        self.model.plot_incoming_connections(current_epoch=self.this_prune_stage)
        self.model.plot_outgoing_connections(current_epoch=self.this_prune_stage)
        
    def on_train_epoch_start(self):
        # =============================================================================
        # updates model every nth epoch
        # =============================================================================  
        
        # absolutely never delete this line!!! 
        self.counter += 1 # = self.prune_stage * self.max_epochs + self.current_epoch
        self.log('counter', self.counter, on_step=False, on_epoch=True)

        self.log('train_decent2_first_weight_start', self.model.decent2.filter_list[0].weights[0].flatten()[0].detach().cpu().numpy().item(), on_step=False, on_epoch=True)
        
        print("e_start d1:", self.model.decent1.filter_list[0].weights.flatten()[0])
        print("e_start d2:", self.model.decent2.filter_list[0].weights[0].flatten()[0])
                
        
    def on_validation_epoch_start(self): # is this even a thing??
        self.log('counter', self.counter, on_step=False, on_epoch=True)
    
    def training_step(self, batch, batch_idx):
        # =============================================================================
        # calculates loss for a batch
        # parameters:
        #    batch
        #    batch id
        # returns:
        #    loss
        # notes:
        #    calling gradcam like self.gradcam(batch) here is dangerous cause changes gradients
        # =============================================================================     

        
        # calculate loss
        # loss = torch.tensor(1)
        loss = self.run_loss_n_metrics(batch, mode="train")
        
        
        # debugging messages!!
        if debug_mode:
            print("t_step d1:", self.model.decent1.filter_list[0].weights.flatten()[0])
            print("t_step d2:", self.model.decent2.filter_list[0].weights[0].flatten()[0])
        
        if False: # to check the gradients - on 14.11.2024 they seemed fine ..
            print("next ********************************")
            ignored = []
            for i_p, param in enumerate(self.model.parameters()):
                if param.grad is not None:
                    # print(param.grad)
                    print("++ para", i_p, " ", param.grad.shape, " ", param.grad.flatten()[0])
                else:
                    ignored.append(i_p)
                    pass
                    # ("++ para NONE")
            print("nan element amount:", len(ignored))
        
        return loss

    def validation_step(self, batch, batch_idx):
        # =============================================================================
        # calculate loss for logging # 2
        # =============================================================================
        if False: # batch_idx < 2:
            print("DECENT NOTE: validation_step", batch_idx)
        
        self.run_loss_n_metrics(batch, mode="val")
        
    def on_validation_epoch_end(self):
        # =============================================================================
        # currently nothing # 3
        # =============================================================================
        if debug_mode:
            print("DECENT NOTE: on_validation_epoch_end")
    
    def on_train_epoch_end(self):
        # =============================================================================
        # save model if next iteration model is pruned # 4 
        # this needs to be called before callback 
        # - if internal pytorch lightning convention changes, this will stop working
        # =============================================================================
        if debug_mode:
            print("DECENT NOTE: on_train_epoch_end", self.counter)
               
        if False:
            print("current epoch")
            print(((self.counter+1) % self.update_every_nth_epoch) == 0)
            print(self.counter+1)
            print(self.counter)
            print(self.update_every_nth_epoch)
        
        
        #if ((self.current_epoch+1) % self.update_every_nth_epoch) == 0 and self.current_epoch != 0:
            # if next epoch is an update, set unpruned flag            
            
            # self.log(f'unpruned_state', 1.0, on_step=False, on_epoch=True)
            
            # save file
            #with open(os.path.join(self.log_dir, 'logger.txt'), 'a') as f:
            #    f.write("\n# parameter requires grad shape #\n")
            #    for p in self.model.parameters():
            #        if p.requires_grad:
            #            f.write(str(p.shape))
            
            
        
        print("e_end d1:", self.model.decent1.filter_list[0].weights.flatten()[0])
        print("e_end d2:", self.model.decent2.filter_list[0].weights[0].flatten()[0])
        
        self.log('train_decent2_first_weight_end', self.model.decent2.filter_list[0].weights[0].flatten()[0].detach().cpu().numpy().item(), on_step=False, on_epoch=True)
        
        unpruned = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print("unpruned", unpruned) # todo: log
        # self.log(f'unpruned', float(unpruned), on_step=False, on_epoch=True) # neither ??
        
    def on_train_end(self):
        pass

    def on_test_epoch_start(self):
        # =============================================================================
        
        # =============================================================================
        
        # helper/model/decentnet - get_everything() - creates csv
        self.model.get_everything(counter='final_test')
        
    def test_step(self, batch, batch_idx):
        # =============================================================================
        # calculate loss for logging, plot gradcam
        # =============================================================================
        if batch_idx < 2:
            print("DECENT NOTE: test_step", batch_idx)

        # we update mo and gt here
        self.run_loss_n_metrics(batch, mode="test")

        """
        with torch.enable_grad():
            grad_preds = preds.requires_grad_()
            preds2 = self.layer2(grad_preds)

        """
        
        # save image
        
        if len(batch) == 4:
            # if mask is there
            img, _, msks, img_id = batch
            img_id = img_id.detach().cpu().item()
            tmp_b4 = True
        else:
            # if no mask
            img, _ = batch # image and mask come out of this
            img_id = batch_idx
            msks = None
            tmp_b4 = False             
        
        # print(img.shape)
        
        # save image
        tmp_file_name = f'entry_id{img_id}_{0}_{0}_{0}_mo{self.mo}_gt{self.gt}.png'
        # tmp_img = self.feature_maps.squeeze()[i_map].cpu().detach().numpy()
        tmp_img = img.squeeze().cpu().detach().numpy()
        tmp_path = os.path.join(self.log_dir, "img_choice")
        os.makedirs(tmp_path, exist_ok=True)
        plt.imsave(os.path.join(tmp_path, tmp_file_name), tmp_img)
        plt.close()
        
        if tmp_b4:
            msks = msks.detach().cpu().numpy().squeeze()
            
            tmp_msk = msks[0] # 28
            
            # save mask
            plt.figure(figsize=(5, 5))
            for o, boundary in enumerate(tmp_msk): # skip last one
                # plt.plot(list(range(len(layer))), layer)
                plt.plot(boundary[:,1]-0.5, boundary[:,0])
            plt.ylim(0, 28 - 1)
            plt.gca().invert_yaxis()
            #plt.axis('off')
            # Save the plot
            # plt.savefig('plot_without_axes.png', bbox_inches='tight', pad_inches=0)
            tmp_file_name = f'entry_id{img_id}_{0}_{0}_{0}_mo{self.mo}_gt{self.gt}.png'
            tmp_path = os.path.join(self.log_dir, "msk_choice")
            os.makedirs(tmp_path, exist_ok=True)
            plt.savefig(os.path.join(tmp_path, tmp_file_name), bbox_inches='tight', pad_inches=0)
            #plt.imsave(os.path.join(tmp_path, tmp_file_name), tmp_msk) # todo
            plt.close()

            # save image + mask (todo)
            plt.imshow(tmp_img, cmap="gray")
            for o, boundary in enumerate(tmp_msk):
                plt.plot(boundary[:,1]-0.5, boundary[:,0])
            tmp_file_name = f'entry_id{img_id}_{0}_{0}_{0}_mo{self.mo}_gt{self.gt}.png'
            tmp_path = os.path.join(self.log_dir, "img_with_msk")
            os.makedirs(tmp_path, exist_ok=True)
            plt.savefig(os.path.join(tmp_path, tmp_file_name), bbox_inches='tight', pad_inches=0)
            plt.close()
            
            # save mat file
            for msk, msk_size in zip(msks, [28,26,24,22]):
                tmp_mat = {'__header__': b'MATLAB 5.0 MAT-file, Platform: PCWIN64, Created on: Fri May 06 15:17:37 2022',
                     '__version__': '1.0',
                     '__globals__': [],
                     'Layer': msk
                            }

                tmp_file_name = f'mat_id{img_id}_size{msk_size}_{0}_{0}_{0}_mo{self.mo}_gt{self.gt}.mat'
                tmp_path = os.path.join(self.log_dir, "mat_transformed_choice")
                os.makedirs(tmp_path, exist_ok=True)
                scipy.io.savemat(file_name=os.path.join(tmp_path, tmp_file_name), mdict=tmp_mat)
            
            
        
            # plt.imsave(os.path.join(tmp_path, tmp_file_name), tmp_img)
        
        # save feature maps of hidden layers and the layer that gets globally pooled
        try:
            with torch.set_grad_enabled(True): # torch.set_grad_enabled(True):
                self.run_xai_gradcam(batch, batch_idx, mode='explain')
        except Exception as e:
            print("DECENT EXCEPTION: batch size has to be 1")
            print(e)
            
        with torch.set_grad_enabled(True):
            
            layer = self.model.decent1
            # this line seems to be useless, always same output no matter what
            layer_str = 'decent1' # 'decent3'  model.model.decent3' # .filter_list[7]weights
            self.run_xai_feature_map(batch, batch_idx, layer, layer_str, device='cuda')
            #filter_list.extend(tmp)
            
            layer = self.model.decent2
            # this line seems to be useless, always same output no matter what
            layer_str = 'decent2' # 'decent3'  model.model.decent3' # .filter_list[7]weights
            self.run_xai_feature_map(batch, batch_idx, layer, layer_str, device='cuda')
            
            layer = self.model.decent3
            # this line seems to be useless, always same output no matter what
            layer_str = 'decent3' # 'decent3'  model.model.decent3' # .filter_list[7]weights
            self.run_xai_feature_map(batch, batch_idx, layer, layer_str, device='cuda')
            
            layer = self.model.decent1x1
            # this line seems to be useless, always same output no matter what
            layer_str = 'decent1x1' # 'decent3'  model.model.decent3' # .filter_list[7]weights
            self.run_xai_feature_map(batch, batch_idx, layer, layer_str, device='cuda')
            
        # get filter list            
        filter_list = []
        for l in self.model.decent1.filter_list:
            filter_list.append(f"filter_{int(l.m_this)}_{int(l.n_this)}_{1}")
        for l in self.model.decent2.filter_list:
            filter_list.append(f"filter_{int(l.m_this)}_{int(l.n_this)}_{2}")
        for l in self.model.decent3.filter_list:
            filter_list.append(f"filter_{int(l.m_this)}_{int(l.n_this)}_{3}")
        for l in self.model.decent1x1.filter_list:
            filter_list.append(f"filter_{int(l.m_this)}_{int(l.n_this)}_{4}")
        df = pd.DataFrame(filter_list, columns=['filter'])
        df.to_csv(os.path.join(self.log_dir, "all_filters.csv"), index=False)
            
    def on_test_epoch_end(self):
        # =============================================================================
        # currently nothing
        # =============================================================================
        if False:
            print("DECENT NOTE: on_test_epoch_end", self.counter)
        
        tmp_path = os.path.join(self.log_dir, "final_plots")
        os.makedirs(tmp_path, exist_ok=True)
        
        # confusion matrix
        cm = self.cm.compute()
        cm = cm.cpu().numpy()
        df_cm = pd.DataFrame(cm, index=list(class_mapper.values()), columns=list(class_mapper.values()))
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False)
        plt.xlabel("Predicted Label")
        plt.ylabel("True Label")
        plt.title("Confusion Matrix")
        # plt.gca().invert_yaxis() - should not be inverted
        plt.savefig(os.path.join(tmp_path, "confusion_matrix.png"), bbox_inches='tight', pad_inches=0)
        plt.close()
        
        # precision-recall curve
        pr_precision, pr_recall, pr_thresholds = self.pr_curve.compute()
        plt.figure(figsize=(10, 8))
        for i in range(self.n_classes):
            converted_label = class_mapper.get(str(i))
            plt.plot(pr_recall[i].cpu(), pr_precision[i].cpu(), label=f"{converted_label}", color=cnv_dr_amd_normal.colors[i]) 
        plt.xlabel("Recall")
        plt.ylabel("Precision")
        plt.title("Precision-Recall Curve")
        plt.legend(loc="lower left")
        plt.savefig(os.path.join(tmp_path, "precision_recall_curve.png"), bbox_inches='tight', pad_inches=0)
        plt.close()
        
        roc_fpr, roc_tpr, roc_thresholds = self.roc_auc.compute()
        plt.figure(figsize=(10, 8))
        for i in range(self.n_classes):
            converted_label = class_mapper.get(str(i))
            plt.plot(roc_fpr[i].cpu(), roc_tpr[i].cpu(), label=f"{converted_label}", color=cnv_dr_amd_normal.colors[i]) # could add AUC here # ... torchmetrics.functional.auroc(preds ... but i don't have access ot the preds here
            # plt.plot(roc_fpr[i].cpu(), roc_tpr[i].cpu(), label=f"Class {i} (AUC = {torchmetrics.functional.auroc(probs[:, i], target == i):.2f})")
        plt.plot([0, 1], [0, 1], 'k--', lw=2)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title("ROC-AUC Curve")
        plt.legend(loc="lower right")
        plt.savefig(os.path.join(tmp_path, "roc_auc_curve.png"), bbox_inches='tight', pad_inches=0)
        plt.close()

        
    
    def run_xai_feature_map(self, batch, batch_idx, layer, layer_str, device='cuda'):
        # https://discuss.pytorch.org/t/how-can-l-load-my-best-model-as-a-feature-extractor-evaluator/17254/5
 
        # img, label = testset.__getitem__(0) # batch x channel x width x height, class

        # img = X(img.to(device).unsqueeze(0), [torch.tensor(0)], [torch.tensor(0)])
                    
        if len(batch) == 4:
            img, ground_truth, msk, img_id = batch
            img_id = img_id.detach().cpu().item()
        else:
            img, ground_truth = batch # image and mask come out of this
            img_id = batch_idx
            msk = None
            
        
        # make it an X object, init with position 0/0 as input for first layer
        tmp_img = X(img.to("cuda"), [torch.tensor(0)], [torch.tensor(0)])

        # print(img.data.shape)

        # run feature map
        # model, layer, layer_str, log_dir, device="cpu"
        fm = filter_activation.DecentFilterActivation(model=self.model, layer=layer, layer_str=layer_str, log_dir=self.log_dir, device=device)
        fm.run(tmp_img, img_id)
        
        filter_list = fm.log()
        
        return filter_list
        
        
    
    def run_xai_gradcam(self, batch, batch_idx, mode='explain'):
        # =============================================================================
        # grad cam - or just cam?? idk
        # todo error: RuntimeError: cannot register a hook on a tensor that doesn't require gradient
        # BATCH SIZE HAS TO BE ONE!!!
        # grad enable in test mode:
        # https://github.com/Project-MONAI/MONAI/discussions/1598
        # https://lightning.ai/docs/pytorch/stable/common/trainer.html
        # =============================================================================
    
        if len(batch) == 4:
            img, ground_truth, msk, img_id = batch
            img_id = img_id.detach().cpu().item()
        else:
            img, ground_truth = batch # image and mask come out of this
            img_id = batch_idx
            msk = None

        # make it an X object, init with position 0/0 as input for first layer
        tmp_img1 = X(img.to("cuda"), [torch.tensor(0)], [torch.tensor(0)]) # .requires_grad_()
        tmp_img2 = X(img.to("cuda"), [torch.tensor(0)], [torch.tensor(0)])

        #print("nooooooooooo grad, whyyyyy")
        #print(tmp_img1)
        #print(img)

        #print('b1', tmp_img1)
        #print('b2', tmp_img2)

        model_output = self(tmp_img1, mode)

        #print('c1', tmp_img1)
        #print('c2', tmp_img2)

        # get the gradient of the output with respect to the parameters of the model
        #pred[:, 386].backward()

        # get prediction value
        pred_max = model_output.argmax(dim=1)

        #print('d1', tmp_img1)

        #print("mo", model_output)
        #print("max", pred_max)
        #print("backprop", model_output[:, pred_max])

        # backpropagate for gradient tracking
        model_output[:, pred_max].backward()

        # pull the gradients out of the model
        gradients = self.model.get_activations_gradient()

        # pool the gradients across the channels
        pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

        #print('e2', tmp_img2)

        # get the activations of the last convolutional layer
        activations = self.model.get_activations(tmp_img2).detach()

        # weight the channels by corresponding gradients
        for i in range(self.n_classes):
            activations[:, i, :, :] *= pooled_gradients[i]

        # average the channels of the activations
        heatmap = torch.mean(activations, dim=1).squeeze()

        #print("hm", heatmap.shape)

        # relu on top of the heatmap
        # expression (2) in https://arxiv.org/pdf/1610.02391.pdf
        #heatmap = torch.max(heatmap, 0)

        # normalize the heatmap
        #heatmap /= torch.max(heatmap)

        #print("hm", heatmap.shape)

        # draw the heatmap
        # plt.matshow(heatmap.detach().cpu().numpy().squeeze())
        # fig.savefig(os.path.join(self.log_dir, f"{self.ci_metric}_m{int(self.m_l2_plot[0])}_n{int(self.n_l2_plot[0])}_{str(current_epoch)}.png"))
        
        tmp_path = os.path.join(self.log_dir, "gradcam")
        os.makedirs(tmp_path, exist_ok=True)
        plt.imsave(os.path.join(tmp_path, 
                                f"cam_id{img_id}_mo{pred_max.detach().cpu().numpy().squeeze()}_gt{ground_truth.detach().cpu().numpy().squeeze()}.png"
                               ), heatmap.detach().cpu().numpy().squeeze())


        heatmap *= 255.0 / heatmap.max()
        pil_heatmap = Image.fromarray(heatmap.detach().cpu().numpy().squeeze()).convert('RGB')
       
        tmp_path = os.path.join(self.log_dir, "gradcam")
        os.makedirs(tmp_path, exist_ok=True)
        pil_heatmap.save(os.path.join(tmp_path, 
                                      f"camgray_id{img_id}_mo{pred_max.detach().cpu().numpy().squeeze()}_gt{ground_truth.detach().cpu().numpy().squeeze()}.png" 
                                     )) 
            
    def run_loss_n_metrics(self, batch, mode="train"):
        # =============================================================================
        # put image through model, calculate loss and metrics
        # use cc term that has been calculated previously
        # =============================================================================
        
        if len(batch) == 4:
            img, ground_truth, mask, img_id = batch
        else:
            img, ground_truth = batch
        
        # init with position 0/0 as input for first layer
        img = X(img.to("cuda"), [torch.tensor(0)], [torch.tensor(0)])
        
        model_output = self(img, mode) # cause of the forward function
        
        # for test routine "test_step"
        self.mo = model_output.argmax(dim=1).squeeze().detach().cpu().numpy()
        self.gt = ground_truth.squeeze().detach().cpu().numpy()
        
        ground_truth = ground_truth.squeeze()
        if len(ground_truth.shape) < 1:
            ground_truth = ground_truth.unsqueeze(0)
        ce_loss = self.criterion(model_output, ground_truth.long()) # ground_truth_multi_hot)
        
        
        # this thing does not work with the old function - the old connection cost is really bad!!!
        # cc = torch.mean(self.model.cc) * self.cc_weight # update_new_connection_cost
        if mode == "train" and self.new_cc_mode == True:
            self.cc_ci = self.model.get_cc_and_ci_loss_term()
            
            # print(self.cc_ci)
            
            # print(self.model.decent2.filter_list[1].weights[0][0])
            
            # get max values to understand what is going on :)
            cc_max_decent3 = self.model.decent3.cc_max_of_layer
            ci_max_decent3 = self.model.decent3.ci_max_of_layer
            #cc_mean = 
            #ci_mean =                     

            
            self.log(f'{mode}_cc_max_decent3', cc_max_decent3, on_step=False, on_epoch=True)
            self.log(f'{mode}_ci_max_decent3', ci_max_decent3, on_step=False, on_epoch=True)
            #self.log(f'{mode}_ci_mean', ci_mean, on_step=False, on_epoch=True)
            #self.log(f'{mode}_cc_mean', cc_mean, on_step=False, on_epoch=True)
            self.log(f'{mode}_cc', self.cc_ci, on_step=False, on_epoch=True) # this should have a more general name, for the future!!
            if debug_mode:
                print("decent note: self.cc_ci:", self.cc_ci)
                print("decent note: cc_max_decent3:", cc_max_decent3)
                print("decent note: ci_max_decent3:", ci_max_decent3)
                #print("decent note: all_ci = mean:", all_ci)
                #print("decent note: all_cc = mean", all_cc)
                print("ce_loss", ce_loss)
            
            loss = ce_loss + (self.cc_ci * self.cc_weight) # make sure to set the weight in the args, also make sure to use norms that are torch not scipy!!
        else:
            loss = ce_loss
        
        pred_value, pred_i  = torch.max(model_output, 1)
                
        if mode == "train":
            value = self.train_metrics(preds=pred_i, target=ground_truth)
            self.log_dict(value, on_step=False, on_epoch=True)
                
        elif mode == "val":
            value = self.val_metrics(preds=pred_i, target=ground_truth)
            self.log_dict(value, on_step=False, on_epoch=True)
                
        else:
            value = self.test_metrics(preds=pred_i, target=ground_truth)
            self.log_dict(value, on_step=False, on_epoch=True)
            
            
            self.cm.update(preds=pred_i, target=ground_truth) # prediction (class)
            self.pr_curve.update(preds=model_output, target=ground_truth) # probability
            self.roc_auc.update(preds=model_output, target=ground_truth) # probability
                        
        self.log(f'{mode}_ce_loss', ce_loss, on_step=False, on_epoch=True)
        self.log(f'{mode}_loss', loss, on_step=False, on_epoch=True)
        
        if False: # debug_mode:
            print("loss", loss)
            print("ce_loss", ce_loss)
            print("mo", model_output)
            print("gt", ground_truth)
            print("mo", self.mo)
            print("gt", self.gt)
            
        if torch.isnan(loss).any():
            print("DECENT WARNING: Loss contains NaN value(s).")
        
        # ce loss + connection cost term
        
        return loss



# Run

## run dev routine ****************************

In [19]:
train_kwargs["xai_done"] = False

if train_kwargs["train_size"] > 0:
    
    # =============================================================================
    # train model and run test/xAI routine

    # logger - save logs in "examples/example_results/lightning_logs"
    # light - DecentLightning model
    # trainer - pl.Trainer
    # trainer.fit
    # explainer - pl.Trainer
    # explainer.test
    # =============================================================================

    pl.seed_everything(19) # To be reproducable

    # THE LOGGER
    logger = CSVLogger(os.path.join(train_kwargs["result_path"], "lightning_logs"), name=train_kwargs["exp_name"])
    
    ckpt_path = os.path.join(*[train_kwargs["result_path"], "lightning_logs", train_kwargs["exp_name"], train_kwargs["load_ckpt_file"]])

    # THE LIGHTNING MODEL
    # Initialize the LightningModule
    light = DecentLightning(kwargs=kwargs, log_dir=logger.log_dir, ckpt_path=ckpt_path)

    # THE LIGHTNING TRAINER (for training)
    
    for i_prune_stage, max_epochs in enumerate(train_kwargs["training_stages"]):
        
        light.set_stage(i_prune_stage=i_prune_stage)
        if i_prune_stage > 0:
            light.prune()
                
        trainer = pl.Trainer(default_root_dir=train_kwargs["result_path"],
                             accelerator="gpu" if str(train_kwargs["device"]).startswith("cuda") else "cpu",
                             devices=[0],
                             # inference_mode=False, # do grad manually
                             log_every_n_steps=train_kwargs["log_every_n_steps"],
                             logger=logger,
                             check_val_every_n_epoch=1,
                             max_epochs=max_epochs, # ["epochs"], # train_kwargs["epochs"],
                             callbacks=[SaveLastModelCheckpoint(save_weights_only=True, mode="max", monitor="val_f1_macro",
                                                       filename='mf_{stage}_{counter}_{val_f1_macro:.2f}_{unpruned:.0f}', save_last=True),
                                        #ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_f1_macro",
                                        #               filename='mf_{stage}_{counter}_{val_f1_macro:.2f}_{unpruned:.0f}'), # monitor fscore
                                        #EndModelCheckpoint(save_weights_only=True,  mode="min", monitor="unpruned", save_top_k=-1,
                                        #                   filename='mu_{epoch}-{val_f1_macro:.2f}-{unpruned:.0f}'), # monitor unpruned
                                        #ModelCheckpoint(save_weights_only=True, 
                                        #                save_last=True, 
                                        #                filename='mi_{i_stage}_{epoch}-{val_f1_macro:.2f}-{unpruned:.0f}'),
                                        #DecentModelCheckpoint(save_weights_only=True, mode="min", monitor="unpruned", save_top_k=-1, save_on_train_epoch_end=True,
                                        #                filename='mu_{epoch}-{val_f1_macro:.2f}-{unpruned:.0f}'), # monitor unpruned
                                        #OptimiserUpdateCheckpoint(),
                                        LearningRateMonitor("epoch")])
        
        # trainer.save_checkpoint(f"ml_{light.stage}_{light.counter}-{light.val_f1_macro:.2f}-{light.unpruned:.0f}.ckpt")

        # PRUNE AT BEGINNING OF TRAINING!!!!, skip first stage, i need to set the stage i think ... ? the trainer resets everything ...
        
        trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
        trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

        # THE TRAIN-RUN
        # Train the model using a Trainer
        trainer.fit(light, dataloader.train_dataloader, dataloader.val_dataloader)
    

    # THE LIGHTNING TRAINER (for testing)
    # we want the grad to work in test, hence: inference_mode=False
    explainer = pl.Trainer(default_root_dir=train_kwargs["result_path"], # is this also wrong?? where is the checkpoint??
                         accelerator="gpu" if str(train_kwargs["device"]).startswith("cuda") else "cpu",
                         devices=[0],
                         logger=logger,
                         inference_mode=False)

    # THE TEST-RUN
    # including test
    test_result = explainer.test(light, dataloader.xai_dataloader, verbose=False)
    

    
    
    train_kwargs["xai_done"] = True

print("Done")

Done


## run test routine ****************************

we need this with the OCTA-500 dataset

torch.load(ckpt_path)['state_dict']

In [None]:
if train_kwargs["load_mode"] and not train_kwargs["xai_done"]:

    # =============================================================================
    # load model and run test/xAI routine

    # logger - save logs in "dumpster"
    # light - DecentLightning model
    # explainer - pl.Trainer
    # explainer.test
    # =============================================================================
    
    print("DECENT INFO: be aware, that you have to manually check, whether every output node has an input. otherwise an error may be triggered by cuda")

    pl.seed_everything(19) # To be reproducable

    # train_kwargs["load_ckpt_file"] = "version_7/checkpoints/epoch=0-val_f1=0.62-unpruned=1560.ckpt"
    
    # Check whether pretrained model exists. If yes, load it.
    # ckpt_path = os.path.join(*[train_kwargs["result_path"], "lightning_logs\debug_oct_no_fc", 'version_13', 'checkpoints/epoch=2-unpruned=269-val_f1=0.25.ckpt'])
    ckpt_path = os.path.join(*[train_kwargs["result_path"], "lightning_logs", train_kwargs["exp_name"], train_kwargs["load_ckpt_file"]])
    print("DECENT INFO: You are using checkpoint file: ", ckpt_path)
    
    # ckpt_path = os.path.join(*[train_kwargs["result_path"], "lightning_logs", train_kwargs["exp_name"], train_kwargs["load_ckpt_file"]])

    print(train_kwargs["exp_name"])
    
    print(ckpt_path)

    
    if os.path.isfile(ckpt_path):

        # tmp = +"_xAI"
        
        # THE LOGGER
        logger = CSVLogger(os.path.join(train_kwargs["result_path"], 'lightning_logs'), name=train_kwargs["exp_name"]+"_xAI") # the xAI routine for an experiment
        # logger = CSVLogger(os.path.join(train_kwargs["result_path"], 'lightning_logs'), name='dumpster')
        
        print("logdir", logger.log_dir)
        
        # THELIGHTNING MODEL
        # load from checkpoint doesn't work, since our architecture is 'messed up' through pruning
        # light = DecentLightning.load_from_checkpoint(state_dict, model_kwargs=model_kwargs, log_dir="example_results/lightning_logs") # Automatically loads the model with the saved hyperparameters
        # use this line instead:
        light = DecentLightning(kwargs=kwargs, log_dir=logger.log_dir, ckpt_path=ckpt_path)

        # THE LIGHTNING TRAINER (for testing)
        # we want the grad to work in test, hence: inference_mode=False
        explainer = pl.Trainer(default_root_dir=train_kwargs["result_path"],
                             accelerator="gpu" if str(train_kwargs["device"]).startswith("cuda") else "cpu",
                             #devices=[0], # why is this not on??
                             logger=logger,
                             inference_mode=False)

        # THE TEST-RUN
        # only test
        test_result = explainer.test(light, dataloader.xai_dataloader, verbose=False)

    else:
        print('DECENT ERROR: not a file - may have been resetted in dev routine, check the load_ckpt_file, set dev routine to False and run everything')

    
print("Done")
    

Global seed set to 19
Missing logger folder: examples/example_results\lightning_logs\counter_seems_to_work_xAI


DECENT INFO: be aware, that you have to manually check, whether every output node has an input. otherwise an error may be triggered by cuda
DECENT INFO: You are using checkpoint file:  examples/example_results\lightning_logs\counter_seems_to_work\version_3/checkpoints/mf_stage=3.0_counter=13.0_val_f1_macro=0.45_unpruned=3732.ckpt
counter_seems_to_work
examples/example_results\lightning_logs\counter_seems_to_work\version_3/checkpoints/mf_stage=3.0_counter=13.0_val_f1_macro=0.45_unpruned=3732.ckpt
logdir examples/example_results\lightning_logs\counter_seems_to_work_xAI\version_0
ckpt_path: examples/example_results\lightning_logs\counter_seems_to_work\version_3/checkpoints/mf_stage=3.0_counter=13.0_val_f1_macro=0.45_unpruned=3732.ckpt
Found pretrained model at examples/example_results\lightning_logs\counter_seems_to_work\version_3/checkpoints/mf_stage=3.0_counter=13.0_val_f1_macro=0.45_unpruned=3732.ckpt, loading...
DECENT INFO: dimensions are entry, decent1, decent2, decent3, decent1x1 =

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

DECENT NOTE: test_step 0
DECENT NOTE: test_step 1


# random nonsense

In [None]:
light.model

In [None]:
for a in light.model.decent3.filter_list:
    print(a.weights)

In [None]:
dir(DecentNet)

In [None]:
len(range(0))

In [None]:
#print(value) 
        #print(img)
        #print(ground_truth)
        # make it an X object
        
        #print(img.shape)
                #print("loss", loss)
        
        # print(cc)
        # from BIMT
        # loss_train = loss_fn(mlp(x.to(device)), one_hots[label])
        # cc = mlp.get_cc(weight_factor=2.0, no_penalize_last=True)
        # total_loss = loss_train + lamb*cc


In [None]:
"""
            try:
                ta = self.train_acc(preds=pred_i, target=ground_truth) # (model_output.argmax(dim=-1) == ground_truth).float().mean()
                tf = self.train_f1(preds=pred_i, target=ground_truth) 
                tp = self.train_prec(preds=pred_i, target=ground_truth) 
            except Exception as e:
                print("DECENT ERROR: we are experiencing this CUDA ERROR most likely, because our decent1x1 has too little filters.")
                print("We need the same number as classes. It can happen, that all in-connections to a filter in decent1x1 got pruned and hence it is gone.")
                print("preds", pred_i)
                print("target", ground_truth)
                print(e)
            
            self.log(f'{mode}_acc', self.train_acc, on_step=False, on_epoch=True)
            self.log(f'{mode}_f1', self.train_f1, on_step=False, on_epoch=True)
            self.log(f'{mode}_prec', self.train_prec, on_step=False, on_epoch=True)
            
            if random.randint(1, 50) == 5:
                print()
                print("train info at random intervals")
                print("p", pred_i)
                print("g", ground_truth)
                print("a", ta)
                print("f", tf)
                print("p", tp)
                print("l", loss)
            """
            """
            va = self.val_acc(preds=pred_i, target=ground_truth) # (model_output.argmax(dim=-1) == ground_truth).float().mean()
            vf = self.val_f1(preds=pred_i, target=ground_truth) 
            vp = self.val_prec(preds=pred_i, target=ground_truth) 
            
            self.log(f'{mode}_acc', self.val_acc, on_step=False, on_epoch=True)
            self.log(f'{mode}_f1', self.val_f1, on_step=False, on_epoch=True)
            self.log(f'{mode}_prec', self.val_prec, on_step=False, on_epoch=True)
            
            if random.randint(1, 50) == 5:
                print()
                print("val info at random intervals")
                print("p", pred_i)
                print("g", ground_truth)
                print("a", va)
                print("f", vf)
                print("p", vp)
                print("l", loss)
            """
            """
            print(pred_i)
            print(ground_truth)
            ta = self.test_acc(preds=pred_i, target=ground_truth) # (model_output.argmax(dim=-1) == ground_truth).float().mean()
            tf = self.test_f1(preds=pred_i, target=ground_truth) 
            tp = self.test_prec(preds=pred_i, target=ground_truth) 
            
            self.log(f'{mode}_acc', self.test_acc, on_step=False, on_epoch=True)
            self.log(f'{mode}_f1', self.test_f1, on_step=False, on_epoch=True)
            self.log(f'{mode}_prec', self.test_prec, on_step=False, on_epoch=True)
            """
        
        
            try:
            pass # print('pred i', pred_i.squeeze().detach().cpu().numpy())
        except Exception as e:
            print("DECENT EXCEPTION: loss n metrics pred")
            print(e)
        try:
            pass # print('gt', ground_truth.squeeze().detach().cpu().numpy())
        except Exception as e:
            print("DECENT EXCEPTION: loss n metrics gt")
            print(e)
        

        #print('self mo', self.mo)
        #print('self gt', self.gt)
        
        
        
        # ground_truth = ground_truth
        
        """
        print("gt", ground_truth)
        print("gt shape", ground_truth.shape)
        print("gt type", ground_truth.type())
        print(torch.zeros(ground_truth.size(0), self.n_classes))
        
        if len(ground_truth.shape) < 2:
            ground_truth_tmp_tmp = ground_truth.unsqueeze(1)
        else:
            ground_truth = ground_truth.transpose(1, 0)
        ground_truth_multi_hot = torch.zeros(ground_truth_tmp.size(0), self.n_classes).scatter_(1, ground_truth_tmp.to("cpu"), 1.).to("cuda")
        
        # this needs fixing
        # ground_truth_multi_hot = torch.zeros(ground_truth.size(0), 10).to("cuda").scatter_(torch.tensor(1).to("cuda"), ground_truth.to("cuda"), torch.tensor(1.).to("cuda")).to("cuda")
        """