In [1]:
import os
import pandas as pd
import numpy as np
import cv2
from torchvision.io import read_image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, random_split, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from torchvision.transforms import ToTensor
from PIL import Image
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision 
from torchvision import transforms
from torchinfo import summary
import timm
import segmentation_models_pytorch as smp
import wandb

from torchgeometry.losses import DiceLoss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!nvidia-smi -L

GPU 0: NVIDIA GeForce RTX 3050 Ti Laptop GPU (UUID: GPU-922977af-d68d-ebac-3c8f-b0b483c19424)


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
class DatasetCustom(Dataset):
    def __init__(self, img_dir, label_dir, resize=None, transform=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.resize = resize
        self.transform = transform
        self.images = os.listdir(self.img_dir)

    def __len__(self):
        return len(self.images)
    
    def read_mask(self, mask_path):
        image = cv2.imread(mask_path)
        if self.resize:
            image = cv2.resize(image, self.resize)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

        lower_red1 = np.array([0, 100, 20])
        upper_red1 = np.array([10, 255, 255])
        lower_red2 = np.array([160,100,20])
        upper_red2 = np.array([179,255,255])
        
        lower_mask_red = cv2.inRange(image, lower_red1, upper_red1)
        upper_mask_red = cv2.inRange(image, lower_red2, upper_red2)
        
        red_mask = lower_mask_red + upper_mask_red
        red_mask[red_mask != 0] = 1

        green_mask = cv2.inRange(image, (36, 25, 25), (70, 255, 255))
        green_mask[green_mask != 0] = 2

        full_mask = cv2.bitwise_or(red_mask, green_mask)
        full_mask = np.expand_dims(full_mask, axis=-1) 
        full_mask = full_mask.astype(np.uint8)
        
        return full_mask

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        label_path = os.path.join(self.label_dir, self.images[idx])
        image = cv2.imread(img_path)  #  BGR
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert to RGB
        label = self.read_mask(label_path)  
        image = cv2.resize(image, self.resize)
        if self.transform:
            image = self.transform(image)
            
        return image, label

In [5]:
images_path = "../data/train/train/"
image_path = []
TRAIN_DIR = '../data/train/train'
for root, dirs, files in os.walk(TRAIN_DIR):
    for file in files:
        path = os.path.join(root,file)
        image_path.append(path)
        
len(image_path)

1000

In [6]:
mask_path = []
TRAIN_MASK_DIR = '../data/train_gt/train_gt'
for root, dirs, files in os.walk(TRAIN_MASK_DIR):
    for file in files:
        path = os.path.join(root,file)
        mask_path.append(path)
        
len(mask_path)

1000

In [7]:
dataset = DatasetCustom(img_dir= TRAIN_DIR,
                             label_dir= TRAIN_MASK_DIR,
                             resize= (256,256),
                             transform = None)

In [8]:
batch_size = 8
images_data = []
labels_data = []
for x,y in dataset:
    images_data.append(x)
    labels_data.append(y)

In [9]:
import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",        
    encoder_weights="imagenet",     
    in_channels=3,                  
    classes=3     
)

In [10]:
class CustomDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = targets
        self.transform = transform

    def __getitem__(self, index):
        image = self.data[index]
        label = self.targets[index]
        assert image.shape[:2] == label.shape[:2]
        if self.transform:
            transformed = self.transform(image=image, mask=label)
            image = transformed['image'].float()
            label = transformed['mask'].float()
            label = label.permute(2, 0, 1)
        return image, label
    
    def __len__(self):
        return len(self.data)

In [11]:
train_transformation = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomCrop(height=256, width=256, p=0.5),
    A.RandomGamma(gamma_limit=(70, 130), always_apply=False, p=0.2),
    A.RGBShift(p=0.3, r_shift_limit=10, g_shift_limit=10, b_shift_limit=10),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transformation = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

In [12]:
train_size = int(0.8 * len(images_data))
val_size = len(images_data) - train_size
train_dataset = CustomDataset(images_data[:train_size], labels_data[:train_size], transform=train_transformation)
val_dataset = CustomDataset(images_data[train_size:], labels_data[train_size:], transform=val_transformation)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [13]:
learning_rate = 0.0001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [14]:
color_dict= {0: (0, 0, 0),
             1: (255, 0, 0),
             2: (0, 255, 0)}
