# Mount tới Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/MAP_ROAD/

# Import thư viện cần thiết

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms

import os
import cv2
import numpy as np
import albumentations as A
import matplotlib.pyplot as plt
import time


from tqdm import tqdm
from copy import deepcopy
from datetime import datetime
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

# Configuration

In [51]:

TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}'  # Cần thiết cho tqdm

# DATASET
images_file = 'data/images' # Tập chứa ảnh training
segs_file = 'data/labels'  # Tập chứa label của ảnh training
test_file = 'data/test'   # Tập chứa ảnh test
save_dir = 'train_val'   # Path lưu kết quả của tập test
ckpt_file = "best.pt"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
val_size = 0.1          # Tỷ lệ chia tập val

# AUGMENTATION
augmentation = True   # True nếu muốn tăng cường dữ liệu
ColorJitter = True
Rotate = True
GridDistortion = True

# TRAINING
lr = 5e-2         # learning rate
num_epochs = 35  # Số lượng epoch training
batch_size = 16   # Batch size

# Chuẩn bị dữ liệu (Data Loader)


In [52]:
class CustomDataLoader(Dataset):
    def __init__(self, data,
                     images_file,
                     segs_file,
                     is_val,
                     augmentation=False,
                     ColorJitter=False,
                     Rotate=False,
                     GridDistortion=False):

        self.ims_file = data
        self.images_file = images_file
        self.segs_file = segs_file

        self.Tensor = transforms.ToTensor()

        if is_val:
            self.aug = [False, ColorJitter, Rotate, GridDistortion]
        else:
            self.aug = [augmentation, ColorJitter, Rotate, GridDistortion]

        self.do_aug = [A.ColorJitter(brightness=0.3, contrast=0.2, saturation=0.2, hue=0.2, always_apply=False, p=0.4), A.Rotate(limit=15, p=0.3), A.GridDistortion(p=0.3)]



    def __len__(self):
        return (len(self.ims_file))

    def img2seg(self, path):
        return path.replace(self.images_file, self.segs_file).replace('.jpg', '.png')


    def preprocess(self, img, seg):
        h0, w0 = img.shape[:2]

        img = cv2.resize(img, (160, 80), interpolation = cv2.INTER_LINEAR)
        seg = cv2.resize(seg, (160, 80), interpolation = cv2.INTER_LINEAR)

        img = img/255
        img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        img = np.ascontiguousarray(img)  # contiguous

        seg = seg/255
        seg = seg.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        seg = np.ascontiguousarray(seg)  # contiguous

        return img, seg


    def __getitem__(self, index):

        img = cv2.imread(self.ims_file[index])
        seg = cv2.imread(self.img2seg(self.ims_file[index]))


        if self.aug[0]:
            img_pre, seg_pre = self.preprocess(img, seg)
            img_out, seg_out = [torch.from_numpy(img_pre)], [seg_pre]
            for i, f in enumerate(self.do_aug):
                if self.aug[i+1]:
                    au = f(image=img, mask=seg)
                    img_aug, seg_aug = au['image'], au['mask']

                    img_aug, seg_aug = self.preprocess(img_aug, seg_aug)

                    img_out.append(torch.from_numpy(img_aug))
                    seg_out.append(torch.from_numpy(seg_aug))

            return img_out, seg_out

        else:
            img_pre, seg_pre = self.preprocess(img, seg)
            img_out, seg_out = [torch.from_numpy(img_pre)], [torch.from_numpy(seg_pre)]

            return img_out, seg_out

In [53]:
def split_data(images_file, segs_file, val_size):
    pbar = tqdm(os.listdir(images_file), total=len(os.listdir(images_file)), desc='Loading data', bar_format=TQDM_BAR_FORMAT)
    f = []
    for name in pbar:
        f.append(os.path.join(images_file, name))

    train, val = train_test_split(f, test_size=val_size, train_size=(1-val_size))

    return train, val

train, val = split_data(images_file, segs_file, val_size)

train_dataset = CustomDataLoader(train,
                                 images_file,
                                 segs_file,
                                 is_val=False,
                                 augmentation=augmentation,
                                 ColorJitter=ColorJitter,
                                 Rotate=Rotate,
                                 GridDistortion=GridDistortion)

