# Hyper-parameters

In [None]:
TRAIN_IMG_SIZE = 128
VAL_IMG_SIZE = 128
TRAINING = True
EPOCHS = 20
BATCH_SIZE = 32

TRAIN_SPLIT = 0.8
VAL_SPLIT = 0.18

#DeepLabV3+
ENCODER = 'resnet101'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
ENCODER_OUTPUT_STRIDE=16 #default: 16
DECODER_ATROUS_RATES=(2, 4, 6) #default: (12, 24, 36)

# Results

## Base Line

| model | val Acc(%) | val IoU(%) |
| ---- | ---- | ---- |
| mobilenet_v2 + U-Net | 88.80 | - |
| mobilenet_v2 + DeepLabV3+ | 93.82 | 83.05 |

## Model

| model | val IoU(%) |
| ---- | ---- |
| mobilenet_v2 | 83.61 |
| resnet34 | 85.27 |
| resnet50 | 85.91 |
| resnet101 | 86.13 |
| resnet152 | **86.14** |

## Image Size

| train | val | val IoU(%) |
| ---- | ---- |  |
| 128 | 128 | 86.13 |
| 128 | 256 | 83.67 |
| 256 | 256 | **88.97** |

## Batch Size

| size | val IoU(%) |
| ---- | ---- |
| 64 | 85.75 |
| 32 | **86.01** |
| 16 | 85.3 |

## Learning Rate

### epochs=20

| lr | val IoU(%) |
| ---- | ---- |
| 0.0001 | 85.92 |
| 0.00009 | 86.13 |
| 0.000085 | 86.05 |
| 0.00008 | 86.13 |
| 0.00007 | 85.69 |
| CA0.0001-0.00005,T_mult1 | 85.74 |
| CA0.00009-0.00008,T_mult1 | **86.33** |
| CA0.00008-0.00005,T_mult1 | 85.64 |

### epochs=30

| lr | val IoU(%) |
| ---- | ---- |
| CA0.00009-0.000075,T_mult1 | 86.71 |
| CA0.00009-0.000075,T_mult2 | 86.67 |

## Loss Function

| Loss | val IoU(%) |
| ---- | ---- |
| CrossEntropy | 86.01 |
| Dice | 85.86 |
| 0.75*CE + 0.25*Dice | **86.09** |

## Atrous Rates

### Image Size (128, 128), Output Strides 16

| rates | val IoU(%) |
| ---- | ---- |
| (2, 4, 6) | **86.13** |
| (2, 4, 8) | 86.09 |
| (12, 24, 36) | 85.63 |

## Augmentation

| Aug | val IoU(%) |
| ---- | ---- |
| No Aug | 59.79 |
| Aug | 59.99 |

## Best Score

| val Accuracy(%) | val IoU(%) |
| ---- | ---- |
| 96.27 | 89.41 |

- encoder: ResNet152
- epoch: 30
- image size: 256
- batch size: 32
- learning rate: 0.00009-8 cosin annealing
- loss: 0.75\*cross entropy + 0.25\*dice
- output stride: 16
- atrous rates: (4, 8, 12)


# Reference

## Data Augmentation

https://blog.shikoan.com/manual-augmentation/#6_Auto_Contrast
https://qiita.com/kurilab/items/b69e1be8d0224ae139ad#randombrightnesscontrast

## Batch size tuning

https://medium.com/mini-distill/effect-of-batch-size-on-training-dynamics-21c14f7a716e

### learning rate scheduler

https://techburst.io/improving-the-way-we-work-with-learning-rate-5e99554f163b  
https://arxiv.org/abs/1608.03983

## segmentation_models_pytorch

https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/deeplabv3/model.py


# Note

## batch size

### 64 -> 32

- 学習速度・精度向上

### 32 -> 16

- 学習速度向上、精度低下
- 12epochほどでtrain,valともに頭打ち

# Preprocess

### Libraries 📚⬇

In [None]:
import os, cv2
import numpy as np
import pandas as pd
import random, tqdm
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

from sklearn.utils import shuffle

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as album

In [None]:
!pip install -q -U segmentation-models-pytorch albumentations > /dev/null
import segmentation_models_pytorch as smp

### Read Data & Create train / valid splits 📁

