In [1]:
import argparse
import logging
import os
import sys

import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, random_split

sys.path.append("..")
sys.path.append("../scripts/")
import superlayer.utils

from scripts import eval_net, train_net, get_args

from superlayer.models import SuperNet, AESuperNet
from superlayer.utils import BrainD, dice_coeff, one_hot, plot_img_array, plot_side_by_side

In [2]:
dir_img = '/home/gid-dalcaav/projects/neuron/data/t1_mix/proc/resize256-crop_x32-slice100/train/vols/'
dir_mask = '/home/gid-dalcaav/projects/neuron/data/t1_mix/proc/resize256-crop_x32-slice100/train/asegs/'

dir_train = '/home/vib9/src/SL-Net/jupyter/partitions/train.txt'
dir_val = '/home/vib9/src/SL-Net/jupyter/partitions/val.txt'

dir_checkpoint_1 = 'checkpoints_1/'
dir_checkpoint_2 = 'checkpoints_2/'

In [3]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
device = torch.device('cuda')
logging.info(f'Using device {device}')

INFO: Using device cuda


In [4]:
target_label_numbers = [0,2,3,4,10,16,17,28,31,41,42,43,49,53,63]
val_percent = 0.1
batch_size = args.batchsize
lr = args.lr
img_scale = args.scale
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')

