# **Dataset preparation**

## Install necessary librairies and imports

In [18]:
!pip install numpy==1.24.3 opencv-python pandas scikit-learn wandb albumentations torchgeometry torchsummary torchinfo timm segmentation_models_pytorch

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.
Defaulting to user installation because normal site-packages is not writeable


In [19]:
import os
import pandas as pd
import numpy as np
import cv2
from torchvision.io import read_image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, random_split, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from torchvision.transforms import ToTensor
from PIL import Image
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision 
from torchvision import transforms
from torchinfo import summary
import timm
import segmentation_models_pytorch as smp
import wandb

# Use of GPU

In [20]:
!nvidia-smi

Fri Nov 22 10:10:50 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4050 ...    Off | 00000000:01:00.0 Off |                  N/A |
| N/A   44C    P8               1W /  80W |     11MiB /  6141MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Dataset creator

In [22]:
class DatasetCustom(Dataset):
    def __init__(self, img_dir, label_dir, resize=None, transform=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.resize = resize
        self.transform = transform
        self.images = os.listdir(self.img_dir)

    def __len__(self):
        return len(self.images)
    
    def read_mask(self, mask_path):
        image = cv2.imread(mask_path)
        image = cv2.resize(image, self.resize)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

        lower_red1 = np.array([0, 100, 20])
        upper_red1 = np.array([10, 255, 255])
        lower_red2 = np.array([160,100,20])
        upper_red2 = np.array([179,255,255])
        
        lower_mask_red = cv2.inRange(image, lower_red1, upper_red1)
        upper_mask_red = cv2.inRange(image, lower_red2, upper_red2)
        
        red_mask = lower_mask_red + upper_mask_red
        red_mask[red_mask != 0] = 1

        green_mask = cv2.inRange(image, (36, 25, 25), (70, 255, 255))
        green_mask[green_mask != 0] = 2

        full_mask = cv2.bitwise_or(red_mask, green_mask)
        full_mask = np.expand_dims(full_mask, axis=-1) 
        full_mask = full_mask.astype(np.uint8)
        
        return full_mask

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        label_path = os.path.join(self.label_dir, self.images[idx])
        image = cv2.imread(img_path)  #  BGR
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert to RGB
        label = self.read_mask(label_path)  
        image = cv2.resize(image, self.resize)
        if self.transform:
            image = self.transform(image)
            
        return image, label

## Create datasets

training dataset

In [24]:
images_path = "data/train/train/"
image_path = []
TRAIN_DIR = 'data/train/train'
for root, dirs, files in os.walk(TRAIN_DIR):
    for file in files:
        path = os.path.join(root,file)
        image_path.append(path)
        
len(image_path)

1000

In [25]:
mask_path = []
TRAIN_MASK_DIR = 'data/train_gt/train_gt'
for root, dirs, files in os.walk(TRAIN_MASK_DIR):
    for file in files:
        path = os.path.join(root,file)
        mask_path.append(path)
        
len(mask_path)

1000

In [26]:
dataset = DatasetCustom(img_dir= TRAIN_DIR,
                             label_dir= TRAIN_MASK_DIR,
                             resize= (256,256),
                             transform = None)

In [27]:
batch_size = 8
images_data = []
labels_data = []
for x,y in dataset:
    images_data.append(x)
    labels_data.append(y)

In [31]:
class CustomDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = targets
        self.transform = transform

    def __getitem__(self, index):
        image = self.data[index]
        label = self.targets[index]
        assert image.shape[:2] == label.shape[:2]
        if self.transform:
            transformed = self.transform(image=image, mask=label)
            image = transformed['image'].float()
            label = transformed['mask'].float()
            label = label.permute(2, 0, 1)
        return image, label
    
    def __len__(self):
        return len(self.data)

In [32]:
train_transformation = A.Compose([
    A.HorizontalFlip(p=0.4),
    A.VerticalFlip(p=0.4),
    A.RandomGamma (gamma_limit=(70, 130), eps=None, always_apply=False, p=0.2),
    A.RGBShift(p=0.3, r_shift_limit=10, g_shift_limit=10, b_shift_limit=10),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transformation = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

In [33]:
train_size = int(0.8 * len(images_data))
val_size = len(images_data) - train_size
train_dataset = CustomDataset(images_data[:train_size], labels_data[:train_size], transform=train_transformation)
val_dataset = CustomDataset(images_data[train_size:], labels_data[train_size:], transform=val_transformation)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

## Define model

In [29]:
model = smp.UnetPlusPlus(
    encoder_name="resnet34",        
    encoder_weights="imagenet",     
    in_channels=3,                  
    classes=3     
)

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /home/tom/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth
100.0%


In [34]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [35]:
color_dict= {0: (0, 0, 0),
             1: (255, 0, 0),
             2: (0, 255, 0)}
def mask_to_rgb(mask, color_dict):
    output = np.zeros((mask.shape[0], mask.shape[1], 3))

    for k in color_dict.keys():
        output[mask==k] = color_dict[k]

    return np.uint8(output)    

## Hyperparameters

In [38]:
NUM_EPOCHS = 250
LEARNING_RATE = 0.0001

CHECKPOINT_PATH = "model.pth"

## Set up wandb

In [36]:
PROJECT = "KaggleCompetDL"
RESUME = "allow"
WANDB_KEY = "178abddbb0479df636406d6449d801b33d61c9c1"

In [39]:
wandb.login(
    key = WANDB_KEY,
)
wandb.init(
    project = PROJECT,
    config={
        "learning_rate": LEARNING_RATE,
        "epochs": NUM_EPOCHS,
    },
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/tom/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtom-briand3[0m ([33mtom-briand3-hanoi-university-of-science-and-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Training loop

In [40]:
!pip install tqdm

Defaulting to user installation because normal site-packages is not writeable


In [None]:
from tqdm import tqdm
import time

num_epochs = NUM_EPOCHS

model.to(device)
criterion = nn.CrossEntropyLoss()
best_val_loss = 999

epoch_bar = tqdm(total=num_epochs, desc='Total Progress')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        labels = labels.squeeze(dim=1).long()
        outputs = model(images)
    
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            labels = labels.squeeze(dim=1).long()
            
            outputs = model(images)

            val_loss += criterion(outputs.float(),labels.long()).item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {val_loss/len(val_loader):.10f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint = { 
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': val_loss,
        }
        save_path = f'model.pth'
        torch.save(checkpoint, save_path)
        
    epoch_bar.update(1)
    wandb.log({'Val_loss': val_loss/len(val_loader),'Train_loss': train_loss/len(train_loader)})
epoch_bar.close()

Total Progress:   0%|                                   | 0/250 [00:00<?, ?it/s]

Epoch [1/250], Loss: 0.2492629296


Total Progress:   0%|                         | 1/250 [00:21<1:31:06, 21.95s/it]

Epoch [2/250], Loss: 0.1359166586


Total Progress:   1%|▏                        | 2/250 [00:43<1:30:17, 21.85s/it]

Epoch [3/250], Loss: 0.0948687536


Total Progress:   1%|▎                        | 3/250 [01:05<1:29:56, 21.85s/it]

Epoch [4/250], Loss: 0.0790827870


Total Progress:   2%|▍                        | 4/250 [01:27<1:30:03, 21.97s/it]

Epoch [5/250], Loss: 0.0642434682


Total Progress:   2%|▌                        | 5/250 [01:49<1:29:45, 21.98s/it]

Epoch [6/250], Loss: 0.0596455342


Total Progress:   2%|▌                        | 6/250 [02:11<1:29:16, 21.95s/it]

Epoch [7/250], Loss: 0.0515719340


Total Progress:   3%|▋                        | 7/250 [02:33<1:28:52, 21.94s/it]

Epoch [8/250], Loss: 0.0494496162


Total Progress:   3%|▊                        | 8/250 [02:55<1:28:28, 21.94s/it]

Epoch [9/250], Loss: 0.0461576264


Total Progress:   4%|▉                        | 9/250 [03:17<1:28:09, 21.95s/it]

Epoch [10/250], Loss: 0.0446983383


Total Progress:   4%|█                       | 11/250 [04:00<1:26:57, 21.83s/it]

Epoch [11/250], Loss: 0.0489421238


Total Progress:   5%|█▏                      | 12/250 [04:22<1:26:14, 21.74s/it]

Epoch [12/250], Loss: 0.0459715146


Total Progress:   5%|█▏                      | 13/250 [04:44<1:25:44, 21.71s/it]

Epoch [13/250], Loss: 0.0490506009


Total Progress:   6%|█▎                      | 14/250 [05:05<1:25:16, 21.68s/it]

Epoch [14/250], Loss: 0.0486072105
Epoch [15/250], Loss: 0.0421819370


Total Progress:   6%|█▍                      | 15/250 [05:27<1:25:14, 21.76s/it]

Epoch [16/250], Loss: 0.0405490945


Total Progress:   6%|█▌                      | 16/250 [05:49<1:25:05, 21.82s/it]

Epoch [17/250], Loss: 0.0395048497


Total Progress:   7%|█▋                      | 18/250 [06:33<1:24:18, 21.80s/it]

Epoch [18/250], Loss: 0.0417271645


Total Progress:   8%|█▊                      | 19/250 [06:54<1:23:47, 21.76s/it]

Epoch [19/250], Loss: 0.0443165907


Total Progress:   8%|█▉                      | 20/250 [07:16<1:23:26, 21.77s/it]

Epoch [20/250], Loss: 0.0403784419
Epoch [21/250], Loss: 0.0368719778


Total Progress:   9%|██                      | 22/250 [08:00<1:22:55, 21.82s/it]

Epoch [22/250], Loss: 0.0573285728


Total Progress:   9%|██▏                     | 23/250 [08:22<1:22:16, 21.75s/it]

Epoch [23/250], Loss: 0.0419564151


Total Progress:  10%|██▎                     | 24/250 [08:43<1:21:45, 21.71s/it]

Epoch [24/250], Loss: 0.0372602614


Total Progress:  10%|██▍                     | 25/250 [09:05<1:21:21, 21.70s/it]

Epoch [25/250], Loss: 0.0376015801
Epoch [26/250], Loss: 0.0353807619


Total Progress:  11%|██▌                     | 27/250 [09:49<1:20:46, 21.74s/it]

Epoch [27/250], Loss: 0.0367807284


Total Progress:  11%|██▋                     | 28/250 [10:10<1:20:13, 21.68s/it]

Epoch [28/250], Loss: 0.0363297287


Total Progress:  12%|██▊                     | 29/250 [10:32<1:19:45, 21.65s/it]

Epoch [29/250], Loss: 0.0487010510


Total Progress:  12%|██▉                     | 30/250 [10:53<1:19:17, 21.63s/it]

Epoch [30/250], Loss: 0.0394097874


Total Progress:  12%|██▉                     | 31/250 [11:15<1:18:50, 21.60s/it]

Epoch [31/250], Loss: 0.0389069052


Total Progress:  13%|███                     | 32/250 [11:36<1:18:26, 21.59s/it]

Epoch [32/250], Loss: 0.0429777542


Total Progress:  13%|███▏                    | 33/250 [11:58<1:18:05, 21.59s/it]

Epoch [33/250], Loss: 0.0417459943


Total Progress:  14%|███▎                    | 34/250 [12:20<1:17:47, 21.61s/it]

Epoch [34/250], Loss: 0.0379184238


Total Progress:  14%|███▎                    | 35/250 [12:41<1:17:32, 21.64s/it]

Epoch [35/250], Loss: 0.0382762626


Total Progress:  14%|███▍                    | 36/250 [13:03<1:17:05, 21.62s/it]

Epoch [36/250], Loss: 0.0391083168


Total Progress:  15%|███▌                    | 37/250 [13:24<1:16:41, 21.60s/it]

Epoch [37/250], Loss: 0.0376108276


Total Progress:  15%|███▋                    | 38/250 [13:46<1:16:18, 21.60s/it]

Epoch [38/250], Loss: 0.0417464741


Total Progress:  16%|███▋                    | 39/250 [14:08<1:15:56, 21.60s/it]

Epoch [39/250], Loss: 0.0400515383


Total Progress:  16%|███▊                    | 40/250 [14:29<1:15:38, 21.61s/it]

Epoch [40/250], Loss: 0.0392128688


Total Progress:  16%|███▉                    | 41/250 [14:51<1:15:19, 21.62s/it]

Epoch [41/250], Loss: 0.0416903002


Total Progress:  17%|████                    | 42/250 [15:13<1:14:55, 21.61s/it]

Epoch [42/250], Loss: 0.0463836050


Total Progress:  17%|████▏                   | 43/250 [15:34<1:14:31, 21.60s/it]

Epoch [43/250], Loss: 0.0444478658


Total Progress:  18%|████▏                   | 44/250 [15:56<1:14:06, 21.58s/it]

Epoch [44/250], Loss: 0.0421835690


Total Progress:  18%|████▎                   | 45/250 [16:17<1:13:41, 21.57s/it]

Epoch [45/250], Loss: 0.0429799020


Total Progress:  18%|████▍                   | 46/250 [16:39<1:13:18, 21.56s/it]

Epoch [46/250], Loss: 0.0466817205


Total Progress:  19%|████▌                   | 47/250 [17:00<1:12:55, 21.55s/it]

Epoch [47/250], Loss: 0.0456526333


Total Progress:  19%|████▌                   | 48/250 [17:22<1:12:36, 21.57s/it]

Epoch [48/250], Loss: 0.0497746770


Total Progress:  20%|████▋                   | 49/250 [17:44<1:12:20, 21.59s/it]

Epoch [49/250], Loss: 0.0534647902


Total Progress:  20%|████▊                   | 50/250 [18:05<1:12:13, 21.67s/it]

Epoch [50/250], Loss: 0.0518261999


Total Progress:  20%|████▉                   | 51/250 [18:27<1:11:48, 21.65s/it]

Epoch [51/250], Loss: 0.0498359150


Total Progress:  21%|████▉                   | 52/250 [18:49<1:11:22, 21.63s/it]

Epoch [52/250], Loss: 0.0754990074


Total Progress:  21%|█████                   | 53/250 [19:10<1:10:57, 21.61s/it]

Epoch [53/250], Loss: 0.0577882510


Total Progress:  22%|█████▏                  | 54/250 [19:32<1:10:31, 21.59s/it]

Epoch [54/250], Loss: 0.0451850204


Total Progress:  22%|█████▎                  | 55/250 [19:53<1:10:07, 21.58s/it]

Epoch [55/250], Loss: 0.0436247330


Total Progress:  22%|█████▍                  | 56/250 [20:15<1:09:44, 21.57s/it]

Epoch [56/250], Loss: 0.0464226601


Total Progress:  23%|█████▍                  | 57/250 [20:36<1:09:22, 21.57s/it]

Epoch [57/250], Loss: 0.0483895605


Total Progress:  23%|█████▌                  | 58/250 [20:58<1:09:01, 21.57s/it]

Epoch [58/250], Loss: 0.0518066821


Total Progress:  24%|█████▋                  | 59/250 [21:19<1:08:40, 21.58s/it]

Epoch [59/250], Loss: 0.0445132079


Total Progress:  24%|█████▊                  | 60/250 [21:41<1:08:20, 21.58s/it]

Epoch [60/250], Loss: 0.0459726828


Total Progress:  24%|█████▊                  | 61/250 [22:03<1:07:53, 21.56s/it]

Epoch [61/250], Loss: 0.0455751276


Total Progress:  25%|█████▉                  | 62/250 [22:24<1:07:34, 21.56s/it]

Epoch [62/250], Loss: 0.0589033261


Total Progress:  25%|██████                  | 63/250 [22:46<1:07:12, 21.56s/it]

Epoch [63/250], Loss: 0.0487365581


Total Progress:  26%|██████▏                 | 64/250 [23:07<1:06:51, 21.57s/it]

Epoch [64/250], Loss: 0.0451880808


Total Progress:  26%|██████▏                 | 65/250 [23:29<1:06:29, 21.56s/it]

Epoch [65/250], Loss: 0.0441234550


Total Progress:  26%|██████▎                 | 66/250 [23:50<1:06:06, 21.56s/it]

Epoch [66/250], Loss: 0.0414908893


Total Progress:  27%|██████▍                 | 67/250 [24:12<1:05:45, 21.56s/it]

Epoch [67/250], Loss: 0.0484425498


Total Progress:  27%|██████▌                 | 68/250 [24:34<1:05:25, 21.57s/it]

Epoch [68/250], Loss: 0.0448474289


Total Progress:  28%|██████▌                 | 69/250 [24:55<1:05:02, 21.56s/it]

Epoch [69/250], Loss: 0.0447990107


Total Progress:  28%|██████▋                 | 70/250 [25:17<1:04:38, 21.55s/it]

Epoch [70/250], Loss: 0.0486968097


Total Progress:  28%|██████▊                 | 71/250 [25:38<1:04:19, 21.56s/it]

Epoch [71/250], Loss: 0.0473303130


Total Progress:  29%|██████▉                 | 72/250 [26:00<1:03:54, 21.54s/it]

Epoch [72/250], Loss: 0.0519951677


Total Progress:  29%|███████                 | 73/250 [26:21<1:03:33, 21.54s/it]

Epoch [73/250], Loss: 0.0565849792


Total Progress:  30%|███████                 | 74/250 [26:43<1:03:11, 21.54s/it]

Epoch [74/250], Loss: 0.0485855663


Total Progress:  30%|███████▏                | 75/250 [27:04<1:02:51, 21.55s/it]

Epoch [75/250], Loss: 0.0544015347


Total Progress:  30%|███████▎                | 76/250 [27:26<1:02:27, 21.54s/it]

Epoch [76/250], Loss: 0.0601769137


Total Progress:  31%|███████▍                | 77/250 [27:47<1:02:10, 21.56s/it]

Epoch [77/250], Loss: 0.0579112989


Total Progress:  31%|███████▍                | 78/250 [28:09<1:01:50, 21.57s/it]

Epoch [78/250], Loss: 0.0554262374


Total Progress:  32%|███████▌                | 79/250 [28:31<1:01:33, 21.60s/it]

Epoch [79/250], Loss: 0.0545162685


Total Progress:  32%|███████▋                | 80/250 [28:52<1:01:09, 21.59s/it]

Epoch [80/250], Loss: 0.0521945949


Total Progress:  32%|███████▊                | 81/250 [29:14<1:00:46, 21.58s/it]

Epoch [81/250], Loss: 0.0596694472


Total Progress:  33%|███████▊                | 82/250 [29:35<1:00:23, 21.57s/it]

Epoch [82/250], Loss: 0.0557373258


Total Progress:  33%|████████▋                 | 83/250 [29:57<59:58, 21.55s/it]

Epoch [83/250], Loss: 0.0490568506


Total Progress:  34%|████████▋                 | 84/250 [30:18<59:34, 21.53s/it]

Epoch [84/250], Loss: 0.0508629396


Total Progress:  34%|████████▊                 | 85/250 [30:40<59:13, 21.53s/it]

Epoch [85/250], Loss: 0.0482335371


Total Progress:  34%|████████▉                 | 86/250 [31:01<58:52, 21.54s/it]

Epoch [86/250], Loss: 0.0506684052


Total Progress:  35%|█████████                 | 87/250 [31:23<58:36, 21.58s/it]

Epoch [87/250], Loss: 0.0483389920


Total Progress:  35%|█████████▏                | 88/250 [31:45<58:19, 21.60s/it]

Epoch [88/250], Loss: 0.0476332053


Total Progress:  36%|█████████▎                | 89/250 [32:06<57:55, 21.59s/it]

Epoch [89/250], Loss: 0.0699291267


Total Progress:  36%|█████████▎                | 90/250 [32:28<57:36, 21.60s/it]

Epoch [90/250], Loss: 0.0517626470


Total Progress:  36%|█████████▍                | 91/250 [32:50<57:14, 21.60s/it]

Epoch [91/250], Loss: 0.0547307046


Total Progress:  37%|█████████▌                | 92/250 [33:11<56:50, 21.59s/it]

Epoch [92/250], Loss: 0.0509347496


Total Progress:  37%|█████████▋                | 93/250 [33:33<56:33, 21.62s/it]

Epoch [93/250], Loss: 0.0605109332


Total Progress:  38%|█████████▊                | 94/250 [33:54<56:11, 21.61s/it]

Epoch [94/250], Loss: 0.0516143027


Total Progress:  38%|█████████▉                | 95/250 [34:16<55:50, 21.61s/it]

Epoch [95/250], Loss: 0.0491977966


Total Progress:  38%|█████████▉                | 96/250 [34:38<55:25, 21.60s/it]

Epoch [96/250], Loss: 0.0508154394


Total Progress:  39%|██████████                | 97/250 [34:59<55:03, 21.59s/it]

Epoch [97/250], Loss: 0.0508716992


Total Progress:  39%|██████████▏               | 98/250 [35:21<54:42, 21.59s/it]

Epoch [98/250], Loss: 0.0610829381


Total Progress:  40%|██████████▎               | 99/250 [35:42<54:22, 21.61s/it]

Epoch [99/250], Loss: 0.0538726055


Total Progress:  40%|██████████               | 100/250 [36:04<53:58, 21.59s/it]

Epoch [100/250], Loss: 0.0495132396


Total Progress:  40%|██████████               | 101/250 [36:26<53:36, 21.58s/it]

Epoch [101/250], Loss: 0.0507516509


Total Progress:  41%|██████████▏              | 102/250 [36:47<53:11, 21.56s/it]

Epoch [102/250], Loss: 0.0494590354


Total Progress:  41%|██████████▎              | 103/250 [37:09<52:49, 21.56s/it]

Epoch [103/250], Loss: 0.0579591462


Total Progress:  42%|██████████▍              | 104/250 [37:30<52:33, 21.60s/it]

Epoch [104/250], Loss: 0.0619307467


Total Progress:  42%|██████████▌              | 105/250 [37:52<52:08, 21.58s/it]

Epoch [105/250], Loss: 0.0633499585


Total Progress:  42%|██████████▌              | 106/250 [38:13<51:46, 21.57s/it]

Epoch [106/250], Loss: 0.0575112125


Total Progress:  43%|██████████▋              | 107/250 [38:35<51:22, 21.56s/it]

Epoch [107/250], Loss: 0.0534905139


Total Progress:  43%|██████████▊              | 108/250 [38:57<51:04, 21.58s/it]

Epoch [108/250], Loss: 0.0539346563


Total Progress:  44%|██████████▉              | 109/250 [39:18<50:41, 21.57s/it]

Epoch [109/250], Loss: 0.0559779035


Total Progress:  44%|███████████              | 110/250 [39:40<50:18, 21.56s/it]

Epoch [110/250], Loss: 0.0505971324


Total Progress:  44%|███████████              | 111/250 [40:01<49:58, 21.57s/it]

Epoch [111/250], Loss: 0.0512902437


Total Progress:  45%|███████████▏             | 112/250 [40:23<49:36, 21.57s/it]

Epoch [112/250], Loss: 0.0536980645


Total Progress:  45%|███████████▎             | 113/250 [40:44<49:16, 21.58s/it]

Epoch [113/250], Loss: 0.0540788755


Total Progress:  46%|███████████▍             | 114/250 [41:06<48:57, 21.60s/it]

Epoch [114/250], Loss: 0.0515755015


Total Progress:  46%|███████████▌             | 115/250 [41:28<48:35, 21.59s/it]

Epoch [115/250], Loss: 0.0542094576


Total Progress:  46%|███████████▌             | 116/250 [41:49<48:12, 21.59s/it]

Epoch [116/250], Loss: 0.0575224566


Total Progress:  47%|███████████▋             | 117/250 [42:11<47:54, 21.61s/it]

Epoch [117/250], Loss: 0.0582239565


Total Progress:  47%|███████████▊             | 118/250 [42:32<47:31, 21.60s/it]

Epoch [118/250], Loss: 0.0609045658


Total Progress:  48%|███████████▉             | 119/250 [42:54<47:08, 21.59s/it]

Epoch [119/250], Loss: 0.0558330011


Total Progress:  48%|████████████             | 120/250 [43:16<46:45, 21.58s/it]

Epoch [120/250], Loss: 0.0533292121


Total Progress:  48%|████████████             | 121/250 [43:37<46:22, 21.57s/it]

Epoch [121/250], Loss: 0.0546265303


Total Progress:  49%|████████████▏            | 122/250 [43:59<45:58, 21.55s/it]

Epoch [122/250], Loss: 0.0557230120


Total Progress:  49%|████████████▎            | 123/250 [44:20<45:37, 21.56s/it]

Epoch [123/250], Loss: 0.0571082454


Total Progress:  50%|████████████▍            | 124/250 [44:42<45:14, 21.55s/it]

Epoch [124/250], Loss: 0.0584228022


Total Progress:  50%|████████████▌            | 125/250 [45:03<44:52, 21.54s/it]

Epoch [125/250], Loss: 0.0572718180


Total Progress:  50%|████████████▌            | 126/250 [45:25<44:30, 21.54s/it]

Epoch [126/250], Loss: 0.0572861141


Total Progress:  51%|████████████▋            | 127/250 [45:46<44:07, 21.52s/it]

Epoch [127/250], Loss: 0.0545188814


Total Progress:  51%|████████████▊            | 128/250 [46:08<43:44, 21.51s/it]

Epoch [128/250], Loss: 0.0508898344


Total Progress:  52%|████████████▉            | 129/250 [46:29<43:23, 21.51s/it]

Epoch [129/250], Loss: 0.0541002146


Total Progress:  52%|█████████████            | 130/250 [46:51<43:01, 21.51s/it]

Epoch [130/250], Loss: 0.0412195811


Total Progress:  52%|█████████████            | 131/250 [47:12<42:42, 21.53s/it]

Epoch [131/250], Loss: 0.0652247637


Total Progress:  53%|█████████████▏           | 132/250 [47:34<42:21, 21.53s/it]

Epoch [132/250], Loss: 0.0441883887


Total Progress:  53%|█████████████▎           | 133/250 [47:55<42:01, 21.55s/it]

Epoch [133/250], Loss: 0.0431376152


Total Progress:  54%|█████████████▍           | 134/250 [48:17<41:39, 21.55s/it]

Epoch [134/250], Loss: 0.0522586886


Total Progress:  54%|█████████████▌           | 135/250 [48:39<41:16, 21.53s/it]

Epoch [135/250], Loss: 0.0505364488


Total Progress:  54%|█████████████▌           | 136/250 [49:00<40:56, 21.55s/it]

Epoch [136/250], Loss: 0.0533349480


Total Progress:  55%|█████████████▋           | 137/250 [49:22<40:35, 21.55s/it]

Epoch [137/250], Loss: 0.0513292824


Total Progress:  55%|█████████████▊           | 138/250 [49:43<40:13, 21.55s/it]

Epoch [138/250], Loss: 0.0526221282


Total Progress:  56%|█████████████▉           | 139/250 [50:05<39:53, 21.56s/it]

Epoch [139/250], Loss: 0.0529849734


Total Progress:  56%|██████████████           | 140/250 [50:26<39:32, 21.57s/it]

Epoch [140/250], Loss: 0.0524284442


Total Progress:  56%|██████████████           | 141/250 [50:48<39:15, 21.61s/it]

Epoch [141/250], Loss: 0.0545291031


Total Progress:  57%|██████████████▏          | 142/250 [51:10<38:54, 21.62s/it]

Epoch [142/250], Loss: 0.0519647243


Total Progress:  57%|██████████████▎          | 143/250 [51:31<38:31, 21.60s/it]

Epoch [143/250], Loss: 0.0535557510


Total Progress:  58%|██████████████▍          | 144/250 [51:53<38:09, 21.60s/it]

Epoch [144/250], Loss: 0.0522674775


Total Progress:  58%|██████████████▍          | 145/250 [52:15<37:49, 21.61s/it]

Epoch [145/250], Loss: 0.0528055071


Total Progress:  58%|██████████████▌          | 146/250 [52:36<37:27, 21.61s/it]

Epoch [146/250], Loss: 0.0558307340


Total Progress:  59%|██████████████▋          | 147/250 [52:58<37:08, 21.64s/it]

Epoch [147/250], Loss: 0.0575311186


Total Progress:  59%|██████████████▊          | 148/250 [53:19<36:46, 21.63s/it]

Epoch [148/250], Loss: 0.0552614266


Total Progress:  60%|██████████████▉          | 149/250 [53:41<36:24, 21.63s/it]

Epoch [149/250], Loss: 0.0537039649


Total Progress:  60%|███████████████          | 150/250 [54:03<36:04, 21.65s/it]

Epoch [150/250], Loss: 0.0547555133


Total Progress:  60%|███████████████          | 151/250 [54:24<35:45, 21.67s/it]

Epoch [151/250], Loss: 0.0570455915


Total Progress:  61%|███████████████▏         | 152/250 [54:46<35:19, 21.63s/it]

Epoch [152/250], Loss: 0.0537707008


Total Progress:  61%|███████████████▎         | 153/250 [55:08<34:56, 21.62s/it]

Epoch [153/250], Loss: 0.0542801223


Total Progress:  62%|███████████████▍         | 154/250 [55:29<34:34, 21.61s/it]

Epoch [154/250], Loss: 0.0556074937


Total Progress:  62%|███████████████▌         | 155/250 [55:51<34:12, 21.60s/it]

Epoch [155/250], Loss: 0.0547824931


Total Progress:  62%|███████████████▌         | 156/250 [56:12<33:50, 21.60s/it]

Epoch [156/250], Loss: 0.0548753314


Total Progress:  63%|███████████████▋         | 157/250 [56:34<33:28, 21.60s/it]

Epoch [157/250], Loss: 0.0591665990


Total Progress:  63%|███████████████▊         | 158/250 [56:56<33:06, 21.60s/it]

Epoch [158/250], Loss: 0.0628868875


Total Progress:  64%|███████████████▉         | 159/250 [57:17<32:44, 21.59s/it]

Epoch [159/250], Loss: 0.0590287811


Total Progress:  64%|████████████████         | 160/250 [57:39<32:22, 21.58s/it]

Epoch [160/250], Loss: 0.0580507935


Total Progress:  64%|████████████████         | 161/250 [58:00<32:02, 21.60s/it]

Epoch [161/250], Loss: 0.0605737960


Total Progress:  65%|████████████████▏        | 162/250 [58:22<31:40, 21.60s/it]

Epoch [162/250], Loss: 0.0625193202


Total Progress:  65%|████████████████▎        | 163/250 [58:44<31:18, 21.60s/it]

Epoch [163/250], Loss: 0.0613209694


Total Progress:  66%|████████████████▍        | 164/250 [59:05<30:57, 21.60s/it]

Epoch [164/250], Loss: 0.0601561327


Total Progress:  66%|████████████████▌        | 165/250 [59:27<30:35, 21.59s/it]

Epoch [165/250], Loss: 0.0649064746


Total Progress:  66%|████████████████▌        | 166/250 [59:48<30:13, 21.59s/it]

Epoch [166/250], Loss: 0.1018927939


Total Progress:  67%|███████████████▎       | 167/250 [1:00:10<29:51, 21.59s/it]

Epoch [167/250], Loss: 0.0598344495


Total Progress:  67%|███████████████▍       | 168/250 [1:00:31<29:29, 21.58s/it]

Epoch [168/250], Loss: 0.0513066990


Total Progress:  68%|███████████████▌       | 169/250 [1:00:53<29:07, 21.57s/it]

Epoch [169/250], Loss: 0.0507483098


Total Progress:  68%|███████████████▋       | 170/250 [1:01:15<28:45, 21.57s/it]

Epoch [170/250], Loss: 0.0539443396


Total Progress:  68%|███████████████▋       | 171/250 [1:01:36<28:24, 21.57s/it]

Epoch [171/250], Loss: 0.0502195390


Total Progress:  69%|███████████████▊       | 172/250 [1:01:58<28:01, 21.56s/it]

Epoch [172/250], Loss: 0.0608667605


Total Progress:  69%|███████████████▉       | 173/250 [1:02:19<27:40, 21.56s/it]

Epoch [173/250], Loss: 0.0551532050


Total Progress:  70%|████████████████       | 174/250 [1:02:41<27:19, 21.57s/it]

Epoch [174/250], Loss: 0.0490934002


Total Progress:  70%|████████████████       | 175/250 [1:03:02<26:58, 21.58s/it]

Epoch [175/250], Loss: 0.0522942143


Total Progress:  70%|████████████████▏      | 176/250 [1:03:24<26:41, 21.64s/it]

Epoch [176/250], Loss: 0.0482817823


Total Progress:  71%|████████████████▎      | 177/250 [1:03:46<26:18, 21.63s/it]

Epoch [177/250], Loss: 0.0612721747


Total Progress:  71%|████████████████▍      | 178/250 [1:04:07<25:56, 21.62s/it]

Epoch [178/250], Loss: 0.0558690112


Total Progress:  72%|████████████████▍      | 179/250 [1:04:29<25:34, 21.61s/it]

Epoch [179/250], Loss: 0.0579194305


Total Progress:  72%|████████████████▌      | 180/250 [1:04:51<25:12, 21.61s/it]

Epoch [180/250], Loss: 0.0570056405


Total Progress:  72%|████████████████▋      | 181/250 [1:05:12<24:50, 21.60s/it]

Epoch [181/250], Loss: 0.0583136839


Total Progress:  73%|████████████████▋      | 182/250 [1:05:34<24:28, 21.60s/it]

Epoch [182/250], Loss: 0.0586587379


Total Progress:  73%|████████████████▊      | 183/250 [1:05:55<24:07, 21.60s/it]

Epoch [183/250], Loss: 0.0579041971


Total Progress:  74%|████████████████▉      | 184/250 [1:06:17<23:45, 21.59s/it]

Epoch [184/250], Loss: 0.0656376066


Total Progress:  74%|█████████████████      | 185/250 [1:06:39<23:23, 21.59s/it]

Epoch [185/250], Loss: 0.0571670211


Total Progress:  74%|█████████████████      | 186/250 [1:07:00<23:02, 21.61s/it]

Epoch [186/250], Loss: 0.0582735744


Total Progress:  75%|█████████████████▏     | 187/250 [1:07:22<22:41, 21.61s/it]

Epoch [187/250], Loss: 0.0655491939


Total Progress:  75%|█████████████████▎     | 188/250 [1:07:43<22:18, 21.59s/it]

Epoch [188/250], Loss: 0.0620825674


Total Progress:  76%|█████████████████▍     | 189/250 [1:08:05<21:56, 21.59s/it]

Epoch [189/250], Loss: 0.0586000874


Total Progress:  76%|█████████████████▍     | 190/250 [1:08:27<21:35, 21.59s/it]

Epoch [190/250], Loss: 0.0657742855


Total Progress:  76%|█████████████████▌     | 191/250 [1:08:48<21:14, 21.60s/it]

Epoch [191/250], Loss: 0.0750390873


Total Progress:  77%|█████████████████▋     | 192/250 [1:09:10<20:52, 21.59s/it]

Epoch [192/250], Loss: 0.0673576223


Total Progress:  77%|█████████████████▊     | 193/250 [1:09:31<20:30, 21.59s/it]

Epoch [193/250], Loss: 0.0658944929


Total Progress:  78%|█████████████████▊     | 194/250 [1:09:53<20:09, 21.59s/it]

Epoch [194/250], Loss: 0.0734964412


Total Progress:  78%|█████████████████▉     | 195/250 [1:10:15<19:49, 21.62s/it]

Epoch [195/250], Loss: 0.0748960326


Total Progress:  78%|██████████████████     | 196/250 [1:10:36<19:27, 21.62s/it]

Epoch [196/250], Loss: 0.0714899012


Total Progress:  79%|██████████████████     | 197/250 [1:10:58<19:05, 21.61s/it]

Epoch [197/250], Loss: 0.0720562016


Total Progress:  79%|██████████████████▏    | 198/250 [1:11:19<18:43, 21.60s/it]

Epoch [198/250], Loss: 0.0751406649


Total Progress:  80%|██████████████████▎    | 199/250 [1:11:41<18:21, 21.60s/it]

Epoch [199/250], Loss: 0.0717874061


Total Progress:  80%|██████████████████▍    | 200/250 [1:12:03<17:59, 21.60s/it]

Epoch [200/250], Loss: 0.0916231498


Total Progress:  80%|██████████████████▍    | 201/250 [1:12:24<17:38, 21.60s/it]

Epoch [201/250], Loss: 0.0622884185


Total Progress:  81%|██████████████████▌    | 202/250 [1:12:46<17:18, 21.63s/it]

Epoch [202/250], Loss: 0.0659649181


Total Progress:  81%|██████████████████▋    | 203/250 [1:13:08<16:56, 21.63s/it]

Epoch [203/250], Loss: 0.0641530478


Total Progress:  82%|██████████████████▊    | 204/250 [1:13:29<16:34, 21.62s/it]

Epoch [204/250], Loss: 0.0639706012


Total Progress:  82%|██████████████████▊    | 205/250 [1:13:51<16:12, 21.61s/it]

Epoch [205/250], Loss: 0.0628574616


Total Progress:  82%|██████████████████▉    | 206/250 [1:14:12<15:50, 21.61s/it]

Epoch [206/250], Loss: 0.0660056574


Total Progress:  83%|███████████████████    | 207/250 [1:14:34<15:29, 21.61s/it]

Epoch [207/250], Loss: 0.0638983938


Total Progress:  83%|███████████████████▏   | 208/250 [1:14:55<15:07, 21.60s/it]

Epoch [208/250], Loss: 0.0719793886


Total Progress:  84%|███████████████████▏   | 209/250 [1:15:17<14:44, 21.57s/it]

Epoch [209/250], Loss: 0.0690110041


Total Progress:  84%|███████████████████▎   | 210/250 [1:15:39<14:23, 21.58s/it]

Epoch [210/250], Loss: 0.0736839549


Total Progress:  84%|███████████████████▍   | 211/250 [1:16:00<14:01, 21.58s/it]

Epoch [211/250], Loss: 0.0689672951


Total Progress:  85%|███████████████████▌   | 212/250 [1:16:22<13:39, 21.56s/it]

Epoch [212/250], Loss: 0.0700898634


Total Progress:  85%|███████████████████▌   | 213/250 [1:16:43<13:17, 21.56s/it]

Epoch [213/250], Loss: 0.0619219397


Total Progress:  86%|███████████████████▋   | 214/250 [1:17:05<12:55, 21.55s/it]

Epoch [214/250], Loss: 0.0642133430


Total Progress:  86%|███████████████████▊   | 215/250 [1:17:26<12:33, 21.53s/it]

Epoch [215/250], Loss: 0.0620206902


Total Progress:  86%|███████████████████▊   | 216/250 [1:17:48<12:13, 21.57s/it]

Epoch [216/250], Loss: 0.0637572261


Total Progress:  87%|███████████████████▉   | 217/250 [1:18:09<11:51, 21.56s/it]

Epoch [217/250], Loss: 0.0702093249


Total Progress:  87%|████████████████████   | 218/250 [1:18:31<11:30, 21.58s/it]

Epoch [218/250], Loss: 0.0629093903


Total Progress:  88%|████████████████████▏  | 219/250 [1:18:53<11:09, 21.59s/it]

Epoch [219/250], Loss: 0.0799373467


Total Progress:  88%|████████████████████▏  | 220/250 [1:19:14<10:47, 21.58s/it]

Epoch [220/250], Loss: 0.0623684151


Total Progress:  88%|████████████████████▎  | 221/250 [1:19:36<10:26, 21.60s/it]

Epoch [221/250], Loss: 0.0557920483


Total Progress:  89%|████████████████████▍  | 222/250 [1:19:58<10:06, 21.67s/it]

Epoch [222/250], Loss: 0.0582437287


Total Progress:  89%|████████████████████▌  | 223/250 [1:20:19<09:44, 21.66s/it]

Epoch [223/250], Loss: 0.0633760266


Total Progress:  90%|████████████████████▌  | 224/250 [1:20:41<09:23, 21.67s/it]

Epoch [224/250], Loss: 0.0655234977


## Save model

In [None]:
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint['model'])

## Prediction

In [None]:
!mkdir prediction

In [None]:
model.eval()
for i in os.listdir("data/test/test"):
    img_path = os.path.join("data/test/test", i)
    ori_img = cv2.imread(img_path)
    ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
    ori_w = ori_img.shape[0]
    ori_h = ori_img.shape[1]
    img = cv2.resize(ori_img, (256, 256))
    transformed = val_transformation(image=img)
    input_img = transformed["image"]
    input_img = input_img.unsqueeze(0).to(device)
    with torch.no_grad():
        output_mask = model.forward(input_img).squeeze(0).cpu().numpy().transpose(1,2,0)
    mask = cv2.resize(output_mask, (ori_h, ori_w))
    mask = np.argmax(mask, axis=2)
    mask_rgb = mask_to_rgb(mask, color_dict)
    mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2BGR)
    cv2.imwrite("prediction/{}".format(i), mask_rgb) 

In [None]:
def rle_to_string(runs):
    return ' '.join(str(x) for x in runs)

def rle_encode_one_mask(mask):
    pixels = mask.flatten()
    pixels[pixels > 225] = 255
    pixels[pixels <= 225] = 0
    use_padding = False
    if pixels[0] or pixels[-1]:
        use_padding = True
        pixel_padded = np.zeros([len(pixels) + 2], dtype=pixels.dtype)
        pixel_padded[1:-1] = pixels
        pixels = pixel_padded
    rle = np.where(pixels[1:] != pixels[:-1])[0] + 2
    if use_padding:
        rle = rle - 1
    rle[1::2] = rle[1::2] - rle[:-1:2]
    
    return rle_to_string(rle)

def rle2mask(mask_rle, shape=(3,3)):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

def mask2string(dir):
    strings = []
    ids = []
    ws, hs = [[] for i in range(2)]
    for image_id in os.listdir(dir):
        id = image_id.split('.')[0]
        path = os.path.join(dir, image_id)
        print(path)
        img = cv2.imread(path)[:,:,::-1]
        h, w = img.shape[0], img.shape[1]
        for channel in range(2):
            ws.append(w)
            hs.append(h)
            ids.append(f'{id}_{channel}')
            string = rle_encode_one_mask(img[:,:,channel])
            strings.append(string)
    r = {
        'ids': ids,
        'strings': strings,
    }
    return r


MASK_DIR_PATH = '/prediction'
dir = MASK_DIR_PATH
res = mask2string(dir)
df = pd.DataFrame(columns=['Id', 'Expected'])
df['Id'] = res['ids']
df['Expected'] = res['strings']

df.to_csv(r'output.csv', index=False)