In [1]:
!pip3 install tensorboard_logger==0.1.0

[33mYou are using pip version 9.0.1, however version 10.0.0 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [2]:
import os
import pathlib

import shutil
import torch
import torch.nn as nn
import torchvision
from tensorboard_logger import Logger
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm, trange

from settings.paths import FILTERED_MS_CELEB_IMAGES_DIR, \
                           SRGAN_MSE_LOSS_LOGS_DIR, \
                           SRGAN_VGG_LOSS_LOGS_DIR, \
                           SRGAN_LIGHT_CNN_LOSS_LOGS_DIR, \
                           SRGAN_MSE_LOSS_WEIGHTS_DIR, \
                           SRGAN_VGG_LOSS_WEIGHTS_DIR, \
                           SRGAN_LIGHT_CNN_LOSS_WEIGHTS_DIR
from utils import maybe_mkdir, remove_if_exists, get_last_file
                           
from src.srgan import Generator, Discriminator, FeatureExtractor

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = '6, 7'

In [4]:
BATCH_SIZE = 16

In [None]:
class MSCeleb(ImageFolder):
    def __init__(self, root):
        super().__init__(root)
        
        self._to_tensor = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        
        self._downscale = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(32, 0),
                transforms.ToTensor(),
            ]
        )
    
    def __getitem__(self, index):
        image, _ = super().__getitem__(index)
        
        image_hr = self._to_tensor(image)
        image_lr = self._downscale(image_hr)
        
        return image_hr, image_lr

In [None]:
dataset = MSCeleb(FILTERED_MS_CELEB_IMAGES_DIR)

In [None]:
dataset_loader = DataLoader(
    dataset, batch_size=16, shuffle=True,
    num_workers=4
)