In [None]:
net1 = SuperNet(input_ch=1, out_ch=15, use_bn=True, superblock_size=64, depth=4)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net1.to(device=device)
train_scores1, val_scores1, train_var_1, val_var_1 = train_net(net=net1,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

INFO: Network:
	1 input channels
	15 output channels (classes)

INFO: Creating dataset with 7329 examples
INFO: Creating dataset with 7329 examples
INFO: Starting training:
        Epochs:          8
        Batch size:      8
        Learning rate:   0.001
        Training size:   7329
        Validation size: 7329
        Checkpoints:     True
        Device:          cuda
        Images scaling:  1
    
Epoch 1/8:  20%|█▉        | 1456/7329 [00:07<00:32, 182.87img/s, loss (batch)=0.24] 
Validation round:   0%|          | 0/916 [00:00<?, ?batch/s][A
Validation round:   0%|          | 1/916 [00:00<10:43,  1.42batch/s][A
Validation round:   1%|          | 9/916 [00:00<07:34,  2.00batch/s][A
Validation round:   2%|▏         | 17/916 [00:01<05:24,  2.77batch/s][A
Validation round:   3%|▎         | 23/916 [00:01<03:50,  3.88batch/s][A
Validation round:   3%|▎         | 26/916 [00:01<03:03,  4.84batch/s][A
Validation round:   4%|▎         | 33/916 [00:01<02:13,  6.61batch/s][A
Valid

Validation round:  75%|███████▌  | 690/916 [00:27<00:09, 24.84batch/s][A
Validation round:  76%|███████▌  | 698/916 [00:27<00:08, 24.81batch/s][A
Validation round:  77%|███████▋  | 706/916 [00:27<00:08, 24.91batch/s][A
Validation round:  78%|███████▊  | 714/916 [00:28<00:08, 24.76batch/s][A
Validation round:  79%|███████▉  | 722/916 [00:28<00:07, 25.62batch/s][A
Validation round:  80%|███████▉  | 730/916 [00:28<00:07, 25.54batch/s][A
Validation round:  81%|████████  | 738/916 [00:29<00:06, 25.81batch/s][A
Validation round:  81%|████████▏ | 746/916 [00:29<00:06, 26.29batch/s][A
Validation round:  82%|████████▏ | 753/916 [00:29<00:05, 32.19batch/s][A
Validation round:  83%|████████▎ | 757/916 [00:29<00:06, 25.81batch/s][A
Validation round:  83%|████████▎ | 762/916 [00:29<00:06, 22.99batch/s][A
Validation round:  84%|████████▍ | 769/916 [00:30<00:05, 28.30batch/s][A
Validation round:  84%|████████▍ | 773/916 [00:30<00:05, 24.76batch/s][A
Validation round:  85%|████████▍ | 778

Validation round:  47%|████▋     | 434/916 [00:17<00:18, 25.66batch/s][A
Validation round:  48%|████▊     | 442/916 [00:17<00:18, 25.64batch/s][A
Validation round:  49%|████▉     | 450/916 [00:17<00:18, 24.91batch/s][A
Validation round:  50%|█████     | 458/916 [00:18<00:18, 25.21batch/s][A
Validation round:  51%|█████     | 466/916 [00:18<00:17, 25.46batch/s][A
Validation round:  52%|█████▏    | 474/916 [00:18<00:17, 25.46batch/s][A
Validation round:  53%|█████▎    | 482/916 [00:18<00:17, 25.45batch/s][A
Epoch 1/8:  40%|███▉      | 2928/7329 [01:10<00:23, 186.16img/s, loss (batch)=0.213]
Validation round:  54%|█████▍    | 498/916 [00:19<00:16, 25.21batch/s][A
Validation round:  55%|█████▌    | 506/916 [00:19<00:16, 25.12batch/s][A
Validation round:  56%|█████▌    | 514/916 [00:20<00:16, 24.96batch/s][A
Validation round:  57%|█████▋    | 522/916 [00:20<00:15, 25.19batch/s][A
Validation round:  58%|█████▊    | 530/916 [00:20<00:15, 25.36batch/s][A
Validation round:  59%|████

Validation round:  35%|███▍      | 317/916 [00:12<00:24, 24.39batch/s][A
Validation round:  35%|███▌      | 325/916 [00:13<00:23, 24.96batch/s][A
Validation round:  36%|███▋      | 333/916 [00:13<00:23, 24.93batch/s][A
Validation round:  37%|███▋      | 341/916 [00:13<00:23, 24.61batch/s][A
Validation round:  38%|███▊      | 349/916 [00:14<00:22, 24.66batch/s][A
Validation round:  39%|███▉      | 357/916 [00:14<00:22, 24.83batch/s][A
Validation round:  40%|███▉      | 365/916 [00:14<00:21, 25.08batch/s][A
Validation round:  41%|████      | 373/916 [00:15<00:21, 25.43batch/s][A
Validation round:  42%|████▏     | 381/916 [00:15<00:21, 25.14batch/s][A
Epoch 1/8:  60%|█████▉    | 4392/7329 [01:50<00:16, 182.21img/s, loss (batch)=0.211]
Validation round:  43%|████▎     | 397/916 [00:16<00:20, 25.65batch/s][A
Validation round:  44%|████▍     | 405/916 [00:16<00:19, 26.07batch/s][A
Validation round:  45%|████▌     | 413/916 [00:16<00:19, 25.54batch/s][A
Validation round:  46%|████

Validation round:  23%|██▎       | 207/916 [00:08<00:28, 25.23batch/s][A
Validation round:  23%|██▎       | 215/916 [00:09<00:27, 25.58batch/s][A
Validation round:  24%|██▍       | 223/916 [00:09<00:26, 25.70batch/s][A
Validation round:  25%|██▌       | 231/916 [00:09<00:26, 25.63batch/s][A
Validation round:  26%|██▌       | 239/916 [00:10<00:26, 25.64batch/s][A
Validation round:  27%|██▋       | 247/916 [00:10<00:26, 25.68batch/s][A
Validation round:  28%|██▊       | 255/916 [00:10<00:26, 25.22batch/s][A
Validation round:  29%|██▊       | 263/916 [00:11<00:26, 24.48batch/s][A
Validation round:  30%|██▉       | 271/916 [00:11<00:26, 24.40batch/s][A
Validation round:  30%|███       | 279/916 [00:11<00:25, 24.87batch/s][A
Epoch 1/8:  80%|███████▉  | 5856/7329 [02:30<00:06, 212.63img/s, loss (batch)=0.21]
Validation round:  32%|███▏      | 295/916 [00:12<00:24, 25.08batch/s][A
Validation round:  33%|███▎      | 303/916 [00:12<00:24, 24.69batch/s][A
Validation round:  34%|███▍ 

Validation round:  13%|█▎        | 121/916 [00:05<00:34, 23.16batch/s][A
Validation round:  14%|█▍        | 129/916 [00:05<00:33, 23.58batch/s][A
Validation round:  15%|█▍        | 137/916 [00:06<00:33, 23.58batch/s][A
Validation round:  16%|█▌        | 145/916 [00:06<00:32, 23.90batch/s][A
Validation round:  17%|█▋        | 152/916 [00:06<00:25, 29.48batch/s][A
Validation round:  17%|█▋        | 156/916 [00:06<00:34, 21.78batch/s][A
Validation round:  18%|█▊        | 161/916 [00:07<00:36, 20.63batch/s][A
Validation round:  18%|█▊        | 169/916 [00:07<00:34, 21.70batch/s][A
Validation round:  19%|█▉        | 177/916 [00:07<00:32, 22.79batch/s][A
Validation round:  20%|██        | 185/916 [00:07<00:31, 23.24batch/s][A
Validation round:  21%|██        | 193/916 [00:08<00:30, 23.83batch/s][A
Validation round:  22%|██▏       | 201/916 [00:08<00:29, 24.01batch/s][A
Validation round:  23%|██▎       | 209/916 [00:08<00:28, 24.39batch/s][A
Validation round:  24%|██▎       | 217

Epoch 2/8:  20%|█▉        | 1448/7329 [00:07<00:29, 201.93img/s, loss (batch)=0.208]
Validation round:   0%|          | 0/916 [00:00<?, ?batch/s][A
Validation round:   0%|          | 1/916 [00:00<12:03,  1.27batch/s][A
Validation round:   1%|          | 7/916 [00:00<08:27,  1.79batch/s][A
Validation round:   1%|          | 9/916 [00:01<06:27,  2.34batch/s][A
Validation round:   2%|▏         | 17/916 [00:01<04:38,  3.23batch/s][A
Validation round:   3%|▎         | 25/916 [00:01<03:23,  4.38batch/s][A
Validation round:   4%|▎         | 33/916 [00:02<02:31,  5.82batch/s][A
Validation round:   4%|▍         | 41/916 [00:02<01:56,  7.54batch/s][A
Validation round:   5%|▌         | 49/916 [00:02<01:29,  9.68batch/s][A
Validation round:   6%|▌         | 57/916 [00:02<01:11, 12.00batch/s][A
Validation round:   7%|▋         | 65/916 [00:03<00:59, 14.34batch/s][A
Validation round:   8%|▊         | 73/916 [00:03<00:49, 17.04batch/s][A
Validation round:   9%|▉         | 81/916 [00:03<00

Validation round:  80%|████████  | 733/916 [00:27<00:06, 27.55batch/s][A
Validation round:  81%|████████  | 741/916 [00:27<00:06, 28.20batch/s][A
Validation round:  82%|████████▏ | 749/916 [00:28<00:06, 26.79batch/s][A
Validation round:  83%|████████▎ | 757/916 [00:28<00:05, 27.97batch/s][A
Validation round:  84%|████████▎ | 765/916 [00:28<00:05, 27.80batch/s][A
Validation round:  84%|████████▍ | 773/916 [00:28<00:05, 25.86batch/s][A
Validation round:  85%|████████▌ | 781/916 [00:29<00:05, 26.73batch/s][A
Validation round:  86%|████████▌ | 789/916 [00:29<00:04, 27.70batch/s][A
Validation round:  87%|████████▋ | 797/916 [00:29<00:04, 28.13batch/s][A
Validation round:  88%|████████▊ | 805/916 [00:30<00:03, 28.46batch/s][A
Validation round:  89%|████████▉ | 813/916 [00:30<00:03, 28.87batch/s][A
Validation round:  90%|████████▉ | 821/916 [00:30<00:03, 28.54batch/s][A
Validation round:  91%|█████████ | 829/916 [00:30<00:02, 29.48batch/s][A
Validation round:  91%|█████████▏| 837

Validation round:  48%|████▊     | 436/916 [00:17<00:16, 29.43batch/s][A
Validation round:  48%|████▊     | 440/916 [00:18<00:19, 23.81batch/s][A
Validation round:  49%|████▊     | 445/916 [00:18<00:21, 21.42batch/s][A
Validation round:  49%|████▉     | 453/916 [00:18<00:20, 22.43batch/s][A
Validation round:  50%|█████     | 461/916 [00:18<00:19, 23.18batch/s][A
Validation round:  51%|█████     | 469/916 [00:19<00:18, 23.90batch/s][A
Validation round:  52%|█████▏    | 477/916 [00:19<00:18, 24.18batch/s][A
Validation round:  53%|█████▎    | 485/916 [00:19<00:17, 24.27batch/s][A
Validation round:  54%|█████▍    | 493/916 [00:20<00:17, 24.65batch/s][A
Validation round:  55%|█████▍    | 500/916 [00:20<00:13, 30.30batch/s][A
Validation round:  55%|█████▌    | 504/916 [00:20<00:18, 22.74batch/s][A
Validation round:  56%|█████▌    | 509/916 [00:20<00:19, 21.11batch/s][A
Validation round:  56%|█████▋    | 517/916 [00:21<00:18, 21.93batch/s][A
Validation round:  57%|█████▋    | 524

Validation round:  26%|██▌       | 236/916 [00:10<00:28, 24.25batch/s][A
Validation round:  27%|██▋       | 244/916 [00:10<00:27, 24.50batch/s][A
Validation round:  28%|██▊       | 252/916 [00:10<00:27, 24.42batch/s][A
Validation round:  28%|██▊       | 259/916 [00:11<00:21, 30.11batch/s][A
Validation round:  29%|██▊       | 263/916 [00:11<00:28, 22.92batch/s][A
Validation round:  29%|██▉       | 268/916 [00:11<00:31, 20.73batch/s][A
Validation round:  30%|███       | 276/916 [00:11<00:30, 21.32batch/s][A
Validation round:  31%|███       | 284/916 [00:12<00:28, 22.37batch/s][A
Validation round:  32%|███▏      | 292/916 [00:12<00:27, 22.79batch/s][A
Validation round:  33%|███▎      | 300/916 [00:12<00:26, 22.92batch/s][A
Validation round:  34%|███▎      | 308/916 [00:13<00:25, 23.67batch/s][A
Validation round:  34%|███▍      | 316/916 [00:13<00:24, 24.06batch/s][A
Validation round:  35%|███▌      | 324/916 [00:13<00:24, 24.48batch/s][A
Validation round:  36%|███▌      | 332

In [None]:
net2 = SuperNet(input_ch=1, out_ch=15, use_bn=True, superblock_size=128, depth=4)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net2.to(device=device)
train_scores2, val_scores2, train_var_2, val_var_2 = train_net(net=net2,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
net3 = SuperNet(input_ch=1, out_ch=15, use_bn=True, superblock_size=200, depth=4)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net3.to(device=device)
train_scores3, val_scores3, train_var_3, val_var_3 = train_net(net=net3,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
net4 = SuperNet(input_ch=1, out_ch=15, use_bn=True, superblock_size=256, depth=4)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net4.to(device=device)
train_scores4, val_scores4, train_var_4, val_var_4 = train_net(net=net4,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
net5 = AESuperNet(input_ch=1, out_ch=15, use_bn=True, superblock_size=64, depth=4)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net5.to(device=device)
train_scores5, val_scores5, train_var_5, val_var_5 = train_net(net=net5,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
net6 = AESuperNet(input_ch=1, out_ch=15, use_bn=True, superblock_size=128, depth=4)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net6.to(device=device)
train_scores6, val_scores6, train_var_6, val_var_6 = train_net(net=net6,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
net7 = AESuperNet(input_ch=1, out_ch=15, use_bn=True, superblock_size=200, depth=4)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net7.to(device=device)
train_scores7, val_scores7, train_var_7, val_var_7 = train_net(net=net7,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
net8 = AESuperNet(input_ch=1, out_ch=15, use_bn=True, superblock_size=256, depth=4)

logging.info(f'Network:\n'
             f'\t{1} input channels\n'
             f'\t{15} output channels (classes)\n')

net8.to(device=device)
train_scores8, val_scores8, train_var_8, val_var_8 = train_net(net=net8,
                                                                epochs=8,
                                                                batch_size=args.batchsize,
                                                                lr=args.lr,
                                                                device=device,
                                                                img_scale=args.scale,
                                                                val_percent=args.val / 100,
                                                                checkpoint=2,
                                                                target_label_numbers=target_label_numbers,
                                                                writer=writer,
                                                                train_path=dir_train,
                                                                val_path=dir_val)

In [None]:
print("configuring combined plots")
domain = len(train_scores1)
x_values = [i+1 for i in range(domain)]
             
a1 = plt.subplot(1,4,1)
a1.set_ylim([0, 0.5])
plt.title("Complex SLN 64 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores1, train_var_1)]
ziptraindown = [a - b for a, b in zip(train_scores1, train_var_1)]
zipvalup = [a + b for a, b in zip(val_scores1, val_var_1)]
zipvaldown = [a - b for a, b in zip(val_scores1, val_var_1)]

plt.plot(x_values, train_scores1, color="blue", label="train")
a1.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores1, color="orange", label="val")
a1.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

a2 = plt.subplot(1,4,2)
a2.set_ylim([0, 0.5])
plt.title("Complex SLN 128 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores2, train_var_2)]
ziptraindown = [a - b for a, b in zip(train_scores2, train_var_2)]
zipvalup = [a + b for a, b in zip(val_scores2, val_var_2)]
zipvaldown = [a - b for a, b in zip(val_scores2, val_var_2)]

plt.plot(x_values, train_scores2, color="blue", label="train")
a2.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores2, color="orange", label="val")
a2.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

a3 = plt.subplot(1,4,3)
a3.set_ylim([0, 0.5])
plt.title("Complex SLN 200 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores3, train_var_3)]
ziptraindown = [a - b for a, b in zip(train_scores3, train_var_3)]
zipvalup = [a + b for a, b in zip(val_scores3, val_var_3)]
zipvaldown = [a - b for a, b in zip(val_scores3, val_var_3)]

plt.plot(x_values, train_scores3, color="blue", label="train")
a3.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores3, color="orange", label="val")
a3.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

a4 = plt.subplot(1,4,4)
a4.set_ylim([0, 0.5])
plt.title("Complex SLN 256 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores4, train_var_4)]
ziptraindown = [a - b for a, b in zip(train_scores4, train_var_4)]
zipvalup = [a + b for a, b in zip(val_scores4, val_var_4)]
zipvaldown = [a - b for a, b in zip(val_scores4, val_var_4)]

plt.plot(x_values, train_scores4, color="blue", label="train")
a4.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores4, color="orange", label="val")
a4.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

plt.tight_layout()
plt.show()

a5 = plt.subplot(1,4,1)
a5.set_ylim([0, 0.5])
plt.title("AE SLN 64 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores5, train_var_5)]
ziptraindown = [a - b for a, b in zip(train_scores5, train_var_5)]
zipvalup = [a + b for a, b in zip(val_scores5, val_var_5)]
zipvaldown = [a - b for a, b in zip(val_scores5, val_var_5)]

plt.plot(x_values, train_scores5, color="blue", label="train")
a5.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores5, color="orange", label="val")
a5.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

a6 = plt.subplot(1,4,2)
a6.set_ylim([0, 0.5])
plt.title("AE SLN 128 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores6, train_var_6)]
ziptraindown = [a - b for a, b in zip(train_scores6, train_var_6)]
zipvalup = [a + b for a, b in zip(val_scores6, val_var_6)]
zipvaldown = [a - b for a, b in zip(val_scores6, val_var_6)]

plt.plot(x_values, train_scores6, color="blue", label="train")
a6.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores6, color="orange", label="val")
a6.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

a7 = plt.subplot(1,4,3)
a7.set_ylim([0, 0.5])
plt.title("AE SLN 200 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores7, train_var_7)]
ziptraindown = [a - b for a, b in zip(train_scores7, train_var_7)]
zipvalup = [a + b for a, b in zip(val_scores7, val_var_7)]
zipvaldown = [a - b for a, b in zip(val_scores7, val_var_7)]

plt.plot(x_values, train_scores7, color="blue", label="train")
a7.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores7, color="orange", label="val")
a7.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

a8 = plt.subplot(1,4,4)
a8.set_ylim([0, 0.5])
plt.title("AE SLN 256 Superblock")
plt.xlabel("Mini-epochs")
plt.ylabel("Dice Loss")

ziptrainup = [a + b for a, b in zip(train_scores8, train_var_8)]
ziptraindown = [a - b for a, b in zip(train_scores8, train_var_8)]
zipvalup = [a + b for a, b in zip(val_scores8, val_var_8)]
zipvaldown = [a - b for a, b in zip(val_scores8, val_var_8)]

plt.plot(x_values, train_scores8, color="blue", label="train")
a8.fill_between(x_values, ziptrainup, ziptraindown, facecolor='lightskyblue', alpha=0.5)
plt.plot(x_values, val_scores8, color="orange", label="val")
a8.fill_between(x_values, zipvalup, zipvaldown, facecolor='navajowhite', alpha=0.5)

plt.legend()
plt.grid()

plt.tight_layout()
plt.show()

plt.close()