In [2]:
!pip install opendatasets

Collecting opendatasets
  Downloading opendatasets-0.1.22-py3-none-any.whl.metadata (9.2 kB)
Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)
Installing collected packages: opendatasets
Successfully installed opendatasets-0.1.22


In [4]:
import opendatasets as od
od.download('www.kaggle.com/datasets/prasunroy/natural-images')

Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username: hoangquangduy
Your Kaggle Key: ··········
Dataset URL: https://www.kaggle.com/datasets/prasunroy/natural-images
Downloading natural-images.zip to ./natural-images


100%|██████████| 342M/342M [00:04<00:00, 76.4MB/s]







Skipping, found downloaded files in "./natural-images" (use force=True to force download)


In [5]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from math import log2, sqrt
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2
import numpy as np
import random
from math import exp

In [6]:
DATASET = '/content/natural-images/natural_images'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 32
EPOCHS = 120
LR = 2e-4
SEED = 42
IM_SIZE = 128
torch.cuda.manual_seed(SEED)

In [7]:
def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))


    return ssim_map.mean()

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel)

In [8]:
class MaskMaker():

    def __init__(self, shape):
        self.shape = shape

        self.action_list = [[1, 0], [-1, 0], [0, 1], [0, -1], [1, 1], [-1, -1]]

    def pos_clip(self, pix, img_size):
        if pix < 0:
            return 0
        elif pix > img_size - 1:
            return img_size - 1

        return pix

    def random_walk(self, canvas, ini_x, ini_y, length):
        x, y = ini_x, ini_y
        img_size = canvas.shape

        for i in range(length):
            r = random.randint(0, len(self.action_list) - 1)

            x += self.action_list[r][0]
            y += self.action_list[r][1]

            x = self.pos_clip(x, img_size[0])
            y = self.pos_clip(y, img_size[1])
            canvas[x, y] = 0

        return canvas

    def forward(self):
        image_size = self.shape

        canvas = np.ones((image_size[0],image_size[1])).astype("i")
        ini_x = random.randint(0, image_size[0]-1)
        ini_y = random.randint(0, image_size[1]-1)

        mask = self.random_walk(canvas,ini_x,ini_y,int(image_size[0] * image_size[1]))

        return mask

In [25]:
""" Parts of the U-Net model """
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=3, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

In [10]:
def get_loader(train_ratio=0.8):
    transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((IM_SIZE, IM_SIZE)),
                # transforms.Normalize(
                #     [0.5 for _ in range(3 )], [0.5 for _ in range(3)]
                # ),
            ]
        )

    dataset = datasets.ImageFolder(
        root=DATASET,
        transform=transform,
    )

    datasize = len(dataset)
    trainsize = int(train_ratio * datasize)
    testsize = datasize - trainsize

    train_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [trainsize, testsize]
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True
    )

    return train_loader, test_loader

In [11]:
def make_input(img_batch):
    preproc = MaskMaker((IM_SIZE, IM_SIZE))
    newinput = []

    for i in range(img_batch.shape[0]):
        img = img_batch[i]

        mask = preproc.forward()
        mask = torch.from_numpy(mask).unsqueeze(0).to(device)

        new_img = mask * img
        # new_img = new_img.to(torch.float32)

        newinput.append(new_img)

    newinput = torch.stack(newinput)
    return newinput


In [12]:
train_loader, test_loader = get_loader()

In [13]:
from tqdm import tqdm
import gc

def training(model,
             train_loader,
             optimizer,
             loss_fn,
             ssim):
    loss_ = []
    ssim_capture = []

    for epoch in tqdm(range(EPOCHS)):
        model.train()

        loss_sum = 0.0  # Initialize loss sum for this epoch
        ssim_avg = 0.0
        for batch_idx, (images, _) in enumerate(train_loader):
            images = images.to(device)
            inp = make_input(images).to(device)

            pred = model(inp)

            loss = loss_fn(pred, images)
            ssim_ = ssim(pred, images)

            loss += (1 - ssim_)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_sum += loss.item()
            ssim_avg += ssim_.item()


        loss_.append(loss_sum / len(train_loader))
        ssim_capture.append(ssim_avg / len(train_loader))
        print(f'Epoch {epoch + 1}/{EPOCHS}, Loss: {loss_sum / len(train_loader)}, SSIM: {ssim_avg / len(train_loader)}')
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    # Vẽ loss trên ax[0]
    ax[0].plot(list(range(70)), loss_[:70])
    ax[0].set_title('Loss')
    ax[0].set_xlabel('Epochs')
    ax[0].set_ylabel('Value')

    ax[1].plot(list(range(70)), ssim_capture[:70])
    ax[1].set_title('Chỉ số SSIM')
    ax[1].set_xlabel('Epochs')
    ax[1].set_ylabel('Value')

    plt.tight_layout()
    plt.show()

    return loss_, ssim_capture

