In [None]:
%reset -f

import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import os
import torchvision
from torchvision.io.image import ImageReadMode
import matplotlib.pyplot as plt
import random
from torch.utils.data import Dataset, DataLoader, random_split
import tqdm
from typing import Callable
import json

torch.cuda.empty_cache()

In [None]:
class UpscalerDataset(Dataset):
    def __init__(self, data:list[tuple[torch.Tensor, torch.Tensor]]):
        self.data:list[tuple[torch.Tensor, torch.Tensor]] = data

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx) -> torch.Tensor:
        return self.data[idx]

def load_dataset() -> list[tuple[torch.Tensor]]:
    input_images_dir:str = os.path.join(os.getcwd(), "Dataset", "64x64")
    output_images_dir:str = os.path.join(os.getcwd(), "Dataset", "512x512")
    input_images_path:list[str] = os.listdir(input_images_dir)
    output_images_path:list[str] = os.listdir(output_images_dir)
    assert len(input_images_path) == len(output_images_path)
    dataset: list[tuple[torch.Tensor]] =  list[tuple[torch.Tensor]]()

    for image_local_path in input_images_path:
        input_image_absolute_path:str = os.path.join(input_images_dir, image_local_path)
        output_image_absolute_path:str = os.path.join(output_images_dir, image_local_path)
        input_image:torch.Tensor = torchvision.io.read_image(input_image_absolute_path, ImageReadMode.RGB).type(torch.float32) / 255.0
        input_image = input_image * 2.0 - 1.0 # Extend the image for the [-1, 1] range
        output_image:torch.Tensor = torchvision.io.read_image(output_image_absolute_path, ImageReadMode.RGB).type(torch.float32) / 255.0
        output_image = output_image * 2.0 - 1.0 # Extend the image for the [-1, 1] range
        dataset.append((input_image.cpu(), output_image.cpu()))
    return UpscalerDataset(dataset)

dataset:UpscalerDataset = load_dataset()

In [None]:
def plot_dataset_image(nb_image:int) -> None:
    for _ in range(0, nb_image):
        index:int = random.randint(0, len(dataset) - 1)

        input_image:torch.Tensor = dataset[index][0] * 0.5 + 0.5
        output_image:torch.Tensor = dataset[index][1] * 0.5 + 0.5
        input_image = input_image.permute(1, 2, 0)
        plt.imshow(input_image)
        plt.show()
        output_image = output_image.permute(1, 2, 0)
        plt.imshow(output_image)
        plt.show()
        
# plot_dataset_image(1)

In [None]:
class UpscalerSettings:
    def __init__(self, in_channels:int=3, out_channels:int=3, base_channels:int=64, num_rdb_blocks:int=16, rdb_growth_rate:int=32, rdb_num_layers:int=8) -> None:
        self.in_channels:int = in_channels
        self.out_channels:int = out_channels
        self.base_channels:int = base_channels
        self.num_rdb_blocks:int = num_rdb_blocks
        self.rdb_growth_rate:int = rdb_growth_rate
        self.rdb_num_layers:int = rdb_num_layers

in_channels:int = 3
out_channels:int = 3
base_channels:int=64
num_rdb_blocks:int=16
rdb_growth_rate:int=32
rdb_num_layers:int=8
model_settings:UpscalerSettings = UpscalerSettings(in_channels, out_channels, base_channels, num_rdb_blocks, rdb_growth_rate, rdb_num_layers)

#training params
batch_size:int = 5
epochs:int = 50
test_proportion:float = 0.1

base_lr:float = 2e-4
def lr_scheduler_fn(step:int) -> float:
    return base_lr

lr_scheduler:Callable[[int], float] = lr_scheduler_fn

saving_epoch_interval:int = 1
print_epoch_interval:int = 1
nb_sample_to_print:int = 1
device:str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device : {device}")

In [None]:
class ResidualDenseBlock(nn.Module):
    def __init__(self, channels:int, growth_rate:int=32, num_layers:int=5):
        super().__init__()
        self.layers:nn.ModuleList = nn.ModuleList()
        
        for i in range(num_layers):
            in_channels = channels + i * growth_rate
            self.layers.append(nn.Sequential(
                nn.Conv2d(in_channels, growth_rate, 3, padding=1),
                nn.ReLU(inplace=True)
            ))
        
        self.lff = nn.Conv2d(channels + num_layers * growth_rate, channels, 1)
        
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        features:list[torch.Tensor] = [x]
        
        for layer in self.layers:
            out:torch.Tensor = layer(torch.cat(features, dim=1))
            features.append(out)
        
        return x + self.lff(torch.cat(features, dim=1))

