In [12]:
from typing import List

import torch
import torch.cuda
from torch.utils.data import DataLoader
from typing import Union
from datetime import datetime

from PIL import Image as PILImage

import numpy as np
import pyml
from models import DiscriminatorPixelMse, DownSampler, GeneratorESPCN, DiscriminatorVggMse, DiscriminatorBaselineNetwork

In [13]:
# root = "/mnt/pi8-v2/mnt/safedata/archive"
root = "/mnt/pi8-v2/mnt/safedata/datasets/2020/torrent/windows/p"
images = pyml.ImagesDataset.from_dirs_recursive([root], shuffle_seed=12, channels_order='chw')
len(images)

37385

In [3]:
class EvalHelper:
    def __init__(self):
        scale = 4
        self.generator = GeneratorESPCN(channels=128, upscale=scale)

    def demonstrate(self, label: Union[np.ndarray, torch.Tensor]) -> PILImage:
        if isinstance(label, np.ndarray):
            if len(label.shape) == 3:
                label = label[np.newaxis]
            label = torch.from_numpy(label)
        assert isinstance(label, torch.Tensor)

        label.requires_grad = False
        y = self.generator(label)
        pic = y.detach().numpy()[0]
        pic = np.moveaxis(pic, 0, 2)
        pic = np.clip(pic, 0.0, 1.0)
        img = PILImage.fromarray((pic * 255.0).astype(np.uint8))
        return img

    def load_model(self, suffix: str):
        path = f"/home/lgor/projects/2023/mySmallProjects/2023/ml_experiments/superresolution/models/espcn/generator_espcn_{suffix}.pth"
        self.generator.load_state_dict(torch.load(path))
        self.generator.eval()

In [4]:
eval_helper = EvalHelper()
eval_helper.load_model("2023-04-26-12-58-31")

In [10]:
saver = pyml.ImageSaver("/home/lgor/projects/2023/myml/srgan/src/demo/2023-04-26")

In [14]:
for i in range(200):
    img = images[i]
    if img is None:
        continue
    if img.shape[1] > 1024 or img.shape[2] > 1024:
        continue
    print(img.shape)
    pic = eval_helper.demonstrate(img)
    saver.saveCHW(img)
    saver.savePIL(pic)

(3, 612, 612)
(3, 630, 1024)
(3, 900, 600)
(3, 640, 640)
(3, 1024, 683)
(3, 377, 604)
(3, 900, 600)
(3, 1024, 768)
(3, 1024, 683)
(3, 450, 450)
(3, 965, 723)
(3, 333, 500)
(3, 900, 600)
(3, 450, 600)
(3, 630, 420)
(3, 887, 550)
(3, 747, 750)
(3, 500, 332)
(3, 900, 600)
(3, 453, 604)
(3, 865, 588)
(3, 750, 500)
(3, 900, 600)
(3, 446, 540)
(3, 426, 640)
(3, 900, 608)
(3, 1024, 747)
(3, 1000, 667)
(3, 600, 900)
(3, 539, 807)
(3, 1024, 700)
(3, 640, 960)
(3, 720, 555)
(3, 750, 500)
(3, 910, 1024)
(3, 900, 600)
(3, 770, 500)
(3, 639, 960)
(3, 604, 492)
(3, 822, 477)
(3, 423, 635)
(3, 1024, 668)
(3, 640, 480)
(3, 700, 467)
(3, 683, 1024)
(3, 750, 500)
(3, 409, 604)
(3, 1024, 680)
(3, 882, 600)
(3, 530, 340)
(3, 900, 600)
(3, 1024, 683)
(3, 612, 612)
(3, 900, 600)
(3, 900, 600)
(3, 400, 600)
(3, 859, 610)
(3, 682, 1024)
(3, 876, 617)
(3, 1024, 585)
(3, 1024, 683)
(3, 683, 1024)
(3, 600, 900)
(3, 385, 511)
(3, 612, 612)
(3, 369, 250)
(3, 1024, 682)
(3, 900, 600)
(3, 900, 600)
(3, 959, 640)
(3,