# MedMNIST Classification

Here is the PyTorch code for the OrgMNIST downstream task (OrganSMNIST Challenge from MedMNIST)

Before you start you need to do the following:
- Download the Pre-Trianing Checkpoints from the GitHub README.md 

The code downloads the data automaticly from: https://medmnist.com/


The code contains parts of: Monai (https://monai.io/), Lightning AI (https://lightning.ai/), Lightning-Hydra (https://github.com/ashleve/lightning-hydra-template)


### Preferences:

You need to fill out the first cell with your preferences:

In [1]:
# Folder where the results will be saved to: (Create a folder on your computer and type the filename here)
root_dir = "/home/wolfda/Data/Spark/Downstream/Results"

# Choose names for Weights & Biases
Project = "MedMNIST" # Name of the WandB Project
Run = "SparK_1" # Name of the Run inside the Project
wandb_tag= ["Tag"] # You can give an Tag to the Run here if you want

# Pre-Training
# Chose if you want to use a pre-trained model (Ture == Pre-Trained weights are loaded || Fales == No pre-training, model is trained from scratch)
preTrain = True 
# Type the filename to the downloaded pre-training Checkpoint here (.ckpt or .pth) [Download from Read.Me page]
pretrained_weights = "/home/wolfda/Data/Spark/Paper_Checkpoints_Upload/BYOL/BYOL.ckpt" # "/path/to/checkpoint/SparK.pth" 
# Choose the Pre-Training Method: 
pre_train = "BYOL" # You can write: "SwAV" "SparK" "BYOL" "MoCo" (this must match the Checkpoint)

# Choose the model
backbone_model = "ResNet" # We have only implemeted ResNet50 -> "ResNet"

# Train Dataset Reduction: 
# What percentage of the downstream task data should be used for training? 
data_percentage = 1 # 1 for all Data; 0.5 for half the Data; ...

# Number of Epochs: 
Epochen = 70

# Learning Rate
lr = 1e-4 
WeightDecay = 0.0005  

# Batch Size
bs = 64 

# Freeze Encoder in the beginning (only the linear layer is trained) 
first_frozen = True # True: first Frozen, after n epochs open + smaller learning rate || False: Train entire encoder form Start
unfreeze = 10 # From which epoch should the encoder be open (ATTENTION: If you change the model, the encoder must be called "backbone") 



From here the Code should run without the need of changes

### Imports:

In [2]:
import os
import numpy as np

# Monai
from monai.transforms import Resize
from monai.config import print_config

# Output Metrics
from sklearn.metrics import roc_auc_score, f1_score

# Weights & Biasis
from pytorch_lightning.loggers import WandbLogger
# Wandb Callbacks from hydra Template
import wandb_callbacks as wbcall
# Logger
from logging import getLogger
logger = getLogger()

# PyTorch 
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import pytorch_lightning
from pytorch_lightning import callbacks
import torchmetrics
import torch.utils.data as data_utils
torch.multiprocessing.set_sharing_strategy('file_system')

# MedMNIST
import medmnist
from medmnist import INFO
import torch.utils.data as data

# Seed Everything 
pytorch_lightning.seed_everything(42, workers=True)

  from .autonotebook import tqdm as notebook_tqdm


2023-11-23 16:40:43,949 - Created a temporary directory at /tmp/tmpe92l88ye
2023-11-23 16:40:43,950 - Writing /tmp/tmpe92l88ye/_remote_module_non_scriptable.py
2023-11-23 16:40:44,025 - Global seed set to 42


42

### Loss + Optimizer
You can change the Loss or the Optimizer here:

In [3]:
# Loss, Optimizer 
loss = torch.nn.CrossEntropyLoss() # NLLLoss() if model already makes softmax | CrossEntropyLoss() if model already makes softmax  || Weight a class: in (): weight=torch.FloatTensor([0.3,0.7])
optim = "Adam"
# Save informations for wandb:
info_params={"Path_save": root_dir, "PreTrain_Weights": pretrained_weights, "Epochs": Epochen, "first_frozen": first_frozen, "Epochs Freeze untill": unfreeze, "Learning_Rate": lr,  "Bs": bs, "Optim Weight_Decay": WeightDecay,}

### Dataset:

In [4]:
data_flag = 'organsmnist'
download = True

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

### Transforms:

In [5]:
# Gray scale Transforms:
# If gray scale: three times the same image together -> to get 3 chaneles (Model is build for RGB (=3 Channel) image)
class gray_to_rgb(object):
    def __init__(self):
        pass
    def __call__(self, image):
        image = torch.cat((image, image, image), 0)
        return image

In [6]:
data_transform = transforms.Compose([
    transforms.ToTensor(),
    gray_to_rgb(),
    transforms.Normalize(mean=[.5], std=[.5]),
    Resize(spatial_size=(224,224)),
])

### Load Model and Pre-Training
Here, the ResNet50 model is loaded and initialized with the pre-training checkpoints:

#### Expected output of this Cell: 
If this appears, everything is correct: 
- missing_keys= ['fc.weight', 'fc.bias'] (beacuse the last fully connected layer was not pre-trained) 
- unexpected_keys= 
    - MoCo: All "encoder_k" layers (because MoCo has 2 encoders and we use only encoder_q)
    - BYOL: All "online_network.projector" and "target_network.encoder" layers (because BYOL has 2 encoders and we only the online_network.encoder)
    - SwAV: All "projection_head" layers (beacuse SwAV has an aditional projection head for the online clustering) 
    - SparK: []
    
    
#### You can insert a new model here:
To check the names of the model and the pre-training checkpoints, use:

    model = models.xyz()
    state_dict = torch.load(pretrained_weights)
    for k, v in model.state_dict().items():
        print(k)
    print("--------------------------------------------------------------------------------")
    for k, v in state_dict.items():
        print(k)

In [7]:
# Model-----------------------------------------------------------------------------

if backbone_model == "ResNet":
    res_model = models.resnet50() # PyTorch Torchvision Resnet50 Model
else:
    print("No Model")
    

# Pre-Training --------------------------------------------------------------------------------
if preTrain == True:
    
    
    # Load pre-training weights
    state_dict = torch.load(pretrained_weights)  #pretrained_weights
    
    # Match the correct name of the layers between pre-trained model and PyTorch ResNet
    # Extraction:
    if "module" in state_dict: # (SparK)
        state_dict = state_dict["module"] 
    if "state_dict" in state_dict: # (SwAV, MoCo, BYOL) 
        state_dict = state_dict["state_dict"]
    # Replacement:
    if pre_train == "SparK" or pre_train == "SwAV":
        state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} 
    elif pre_train == "MoCo":
        state_dict = {k.replace("encoder_q.", ""): v for k, v in state_dict.items()} 
    elif pre_train == "BYOL":
        state_dict = {k.replace("online_network.encoder.", ""): v for k, v in state_dict.items()} 

    # Initialisation of the ResNet model with pre-training checkpoints
    msg = res_model.load_state_dict(state_dict, strict=False)
    
    
    # Check if it works
    print(format(msg))