class UpscaleBlock(nn.Module):
    def __init__(self, channels:int, scale_factor:int=2):
        super().__init__()
        self.conv:nn.Conv2d = nn.Conv2d(channels, channels * (scale_factor ** 2), 3, padding=1)
        self.ps:nn.PixelShuffle = nn.PixelShuffle(scale_factor)
        self.relu:nn.ReLU = nn.ReLU(inplace=True)
        
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.ps(x)
        x = self.relu(x)
        return x

class RDN_Upscaler(nn.Module):
    def __init__(self, in_channels:int=3, out_channels:int=3, base_channels:int=64, num_rdb_blocks:int=16, rdb_growth_rate:int=32, rdb_num_layers:int=8):
        super().__init__()
        
        self.sfe1:nn.Conv2d = nn.Conv2d(in_channels, base_channels, 3, padding=1)
        self.sfe2:nn.Conv2d = nn.Conv2d(base_channels, base_channels, 3, padding=1)
        
        self.rdbs:nn.ModuleList = nn.ModuleList([
            ResidualDenseBlock(base_channels, rdb_growth_rate, rdb_num_layers)
            for _ in range(0, num_rdb_blocks)
        ])
        
        self.gff:nn.Sequential = nn.Sequential(
            nn.Conv2d(base_channels * num_rdb_blocks, base_channels, 1),
            nn.Conv2d(base_channels, base_channels, 3, padding=1)
        )
        
        self.upscale:nn.Sequential = nn.Sequential(
            UpscaleBlock(base_channels, 2),  # 64 -> 128
            UpscaleBlock(base_channels, 2),  # 128 -> 256
            UpscaleBlock(base_channels, 2),  # 256 -> 512
        )
        
        self.reconstruction:nn.Sequential = nn.Sequential(
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels, out_channels, 3, padding=1)
        )
        
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        sfe1:torch.Tensor = self.sfe1(x)
        sfe2:torch.Tensor = self.sfe2(sfe1)
        
        features:list[torch.Tensor] = [sfe2]
        for rdb in self.rdbs:
            features.append(rdb(features[-1]))
        
        global_feat:torch.Tensor = self.gff(torch.cat(features[1:], dim=1))
        global_feat = global_feat + sfe1
        
        upscaled:torch.Tensor = self.upscale(global_feat)
        upscaled = self.reconstruction(upscaled)
        return torch.tanh(upscaled)

In [None]:
def load_upscaler(file:str) -> RDN_Upscaler:
    model:RDN_Upscaler = RDN_Upscaler(model_settings.in_channels, model_settings.out_channels, model_settings.base_channels, model_settings.num_rdb_blocks, model_settings.rdb_growth_rate, model_settings.rdb_num_layers)
    model_path:str = os.path.join(os.getcwd(), "Models", file)
    model.load_state_dict(torch.load(model_path, weights_only=True))
    model.eval()
    return model

# upscaler:RDN_Upscaler = load_upscaler("denoiser.model")

# param_size = 0
# for param in upscaler.parameters():
#     param_size += param.nelement() * param.element_size()

# size_all_mb = (param_size + 0) / (1024**2)
# print("Model size: {:.3f} MB with {} parameters".format(size_all_mb, sum(p.numel() for p in upscaler.parameters())))

In [None]:
class TrainResult:
    train_losses:list[float]
    test_losses:list[float]

    @staticmethod
    def from_json(json_str:str) -> 'TrainResult':
        d = json.loads(json_str)
        return TrainResult(**d)

    def __init__(self, train_losses:list[float], test_losses:list[float]) -> None:
        self.train_losses = train_losses
        self.test_losses = test_losses

    def to_json(self) -> str:
        return json.dumps(self.__dict__, indent=4)

In [None]:
def train(model:RDN_Upscaler) -> TrainResult:
    resume_train(model, 0)
    
