In [1]:
import os
import os.path as osp
from tqdm import tqdm
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from generator_model import Generator
import config
import matplotlib.pyplot as plt
import numpy as np

In [2]:
class PhotoMonetDataset(Dataset):
    def __init__(self, root_photo, root_monet, transform=None):
        self.root_photo = root_photo
        self.root_monet = root_monet
        self.transform = transform
        self.monet_images = os.listdir(self.root_monet)
        self.photo_images = os.listdir(self.root_photo)
        self.num_monet = len(self.monet_images)
        self.num_photo = len(self.photo_images)
        self.len_dataset = max(self.num_monet, self.num_photo)

    def __len__(self):
        return self.len_dataset

    def __getitem__(self, index):
        monet_img = self.monet_images[index%self.num_monet]
        photo_img = self.photo_images[index%self.num_photo]

        monet_img = osp.join(self.root_monet, monet_img)
        photo_img = osp.join(self.root_photo, photo_img)

        monet_img = np.array(Image.open(monet_img).convert("RGB"))
        photo_img = np.array(Image.open(photo_img).convert("RGB"))

        if self.transform:
            monet_img = self.transform(monet_img)
            photo_img = self.transform(photo_img)

        return monet_img, photo_img


In [5]:
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5 for _ in range(3)],
                                [0.5 for _ in range(3)]),
        ])

test_dataset = PhotoMonetDataset(osp.join("../monet2photo", "testB"), 
                            osp.join("../monet2photo", "testA"),
                            transform=transform)

print(test_dataset.__len__())
test_dataloader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4)


751


In [13]:
gen_m = Generator(in_channels=3)
gen_m.load_state_dict(torch.load("checkpoints/1212210058/gen_m_11.pth"))

<All keys matched successfully>

In [14]:
gen_m = gen_m.to("cuda")

for i, (_, photo_img) in enumerate(test_dataloader):
    photo_img = photo_img.to("cuda")
    with torch.no_grad():
        fake_monet = gen_m(photo_img)
    fake_monet = fake_monet*0.5+0.5
    grid = torch.cat((photo_img, fake_monet))
    fake_monet = make_grid(grid)
    # plt.imshow(fake_monet.cpu().permute(1,2,0))
    # plt.show()
    save_image(fake_monet, f"testOutputs/image_{i}.png")
    if i>100:
        break