_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=['online_network.projector.model.0.weight', 'online_network.projector.model.1.weight', 'online_network.projector.model.1.bias', 'online_network.projector.model.1.running_mean', 'online_network.projector.model.1.running_var', 'online_network.projector.model.1.num_batches_tracked', 'online_network.projector.model.3.weight', 'online_network.projector.model.3.bias', 'online_network.predictor.model.0.weight', 'online_network.predictor.model.1.weight', 'online_network.predictor.model.1.bias', 'online_network.predictor.model.1.running_mean', 'online_network.predictor.model.1.running_var', 'online_network.predictor.model.1.num_batches_tracked', 'online_network.predictor.model.3.weight', 'online_network.predictor.model.3.bias', 'target_network.encoder.conv1.weight', 'target_network.encoder.bn1.weight', 'target_network.encoder.bn1.bias', 'target_network.encoder.bn1.running_mean', 'target_network.encoder.bn1.running_var', 't

### PyTorch Lightning Class

In [8]:
class Net(pytorch_lightning.LightningModule): 
    
    def __init__(self):
        super().__init__()
        
        # Remove the last layer (linear layer) and add a own linear layer 
        # This is needed for the freezing. The Conv part that should be frozen (without the linear layer) needs to be named "backbone". 
        self.net = res_model
        
        if backbone_model == "ResNet":
            # Remove last layer 
            num_filters = self.net.fc.in_features
            layers = list(self.net.children())[:-1]
            self.backbone = torch.nn.Sequential(*layers)
            # Add one linear layer for classification
            self.classifier = torch.nn.Linear(num_filters, n_classes)
            del self.net
        
        # Loss
        self.loss_function = loss 
        
        # Metrics
        self.best_acc = torchmetrics.MaxMetric()
    

    def forward(self, x):
        
        x = self.backbone(x).flatten(1)
        x = self.classifier(x)

        return x
    
    def on_train_start(self):
        # by default lightning executes validation step sanity checks before training starts,
        # so we need to make sure val_acc_best doesn't store accuracy from these checks
        self.best_acc.reset()


    def prepare_data(self):
        
        # Create and Reduce Dataset: 
        train_dataset_all = DataClass(split='train', transform=data_transform, download=download)
        indices = torch.arange(int(len(train_dataset_all)*data_percentage)) #data_percentage=1 means we use all data
        self.train_dataset = data_utils.Subset(train_dataset_all, indices)
        self.val_dataset = DataClass(split='val', transform=data_transform, download=download)
        self.test_dataset = DataClass(split='test', transform=data_transform, download=download)
        # Creates dataset (path in which all images and labels are located ([vol, class]) + transforms)
      
    def train_dataloader(self):
        train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=bs, shuffle=True, num_workers=10)
        print("Train Data Loader | Bs:", bs, "| len",  len(train_loader), "| ges", bs*len(train_loader))
        return train_loader

    def val_dataloader(self):
        val_loader = torch.utils.data.DataLoader(self.val_dataset, batch_size=1, num_workers=10)
        print("Val Data Loader", len(val_loader))
        return val_loader
    
    def test_dataloader(self): 
        test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=1, num_workers=10)
        print("Test Data Loader", len(test_loader))
        return test_loader

    def configure_optimizers(self):
        if optim == "Adam":
            optimizer = torch.optim.Adam(self.parameters(), lr, weight_decay = WeightDecay)
        return optimizer

    def training_step(self, batch, batch_idx): # batch =  PT Dataset
        images, labels= batch 
        labels = labels.squeeze(1).long()
        
        output = self.forward(images) 
        
        loss = self.loss_function(output, labels)
        
        # Wandb Logs
        self.log("train/loss", loss.item())

        return {"loss": loss}
    
    def training_epoch_end(self, outputs): # outputs = Sampels in the train dataset  x {"loss": loss, "log": tensorboard_logs}
        train_loss, num_items = 0, 0
        
        for output in outputs: # Loops through all smaples in the train dataset
            train_loss += output["loss"].sum().item() # Adds all losses
            num_items += 1
            
        mean_train_loss = torch.tensor(train_loss / num_items) # mean loss
        
        # Wandb Logs
        self.log("train/mean_loss", mean_train_loss,)
    

    def validation_step(self, batch, batch_idx): 
        images, labels= batch 
        labels = labels.squeeze(0).long()
        
        outputs = self.forward(images)
        
        klasse = torch.argmax(outputs, dim=1) # Find index with highest probability
        
        loss = self.loss_function(outputs, labels)
        
        return {"val_loss": loss, "targets": labels, "preds": klasse}
    

    def validation_epoch_end(self, outputs): # outputs = Samples in val dataset x {"val_loss": loss, "targets": labels, "preds": klasse}
        val_loss, num_items, true, false = 0, 0, 0, 0 
        
        
        for output in outputs: # Loops through all smaples in the val dataset
            val_loss += output["val_loss"].sum().item() # Add all losses
            num_items += 1
            
            # How many classified correctly
            if output["preds"] == output["targets"]: # if index with highest probability == correct class -> correctly classified
                true += 1
            else:
                false += 1
            
        mean_val_loss = torch.tensor(val_loss / num_items) 
    
        # Accurancy
        acc = torch.tensor(true / num_items)
        self.best_acc.update(acc)
        best_acc = self.best_acc.compute()
        
        # Wandb Logs 
        self.log("val/loss", mean_val_loss)
        self.log("val/accuracy", acc)
        self.log("val/best_accuracy", best_acc)
        

    
    def test_step(self, batch, batch_idx): 
        images, labels = batch 
        labels = labels.squeeze(0).long()
        
        outputs = self.forward(images) 
        outputs = torch.nn.functional.softmax(outputs)
        
        klasse = torch.argmax(outputs, dim=1)
        
        loss = self.loss_function(outputs, labels)
        
        return {"test_loss": loss, "output": outputs, "targets": labels, "preds": klasse} 
    
    
    def test_epoch_end(self, outputs): # outputs = Samples in test dataset x  {"test_loss": loss, "output": outputs, "lable": labels, "klasse": klasse}
        
        test_loss, true, false, num_items = 0, 0, 0, 0
        
        output_list = np.array([])
        lable_list = np.array([])
        output_prob = np.array([[]])
        
        i = 0
        for output in outputs: # Loops through all smaples in the val dataset

            test_loss += output["test_loss"] 
            num_items += 1
            
            output_hold = output["preds"].cpu().detach()
            lable_hold = output["targets"].cpu().detach()
            output_prob_hold = output["output"].cpu().detach()
            
            output_list = np.append(output_list, output_hold) # Predicted Class 0...5
            lable_list = np.append(lable_list, lable_hold) # Lable Class 0...5

            if output_prob.shape == (1, 0):
                output_prob=output_prob_hold
            else: 
                output_prob = np.append(output_prob, output_prob_hold, axis=0)
            

            if output["preds"] == output["targets"]: 
                true += 1
            else:
                false += 1
            i+=1
            
        output_list = output_list.astype(int)
        lable_list = lable_list.astype(int)
        output_prob = output_prob
        print(output_list.shape)
        print(lable_list.shape)
        print(output_prob.shape)
        
        
        # __________ DOKU ____________________________
        
        
        # Number correctly classified
        print("\n" + "True", true) 
        print("False", false)
        
        # Accurancy
        acc = torch.tensor(true / num_items) 
        print("\n" + "acc", acc)
        
        # F1
        f1 = f1_score(lable_list, output_list, average='micro')
        print("\n" + "F1",f1)
        
        # AUC
        auc = roc_auc_score(lable_list, output_prob, average = 'macro', multi_class='ovr') #multi_class='ovr'
        print("\n" + "AUC",auc)

        # Wandb Logs
        self.log("test/Accuracy", acc)
        self.log("test/AUC", auc)
        self.log("test/F1", f1)

