In [None]:
from pathlib import Path
import torch
from torch import nn
import torchvision
import os
import numpy as np
import wavemix.sisr as sisr
from PIL import Image
import kornia

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
data_path = Path.cwd() / "data"

data_folder =  data_path/"lowres_images"

output_folder =  data_path/"output_images"

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

In [4]:
class WaveMixSR(nn.Module):
    def __init__(
        self,
        *,
        depth,
        mult = 1,
        ff_channel = 16,
        final_dim = 16,
        dropout = 0.3,
        scale_factor = 2
    ):
        super().__init__()
        
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(sisr.Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
        
        self.final = nn.Sequential(
            nn.Conv2d(final_dim,int(final_dim/2), 3, stride=1, padding=1),
            nn.Conv2d(int(final_dim/2), 1, 1)
        )


        self.path1 = nn.Sequential(
            nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners = False),
            nn.Conv2d(1, int(final_dim/2), 3, 1, 1),
            nn.Conv2d(int(final_dim/2), final_dim, 3, 1, 1)
        )

        self.path2 = nn.Sequential(
            nn.Upsample(scale_factor=int(scale_factor), mode='bilinear', align_corners = False),
        )

    def forward(self, img):

        y = img[:, 0:1, :, :] 
        crcb = img[:, 1:3, :, :]

        y = self.path1(y)


        for attn in self.layers:
            y = attn(y) + y

        y = self.final(y)

        crcb = self.path2(crcb)
        
        return  torch.cat((y,crcb), dim=1)

In [5]:
weights = torch.load('weights.pth', map_location=device)
model = WaveMixSR(depth = 4, mult = 1, ff_channel = 144, final_dim = 144, dropout = 0.3, scale_factor = 2).to(device)
model.load_state_dict(weights)
model.eval()

WaveMixSR(
  (layers): ModuleList(
    (0-3): 4 x Level1Waveblock(
      (feedforward): Sequential(
        (0): Conv2d(144, 144, kernel_size=(1, 1), stride=(1, 1))
        (1): GELU(approximate='none')
        (2): Dropout(p=0.3, inplace=False)
        (3): Conv2d(144, 144, kernel_size=(1, 1), stride=(1, 1))
        (4): ConvTranspose2d(144, 144, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (5): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (reduction): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (final): Sequential(
    (0): Conv2d(144, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(72, 1, kernel_size=(1, 1), stride=(1, 1))
  )
  (path1): Sequential(
    (0): Upsample(scale_factor=2.0, mode='bilinear')
    (1): Conv2d(1, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(72, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (path2): Sequential(
   

In [6]:
transform_target = torchvision.transforms.Compose(
        [   torchvision.transforms.ToTensor(),
     ])

In [None]:
# Note: The following code is commented out because it is not needed for usage. It is only used for testing and comparison purposes.

# losses = []
# test_data_folder_tmp = data_folder
# for image in os.listdir(test_data_folder_tmp):
#     img = Image.open(test_data_folder_tmp/image)
#     img = transform_target(img)
#     img = kornia.color.rgb_to_ycbcr(img)
#     img = img.unsqueeze(0).to(device)

#     with torch.no_grad():
#         output = model(img)
#         output = kornia.color.ycbcr_to_rgb(output)
#         output = output.squeeze(0)
#         output = output.to('cpu')
#         output = output.detach().numpy()
#         output = output.transpose(1,2,0)
#         output = np.clip(output, 0, 1)
#         output = (output*255).astype(np.uint8)
#         target_image = Image.open(data_folder/image.replace('x2',''))
#         if output.shape[0] < target_image.size[1]:
#             output = np.pad(output, ((0, target_image.size[1] - output.shape[0]), (0, 0), (0, 0)), mode='constant')
#         elif output.shape[0] > target_image.size[1]:
#             output = output[:target_image.size[1], :, :]
#         if output.shape[1] < target_image.size[0]:
#             output = np.pad(output, ((0, 0), (0, target_image.size[0] - output.shape[1]), (0, 0)), mode='constant')
#         elif output.shape[1] > target_image.size[0]:
#             output = output[:, :target_image.size[0], :]
#         loss = np.mean((np.array(target_image) - output)**2)
#         losses.append(loss)
#         output = Image.fromarray(output)
#         output.save(output_folder/image)

# losses_file = os.path.join(output_folder, "losses.json")
# with open(losses_file, 'w') as f:
#     json.dump(losses, f)
    
# average_loss = sum(losses)/len(losses)

In [None]:
for image in os.listdir(data_folder):
    img = Image.open(data_folder/image)
    img = transform_target(img)
    img = kornia.color.rgb_to_ycbcr(img)
    img = img.unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img)
        output = kornia.color.ycbcr_to_rgb(output)
        output = output.squeeze(0)
        output = output.to('cpu')
        output = output.detach().numpy()
        output = output.transpose(1,2,0)
        output = np.clip(output, 0, 1)
        output = (output*255).astype(np.uint8)
        output = Image.fromarray(output)
        output.save(output_folder/image)