val_dataset = CustomDataLoader(val,
                               images_file,
                               segs_file,
                               is_val=True,
                               augmentation=augmentation,
                               ColorJitter=ColorJitter,
                               Rotate=Rotate,
                               GridDistortion=GridDistortion)

print(f'Training size: {len(train_dataset)}')
print(f'Val size: {len(val_dataset)}')

train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)

Loading data: 100%|██████████| 317/317 [00:00<00:00, 611588.95it/s]

Training size: 285
Val size: 32





# Khởi tạo Model

In [54]:
def double_conv(in_ch, out_ch):
    conv_op = nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True)
    )
    return conv_op

class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=1):
        super(UNet, self).__init__()
        self.conv1 = double_conv(in_ch, 8)
        self.conv2 = double_conv(8, 16)
        self.conv3 = double_conv(16, 32)
        self.conv4 = double_conv(32, 64)

        self.conv5 = double_conv(96, 32)
        self.conv6 = double_conv(48, 16)
        self.conv7 = double_conv(24, 8)
        self.pooling = nn.MaxPool2d(kernel_size=2)

        self.upsample1 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)
        self.upsample2 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=2, stride=2)
        self.upsample3 = nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=2, stride=2)

        self.conv0 = nn.Conv2d(in_channels=8, out_channels=out_ch, kernel_size=1)


    def forward(self, x):
        #Encoder
        down1 = self.conv1(x)
        pool1 = self.pooling(down1)
        down2 = self.conv2(pool1)
        pool2 = self.pooling(down2)
        down3 = self.conv3(pool2)
        pool3 = self.pooling(down3)
        down4 = self.conv4(pool3)

        #Decoder
        upsample1 = self.upsample1(down4)
        cat1 = torch.cat([down3, upsample1], dim=1)
        up1 = self.conv5(cat1)
        upsample2 = self.upsample2(up1)
        cat2 = torch.cat([down2, upsample2], dim=1)
        up2 = self.conv6(cat2)
        upsample3 = self.upsample3(up2)
        cat3 = torch.cat([down1, upsample3], dim=1)
        up3 = self.conv7(cat3)

        outputs = self.conv0(up3)

        return outputs

img = torch.rand(1, 3, 80, 160)
model = UNet()
total_params = sum(p.numel() for p in model.parameters())

print(f'Input size: {img.size()}')
print(f'Total params: {total_params}')

Input size: torch.Size([1, 3, 80, 160])
Total params: 144433


# Training

In [55]:
def dice_loss(input: torch.Tensor, target: torch.Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
#     assert input.size() == target.size()
#     assert input.dim() == 3 or not reduce_batch_first

#     sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)

#     print(input.size())
#     print(target.size())
    inter = 2 * (input * target)
    sets_sum = input + target
    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)

    dice = (inter + epsilon) / (sets_sum + epsilon)
    return 1.0 - dice.mean()


def poly_lr_scheduler(lr, max_epochs, optimizer, epoch, power=2):
    lr = round(lr * (1 - epoch / max_epochs) ** power, 8)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr


def evaluate(model, val_loader):
    model.eval()
    num_val_batches = len(val_loader)
    dice_score = 0

    for i, (img_l, target_l) in enumerate(val_loader):

        for img, target in zip(img_l, target_l):
            img = img.to(device).float()
            true_masks = target.to(device).float()
            true_masks = torch.mean(true_masks, dim=1, keepdim=True)
            mask_pred = model(img)
            mask_pred = (torch.sigmoid(mask_pred) > 0.5).float()


            dice_score += dice_loss(mask_pred, true_masks, reduce_batch_first=False)


    model.train()

    return dice_score / max(num_val_batches, 1)

In [56]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr, (0.9, 0.999), eps=1e-08, weight_decay=5e-4)
model = model.to(device)

best = 10000
dice_score = 0