### Train, Test, Save

In [9]:
# initialise the LightningModule
net = Net()

# Creates a path to save the checkpoints and WandB infos: 
checkpoint_dir = os.path.join(root_dir, Project, Run) 
try:
    os.makedirs(checkpoint_dir) 
except OSError:
    print ("Path %s exsists" % checkpoint_dir)
else:
    print ("Successfully created path %s" % checkpoint_dir)

    
# weights and biases
wandb_logger = WandbLogger(project=Project, name=Run, tags = wandb_tag, save_dir=checkpoint_dir)


## Callbacks  

# Saves the Checkpoints 
checkpoint_callback = callbacks.ModelCheckpoint( # opens a new folder each time (logs-"test1-{epoch}-{val_loss:.2f}") Stores the status of the last training step
    dirpath = checkpoint_dir, # Saves path
    monitor = "val/accuracy", # what I want to write in filename I have to monitor here  
    mode="max", 
    filename = "{epoch}",
    save_last = True, # Saves the status of the last epoch of training (Name: last.ckpt)
    save_top_k = 2, # Saves the status of the k=2 best epochs [best val accuracy] of the training  (Name: epoch=*-val_loss=*.ckpt)
)

# Set up to which epoch you want to freeze 
# What you want to freeze must be called "backbone"
finetuning = callbacks.BackboneFinetuning(
    unfreeze_backbone_at_epoch=unfreeze, 
    )

