In [7]:
import glob
import random
import os
import numpy as np

import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])


class ImageDataset(Dataset):
    def __init__(self, root, hr_shape, lr_factor):
        hr_height, hr_width = hr_shape
        # Transforms for low resolution images and high resolution images
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // lr_factor, hr_height // lr_factor), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_height), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )

        self.files = sorted(glob.glob(root + "/*.png"))

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)

In [8]:
d_set = ImageDataset('data\DIV2K_train_HR\DIV2K_train_HR',(480,480),2)
d_set.__len__()



800

In [9]:
for image in d_set:
    print(image)
    break

{'lr': tensor([[[-0.7822, -1.1589, -1.3130,  ..., -1.8268, -1.8439, -1.8439],
         [-0.7137, -0.9192, -1.2103,  ..., -1.7583, -1.8268, -1.8439],
         [-0.7993, -0.7993, -0.9192,  ..., -1.7583, -1.7925, -1.8439],
         ...,
         [-1.6898, -1.4672, -1.1589,  ..., -0.5082, -0.4226, -0.6281],
         [-1.7583, -1.4500, -0.9192,  ..., -0.4911, -0.3541, -1.0733],
         [-1.6213, -1.3987, -1.0733,  ..., -1.0904, -0.8849, -1.6042]],

        [[-0.9853, -1.2304, -1.2479,  ..., -1.9132, -1.9657, -1.9482],
         [-0.8627, -0.9853, -1.2304,  ..., -1.7906, -1.9307, -1.9657],
         [-0.8452, -0.8978, -0.8803,  ..., -1.7906, -1.8782, -1.9482],
         ...,
         [-1.7556, -1.6155, -1.3529,  ..., -0.6877, -0.5651, -0.7402],
         [-1.8431, -1.5455, -1.1253,  ..., -0.5651, -0.5126, -1.1779],
         [-1.6856, -1.4580, -1.2654,  ..., -1.0728, -1.1078, -1.6681]],

        [[-0.7587, -1.0201, -1.0724,  ..., -1.7870, -1.8044, -1.7870],
         [-0.5495, -0.7761, -0.8807,  

In [29]:
from torch.utils.data import DataLoader
dataloader = DataLoader(
    d_set,
    batch_size=8,
    shuffle=True
)

In [30]:
for data in dataloader:
    print(data)
    break

{'lr': tensor([[[[-0.5767, -1.1418, -0.9363,  ...,  0.5364,  0.5364,  0.6392],
          [-0.7993, -0.9534, -0.8164,  ...,  0.3994,  0.4337,  0.5364],
          [-0.9877, -0.9877, -1.1075,  ...,  0.2453,  0.2453,  0.2967],
          ...,
          [ 0.6734,  0.2111,  0.1939,  ..., -1.6555, -1.7069, -1.7754],
          [ 1.0331,  1.0502,  1.0844,  ..., -1.4329, -1.2788, -1.7412],
          [ 0.9132,  0.8104,  0.7591,  ..., -0.5424, -1.4500, -1.5357]],

         [[ 0.0651, -0.4776, -0.4776,  ...,  0.5378,  0.5203,  0.6078],
          [-0.2850, -0.4601, -0.3550,  ...,  0.4153,  0.4328,  0.5553],
          [-0.5301, -0.5126, -0.5126,  ...,  0.2752,  0.2577,  0.2927],
          ...,
          [ 1.0980,  0.6779,  0.6254,  ..., -1.6681, -1.7031, -1.4930],
          [ 1.4132,  1.4482,  1.4832,  ..., -1.3704, -1.4055, -1.7206],
          [ 1.3081,  1.2206,  1.0455,  ..., -0.1099, -1.4405, -1.5805]],

         [[-0.7064, -0.8110, -0.5495,  ...,  0.3568,  0.3393,  0.4439],
          [-0.6367, -0.

In [23]:
import multiprocessing

multiprocessing.cpu_count()

8