In [1]:
from zipfile import ZipFile
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

from tqdm import tqdm
import os
import cv2

import matplotlib.pyplot as plt

In [2]:
def preproc_data():
    targets, samples = [], []

    with ZipFile("fotoarc.zip", "r") as zip_data:
        cnt_file = len(zip_data.namelist())
        with tqdm(total=cnt_file, position=0, leave=True) as pbar:
            for fileidx in range(cnt_file):

                file = zip_data.filelist[fileidx]
                name = file.filename
                zip_data.extract(name, path="fotoarc/")

                img = cv2.imread('fotoarc/' + name)
                img = cv2.resize(img, (256, 256))

                yuv = cv2.cvtColor(img, cv2.COLOR_RGB2LUV)
                samples += [yuv[:, :, 0]]
                targets += [yuv[:, :, 1:]]

                os.remove("fotoarc/" + name)

                pbar.set_description(f"Files: {fileidx+1}/{cnt_file}")
                pbar.update()

    x_train, x_test, y_train, y_test = train_test_split(samples,
                                                        targets,
                                                        test_size=0.2,
                                                        shuffle=True,
                                                        random_state=42)

    x_train, x_test, y_train, y_test = torch.Tensor(np.array(x_train)), \
        torch.Tensor(np.array(x_test)), \
        torch.Tensor(np.array(y_train)), \
        torch.Tensor(np.array(y_test))

    return TensorDataset(x_train, y_train), TensorDataset(x_test, y_test)

In [3]:
def print_img(tensor_L, tensor_UV):
    img = np.uint8(torch.cat((tensor_L.view(256, 256, 1), tensor_UV), dim=2))
    img = cv2.cvtColor(img, cv2.COLOR_LUV2RGB)
    plt.imshow(img)

In [4]:
train_dataset, test_dataset = preproc_data()

Files: 7129/7129: 100%|████████████████████████████████████████████████████████████| 7129/7129 [01:59<00:00, 59.59it/s]


In [5]:
batch_size = 32
train_loader, test_loader = DataLoader(train_dataset, batch_size), DataLoader(test_dataset, 1)

In [None]:
print_img(train_dataset[0][0], train_dataset[0][1])

In [25]:
class CNN_CI(nn.Module):
    def __init__(self, parameters={}):
        super().__init__()

        # ----------------------------------------- model -------------------------------------------
        # 256x256x1
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, stride=2, padding=2)
        # 128x128x64
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        # 64x64x128
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=3, stride=2, padding=1)
        # 32x32x512
        self.conv4 = nn.Conv2d(in_channels=512, out_channels=128, kernel_size=3, stride=1, padding=1)
        # 32x32x128
        # resize 256x256x2
        # --------------------------------------------------------------------------------------------

        self.name = "ColorImg_1_0"
        self.modules = ["conv1", "conv2", "conv3", "conv4"]

        self.load_self(parameters)

        self.to(self.device)

    def init_weigth(self):
        for module in self.modules:
            nn.init.xavier_uniform_(getattr(self, module).weight)

    def testing(self, test_dataloader):
        len_test_data = len(test_dataloader)

        with tqdm(total=len_test_data, position=0, leave=True) as pbar:

            self.eval()
            running_loss = 0
            num_test = 1

            for test in test_dataloader:
                x_data = test[0].view(test[0].shape[0], 1, 256, 256).to(self.device)
                y_data = test[1].to(self.device)

                y_pred = self(x_data)
                loss = self.criterion(y_pred, y_data)

                loss.backward()
                running_loss += loss.item()

                pbar.set_description(f"Test: {num_test}/{len_test_data}, Loss: {running_loss}")
                pbar.update()

                num_test += 1

        self.history_test_loss.update({self.epochs_train: running_loss})
        self.save()

    def fit(self, train_dataloader, epochs):
        len_train_data = len(train_dataloader)

        with tqdm(total=epochs * len_train_data, position=0, leave=True) as pbar:

            for epoch in range(epochs):
                self.train()
                running_loss = 0
                num_batch = 1

                for batch in train_dataloader:
                    x_data = batch[0].view(batch[0].shape[0], 1, 256, 256).to(self.device)
                    y_data = batch[1].to(self.device)

                    y_pred = self(x_data)
                    loss = self.criterion(y_pred, y_data)

                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    running_loss += loss.item()

                    pbar.set_description(f"Epoch: {self.epochs_train + 1}/{self.epochs_train + epochs - epoch}, Batch: {num_batch}/{len_train_data}, Loss: {running_loss}")
                    pbar.update()

                    num_batch += 1

                self.epochs_train += 1
                self.history_train_loss.update({self.epochs_train: running_loss})
                self.save()

        self.pretrained = True
        self.save()

    def load_self(self, parameters):
        try:
            checkpoint = torch.load(f"{self.name}.pth")

            self.pretrained = True
            self.learning_rate = checkpoint["learning_rate"]
            self.load_state_dict(checkpoint['load_state_dict'])
            self.optimizer = checkpoint['optimizer']
            self.epochs_train = checkpoint["epochs_train"]
            self.history_train_loss = checkpoint["history_train_loss"]
            self.history_test_loss = checkpoint["history_test_loss"]
            self.device = checkpoint["device"]
            self.criterion = checkpoint["criterion"]

        except FileNotFoundError:

            self.pretrained = False
            self.learning_rate = parameters["learning_rate"]

            self.init_weigth()
            self.optimizer = parameters["optimizer"](self.parameters(), lr=self.learning_rate)

            self.epochs_train = 0
            self.history_train_loss = {}
            self.history_test_loss = {}
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            self.criterion = parameters["criterion"]

    def save(self):
        checkpoint = {"load_state_dict": self.state_dict(),
                      "learning_rate": self.learning_rate,
                      "optimizer": self.optimizer,
                      "epochs_train": self.epochs_train,
                      "history_train_loss": self.history_train_loss,
                      "history_test_loss": self.history_test_loss,
                      "device": self.device,
                      "criterion": self.criterion}

        torch.save(checkpoint, f"{self.name}.pth")

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(x.shape[0], 256, 256, 2)

        return x

In [13]:
model = CNN_CI(parameters={"learning_rate": 0.005,
                           "optimizer": Adam,
                           "criterion": nn.CrossEntropyLoss()})

In [30]:
model.optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.005
    maximize: False
    weight_decay: 0
)

In [15]:
model.fit(train_loader, epochs=3)

Epoch: 3/3, Batch: 179/179, Loss: 30305198.9375: 100%|███████████████████████████████| 537/537 [03:54<00:00,  2.29it/s]


In [17]:
model.history_train_loss

{1: 38339434.875, 2: 30305240.203125, 3: 30305198.9375}

In [18]:
model.testing(test_loader)

Test: 1426/1426, Loss: 241902751.9375: 100%|███████████████████████████████████████| 1426/1426 [00:30<00:00, 46.82it/s]


In [20]:
model.history_test_loss

{3: 241902751.9375}

In [21]:
model = CNN_CI()

In [26]:
model.history_train_loss

{1: 38339434.875, 2: 30305240.203125, 3: 30305198.9375}

In [31]:
model.fit(train_loader, 1)

Epoch: 5/5, Batch: 179/179, Loss: 30305146.015625: 100%|█████████████████████████████| 179/179 [01:13<00:00,  2.44it/s]


In [32]:
model.history_train_loss

{1: 38339434.875,
 2: 30305240.203125,
 3: 30305198.9375,
 4: 30305146.015625,
 5: 30305146.015625}