In [1]:
# !./get_data.sh

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
import time
import wandb
from collections import defaultdict
from itertools import islice
from IPython import display
# import albumentations as A
from torchsummary import summary
from utils.dataset import MRI
from utils.functions import train_network, SaveBestModel
from utils.loss import dice_loss
from models.unet import UNet

In [2]:
device = 'cuda:0'

# 16

# Image from 1 slice is used

In [3]:
# !bash get_data_16.sh

In [4]:
experiment_name = '1_slice_dice_16px'
time_string = time.strftime("%d%m_%H%M", time.localtime())
saver = SaveBestModel(f'results/pths/{time_string}_{experiment_name}')
# saver = None

config = {
    "learning_rate": 5e-5,
    "epochs": 500,
    "batch_size": 4,
    "saver": saver,
    "num_workers": 8,
    "criterion": dice_loss,    # nn.CrossEntropyLoss or dice_loss
    "model": {
        "num_classes": 5,
        "min_channels": 32,
        "max_channels": 512,
        "num_down_blocks": 3,
        "img_channels": 1,
        "dropout": 0.5,
        "upsampling_mode": "upsampling",
        "norm_mode": "instance"
    }
}

wandb.init(project="bsim", name=experiment_name, config=config, entity="bsim-skt")

train_dataset = MRI('train', data_path='data_16', img_size=16, mode = '1_slice')
validation_dataset = MRI('test', data_path='data_16', img_size=16, mode = '1_slice')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers']) 
validation_dataloader = DataLoader(validation_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

model = UNet(**config['model']).to(device)
        
opt = torch.optim.Adam(model.parameters(), lr=wandb.config['learning_rate'], weight_decay=1e-4)
# opt = torch.optim.SGD(model.parameters(), lr=wandb.config['learning_rate'], weight_decay=1e-4)

train_network(model, opt, config['criterion'], wandb.config['epochs'], 
              train_dataloader, validation_dataloader, device, saver, use_wandb=True)
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mdartemasov[0m ([33mbsim-skt[0m). Use [1m`wandb login --relogin`[0m to force relogin


Dice:0.823: 100%|█████████████████████████████████████████████████████████████████████| 500/500 [15:29<00:00,  1.86s/it]


VBox(children=(Label(value='0.180 MB of 0.180 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Mean Dice,▁▃▄▄▄▅▅▅▅▅▅▅▅▆▆▇▇▇▇▆▇▇▇▇▇▇▇███▇▇▇███▇▇██
Mean IOU,▁▃▄▄▄▅▅▆▆▅▆▆▅▆▇▇▇▇▇▆▇▇█▇▇▇▇███▇▇▇███▇▇██
Mean accuracy,▁▃▅▅▆▆▇▇▇▇▇▇▇▇▇████▇█▇██▇██████▇████▇███
Mean class recall,▁▃▄▄▄▅▅▅▅▅▅▅▅▆▆▇▇▇▇▆▇▇▇▇▆▇▇███▇▇▇██▇▇▇▇█
Train loss,█▅▄▃▃▃▃▃▂▃▃▃▄▂▂▄▂▃▁▂▂▃▂▂▃▂▃▅▂▃▃▂▃▂▂▃▂▂▂▃
Val loss,█▆▄▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Mean Dice,0.82309
Mean IOU,0.75521
Mean accuracy,0.9248
Mean class recall,0.8221
Train loss,0.70985
Val loss,0.68488
