In [1]:
import numpy as np
import pandas as pd
import random
import torch

Lets dynamically choose type of device to use for our computations. To run on your own GPU one needs to install pytorch and cuda-toolkit.

In [2]:
print(torch.cuda.is_available())
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

True


In [3]:
torch.cuda.get_device_properties(device)



_CudaDeviceProperties(name='NVIDIA GeForce RTX 2080 Ti', major=7, minor=5, total_memory=10824MB, multi_processor_count=68)

Define the seed for reproductivity:

In [4]:
def set_seed(seed, use_gpu = True):
    """
    Set SEED for PyTorch reproducibility
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_gpu:
        torch.cuda.manual_seed_all(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

SEED = 196
USE_SEED = True
if USE_SEED:
    set_seed(SEED, torch.cuda.is_available())

**Load the images and masks**

In [5]:
augmented = input("Data augmentation? (y/n)")

if augmented == "y":
    train_images = np.load("data/data/train_augmented/combined_images.npy")
    train_masks = np.load("data/data/train_augmented/combined_masks.npy")

elif augmented == "n":
    train_images = np.load("data/data/train_nonaugmented/train_images.npy")
    train_masks = np.load("data/data/train_nonaugmented/train_masks.npy")

else:
    raise ValueError(f"Only 'y' or 'n' ")

test_images = np.load("data/data/test/test_images.npy")
test_masks = np.load("data/data/test/test_masks.npy")

Data augmentation? (y/n) y


In [6]:
print("Train images shape:", train_images.shape)
print("Test images shape:", test_images.shape)
print("Train masks shape:", train_masks.shape)
print("Test masks shape:", test_masks.shape)

Train images shape: (3500, 256, 256, 3)
Test images shape: (300, 256, 256, 3)
Train masks shape: (3500, 256, 256, 1)
Test masks shape: (300, 256, 256, 1)


**Standaritzation**

In [7]:
mean = np.mean(train_images, axis = (0,1,2)) / 255
std = np.std(train_images, axis = (0,1,2)) / 255

print("-----  NORMALIZATION VALUES  -----")
print(f"Mean (RGB): {mean}")
print(f"Standard Deviation (RGB): {std}")

-----  NORMALIZATION VALUES  -----
Mean (RGB): [0.56289435 0.33033543 0.24323519]
Standard Deviation (RGB): [0.31274417 0.23138525 0.19586279]


In [8]:
import torchvision
import torchvision.transforms as transforms

transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),  # Convert the image to a PyTorch tensor and normalize it between [0, 1]
    torchvision.transforms.Normalize(mean, std)  # Normalize the tensor using the provided mean and standard deviation
])

**Validation set**

In [9]:
from sklearn.model_selection import train_test_split

#Validation set (40%) of the testing images
testing_images, val_images, testing_masks, val_masks = train_test_split(
    test_images, test_masks, test_size=0.4, random_state=42)

# Check shape of training images and masks
print("Training Images Shape:", train_images.shape)
print("Training Masks Shape:", train_masks.shape)

# Check shape of validation images and masks
print("Validation Images Shape:", val_images.shape)
print("Validation Masks Shape:", val_masks.shape)

# Check length of training and validation sets
print("Number of Train Examples:", len(train_images))
print("Number of Validation Examples:", len(val_images))
print("Number of Test Examples:", len(testing_images))

Training Images Shape: (3500, 256, 256, 3)
Training Masks Shape: (3500, 256, 256, 1)
Validation Images Shape: (120, 256, 256, 3)
Validation Masks Shape: (120, 256, 256, 1)
Number of Train Examples: 3500
Number of Validation Examples: 120
Number of Test Examples: 180


In [10]:
from src.preprocess import KvasirDataset

# Create KvasirDataset objects for train, validation and test sets
train_dataset = KvasirDataset(images=train_images, masks=train_masks, transforms=transforms)
val_dataset = KvasirDataset(images=val_images,masks=val_masks, transforms=transforms)
test_dataset = KvasirDataset(images=testing_images, masks=testing_masks, transforms=transforms) 


**Define the model**

In [11]:
config = {}
config['name']='DiceLoss_try_Aug'
config['epochs']=100
config['batch_size']=16
config['arch']='NestedUNet'
config['deep_supervision']=True
config['input_channels']=3
config['num_classes']=1
config['early_stopping']=10 # 10 epochs without improving the dice coefficient
#config['input_w']=128
#config['input_h']=128


In [12]:
from torch.utils.data import DataLoader

train_iterator = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
test_iterator = DataLoader(test_dataset, batch_size=config['batch_size'])
val_iterator = DataLoader(val_dataset,batch_size=config['batch_size'])

In [13]:
from src.main import NestedUNet
from src.utils import count_parameters

model = NestedUNet(config)
print(f"The model has {count_parameters(model):,} trainable parameters.")

The model has 2,264,899 trainable parameters.


In [14]:
model = model.to(device)

In [15]:
from src.loss import DiceLoss, IoULoss, BCEDiceLoss
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

criterion = DiceLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4) 
scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=5, verbose=True)



In [None]:
from src.train_val_test import train, evaluate
from collections import OrderedDict

log = OrderedDict([
    ('epoch', []),
    ('loss', []),
    ('iou', []),
    ('dice',[]),
    ('accuracy',[]),
    ('val_loss', []),
    ('val_iou', []),
    ('val_dice',[]),
    ('val_accuracy',[])
])

best_dice = 0
trigger = 0
for epoch in range(config['epochs']):
    print('Epoch [%d/%d]' % (epoch, config['epochs']))

    # train for one epoch
    train_log = train(config, model, train_iterator, criterion, optimizer, config['deep_supervision'], device)
    # evaluate on validation set
    val_log = evaluate(config, model, val_iterator, criterion, config['deep_supervision'], device)

    print('loss %.4f - iou %.4f - dice %.4f - accuracy %.4f - val_loss %.4f - val_iou %.4f - val_dice %.4f - val_accuracy %.4f'
          % (train_log['loss'], train_log['iou'], train_log['dice_coef'], train_log['accuracy'],
             val_log['loss'], val_log['iou'], val_log['dice_coef'], val_log['accuracy']))

    log['epoch'].append(epoch)
    log['loss'].append(train_log['loss'])
    log['iou'].append(train_log['iou'])
    log['dice'].append(train_log['dice_coef'])
    log['accuracy'].append(train_log['accuracy'])
    log['val_loss'].append(val_log['loss'])
    log['val_iou'].append(val_log['iou'])
    log['val_dice'].append(val_log['dice_coef'])
    log['val_accuracy'].append(val_log['accuracy'])
    
    pd.DataFrame(log).to_csv('models/%s/log.csv' %
                             config['name'], index=False)

    trigger += 1

    if val_log['dice_coef'] > best_dice:
        torch.save(model.state_dict(), 'models/%s/model.pth' %
                   config['name'])
        best_dice = val_log['dice_coef']
        print("=> saved best model")
        trigger = 0

    # Reduce learning rate when validation metric stops improving
    scheduler.step(val_log['loss'])

    # early stopping
    if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:
        print("=> early stopping")
        break

    torch.cuda.empty_cache()

Epoch [0/100]


100%|█████████████████████████| 219/219 [01:13<00:00,  2.99it/s, loss=0.677, iou=0.203, dice_coef=0.336, accuracy=0.552]
100%|█████████████████████████████| 8/8 [00:00<00:00,  9.89it/s, loss=0.663, iou=0.202, dice_coef=0.335, accuracy=0.688]


loss 0.6775 - iou 0.2035 - dice 0.3364 - accuracy 0.5521 - val_loss 0.6627 - val_iou 0.2016 - val_dice 0.3349 - val_accuracy 0.6875
=> saved best model
Epoch [1/100]


100%|██████████████████████████| 219/219 [01:18<00:00,  2.78it/s, loss=0.633, iou=0.24, dice_coef=0.386, accuracy=0.628]
100%|█████████████████████████████| 8/8 [00:00<00:00,  9.49it/s, loss=0.634, iou=0.234, dice_coef=0.379, accuracy=0.727]


loss 0.6326 - iou 0.2403 - dice 0.3857 - accuracy 0.6275 - val_loss 0.6340 - val_iou 0.2340 - val_dice 0.3787 - val_accuracy 0.7267
=> saved best model
Epoch [2/100]


100%|█████████████████████████| 219/219 [01:24<00:00,  2.58it/s, loss=0.599, iou=0.271, dice_coef=0.424, accuracy=0.675]
100%|█████████████████████████████| 8/8 [00:01<00:00,  7.98it/s, loss=0.591, iou=0.278, dice_coef=0.434, accuracy=0.722]


loss 0.5988 - iou 0.2709 - dice 0.4243 - accuracy 0.6754 - val_loss 0.5905 - val_iou 0.2777 - val_dice 0.4337 - val_accuracy 0.7217
=> saved best model
Epoch [3/100]


100%|█████████████████████████| 219/219 [01:30<00:00,  2.43it/s, loss=0.567, iou=0.304, dice_coef=0.464, accuracy=0.725]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.17it/s, loss=0.564, iou=0.306, dice_coef=0.468, accuracy=0.801]


loss 0.5666 - iou 0.3044 - dice 0.4643 - accuracy 0.7249 - val_loss 0.5643 - val_iou 0.3059 - val_dice 0.4677 - val_accuracy 0.8006
=> saved best model
Epoch [4/100]


100%|█████████████████████████| 219/219 [01:31<00:00,  2.40it/s, loss=0.532, iou=0.343, dice_coef=0.509, accuracy=0.771]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.13it/s, loss=0.536, iou=0.328, dice_coef=0.493, accuracy=0.757]


loss 0.5322 - iou 0.3434 - dice 0.5088 - accuracy 0.7713 - val_loss 0.5359 - val_iou 0.3283 - val_dice 0.4932 - val_accuracy 0.7569
=> saved best model
Epoch [5/100]


100%|███████████████████████████| 219/219 [01:31<00:00,  2.40it/s, loss=0.5, iou=0.385, dice_coef=0.553, accuracy=0.813]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.17it/s, loss=0.492, iou=0.389, dice_coef=0.559, accuracy=0.864]


loss 0.4997 - iou 0.3855 - dice 0.5533 - accuracy 0.8134 - val_loss 0.4923 - val_iou 0.3894 - val_dice 0.5592 - val_accuracy 0.8637
=> saved best model
Epoch [6/100]


100%|█████████████████████████| 219/219 [01:29<00:00,  2.44it/s, loss=0.466, iou=0.432, dice_coef=0.601, accuracy=0.858]
100%|█████████████████████████████| 8/8 [00:01<00:00,  7.19it/s, loss=0.474, iou=0.429, dice_coef=0.598, accuracy=0.876]


loss 0.4657 - iou 0.4317 - dice 0.6008 - accuracy 0.8581 - val_loss 0.4735 - val_iou 0.4285 - val_dice 0.5981 - val_accuracy 0.8756
=> saved best model
Epoch [7/100]


100%|██████████████████████████| 219/219 [01:31<00:00,  2.39it/s, loss=0.437, iou=0.477, dice_coef=0.643, accuracy=0.89]
100%|██████████████████████████████| 8/8 [00:01<00:00,  7.79it/s, loss=0.487, iou=0.38, dice_coef=0.549, accuracy=0.805]


loss 0.4373 - iou 0.4768 - dice 0.6429 - accuracy 0.8896 - val_loss 0.4868 - val_iou 0.3804 - val_dice 0.5491 - val_accuracy 0.8050
Epoch [8/100]


100%|█████████████████████████| 219/219 [01:31<00:00,  2.40it/s, loss=0.409, iou=0.524, dice_coef=0.685, accuracy=0.902]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.43it/s, loss=0.424, iou=0.499, dice_coef=0.665, accuracy=0.899]


loss 0.4094 - iou 0.5237 - dice 0.6849 - accuracy 0.9019 - val_loss 0.4240 - val_iou 0.4991 - val_dice 0.6646 - val_accuracy 0.8986
=> saved best model
Epoch [9/100]


100%|███████████████████████████| 219/219 [01:31<00:00,  2.39it/s, loss=0.386, iou=0.565, dice_coef=0.72, accuracy=0.91]
100%|██████████████████████████████| 8/8 [00:00<00:00,  8.13it/s, loss=0.405, iou=0.525, dice_coef=0.687, accuracy=0.91]


loss 0.3856 - iou 0.5649 - dice 0.7195 - accuracy 0.9104 - val_loss 0.4052 - val_iou 0.5254 - val_dice 0.6873 - val_accuracy 0.9102
=> saved best model
Epoch [10/100]


100%|█████████████████████████| 219/219 [01:31<00:00,  2.40it/s, loss=0.369, iou=0.592, dice_coef=0.741, accuracy=0.914]
100%|██████████████████████████████| 8/8 [00:00<00:00,  8.26it/s, loss=0.39, iou=0.567, dice_coef=0.722, accuracy=0.907]


loss 0.3693 - iou 0.5916 - dice 0.7410 - accuracy 0.9145 - val_loss 0.3896 - val_iou 0.5670 - val_dice 0.7219 - val_accuracy 0.9066
=> saved best model
Epoch [11/100]


100%|█████████████████████████| 219/219 [01:30<00:00,  2.42it/s, loss=0.353, iou=0.623, dice_coef=0.765, accuracy=0.921]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.22it/s, loss=0.375, iou=0.574, dice_coef=0.728, accuracy=0.915]


loss 0.3532 - iou 0.6228 - dice 0.7654 - accuracy 0.9209 - val_loss 0.3745 - val_iou 0.5741 - val_dice 0.7279 - val_accuracy 0.9150
=> saved best model
Epoch [12/100]


100%|█████████████████████████| 219/219 [01:30<00:00,  2.41it/s, loss=0.341, iou=0.646, dice_coef=0.783, accuracy=0.924]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.22it/s, loss=0.378, iou=0.583, dice_coef=0.734, accuracy=0.908]


loss 0.3408 - iou 0.6461 - dice 0.7829 - accuracy 0.9242 - val_loss 0.3784 - val_iou 0.5825 - val_dice 0.7343 - val_accuracy 0.9081
=> saved best model
Epoch [13/100]


100%|█████████████████████████| 219/219 [01:30<00:00,  2.42it/s, loss=0.331, iou=0.667, dice_coef=0.798, accuracy=0.928]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.27it/s, loss=0.354, iou=0.632, dice_coef=0.774, accuracy=0.925]


loss 0.3306 - iou 0.6670 - dice 0.7982 - accuracy 0.9283 - val_loss 0.3542 - val_iou 0.6317 - val_dice 0.7737 - val_accuracy 0.9247
=> saved best model
Epoch [14/100]


100%|███████████████████████████| 219/219 [01:30<00:00,  2.41it/s, loss=0.32, iou=0.684, dice_coef=0.81, accuracy=0.931]
100%|██████████████████████████████| 8/8 [00:00<00:00,  8.26it/s, loss=0.34, iou=0.639, dice_coef=0.778, accuracy=0.926]


loss 0.3202 - iou 0.6836 - dice 0.8102 - accuracy 0.9310 - val_loss 0.3401 - val_iou 0.6394 - val_dice 0.7781 - val_accuracy 0.9260
=> saved best model
Epoch [15/100]


100%|██████████████████████████| 219/219 [01:30<00:00,  2.42it/s, loss=0.312, iou=0.698, dice_coef=0.82, accuracy=0.933]
100%|██████████████████████████████| 8/8 [00:00<00:00,  8.25it/s, loss=0.34, iou=0.636, dice_coef=0.775, accuracy=0.926]


loss 0.3121 - iou 0.6977 - dice 0.8201 - accuracy 0.9334 - val_loss 0.3396 - val_iou 0.6365 - val_dice 0.7751 - val_accuracy 0.9264
Epoch [16/100]


100%|█████████████████████████| 219/219 [01:30<00:00,  2.42it/s, loss=0.305, iou=0.707, dice_coef=0.827, accuracy=0.935]
100%|██████████████████████████████| 8/8 [00:00<00:00,  8.42it/s, loss=0.349, iou=0.616, dice_coef=0.76, accuracy=0.924]


loss 0.3050 - iou 0.7074 - dice 0.8268 - accuracy 0.9348 - val_loss 0.3488 - val_iou 0.6156 - val_dice 0.7596 - val_accuracy 0.9244
Epoch [17/100]


100%|███████████████████████████| 219/219 [01:30<00:00,  2.42it/s, loss=0.3, iou=0.719, dice_coef=0.835, accuracy=0.937]
100%|██████████████████████████████| 8/8 [00:00<00:00,  8.07it/s, loss=0.322, iou=0.656, dice_coef=0.79, accuracy=0.926]


loss 0.3001 - iou 0.7193 - dice 0.8351 - accuracy 0.9374 - val_loss 0.3222 - val_iou 0.6564 - val_dice 0.7903 - val_accuracy 0.9256
=> saved best model
Epoch [18/100]


100%|███████████████████████████| 219/219 [01:29<00:00,  2.44it/s, loss=0.289, iou=0.74, dice_coef=0.85, accuracy=0.941]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.38it/s, loss=0.321, iou=0.664, dice_coef=0.797, accuracy=0.926]


loss 0.2886 - iou 0.7404 - dice 0.8495 - accuracy 0.9412 - val_loss 0.3211 - val_iou 0.6644 - val_dice 0.7967 - val_accuracy 0.9262
=> saved best model
Epoch [19/100]


100%|█████████████████████████| 219/219 [01:30<00:00,  2.41it/s, loss=0.284, iou=0.749, dice_coef=0.855, accuracy=0.942]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.30it/s, loss=0.312, iou=0.675, dice_coef=0.806, accuracy=0.928]


loss 0.2844 - iou 0.7491 - dice 0.8553 - accuracy 0.9424 - val_loss 0.3120 - val_iou 0.6755 - val_dice 0.8057 - val_accuracy 0.9277
=> saved best model
Epoch [20/100]


100%|█████████████████████████| 219/219 [01:30<00:00,  2.41it/s, loss=0.276, iou=0.764, dice_coef=0.865, accuracy=0.945]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.22it/s, loss=0.319, iou=0.646, dice_coef=0.782, accuracy=0.918]


loss 0.2761 - iou 0.7637 - dice 0.8649 - accuracy 0.9452 - val_loss 0.3192 - val_iou 0.6459 - val_dice 0.7818 - val_accuracy 0.9181
Epoch [21/100]


100%|█████████████████████████| 219/219 [01:31<00:00,  2.40it/s, loss=0.273, iou=0.767, dice_coef=0.867, accuracy=0.946]
100%|██████████████████████████████| 8/8 [00:00<00:00,  8.14it/s, loss=0.317, iou=0.656, dice_coef=0.789, accuracy=0.93]


loss 0.2732 - iou 0.7668 - dice 0.8667 - accuracy 0.9455 - val_loss 0.3170 - val_iou 0.6563 - val_dice 0.7892 - val_accuracy 0.9300
Epoch [22/100]


100%|█████████████████████████| 219/219 [01:31<00:00,  2.39it/s, loss=0.269, iou=0.776, dice_coef=0.873, accuracy=0.947]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.14it/s, loss=0.311, iou=0.673, dice_coef=0.802, accuracy=0.932]


loss 0.2689 - iou 0.7760 - dice 0.8726 - accuracy 0.9470 - val_loss 0.3111 - val_iou 0.6726 - val_dice 0.8018 - val_accuracy 0.9316
Epoch [23/100]


100%|██████████████████████████| 219/219 [01:33<00:00,  2.34it/s, loss=0.265, iou=0.78, dice_coef=0.875, accuracy=0.948]
100%|█████████████████████████████| 8/8 [00:01<00:00,  7.95it/s, loss=0.312, iou=0.666, dice_coef=0.798, accuracy=0.928]


loss 0.2648 - iou 0.7796 - dice 0.8750 - accuracy 0.9475 - val_loss 0.3119 - val_iou 0.6664 - val_dice 0.7980 - val_accuracy 0.9278
Epoch [24/100]


100%|█████████████████████████| 219/219 [01:31<00:00,  2.40it/s, loss=0.263, iou=0.784, dice_coef=0.878, accuracy=0.948]
100%|██████████████████████████████| 8/8 [00:01<00:00,  7.93it/s, loss=0.315, iou=0.66, dice_coef=0.791, accuracy=0.931]


loss 0.2628 - iou 0.7838 - dice 0.8777 - accuracy 0.9482 - val_loss 0.3146 - val_iou 0.6602 - val_dice 0.7913 - val_accuracy 0.9307
Epoch [25/100]


100%|██████████████████████████| 219/219 [01:32<00:00,  2.36it/s, loss=0.258, iou=0.795, dice_coef=0.885, accuracy=0.95]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.08it/s, loss=0.303, iou=0.695, dice_coef=0.819, accuracy=0.933]


loss 0.2584 - iou 0.7946 - dice 0.8845 - accuracy 0.9503 - val_loss 0.3027 - val_iou 0.6949 - val_dice 0.8186 - val_accuracy 0.9329
=> saved best model
Epoch [26/100]


100%|█████████████████████████| 219/219 [01:32<00:00,  2.37it/s, loss=0.253, iou=0.805, dice_coef=0.891, accuracy=0.952]
100%|█████████████████████████████| 8/8 [00:00<00:00,  8.13it/s, loss=0.299, iou=0.699, dice_coef=0.821, accuracy=0.935]


loss 0.2531 - iou 0.8051 - dice 0.8911 - accuracy 0.9525 - val_loss 0.2986 - val_iou 0.6988 - val_dice 0.8206 - val_accuracy 0.9354
=> saved best model
Epoch [27/100]


 63%|███████████████▋         | 137/219 [00:57<00:33,  2.43it/s, loss=0.248, iou=0.816, dice_coef=0.898, accuracy=0.954]

In [None]:
test_log = evaluate(config, model, test_iterator, criterion, config['deep_supervision'], device)

pd.DataFrame([test_log]).to_csv('models/%s/test_log.csv' % config['name'])

print('test_loss %.4f - test_iou %.4f - test_dice %.4f - test_accuracy %.4f'
          % (test_log['loss'], test_log['iou'], test_log['dice_coef'], test_log['accuracy']))

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Read the CSV file into a DataFrame
df = pd.read_csv('models/DiceLoss_try_Aug/log.csv')
test = pd.read_csv('models/DiceLoss_try_Aug/test_log.csv')

# Test values for loss, IoU, and Dice coefficients
test_loss = test['loss'].item()  
test_iou = test['iou'].item() 
test_dice = test['dice_coef'].item() 
test_accuracy = test['accuracy'].item() 

# Create a figure and subplots
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 12))

# Plot the loss features in the first subplot
ax1.plot(df['epoch'], df['loss'], label='Training Loss')
ax1.plot(df['epoch'], df['val_loss'], label='Validation Loss')
ax1.axhline(y=test_loss, color='r', linestyle='--', label='Test Loss')
ax1.annotate(f'{test_loss:.2f}', xy=(90, test_loss+0.005), color='r')
ax1.set_title('Loss Metrics')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()

ax2.plot(df['epoch'], df['accuracy'], label='Training Accuracy')
ax2.plot(df['epoch'], df['val_accuracy'], label='Validation Accuracy')
ax2.axhline(y=test_accuracy, color='r', linestyle='--', label='Test Accuracy')
ax2.annotate(f'{test_accuracy:.2f}', xy=(90, test_accuracy+0.002), color='r')
ax2.set_title('Accuracy Metrics')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()

ax3.plot(df['epoch'], df['iou'], label='Training IoU')
ax3.plot(df['epoch'], df['val_iou'], label='Validation IoU')
ax3.axhline(y=test_iou, color='r', linestyle='--', label='Test IoU')
ax3.annotate(f'{test_iou:.2f}', xy=(90, test_iou+0.003), color='r')
ax3.set_title('IoU Metrics')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('IoU')
ax3.legend()

ax4.plot(df['epoch'], df['dice'], label='Training Dice')
ax4.plot(df['epoch'], df['val_dice'], label='Validation Dice')
ax4.axhline(y=test_dice, color='r', linestyle='--', label='Test Dice')
ax4.annotate(f'{test_dice:.2f}', xy=(90, test_dice+0.005), color='r')
ax4.set_title('Dice Metrics')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Dice')
ax4.legend()

fig.suptitle("UNET++ with DiceLoss and with Data Augmentation")
# Adjust the spacing between subplots
plt.tight_layout()

plt.savefig("results/DiceLoss_Aug.png")
# Display the plot
plt.show()