In [None]:
def split_paths(paths, train_rate, val_rate):
    train_n = int(len(paths) * train_rate)
    val_n = int(len(paths) * val_rate)
    
    train = paths[:train_n]
    val = paths[train_n:train_n + val_n]
    test = paths[train_n + val_n:]
    
    return train, val, test

In [None]:
IMG_PATH = '/kaggle/input/the-oxfordiiit-pet-dataset/images/images'
ANNOTATION_PATH = '/kaggle/input/the-oxfordiiit-pet-dataset/annotations/annotations/trimaps'

input_img_paths = []
annotation_img_paths = []
for fname in os.listdir(IMG_PATH):
    img_path = os.path.join(IMG_PATH, fname)
    ann_path = os.path.join(ANNOTATION_PATH, fname.replace(".jpg", ".png"))
    if cv2.imread(img_path) is None or cv2.imread(ann_path) is None:
        print(f"cannot not load {fname}")
        continue

    input_img_paths.append(img_path)
    annotation_img_paths.append(ann_path)

In [None]:
input_img_paths, annotation_img_paths = shuffle(input_img_paths, annotation_img_paths)

train_img_paths, val_img_paths, test_img_paths = split_paths(input_img_paths, TRAIN_SPLIT, VAL_SPLIT)
print(len(train_img_paths), len(val_img_paths), len(test_img_paths))

train_label_paths, val_label_paths, test_label_paths = split_paths(annotation_img_paths, TRAIN_SPLIT, VAL_SPLIT)
print(len(train_label_paths), len(val_label_paths), len(test_label_paths))

## visualize mask

In [None]:
mask = cv2.cvtColor(cv2.imread(annotation_img_paths[0]), cv2.COLOR_BGR2RGB)
mask.shape

In [None]:
plt.imshow(mask[:,:,:])

In [None]:
plt.figure()
plt.imshow(mask[:,:,0])
plt.figure()
plt.imshow(mask[:,:,1])
plt.figure()
plt.imshow(mask[:,:,2])

### RGBすべてのレイヤーで数値が同じグレースケールの画像

In [None]:
print((mask[:,:,0] == mask[:,:,1]).all())
print((mask[:,:,0] == mask[:,:,2]).all())

### ラベルの確認

- 1: object
- 2: background
- 3: edge

In [None]:
mask_ = cv2.resize(mask, dsize=(28, 28))
print(mask_[:,:,0])
plt.imshow(mask_[:,:,0])

In [None]:
# DATA_DIR = '../input/deepglobe-road-extraction-dataset'

# metadata_df = pd.read_csv(os.path.join(DATA_DIR, 'metadata.csv'))
# metadata_df = metadata_df[metadata_df['split']=='train']
# metadata_df = metadata_df[['image_id', 'sat_image_path', 'mask_path']]
# metadata_df['sat_image_path'] = metadata_df['sat_image_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
# metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
# # Shuffle DataFrame
# metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)

# # Perform 90/10 split for train / val
# valid_df = metadata_df.sample(frac=0.1, random_state=42)
# train_df = metadata_df.drop(valid_df.index)
# len(train_df), len(valid_df)

In [None]:
img = cv2.cvtColor(cv2.imread(input_img_paths[0]), cv2.COLOR_BGR2RGB)
img = cv2.resize(img, dsize=(10, 10))
print(img[:,:,0])
plt.imshow(img)

## metadata

In [None]:
# Get class names
class_names = ["object", "background", "edge"]
# Get class RGB values
class_rgb_values = [[1,1,1], [2,2,2], [3,3,3]]

print('All dataset classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)

#### Shortlist specific classes to segment

In [None]:
# Useful to shortlist specific classes in datasets with large number of classes
select_classes = ["object", "background", "edge"]

# Get RGB values of required classes
select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]
select_class_rgb_values =  np.array(class_rgb_values)[select_class_indices]

