In [1]:
"""
!wget http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz
!tar xf lfw-deepfunneled.tgz
!mkdir lfw-deepfunneled/train
!mkdir lfw-deepfunneled/test
!mv lfw-deepfunneled/[A-W]* lfw-deepfunneled/train
!mv lfw-deepfunneled/[X-Z]* lfw-deepfunneled/test
"""

'\n!wget http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz\n!tar xf lfw-deepfunneled.tgz\n!mkdir lfw-deepfunneled/train\n!mkdir lfw-deepfunneled/test\n!mv lfw-deepfunneled/[A-W]* lfw-deepfunneled/train\n!mv lfw-deepfunneled/[X-Z]* lfw-deepfunneled/test\n'

In [2]:
import torch
from torch import nn,optim
from torch.utils.data import Dataset,TensorDataset,DataLoader
import tqdm

import numpy as np

from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.utils import save_image
import time
import math

In [3]:
#ダウンサイズ用のイメージフォルダクラス
class DownSizedPairImageFolder(ImageFolder):
    def __init__(self,root,transform=None,large_size = 128,small_size = 32,**kwds):
        super().__init__(root,transform = transform, **kwds)
        self.large_resizer = transforms.Resize(large_size)
        self.small_resizer = transforms.Resize(small_size)

    def __getitem__(self,index):
        path,_=self.imgs[index]
        img=self.loader(path)

        large_img=self.large_resizer(img)
        small_img=self.small_resizer(img)

        if self.transform is not None:
            large_img=self.transform(large_img)
            small_img=self.transform(small_img)

        return small_img, large_img

In [4]:
#データローダーの準備
train_data=DownSizedPairImageFolder("./lfw-deepfunneled/train", transform = transforms.ToTensor())
    
test_data=DownSizedPairImageFolder("./lfw-deepfunneled/test", transform = transforms.ToTensor())        

batch_size=32
train_loader=DataLoader(train_data,batch_size = batch_size,shuffle = True, num_workers = 4)
test_loader=DataLoader(test_data,batch_size = batch_size,shuffle = True, num_workers = 4)

In [5]:
net=nn.Sequential(
    nn.Conv2d(3,256,4,stride=2,padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.Conv2d(256,512,4,stride=2,padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(512),
    nn.ConvTranspose2d(512,256,4,stride=2,padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.ConvTranspose2d(256,128,4,stride=2,padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(128),
    nn.ConvTranspose2d(128,64,4,stride=2,padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.ConvTranspose2d(64,3,4,stride=2,padding=1)
    )

In [6]:
dev="cuda:0"
net.to(dev)

def psnr(mse,max_v = 1.0):
    return 10*math.log10(max_v**2/mse)

In [7]:
def eval_net(net,data_loader,device = dev):
    net.eval()
    ys = []
    ypreds = []
    for x,y in data_loader:
        x = x.to(device)
        y = y.to(device)
        with torch.no_grad():
            y_pred = net(x)
        ys.append(y)
        ypreds.append(y_pred)
    ys = torch.cat(ys)
    ypreds = torch.cat(ypreds)
    score = nn.functional.mse_loss(ypreds,ys).item()
    return score

In [8]:
import os
if not os.path.exists("resolution-data"):
    os.mkdir("resolution-data")

#from PIL import Image
#import glob    
random_test_loader = DataLoader(test_data,batch_size = 4, shuffle = 4)
it = iter(random_test_loader)
x,y = next(it)


def train_net(net,train_loader,test_loader,optimizer_cls = optim.Adam,loss_fn = nn.MSELoss(),n_iter = 10,device = dev):
    train_losses = []
    train_acc = []
    val_acc = []
    optimizer = optimizer_cls(net.parameters())
    for epoch in range(n_iter):
        running_loss = 0.0
        net.train()
        n=0
        score=0
        for i , (xx,yy) in tqdm.tqdm(enumerate(train_loader),total=len(train_loader)):
            xx = xx.to(device)
            yy = yy.to(device)
            y_pred = net(xx)
            loss = loss_fn(y_pred,yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            n += len(xx)
        train_losses.append(running_loss/len(train_loader))
        val_acc.append(eval_net(net,test_loader,device))
        print(epoch,train_losses[-1],psnr(train_losses[-1]),psnr(val_acc[-1]),flush=True)
        """
        yp = net(x.to(dev)).to("cpu")
        save_image(torch.cat([y,yp],0), f'./resolution-data/{epoch}.jpg', nrow = 4)
        """

In [9]:
train_net(net,train_loader,test_loader,device = dev)

bl_recon = nn.functional.interpolate(x,128,mode = "bilinear", align_corners = True)
yp = net(x.to(dev)).to("cpu")

save_image(torch.cat([y,bl_recon,yp],0),"upscale.jpg",nrow = 4)
"""
files = sorted(glob.glob('./resolution-data/*.jpg'))  
images = list(map(lambda file : Image.open(file) , files))
images[0].save('Super_resolving_process.gif' , save_all = True , append_images = images[1:] , duration = 400 , loop = 0)
"""

100%|██████████| 409/409 [00:09<00:00, 42.26it/s]


0 0.04507972321798683 13.460187591137311 21.004050743272327


100%|██████████| 409/409 [00:09<00:00, 42.11it/s]


1 0.005329560779358436 22.73308580647135 23.540606817573067


100%|██████████| 409/409 [00:09<00:00, 42.45it/s]


2 0.004018978975157648 23.958842657653925 24.422237570796533


100%|██████████| 409/409 [00:09<00:00, 42.10it/s]


3 0.00345408890913252 24.61666487773572 25.512737846621768


100%|██████████| 409/409 [00:09<00:00, 42.10it/s]


4 0.0032383495222203305 24.896762786896613 25.960164557959263


100%|██████████| 409/409 [00:09<00:00, 41.89it/s]


5 0.002989008495570468 25.244728505766005 25.79718299721562


100%|██████████| 409/409 [00:09<00:00, 41.55it/s]


6 0.002877220074238579 25.410269183070817 25.700847238034772


100%|██████████| 409/409 [00:09<00:00, 41.58it/s]


7 0.0028052504725462795 25.520283558023642 25.688415407346987


100%|██████████| 409/409 [00:09<00:00, 41.65it/s]


8 0.002684663395422413 25.711101586328265 26.44840638280737


100%|██████████| 409/409 [00:09<00:00, 41.53it/s]


9 0.002657028075938357 25.75603856519086 26.114822450045352


"\nfiles = sorted(glob.glob('./resolution-data/*.jpg'))  \nimages = list(map(lambda file : Image.open(file) , files))\nimages[0].save('Super_resolving_process.gif' , save_all = True , append_images = images[1:] , duration = 400 , loop = 0)\n"