In [19]:
import os
import shutil
import random
from glob import glob
from natsort import natsorted

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import nibabel as nib

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()

from tqdm import tqdm

from monai.data import DataLoader, Dataset, CacheDataset, pad_list_data_collate
from monai.utils import set_determinism
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    Spacingd,
    ToTensord,
    SpatialCropd,
    RandAdjustContrastd,
    RandGaussianNoised,
    RandCoarseShuffled,
    CropForegroundd,
)


In [20]:
# !nvidia-smi

In [21]:
torch.cuda.empty_cache()

In [22]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
device

device(type='cuda')

In [23]:
class RelabelTransform(MapTransform):
    def __init__(self, keys, old_label, new_label):
        super().__init__(keys)
        self.old_label = old_label
        self.new_label = new_label

    def __call__(self, data):
        
        d = dict(data)
        
        for key in self.keys:
            d[key][d[key] == self.old_label] = self.new_label
        return d



def prepare(
            in_dir,
            pixdim=(1.2, 1.2,1.0), 
            spatial_size=[160, 160, 128], 
            gamma=(0.8,1.25), 
            roi_center=[100,100,79], 
            roi_size =[160, 160, 128],
            cache=True,
            start_file=0,
            end_file=None,
            train_size=0.8):

    set_determinism(seed=0)

    flair = natsorted(glob(os.path.join(in_dir, '*/*flair.nii')))[start_file : end_file]
    t1 = natsorted(glob(os.path.join(in_dir, '*/*t1.nii')))[start_file : end_file]
    t1ce = natsorted(glob(os.path.join(in_dir, '*/*t1ce.nii')))[start_file : end_file]
    t2 = natsorted(glob(os.path.join(in_dir, '*/*t2.nii')))[start_file : end_file]
    mask = natsorted(glob(os.path.join(in_dir, '*/*seg.nii')))[start_file : end_file]

    full_dataset = [{"vol": [flair,t1ce,t2,t1], "seg": mask} for flair,t1ce,t2,t1,mask in
                   zip(flair,t1ce,t2,t1,mask)]
    
    train_size = int(train_size * len(full_dataset)) 
    val_size = len(full_dataset) - train_size 
        
        
    train_dataset, val_dataset = random_split(
    full_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)  # Set seed for reproducibility
    )




    train_transform = Compose(
        [
            LoadImaged(keys=["vol", "seg"]),
            EnsureChannelFirstd(keys=["vol", "seg"]),
            EnsureTyped(keys=["vol", "seg"]),
            
            Orientationd(keys=["vol", "seg"], axcodes="RAS"),
            Spacingd(keys=["vol", "seg"], pixdim=pixdim, mode=("bilinear", "nearest")),
            RelabelTransform(keys=["seg"], old_label=4, new_label=3),
            
            SpatialCropd(keys=["vol","seg"], roi_center=roi_center ,roi_size=roi_size ),

            RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=0),
            RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=1),
            RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=2),
            
            RandGaussianNoised(keys=["vol"], prob=0.6, mean=0.0, std=0.1),
            
            RandCoarseShuffled(
            keys=["vol"],
            prob=0.7,
            holes=10 ,
            spatial_size=(16, 16, 16)
            ),
            
            RandAdjustContrastd(keys=["vol"] ,prob=0.7,gamma=gamma),
            
            NormalizeIntensityd(keys="vol", nonzero=True, channel_wise=True),
            RandScaleIntensityd(keys="vol", factors=0.15, prob=0.7),
            RandShiftIntensityd(keys="vol", offsets=0.15, prob=0.7),
            ToTensord(keys=["vol", "seg"]),
        ]
    )
    
    val_transform = Compose(
        [
                LoadImaged(keys=["vol", "seg"]),
                EnsureChannelFirstd(keys=["vol", "seg"]),
                EnsureTyped(keys=["vol", "seg"]),
                
                Orientationd(keys=["vol", "seg"], axcodes="RAS"),
                Spacingd(keys=["vol", "seg"], pixdim=pixdim, mode=("bilinear", "nearest")),
                RelabelTransform(keys=["seg"], old_label=4, new_label=3),
                
                SpatialCropd(keys=["vol","seg"], roi_center=roi_center ,roi_size=roi_size ),

                RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=0),
                RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=1),
                RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=2),
                
                RandGaussianNoised(keys=["vol"], prob=0.6, mean=0.0, std=0.1),
                
                RandCoarseShuffled(
                keys=["vol"],
                prob=0.7,
                holes=10 ,
                spatial_size=(16, 16, 16)
                ),
                
            NormalizeIntensityd(keys="vol", nonzero=True, channel_wise=True),
        ]
    )



    if cache:
        train_ds = CacheDataset(
            data=train_dataset,
            transform=train_transform,
            cache_rate=1.0,
            num_workers=0,
        )


        val_ds = CacheDataset(
            data=val_dataset,
            transform=val_transform,
            cache_rate=1.0,
            num_workers=0,
        )


        return (
    DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0),
    DataLoader(val_ds, batch_size=2, shuffle=True, num_workers=0)
)


    else:
        train_ds = Dataset(
                    data=train_dataset,
                    transform=train_transform,
                    cache_rate=1.0,
                    num_workers=0,
                )
        
                
                
        val_ds = Dataset(
                    data=val_dataset,
                    transform=val_transform,
                    cache_rate=1.0,
                    num_workers=0,
                )

        return (
    DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0),
    DataLoader(val_ds, batch_size=2, shuffle=True, num_workers=0)
)


