In [1]:
import argparse
import logging
import os
import sys

import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, random_split

sys.path.append("..")
sys.path.append("../scripts/")
import superlayer.utils

from scripts import eval_net, train_net, get_args

from superlayer.models import SUnet
from superlayer.utils import BrainD, dice_coeff, one_hot, plot_img_array, plot_side_by_side

In [2]:
net1_W = np.load("/home/vib9/src/SL-Net/superlayer/models/superblocks/SLN_64.npy")
net5_W = np.load("/home/vib9/src/SL-Net/superlayer/models/superblocks/AE_64.npy")[:32,:,:,:]
net6_W = np.load("/home/vib9/src/SL-Net/superlayer/models/superblocks/AE_128.npy")[:64,:,:,:]
net7_W = np.load("/home/vib9/src/SL-Net/superlayer/models/superblocks/AE_200.npy")[:100,:,:,:]
net8_W = np.load("/home/vib9/src/SL-Net/superlayer/models/superblocks/AE_256.npy")[:128,:,:,:]

In [3]:
dir_img = '/home/gid-dalcaav/projects/neuron/data/t1_mix/proc/resize256-crop_x32-slice100/train/vols/'
dir_mask = '/home/gid-dalcaav/projects/neuron/data/t1_mix/proc/resize256-crop_x32-slice100/train/asegs/'

dir_train = '/home/vib9/src/SL-Net/jupyter/partitions/train.txt'
dir_val = '/home/vib9/src/SL-Net/jupyter/partitions/val.txt'

dir_checkpoint_1 = 'checkpoints_1/'
dir_checkpoint_2 = 'checkpoints_2/'

In [4]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
device = torch.device('cuda')
logging.info(f'Using device {device}')

INFO: Using device cuda


In [5]:
target_label_numbers = [0,2,3,4,10,16,17,28,31,41,42,43,49,53,63]
val_percent = 0.1
batch_size = args.batchsize
lr = args.lr
img_scale = args.scale
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')

