In [1]:
import os
import pandas as pd
import numpy as np
import cv2
from torchvision.io import read_image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, random_split, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from torchvision.transforms import ToTensor
from PIL import Image
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision 
from torchvision import transforms
from torchinfo import summary
import timm
import segmentation_models_pytorch as smp
import wandb

from torchgeometry.losses import DiceLoss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!nvidia-smi -L

GPU 0: NVIDIA GeForce RTX 3050 Ti Laptop GPU (UUID: GPU-922977af-d68d-ebac-3c8f-b0b483c19424)


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
class DatasetCustom(Dataset):
    def __init__(self, img_dir, label_dir, resize=None, transform=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.resize = resize
        self.transform = transform
        self.images = os.listdir(self.img_dir)

    def __len__(self):
        return len(self.images)
    
    def read_mask(self, mask_path):
        image = cv2.imread(mask_path)
        if self.resize:
            image = cv2.resize(image, self.resize)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

        lower_red1 = np.array([0, 100, 20])
        upper_red1 = np.array([10, 255, 255])
        lower_red2 = np.array([160,100,20])
        upper_red2 = np.array([179,255,255])
        
        lower_mask_red = cv2.inRange(image, lower_red1, upper_red1)
        upper_mask_red = cv2.inRange(image, lower_red2, upper_red2)
        
        red_mask = lower_mask_red + upper_mask_red
        red_mask[red_mask != 0] = 1

        green_mask = cv2.inRange(image, (36, 25, 25), (70, 255, 255))
        green_mask[green_mask != 0] = 2

        full_mask = cv2.bitwise_or(red_mask, green_mask)
        full_mask = np.expand_dims(full_mask, axis=-1) 
        full_mask = full_mask.astype(np.uint8)
        
        return full_mask

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        label_path = os.path.join(self.label_dir, self.images[idx])
        image = cv2.imread(img_path)  #  BGR
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert to RGB
        label = self.read_mask(label_path)  
        image = cv2.resize(image, self.resize)
        if self.transform:
            image = self.transform(image)
            
        return image, label

In [5]:
images_path = "../data/train/train/"
image_path = []
TRAIN_DIR = '../data/train/train'
for root, dirs, files in os.walk(TRAIN_DIR):
    for file in files:
        path = os.path.join(root,file)
        image_path.append(path)
        
len(image_path)

1000

In [6]:
mask_path = []
TRAIN_MASK_DIR = '../data/train_gt/train_gt'
for root, dirs, files in os.walk(TRAIN_MASK_DIR):
    for file in files:
        path = os.path.join(root,file)
        mask_path.append(path)
        
len(mask_path)

1000

In [7]:
dataset = DatasetCustom(img_dir= TRAIN_DIR,
                             label_dir= TRAIN_MASK_DIR,
                             resize= (256,256),
                             transform = None)

In [8]:
batch_size = 8
images_data = []
labels_data = []
for x,y in dataset:
    images_data.append(x)
    labels_data.append(y)

In [9]:
import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",        
    encoder_weights="imagenet",     
    in_channels=3,                  
    classes=3     
)

In [10]:
class CustomDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = targets
        self.transform = transform

    def __getitem__(self, index):
        image = self.data[index]
        label = self.targets[index]
        assert image.shape[:2] == label.shape[:2]
        if self.transform:
            transformed = self.transform(image=image, mask=label)
            image = transformed['image'].float()
            label = transformed['mask'].float()
            label = label.permute(2, 0, 1)
        return image, label
    
    def __len__(self):
        return len(self.data)

In [11]:
train_transformation = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomCrop(height=256, width=256, p=0.5),
    A.RandomGamma(gamma_limit=(70, 130), eps=None, always_apply=False, p=0.2),
    A.RGBShift(p=0.3, r_shift_limit=10, g_shift_limit=10, b_shift_limit=10),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transformation = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

  A.RandomGamma(gamma_limit=(70, 130), eps=None, always_apply=False, p=0.2),