def mask_to_rgb(mask, color_dict):
    output = np.zeros((mask.shape[0], mask.shape[1], 3))

    for k in color_dict.keys():
        output[mask==k] = color_dict[k]

    return np.uint8(output)    

In [15]:
wandb.init(
    project = "DlPolypSegment",
    name = "Unet"
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtuantruongvu[0m ([33mtuantruongvu-hanoi-university-of-science-and-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [16]:
from tqdm import tqdm
import time

num_epochs = 100

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
model.to(device)
criterion1 = nn.CrossEntropyLoss()
criterion2 = DiceLoss()
best_val_loss = 999

epoch_bar = tqdm(total=num_epochs, desc='Total Progress')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        labels = labels.squeeze(dim=1).long()
        outputs = model(images)
    
        loss1 = criterion1(outputs, labels)
        loss2 = criterion2(outputs, labels)
        loss = loss1 + loss2
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            labels = labels.squeeze(dim=1).long()
            
            outputs = model(images)

            val_loss1 = criterion1(outputs,labels)
            val_loss2 = criterion2(outputs,labels)
            val_loss += val_loss1 + val_loss2

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss/len(train_loader):.10f} , Val Loss: {val_loss/len(val_loader):.10f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint = { 
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': val_loss,
        }
        save_path = f'model.pth'
        torch.save(checkpoint, save_path)
        
    epoch_bar.update(1)
    wandb.log({'Val_loss': val_loss/len(val_loader),'Train_loss': train_loss/len(train_loader)})
epoch_bar.close()
wandb.finish()

Total Progress:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/100], Train Loss: 1.4156233174 , Val Loss: 0.8349348307


Total Progress:   1%|          | 1/100 [00:17<28:41, 17.39s/it]

Epoch [2/100], Train Loss: 0.6414102826 , Val Loss: 0.4665164948


Total Progress:   2%|▏         | 2/100 [00:32<26:21, 16.14s/it]

Epoch [3/100], Train Loss: 0.3706114805 , Val Loss: 0.2850464284


Total Progress:   3%|▎         | 3/100 [00:48<25:38, 15.86s/it]

Epoch [4/100], Train Loss: 0.2419301184 , Val Loss: 0.2020360678


Total Progress:   4%|▍         | 4/100 [01:03<24:55, 15.58s/it]

Epoch [5/100], Train Loss: 0.1824513938 , Val Loss: 0.1539010853


Total Progress:   5%|▌         | 5/100 [01:18<24:26, 15.43s/it]

Epoch [6/100], Train Loss: 0.1447744396 , Val Loss: 0.1369906664


Total Progress:   6%|▌         | 6/100 [01:33<24:05, 15.37s/it]

Epoch [7/100], Train Loss: 0.1229066705 , Val Loss: 0.1311733276


Total Progress:   7%|▋         | 7/100 [01:48<23:42, 15.30s/it]

Epoch [8/100], Train Loss: 0.1106884025 , Val Loss: 0.1112009957


Total Progress:   8%|▊         | 8/100 [02:04<23:22, 15.25s/it]

Epoch [9/100], Train Loss: 0.0929467160 , Val Loss: 0.1014075354


Total Progress:   9%|▉         | 9/100 [02:19<23:05, 15.22s/it]

Epoch [10/100], Train Loss: 0.0913143014 , Val Loss: 0.0900473967


Total Progress:  11%|█         | 11/100 [02:49<22:20, 15.07s/it]

Epoch [11/100], Train Loss: 0.0764397995 , Val Loss: 0.0918737128
Epoch [12/100], Train Loss: 0.0739465679 , Val Loss: 0.0863076672


Total Progress:  12%|█▏        | 12/100 [03:04<22:07, 15.09s/it]

Epoch [13/100], Train Loss: 0.0709374840 , Val Loss: 0.0848543793


Total Progress:  13%|█▎        | 13/100 [03:19<21:57, 15.15s/it]

Epoch [14/100], Train Loss: 0.0661234054 , Val Loss: 0.0767937899


Total Progress:  15%|█▌        | 15/100 [03:49<21:21, 15.07s/it]

Epoch [15/100], Train Loss: 0.0621520465 , Val Loss: 0.0882717967
Epoch [16/100], Train Loss: 0.0599807185 , Val Loss: 0.0754979998


Total Progress:  17%|█▋        | 17/100 [04:19<20:46, 15.01s/it]

Epoch [17/100], Train Loss: 0.0526146760 , Val Loss: 0.0895142332


Total Progress:  18%|█▊        | 18/100 [04:34<20:26, 14.96s/it]

Epoch [18/100], Train Loss: 0.0539667884 , Val Loss: 0.0848251581


Total Progress:  19%|█▉        | 19/100 [04:49<20:08, 14.92s/it]

Epoch [19/100], Train Loss: 0.0528663885 , Val Loss: 0.0827867836
Epoch [20/100], Train Loss: 0.0469985294 , Val Loss: 0.0718805417


Total Progress:  21%|██        | 21/100 [05:19<19:42, 14.96s/it]

Epoch [21/100], Train Loss: 0.0475697471 , Val Loss: 0.0767958015


Total Progress:  22%|██▏       | 22/100 [05:34<19:24, 14.93s/it]

Epoch [22/100], Train Loss: 0.0465752776 , Val Loss: 0.0732364133


Total Progress:  23%|██▎       | 23/100 [05:49<19:08, 14.91s/it]

Epoch [23/100], Train Loss: 0.0416541798 , Val Loss: 0.0723756328


Total Progress:  24%|██▍       | 24/100 [06:03<18:51, 14.89s/it]

Epoch [24/100], Train Loss: 0.0435015935 , Val Loss: 0.0850890502


Total Progress:  25%|██▌       | 25/100 [06:18<18:36, 14.89s/it]

Epoch [25/100], Train Loss: 0.0468000332 , Val Loss: 0.0779942349
Epoch [26/100], Train Loss: 0.0479917901 , Val Loss: 0.0624882840


Total Progress:  27%|██▋       | 27/100 [06:48<18:09, 14.93s/it]

Epoch [27/100], Train Loss: 0.0466873510 , Val Loss: 0.0705930144


Total Progress:  28%|██▊       | 28/100 [07:03<17:54, 14.92s/it]

Epoch [28/100], Train Loss: 0.0439725461 , Val Loss: 0.0716617927


Total Progress:  29%|██▉       | 29/100 [07:18<17:38, 14.90s/it]

Epoch [29/100], Train Loss: 0.0361370283 , Val Loss: 0.0637115538


Total Progress:  30%|███       | 30/100 [07:33<17:22, 14.89s/it]

Epoch [30/100], Train Loss: 0.0355657116 , Val Loss: 0.0658224747
Epoch [31/100], Train Loss: 0.0361338698 , Val Loss: 0.0619552992


Total Progress:  32%|███▏      | 32/100 [08:03<16:56, 14.95s/it]

Epoch [32/100], Train Loss: 0.0383722734 , Val Loss: 0.0751283988


Total Progress:  33%|███▎      | 33/100 [08:18<16:39, 14.92s/it]

Epoch [33/100], Train Loss: 0.0322013484 , Val Loss: 0.0698330924


Total Progress:  34%|███▍      | 34/100 [08:33<16:24, 14.91s/it]

Epoch [34/100], Train Loss: 0.0349134495 , Val Loss: 0.0620221980


Total Progress:  35%|███▌      | 35/100 [08:48<16:08, 14.91s/it]

Epoch [35/100], Train Loss: 0.0363568546 , Val Loss: 0.0799955800
Epoch [36/100], Train Loss: 0.0390383301 , Val Loss: 0.0592526980


Total Progress:  37%|███▋      | 37/100 [09:18<15:41, 14.95s/it]

Epoch [37/100], Train Loss: 0.0317286856 , Val Loss: 0.0857253596
Epoch [38/100], Train Loss: 0.0327192588 , Val Loss: 0.0586996786


Total Progress:  39%|███▉      | 39/100 [09:48<15:14, 14.99s/it]

Epoch [39/100], Train Loss: 0.0295238936 , Val Loss: 0.0638316870


Total Progress:  40%|████      | 40/100 [10:03<14:57, 14.95s/it]

Epoch [40/100], Train Loss: 0.0334192177 , Val Loss: 0.0745705813


Total Progress:  41%|████      | 41/100 [10:18<14:39, 14.90s/it]

Epoch [41/100], Train Loss: 0.0351011133 , Val Loss: 0.0623549670


Total Progress:  42%|████▏     | 42/100 [10:32<14:23, 14.90s/it]

Epoch [42/100], Train Loss: 0.0291813593 , Val Loss: 0.0743219927


Total Progress:  43%|████▎     | 43/100 [10:47<14:07, 14.86s/it]

Epoch [43/100], Train Loss: 0.0314576567 , Val Loss: 0.0774240866


Total Progress:  44%|████▍     | 44/100 [11:02<13:52, 14.86s/it]

Epoch [44/100], Train Loss: 0.0267499324 , Val Loss: 0.0699066147


Total Progress:  45%|████▌     | 45/100 [11:17<13:36, 14.85s/it]

Epoch [45/100], Train Loss: 0.0266517891 , Val Loss: 0.0699161887


Total Progress:  46%|████▌     | 46/100 [11:32<13:21, 14.83s/it]

Epoch [46/100], Train Loss: 0.0271609612 , Val Loss: 0.0868144333


Total Progress:  47%|████▋     | 47/100 [11:46<13:05, 14.83s/it]

Epoch [47/100], Train Loss: 0.0296038495 , Val Loss: 0.0619847029


Total Progress:  48%|████▊     | 48/100 [12:01<12:50, 14.82s/it]

Epoch [48/100], Train Loss: 0.0279015356 , Val Loss: 0.0773081481


Total Progress:  49%|████▉     | 49/100 [12:16<12:36, 14.83s/it]

Epoch [49/100], Train Loss: 0.0276191126 , Val Loss: 0.0667604655


Total Progress:  50%|█████     | 50/100 [12:31<12:21, 14.83s/it]

Epoch [50/100], Train Loss: 0.0259080781 , Val Loss: 0.0639141947


Total Progress:  51%|█████     | 51/100 [12:46<12:05, 14.80s/it]

Epoch [51/100], Train Loss: 0.0294444706 , Val Loss: 0.0679208115


Total Progress:  52%|█████▏    | 52/100 [13:00<11:50, 14.80s/it]

Epoch [52/100], Train Loss: 0.0274439706 , Val Loss: 0.0692981109


Total Progress:  53%|█████▎    | 53/100 [13:15<11:35, 14.79s/it]

Epoch [53/100], Train Loss: 0.0249999603 , Val Loss: 0.0775171369


Total Progress:  54%|█████▍    | 54/100 [13:30<11:20, 14.80s/it]

Epoch [54/100], Train Loss: 0.0318184723 , Val Loss: 0.0648816302


Total Progress:  55%|█████▌    | 55/100 [13:45<11:05, 14.79s/it]

Epoch [55/100], Train Loss: 0.0273627751 , Val Loss: 0.0796312764


Total Progress:  56%|█████▌    | 56/100 [14:00<10:50, 14.79s/it]

Epoch [56/100], Train Loss: 0.0260301823 , Val Loss: 0.0707893074


Total Progress:  57%|█████▋    | 57/100 [14:14<10:35, 14.78s/it]

Epoch [57/100], Train Loss: 0.0247360885 , Val Loss: 0.0631930456


Total Progress:  58%|█████▊    | 58/100 [14:29<10:20, 14.76s/it]

Epoch [58/100], Train Loss: 0.0219485257 , Val Loss: 0.0660702810


Total Progress:  59%|█████▉    | 59/100 [14:44<10:05, 14.76s/it]

Epoch [59/100], Train Loss: 0.0237540491 , Val Loss: 0.0696342885


Total Progress:  60%|██████    | 60/100 [14:59<09:50, 14.75s/it]

Epoch [60/100], Train Loss: 0.0221062770 , Val Loss: 0.0635866374


Total Progress:  61%|██████    | 61/100 [15:13<09:35, 14.76s/it]

Epoch [61/100], Train Loss: 0.0192654991 , Val Loss: 0.0676189139


Total Progress:  62%|██████▏   | 62/100 [15:28<09:20, 14.76s/it]

Epoch [62/100], Train Loss: 0.0185887625 , Val Loss: 0.0698819831


Total Progress:  63%|██████▎   | 63/100 [15:43<09:05, 14.75s/it]

Epoch [63/100], Train Loss: 0.0219854944 , Val Loss: 0.0692262948


Total Progress:  64%|██████▍   | 64/100 [15:58<08:50, 14.74s/it]

Epoch [64/100], Train Loss: 0.0199484578 , Val Loss: 0.0813274235


Total Progress:  65%|██████▌   | 65/100 [16:12<08:36, 14.74s/it]

Epoch [65/100], Train Loss: 0.0204077258 , Val Loss: 0.0710155889


Total Progress:  66%|██████▌   | 66/100 [16:27<08:20, 14.73s/it]

Epoch [66/100], Train Loss: 0.0214376932 , Val Loss: 0.0713761970


Total Progress:  67%|██████▋   | 67/100 [16:42<08:05, 14.71s/it]

Epoch [67/100], Train Loss: 0.0185741850 , Val Loss: 0.0719060451


Total Progress:  68%|██████▊   | 68/100 [16:56<07:49, 14.68s/it]

Epoch [68/100], Train Loss: 0.0217450258 , Val Loss: 0.0827539787


Total Progress:  69%|██████▉   | 69/100 [17:11<07:35, 14.68s/it]

Epoch [69/100], Train Loss: 0.0378542634 , Val Loss: 0.0735819489


Total Progress:  70%|███████   | 70/100 [17:26<07:20, 14.69s/it]

Epoch [70/100], Train Loss: 0.0284983214 , Val Loss: 0.0588091500


Total Progress:  71%|███████   | 71/100 [17:41<07:06, 14.72s/it]

Epoch [71/100], Train Loss: 0.0306257652 , Val Loss: 0.0755160972


Total Progress:  72%|███████▏  | 72/100 [17:56<06:54, 14.80s/it]

Epoch [72/100], Train Loss: 0.0261042247 , Val Loss: 0.0657979771


Total Progress:  73%|███████▎  | 73/100 [18:10<06:40, 14.85s/it]

Epoch [73/100], Train Loss: 0.0204380774 , Val Loss: 0.0606478266


Total Progress:  74%|███████▍  | 74/100 [18:25<06:26, 14.85s/it]

Epoch [74/100], Train Loss: 0.0186301383 , Val Loss: 0.0702661201


Total Progress:  75%|███████▌  | 75/100 [18:40<06:11, 14.86s/it]

Epoch [75/100], Train Loss: 0.0160367880 , Val Loss: 0.0681894273


Total Progress:  76%|███████▌  | 76/100 [18:55<05:57, 14.88s/it]

Epoch [76/100], Train Loss: 0.0148060307 , Val Loss: 0.0679971427


Total Progress:  77%|███████▋  | 77/100 [19:10<05:42, 14.91s/it]

Epoch [77/100], Train Loss: 0.0353821884 , Val Loss: 0.0690390393


Total Progress:  78%|███████▊  | 78/100 [19:25<05:27, 14.91s/it]

Epoch [78/100], Train Loss: 0.0228020244 , Val Loss: 0.0671818331


Total Progress:  79%|███████▉  | 79/100 [19:40<05:12, 14.89s/it]

Epoch [79/100], Train Loss: 0.0192677474 , Val Loss: 0.0589426793


Total Progress:  80%|████████  | 80/100 [19:55<04:57, 14.89s/it]

Epoch [80/100], Train Loss: 0.0165137359 , Val Loss: 0.0671133250


Total Progress:  81%|████████  | 81/100 [20:10<04:43, 14.93s/it]

Epoch [81/100], Train Loss: 0.0179575439 , Val Loss: 0.0629432946


Total Progress:  82%|████████▏ | 82/100 [20:25<04:28, 14.92s/it]

Epoch [82/100], Train Loss: 0.0188978815 , Val Loss: 0.0698162839


Total Progress:  83%|████████▎ | 83/100 [20:40<04:13, 14.92s/it]

Epoch [83/100], Train Loss: 0.0196648791 , Val Loss: 0.0705471337


Total Progress:  84%|████████▍ | 84/100 [20:54<03:58, 14.90s/it]

Epoch [84/100], Train Loss: 0.0256054749 , Val Loss: 0.0832020715


Total Progress:  85%|████████▌ | 85/100 [21:09<03:43, 14.92s/it]

Epoch [85/100], Train Loss: 0.0297114631 , Val Loss: 0.0723200142


Total Progress:  86%|████████▌ | 86/100 [21:24<03:28, 14.92s/it]

Epoch [86/100], Train Loss: 0.0194494261 , Val Loss: 0.0669098273


Total Progress:  87%|████████▋ | 87/100 [21:39<03:13, 14.92s/it]

Epoch [87/100], Train Loss: 0.0166434452 , Val Loss: 0.0697203204


Total Progress:  88%|████████▊ | 88/100 [21:54<02:59, 14.95s/it]

Epoch [88/100], Train Loss: 0.0221981785 , Val Loss: 0.0950733349


Total Progress:  89%|████████▉ | 89/100 [22:09<02:44, 14.95s/it]

Epoch [89/100], Train Loss: 0.0210334626 , Val Loss: 0.0875064805


Total Progress:  90%|█████████ | 90/100 [22:24<02:29, 14.95s/it]

Epoch [90/100], Train Loss: 0.0163666967 , Val Loss: 0.0834618658


Total Progress:  91%|█████████ | 91/100 [22:39<02:14, 14.94s/it]

Epoch [91/100], Train Loss: 0.0234630830 , Val Loss: 0.0725720897


Total Progress:  92%|█████████▏| 92/100 [22:54<01:59, 14.93s/it]

Epoch [92/100], Train Loss: 0.0181823208 , Val Loss: 0.0869698822


Total Progress:  93%|█████████▎| 93/100 [23:09<01:44, 14.93s/it]

Epoch [93/100], Train Loss: 0.0191810295 , Val Loss: 0.0719959736


Total Progress:  94%|█████████▍| 94/100 [23:24<01:29, 14.94s/it]

Epoch [94/100], Train Loss: 0.0162532969 , Val Loss: 0.0680731460


Total Progress:  95%|█████████▌| 95/100 [23:39<01:14, 14.93s/it]

Epoch [95/100], Train Loss: 0.0152783103 , Val Loss: 0.0701837540


Total Progress:  96%|█████████▌| 96/100 [23:54<00:59, 14.94s/it]

Epoch [96/100], Train Loss: 0.0161809626 , Val Loss: 0.0730343685


Total Progress:  97%|█████████▋| 97/100 [24:09<00:44, 14.93s/it]

Epoch [97/100], Train Loss: 0.0191337145 , Val Loss: 0.0785927549


Total Progress:  98%|█████████▊| 98/100 [24:24<00:29, 14.92s/it]

Epoch [98/100], Train Loss: 0.0157275796 , Val Loss: 0.0767626390


Total Progress:  99%|█████████▉| 99/100 [24:38<00:14, 14.93s/it]

Epoch [99/100], Train Loss: 0.0162185761 , Val Loss: 0.0669567734


Total Progress: 100%|██████████| 100/100 [24:53<00:00, 14.94s/it]

Epoch [100/100], Train Loss: 0.0156760909 , Val Loss: 0.0667034388





0,1
Train_loss,█▃▃▃▃▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁
Val_loss,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Train_loss,0.01568
Val_loss,0.0667


In [17]:
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint['model'])
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
model.to(device)

  checkpoint = torch.load('model.pth')


Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [18]:
model.eval()
for i in os.listdir("../data/test/test"):
    img_path = os.path.join("../data/test/test", i)
    ori_img = cv2.imread(img_path)
    ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
    ori_w = ori_img.shape[0]
    ori_h = ori_img.shape[1]
    img = cv2.resize(ori_img, (256, 256))
    transformed = val_transformation(image=img)
    input_img = transformed["image"]
    input_img = input_img.unsqueeze(0).to(device)
    with torch.no_grad():
        output_mask = model.forward(input_img).squeeze(0).cpu().numpy().transpose(1,2,0)
    mask = cv2.resize(output_mask, (ori_h, ori_w))
    mask = np.argmax(mask, axis=2)
    mask_rgb = mask_to_rgb(mask, color_dict)
    mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2BGR)
    cv2.imwrite("prediction/{}".format(i), mask_rgb) 

In [19]:
def rle_to_string(runs):
    return ' '.join(str(x) for x in runs)

def rle_encode_one_mask(mask):
    pixels = mask.flatten()
    pixels[pixels > 225] = 255
    pixels[pixels <= 225] = 0
    use_padding = False
    if pixels[0] or pixels[-1]:
        use_padding = True
        pixel_padded = np.zeros([len(pixels) + 2], dtype=pixels.dtype)
        pixel_padded[1:-1] = pixels
        pixels = pixel_padded
    rle = np.where(pixels[1:] != pixels[:-1])[0] + 2
    if use_padding:
        rle = rle - 1
    rle[1::2] = rle[1::2] - rle[:-1:2]
    
    return rle_to_string(rle)

def rle2mask(mask_rle, shape=(3,3)):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

def mask2string(dir):
    strings = []
    ids = []
    ws, hs = [[] for i in range(2)]
    for image_id in os.listdir(dir):
        id = image_id.split('.')[0]
        path = os.path.join(dir, image_id)
        print(path)
        img = cv2.imread(path)[:,:,::-1]
        h, w = img.shape[0], img.shape[1]
        for channel in range(2):
            ws.append(w)
            hs.append(h)
            ids.append(f'{id}_{channel}')
            string = rle_encode_one_mask(img[:,:,channel])
            strings.append(string)
    r = {
        'ids': ids,
        'strings': strings,
    }
    return r


MASK_DIR_PATH = 'prediction/'
dir = MASK_DIR_PATH
res = mask2string(dir)
df = pd.DataFrame(columns=['Id', 'Expected'])
df['Id'] = res['ids']
df['Expected'] = res['strings']

df.to_csv(r'output.csv', index=False)

prediction/019410b1fcf0625f608b4ce97629ab55.jpeg
prediction/02fa602bb3c7abacdbd7e6afd56ea7bc.jpeg
prediction/0398846f67b5df7cdf3f33c3ca4d5060.jpeg
prediction/05734fbeedd0f9da760db74a29abdb04.jpeg
prediction/05b78a91391adc0bb223c4eaf3372eae.jpeg
prediction/0619ebebe9e9c9d00a4262b4fe4a5a95.jpeg
prediction/0626ab4ec3d46e602b296cc5cfd263f1.jpeg
prediction/0a0317371a966bf4b3466463a3c64db1.jpeg
prediction/0a5f3601ad4f13ccf1f4b331a412fc44.jpeg
prediction/0af3feff05dec1eb3a70b145a7d8d3b6.jpeg
prediction/0fca6a4248a41e8db8b4ed633b456aaa.jpeg
prediction/1002ec4a1fe748f3085f1ce88cbdf366.jpeg
prediction/1209db6dcdda5cc8a788edaeb6aa460a.jpeg
prediction/13dd311a65d2b46d0a6085835c525af6.jpeg
prediction/1531871f2fd85a04faeeb2b535797395.jpeg
prediction/15fc656702fa602bb3c7abacdbd7e6af.jpeg
prediction/1ad4f13ccf1f4b331a412fc44655fb51.jpeg
prediction/1b62f15ec83b97bb11e8e0c4416c1931.jpeg
prediction/1c0e9082ea2c193ac8d551c149b60f29.jpeg
prediction/1db239dda50f954ba59c7de13a35276a.jpeg
prediction/26679bff5