In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import numpy as np
from scipy import ndimage
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.io import read_image, ImageReadMode
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import os
from torch import is_tensor, FloatTensor,tensor
import torch

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
torch.backends.cudnn.benchmark = True

cpu


In [None]:
batch_size = 30
epochs = 60

### Dataset

In [None]:
class SuperResolutionDataset(Dataset):
    def __init__(self, high_resolution_dir, low_resolution_dir, transform=None, target_transform=None):
        self.high_resolution_dir = high_resolution_dir
        self.low_resolution_dir = low_resolution_dir
        self.transform = transform
        self.target_transform = target_transform
        self.files = [name for name in os.listdir(high_resolution_dir) if name.endswith('.png')]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, item):
        lr_path = os.path.join(self.low_resolution_dir, self.files[item])
        lr_image = read_image(lr_path , ImageReadMode.RGB).float()

        lr_image_y = (16+ lr_image[..., 0, :, :]*0.25679 + lr_image[..., 1, :, :]*0.504 + lr_image[..., 2, :, :]*0.09791)/255 
        lr_image_y = lr_image_y[None , :, :]
        hr_path = os.path.join(self.high_resolution_dir, self.files[item])
        hr_image = read_image(hr_path , ImageReadMode.RGB).float()
        
        hr_image_y = (16+ hr_image[..., 0, :, :]*0.25679 + hr_image[..., 1, :, :]*0.504 + hr_image[..., 2, :, :]*0.09791) /255
        hr_image_y = hr_image_y[None , :, :]
        if self.transform:
            lr_image = self.transform(lr_image)
        if self.target_transform:
            hr_image = self.transform(hr_image)
        return lr_image_y, hr_image_y

    
SRDataset = SuperResolutionDataset("drive/MyDrive/Datasets/HRPatches2" , "drive/MyDrive/Datasets/LRPatches2"  )
train_dataloader=DataLoader(SRDataset, batch_size=batch_size, shuffle=True)

### Funkcja ucząca

In [None]:
def pre_train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    losses = []
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        pred = model(X)
        loss = loss_fn(pred, y)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

        if batch % 4 == 0:
            loss, current = np.mean(losses), batch * len(X)
            print(f"loss: {loss:>7f}, sqr {loss**0.5:>7f}  [{current:>6d}/{size:>6d}]")
            losses = []

