In [1]:
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.models import vgg19
import torchvision.transforms as transforms
import os.path as osp
from glob import glob
import shutil
from tqdm import tqdm

import cv2
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split

import sys
import json
import random

from torch import optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
#from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image

In [2]:
"""
ディレクトリ 構成
root
    | - input
    |        | - annotation
    |        | - cat_face
    |                   | - demo
    |                   | - test
    |                   | - train
    |        | - images
    |                   | - cat
    |
    | - outputs

"""

'\nディレクトリ 構成\nroot\n    | - input\n    |        | - annotation\n    |        | - cat_face\n    |                   | - demo\n    |                   | - test\n    |                   | - train\n    |        | - images\n    |                   | - cat\n    |\n    | - outputs\n\n'

# データのダウンロード

In [3]:
"""
#!wget http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
!tar xf 102flowers.tgz
!mkdir oxford-102
!mkdir oxford-102-train
!mkdir oxford-102-test
!mkdir params
!mkdir oxford-102/jpg
!mv jpg/*.jpg oxford-102/jpg
!cp ./oxford-102/jpg/image_0000[1-5].jpg ./oxford-102-test/
"""

'\n#!wget http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz\n!tar xf 102flowers.tgz\n!mkdir oxford-102\n!mkdir oxford-102-train\n!mkdir oxford-102-test\n!mkdir params\n!mkdir oxford-102/jpg\n!mv jpg/*.jpg oxford-102/jpg\n!cp ./oxford-102/jpg/image_0000[1-5].jpg ./oxford-102-test/\n'

In [4]:
image_dir = "./oxford-102/jpg/"
train_dir = "./oxford-102-train"
test_dir = "./oxford-102-test"

In [5]:
weight_save_dir = "./params/"
param_save_path = osp.join("./params/", 'param.json')
!mkdir ./output
image_test_save_dir = "./output"