In [12]:
train_size = int(0.8 * len(images_data))
val_size = len(images_data) - train_size
train_dataset = CustomDataset(images_data[:train_size], labels_data[:train_size], transform=train_transformation)
val_dataset = CustomDataset(images_data[train_size:], labels_data[train_size:], transform=val_transformation)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [13]:
learning_rate = 0.0001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [14]:
color_dict= {0: (0, 0, 0),
             1: (255, 0, 0),
             2: (0, 255, 0)}
def mask_to_rgb(mask, color_dict):
    output = np.zeros((mask.shape[0], mask.shape[1], 3))

    for k in color_dict.keys():
        output[mask==k] = color_dict[k]

    return np.uint8(output)    

In [15]:
wandb.init(
    project = "DlPolypSegment"
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtuantruongvu[0m ([33mtuantruongvu-hanoi-university-of-science-and-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [16]:
from tqdm import tqdm
import time

num_epochs = 100

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
model.to(device)
criterion1 = nn.CrossEntropyLoss()
criterion2 = DiceLoss()
best_val_loss = 999

epoch_bar = tqdm(total=num_epochs, desc='Total Progress')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        labels = labels.squeeze(dim=1).long()
        outputs = model(images)
    
        loss1 = criterion1(outputs, labels)
        loss2 = criterion2(outputs, labels)
        loss = loss1 + loss2
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            labels = labels.squeeze(dim=1).long()
            
            outputs = model(images)

            val_loss1 = criterion1(outputs,labels)
            val_loss2 = criterion2(outputs,labels)
            val_loss += val_loss1 + val_loss2

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss/len(train_loader):.10f} , Val Loss: {val_loss/len(val_loader):.10f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint = { 
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': val_loss,
        }
        save_path = f'model.pth'
        torch.save(checkpoint, save_path)
        
    epoch_bar.update(1)
    wandb.log({'Val_loss': val_loss/len(val_loader),'Train_loss': train_loss/len(train_loader)})
epoch_bar.close()

Total Progress:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/100], Train Loss: 1.3581176269 , Val Loss: 0.8742238283


Total Progress:   1%|          | 1/100 [00:16<27:31, 16.68s/it]

Epoch [2/100], Train Loss: 0.6321579900 , Val Loss: 0.4459774792


Total Progress:   2%|▏         | 2/100 [00:31<25:46, 15.78s/it]

Epoch [3/100], Train Loss: 0.3531374276 , Val Loss: 0.2740578055


Total Progress:   3%|▎         | 3/100 [00:47<25:09, 15.57s/it]

Epoch [4/100], Train Loss: 0.2303303611 , Val Loss: 0.1892667562


Total Progress:   4%|▍         | 4/100 [01:03<25:10, 15.74s/it]

Epoch [5/100], Train Loss: 0.1680121727 , Val Loss: 0.1491367072


Total Progress:   5%|▌         | 5/100 [01:18<24:52, 15.71s/it]

Epoch [6/100], Train Loss: 0.1342470510 , Val Loss: 0.1246974170


Total Progress:   6%|▌         | 6/100 [01:33<24:20, 15.53s/it]

Epoch [7/100], Train Loss: 0.1105819168 , Val Loss: 0.1101419032


Total Progress:   8%|▊         | 8/100 [02:04<23:24, 15.27s/it]

Epoch [8/100], Train Loss: 0.0959133750 , Val Loss: 0.1123297513
Epoch [9/100], Train Loss: 0.0906370817 , Val Loss: 0.0930941924


Total Progress:  10%|█         | 10/100 [02:35<22:55, 15.28s/it]

Epoch [10/100], Train Loss: 0.0830211243 , Val Loss: 0.0978180319


Total Progress:  11%|█         | 11/100 [02:49<22:26, 15.13s/it]

Epoch [11/100], Train Loss: 0.0804234630 , Val Loss: 0.0989124626
Epoch [12/100], Train Loss: 0.0806460091 , Val Loss: 0.0870253742


Total Progress:  12%|█▏        | 12/100 [03:04<22:12, 15.14s/it]

Epoch [13/100], Train Loss: 0.0716272062 , Val Loss: 0.0856924057


Total Progress:  13%|█▎        | 13/100 [03:20<21:57, 15.14s/it]

Epoch [14/100], Train Loss: 0.0663383986 , Val Loss: 0.0823409557


Total Progress:  15%|█▌        | 15/100 [03:50<21:22, 15.09s/it]

Epoch [15/100], Train Loss: 0.0611123418 , Val Loss: 0.0826050565
Epoch [16/100], Train Loss: 0.0609049870 , Val Loss: 0.0793058798


Total Progress:  16%|█▌        | 16/100 [04:06<21:24, 15.29s/it]

Epoch [17/100], Train Loss: 0.0556600227 , Val Loss: 0.0701037794


Total Progress:  18%|█▊        | 18/100 [04:36<20:51, 15.26s/it]

Epoch [18/100], Train Loss: 0.0509163867 , Val Loss: 0.0719923079


Total Progress:  19%|█▉        | 19/100 [04:51<20:25, 15.13s/it]

Epoch [19/100], Train Loss: 0.0527880571 , Val Loss: 0.0703394935
Epoch [20/100], Train Loss: 0.0463762946 , Val Loss: 0.0686662793


Total Progress:  21%|██        | 21/100 [05:21<19:50, 15.07s/it]

Epoch [21/100], Train Loss: 0.0496830301 , Val Loss: 0.0715233013


Total Progress:  22%|██▏       | 22/100 [05:36<19:29, 15.00s/it]

Epoch [22/100], Train Loss: 0.0493793577 , Val Loss: 0.0767694935
Epoch [23/100], Train Loss: 0.0497017920 , Val Loss: 0.0615422651


Total Progress:  24%|██▍       | 24/100 [06:06<19:05, 15.07s/it]

Epoch [24/100], Train Loss: 0.0398084719 , Val Loss: 0.0661726296


Total Progress:  25%|██▌       | 25/100 [06:21<18:44, 15.00s/it]

Epoch [25/100], Train Loss: 0.0382015822 , Val Loss: 0.0656309351


Total Progress:  26%|██▌       | 26/100 [06:36<18:25, 14.93s/it]

Epoch [26/100], Train Loss: 0.0368485443 , Val Loss: 0.0660573766


Total Progress:  27%|██▋       | 27/100 [06:51<18:07, 14.90s/it]

Epoch [27/100], Train Loss: 0.0437414907 , Val Loss: 0.0634634793


Total Progress:  28%|██▊       | 28/100 [07:06<17:52, 14.89s/it]

Epoch [28/100], Train Loss: 0.0405702850 , Val Loss: 0.0654426217


Total Progress:  29%|██▉       | 29/100 [07:21<17:36, 14.88s/it]

Epoch [29/100], Train Loss: 0.0392179883 , Val Loss: 0.0643109754


Total Progress:  30%|███       | 30/100 [07:35<17:20, 14.87s/it]

Epoch [30/100], Train Loss: 0.0382050926 , Val Loss: 0.0672450066
Epoch [31/100], Train Loss: 0.0369159583 , Val Loss: 0.0582945198


Total Progress:  32%|███▏      | 32/100 [08:06<17:07, 15.11s/it]

Epoch [32/100], Train Loss: 0.0299511776 , Val Loss: 0.0610341765


Total Progress:  33%|███▎      | 33/100 [08:21<16:48, 15.06s/it]

Epoch [33/100], Train Loss: 0.0315319639 , Val Loss: 0.0587797426


Total Progress:  34%|███▍      | 34/100 [08:36<16:30, 15.00s/it]

Epoch [34/100], Train Loss: 0.0292355368 , Val Loss: 0.0607857816


Total Progress:  35%|███▌      | 35/100 [08:51<16:12, 14.96s/it]

Epoch [35/100], Train Loss: 0.0310535219 , Val Loss: 0.0667238533


Total Progress:  36%|███▌      | 36/100 [09:06<15:56, 14.95s/it]

Epoch [36/100], Train Loss: 0.0307912674 , Val Loss: 0.0667355359


Total Progress:  37%|███▋      | 37/100 [09:21<15:40, 14.94s/it]

Epoch [37/100], Train Loss: 0.0390232862 , Val Loss: 0.0744055510


Total Progress:  38%|███▊      | 38/100 [09:36<15:26, 14.94s/it]

Epoch [38/100], Train Loss: 0.0452097284 , Val Loss: 0.0873742253


Total Progress:  39%|███▉      | 39/100 [09:51<15:10, 14.93s/it]

Epoch [39/100], Train Loss: 0.0425047737 , Val Loss: 0.0905403197


Total Progress:  40%|████      | 40/100 [10:05<14:54, 14.91s/it]

Epoch [40/100], Train Loss: 0.0324445698 , Val Loss: 0.0692761838


Total Progress:  41%|████      | 41/100 [10:20<14:38, 14.90s/it]

Epoch [41/100], Train Loss: 0.0347634821 , Val Loss: 0.0723378137


Total Progress:  42%|████▏     | 42/100 [10:35<14:23, 14.89s/it]

Epoch [42/100], Train Loss: 0.0300104786 , Val Loss: 0.0678314269


Total Progress:  43%|████▎     | 43/100 [10:50<14:08, 14.88s/it]

Epoch [43/100], Train Loss: 0.0257228228 , Val Loss: 0.0625303984


Total Progress:  44%|████▍     | 44/100 [11:05<13:51, 14.84s/it]

Epoch [44/100], Train Loss: 0.0287613211 , Val Loss: 0.0668816268


Total Progress:  45%|████▌     | 45/100 [11:20<13:35, 14.83s/it]

Epoch [45/100], Train Loss: 0.0276837957 , Val Loss: 0.0621629879


Total Progress:  46%|████▌     | 46/100 [11:34<13:20, 14.82s/it]

Epoch [46/100], Train Loss: 0.0283053428 , Val Loss: 0.0687312037


Total Progress:  47%|████▋     | 47/100 [11:49<13:06, 14.83s/it]

Epoch [47/100], Train Loss: 0.0269667692 , Val Loss: 0.0655271709


Total Progress:  48%|████▊     | 48/100 [12:04<12:51, 14.83s/it]

Epoch [48/100], Train Loss: 0.0269035924 , Val Loss: 0.0800674707


Total Progress:  49%|████▉     | 49/100 [12:19<12:37, 14.85s/it]

Epoch [49/100], Train Loss: 0.0245233109 , Val Loss: 0.0608966425


Total Progress:  50%|█████     | 50/100 [12:34<12:23, 14.86s/it]

Epoch [50/100], Train Loss: 0.0278783104 , Val Loss: 0.0654346868


Total Progress:  51%|█████     | 51/100 [12:49<12:08, 14.88s/it]

Epoch [51/100], Train Loss: 0.0288465328 , Val Loss: 0.0810491443


Total Progress:  52%|█████▏    | 52/100 [13:04<11:53, 14.87s/it]

Epoch [52/100], Train Loss: 0.0268786007 , Val Loss: 0.0688986331


Total Progress:  53%|█████▎    | 53/100 [13:19<11:39, 14.88s/it]

Epoch [53/100], Train Loss: 0.0261986896 , Val Loss: 0.0606727004


Total Progress:  54%|█████▍    | 54/100 [13:33<11:25, 14.89s/it]

Epoch [54/100], Train Loss: 0.0242486166 , Val Loss: 0.0652684048


Total Progress:  55%|█████▌    | 55/100 [13:49<11:13, 14.96s/it]

Epoch [55/100], Train Loss: 0.0240626787 , Val Loss: 0.0647045895


Total Progress:  56%|█████▌    | 56/100 [14:04<10:59, 14.98s/it]

Epoch [56/100], Train Loss: 0.0213877697 , Val Loss: 0.0645239800


Total Progress:  57%|█████▋    | 57/100 [14:19<10:45, 15.01s/it]

Epoch [57/100], Train Loss: 0.0221558333 , Val Loss: 0.0605948828


Total Progress:  58%|█████▊    | 58/100 [14:34<10:30, 15.02s/it]

Epoch [58/100], Train Loss: 0.0211407320 , Val Loss: 0.0684251562


Total Progress:  59%|█████▉    | 59/100 [14:49<10:15, 15.01s/it]

Epoch [59/100], Train Loss: 0.0219269997 , Val Loss: 0.0708676502


Total Progress:  60%|██████    | 60/100 [15:04<09:59, 15.00s/it]

Epoch [60/100], Train Loss: 0.0211554435 , Val Loss: 0.0645374283


Total Progress:  61%|██████    | 61/100 [15:19<09:44, 14.98s/it]

Epoch [61/100], Train Loss: 0.0212643852 , Val Loss: 0.0701467544


Total Progress:  62%|██████▏   | 62/100 [15:34<09:31, 15.03s/it]

Epoch [62/100], Train Loss: 0.0218652181 , Val Loss: 0.0684527531


Total Progress:  63%|██████▎   | 63/100 [15:49<09:17, 15.06s/it]

Epoch [63/100], Train Loss: 0.0220478219 , Val Loss: 0.0728445873


Total Progress:  64%|██████▍   | 64/100 [16:04<09:02, 15.08s/it]

Epoch [64/100], Train Loss: 0.0228651497 , Val Loss: 0.0624211505


Total Progress:  65%|██████▌   | 65/100 [16:19<08:47, 15.08s/it]

Epoch [65/100], Train Loss: 0.0199321223 , Val Loss: 0.0701500326


Total Progress:  66%|██████▌   | 66/100 [16:34<08:32, 15.06s/it]

Epoch [66/100], Train Loss: 0.0309367807 , Val Loss: 0.0794876590


Total Progress:  67%|██████▋   | 67/100 [16:49<08:16, 15.04s/it]

Epoch [67/100], Train Loss: 0.0301326171 , Val Loss: 0.0652298033


Total Progress:  68%|██████▊   | 68/100 [17:04<08:00, 15.02s/it]

Epoch [68/100], Train Loss: 0.0243304429 , Val Loss: 0.0827244148


Total Progress:  69%|██████▉   | 69/100 [17:19<07:44, 14.99s/it]

Epoch [69/100], Train Loss: 0.0219703391 , Val Loss: 0.0676845014


Total Progress:  70%|███████   | 70/100 [17:34<07:29, 14.97s/it]

Epoch [70/100], Train Loss: 0.0210611217 , Val Loss: 0.0655620098


Total Progress:  71%|███████   | 71/100 [17:49<07:13, 14.94s/it]

Epoch [71/100], Train Loss: 0.0197185687 , Val Loss: 0.0701412708


Total Progress:  72%|███████▏  | 72/100 [18:04<06:57, 14.93s/it]

Epoch [72/100], Train Loss: 0.0216375938 , Val Loss: 0.0719731078


Total Progress:  73%|███████▎  | 73/100 [18:19<06:42, 14.92s/it]

Epoch [73/100], Train Loss: 0.0255287470 , Val Loss: 0.0739833787


Total Progress:  74%|███████▍  | 74/100 [18:34<06:27, 14.90s/it]

Epoch [74/100], Train Loss: 0.0252028769 , Val Loss: 0.0787199065


Total Progress:  75%|███████▌  | 75/100 [18:48<06:12, 14.89s/it]

Epoch [75/100], Train Loss: 0.0274187556 , Val Loss: 0.0740900040


Total Progress:  76%|███████▌  | 76/100 [19:03<05:56, 14.85s/it]

Epoch [76/100], Train Loss: 0.0211062838 , Val Loss: 0.0630265102


Total Progress:  77%|███████▋  | 77/100 [19:18<05:40, 14.80s/it]

Epoch [77/100], Train Loss: 0.0182404390 , Val Loss: 0.0618452728


Total Progress:  78%|███████▊  | 78/100 [19:33<05:25, 14.79s/it]

Epoch [78/100], Train Loss: 0.0179517763 , Val Loss: 0.0630413592


Total Progress:  79%|███████▉  | 79/100 [19:47<05:10, 14.79s/it]

Epoch [79/100], Train Loss: 0.0168786588 , Val Loss: 0.0644734651


Total Progress:  80%|████████  | 80/100 [20:02<04:55, 14.77s/it]

Epoch [80/100], Train Loss: 0.0159527617 , Val Loss: 0.0625504330


Total Progress:  81%|████████  | 81/100 [20:17<04:40, 14.76s/it]

Epoch [81/100], Train Loss: 0.0148618616 , Val Loss: 0.0598094836


Total Progress:  82%|████████▏ | 82/100 [20:32<04:25, 14.77s/it]

Epoch [82/100], Train Loss: 0.0146446790 , Val Loss: 0.0612473860


Total Progress:  83%|████████▎ | 83/100 [20:46<04:11, 14.78s/it]

Epoch [83/100], Train Loss: 0.0161651107 , Val Loss: 0.0648224279


Total Progress:  84%|████████▍ | 84/100 [21:01<03:56, 14.79s/it]

Epoch [84/100], Train Loss: 0.0142468866 , Val Loss: 0.0658435300


Total Progress:  85%|████████▌ | 85/100 [21:16<03:41, 14.78s/it]

Epoch [85/100], Train Loss: 0.0165853059 , Val Loss: 0.0654491484


Total Progress:  86%|████████▌ | 86/100 [21:31<03:26, 14.78s/it]

Epoch [86/100], Train Loss: 0.0169472556 , Val Loss: 0.0631209239


Total Progress:  87%|████████▋ | 87/100 [21:45<03:11, 14.76s/it]

Epoch [87/100], Train Loss: 0.0150140364 , Val Loss: 0.0671034306


Total Progress:  88%|████████▊ | 88/100 [22:00<02:57, 14.75s/it]

Epoch [88/100], Train Loss: 0.0147660785 , Val Loss: 0.0609422512


Total Progress:  89%|████████▉ | 89/100 [22:15<02:42, 14.74s/it]

Epoch [89/100], Train Loss: 0.0240407889 , Val Loss: 0.0680233836


Total Progress:  90%|█████████ | 90/100 [22:30<02:27, 14.74s/it]

Epoch [90/100], Train Loss: 0.0269013304 , Val Loss: 0.0918786302


Total Progress:  91%|█████████ | 91/100 [22:44<02:12, 14.73s/it]

Epoch [91/100], Train Loss: 0.0218923834 , Val Loss: 0.0711963624


Total Progress:  92%|█████████▏| 92/100 [22:59<01:57, 14.73s/it]

Epoch [92/100], Train Loss: 0.0177498789 , Val Loss: 0.0585445091


Total Progress:  93%|█████████▎| 93/100 [23:14<01:43, 14.73s/it]

Epoch [93/100], Train Loss: 0.0186421738 , Val Loss: 0.0684938729


Total Progress:  94%|█████████▍| 94/100 [23:29<01:28, 14.72s/it]

Epoch [94/100], Train Loss: 0.0210145186 , Val Loss: 0.0698609874


Total Progress:  95%|█████████▌| 95/100 [23:43<01:13, 14.73s/it]

Epoch [95/100], Train Loss: 0.0172546427 , Val Loss: 0.0781958848


Total Progress:  96%|█████████▌| 96/100 [23:58<00:58, 14.73s/it]

Epoch [96/100], Train Loss: 0.0171769367 , Val Loss: 0.0761600509


Total Progress:  97%|█████████▋| 97/100 [24:13<00:44, 14.73s/it]

Epoch [97/100], Train Loss: 0.0160684142 , Val Loss: 0.0744323283


Total Progress:  98%|█████████▊| 98/100 [24:27<00:29, 14.73s/it]

Epoch [98/100], Train Loss: 0.0153252096 , Val Loss: 0.0802170858


Total Progress:  99%|█████████▉| 99/100 [24:42<00:14, 14.73s/it]

Epoch [99/100], Train Loss: 0.0144642045 , Val Loss: 0.0681337640


Total Progress: 100%|██████████| 100/100 [24:57<00:00, 14.97s/it]

Epoch [100/100], Train Loss: 0.0205203051 , Val Loss: 0.0789717808





In [20]:
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint['model'])
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
model.to(device)

  checkpoint = torch.load('model.pth')


Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [21]:
model.eval()
for i in os.listdir("../data/test/test"):
    img_path = os.path.join("../data/test/test", i)
    ori_img = cv2.imread(img_path)
    ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
    ori_w = ori_img.shape[0]
    ori_h = ori_img.shape[1]
    img = cv2.resize(ori_img, (256, 256))
    transformed = val_transformation(image=img)
    input_img = transformed["image"]
    input_img = input_img.unsqueeze(0).to(device)
    with torch.no_grad():
        output_mask = model.forward(input_img).squeeze(0).cpu().numpy().transpose(1,2,0)
    mask = cv2.resize(output_mask, (ori_h, ori_w))
    mask = np.argmax(mask, axis=2)
    mask_rgb = mask_to_rgb(mask, color_dict)
    mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2BGR)
    cv2.imwrite("prediction/{}".format(i), mask_rgb) 

In [22]:
def rle_to_string(runs):
    return ' '.join(str(x) for x in runs)

def rle_encode_one_mask(mask):
    pixels = mask.flatten()
    pixels[pixels > 225] = 255
    pixels[pixels <= 225] = 0
    use_padding = False
    if pixels[0] or pixels[-1]:
        use_padding = True
        pixel_padded = np.zeros([len(pixels) + 2], dtype=pixels.dtype)
        pixel_padded[1:-1] = pixels
        pixels = pixel_padded
    rle = np.where(pixels[1:] != pixels[:-1])[0] + 2
    if use_padding:
        rle = rle - 1
    rle[1::2] = rle[1::2] - rle[:-1:2]
    
    return rle_to_string(rle)

def rle2mask(mask_rle, shape=(3,3)):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

def mask2string(dir):
    strings = []
    ids = []
    ws, hs = [[] for i in range(2)]
    for image_id in os.listdir(dir):
        id = image_id.split('.')[0]
        path = os.path.join(dir, image_id)
        print(path)
        img = cv2.imread(path)[:,:,::-1]
        h, w = img.shape[0], img.shape[1]
        for channel in range(2):
            ws.append(w)
            hs.append(h)
            ids.append(f'{id}_{channel}')
            string = rle_encode_one_mask(img[:,:,channel])
            strings.append(string)
    r = {
        'ids': ids,
        'strings': strings,
    }
    return r


MASK_DIR_PATH = 'prediction/'
dir = MASK_DIR_PATH
res = mask2string(dir)
df = pd.DataFrame(columns=['Id', 'Expected'])
df['Id'] = res['ids']
df['Expected'] = res['strings']

df.to_csv(r'output.csv', index=False)

prediction/019410b1fcf0625f608b4ce97629ab55.jpeg
prediction/02fa602bb3c7abacdbd7e6afd56ea7bc.jpeg
prediction/0398846f67b5df7cdf3f33c3ca4d5060.jpeg
prediction/05734fbeedd0f9da760db74a29abdb04.jpeg
prediction/05b78a91391adc0bb223c4eaf3372eae.jpeg
prediction/0619ebebe9e9c9d00a4262b4fe4a5a95.jpeg
prediction/0626ab4ec3d46e602b296cc5cfd263f1.jpeg
prediction/0a0317371a966bf4b3466463a3c64db1.jpeg
prediction/0a5f3601ad4f13ccf1f4b331a412fc44.jpeg
prediction/0af3feff05dec1eb3a70b145a7d8d3b6.jpeg
prediction/0fca6a4248a41e8db8b4ed633b456aaa.jpeg
prediction/1002ec4a1fe748f3085f1ce88cbdf366.jpeg
prediction/1209db6dcdda5cc8a788edaeb6aa460a.jpeg
prediction/13dd311a65d2b46d0a6085835c525af6.jpeg
prediction/1531871f2fd85a04faeeb2b535797395.jpeg
prediction/15fc656702fa602bb3c7abacdbd7e6af.jpeg
prediction/1ad4f13ccf1f4b331a412fc44655fb51.jpeg
prediction/1b62f15ec83b97bb11e8e0c4416c1931.jpeg
prediction/1c0e9082ea2c193ac8d551c149b60f29.jpeg
prediction/1db239dda50f954ba59c7de13a35276a.jpeg
prediction/26679bff5