def resume_train(model:RDN_Upscaler, epoch:int) -> TrainResult:

    sample_to_print = min(nb_sample_to_print, batch_size)
    epoch = max(0, epoch)

    train_result:TrainResult = None
    if epoch <= 0:
        train_result = TrainResult([], [])
    else:
        train_res_path:str = f"./Models/train_result_epoch{epoch}.txt"
        file = open(train_res_path, "r")
        string = file.read()
        train_result  = TrainResult.from_json(string)
        file.close()

    test_size:int = int(test_proportion * len(dataset))
    train_size:int = len(dataset) - test_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    optimizer:torch.optim.Adam = torch.optim.AdamW(model.parameters(), lr=lr_scheduler(0), weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_scheduler)
    criterion:_Loss = nn.MSELoss()
    model.to(device)

    print("Start training Upscaler")
    with tqdm.tqdm(total=epochs * len(train_dataset), desc="Training model") as pbar:
        pbar.update(epoch * len(train_dataset))
        for epoch in range(epoch, epochs):
            train_loader:DataLoader = DataLoader(train_dataset, batch_size, shuffle=True)
            input_batch:torch.Tensor = None
            output_batch:torch.Tensor = None

            last_pred_output_batch:torch.Tensor = None
            last_output_batch:torch.Tensor = None

            total_loss:float = 0.0
            for input_batch, output_batch in train_loader:
                input_batch = input_batch.to(device)
                output_batch = output_batch.to(device)
                optimizer.zero_grad()

                predicted_image:torch.Tensor = model(input_batch)
                loss:torch.Tensor = criterion(predicted_image, output_batch)
                total_loss += loss.cpu().item()

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                last_pred_output_batch = predicted_image
                last_output_batch = output_batch
                pbar.update(input_batch.shape[0])

            print_sample:bool = print_epoch_interval > 0 and ((epoch + 1) % print_epoch_interval == 0 or epoch + 1 == epochs)
            if print_sample:
                for i in range(0, sample_to_print):
                    pred_img:torch.Tensor = last_pred_output_batch[i].detach().cpu() * 0.5 + 0.5
                    output_img:torch.Tensor = last_output_batch[i].detach().cpu() * 0.5 + 0.5
                    print("target image :")
                    plt.imshow(output_img.permute(1, 2, 0))
                    plt.show()
                    print("Predict image :")
                    plt.imshow(pred_img.permute(1, 2, 0))
                    plt.show()

            train_loss:float = total_loss / (len(train_dataset) / batch_size)
            train_result.train_losses.append(train_loss)

            total_loss = 0.0
            test_loader:DataLoader = DataLoader(test_dataset, batch_size, shuffle=True)
            with torch.no_grad():
                for input_batch, output_batch in test_loader:
                    input_batch = input_batch.to(device)
                    output_batch = output_batch.to(device)
                    predicted_image:torch.Tensor = model(input_batch)
                    loss:torch.Tensor = criterion(predicted_image, output_batch)
                    total_loss += loss.cpu().item()
                train_result.test_losses.append(total_loss / (len(test_dataset) / batch_size))

            save_epoch:bool = saving_epoch_interval >= 1 and ((epoch + 1) % saving_epoch_interval == 0 or epoch + 1 == epochs)
            if save_epoch:
                path:str = f"./Models/upscaler_epoch{epoch}.model" if epoch < epochs - 1 else "./Models/upscaler.model"
                if os.path.isfile(path):
                    os.remove(path)
                torch.save(model.state_dict(), path)

                train_res_path:str = f"./Models/train_result_epoch{epoch}.txt" if epoch < epochs - 1 else "./Models/train_result.txt"
                file = open(train_res_path, "w")
                file.write(train_result.to_json())
                file.close()

        return train_result
    
upscaler:RDN_Upscaler = RDN_Upscaler(model_settings.in_channels, model_settings.out_channels, model_settings.base_channels, model_settings.num_rdb_blocks, model_settings.rdb_growth_rate, model_settings.rdb_num_layers)
seed:int = 2442157549
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
train_result:TrainResult = train(upscaler)

# epoch_to_resume:int = 549
# upscaler:RDN_Upscaler = load_upscaler(f"upscaler_epoch{epoch_to_resume}.model")
# train_result:TrainResult = resume_train(upscaler, epoch_to_resume)

In [None]:
def show_training_result(train_result:TrainResult) -> None:
    assert len(train_result.test_losses) == len(train_result.test_losses)
    X:list[float] = [float(i) for i in range(1, len(train_result.test_losses) + 1)]
    plt.plot(X, train_result.train_losses, label='train loss')
    plt.plot(X, train_result.test_losses, label='test loss')
    plt.title("Losses", fontsize=18)
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.legend()
    plt.show()

train_res_path:str = f"./Models/train_result_epoch{549}.txt"
file = open(train_res_path, "r")
train_result:TrainResult = TrainResult.from_json(file.read())
file.close()
show_training_result(train_result)

In [None]:
def test_model(upscaler:RDN_Upscaler, test_size:int) -> None:
    upscaler.to(device)
    with torch.no_grad():
        test_data:list[tuple[torch.Tensor]] = [dataset[random.randint(0, len(dataset) - 1)] for _ in range(0, test_size)]
        input_batch:torch.Tensor = torch.stack([image[0] for image in test_data]).to(device)
        output_batch:torch.Tensor = torch.stack([image[1] for image in test_data]).to(device)
        pred_images:torch.Tensor = upscaler(input_batch)

        for i in range(0, test_size):
            print("Output image : ")
            plt.imshow((output_batch[i].cpu() * 0.5 + 0.5).permute(1, 2, 0))
            plt.show()

            print("Predicted image : ")
            plt.imshow((pred_images[i].cpu() * 0.5 + 0.5).permute(1, 2, 0))
            plt.show()

upscaler:RDN_Upscaler = load_upscaler(f"upscaler_epoch{549}.model")
test_model(upscaler, 1)