In [1]:
# !./get_data.sh

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
import time
import wandb
from collections import defaultdict
from itertools import islice
from IPython import display
# import albumentations as A
from torchsummary import summary
from utils.dataset import MRI
from utils.functions import train_network, SaveBestModel
from utils.loss import dice_loss
from models.unet import UNet

In [2]:
device = 'cuda:1'

# Image from 8 slices is used

In [None]:
experiment_name = 'DEBUG'
time_string = time.strftime("%d%m_%H%M", time.localtime())
# saver = SaveBestModel(f'results/pths/{time_string}_{experiment_name}')
saver = None

config = {
    "learning_rate": 1e-3,
    "epochs": 500,
    "batch_size": 4,
    "saver": saver,
    "num_workers": 8,
    "criterion": dice_loss,    # nn.CrossEntropyLoss or dice_loss
    "model": {
        "num_classes": 8,
        "min_channels": 32,
        "max_channels": 512,
        "num_down_blocks": 4,
        "img_channels": 8,
        "dropout": 0.5,
        "upsampling_mode": "upsampling",
        "norm_mode": "instance"
    }
}

wandb.init(project="bsim", name=experiment_name, config=config, entity="bsim-skt")

train_dataset = MRI('train', mode = 'slices')
validation_dataset = MRI('test', mode = 'slices')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers']) 
validation_dataloader = DataLoader(validation_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

model = UNet(**config['model']).to(device)
        
# opt = torch.optim.Adam(model.parameters(), lr=wandb.config['learning_rate'], weight_decay=1e-4)
opt = torch.optim.SGD(model.parameters(), lr=wandb.config['learning_rate'], weight_decay=1e-4)

train_network(model, opt, config['criterion'], wandb.config['epochs'], 
              train_dataloader, validation_dataloader, device, saver, use_wandb=True)
wandb.finish()

# Only reconstr image is used

In [4]:
experiment_name = 'reconstr_only_dice'
time_string = time.strftime("%d%m_%H%M", time.localtime())
saver = SaveBestModel(f'results/pths/{time_string}_{experiment_name}')

config = {
    "learning_rate": 1e-3,
    "epochs": 500,
    "batch_size": 4,
    "saver": saver,
    "num_workers": 8,
    "criterion": dice_loss,    # nn.CrossEntropyLoss or dice_loss
    "model": {
        "num_classes": 8,
        "min_channels": 32,
        "max_channels": 512,
        "num_down_blocks": 4,
        "img_channels": 1,
        "dropout": 0.5,
        "upsampling_mode": "upsampling",
        "norm_mode": "instance"
    }
}

wandb.init(project="bsim", name=experiment_name, config=config, entity="bsim-skt")

train_dataset = MRI('train', mode = 'reconstructed_only')
validation_dataset = MRI('test', mode = 'reconstructed_only')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers']) 
validation_dataloader = DataLoader(validation_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

model = UNet(**config['model']).to(device)
        
# opt = torch.optim.Adam(model.parameters(), lr=wandb.config['learning_rate'], weight_decay=1e-4)
opt = torch.optim.SGD(model.parameters(), lr=wandb.config['learning_rate'], weight_decay=1e-4)

train_network(model, opt, config['criterion'], wandb.config['epochs'], 
              train_dataloader, validation_dataloader, device, saver, use_wandb=True)
wandb.finish()

Dice:0.311: 100%|█████████████████████████████| 500/500 [31:46<00:00,  3.81s/it]


VBox(children=(Label(value='0.718 MB of 0.718 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Mean Dice,▁▃▃▃▃▃▃▃▃▃▃▃▄▄▄▄▅▆▇▇▇▇██████████████████
Mean IOU,▁▂▃▃▃▄▄▄▃▃▃▃▄▄▄▄▅▇▇▇▇▇██████████████████
Mean accuracy,▁▄▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██████████████
Mean class recall,▁▄▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇███████████████████
Train loss,██▇▆▆▆▇▇▆▇▇▆▆▄▃▆▆▃▅▄▆▄▅▄▂▂▂▃▄▂▂▂▄▂▂▂▁▃▁▁
Val loss,█▇▇▇▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁

0,1
Mean Dice,0.31054
Mean IOU,0.27336
Mean accuracy,0.6156
Mean class recall,0.39725
Train loss,0.86613
Val loss,0.87519


# Reconstr image and zero slice is used

In [None]:
experiment_name = 'reconstr_and_zero_dice'
time_string = time.strftime("%d%m_%H%M", time.localtime())
saver = SaveBestModel(f'results/pths/{time_string}_{experiment_name}')

config = {
    "learning_rate": 1e-3,
    "epochs": 500,
    "batch_size": 4,
    "saver": saver,
    "num_workers": 8,
    "criterion": dice_loss,    # nn.CrossEntropyLoss or dice_loss
    "model": {
        "num_classes": 8,
        "min_channels": 32,
        "max_channels": 512,
        "num_down_blocks": 4,
        "img_channels": 2,
        "dropout": 0.5,
        "upsampling_mode": "upsampling",
        "norm_mode": "instance"
    }
}

wandb.init(project="bsim", name=experiment_name, config=config, entity="bsim-skt")

train_dataset = MRI('train', mode = 'fist_plus_reconstr')
validation_dataset = MRI('test', mode = 'fist_plus_reconstr')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers']) 
validation_dataloader = DataLoader(validation_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

model = UNet(**config['model']).to(device)
        
# opt = torch.optim.Adam(model.parameters(), lr=wandb.config['learning_rate'], weight_decay=1e-4)
opt = torch.optim.SGD(model.parameters(), lr=wandb.config['learning_rate'], weight_decay=1e-4)


train_network(model, opt, config['criterion'], wandb.config['epochs'], 
              train_dataloader, validation_dataloader, device, saver, use_wandb=True)
wandb.finish()

Dice:0.234:  37%|██████████▋                  | 184/500 [12:12<20:48,  3.95s/it]