# Loggt Lr in Wandb
lr_monitor = callbacks.LearningRateMonitor(logging_interval=None, log_momentum=True)


# All Callbacks together  (from here and from Hydra [wbcall von wandb_callbacks.py, von hydra template])
if first_frozen == True:
    callback_summary = [checkpoint_callback, finetuning, lr_monitor, wbcall.LogConfusionMatrix(), wbcall.LogF1PrecRecHeatmap()] #wbcall.LogF1AUCTest()
else:
    callback_summary = [checkpoint_callback, lr_monitor, wbcall.LogConfusionMatrix(), wbcall.LogF1PrecRecHeatmap()] #wbcall.LogF1AUCTest()


# initialise Lightning's trainer
trainer = pytorch_lightning.Trainer(
    gpus=1,
    max_epochs=Epochen, # Epochs
    logger=wandb_logger, # weight & biasis 
    log_every_n_steps=9, # How often does it log (every 9 steps) 
    callbacks=callback_summary,
    num_sanity_val_steps=1, # Does a run with the Val dataset to check if everything fits
)
trainer.logger.log_hyperparams(info_params) # So that my information is logged in wandb

# train + val
trainer.fit(net)

# test
ckpt_path = trainer.checkpoint_callback.best_model_path
print(ckpt_path)
trainer.test(net, ckpt_path=ckpt_path)