In [24]:
# train_path = r'G:\BraTS\MICCAI_BraTS2020_TestingData'
# train_loader , val_loader = prepare(in_dir=train_path,start_file=0,end_file=10)

# # Get one batch from the DataLoader
# val_batch = next(iter(val_loader))  # assuming val_loader is your DataLoader
# # Pick the first sample from the batch
# val_data_example = {k: v[0] for k, v in val_batch.items()}

# print(f"Volume shape: {val_data_example['vol'].shape}")

# for slice_idx in range(20, 100, 3):  # slice index
#     plt.figure("Volume", (24, 6))
#     for ch in range(4):  # channels: flair, t1ce, t2, t1
#         plt.subplot(1, 4, ch + 1)
#         plt.title(f"vol channel {ch}")
#         plt.imshow(val_data_example["vol"][ch, :, :, slice_idx].detach().cpu(), cmap="gray")
#     plt.tight_layout()
#     plt.show()

#     # Visualize the segmentation mask
#     seg = val_data_example['seg'].detach().cpu()

#     print(np.unique(seg))  # Optional: check unique values
#     print(f"Segmentation shape: {seg.shape}")  # (1, H, W, D)

#     plt.figure("Segmentation", (6, 6))
#     plt.title(f"segmentation (slice {slice_idx})")
#     plt.imshow(seg[0, :, :, slice_idx], cmap="viridis")
#     plt.axis('off')
#     plt.show()


In [25]:
from monai.utils import first
import matplotlib.pyplot as plt
def show_patient(tran_loader,val_loader,SLICE_NUMBER=1 ,train:bool=True,val:bool=False):
    
    view_train_patient = first(tran_loader)
    view_val_patient = first(val_loader)
    
    if train:
        
        # Create figure with proper size
        plt.figure(figsize=(12, 6))  

        # First subplot (Original Image)
        plt.subplot(1, 2, 1)
        plt.title(f'vol {SLICE_NUMBER}')

        view_train_patient_image = np.array(view_train_patient['vol'][0, 0, :, :, SLICE_NUMBER], dtype=np.float16 )
        plt.imshow(view_train_patient_image, cmap="gray")

        # Second subplot (Test Image)
        plt.subplot(1, 2, 2)
        plt.title(f'seg {SLICE_NUMBER}')
        view_train_patient_label = np.array(view_train_patient['seg'][0, 0, :, :, SLICE_NUMBER], dtype=np.float16)  
        plt.imshow(view_train_patient_label, cmap="gray")

        # Show the plot
        plt.show()
    
    
    if val:
        
        # Create figure with proper size
        plt.figure(figsize=(12, 6))  

        # First subplot (Original Image)
        plt.subplot(1, 2, 1)
        plt.title(f'vol {SLICE_NUMBER}')

        view_val_patient = np.array(view_val_patient['vol'][0, 0, :, :, SLICE_NUMBER], dtype=np.float16)  
        plt.imshow(view_val_patient, cmap="gray")

        # Second subplot (Test Image)
        plt.subplot(1, 2, 2)
        plt.title(f'seg {SLICE_NUMBER}')
        view_val_patient = np.array(view_val_patient['seg'][0, 0, :, :, SLICE_NUMBER], dtype=np.float16)  
        plt.imshow(view_val_patient, cmap="gray")

        # Show the plot
        plt.show()