mkdir: ディレクトリ `./output' を作成できません: ファイルが存在します


In [6]:
random_crop_times = 4
hr_const = 256
# クロップする画像のサイズ
crop_size = (hr_const, hr_const)
dataset_name = 'cat_face'

In [7]:
#データクロップ用の関数
def random_crop(image, crop_size):
    h, w, _ = image.shape

    top = np.random.randint(0, h - crop_size[0])
    left = np.random.randint(0, w - crop_size[1])

    bottom = top + crop_size[0]
    right = left + crop_size[1]

    image = image[top:bottom, left:right, :]
    return image

## 猫の名前からテスト用とデモ用を取り出す。

In [8]:
"image_{:05}".format(1)

'image_00001'

## 上で分離した訓練用のデータに対し、クロップを適用してデータを加工・増量する。

In [None]:
#"""
# 学習に用いる画像
lst = [ "image_{:05}".format(i) for i in range(1,8189)]
cropped_image_save_path = "./oxford-102-train"

for item in tqdm(lst, total=len(lst)):
    image_name = '{}.jpg'.format(item)
    image_path = osp.join(image_dir, image_name)
    image = cv2.imread(image_path)
    h, w, _ = image.shape
    # 画像のサイズが小さい時は対象から除外する。
    if (h < crop_size[0]) | (w < crop_size[1]):
        print('{} size is invalid. h: {},  w: {}'.format(image_name, h, w))
        continue
    #ランダムクロップ分だけ作る
    for num in range(random_crop_times):
        cropped_image = random_crop(image, crop_size=crop_size)
        image_save_name = '{}_{:03}.jpg'.format(item, num)
        cropped_image_save_path = osp.join(train_dir, image_save_name)
        os.makedirs(osp.dirname(cropped_image_save_path), exist_ok=True)
        cv2.imwrite(cropped_image_save_path, cropped_image)
#"""

## テスト用とデモ用の画像を専用のディレクトリに移しておく。

In [9]:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

In [10]:
class ImageDataset(Dataset):
    """
    学習のためのDatasetクラス
    32×32の低解像度の本物画像と、128×128の本物画像を出力する
    """
    def __init__(self, dataset_dir, hr_shape):
        hr_height, hr_width = hr_shape
        
        # 低解像度の画像を取得するための処理
        self.lr_transform = transforms.Compose([
            transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)])

        # 高像度の画像を取得するための処理
        self.hr_transform = transforms.Compose([
            transforms.Resize((hr_height, hr_height), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)])
        
        self.files = sorted(glob(osp.join(dataset_dir, '*')))
    
    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)
        
        return {'lr': img_lr, 'hr': img_hr}
    
    def __len__(self):
        return len(self.files)

In [11]:
class TestImageDataset(Dataset):
    def __init__(self, dataset_dir):
        self.hr_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])
        self.files = sorted(glob(osp.join(dataset_dir, '*')))
    
    def lr_transform(self, img, img_size):
        img_width, img_height = img_size
        self.__lr_transform = transforms.Compose([
            transforms.Resize((img_height // 4, 
                               img_width // 4), 
                               Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)])
        img = self.__lr_transform(img)
        return img
            
    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_size = img.size
        img_lr = self.lr_transform(img, img_size)
        img_hr = self.hr_transform(img)        
        return {'lr': img_lr, 'hr': img_hr}
    
    def __len__(self):
        return len(self.files)

# parts of ESRGAN

## 生成器の準備

## Dense Residual Blockの構築

In [12]:

class DenseResidualBlock(nn.Module):
    def __init__(self, filters, res_scale = 0.2):
        super().__init__()
        self.res_scale = res_scale
        
        def block(in_features, non_linearity = True):
            layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias = True)]
            if non_linearity:
                layers += [nn.LeakyReLU()]
            return nn.Sequential(*layers)
    
        self.b1 = block(in_features = 1 * filters)
        self.b2 = block(in_features = 2 * filters)
        self.b3 = block(in_features = 3 * filters)
        self.b4 = block(in_features = 4 * filters)
        self.b5 = block(in_features = 5 * filters, non_linearity = False)
        self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]
    
    def forward(self, x):
        inputs = x
        for block in self.blocks:
            out = block(inputs)
            inputs = torch.cat([inputs, out], dim = 1)
        return out.mul(self.res_scale) + x

In [None]:
# サイズの確認
x = torch.randn(10, 64, 5, 5)
D = DenseResidualBlock(filters = 64)
D(x).shape

## Residual in Residual Blockの構築

In [13]:
# residual-blockの再帰
class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2):
        super().__init__()
        self.res_scale = res_scale
        self.dense_blocks = nn.Sequential(
            DenseResidualBlock(filters), 
            DenseResidualBlock(filters), 
            DenseResidualBlock(filters)
        )
    
    def forward(self, x):
        return self.dense_blocks(x).mul(self.res_scale) + x

In [None]:
# サイズの確認
DD = ResidualInResidualDenseBlock(filters = 64)
DD(x).shape

## 構築してきたパーツを組み合わせて生成器を構築

In [14]:
"""
Conv2D
RinRD-Block
     .
     .
RinRD-Block
Conv2D
|---------------|
|Conve2D     |
|LeaklyReLU |
|Pixelshuffle |
|---------------|
     .
     .
|---------------|
|Conve2D     |
|LeaklyReLU |
|Pixelshuffle |
|---------------|
Conv2D
LeaklyReLU
Conv2D
"""

class GeneratorRRDB(nn.Module):
    def __init__(self, channels, filters = 64, num_res_blocks = 16, num_upsample = 2):
        super().__init__()
        
        self.conv1 = nn.Conv2d(channels, filters, kernel_size = 3, stride = 1, padding = 1)
        self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters)  for _ in range(num_res_blocks)])
        self.conv2 = nn.Conv2d(filters, filters, kernel_size = 3, stride = 1, padding = 1)
        
        upsample_layers = []
        
        for _ in range(num_upsample):
            upsample_layers += [
                nn.Conv2d(filters, filters * 4, kernel_size = 3, stride = 1, padding = 1),
                nn.LeakyReLU(),
                nn.PixelShuffle(upscale_factor = 2),
            ]
        self.upsampling = nn.Sequential(*upsample_layers)
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters, kernel_size = 3, stride = 1, padding = 1),
            nn.LeakyReLU(),
            nn.Conv2d(filters, channels, kernel_size = 3, stride = 1, padding = 1),
        )
    
    def forward(self, x):
        out1 = self.conv1(x)
        #print(1,out1.shape)
        out = self.res_blocks(out1)
        #print(2,out.shape)
        out2 = self.conv2(out)
        #print(3,out2.shape)
        out = torch.add(out1, out2)
        #print(4,out.shape)
        out = self.upsampling(out)
        #print(5,out.shape)
        out = self.conv3(out)
        #print(6,out.shape)
        return out

In [None]:
# サイズの確認
# ex : (10,3,32,32) --> (10, 3, 128, 128)
y = torch.randn(10,3,32,32)
DDD = GeneratorRRDB(3, filters = 64, num_res_blocks = 23)
DDD(y).shape

## 特徴量抽出層

In [None]:
class FeatureExtractor(nn.Module):
    """
    vgg19を応用した特徴量抽出器
    """
    def __init__(self):
        super().__init__()
        vgg19_model = vgg19(pretrained = True)
        self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])

    def forward(self, x):
        return self.vgg19_54(x)

In [None]:
# ex : (10, 3, 128, 128) --> (10, 512, 8,8)
y = torch.randn(10,3,32,32)
FE = FeatureExtractor()
FE(y).shape

In [None]:
y = torch.randn(10,3,32,32)
(y-y.mean(0, keepdim=True)).shape

## 識別器の準備

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super().__init__()
                
        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)
    
        def discriminator_block(in_filters, out_filters, first_block = False):
            layers = []
            layers.append(nn.Conv2d(in_filters, 
                                    out_filters, 
                                    kernel_size = 3, 
                                    stride = 1, 
                                    padding = 1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace = True))
            layers.append(nn.Conv2d(out_filters, 
                                    out_filters, 
                                    kernel_size = 3, 
                                    stride = 2, 
                                    padding = 1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace = True))
            return layers
        
        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            #print(discriminator_block(in_filters, out_filters,  first_block=(i == 0)))
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters
        
        layers.append(nn.Conv2d(out_filters,  1, kernel_size = 3, stride = 1,  padding = 1))
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, img):
        return self.model(img)

## 構築してきたパーツを組み合わせてESRGANを構築

In [None]:
class ESRGAN():
    def __init__(self, opt):
        # 生成器・識別器の設定
        self.generator = GeneratorRRDB(opt.channels, filters = 64, num_res_blocks = opt.residual_blocks).to(opt.device)
        self.discriminator = Discriminator(input_shape = (opt.channels, *hr_shape)).to(opt.device)

        # 特徴量抽出器の設定
        self.feature_extractor = FeatureExtractor().to(opt.device)
        self.feature_extractor.eval()
        
        # 損失関数の設定
        self.criterion_GAN = nn.BCEWithLogitsLoss().to(opt.device)
        self.criterion_content = nn.L1Loss().to(opt.device)
        self.criterion_pixel = nn.L1Loss().to(opt.device)

        # オプティマイザーの設定
        self.optimizer_G = optim.Adam(self.generator.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2))
        self.optimizer_D = optim.Adam(self.discriminator.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2))

        # デバイスとテンソル型の定義
        self.Tensor = torch.Tensor
        self.dev = opt.device
        
    #==================================================================================
    
    def pre_train(self, imgs, batch_num):
        imgs_lr = imgs['lr'].type(torch.Tensor).to(self.dev)
        imgs_hr = imgs['hr'].type(torch.Tensor).to(self.dev)

        valid = torch.tensor(np.ones((imgs_lr.size(0), *self.discriminator.output_shape)), requires_grad=False).to(self.dev)
        fake = torch.tensor(np.zeros((imgs_lr.size(0), *self.discriminator.output_shape)), requires_grad=False).to(self.dev)

        # 勾配初期化
        self.optimizer_G.zero_grad()

        # 低解像度 --> 高解像度を実行し、ピクセル単位の損失計算
        gen_hr = self.generator(imgs_lr)
        loss_pixel = self.criterion_pixel(gen_hr, imgs_hr)

        # 画素単位の損失であるloss_pixelで事前学習を行う
        loss_pixel.backward()
        self.optimizer_G.step()
        train_info = {'epoch': epoch, 'batch_num': batch_num, 'loss_pixel': loss_pixel.item()}
        """
        if batch_num == 1:
            sys.stdout.write('\n{}'.format(train_info))
        else:
            sys.stdout.write('\r{}'.format('\t'*20))
            sys.stdout.write('\r{}'.format(train_info))
        """
        
        sys.stdout.write('\r{}'.format('\t'*20))
        sys.stdout.write('\r{}'.format(train_info))
        sys.stdout.flush()
        
    #============================================================================
    
    def train(self, imgs, batch_num):
        imgs_lr = imgs['lr'].type(self.Tensor).to(self.dev)
        imgs_hr = imgs['hr'].type(self.Tensor).to(self.dev)

        ################### 
        #####生成器の損失#####
        ###################
        # 正解ラベル
        valid = torch.tensor(np.ones((imgs_lr.size(0), *self.discriminator.output_shape)), requires_grad = False).to(self.dev)
        fake = torch.tensor(np.zeros((imgs_lr.size(0), *self.discriminator.output_shape)),requires_grad = False).to(self.dev)

        # 低解像度 --> 高解像度
        self.optimizer_G.zero_grad()
        gen_hr = self.generator(imgs_lr)

        # (1)ピクセル単位の損失計算
        loss_pixel = self.criterion_pixel(gen_hr, imgs_hr)

        # (2)Adversarial loss
        # 本物画像と超解像画像の識別器による判定を見る
        pred_real = self.discriminator(imgs_hr).detach()
        pred_fake = self.discriminator(gen_hr)
        loss_GAN = self.criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

        # (3)Perceptual loss
        #特徴量の比較、シンプルな損失関数
        gen_feature = self.feature_extractor(gen_hr)
        real_feature = self.feature_extractor(imgs_hr).detach()
        loss_content = self.criterion_content(gen_feature, real_feature)

        
        loss_G = loss_content + opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel
        loss_G.backward()
        self.optimizer_G.step()

        
        ################### 
        #####識別機の損失#####
        ###################
        self.optimizer_D.zero_grad()
        pred_real = self.discriminator(imgs_hr)
        pred_fake = self.discriminator(gen_hr.detach())

        loss_real = self.criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)            
        loss_fake = self.criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)    
        loss_D = (loss_real + loss_fake) / 2
        loss_D.backward()
        self.optimizer_D.step()

        train_info = {'epoch': epoch, 'batch_num': batch_num,  'loss_D': loss_D.item(), 'loss_G': loss_G.item(),
                      'loss_content': loss_content.item(), 'loss_GAN': loss_GAN.item(), 'loss_pixel': loss_pixel.item(),}
        """
        if batch_num == 1:
            sys.stdout.write('\n{}'.format(train_info))
        else:
            sys.stdout.write('\r{}'.format('\t'*20))
            sys.stdout.write('\r{}'.format(train_info))
        """
        sys.stdout.write('\r{}'.format('\t'*20))
        sys.stdout.write('\r{}'.format(train_info))
        sys.stdout.flush()

    def save_image(self,imgs,batches_done):
        with torch.no_grad():
            imgs_lr = imgs["lr"].type(self.Tensor).to(self.dev)
            gen_hr = self.generator(imgs_lr)
            gen_hr = denormalize(gen_hr)

            image_batch_save_dir = osp.join(image_test_save_dir, '{:03}'.format(i))
            gen_hr_dir = osp.join(image_batch_save_dir, "hr_image")
            os.makedirs(image_batch_save_dir, exist_ok=True)
            save_image(gen_hr, osp.join(image_batch_save_dir, "{:09}.png".format(batches_done)), nrow=1, normalize=False)
        

    def save_weight(self, batches_done):
        """
        重みの保存
        """
        generator_weight_path = osp.join(weight_save_dir, "generator_{:08}.pth".format(batches_done))
        discriminator_weight_path = osp.join(weight_save_dir, "discriminator_{:08}.pth".format(batches_done))

        torch.save(self.generator.state_dict(), generator_weight_path)
        torch.save(self.discriminator.state_dict(), discriminator_weight_path)

In [15]:
def denormalize(tensors):
    """
    平均と標準偏差を使ってデータ値の加工
    """
    for c in range(3):
        tensors[:, c].mul_(std[c]).add_(mean[c])
    return torch.clamp(tensors, 0, 255)

In [16]:
def save_json(file, save_path, mode):
    """Jsonファイルを保存
    """
    with open(save_path, mode) as outfile:
        json.dump(file, outfile, indent=4)

In [17]:
class Opts():
    def __init__(self):
        self.n_epoch = 50
        self.residual_blocks = 17
        self.lr = 0.0002
        self.b1 = 0.9
        self.b2 = 0.999
        self.batch_size = 16
        self.n_cpu = 8
        self.warmup_batches = 500
        self.lambda_adv = 5e-3
        self.lambda_pixel = 1e-2
        self.pretrained = False
        self.dataset_name = 'cat'
        self.sample_interval = 100
        self.checkpoint_interval = 1000
        self.hr_height = hr_const
        self.hr_width = hr_const
        self.channels = 3
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        print(self.device)
    def to_dict(self):
        parameters = {
            'n_epoch': self.n_epoch,
            'hr_height': self.hr_height,
            'residual_blocks': self.residual_blocks,
            'lr': self.lr,
            'b1': self.b1,
            'b2': self.b2,
            'batch_size': self.batch_size,
            'n_cpu': self.n_cpu,
            'warmup_batches': self.warmup_batches,
            'lambda_adv': self.lambda_adv,
            'lambda_pixel': self.lambda_pixel,
            'pretrained': self.pretrained,
            'dataset_name': self.dataset_name,
            'sample_interval': self.sample_interval,
            'checkpoint_interval': self.checkpoint_interval,
            'hr_height': self.hr_height,
            'hr_width': self.hr_width,
            'channels': self.channels,
            'device': str(self.device),
        }
        return parameters
opt = Opts()
save_json(opt.to_dict(), param_save_path, 'w')

cuda:0


In [18]:
hr_shape = (opt.hr_height, opt.hr_height)
#hr_shape = (256,256)

In [19]:
train_dataloader = DataLoader(
    ImageDataset(train_dir, hr_shape = hr_shape),
    batch_size = opt.batch_size,
    shuffle = True,
    num_workers = opt.n_cpu,
)
test_dataloader = DataLoader(
    TestImageDataset(test_dir),
    batch_size = 1,
    shuffle = False,
    num_workers = opt.n_cpu,
)

In [None]:
esrgan = ESRGAN(opt)

In [None]:
for epoch in range(1, opt.n_epoch + 1):
    print("\ncurrent epoch : ", epoch)
    for batch_num, imgs in enumerate(train_dataloader):
        if batch_num < 1000:
            batches_done = (epoch - 1) * len(train_dataloader) + batch_num
            # 事前学習
            if batches_done <= opt.warmup_batches:#500
                esrgan.pre_train(imgs, batch_num)
            # メイン学習
            else:
                esrgan.train(imgs, batch_num)
            """
            # 高解像度の生成画像の保存
            if batches_done % opt.sample_interval == 0:
                for i, imgs in enumerate(test_dataloader):
                    esrgan.save_image(imgs, batches_done)
            """
            if batches_done % opt.sample_interval == 0:
                for i, imgs in enumerate(test_dataloader):
                    esrgan.save_image(imgs,batches_done)
            # 学習した重みの保存
            if batches_done % opt.checkpoint_interval == 0:
                esrgan.save_weight(batches_done)
        else:
            continue

In [20]:
esrgan = GeneratorRRDB(3, filters = 64, num_res_blocks = 17).to("cuda:0")
esrgan.load_state_dict(torch.load("./params/generator_00101000.pth"))
esrgan.to("cpu")

GeneratorRRDB(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (res_blocks): Sequential(
    (0): ResidualInResidualDenseBlock(
      (dense_blocks): Sequential(
        (0): DenseResidualBlock(
          (b1): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): LeakyReLU(negative_slope=0.01)
          )
          (b2): Sequential(
            (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): LeakyReLU(negative_slope=0.01)
          )
          (b3): Sequential(
            (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): LeakyReLU(negative_slope=0.01)
          )
          (b4): Sequential(
            (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): LeakyReLU(negative_slope=0.01)
          )
          (b5): Sequential(
            (0): Conv2d(320, 64, kernel_size=(3, 3), stride=

# Test

In [None]:
x = torch.zeros(2,3,32,32).to(dev)
y = esrgan(x)
print(y.shape)

In [21]:
for p in esrgan.parameters():
    p.requires_grad = False

# GAN

In [22]:
from torchvision.datasets import ImageFolder
img_data = ImageFolder("./oxford-102/",
                     transform=transforms.Compose([transforms.Resize(400),transforms.CenterCrop(hr_const//4),transforms.ToTensor()]))
batch_size = 64
img_loader = DataLoader(img_data,batch_size = batch_size, shuffle = True)
dev = "cuda:0"

In [23]:
nz = 100
ngf = 32
#(100,1,1) --> (3,32,32)
class GNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz,ngf*8,4,1,0,bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf*2,ngf*1,4,2,1,bias=False),
            nn.BatchNorm2d(ngf*1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf*1,3,4,2,1,bias=True),
            nn.Tanh()
            )
    def forward(self,x):
        out = self.main(x)
        return out

In [24]:
ndf=32
#(3,32,32) --> (1)
class DNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3,ndf,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf*8,1,4,1,0,bias=True)
            )

    def forward(self,x):
        out = self.main(x)
        return out.squeeze()
    
d = DNet().to(dev)
g = GNet().to(dev)

opt_d = optim.Adam(d.parameters(),lr=0.0002,betas=(0.5,0.999))
opt_g = optim.Adam(g.parameters(),lr=0.0002,betas=(0.5,0.999))

ones = torch.ones(batch_size).to(dev)
zeros = torch.zeros(batch_size).to(dev)

loss_f = nn.BCEWithLogitsLoss()

fixed_z = torch.randn(64,nz,1,1).to(dev)
#gif_z = torch.randn(1,nz,1,1).to(dev)

In [25]:
from statistics import mean

def train_dcgan(g,d,opt_g,opt_d,loader):
    log_loss_g = []
    log_loss_d = []
    
    for real_img  in tqdm(loader):
        #save_image(real_img[0],"./data_tmp/tmp.jpg")
        real_img = real_img["lr"]
        batch_len = len(real_img)
        #print(real_img.size())
        real_img = real_img.to(dev)
        z = torch.randn(batch_len,nz,1,1).to(dev)
        #------------------------------------------------------------------------------------------------------------------
        fake_img = g(z)                                                         #フェイクイメージの作成
        fake_img_tensor = fake_img.detach()                  #フェイクイメージを計算グラフから分離
        out = d.forward(fake_img)                                      #フェイクイメージの判定
        loss_g = loss_f(out, ones[:batch_len])                 #フェイク画像の損失計算 as 本物
        log_loss_g.append(loss_g.item())                        #
        d.zero_grad() , g.zero_grad()                               #勾配初期化
        loss_g.backward()                                                  #生成機の逆伝搬
        opt_g.step()                                                           #生成機の更新
        #------------------------------------------------------------------------------------------------------------------
        real_out = d.forward(real_img)                                       #本物画像の判定
        loss_d_real = loss_f(real_out,ones[:batch_len])         #本物画像の損失計算

        fake_img = fake_img_tensor                                          #
        fake_out = d.forward(fake_img_tensor)                      #
        loss_d_fake = loss_f(fake_out, zeros[:batch_len])    #フェイク画像の損失計算 as フェイク

        loss_d = loss_d_real + loss_d_fake
        log_loss_d.append(loss_d.item())
        d.zero_grad(),g.zero_grad()                                         #勾配初期化
        loss_d.backward()                                                          #判別機の逆伝搬
        opt_d.step()                                                                   #判別機の更新
        #------------------------------------------------------------------------------------------------------------------
    return mean(log_loss_g),mean(log_loss_d)

In [26]:
if not os.path.exists("data"):
    os.mkdir("data")
if not os.path.exists("data2"):
    os.mkdir("data2")
if not os.path.exists("data_tmp"):
    os.mkdir("data_tmp")

In [27]:
import time

total_time=0.0
average_time=0.0
iter_epoch=200


for epoch in range(iter_epoch):
    t_start =time.time()
    print("current epoch : ",epoch)
    train_dcgan(g,d,opt_g,opt_d,train_dataloader)
    t_finish=time.time()
    total_time+=t_finish-t_start
    average_time=total_time/(epoch+1)
    print("runtime = ",t_finish-t_start,"sec")
    print("expected remaining time = ",int(average_time*(iter_epoch-(epoch+1))),"sec",":",int(average_time*(iter_epoch-(epoch+1))/60),"minute")
    #gif_data.append(g(gif_z))
    if epoch%1==0:
        torch.save(g.state_dict(),"./data/g_{:03d}.prm".format(epoch),pickle_protocol=4)
        torch.save(d.state_dict(),"./data/d_{:03d}.prm".format(epoch),pickle_protocol=4)
        generated_img_tmp = g(fixed_z).to("cpu")
        generated_img = esrgan(generated_img_tmp)
        #g = g.to(dev)
        save_image(generated_img,"./data/{:03d}.jpg".format(epoch))
        save_image(generated_img_tmp,"./data2/{:03d}.jpg".format(epoch))
        del generated_img_tmp, generated_img


files = sorted(glob('./data/*.jpg'))  
images = list(map(lambda file : Image.open(file) , files))
images[0].save('generating_process.gif' , save_all = True , append_images = images[1:] , duration = 200 , loop = 0)

  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  0


100%|██████████| 2047/2047 [00:56<00:00, 36.12it/s]


runtime =  56.68605422973633 sec
expected remaining time =  11280 sec : 188 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  1


100%|██████████| 2047/2047 [00:55<00:00, 36.99it/s]


runtime =  55.35428762435913 sec
expected remaining time =  11091 sec : 184 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  2


100%|██████████| 2047/2047 [00:55<00:00, 36.89it/s]


runtime =  55.4928719997406 sec
expected remaining time =  11001 sec : 183 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  3


100%|██████████| 2047/2047 [00:54<00:00, 37.32it/s]


runtime =  54.86479926109314 sec
expected remaining time =  10897 sec : 181 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  4


100%|██████████| 2047/2047 [00:55<00:00, 37.07it/s]


runtime =  55.218464612960815 sec
expected remaining time =  10827 sec : 180 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  5


100%|██████████| 2047/2047 [00:54<00:00, 37.48it/s]


runtime =  54.632243156433105 sec
expected remaining time =  10742 sec : 179 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  6


100%|██████████| 2047/2047 [00:54<00:00, 37.62it/s]


runtime =  54.41776514053345 sec
expected remaining time =  10660 sec : 177 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  7


100%|██████████| 2047/2047 [00:54<00:00, 37.29it/s]


runtime =  54.90651202201843 sec
expected remaining time =  10597 sec : 176 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  8


100%|██████████| 2047/2047 [00:54<00:00, 37.85it/s]


runtime =  54.098514556884766 sec
expected remaining time =  10519 sec : 175 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  9


100%|██████████| 2047/2047 [00:54<00:00, 37.40it/s]


runtime =  54.73769545555115 sec
expected remaining time =  10457 sec : 174 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  10


100%|██████████| 2047/2047 [00:54<00:00, 37.82it/s]


runtime =  54.1345100402832 sec
expected remaining time =  10387 sec : 173 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  11


100%|██████████| 2047/2047 [00:52<00:00, 38.77it/s]


runtime =  52.81117010116577 sec
expected remaining time =  10298 sec : 171 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  12


100%|██████████| 2047/2047 [00:54<00:00, 37.50it/s]


runtime =  54.5956494808197 sec
expected remaining time =  10241 sec : 170 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  13


100%|██████████| 2047/2047 [00:53<00:00, 38.22it/s]


runtime =  53.56459021568298 sec
expected remaining time =  10170 sec : 169 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  14


100%|██████████| 2047/2047 [00:54<00:00, 37.90it/s]


runtime =  54.01160740852356 sec
expected remaining time =  10107 sec : 168 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  15


100%|██████████| 2047/2047 [00:54<00:00, 37.84it/s]


runtime =  54.11170840263367 sec
expected remaining time =  10046 sec : 167 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  16


100%|██████████| 2047/2047 [00:53<00:00, 38.46it/s]


runtime =  53.22838616371155 sec
expected remaining time =  9977 sec : 166 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  17


100%|██████████| 2047/2047 [00:54<00:00, 37.27it/s]


runtime =  54.933308839797974 sec
expected remaining time =  9927 sec : 165 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  18


100%|██████████| 2047/2047 [00:55<00:00, 36.98it/s]


runtime =  55.368277072906494 sec
expected remaining time =  9880 sec : 164 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  19


100%|██████████| 2047/2047 [00:53<00:00, 38.03it/s]


runtime =  53.84199285507202 sec
expected remaining time =  9819 sec : 163 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  20


100%|██████████| 2047/2047 [00:54<00:00, 37.42it/s]


runtime =  54.707590103149414 sec
expected remaining time =  9765 sec : 162 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  21


100%|██████████| 2047/2047 [00:55<00:00, 37.10it/s]


runtime =  55.18231797218323 sec
expected remaining time =  9716 sec : 161 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  22


100%|██████████| 2047/2047 [00:55<00:00, 37.11it/s]


runtime =  55.17454290390015 sec
expected remaining time =  9666 sec : 161 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  23


100%|██████████| 2047/2047 [00:54<00:00, 37.85it/s]


runtime =  54.09503793716431 sec
expected remaining time =  9607 sec : 160 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  24


100%|██████████| 2047/2047 [00:54<00:00, 37.64it/s]


runtime =  54.38962912559509 sec
expected remaining time =  9551 sec : 159 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  25


100%|██████████| 2047/2047 [00:54<00:00, 37.75it/s]


runtime =  54.238770484924316 sec
expected remaining time =  9495 sec : 158 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  26


100%|██████████| 2047/2047 [00:54<00:00, 37.28it/s]


runtime =  54.912288427352905 sec
expected remaining time =  9442 sec : 157 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  27


100%|██████████| 2047/2047 [00:54<00:00, 37.50it/s]


runtime =  54.593222856521606 sec
expected remaining time =  9388 sec : 156 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  28


100%|██████████| 2047/2047 [00:54<00:00, 37.29it/s]


runtime =  54.91139221191406 sec
expected remaining time =  9335 sec : 155 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  29


100%|██████████| 2047/2047 [00:54<00:00, 37.41it/s]


runtime =  54.723790884017944 sec
expected remaining time =  9281 sec : 154 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  30


100%|██████████| 2047/2047 [00:55<00:00, 37.08it/s]


runtime =  55.212308406829834 sec
expected remaining time =  9230 sec : 153 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  31


100%|██████████| 2047/2047 [00:54<00:00, 37.39it/s]


runtime =  54.75172448158264 sec
expected remaining time =  9176 sec : 152 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  32


100%|██████████| 2047/2047 [00:53<00:00, 37.92it/s]


runtime =  53.992475748062134 sec
expected remaining time =  9118 sec : 151 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  33


100%|██████████| 2047/2047 [00:53<00:00, 38.15it/s]


runtime =  53.66138482093811 sec
expected remaining time =  9059 sec : 150 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  34


100%|██████████| 2047/2047 [00:53<00:00, 38.01it/s]


runtime =  53.863736629486084 sec
expected remaining time =  9001 sec : 150 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  35


100%|██████████| 2047/2047 [00:53<00:00, 38.28it/s]


runtime =  53.48292112350464 sec
expected remaining time =  8942 sec : 149 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  36


100%|██████████| 2047/2047 [00:53<00:00, 38.55it/s]


runtime =  53.10867643356323 sec
expected remaining time =  8881 sec : 148 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  37


100%|██████████| 2047/2047 [00:53<00:00, 38.01it/s]


runtime =  53.859593629837036 sec
expected remaining time =  8824 sec : 147 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  38


100%|██████████| 2047/2047 [00:53<00:00, 38.15it/s]


runtime =  53.66917610168457 sec
expected remaining time =  8766 sec : 146 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  39


100%|██████████| 2047/2047 [00:55<00:00, 37.10it/s]


runtime =  55.180824995040894 sec
expected remaining time =  8714 sec : 145 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  40


100%|██████████| 2047/2047 [00:54<00:00, 37.84it/s]


runtime =  54.102943897247314 sec
expected remaining time =  8659 sec : 144 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  41


100%|██████████| 2047/2047 [00:54<00:00, 37.53it/s]


runtime =  54.55178785324097 sec
expected remaining time =  8604 sec : 143 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  42


100%|██████████| 2047/2047 [00:54<00:00, 37.59it/s]


runtime =  54.46657180786133 sec
expected remaining time =  8550 sec : 142 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  43


100%|██████████| 2047/2047 [00:53<00:00, 37.94it/s]


runtime =  53.963887214660645 sec
expected remaining time =  8494 sec : 141 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  44


100%|██████████| 2047/2047 [00:54<00:00, 37.82it/s]


runtime =  54.13990664482117 sec
expected remaining time =  8438 sec : 140 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  45


100%|██████████| 2047/2047 [00:53<00:00, 38.54it/s]


runtime =  53.12497615814209 sec
expected remaining time =  8379 sec : 139 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  46


100%|██████████| 2047/2047 [00:52<00:00, 38.68it/s]


runtime =  52.92608046531677 sec
expected remaining time =  8320 sec : 138 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  47


100%|██████████| 2047/2047 [00:52<00:00, 39.36it/s]


runtime =  52.021223068237305 sec
expected remaining time =  8258 sec : 137 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  48


100%|██████████| 2047/2047 [00:54<00:00, 37.28it/s]


runtime =  54.92251443862915 sec
expected remaining time =  8206 sec : 136 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  49


100%|██████████| 2047/2047 [00:53<00:00, 38.47it/s]


runtime =  53.22203063964844 sec
expected remaining time =  8148 sec : 135 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  50


100%|██████████| 2047/2047 [00:53<00:00, 38.39it/s]


runtime =  53.330730676651 sec
expected remaining time =  8091 sec : 134 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  51


100%|██████████| 2047/2047 [00:55<00:00, 36.70it/s]


runtime =  55.79169678688049 sec
expected remaining time =  8041 sec : 134 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  52


100%|██████████| 2047/2047 [00:54<00:00, 37.43it/s]


runtime =  54.70178484916687 sec
expected remaining time =  7987 sec : 133 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  53


100%|██████████| 2047/2047 [00:55<00:00, 37.00it/s]


runtime =  55.32994318008423 sec
expected remaining time =  7936 sec : 132 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  54


100%|██████████| 2047/2047 [00:53<00:00, 37.92it/s]


runtime =  53.99246573448181 sec
expected remaining time =  7880 sec : 131 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  55


100%|██████████| 2047/2047 [00:53<00:00, 37.97it/s]


runtime =  53.92208790779114 sec
expected remaining time =  7825 sec : 130 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  56


100%|██████████| 2047/2047 [00:55<00:00, 37.09it/s]


runtime =  55.20349740982056 sec
expected remaining time =  7773 sec : 129 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  57


100%|██████████| 2047/2047 [00:55<00:00, 36.95it/s]


runtime =  55.40675687789917 sec
expected remaining time =  7721 sec : 128 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  58


100%|██████████| 2047/2047 [00:54<00:00, 37.39it/s]


runtime =  54.75305247306824 sec
expected remaining time =  7668 sec : 127 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  59


100%|██████████| 2047/2047 [00:54<00:00, 37.32it/s]


runtime =  54.85090446472168 sec
expected remaining time =  7614 sec : 126 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  60


100%|██████████| 2047/2047 [00:54<00:00, 37.79it/s]


runtime =  54.1704843044281 sec
expected remaining time =  7559 sec : 125 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  61


100%|██████████| 2047/2047 [00:55<00:00, 36.76it/s]


runtime =  55.69063091278076 sec
expected remaining time =  7508 sec : 125 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  62


100%|██████████| 2047/2047 [00:55<00:00, 36.76it/s]


runtime =  55.68992900848389 sec
expected remaining time =  7456 sec : 124 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  63


100%|██████████| 2047/2047 [00:54<00:00, 37.28it/s]


runtime =  54.92309355735779 sec
expected remaining time =  7403 sec : 123 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  64


100%|██████████| 2047/2047 [00:55<00:00, 36.97it/s]


runtime =  55.3833441734314 sec
expected remaining time =  7350 sec : 122 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  65


100%|██████████| 2047/2047 [00:55<00:00, 36.84it/s]


runtime =  55.56681823730469 sec
expected remaining time =  7298 sec : 121 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  66


100%|██████████| 2047/2047 [00:54<00:00, 37.45it/s]


runtime =  54.66885733604431 sec
expected remaining time =  7244 sec : 120 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  67


100%|██████████| 2047/2047 [00:54<00:00, 37.34it/s]


runtime =  54.82917809486389 sec
expected remaining time =  7190 sec : 119 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  68


100%|██████████| 2047/2047 [00:56<00:00, 36.46it/s]


runtime =  56.15877389907837 sec
expected remaining time =  7139 sec : 118 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  69


100%|██████████| 2047/2047 [00:54<00:00, 37.22it/s]


runtime =  55.00367593765259 sec
expected remaining time =  7085 sec : 118 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  70


100%|██████████| 2047/2047 [00:54<00:00, 37.59it/s]


runtime =  54.47283697128296 sec
expected remaining time =  7031 sec : 117 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  71


100%|██████████| 2047/2047 [00:55<00:00, 37.21it/s]


runtime =  55.017781019210815 sec
expected remaining time =  6977 sec : 116 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  72


100%|██████████| 2047/2047 [00:54<00:00, 37.38it/s]


runtime =  54.76900887489319 sec
expected remaining time =  6923 sec : 115 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  73


100%|██████████| 2047/2047 [00:53<00:00, 38.13it/s]


runtime =  53.69317173957825 sec
expected remaining time =  6867 sec : 114 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  74


100%|██████████| 2047/2047 [00:53<00:00, 38.41it/s]


runtime =  53.298723459243774 sec
expected remaining time =  6811 sec : 113 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  75


100%|██████████| 2047/2047 [00:52<00:00, 38.96it/s]


runtime =  52.55122780799866 sec
expected remaining time =  6753 sec : 112 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  76


100%|██████████| 2047/2047 [00:53<00:00, 38.37it/s]


runtime =  53.35598874092102 sec
expected remaining time =  6697 sec : 111 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  77


100%|██████████| 2047/2047 [00:53<00:00, 38.23it/s]


runtime =  53.55406165122986 sec
expected remaining time =  6641 sec : 110 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  78


100%|██████████| 2047/2047 [00:53<00:00, 38.62it/s]


runtime =  53.0079550743103 sec
expected remaining time =  6584 sec : 109 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  79


100%|██████████| 2047/2047 [00:52<00:00, 38.67it/s]


runtime =  52.94516706466675 sec
expected remaining time =  6528 sec : 108 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  80


100%|██████████| 2047/2047 [00:53<00:00, 38.61it/s]


runtime =  53.02953124046326 sec
expected remaining time =  6471 sec : 107 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  81


100%|██████████| 2047/2047 [00:53<00:00, 38.12it/s]


runtime =  53.701881885528564 sec
expected remaining time =  6416 sec : 106 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  82


100%|██████████| 2047/2047 [00:54<00:00, 37.49it/s]


runtime =  54.60849666595459 sec
expected remaining time =  6362 sec : 106 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  83


100%|██████████| 2047/2047 [00:53<00:00, 38.46it/s]


runtime =  53.23490643501282 sec
expected remaining time =  6306 sec : 105 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  84


100%|██████████| 2047/2047 [00:52<00:00, 38.72it/s]


runtime =  52.86526799201965 sec
expected remaining time =  6250 sec : 104 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  85


100%|██████████| 2047/2047 [00:52<00:00, 39.06it/s]


runtime =  52.410969257354736 sec
expected remaining time =  6193 sec : 103 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  86


100%|██████████| 2047/2047 [00:53<00:00, 38.24it/s]


runtime =  53.539057970047 sec
expected remaining time =  6137 sec : 102 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  87


100%|██████████| 2047/2047 [00:52<00:00, 38.99it/s]


runtime =  52.511974573135376 sec
expected remaining time =  6081 sec : 101 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  88


100%|██████████| 2047/2047 [00:53<00:00, 38.56it/s]


runtime =  53.10345506668091 sec
expected remaining time =  6025 sec : 100 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  89


100%|██████████| 2047/2047 [00:53<00:00, 38.50it/s]


runtime =  53.17586970329285 sec
expected remaining time =  5969 sec : 99 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  90


100%|██████████| 2047/2047 [00:52<00:00, 38.73it/s]


runtime =  52.85624599456787 sec
expected remaining time =  5913 sec : 98 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  91


100%|██████████| 2047/2047 [00:53<00:00, 38.60it/s]


runtime =  53.04874300956726 sec
expected remaining time =  5858 sec : 97 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  92


100%|██████████| 2047/2047 [00:53<00:00, 38.31it/s]


runtime =  53.43649625778198 sec
expected remaining time =  5802 sec : 96 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  93


100%|██████████| 2047/2047 [00:53<00:00, 38.62it/s]


runtime =  53.01057291030884 sec
expected remaining time =  5747 sec : 95 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  94


100%|██████████| 2047/2047 [00:53<00:00, 37.99it/s]


runtime =  53.89460206031799 sec
expected remaining time =  5692 sec : 94 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  95


100%|██████████| 2047/2047 [00:53<00:00, 38.40it/s]


runtime =  53.3246865272522 sec
expected remaining time =  5637 sec : 93 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  96


100%|██████████| 2047/2047 [00:53<00:00, 38.56it/s]


runtime =  53.0979905128479 sec
expected remaining time =  5582 sec : 93 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  97


100%|██████████| 2047/2047 [00:51<00:00, 39.40it/s]


runtime =  51.964481592178345 sec
expected remaining time =  5525 sec : 92 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  98


100%|██████████| 2047/2047 [00:53<00:00, 38.41it/s]


runtime =  53.30451440811157 sec
expected remaining time =  5470 sec : 91 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  99


100%|██████████| 2047/2047 [00:53<00:00, 38.53it/s]


runtime =  53.132615089416504 sec
expected remaining time =  5415 sec : 90 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  100


100%|██████████| 2047/2047 [00:54<00:00, 37.45it/s]


runtime =  54.66817259788513 sec
expected remaining time =  5361 sec : 89 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  101


100%|██████████| 2047/2047 [00:52<00:00, 38.73it/s]


runtime =  52.86293268203735 sec
expected remaining time =  5306 sec : 88 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  102


100%|██████████| 2047/2047 [00:53<00:00, 38.29it/s]


runtime =  53.47044491767883 sec
expected remaining time =  5251 sec : 87 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  103


100%|██████████| 2047/2047 [00:53<00:00, 38.33it/s]


runtime =  53.420315980911255 sec
expected remaining time =  5196 sec : 86 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  104


100%|██████████| 2047/2047 [00:53<00:00, 38.59it/s]


runtime =  53.0481231212616 sec
expected remaining time =  5141 sec : 85 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  105


100%|██████████| 2047/2047 [00:53<00:00, 38.00it/s]


runtime =  53.87976336479187 sec
expected remaining time =  5087 sec : 84 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  106


100%|██████████| 2047/2047 [00:52<00:00, 38.69it/s]


runtime =  52.90809178352356 sec
expected remaining time =  5032 sec : 83 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  107


100%|██████████| 2047/2047 [00:53<00:00, 38.23it/s]


runtime =  53.54304909706116 sec
expected remaining time =  4977 sec : 82 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  108


100%|██████████| 2047/2047 [00:53<00:00, 38.38it/s]


runtime =  53.339951276779175 sec
expected remaining time =  4922 sec : 82 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  109


100%|██████████| 2047/2047 [00:53<00:00, 38.30it/s]


runtime =  53.45494556427002 sec
expected remaining time =  4868 sec : 81 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  110


100%|██████████| 2047/2047 [00:53<00:00, 38.57it/s]


runtime =  53.08002781867981 sec
expected remaining time =  4813 sec : 80 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  111


100%|██████████| 2047/2047 [00:53<00:00, 38.47it/s]


runtime =  53.224215030670166 sec
expected remaining time =  4758 sec : 79 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  112


100%|██████████| 2047/2047 [00:53<00:00, 38.37it/s]


runtime =  53.35719037055969 sec
expected remaining time =  4703 sec : 78 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  113


100%|██████████| 2047/2047 [00:53<00:00, 38.22it/s]


runtime =  53.56715750694275 sec
expected remaining time =  4649 sec : 77 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  114


100%|██████████| 2047/2047 [00:53<00:00, 38.03it/s]


runtime =  53.83413481712341 sec
expected remaining time =  4595 sec : 76 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  115


100%|██████████| 2047/2047 [00:53<00:00, 38.29it/s]


runtime =  53.46772122383118 sec
expected remaining time =  4540 sec : 75 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  116


100%|██████████| 2047/2047 [00:53<00:00, 38.21it/s]


runtime =  53.5847806930542 sec
expected remaining time =  4486 sec : 74 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  117


100%|██████████| 2047/2047 [00:53<00:00, 38.39it/s]


runtime =  53.31902360916138 sec
expected remaining time =  4431 sec : 73 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  118


100%|██████████| 2047/2047 [00:52<00:00, 38.66it/s]


runtime =  52.96636462211609 sec
expected remaining time =  4377 sec : 72 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  119


100%|██████████| 2047/2047 [00:53<00:00, 37.94it/s]


runtime =  53.9650502204895 sec
expected remaining time =  4322 sec : 72 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  120


100%|██████████| 2047/2047 [00:53<00:00, 37.97it/s]


runtime =  53.91427254676819 sec
expected remaining time =  4268 sec : 71 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  121


100%|██████████| 2047/2047 [00:53<00:00, 38.14it/s]


runtime =  53.668251276016235 sec
expected remaining time =  4214 sec : 70 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  122


100%|██████████| 2047/2047 [00:53<00:00, 38.53it/s]


runtime =  53.13909840583801 sec
expected remaining time =  4159 sec : 69 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  123


100%|██████████| 2047/2047 [00:52<00:00, 38.80it/s]


runtime =  52.77311372756958 sec
expected remaining time =  4105 sec : 68 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  124


100%|██████████| 2047/2047 [00:52<00:00, 39.02it/s]


runtime =  52.47173571586609 sec
expected remaining time =  4050 sec : 67 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  125


100%|██████████| 2047/2047 [00:53<00:00, 38.17it/s]


runtime =  53.63289022445679 sec
expected remaining time =  3996 sec : 66 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  126


100%|██████████| 2047/2047 [00:53<00:00, 38.38it/s]


runtime =  53.349374771118164 sec
expected remaining time =  3941 sec : 65 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  127


100%|██████████| 2047/2047 [00:52<00:00, 38.65it/s]


runtime =  52.96210479736328 sec
expected remaining time =  3887 sec : 64 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  128


100%|██████████| 2047/2047 [00:53<00:00, 38.40it/s]


runtime =  53.309080839157104 sec
expected remaining time =  3832 sec : 63 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  129


100%|██████████| 2047/2047 [00:53<00:00, 38.61it/s]


runtime =  53.03487777709961 sec
expected remaining time =  3778 sec : 62 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  130


100%|██████████| 2047/2047 [00:53<00:00, 38.14it/s]


runtime =  53.68079710006714 sec
expected remaining time =  3724 sec : 62 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  131


100%|██████████| 2047/2047 [00:53<00:00, 38.06it/s]


runtime =  53.781121015548706 sec
expected remaining time =  3670 sec : 61 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  132


100%|██████████| 2047/2047 [00:53<00:00, 37.98it/s]


runtime =  53.90377354621887 sec
expected remaining time =  3616 sec : 60 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  133


100%|██████████| 2047/2047 [00:53<00:00, 38.30it/s]


runtime =  53.462849617004395 sec
expected remaining time =  3561 sec : 59 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  134


100%|██████████| 2047/2047 [00:52<00:00, 38.63it/s]


runtime =  52.9954469203949 sec
expected remaining time =  3507 sec : 58 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  135


100%|██████████| 2047/2047 [00:52<00:00, 38.66it/s]


runtime =  52.95542669296265 sec
expected remaining time =  3452 sec : 57 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  136


100%|██████████| 2047/2047 [00:53<00:00, 38.42it/s]


runtime =  53.28408932685852 sec
expected remaining time =  3398 sec : 56 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  137


100%|██████████| 2047/2047 [00:54<00:00, 37.80it/s]


runtime =  54.16818976402283 sec
expected remaining time =  3344 sec : 55 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  138


100%|██████████| 2047/2047 [00:52<00:00, 38.80it/s]


runtime =  52.76704716682434 sec
expected remaining time =  3290 sec : 54 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  139


100%|██████████| 2047/2047 [00:54<00:00, 37.78it/s]


runtime =  54.1822509765625 sec
expected remaining time =  3236 sec : 53 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  140


100%|██████████| 2047/2047 [00:52<00:00, 38.82it/s]


runtime =  52.74682903289795 sec
expected remaining time =  3182 sec : 53 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  141


100%|██████████| 2047/2047 [00:52<00:00, 39.34it/s]


runtime =  52.05100464820862 sec
expected remaining time =  3127 sec : 52 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  142


100%|██████████| 2047/2047 [00:53<00:00, 38.01it/s]


runtime =  53.855600357055664 sec
expected remaining time =  3073 sec : 51 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  143


100%|██████████| 2047/2047 [00:53<00:00, 38.55it/s]


runtime =  53.10473322868347 sec
expected remaining time =  3019 sec : 50 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  144


100%|██████████| 2047/2047 [00:53<00:00, 38.23it/s]


runtime =  53.54701066017151 sec
expected remaining time =  2965 sec : 49 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  145


100%|██████████| 2047/2047 [00:52<00:00, 38.67it/s]


runtime =  52.936683654785156 sec
expected remaining time =  2910 sec : 48 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  146


100%|██████████| 2047/2047 [00:54<00:00, 37.59it/s]


runtime =  54.465893507003784 sec
expected remaining time =  2857 sec : 47 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  147


100%|██████████| 2047/2047 [00:53<00:00, 38.34it/s]


runtime =  53.40782928466797 sec
expected remaining time =  2803 sec : 46 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  148


100%|██████████| 2047/2047 [00:52<00:00, 38.90it/s]


runtime =  52.62426471710205 sec
expected remaining time =  2748 sec : 45 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  149


100%|██████████| 2047/2047 [00:54<00:00, 37.48it/s]


runtime =  54.62202429771423 sec
expected remaining time =  2695 sec : 44 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  150


100%|██████████| 2047/2047 [00:52<00:00, 38.85it/s]


runtime =  52.689878702163696 sec
expected remaining time =  2640 sec : 44 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  151


100%|██████████| 2047/2047 [00:53<00:00, 38.33it/s]


runtime =  53.41001844406128 sec
expected remaining time =  2586 sec : 43 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  152


100%|██████████| 2047/2047 [00:52<00:00, 39.19it/s]


runtime =  52.23854207992554 sec
expected remaining time =  2532 sec : 42 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  153


100%|██████████| 2047/2047 [00:53<00:00, 38.15it/s]


runtime =  53.66167664527893 sec
expected remaining time =  2478 sec : 41 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  154


100%|██████████| 2047/2047 [00:52<00:00, 38.76it/s]


runtime =  52.819095611572266 sec
expected remaining time =  2424 sec : 40 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  155


100%|██████████| 2047/2047 [00:53<00:00, 38.12it/s]


runtime =  53.706480979919434 sec
expected remaining time =  2370 sec : 39 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  156


100%|██████████| 2047/2047 [00:52<00:00, 38.66it/s]


runtime =  52.95649862289429 sec
expected remaining time =  2316 sec : 38 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  157


100%|██████████| 2047/2047 [00:52<00:00, 39.23it/s]


runtime =  52.18588161468506 sec
expected remaining time =  2261 sec : 37 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  158


100%|██████████| 2047/2047 [00:52<00:00, 39.10it/s]


runtime =  52.357303619384766 sec
expected remaining time =  2207 sec : 36 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  159


100%|██████████| 2047/2047 [00:52<00:00, 39.01it/s]


runtime =  52.48251700401306 sec
expected remaining time =  2153 sec : 35 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  160


100%|██████████| 2047/2047 [00:53<00:00, 38.60it/s]


runtime =  53.03030300140381 sec
expected remaining time =  2099 sec : 34 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  161


100%|██████████| 2047/2047 [00:53<00:00, 38.54it/s]


runtime =  53.11544704437256 sec
expected remaining time =  2045 sec : 34 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  162


100%|██████████| 2047/2047 [00:53<00:00, 38.17it/s]


runtime =  53.64231300354004 sec
expected remaining time =  1991 sec : 33 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  163


100%|██████████| 2047/2047 [00:53<00:00, 38.55it/s]


runtime =  53.11044645309448 sec
expected remaining time =  1937 sec : 32 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  164


100%|██████████| 2047/2047 [00:52<00:00, 38.81it/s]


runtime =  52.74509859085083 sec
expected remaining time =  1883 sec : 31 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  165


100%|██████████| 2047/2047 [00:53<00:00, 38.45it/s]


runtime =  53.24551582336426 sec
expected remaining time =  1829 sec : 30 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  166


100%|██████████| 2047/2047 [00:52<00:00, 38.98it/s]


runtime =  52.512885332107544 sec
expected remaining time =  1775 sec : 29 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  167


100%|██████████| 2047/2047 [00:53<00:00, 38.33it/s]


runtime =  53.41865253448486 sec
expected remaining time =  1721 sec : 28 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  168


100%|██████████| 2047/2047 [00:52<00:00, 38.82it/s]


runtime =  52.74603772163391 sec
expected remaining time =  1667 sec : 27 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  169


100%|██████████| 2047/2047 [00:54<00:00, 37.68it/s]


runtime =  54.32466125488281 sec
expected remaining time =  1613 sec : 26 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  170


100%|██████████| 2047/2047 [00:52<00:00, 38.96it/s]


runtime =  52.54910850524902 sec
expected remaining time =  1559 sec : 25 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  171


100%|██████████| 2047/2047 [00:54<00:00, 37.86it/s]


runtime =  54.07499146461487 sec
expected remaining time =  1506 sec : 25 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  172


100%|██████████| 2047/2047 [00:55<00:00, 36.87it/s]


runtime =  55.53730821609497 sec
expected remaining time =  1452 sec : 24 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  173


100%|██████████| 2047/2047 [00:52<00:00, 38.99it/s]


runtime =  52.501383543014526 sec
expected remaining time =  1398 sec : 23 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  174


100%|██████████| 2047/2047 [00:53<00:00, 38.59it/s]


runtime =  53.05550742149353 sec
expected remaining time =  1344 sec : 22 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  175


100%|██████████| 2047/2047 [00:53<00:00, 38.40it/s]


runtime =  53.30738663673401 sec
expected remaining time =  1290 sec : 21 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  176


100%|██████████| 2047/2047 [00:53<00:00, 38.22it/s]


runtime =  53.57560920715332 sec
expected remaining time =  1237 sec : 20 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  177


100%|██████████| 2047/2047 [00:53<00:00, 37.99it/s]


runtime =  53.88811993598938 sec
expected remaining time =  1183 sec : 19 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  178


100%|██████████| 2047/2047 [00:53<00:00, 38.04it/s]


runtime =  53.82120180130005 sec
expected remaining time =  1129 sec : 18 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  179


100%|██████████| 2047/2047 [00:53<00:00, 38.49it/s]


runtime =  53.20048260688782 sec
expected remaining time =  1075 sec : 17 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  180


100%|██████████| 2047/2047 [00:54<00:00, 37.53it/s]


runtime =  54.555978775024414 sec
expected remaining time =  1021 sec : 17 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  181


100%|██████████| 2047/2047 [00:52<00:00, 39.08it/s]


runtime =  52.385721921920776 sec
expected remaining time =  968 sec : 16 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  182


100%|██████████| 2047/2047 [00:53<00:00, 38.07it/s]


runtime =  53.775893688201904 sec
expected remaining time =  914 sec : 15 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  183


100%|██████████| 2047/2047 [00:53<00:00, 38.60it/s]


runtime =  53.03929376602173 sec
expected remaining time =  860 sec : 14 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  184


100%|██████████| 2047/2047 [00:53<00:00, 38.55it/s]


runtime =  53.116087913513184 sec
expected remaining time =  806 sec : 13 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  185


100%|██████████| 2047/2047 [00:52<00:00, 38.69it/s]


runtime =  52.925132751464844 sec
expected remaining time =  752 sec : 12 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  186


100%|██████████| 2047/2047 [00:52<00:00, 39.22it/s]


runtime =  52.202178716659546 sec
expected remaining time =  698 sec : 11 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  187


100%|██████████| 2047/2047 [00:53<00:00, 38.12it/s]


runtime =  53.69845795631409 sec
expected remaining time =  645 sec : 10 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  188


100%|██████████| 2047/2047 [00:53<00:00, 38.09it/s]


runtime =  53.75487542152405 sec
expected remaining time =  591 sec : 9 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  189


100%|██████████| 2047/2047 [00:53<00:00, 38.24it/s]


runtime =  53.540772914886475 sec
expected remaining time =  537 sec : 8 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  190


100%|██████████| 2047/2047 [00:53<00:00, 38.21it/s]


runtime =  53.59064030647278 sec
expected remaining time =  483 sec : 8 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  191


100%|██████████| 2047/2047 [00:53<00:00, 38.42it/s]


runtime =  53.28841972351074 sec
expected remaining time =  430 sec : 7 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  192


100%|██████████| 2047/2047 [00:53<00:00, 38.38it/s]


runtime =  53.33629083633423 sec
expected remaining time =  376 sec : 6 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  193


100%|██████████| 2047/2047 [00:52<00:00, 38.75it/s]


runtime =  52.83329963684082 sec
expected remaining time =  322 sec : 5 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  194


100%|██████████| 2047/2047 [00:53<00:00, 38.11it/s]


runtime =  53.71681332588196 sec
expected remaining time =  268 sec : 4 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  195


100%|██████████| 2047/2047 [00:52<00:00, 38.77it/s]


runtime =  52.812968015670776 sec
expected remaining time =  214 sec : 3 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  196


100%|██████████| 2047/2047 [00:53<00:00, 38.48it/s]


runtime =  53.20714545249939 sec
expected remaining time =  161 sec : 2 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  197


100%|██████████| 2047/2047 [00:52<00:00, 39.14it/s]


runtime =  52.311509132385254 sec
expected remaining time =  107 sec : 1 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  198


100%|██████████| 2047/2047 [00:54<00:00, 37.68it/s]


runtime =  54.33315396308899 sec
expected remaining time =  53 sec : 0 minute


  0%|          | 0/2047 [00:00<?, ?it/s]

current epoch :  199


100%|██████████| 2047/2047 [00:54<00:00, 37.89it/s]


runtime =  54.03066825866699 sec
expected remaining time =  0 sec : 0 minute


In [None]:
#del esrgan

In [None]:
#esrgan = GeneratorRRDB(3, filters = 64, num_res_blocks = 17).to("cuda:0")
#esrgan.load_state_dict(torch.load("./params/generator_00101000.pth"))

In [None]:
batch_size = 64
dev = "cuda:0"

In [None]:
nz = 100
ngf = 32
class GNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz,ngf*4,4,1,0,bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf*2,ngf*1,4,2,1,bias=False),
            nn.BatchNorm2d(ngf*1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf*1,3,4,2,1,bias=True),
            nn.Tanh()
            )

    def forward(self,x):
        out = self.main(x)
        return out

In [None]:
ndf=32
#(3,64,64)
class DNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3,ndf,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2,inplace=True),###
            nn.Conv2d(ndf*8,ndf*8,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(ndf*8,1,4,1,0,bias=True)
            )

    def forward(self,x):
        out = self.main(x)
        return out.squeeze()
    
d = DNet().to(dev)
g = GNet().to(dev)

opt_d = optim.Adam(d.parameters(),lr=0.0002,betas=(0.5,0.999))
opt_g = optim.Adam(g.parameters(),lr=0.0002,betas=(0.5,0.999))

ones = torch.ones(batch_size).to(dev)
zeros = torch.zeros(batch_size).to(dev)

loss_f = nn.BCEWithLogitsLoss()

fixed_z = torch.randn(8,nz,1,1).to(dev)
#gif_z = torch.randn(1,nz,1,1).to(dev)

In [None]:
def train_dcgan(g,d,opt_g,opt_d,loader):
    #log_loss_g = []
    #log_loss_d = []
    
    for real_img in tqdm(loader):
        #real_lr = real_img["lr"]
        real_hr = real_img["hr"]
        batch_len = len(real_hr)
        #print(real_img.size())
        real_hr = real_hr.to(dev)
        z = torch.randn(batch_len,nz,1,1).to(dev)
        #------------------------------------------------------------------------------------------------------------------
        fake_img = esrgan.generator(g(z))                                                         #フェイクイメージの作成
        fake_img_tensor = fake_img.detach()                  #フェイクイメージを計算グラフから分離
        out = d.forward(fake_img)                                      #フェイクイメージの判定
        loss_g = loss_f(out, ones[:batch_len])                 #フェイク画像の損失計算 as 本物
        #log_loss_g.append(loss_g.item())                        #
        d.zero_grad() , g.zero_grad()                               #勾配初期化
        loss_g.backward()                                                  #生成機の逆伝搬
        opt_g.step()                                                           #生成機の更新
        #------------------------------------------------------------------------------------------------------------------
        real_out = d.forward(real_hr)                                       #本物画像の判定
        loss_d_real = loss_f(real_out,ones[:batch_len])         #本物画像の損失計算

        fake_img = fake_img_tensor                                          #
        fake_out = d.forward(fake_img_tensor)                      #
        loss_d_fake = loss_f(fake_out, zeros[:batch_len])    #フェイク画像の損失計算 as フェイク

        loss_d = loss_d_real + loss_d_fake
        #log_loss_d.append(loss_d.item())
        d.zero_grad(),g.zero_grad()                                         #勾配初期化
        loss_d.backward()                                                          #判別機の逆伝搬
        opt_d.step()                                                                   #判別機の更新
        #------------------------------------------------------------------------------------------------------------------
    return loss_g,loss_d

In [None]:
if not os.path.exists("data-test"):
    os.mkdir("data-test")
if not os.path.exists("data2-test"):
    os.mkdir("data2-test")

In [None]:
import time

total_time=0.0
average_time=0.0
iter_epoch=200


for epoch in range(iter_epoch):
    t_start =time.time()
    print("current epoch : ",epoch)
    train_dcgan(g,d,opt_g,opt_d,train_dataloader)
    t_finish=time.time()
    total_time+=t_finish-t_start
    average_time=total_time/(epoch+1)
    print("runtime = ",t_finish-t_start,"sec")
    print("expected remaining time = ",int(average_time*(iter_epoch-(epoch+1))),"sec",":",int(average_time*(iter_epoch-(epoch+1))/60),"minute")
    #gif_data.append(g(gif_z))
    if epoch%1==0:
        torch.save(g.state_dict(),"./data-test/g_{:03d}.prm".format(epoch),pickle_protocol=4)
        torch.save(d.state_dict(),"./data-test/d_{:03d}.prm".format(epoch),pickle_protocol=4)
        generated_img = esrgan.generator(g(fixed_z))
        #generated_img = esrgan.generator(generated_img_tmp)
        #g = g.to(dev)
        save_image(generated_img,"./data-test/{:03d}.jpg".format(epoch))
        #save_image(generated_img_tmp,"./data2-test/{:03d}.jpg".format(epoch))
        del generated_img


files = sorted(glob('./data-test/*.jpg'))  
images = list(map(lambda file : Image.open(file) , files))
images[0].save('generating_process2.gif' , save_all = True , append_images = images[1:] , duration = 200 , loop = 0)