for epoch in range(num_epochs):

    poly_lr_scheduler(lr, num_epochs, optimizer, epoch)
    for param_group in optimizer.param_groups:
        lr = param_group['lr']

    model.train()

    print(('\n' + '%11s' * 4) % ('Epoch', 'Loss', 'Score', 'Lr'))
    pbar = enumerate(train_loader)
    total_batch = len(train_loader)
    pbar = tqdm(pbar, total=total_batch, bar_format=TQDM_BAR_FORMAT)

    for i, (img_l, target_l) in pbar:
        for img, target in zip(img_l, target_l):
            img = img.to(device).float()
            true_masks = target.to(device).float()
            true_masks = torch.mean(true_masks, dim=1, keepdim=True)
            mask_pred = model(img)

            optimizer.zero_grad()
            loss = criterion(mask_pred, true_masks)
            loss += dice_loss(torch.sigmoid(mask_pred), true_masks, reduce_batch_first=True)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            last = loss.item()

            pbar.set_description(('%13s' * 1 + '%13.4g'*3) %
                                     (f'{epoch}/{num_epochs - 1}', last, dice_score, lr))

            dice_score = evaluate(model, val_loader)

            if last < best:
                best = last

            ckpt = {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'loss': best,
                    'dice_score': dice_score,
                    'date': datetime.now().isoformat()
            }

            torch.save(ckpt, ckpt_file)



      Epoch       Loss      Score         Lr


         0/34       0.7488       0.1007         0.05: 100%|██████████| 18/18 [00:48<00:00,  2.69s/it]



      Epoch       Loss      Score         Lr


         1/34       0.7394       0.4364      0.04718: 100%|██████████| 18/18 [00:51<00:00,  2.87s/it]



      Epoch       Loss      Score         Lr


         2/34       0.1925       0.1062      0.04195: 100%|██████████| 18/18 [00:47<00:00,  2.66s/it]



      Epoch       Loss      Score         Lr


         3/34       0.1723       0.0525      0.03506: 100%|██████████| 18/18 [00:49<00:00,  2.74s/it]



      Epoch       Loss      Score         Lr


         4/34       0.1265       0.1475      0.02751: 100%|██████████| 18/18 [00:47<00:00,  2.65s/it]



      Epoch       Loss      Score         Lr


         5/34       0.1108       0.0394      0.02021: 100%|██████████| 18/18 [00:51<00:00,  2.85s/it]



      Epoch       Loss      Score         Lr


         6/34      0.07909      0.02515      0.01387: 100%|██████████| 18/18 [00:49<00:00,  2.74s/it]



      Epoch       Loss      Score         Lr


         7/34       0.1073      0.03489     0.008879: 100%|██████████| 18/18 [00:48<00:00,  2.69s/it]



      Epoch       Loss      Score         Lr


         8/34      0.08564      0.01844     0.005284: 100%|██████████| 18/18 [00:47<00:00,  2.65s/it]



      Epoch       Loss      Score         Lr


         9/34       0.1185      0.01579     0.002916: 100%|██████████| 18/18 [00:48<00:00,  2.68s/it]



      Epoch       Loss      Score         Lr


        10/34      0.08937      0.01344     0.001488: 100%|██████████| 18/18 [00:50<00:00,  2.83s/it]



      Epoch       Loss      Score         Lr


        11/34      0.07414      0.01412    0.0006995: 100%|██████████| 18/18 [00:47<00:00,  2.62s/it]



      Epoch       Loss      Score         Lr


        12/34      0.05401      0.01402    0.0003021:  61%|██████    | 11/18 [00:32<00:20,  2.93s/it]


KeyboardInterrupt: ignored

# Đánh giá với dữ liệu test

In [44]:
model = UNet()
checkpoint = torch.load(ckpt_file)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()

images = os.listdir(test_file)

for name in images:
    img = cv2.imread(os.path.join(test_file, name))
    img = cv2.resize(img, (160, 80), interpolation = cv2.INTER_LINEAR)

    img = img/255
    img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
    img = np.ascontiguousarray(img)  # contiguous

    img = torch.from_numpy(img)
    img = img.to(device).unsqueeze(0).float()

#     print(img.shape)

    with torch.no_grad():
        mask_pred = model(img)

    mask_pred = (torch.sigmoid(mask_pred) > 0.5).float()
    to_save = mask_pred.squeeze(0).squeeze(0)
    to_save = (to_save.cpu().numpy()*255).astype(np.uint8)

    save = f'train_val/{name}.jpg'
    cv2.imwrite(save, to_save)