Successfully created path /home/wolfda/Data/Spark/Downstream/Results/MedMNIST/SparK_1


[34m[1mwandb[0m: Currently logged in as: [33mwolfda95[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_deprecation(


2023-11-23 16:40:56,406 - GPU available: True (cuda), used: True
2023-11-23 16:40:56,408 - TPU available: False, using: 0 TPU cores
2023-11-23 16:40:56,408 - IPU available: False, using: 0 IPUs
2023-11-23 16:40:56,408 - HPU available: False, using: 0 HPUs
Using downloaded and verified file: /home/wolfda/.medmnist/organsmnist.npz
Using downloaded and verified file: /home/wolfda/.medmnist/organsmnist.npz
Using downloaded and verified file: /home/wolfda/.medmnist/organsmnist.npz


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


2023-11-23 16:40:56,704 - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
2023-11-23 16:40:56,715 - 
  | Name          | Type             | Params
---------------------------------------------------
0 | backbone      | Sequential       | 23.5 M
1 | classifier    | Linear           | 22.5 K
2 | loss_function | CrossEntropyLoss | 0     
3 | best_acc      | MaxMetric        | 0     
---------------------------------------------------
75.7 K    Trainable params
23.5 M    Non-trainable params
23.5 M    Total params
94.122    Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]Val Data Loader 2452
Train Data Loader | Bs: 64 | len 218 | ges 13952                           
Epoch 0:   8%|▊         | 218/2670 [00:36<06:45,  6.04it/s, loss=2.19, v_num=ntuo]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/2452 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/2452 [00:00<?, ?it/s][A
Epoch 0:   8%|▊         | 219/2670 [00:36<06:49,  5.98it/s, loss=2.

  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1:   6%|▌         | 154/2670 [00:26<07:06,  5.90it/s, loss=2.06, v_num=ntuo] /home/wolfda/Data/Spark/Downstream/Results/MedMNIST/SparK_1/epoch=0.ckpt
Using downloaded and verified file: /home/wolfda/.medmnist/organsmnist.npz
Using downloaded and verified file: /home/wolfda/.medmnist/organsmnist.npz
Using downloaded and verified file: /home/wolfda/.medmnist/organsmnist.npz
2023-11-23 16:42:32,490 - Restoring states from the checkpoint path at /home/wolfda/Data/Spark/Downstream/Results/MedMNIST/SparK_1/epoch=0.ckpt


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


<Figure size 1400x800 with 0 Axes>

<Figure size 1400x300 with 0 Axes>

In [12]:
import IPython

IPython.Application.instance().kernel.do_shutdown(True) #automatically restarts kernel

{'status': 'ok', 'restart': True}