In [1]:
# !./get_data.sh

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
import time
import wandb
from collections import defaultdict
from itertools import islice
from IPython import display
# import albumentations as A
from torchsummary import summary
from utils.dataset import MRI
from utils.functions import train_network, SaveBestModel
from utils.loss import dice_loss
from models.unet import UNet

In [2]:
device = 'cuda:1'

# Sweep config

In [3]:
sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'Mean Dice',
        'goal': 'maximize'   
    }}

In [4]:
parameters_dict = {
    'optimizer': {
        'values': ['adam', 'sgd']
    },
    'learning_rate': {
        'distribution': 'uniform',
        'min': 0,
        'max': 0.1
    },
    'weight_decay': {
        'distribution': 'uniform',
        'min': 0,
        'max': 0.75
    },
    'batch_size': {
        'distribution': 'q_log_uniform_values',
        'q': 4,
        'min': 4,
        'max': 64,
    },
    'criterion': {
        'values': ['ce', 'dice']
    },
    'num_epochs': {
        'value': 200
    },
    'saver': {
        'value': None
    },
    'num_workers': {
        'value': 8
    },
    'num_classes': {
        'value': 8
    },
    'min_channels': {
        'value': 32
    },
    'max_channels': {
        'value': 512
    },
    'num_down_blocks': {
        'values': [3, 4, 5]
    },
    'img_channels': {
        'value': 8
    },
    'dropout': {
          'values': [0.3, 0.4, 0.5, 0.6]
    },
    'upsampling_mode': {
        'values': ['upsampling', 'conv_transpose']
    },
    'norm_mode': {
        'values': ['instance', 'batch']
    },
    }

sweep_config['parameters'] = parameters_dict

In [5]:
sweep_id = wandb.sweep(sweep_config, project="test-sweeps", entity="bsim-skt")

Create sweep with ID: gx43qrty
Sweep URL: https://wandb.ai/bsim-skt/test-sweeps/sweeps/gx43qrty


# Image from 8 slices is used

In [6]:
saver = None

train_dataset = MRI('train', mode = 'slices')
validation_dataset = MRI('test', mode = 'slices')