In [14]:
def sharpen(image):
    image = np.clip(image, 0, 255).astype(np.uint8)

    kernel_size = (31, 31)  # Kích thước kernel
    sigma = 5.0
    blurred_image = cv2.GaussianBlur(image, kernel_size, sigma)

    sharpened_image_unsharp = cv2.addWeighted(image, 1.5, blurred_image, -0.5, 0)

    kernel_sharpening = np.array([[0, -0.5, 0],
                               [-0.5, 3, -0.5],
                               [0, -0.5, 0]])
    sharpened_image_kernel = cv2.filter2D(sharpened_image_unsharp, -1, kernel_sharpening)

    return sharpened_image_kernel


In [15]:

def plot_result(in_imgs, pred_img, true_img):
    num = in_imgs.shape[0]
    if num <= 1:
        fig, axs = plt.subplots(1, 3, figsize=(9, 2 * num))

        img = in_imgs[0].cpu()
        img = img * 255
        img = img.int().permute(1, 2, 0).numpy()
        axs[0].imshow(img)
        axs[0].axis('off')

        img = pred_img[0].cpu()
        img = img * 255
        img = img.int().permute(1, 2, 0).numpy()
        img = sharpen(img)
        axs[1].imshow(img)
        axs[1].axis('off')

        img = true_img[0].cpu()
        img = img.int().permute(1, 2, 0).detach().numpy()
        axs[2].imshow(img)
        axs[2].axis('off')
        return

    fig, axs = plt.subplots(num, 3, figsize=(9, 3 * num))

    for i in range(num):
        img = in_imgs[i].cpu()
        img = img * 255
        img = img.int().permute(1, 2, 0).numpy()
        axs[i, 0].imshow(img)
        axs[i, 0].axis('off')

    for i in range(num):
        img = pred_img[i].cpu()
        img = img * 255
        img = img.int().permute(1, 2, 0).numpy()
        img = sharpen(img)
        axs[i, 1].imshow(img)
        axs[i, 1].axis('off')

    for i in range(num):
        img = true_img[i].cpu()
        img = img * 255
        img = img.int().permute(1, 2, 0).detach().numpy()
        axs[i, 2].imshow(img)
        axs[i, 2].axis('off')

    plt.subplots_adjust(wspace=0, hspace=0)

    plt.show()


In [27]:
def training_phase(model_path='',
                   use_pretrain=True,
                   device='cuda'):
    if use_pretrain:
        print('Using pre-trained model----')
        net = UNet().to(device)
        net.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device(device)))
        net.eval()
    else:
        print('Start training new model----')
        net = UNet().to(device)
        optimizer = torch.optim.Adam(net.parameters(), lr=LR)
        loss_fn = nn.MSELoss().to(device)
        ssim = SSIM().to(device)

        loss_, ssim_capture = training(net,
                                       train_loader,
                                       optimizer,
                                       loss_fn,
                                       ssim)
    return net

In [17]:
def test_on_single_image(net, img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_torch = torch.from_numpy(np.asarray(img, dtype=np.uint8)).permute(2, 0, 1)

    masker = MaskMaker(img_torch.shape[1:]).forward()
    torch.cuda.empty_cache()
    with torch.no_grad():
        mask = torch.from_numpy(masker)
        img_ = img_torch * mask.unsqueeze(0)
        img_ = img_ / 255.0

        pred = net(img_.unsqueeze(0).to(device))
        plot_result(img_.unsqueeze(0), pred, img_torch.unsqueeze(0))

def test_on_many_image(net, num_of_batches, testloader):

    maks = MaskMaker((IM_SIZE, IM_SIZE))
    with torch.inference_mode():
        net.eval()
        for batch_idx, (images, _) in enumerate(test_loader):
            images = images.to(device)
            ip = make_input(images).to(device)
            y_pred = net(ip)

            plot_result(ip, y_pred, images)
            i += 1
            if i == num_of_batches:
                break

In [None]:
net = training_phase(model_path='/Models/checkpoint.pth', use_pretrain=True, device=device)
test_on_single_image(net, '/images/test_image1.jpg')