In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import glob
import cv2
import numpy as np
import PIL.Image as Image
import cv2
import copy
from PIL import ImageFilter
import matplotlib.pyplot as plt

In [None]:
SCALE = 3

In [None]:
class HeadDataset(Dataset):
    def __init__(self, files, scale=SCALE, stride=30):
        # self.files = files
        self.scale = scale
        # self.files = []
        image_windows = []
        for image in files:
            img = cv2.cvtColor(cv2.imread(image), cv2.COLOR_BGR2RGB)
            h,w,c = img.shape
            for i in range(0,h-stride,stride):
                for j in range(0,w-stride,stride):
                    image_windows.append(img[i:i+stride, j:j+stride, :])
        self.files = image_windows

    def __len__(self):
        return len(self.files)

           
    def __getitem__(self, idx):
        scale = self.scale
        hr = Image.fromarray(self.files[idx])
        hr_width = (hr.width // scale) * scale
        hr_height = (hr.height // scale) * scale
        lr = hr.filter(ImageFilter.GaussianBlur(radius=2))
        lr = lr.resize((hr_width // scale, hr_height // scale), resample=Image.BICUBIC)
        lr = lr.resize((lr.width * scale, lr.height * scale), resample=Image.BICUBIC)
        hr = np.moveaxis(np.array(hr).astype(np.float32), 2, 0)
        lr = np.moveaxis(np.array(lr).astype(np.float32), 2, 0)
        return lr, hr

In [None]:
train_folder_regex = './Train/*'
train_dataset = HeadDataset(glob.glob(train_folder_regex))
train_data = DataLoader(dataset=train_dataset,batch_size=16,shuffle=True,num_workers=8,pin_memory=True)

In [None]:
test_data = []
for image in glob.glob('./Test/*/*.bmp'):
    img = cv2.cvtColor(cv2.imread(image), cv2.COLOR_BGR2RGB)
    img = torch.tensor(np.array([np.moveaxis(np.array(img).astype(np.float32), 2, 0)]))
    test_data.append(img)