# From Pytorch & Co to Plug_ai

## Using conventionnal Pytorch + MonAI

In [1]:
!nvidia-smi

Tue Apr 18 01:18:33 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.13    Driver Version: 525.60.13    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000000:1A:00.0 Off |                    0 |
| N/A   48C    P0    74W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

### Imports

In [2]:
import os
import glob
import torch
import numpy as np
from monai.transforms import (
    Compose,
    LoadImaged,
    AddChanneld,
    ToTensord,
    EnsureChannelFirstd,
    ConcatItemsd,
    SpatialCropd,
    AsDiscreted
)
from monai.data import Dataset

from monai.networks.nets import DynUNet
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from torch.optim import SGD
from monai.utils import set_determinism, first

### Dataset

In [3]:
data_dir = "/gpfswork/rech/ibu/commun/BraTS2021/BraTS2021_Training_Data/"

def get_datalist(dataset_dir):
    datalist = []
    with open(os.path.join(dataset_dir, "train.txt"), "r") as f:
        lines = f.readlines()
        for line in lines:
            file_dic = {}
            files = line.split()
            for i, file in enumerate(files[:-1]):
                file_dic[f"channel_{i}"] = os.path.join(dataset_dir, file)

            file_dic["label"] = os.path.join(dataset_dir, files[-1])
            datalist.append(file_dic)
    return datalist

datalist = get_datalist(data_dir)
keys = list(datalist[0].keys())

transform = Compose([
            LoadImaged(keys=keys),
            EnsureChannelFirstd(keys=keys),
            ConcatItemsd(keys[:-1], "input"),
            SpatialCropd(keys=['input', 'label'], # crop it to make easily usable for etape 1
                         roi_size=[128, 128, 128],
                         roi_center=[0, 0, 0]
                         ),
            AsDiscreted(keys=['label'], to_onehot=5)
        ])



train_dataset = Dataset(data=datalist[:20],
                        transform=transform)

In [4]:
#for i, data in enumerate(datalist:
#    print(i, data['label'])

### Dataloader

In [5]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, 
                          batch_size=2, 
                          shuffle=True, 
                          num_workers=4,
                         prefetch_factor=10)

### Model

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = DynUNet(spatial_dims = 3,
                in_channels = 4,
                out_channels = 5,
                kernel_size = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]],
                strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
                upsample_kernel_size = [[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
                norm_name = "instance",  # you can use fused kernel for normal layer when set to `INSTANCE_NVFUSER`
                deep_supervision =  True,
                deep_supr_num = 3,).to(device)


### Criterion

In [7]:
loss_function = DiceCELoss(to_onehot_y=True, 
                           softmax=True)

### Metric

In [8]:
metric = DiceMetric(include_background=False, 
                    reduction="mean")

### Optimizer

In [9]:
optimizer = SGD(model.parameters(), lr=0.001)

### Training Loop

In [10]:
num_epochs = 2
val_interval = 2

best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []

for epoch in range(num_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{num_epochs}")
    model.train()
    epoch_loss = 0
    step = 0

    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["input"].to(device),
            batch_data["label"].to(device),
        )
        
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = torch.unbind(outputs, dim=1)[0]
        
        labels = torch.argmax(labels, dim=1, keepdim=True)
        loss = loss_function(outputs, labels)
    
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        #print(f"{step}/{len(train_dataset) // train_loader.batch_size}, train_loss: {loss.item():.4f}")
    
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

----------
epoch 1/2
epoch 1 average loss: 3.1551
----------
epoch 2/2
epoch 2 average loss: 3.0625


## Plug_AI

### Config file

In [11]:
#writefile

### Execution

In [12]:
#!python -m plug_ai --config_file ...