In [12]:
from pathlib import Path
import random
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision
import os
import numpy as np
import wavemix.sisr as sisr
from PIL import Image
from torchinfo import summary
import gc
import json

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
data_path = Path.cwd() / "data"

train_data_folder = data_path/"DIV2K/DIV2K_train_LR_bicubic/X2"
test_data_folder = data_path/"DIV2K/DIV2K_valid_LR_bicubic/X2"

train_data_targest_folder = data_path/"DIV2K/DIV2K_train_HR"
test_data_targest_folder = data_path/"DIV2K/DIV2K_valid_HR"

output_folder = data_path/"output_bicubic_x2_valid"
os.makedirs(output_folder, exist_ok=True)

In [4]:
class WaveMixSR(nn.Module):
    def __init__(
        self,
        *,
        depth,
        mult = 1,
        ff_channel = 16,
        final_dim = 16,
        dropout = 0.3,
        scale_factor = 2
    ):
        super().__init__()
        
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(sisr.Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
        
        self.final = nn.Sequential(
            nn.Conv2d(final_dim,int(final_dim/2), 3, stride=1, padding=1),
            nn.Conv2d(int(final_dim/2), 1, 1)
        )


        self.path1 = nn.Sequential(
            nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners = False),
            nn.Conv2d(1, int(final_dim/2), 3, 1, 1),
            nn.Conv2d(int(final_dim/2), final_dim, 3, 1, 1)
        )

        self.path2 = nn.Sequential(
            nn.Upsample(scale_factor=int(scale_factor), mode='bilinear', align_corners = False),
        )

    def forward(self, img):

        y = img[:, 0:1, :, :] 
        crcb = img[:, 1:3, :, :]

        y = self.path1(y)


        for attn in self.layers:
            y = attn(y) + y

        y = self.final(y)

        crcb = self.path2(crcb)
        
        return  torch.cat((y,crcb), dim=1)

In [14]:
weights = torch.load('weights.pth', map_location=device)

  weights = torch.load('weights.pth', map_location=device)


In [15]:
model = WaveMixSR(depth = 4, mult = 1, ff_channel = 144, final_dim = 144, dropout = 0.3, scale_factor = 2).to(device)
model.load_state_dict(weights)

<All keys matched successfully>

In [18]:
losses = []

model.eval()

for image in os.listdir(test_data_folder):
    img = Image.open(test_data_folder/image)
    img = torchvision.transforms.ToTensor()(img)
    img = img.unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img)
        output = output.squeeze(0)
        output = output.to('cpu')
        output = output.detach().numpy()
        output = output.transpose(1,2,0)
        output = np.clip(output, 0, 1)
        output = (output*255).astype(np.uint8)
        biliear_img = Image.open(test_data_targest_folder/image.replace('x2',''))
        loss = np.mean((np.array(biliear_img) - output)**2)
        losses.append(loss)
        output = Image.fromarray(output)
        output.save(output_folder/image)

losses_file = os.path.join(output_folder, "losses.json")
with open(losses_file, 'w') as f:
    json.dump(losses, f)
    
average_loss = sum(losses)/len(losses)

In [None]:
del model
gc.collect()