print('Selected classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)

### Helper functions for viz. & one-hot encoding/decoding

In [None]:
# helper function for data visualization
def visualize(**images):
    """
    Plot images in one row
    """
    n_images = len(images)
    plt.figure(figsize=(20,8))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([]); 
        plt.yticks([])
        # get title from the parameter names
        plt.title(name.replace('_',' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()

# Perform one hot encoding on label
def one_hot_encode(label, label_values):
    """
    Convert a segmentation image label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    # Arguments
        label: The 2D array segmentation image label
        label_values
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of num_classes
    """
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map
    
# Perform reverse one-hot-encoding on labels / preds
def reverse_one_hot(image):
    """
    Transform a 2D array in one-hot format (depth is num_classes),
    to a 2D array with only 1 channel, where each pixel value is
    the classified class key.
    # Arguments
        image: The one-hot format image 
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of 1, where each pixel value is the classified 
        class key.
    """
    x = np.argmax(image, axis = -1)
    return x

# Perform colour coding on the reverse-one-hot outputs
def colour_code_segmentation(image, label_values):
    """
    Given a 1-channel array of class keys, colour code the segmentation results.
    # Arguments
        image: single channel array where each value represents the class key.
        label_values

    # Returns
        Colour coded image for segmentation visualization
    """
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]

    return x

In [None]:
class RoadsDataset(torch.utils.data.Dataset):

    """DeepGlobe Road Extraction Challenge Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        df (str): DataFrame containing images / labels paths
        class_rgb_values (list): RGB values of select classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    def __init__(
            self, 
            image_paths,
            mask_paths,
            class_rgb_values=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        
        self.class_rgb_values = class_rgb_values
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read images and masks
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
        
        # one-hot-encode the mask
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        # return length of 
        return len(self.image_paths)

#### Visualize Sample Image and Mask 📈

In [None]:
dataset = RoadsDataset(input_img_paths, annotation_img_paths, class_rgb_values=select_class_rgb_values)
random_idx = random.randint(0, len(dataset)-1)
image, mask = dataset[random_idx]

print(image[:5,:5,0])

ground_truth_mask = colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values)
print(ground_truth_mask.shape)
print(ground_truth_mask[:10,:10,0])

visualize(
    original_image = image,
    ground_truth_mask = reverse_one_hot(mask),
    one_hot_encoded_mask = reverse_one_hot(mask)
)

# Defining Augmentations 🙃

In [None]:
def get_training_augmentation():
    train_transform = [
        album.Resize(TRAIN_IMG_SIZE, TRAIN_IMG_SIZE, p=1),
        album.HorizontalFlip(p=0.5),
        album.ShiftScaleRotate(shift_limit=0, scale_limit=0, rotate_limit=15, p=0.5),
#         album.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        album.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
        album.Normalize(mean=(0, 0, 0), std=(1, 1, 1)),
    ]
    return album.Compose(train_transform)

def get_validation_augmentation():
    train_transform = [
        album.Resize(VAL_IMG_SIZE, VAL_IMG_SIZE, p=1),
        album.Normalize(mean=(0, 0, 0), std=(1, 1, 1)),
    ]
    return album.Compose(train_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn=None):
    """Construct preprocessing transform    
    Args:
        preprocessing_fn (callable): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    """
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
        
    return album.Compose(_transform)

#### Visualize Augmented Images & Masks

In [None]:
augmented_dataset = RoadsDataset(
    train_img_paths,
    train_label_paths, 
    augmentation=get_training_augmentation(),
    class_rgb_values=select_class_rgb_values,
)

random_idx = random.randint(0, len(augmented_dataset)-1)

# Different augmentations on image/mask pairs
for _ in range(10):
    image, mask = augmented_dataset[random_idx]
    visualize(
        original_image = image,
        ground_truth_mask = reverse_one_hot(mask)
    )

# Training DeepLabV3+

<h3><center>DeepLabV3+ Model Architecture</center></h3>
<img src="https://miro.medium.com/max/1000/1*2mYfKnsX1IqCCSItxpXSGA.png" width="750" height="750"/>
<h4></h4>
<h4><center><a href="https://arxiv.org/abs/1802.02611">Image Source: DeepLabV3+ [Liang-Chieh Chen et al.]</a></center></h4>

### Model Definition

In [None]:
CLASSES = select_classes

# create segmentation model with pretrained encoder
model = smp.DeepLabV3Plus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
    encoder_output_stride=ENCODER_OUTPUT_STRIDE, #default: 16
    decoder_atrous_rates=DECODER_ATROUS_RATES, #default: (12, 24, 36)
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

#### Get Train / Val DataLoaders

In [None]:
# Get train and val dataset instances
train_dataset = RoadsDataset(
    train_img_paths,
    train_label_paths,
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
)

valid_dataset = RoadsDataset(
    val_img_paths,
    val_label_paths,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
)

# Get train and val data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

#### Set Hyperparams

In [None]:
# Set device: `cuda` or `cpu`
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"DEVICE: {DEVICE}")

# define loss function
class CeDiceLoss(smp.utils.losses.base.Loss):
    def __init__(self, alpha=0.75, **kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha

    def forward(self, y_pr, y_gt):
        return self.alpha * smp.utils.losses.BCELoss().forward(y_pr, y_gt) \
                + (1-self.alpha) * smp.utils.losses.DiceLoss().forward(y_pr, y_gt)

loss = CeDiceLoss(alpha=0.75)
# loss = smp.utils.losses.DiceLoss()

# define metrics
metrics = [
    smp.utils.metrics.Accuracy(),
    smp.utils.metrics.IoU(threshold=0.5),
]

# define optimizer
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

# load best saved model checkpoint from previous commit (if present)
if os.path.exists('../input/the-oxfordiiit-pet-dataset-deeplabv3/best_model.pth'):
    model = torch.load('../input/the-oxfordiiit-pet-dataset-deeplabv3/best_model.pth', map_location=DEVICE)
    print('Loaded pre-trained DeepLabV3+ model!')

In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

## Define lr scheduler of Cosin Annealing

In [None]:
class CosineAnnealingWarmRestarts:
    def __init__(self, epoch=10, min_lr=1e-4, max_lr=1e-3, T_0=1, T_mult=1):
        self.epoch = epoch
        self.min_lr = min_lr
        self.max_lr = max_lr
        self.T_0 = T_0
        self.T_mult = T_mult

    def calc_lr(self, ep):
        if ep < self.T_0:
            return self.max_lr
        ep = ep - (self.T_0-1)

        if self.T_mult == 1:
            T_cur = ep-1
            T_i = self.epoch - self.T_0
            return self.min_lr + 1/2 * (self.max_lr - self.min_lr) * (1 + np.cos(T_cur/T_i*np.pi))

        def calc_restart_count(e, cnt):
            if e * self.T_mult > ep:
                return cnt
            return calc_restart_count(e * self.T_mult, cnt+1)

        restart_count = calc_restart_count(1, 0)
        T_cur = ep - self.T_mult ** restart_count
        T_i = max(self.T_mult ** restart_count - 1, 1)
    #     print(epoch, restart_count, T_cur, T_i)

        return self.min_lr + 1/2 * (self.max_lr - self.min_lr) * (1 + np.cos(T_cur/T_i*np.pi))

In [None]:
ca = CosineAnnealingWarmRestarts(epoch=35, min_lr=0, max_lr=1, T_0=5, T_mult=1)
a = [ca.calc_lr(i) for i in range(1, 36)]
plt.plot(range(1, len(a)+1), a)
plt.title("T_mult: 1")

ca = CosineAnnealingWarmRestarts(epoch=35, min_lr=0, max_lr=1, T_0=6, T_mult=2)
a = [ca.calc_lr(i) for i in range(1, 36)]
plt.figure()
plt.plot(range(1, len(a)+1), a)
plt.title("T_mult: 2")

In [None]:
ca_lr_scheduler = CosineAnnealingWarmRestarts(epoch=EPOCHS, min_lr=0.00008, max_lr=0.00009, T_0=1, T_mult=1)

a = [ca_lr_scheduler.calc_lr(i) for i in range(1, EPOCHS+1)]
plt.figure()
plt.plot(range(1, len(a)+1), a)

### Training DeepLabV3+

In [None]:
%%time

if TRAINING:

    best_iou_score = 0.0
    train_logs_list, valid_logs_list = [], []

    for i in range(1, EPOCHS+1):

        # Perform training & validation
        print('\nEpoch: {}'.format(i))

        lr = ca_lr_scheduler.calc_lr(i)
        print('Learning rate: {}'.format(lr))

        optimizer = torch.optim.Adam([ 
            dict(params=model.parameters(), lr=lr),
        ])

        train_epoch = smp.utils.train.TrainEpoch(
            model, 
            loss=loss, 
            metrics=metrics, 
            optimizer=optimizer,
            device=DEVICE,
            verbose=True,
        )

        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)
        train_logs_list.append(train_logs)
        valid_logs_list.append(valid_logs)

        # Save model if a better val IoU score is obtained
        if best_iou_score < valid_logs['iou_score']:
            best_iou_score = valid_logs['iou_score']
            torch.save(model, './best_model.pth')
            print('Model saved!')

# Prediction on Test Data

In [None]:
# load best saved model checkpoint from the current run
if os.path.exists('./best_model.pth'):
    best_model = torch.load('./best_model.pth', map_location=DEVICE)
    print('Loaded DeepLabV3+ model from this run.')

# load best saved model checkpoint from previous commit (if present)
elif os.path.exists('../input/the-oxfordiiit-pet-dataset-deeplabv3/best_model.pth'):
    best_model = torch.load('../input/the-oxfordiiit-pet-dataset-deeplabv3/best_model.pth', map_location=DEVICE)
    print('Loaded DeepLabV3+ model from a previous commit.')

In [None]:
# create test dataloader to be used with DeepLabV3+ model (with preprocessing operation: to_tensor(...))
test_dataset = RoadsDataset(
    test_img_paths,
    test_label_paths,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
)

test_dataloader = DataLoader(test_dataset)

# test dataset for visualization (without preprocessing augmentations & transformations)
test_dataset_vis = RoadsDataset(
    test_img_paths,
    test_label_paths,
    class_rgb_values=select_class_rgb_values,
)

# get a random test image/mask index
random_idx = random.randint(0, len(test_dataset_vis)-1)
image, mask = test_dataset_vis[random_idx]

visualize(
    original_image = image,
    ground_truth_mask = reverse_one_hot(mask),
    one_hot_encoded_mask = reverse_one_hot(mask)
)


In [None]:
sample_preds_folder = 'sample_predictions/'
if not os.path.exists(sample_preds_folder):
    os.makedirs(sample_preds_folder)

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

In [None]:
n = 30
idxes = np.random.randint(0,len(test_img_paths),n)
for idx in idxes:
    image, gt_mask = test_dataset[idx]
    image_vis, _ = test_dataset_vis[idx]
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    # Predict test image
    pred_mask = best_model(x_tensor)
    pred_mask = pred_mask.detach().squeeze().cpu().numpy()

#     print(image_vis.shape, gt_mask.shape, pred_mask.shape)

    # Convert pred_mask from `CHW` format to `HWC` format
    pred_mask = np.transpose(pred_mask,(1,2,0))
    gt_mask = np.transpose(gt_mask,(1,2,0))
    
    visualize(
        original_image = cv2.resize(image_vis, dsize=(128, 128)),
        ground_truth_mask = reverse_one_hot(gt_mask),
        predicted_mask = reverse_one_hot(pred_mask),
    )

# Model Evaluation on Test Dataset

In [None]:
test_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

valid_logs = test_epoch.run(test_dataloader)
print("Evaluation on Test Data: ")
print(f"Mean IoU Score: {valid_logs['iou_score']:.4f}")
print(f"Mean Dice Loss: {valid_logs['ce_dice_loss']:.4f}")

### Plot Dice Loss & IoU Metric for Train vs. Val

In [None]:
train_logs_df = pd.DataFrame(train_logs_list)
train_logs_df.T

In [None]:
valid_logs_df = pd.DataFrame(valid_logs_list)
valid_logs_df.T

In [None]:
plt.figure(figsize=(20,8))
plt.plot(train_logs_df.index.tolist(), train_logs_df.iou_score.tolist(), lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.iou_score.tolist(), lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('IoU Score', fontsize=20)
plt.title('IoU Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('iou_score_plot.png')
plt.show()

In [None]:
plt.figure(figsize=(20,8))
plt.plot(train_logs_df.index.tolist(), train_logs_df.ce_dice_loss.tolist(), lw=3, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.ce_dice_loss.tolist(), lw=3, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('Dice Loss', fontsize=20)
plt.title('Dice Loss Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('dice_loss_plot.png')
plt.show()