In [26]:
# Applying the transform to a single sample
# sample = transforms(train_dataset[0])
# print(f"Volume shape: {sample['vol'].shape}")
# print(f"Segmentation shape: {sample['seg'].shape}")


In [27]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

class UNetEncoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetEncoder, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.pool = nn.MaxPool3d(2, 2)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        skip = x 
        x = self.pool(x)
        return x, skip


class UNetDecoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetDecoder, self).__init__()
        self.upconv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv1 = nn.Conv3d(out_channels * 2, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)

    def forward(self, x, skip):
        x = self.upconv(x)
        x = torch.cat([x, skip], dim=1)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        return x 


class UNetBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetBottleneck, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.dropout = nn.Dropout3d(0.3)  # %30 dropout

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.dropout(x)
        return x 


class UnetModel(nn.Module):
    def __init__(self, num_classes):
        super(UnetModel, self).__init__()

        self.encoder1 = UNetEncoder(4, 32)    
        self.encoder2 = UNetEncoder(32, 64)   
        self.encoder3 = UNetEncoder(64, 128)  
        self.encoder4 = UNetEncoder(128, 256) 

        self.bottleneck = UNetBottleneck(256, 512) 

        self.decoder1 = UNetDecoder(512, 256)
        self.decoder2 = UNetDecoder(256, 128)
        self.decoder3 = UNetDecoder(128, 64)
        self.decoder4 = UNetDecoder(64, 32)

        self.final_conv = nn.Conv3d(32, num_classes, kernel_size=1)
    
    def forward(self, x):
        x, skip1 = self.encoder1(x)
        x, skip2 = self.encoder2(x)
        x, skip3 = self.encoder3(x)
        x, skip4 = self.encoder4(x) 

        x = self.bottleneck(x)

        x = self.decoder1(x, skip4)
        x = self.decoder2(x, skip3)
        x = self.decoder3(x, skip2)
        x = self.decoder4(x, skip1)
        
        x = self.final_conv(x)  
        return x 


In [28]:
model = UnetModel(4).to(device)

In [29]:
class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5, smooth=1e-6):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.smooth = smooth

    def dice_loss(self, pred, target):
        pred = torch.softmax(pred, dim=1)
        intersection = (pred * target).sum(dim=(2, 3, 4))
        union = pred.sum(dim=(2, 3, 4)) + target.sum(dim=(2, 3, 4))
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

    def forward(self, pred, target):
        dice = self.dice_loss(pred, target)
        ce = nn.CrossEntropyLoss()(pred, target.argmax(dim=1))  # [B, 128, 128, 128]
        return self.alpha * dice + (1 - self.alpha) * ce

In [30]:
import os
import time
import torch
import torch.nn as nn
from tqdm import tqdm
 
# Dice Score function
def dice_score(pred, target, smooth=1e-6):
    pred = torch.softmax(pred, dim=1)  # [B, 4, 128, 128, 128]
    intersection = (pred * target).sum(dim=(2, 3, 4))
    union = pred.sum(dim=(2, 3, 4)) + target.sum(dim=(2, 3, 4))
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice.mean()