### Model

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(KPNLPnetwork, self).__init__()
        self.kernel = (1.0/100)*torch.tensor([[[[1, 4, 6, 4, 1],[4, 16, 24, 16, 4],[6, 24, 36, 24, 6], [4, 16, 24, 16, 4],[1, 4, 6, 4, 1]]]])
        self.downsample = nn.PixelUnshuffle(4)
        self.conv1a = nn.Conv2d(16 , 64 , 3 , padding=1)
        self.conv1b = nn.Conv2d(64, 64, 3, padding=1)
        self.conv1qa = nn.Conv2d(64, 64, 3, padding=1)
        self.conv1qb = nn.Conv2d(64, 64, 3, padding=1)
        self.conv1ha = nn.Conv2d(16, 64, 3, padding=1)
        self.conv1hb = nn.Conv2d(64, 64, 3, padding=1)
        self.conv1fa = nn.Conv2d(4, 64, 3, padding=1)
        self.conv1fb = nn.Conv2d(64, 64, 3, padding=1)
        self.relu = nn.LeakyReLU()
        self.stack = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),

            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),

            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),

            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, 3, padding=1)
        )
        self.upsample2 = nn.PixelShuffle(2)
        self.upsample4 = nn.PixelShuffle(4)
        self.conv2q = nn.Conv2d(64, 25 , 3 , padding=1)
        self.conv2h = nn.Conv2d(64, 25, 3, padding=1)
        self.conv2f = nn.Conv2d(64, 25, 3, padding=1)
        self.conv3q = nn.Conv2d(25 , 1 , 5, padding='same')
        self.conv3h = nn.Conv2d(25, 1, 5, padding='same')
        self.conv3f = nn.Conv2d(25, 1, 5, padding='same')

        self.pyrConv = nn.Conv2d(1 ,1 ,5 , padding="same" , bias=False)

        self.pyrConv.weight = nn.Parameter(self.kernel)

        self.normalUp = nn.Upsample(scale_factor  = 2 , mode='bicubic')
        self.padLayer = nn.ZeroPad2d(2)

    def forward(self, x):
        common = self.downsample(x)
        common = self.conv1a(common)
        common = self.relu(common)
        common = self.stack(common)
        common = self.conv1b(common)
        common = self.relu(common)
        quarter = common
        quarter = self.conv1qa(quarter)
        quarter = self.relu(quarter)
        quarter = self.conv1qb(quarter)
        quarter = self.relu(quarter)
        quarter = self.conv2q(quarter)
        quarter = self.relu(quarter)

        half = self.upsample2(common)
        full = self.upsample4(common)

        half = self.conv1ha(half)
        half = self.relu(half)
        half = self.conv1hb(half)
        half = self.relu(half)
        half = self.conv2h(half)
        half = self.relu(half)


        full = self.conv1fa(full)
        full = self.relu(full)
        full = self.conv1fb(full)
        full = self.relu(full)
        full = self.conv2f(full)
        full = self.relu(full)
        h = x.shape[2]
        w = x.shape[3]
        padded = self.padLayer(x).to(device)
        nq = torch.empty(x.shape[0] , 25, h//4, w//4).to(device)
        nh = torch.empty(x.shape[0] , 25, h//2, w//2).to(device)
        c = torch.empty(x.shape[0] , 25, h, w ).to(device)
        for i in range(h):
            for j in range(w):
                c[...,:,i,j] = torch.flatten(padded[... , 0, i:i+5 , j:j+5] , start_dim=1)
        d = full*c
        e = torch.sum(d , 1, keepdim  = True)

        for i in range(h//2):
            pom_i = i*2
            for j in range(w//2):
                pom_j = j*2
                nh[...,:,i,j] = torch.flatten(padded[... , 0, pom_i:pom_i+5 , pom_j:pom_j+5] , start_dim=1)
        dh = half*nh
        eh = torch.sum(dh , 1, keepdim  = True)


        for i in range(h//4):
            pom_i = i*4
            for j in range(w//4):
                pom_j = j*4
                nq[...,:,i,j] = torch.flatten(padded[... , 0, pom_i:pom_i+5 , pom_j:pom_j+5] , start_dim=1)
        dq = quarter*nq
        eq = torch.sum(dq , 1, keepdim  = True)

        eq = self.normalUp(eq)
        eq = self.pyrConv(eq)  
        eh = eh+ eq
        eh = self.normalUp(eh)
        eh = self.pyrConv(eh)  
        e = eh+ e

        e = self.normalUp(e)
        c.detach()
        eh.detach()
        eq.detach()
        padded.detach()
        return e


### Uczenie sieci

In [None]:
torch.cuda.empty_cache()
modelLeaky = NeuralNetwork().to(device)
cost = nn.MSELoss()
opt = optim.Adam(modelLeaky.parameters() ,lr= 0.0001 )

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    pre_train(train_dataloader, modelLeaky, cost, opt)
    torch.save(modelLeaky.state_dict(), 'drive/MyDrive/KPNLP.model')
print("Done!")

In [None]:
torch.cuda.empty_cache()
modelLeaky = NeuralNetwork().to(device)
modelLeaky.load_state_dict(torch.load('drive/MyDrive/KPNLP.model'))


cost = nn.MSELoss()
opt = optim.Adam(modelLeaky.parameters() ,lr= 0.0001 )
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    pre_train(train_dataloader, modelLeaky, cost, opt)
    torch.save(modelLeaky.state_dict(), 'drive/MyDrive/KPNLP.model')
print("Done!")

### Sprawdzanie wyników na podanym zdjęciu

In [None]:
def PSNR(pred, target):
    mse = np.mean((pred - target) ** 2)
    psnr = 255 ** 2 / mse
    return 10 * np.log10(psnr)
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load('drive/MyDrive/KPNLP.model' , map_location=torch.device('cpu')))
lr_image = read_image("img (12).png")[None , :].float()
lr_image_y = (16+ lr_image[..., 0, :, :]*0.25679 + lr_image[..., 1, :, :]*0.504 + lr_image[..., 2, :, :]*0.09791)/255
lr_image_y = lr_image_y[None , :, :]
lr_image_cb = (128 - 37.945*lr_image[..., 0, :, :]/256 - 74.494*lr_image[..., 1, :, :]/256 + 112.439*lr_image[..., 2, :, :]/256)
lr_image_cr = (128 + 112.439*lr_image[..., 0, :, :]/256 - 94.154*lr_image[..., 1, :, :]/256 - 18.285*lr_image[..., 2, :, :]/256)
hr_cb = nn.functional.interpolate(lr_image_cb[None , :,:],scale_factor = 2 , mode='bicubic').detach().numpy()[0,0]
hr_cr = nn.functional.interpolate(lr_image_cr[None , :,:],scale_factor = 2 , mode='bicubic').detach().numpy()[0,0]
print(str(lr_image_y.shape) , lr_image_y.dtype)
lr_image_y = lr_image_y.to(device)
pom = model(lr_image_y)
pom2 = pom.detach().cpu().numpy()[0,0]
pom2 *= 255
pom2 = np.clip(pom2 , 0, 255)
hr_cr = np.clip(hr_cr, 0, 255)
hr_cb = np.clip(hr_cb , 0, 255)
#print(pom2.shape, pom2.max() , pom2.min() , hr_cb.shape , hr_cr.shape)
r = pom2 + 1.402 *(hr_cr - 128)
g = pom2 - 0.344136*(hr_cb - 128) - 0.714136 *(hr_cr-128)
b = pom2 + 1.772* (hr_cb - 128)
improvedImg = np.dstack((r,g,b)).astype(np.uint8)
#print(improvedImg.shape, improvedImg.max() , improvedImg.min())
plt.imshow(improvedImg)
plt.show()
from PIL import Image
im = Image.fromarray(improvedImg)
im.save("img (12)Pred.png")

hr_image = read_image("img (12) (1).png").float().numpy()
hr_image = np.moveaxis(hr_image, 0, -1)
#print(hr_image.shape , hr_image.max())
print("psnr: " , PSNR(improvedImg , hr_image))
pred_biciubic = nn.functional.interpolate(lr_image,scale_factor = 2 , mode='bicubic').detach().numpy()[0]
pred_biciubic = np.moveaxis(pred_biciubic, 0, -1)
print("psnr: bicubic" , PSNR(pred_biciubic , hr_image))