In [None]:
class SRGANManager:
    def __init__(self,
                 dataset_loader,
                 n_resblocks=16, n_upsample=2,
                 generator_lr=0.0001, discriminator_lr=0.0001,
                 batch_size=16,
                 cuda=True):
        self._batch_size = batch_size
        
        self._generator = Generator(n_resblocks, n_upsample)
        self._discriminator = Discriminator()
        
        self._content_criterion = nn.MSELoss()
        self._adversarial_criterion = nn.BCELoss()
        self._ones_const = Variable(torch.ones(BATCH_SIZE, 1))
        
        self._cuda = cuda
        if self._cuda is True:
            self._generator = nn.DataParallel(self._generator).cuda()
            self._discriminator = nn.DataParallel(self._discriminator).cuda()
            self._content_criterion = self._content_criterion.cuda()
            self._adversarial_criterion = self._adversarial_criterion.cuda()
            self._ones_const = self._ones_const.cuda()
        
        self._optim_generator = optim.Adam(
            self._generator.parameters(), lr=generator_lr
        )
        self._optim_discriminator = optim.Adam(
            self._discriminator.parameters(), lr=discriminator_lr
        )
        
        self._dataset_loader = dataset_loader
        
    def load_generator_weights(self, path):
        self._generator.load_state_dict(torch.load(path))
        
    def load_discriminator_weights(self, path):
        self._discriminator.load_state_dict(torch.load(path))

    def train_mse_only(self,
                       log_dir=SRGAN_MSE_LOSS_LOGS_DIR,
                       epoch_count=30,
                       generator_weights_dir=SRGAN_MSE_LOSS_WEIGHTS_DIR,
                       save_frequency=15000,
                       values_log_frequency=300,
                       images_log_frequency=3000,
                       start_log=100,
                       train_type='careful'):
        self._train(
            log_dir=log_dir,
            epoch_count=epoch_count,
            train_function=self._mse_only_train_function,
            generator_weights_dir=generator_weights_dir,
            save_frequency=save_frequency,
            values_log_frequency=values_log_frequency,
            images_log_frequency=images_log_frequency,
            start_log=start_log,
            train_type=train_type
        )
            
        
    def _mse_only_train_function(self, high_res_real, high_res_fake):
        self._generator.zero_grad()

        generator_content_loss = self._content_criterion(
            high_res_fake, high_res_real,
        )

        generator_content_loss.backward()
        self._optim_generator.step()
        
        return {
            'generator_mse_loss': generator_content_loss,
        }
        
    def _train(self,
               log_dir,
               epoch_count,
               train_function,
               generator_weights_dir,
               save_frequency=15000,
               values_log_frequency=300,
               images_log_frequency=3000,
               start_log=100,
               train_type='careful',
               discriminator_weights_dir=None):
        
        try:
            step_num = self._prepare_train_type(
                train_type, log_dir,
                generator_weights_dir, discriminator_weights_dir
            )

            logger = Logger(log_dir)

            for epoch_number in range(epoch_count):
                print('Epoch: {}'.format(epoch_number))
                for image_hr, image_lr in tqdm(self._dataset_loader, total=len(self._dataset_loader)):
                    high_res_real = Variable(image_hr)
                    low_res = Variable(image_lr)
                    high_res_fake = self._generator(low_res)
                    if self._cuda is True:
                        high_res_real = high_res_real.cuda()
                        high_res_fake = high_res_fake.cuda()

                    values_dict = train_function(high_res_real, high_res_fake)

                    if step_num >= start_log:
                        if step_num % values_log_frequency == 0:
                            for key, value in values_dict.items():
                                logger.log_value(
                                    key,
                                    value,
                                    step_num
                                )

                        if step_num % images_log_frequency == 0:
                            logger.log_images(
                                'real_images',
                                high_res_real.data.cpu()[:1],
                                step_num
                            )

                            logger.log_images(
                                'lr_images',
                                low_res.data.cpu()[:1],
                                step_num
                            )

                            logger.log_images(
                                'fake_images',
                                high_res_fake.data.cpu()[:1],
                                step_num
                            )

                    if step_num % save_frequency == 0 and step_num != 0:
                        self._save_all(step_num, generator_weights_dir, discriminator_weights_dir)

                    step_num += 1

            self._save_all(step_num, generator_weights_dir, discriminator_weights_dir)
        except KeyboardInterrupt:
            self._save_all(step_num, generator_weights_dir, discriminator_weights_dir)
        
    def _save_all(self, step_num, generator_weights_dir, discriminator_weights_dir=None):
        torch.save(
            self._generator.state_dict(), self._make_weight_path(generator_weights_dir, step_num)
        )
        
        if discriminator_weights_dir is not None:
            torch.save(
                self._discriminator.state_dict(), self._make_weight_path(discriminator_weights_dir, step_num)
            )
        
    def _prepare_train_type(self, train_type, log_dir,
                            generator_weights_dir, discriminator_weights_dir=None):
        step = 0
        
        if train_type is 'careful':
            print(log_dir)
            if os.path.exists(log_dir) or os.path.exists(generator_weights_dir) or \
               (discriminator_weights_dir is not None and os.path.exists(discriminator_weights_dir)):
                raise AssertionError('There are previous train files')
        elif train_type is 'clean':
            print(
                'Clean all train data for (Y/N):\nlog_dir: {}\ngenerator_weights_dir: {}\ndiscriminator_weights_dir: {}'.format(
                    log_dir,
                    generator_weights_dir,
                    discriminator_weights_dir,
                )
            )
            answer = input()
            if answer == 'Y':
                remove_if_exists(log_dir)
                remove_if_exists(generator_weights_dir)
                if discriminator_weights_dir is not None:
                    remove_if_exists(discriminator_weights_dir)
            else:
                raise Exception('Bad answer')
        elif train_type is 'continue':
            try:
                filename = get_last_file(generator_weights_dir)
                self.load_generator_weights(os.path.join(generator_weights_dir, filename))

                if discriminator_weights_dir is not None:
                    filename = get_last_file(discriminator_weights_dir)
                    self.load_discriminator_weights(os.path.join(discriminator_weights_dir, filename))
                
                step = int(os.path.splitext(filename)[0])
            except Exception:
                raise Exception('Nothing to load')
        else:
            raise ValueError('No such train type')
            
        pathlib.Path(generator_weights_dir).mkdir(parents=True, exist_ok=True)
        if discriminator_weights_dir is not None:
            pathlib.Path(discriminator_weights_dir).mkdir(parents=True, exist_ok=True)
        pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)
        
        return step
        
    def _make_weight_path(self, folder, step):
        return os.path.join(folder, '{:010d}.pth'.format(step))

In [None]:
srgan_manager = SRGANManager(dataset_loader)

In [None]:
srgan_manager.train_mse_only(epoch_count=30, train_type='clean')