def model_train(model, val_data, train_data,
                model_train_val_metric_path="D:\web dev backup\Pytorch\Brain_tumor_metric",
                epochs=20, checkpoint_path="Brain_tumor/tumor_unet_model_v2.pth",
                i=1):
   
    save_loss_train = []
    save_loss_val = []
    save_metric_train = []
    save_metric_val = []
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    
    model_concat_train_val_metric_path = "D:\web dev backup\Pytorch\Brain_tumor_metric_concat"
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Number of GPUs: {torch.cuda.device_count()}")

    if torch.cuda.device_count() > 1:
        print("Using DataParallel for multi-GPU!")
        model = nn.DataParallel(model)
    model.to(device)

    criterion = CombinedLoss()  # Assuming CombinedLoss is defined somewhere
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)

    # **Load previous checkpoint if available**
    
    best_dice = 0.0  # remove it for already epoched dataset
    if os.path.exists(checkpoint_path):
        print(f"Loading previous checkpoint from {checkpoint_path}...")
        checkpoint = torch.load(checkpoint_path, map_location=device)

        # Load the best Dice score if available
        if 'best_dice' in checkpoint and checkpoint['best_dice'] is not None:
            best_dice = checkpoint['best_dice']
            print(f"Loaded best Dice score from checkpoint: {best_dice:.4f}")
        else:
            best_dice = 0.0
            print("No best Dice score found in checkpoint. Using default value of 0.0.")

        # Load model state
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Model state loaded successfully!")

        # Load optimizer state (add this)
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print("Optimizer state loaded successfully!")
        else:
          print("No optimizer state found in checkpoint. Starting with a new optimizer.")

    else:
        print("No checkpoint found. Starting training from scratch.")

    best_val_loss = float('inf')
    patience = 5  # Early stopping
    patience_counter = 0

    for epoch in range(epochs):
        start_time = time.time()

        # **Training Phase**
        model.train()
        train_loss = 0
        train_dice = 0
        train_loader = tqdm(train_data, desc=f"Epoch {epoch+1}/{epochs} [Train]")

        for batch_data in train_loader:


            images, labels = batch_data['vol'], batch_data['seg']
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)  # [B, 4, 128, 128, 128]
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_dice += dice_score(outputs, labels).item()
            train_loader.set_postfix({'train_loss': train_loss / (train_loader.n + 1)})

        avg_train_loss = train_loss / len(train_data)
        avg_train_dice = train_dice / len(train_data)
        
        save_loss_train.append(avg_train_loss)
        save_metric_train.append(avg_train_dice)
        
        file_path_train_loss = os.path.join(model_concat_train_val_metric_path,f'loss_train_{i}.npy')
        file_path_train_metric = os.path.join(model_concat_train_val_metric_path,f'metric_train_{i}.npy')
        
        if os.path.exists(file_path_train_loss) and os.path.exists(file_path_train_metric):
            
            # Load existing data
            train_loss_data = np.load(file_path_train_loss)
            train_metric_data = np.load(file_path_train_metric)
            
            # Append along axis=0
            train_loss_data = np.concatenate((train_loss_data, save_loss_train), axis=0)
            train_metric_data = np.concatenate((train_metric_data, save_metric_train), axis=0)
            
        else:
        # If file doesn't exist, just use new_data as initial data
            train_loss_data = save_loss_train
            train_metric_data = save_metric_train


        np.save(file_path_train_loss, train_loss_data)
        np.save(file_path_train_metric, train_metric_data)

        np.save(os.path.join(model_train_val_metric_path, 'loss_train.npy'), save_loss_train)
        np.save(os.path.join(model_train_val_metric_path, 'metric_train.npy'), save_metric_train)

        
        # **Validation Phase**
        model.eval()
        val_loss = 0
        val_dice = 0
        val_loader = tqdm(val_data, desc=f"Epoch {epoch+1}/{epochs} [Val]")

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                val_dice += dice_score(outputs, labels).item()

                val_loader.set_postfix({'val_loss': val_loss / (val_loader.n + 1),
                                        'val_dice': val_dice / (val_loader.n + 1)})

        avg_val_loss = val_loss / len(val_data)
        avg_val_dice = val_dice / len(val_data)
        
        save_loss_val.append(avg_val_loss)
        save_metric_val.append(avg_val_dice)

        
        file_path_val_loss = os.path.join(model_concat_train_val_metric_path,f'loss_val_{i}.npy')
        file_path_val_metric = os.path.join(model_concat_train_val_metric_path,f'metric_val_{i}.npy')
        
        if os.path.exists(file_path_val_loss) and os.path.exists(file_path_val_metric):
            
            # Load existing data
            val_loss_data = np.load(file_path_val_loss)
            val_metric_data = np.load(file_path_val_metric)
            
            # Append along axis=0
            val_loss_data = np.concatenate((val_loss_data, save_loss_val), axis=0)
            val_metric_data = np.concatenate((val_metric_data, save_metric_val), axis=0)
            
        else:
        # If file doesn't exist, just use new_data as initial data
            val_loss_data = save_loss_val
            val_metric_data = save_metric_val


        np.save(file_path_val_loss, val_loss_data)
        np.save(file_path_val_metric, val_metric_data)

        np.save(os.path.join(model_train_val_metric_path, 'loss_val.npy'), save_loss_val)
        np.save(os.path.join(model_train_val_metric_path, 'metric_val.npy'), save_metric_val)


        scheduler.step(avg_val_loss)

        epoch_duration = time.time() - start_time

        print(f'Epoch {epoch+1}/{epochs}')
        print(f'Training Loss: {avg_train_loss:.4f}')
        print(f'Training Loss: {avg_train_dice:.4f}')
        print(f'Validation Loss: {avg_val_loss:.4f}')
        print(f'Validation Dice: {avg_val_dice:.4f}')
        print(f'Time Taken: {epoch_duration:.2f} seconds')
        print('-' * 50)

        # **Save the best model based on Dice Score**
        if avg_val_dice > best_dice:
            best_dice = avg_val_dice
            best_val_loss = avg_val_loss

                    # Save the best model only
                    # Save the model along with the best dice score
            torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'best_dice': best_dice
                        }, "/content/drive/My Drive/Brain_tumor/unet_model_v10.pth")

            print(f"Best model saved with Dice: {best_dice:.4f}, Val Loss: {best_val_loss:.4f}")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break





  model_train_val_metric_path="D:\web dev backup\Pytorch\Brain_tumor_metric",
  model_concat_train_val_metric_path = "D:\web dev backup\Pytorch\Brain_tumor_metric_concat"