def run_sweep(config=None):
    with wandb.init(config=config):
        config = wandb.config

        train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers) 
        validation_dataloader = DataLoader(validation_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

        model = UNet(
            num_classes=config.num_classes, min_channels=config.min_channels,
            max_channels=config.max_channels, num_down_blocks=config.num_down_blocks, 
            img_channels=config.img_channels, upsampling_mode=config.upsampling_mode, 
            norm_mode=config.norm_mode).to(device)

        if config.optimizer == 'adam':
            opt = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
        elif config.optimizer == 'sgd':
            opt = torch.optim.SGD(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
        
        if config.criterion == 'ce':
            criterion = nn.CrossEntropyLoss()
        elif config.criterion == 'dice':
            criterion = dice_loss
        

        train_network(network=model, opt=opt, criterion=criterion, num_epochs=config.num_epochs, 
                      train_loader=train_dataloader, val_loader=validation_dataloader, device=device, 
                      saver=config.saver, use_wandb=True)

In [None]:
wandb.agent('bsim-skt/test-sweeps/30sppqsh', run_sweep, count=10)

[34m[1mwandb[0m: Agent Starting Run: c0jbtypq with config:
[34m[1mwandb[0m: 	batch_size: 4
[34m[1mwandb[0m: 	criterion: dice
[34m[1mwandb[0m: 	dropout: 0.4
[34m[1mwandb[0m: 	img_channels: 8
[34m[1mwandb[0m: 	learning_rate: 0.0006781078949649116
[34m[1mwandb[0m: 	max_channels: 512
[34m[1mwandb[0m: 	min_channels: 32
[34m[1mwandb[0m: 	norm_mode: instance
[34m[1mwandb[0m: 	num_classes: 8
[34m[1mwandb[0m: 	num_down_blocks: 5
[34m[1mwandb[0m: 	num_epochs: 200
[34m[1mwandb[0m: 	num_workers: 8
[34m[1mwandb[0m: 	optimizer: adam
[34m[1mwandb[0m: 	saver: None
[34m[1mwandb[0m: 	upsampling_mode: upsampling
[34m[1mwandb[0m: 	weight_decay: 0.3810053559630216


Dice:0.282: 100%|█████████████████████████████| 200/200 [21:39<00:00,  6.50s/it]


VBox(children=(Label(value='0.440 MB of 0.440 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Mean Dice,█▄▅▁▁▂▅▅▅▅▅▅▆▃▃▂█▅▃▄▅▃▅▄▅▅▄▅▄▃▅▂▆▃▅▄▂▅▁▅
Mean IOU,█▄▆▂▁▃▅▅▆▆▅▆▇▄▄▃█▆▃▅▆▄▅▅▅▆▅▆▅▃▅▂▆▃▆▅▃▅▂▆
Mean accuracy,▅▆▆▃▁▂▆▆▆▅▆▆▇▄▄▄█▆▄▆▇▅▆▆▆▆▅▇▆▄▆▃▇▄▆▅▃▆▂▆
Mean class recall,█▅▆▂▁▄▆▅▆▅▆▆▇▄▄▃█▆▄▅▆▄▆▅▆▆▅▆▄▄▅▂▇▄▆▅▃▅▂▅
Train loss,▂▃▅▅▃▇▄▂▅▅▆▃▆▇▄▁▃▄▆▆▆▅▂▂▆▅▆▆▄▅▅▃▅▆▆▃█▄▄▆
Val loss,▁▅▆▇▇█▆▆▆▅▆▅▃▆▂▃▂▁▂▃▃▃▃▁▂▁▄▁▅▆▆▆▄▅▄▃▃▄▄▁

0,1
Mean Dice,0.2824
Mean IOU,0.25466
Mean accuracy,0.56891
Mean class recall,0.33119
Train loss,0.90202
Val loss,0.90152


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 211zxk9i with config:
[34m[1mwandb[0m: 	batch_size: 4
[34m[1mwandb[0m: 	criterion: dice
[34m[1mwandb[0m: 	dropout: 0.6
[34m[1mwandb[0m: 	img_channels: 8
[34m[1mwandb[0m: 	learning_rate: 0.0006064438876040385
[34m[1mwandb[0m: 	max_channels: 512
[34m[1mwandb[0m: 	min_channels: 32
[34m[1mwandb[0m: 	norm_mode: instance
[34m[1mwandb[0m: 	num_classes: 8
[34m[1mwandb[0m: 	num_down_blocks: 5
[34m[1mwandb[0m: 	num_epochs: 200
[34m[1mwandb[0m: 	num_workers: 8
[34m[1mwandb[0m: 	optimizer: sgd
[34m[1mwandb[0m: 	saver: None
[34m[1mwandb[0m: 	upsampling_mode: conv_transpose
[34m[1mwandb[0m: 	weight_decay: 0.09457546202889688


Dice:0.248:  93%|██████████████████████████▉  | 186/200 [19:11<01:26,  6.19s/it]

In [13]:
criterion = nn.CrossEntropyLoss()
train_dataset = MRI('train', mode = 'slices')
validation_dataset = MRI('test', mode = 'slices')
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8) 
validation_dataloader = DataLoader(validation_dataset, batch_size=8, shuffle=False, num_workers=8)

model = UNet(num_classes=8, in_channels=8, min_channels=32, num_down_blocks=4, normalization_mode='instance').to(device)
opt = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-4)

In [13]:
experiment_name = 'unet_test'
time_string = time.strftime("%d%m_%H%M", time.localtime())
writer = SummaryWriter(comment=experiment_name, flush_secs=30, log_dir=f'results/runs/{time_string}_{experiment_name}/{experiment_name}')
saver = SaveBestModel(f'results/pths/{time_string}_{experiment_name}')
# writer = None
# saver = None
train_network(model, opt, criterion, 300, writer, saver)

  0%|          | 0/300 [00:00<?, ?it/s]

RuntimeError: Given groups=1, weight of size [32, 2, 3, 3], expected input[8, 8, 64, 64] to have 2 channels, but got 8 channels instead

In [None]:
pred = model(validation_dataset[:][0].to(device))
plt.imshow(validation_dataset[21][1])

In [None]:
plt.imshow(pred[21].squeeze().argmax(axis=0).cpu())

# Only reconstr image is used

In [None]:
criterion = nn.CrossEntropyLoss()
train_dataset = MRI('train')
validation_dataset = MRI('test')
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8) 
validation_dataloader = DataLoader(validation_dataset, batch_size=8, shuffle=False, num_workers=8)

model = UNet(num_classes=8, in_channels=1, min_channels=32, num_down_blocks=4).to(device)
opt = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-4)

In [None]:
experiment_name = 'only reconstructed image'
time_string = time.strftime("%d%m_%H%M", time.localtime())
writer = SummaryWriter(comment=experiment_name, flush_secs=30, log_dir=f'results/runs/{time_string}_{experiment_name}/{experiment_name}')
saver = SaveBestModel(f'results/pths/{time_string}_{experiment_name}')
# writer = None
# saver = None
train_network(model, opt, criterion, 100, writer, saver)

In [None]:
pred = model(validation_dataset[:][0].to(device))
plt.imshow(validation_dataset[21][1])

In [None]:
plt.imshow(pred[21].squeeze().argmax(axis=0).cpu())

# Reconstr image and zero slice is used

In [None]:
criterion = nn.CrossEntropyLoss()
train_dataset = MRI('train', mode = 'fist_plus_reconstr')
validation_dataset = MRI('test', mode = 'fist_plus_reconstr')
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8) 
validation_dataloader = DataLoader(validation_dataset, batch_size=8, shuffle=False, num_workers=8)

model = UNet(num_classes=8, in_channels=2, min_channels=32, num_down_blocks=4).to(device)
opt = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-4)

In [None]:
experiment_name = 'reconstructed and zero slice'
time_string = time.strftime("%d%m_%H%M", time.localtime())
writer = SummaryWriter(comment=experiment_name, flush_secs=30, log_dir=f'results/runs/{time_string}_{experiment_name}/{experiment_name}')
saver = SaveBestModel(f'results/pths/{time_string}_{experiment_name}')
# writer = None
# saver = None
train_network(model, opt, criterion, 100, writer, saver)

In [None]:
pred = model(validation_dataset[:][0].to(device))
plt.imshow(validation_dataset[21][1])

In [None]:
plt.imshow(pred[21].squeeze().argmax(axis=0).cpu())