In [None]:
enc_nf = [32, 32, 32, 32]
dec_nf = [32, 32, 32, 32]
net5 = SUnet(input_ch=1, out_ch=15, use_bn=True, enc_nf=enc_nf, dec_nf=dec_nf, ignore_last=False, W=net5_W)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net5.to(device=device)
train_scores5, val_scores5, train_var_5, val_var_5 = train_net(net=net5,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

(32, 64, 3, 3)


INFO: Network:
	1 input channels
	15 output channels (classes)

INFO: Creating dataset with 7329 examples
INFO: Creating dataset with 7329 examples
INFO: Starting training:
        Epochs:          8
        Batch size:      8
        Learning rate:   0.001
        Training size:   7329
        Validation size: 7329
        Checkpoints:     True
        Device:          cuda
        Images scaling:  1
    
Epoch 1/8:  20%|█▉        | 1456/7329 [00:08<00:29, 200.50img/s, loss (batch)=0.425]
Validation round:   0%|          | 0/916 [00:00<?, ?batch/s][A
Validation round:   0%|          | 1/916 [00:00<10:32,  1.45batch/s][A
Validation round:   1%|          | 6/916 [00:00<07:25,  2.04batch/s][A
Validation round:   1%|          | 9/916 [00:00<05:26,  2.78batch/s][A
Validation round:   1%|▏         | 12/916 [00:01<03:58,  3.79batch/s][A
Validation round:   2%|▏         | 17/916 [00:01<02:56,  5.10batch/s][A
Validation round:   2%|▏         | 20/916 [00:01<02:16,  6.57batch/s][A
Valida

Validation round:  77%|███████▋  | 702/916 [00:31<00:12, 17.73batch/s][A
Validation round:  78%|███████▊  | 710/916 [00:31<00:10, 18.81batch/s][A
Validation round:  78%|███████▊  | 718/916 [00:31<00:09, 20.02batch/s][A
Validation round:  79%|███████▉  | 726/916 [00:32<00:09, 20.92batch/s][A
Validation round:  80%|████████  | 734/916 [00:32<00:08, 21.55batch/s][A
Validation round:  81%|████████  | 741/916 [00:32<00:06, 26.95batch/s][A
Validation round:  81%|████████▏ | 745/916 [00:33<00:08, 21.05batch/s][A
Validation round:  82%|████████▏ | 750/916 [00:33<00:08, 19.31batch/s][A
Validation round:  83%|████████▎ | 758/916 [00:33<00:07, 20.66batch/s][A
Validation round:  84%|████████▎ | 765/916 [00:33<00:05, 26.05batch/s][A
Validation round:  84%|████████▍ | 769/916 [00:34<00:06, 21.24batch/s][A
Validation round:  84%|████████▍ | 774/916 [00:34<00:07, 19.15batch/s][A
Validation round:  85%|████████▌ | 782/916 [00:34<00:06, 19.73batch/s][A
Validation round:  86%|████████▌ | 790

Epoch 1/8:  40%|███▉      | 2928/7329 [01:10<00:25, 175.95img/s, loss (batch)=0.385]
Validation round:  37%|███▋      | 338/916 [00:14<00:23, 24.09batch/s][A
Validation round:  37%|███▋      | 343/916 [00:14<00:24, 23.43batch/s][A
Validation round:  38%|███▊      | 346/916 [00:14<00:22, 25.04batch/s][A
Validation round:  38%|███▊      | 351/916 [00:14<00:22, 24.68batch/s][A
Validation round:  39%|███▉      | 355/916 [00:14<00:20, 27.46batch/s][A
Validation round:  39%|███▉      | 359/916 [00:15<00:24, 22.74batch/s][A
Validation round:  40%|███▉      | 366/916 [00:15<00:19, 28.40batch/s][A
Validation round:  40%|████      | 370/916 [00:15<00:25, 21.64batch/s][A
Validation round:  41%|████      | 375/916 [00:15<00:25, 20.99batch/s][A
Validation round:  42%|████▏     | 382/916 [00:15<00:20, 26.43batch/s][A
Validation round:  42%|████▏     | 386/916 [00:16<00:25, 20.58batch/s][A
Validation round:  43%|████▎     | 391/916 [00:16<00:26, 20.05batch/s][A
Validation round:  44%|████

In [None]:
enc_nf = [64, 64, 64, 64]
dec_nf = [64, 64, 64, 64]
net6 = SUnet(input_ch=1, out_ch=15, use_bn=True, enc_nf=enc_nf, dec_nf=dec_nf, ignore_last=False, W=net6_W)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net6.to(device=device)
train_scores6, val_scores6, train_var_6, val_var_6 = train_net(net=net6,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
enc_nf = [100, 100, 100, 100]
dec_nf = [100, 100, 100, 100]
net7 = SUnet(input_ch=1, out_ch=15, use_bn=True, enc_nf=enc_nf, dec_nf=dec_nf, ignore_last=False, W=net7_W)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net7.to(device=device)
train_scores7, val_scores7, train_var_7, val_var_7 = train_net(net=net7,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
enc_nf = [128, 128, 128, 128]
dec_nf = [128, 128, 128, 128]
net8 = SUnet(input_ch=1, out_ch=15, use_bn=True, enc_nf=enc_nf, dec_nf=dec_nf, ignore_last=False, W=net8_W)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net8.to(device=device)
train_scores8, val_scores8, train_var_8, val_var_8 = train_net(net=net8,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
print("configuring combined plots")
domain = len(train_scores5)
x_values = [i+1 for i in range(domain)]

a5 = plt.subplot(1,2,1)
a5.set_ylim([0, 0.5])
plt.title("AE SLN 64 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores5, train_var_5)]
ziptraindown = [a - b for a, b in zip(train_scores5, train_var_5)]
zipvalup = [a + b for a, b in zip(val_scores5, val_var_5)]
zipvaldown = [a - b for a, b in zip(val_scores5, val_var_5)]

plt.plot(x_values, train_scores5, color="blue", label="train")
a5.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores5, color="orange", label="val")
a5.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

a6 = plt.subplot(1,2,2)
a6.set_ylim([0, 0.5])
plt.title("AE SLN 128 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores6, train_var_6)]
ziptraindown = [a - b for a, b in zip(train_scores6, train_var_6)]
zipvalup = [a + b for a, b in zip(val_scores6, val_var_6)]
zipvaldown = [a - b for a, b in zip(val_scores6, val_var_6)]

plt.plot(x_values, train_scores6, color="blue", label="train")
a6.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores6, color="orange", label="val")
a6.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

plt.show()

a7 = plt.subplot(1,2,1)
a7.set_ylim([0, 0.5])
plt.title("AE SLN 200 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores7, train_var_7)]
ziptraindown = [a - b for a, b in zip(train_scores7, train_var_7)]
zipvalup = [a + b for a, b in zip(val_scores7, val_var_7)]
zipvaldown = [a - b for a, b in zip(val_scores7, val_var_7)]

plt.plot(x_values, train_scores7, color="blue", label="train")
a7.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores7, color="orange", label="val")
a7.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

a8 = plt.subplot(1,2,2)
a8.set_ylim([0, 0.5])
plt.title("AE SLN 256 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores8, train_var_8)]
ziptraindown = [a - b for a, b in zip(train_scores8, train_var_8)]
zipvalup = [a + b for a, b in zip(val_scores8, val_var_8)]
zipvaldown = [a - b for a, b in zip(val_scores8, val_var_8)]

plt.plot(x_values, train_scores8, color="blue", label="train")
a8.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores8, color="orange", label="val")
a8.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

In [None]:
net1_W = np.load("/home/vib9/src/SL-Net/superlayer/models/superblocks/SLN_64.npy")
net5_W = np.load("/home/vib9/src/SL-Net/superlayer/models/superblocks/AE_64.npy")[32:,:,:,:]
net6_W = np.load("/home/vib9/src/SL-Net/superlayer/models/superblocks/AE_128.npy")[64:,:,:,:]
net7_W = np.load("/home/vib9/src/SL-Net/superlayer/models/superblocks/AE_200.npy")[100:,:,:,:]
net8_W = np.load("/home/vib9/src/SL-Net/superlayer/models/superblocks/AE_256.npy")[128:,:,:,:]

In [None]:
enc_nf = [32, 32, 32, 32]
dec_nf = [32, 32, 32, 32]
net9 = SUnet(input_ch=1, out_ch=15, use_bn=True, enc_nf=enc_nf, dec_nf=dec_nf, ignore_last=False, W=net5_W)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net9.to(device=device)
train_scores9, val_scores9, train_var_9, val_var_9 = train_net(net=net9,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
enc_nf = [64, 64, 64, 64]
dec_nf = [64, 64, 64, 64]
net10 = SUnet(input_ch=1, out_ch=15, use_bn=True, enc_nf=enc_nf, dec_nf=dec_nf, ignore_last=False, W=net6_W)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net10.to(device=device)
train_scores10, val_scores10, train_var_10, val_var_10 = train_net(net=net10,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
enc_nf = [100, 100, 100, 100]
dec_nf = [100, 100, 100, 100]
net11 = SUnet(input_ch=1, out_ch=15, use_bn=True, enc_nf=enc_nf, dec_nf=dec_nf, ignore_last=False, W=net7_W)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net11.to(device=device)
train_scores11, val_scores11, train_var_11, val_var_11 = train_net(net=net11,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
enc_nf = [128, 128, 128, 128]
dec_nf = [128, 128, 128, 128]
net12 = SUnet(input_ch=1, out_ch=15, use_bn=True, enc_nf=enc_nf, dec_nf=dec_nf, ignore_last=False, W=net8_W)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net12.to(device=device)
train_scores12, val_scores12, train_var_12, val_var_12 = train_net(net=net12,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
print("configuring combined plots")
domain = len(train_scores1)
x_values = [i+1 for i in range(domain)]

a9 = plt.subplot(1,2,1)
a9.set_ylim([0, 0.5])
plt.title("AE SLN 64 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores9, train_var_9)]
ziptraindown = [a - b for a, b in zip(train_scores9, train_var_9)]
zipvalup = [a + b for a, b in zip(val_scores9, val_var_9)]
zipvaldown = [a - b for a, b in zip(val_scores9, val_var_9)]

plt.plot(x_values, train_scores9, color="blue", label="train")
a9.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores9, color="orange", label="val")
a9.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

a10 = plt.subplot(1,2,2)
a10.set_ylim([0, 0.5])
plt.title("AE SLN 128 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores10, train_var_10)]
ziptraindown = [a - b for a, b in zip(train_scores10, train_var_10)]
zipvalup = [a + b for a, b in zip(val_scores10, val_var_10)]
zipvaldown = [a - b for a, b in zip(val_scores10, val_var_10)]

plt.plot(x_values, train_scores10, color="blue", label="train")
a10.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores10, color="orange", label="val")
a10.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

plt.show()

a11 = plt.subplot(1,2,1)
a11.set_ylim([0, 0.5])
plt.title("AE SLN 200 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores11, train_var_11)]
ziptraindown = [a - b for a, b in zip(train_scores11, train_var_11)]
zipvalup = [a + b for a, b in zip(val_scores11, val_var_11)]
zipvaldown = [a - b for a, b in zip(val_scores11, val_var_11)]

plt.plot(x_values, train_scores11, color="blue", label="train")
a11.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores11, color="orange", label="val")
a11.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

a12 = plt.subplot(1,2,2)
a12.set_ylim([0, 0.5])
plt.title("AE SLN 256 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores12, train_var_12)]
ziptraindown = [a - b for a, b in zip(train_scores12, train_var_12)]
zipvalup = [a + b for a, b in zip(val_scores12, val_var_12)]
zipvaldown = [a - b for a, b in zip(val_scores12, val_var_12)]

plt.plot(x_values, train_scores12, color="blue", label="train")
a12.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores12, color="orange", label="val")
a12.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()