In [31]:
train_path = r'G:\BraTS\MICCAI_BraTS2020_TestingData'
train_loader , val_loader = prepare(in_dir=train_path,start_file=0,end_file=10)


Loading dataset: 100%|██████████| 8/8 [00:34<00:00,  4.25s/it]
Loading dataset: 100%|██████████| 2/2 [00:09<00:00,  4.70s/it]


In [32]:
print(next(model.parameters()).device)
print(next(model.parameters()).device)  # Should print cuda:0

cuda:0
cuda:0


In [33]:
model_train(model, val_data= val_loader, train_data= train_loader, epochs=10,checkpoint_path="Brain_tumor/tumor_unet_model_v41.pth",i=1)

Number of GPUs: 1
No checkpoint found. Starting training from scratch.


Epoch 1/10 [Train]:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 1/10 [Train]:   0%|          | 0/4 [00:54<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 800.00 MiB. GPU 0 has a total capacity of 4.00 GiB of which 0 bytes is free. Of the allocated memory 10.51 GiB is allocated by PyTorch, and 52.70 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
model_dir = "D:\web dev backup\Pytorch\Brain_tumor_metric"
train_loss = np.load(os.path.join(model_dir, 'loss_train.npy'))
train_metric = np.load(os.path.join(model_dir, 'metric_train.npy'))
val_loss = np.load(os.path.join(model_dir, 'loss_val.npy'))
val_metric = np.load(os.path.join(model_dir, 'metric_val.npy'))

In [None]:
plt.figure("Results 25 june", (12, 12))
plt.subplot(2, 2, 1)
plt.title("Train dice loss")
x = [i + 1 for i in range(len(train_loss))]
y = train_loss
plt.xlabel("epoch")
plt.plot(x, y)

plt.subplot(2, 2, 2)
plt.title("Train metric DICE")
x = [i + 1 for i in range(len(train_metric))]
y = train_metric
plt.xlabel("epoch")
plt.plot(x, y)

plt.subplot(2, 2, 3)
plt.title("Test dice loss")
x = [i + 1 for i in range(len(val_loss))]
y = val_loss
plt.xlabel("epoch")
plt.plot(x, y)

plt.subplot(2, 2, 4)
plt.title("Test metric DICE")
x = [i + 1 for i in range(len(val_metric))]
y = val_metric
plt.xlabel("epoch")
plt.plot(x, y)

plt.show()