In [1]:
!pip3 install tensorboard_logger==0.1.0

[33mYou are using pip version 9.0.1, however version 10.0.0 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [25]:
import os
import pathlib
import traceback

import shutil
import torch
import torch.nn as nn
import torchvision
from tensorboard_logger import Logger
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm, trange

from settings.paths import FILTERED_MS_CELEB_IMAGES_DIR, \
                           SRGAN_MSE_LOSS_LOGS_DIR, \
                           SRGAN_VGG_LOSS_LOGS_DIR, \
                           SRGAN_LIGHT_CNN_LOSS_LOGS_DIR, \
                           SRGAN_MSE_LOSS_WEIGHTS_DIR, \
                           SRGAN_MSE_LOSS_BEST_WEIGHT, \
                           SRGAN_VGG_LOSS_GENERATOR_WEIGHTS_DIR, \
                           SRGAN_VGG_LOSS_DISCRIMINATOR_WEIGHTS_DIR, \
                           SRGAN_LIGHT_CNN_LOSS_GENERATOR_WEIGHTS_DIR, \
                           SRGAN_LIGHT_CNN_LOSS_DISCRIMINATOR_WEIGHTS_DIR
                        
                        
from utils import maybe_mkdir, remove_if_exists, get_last_file

from src.light_cnn import LightCNN_9Layers
from src.srgan import Generator, Discriminator, FeatureExtractor

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = '6, 7'

In [4]:
BATCH_SIZE = 16

In [41]:
class GraysacleFeatureExtractor(FeatureExtractor):
    def forward(self, x):
        grayscale_x = 0.21 * x[:,0:1] + 0.72 * x[:,1:2] + 0.07 * x[:,2:3]
        return super().forward(grayscale_x)

In [5]:
class MSCeleb(ImageFolder):
    def __init__(self, root):
        super().__init__(root)
        
        self._to_tensor = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        
        self._downscale = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(32, 0),
                transforms.ToTensor(),
            ]
        )
    
    def __getitem__(self, index):
        image, _ = super().__getitem__(index)
        
        image_hr = self._to_tensor(image)
        image_lr = self._downscale(image_hr)
        
        return image_hr, image_lr

In [6]:
dataset = MSCeleb(FILTERED_MS_CELEB_IMAGES_DIR)

In [7]:
dataset_loader = DataLoader(
    dataset, batch_size=16, shuffle=True,
    num_workers=4
)

In [45]:
class SRGANManager:
    def __init__(self,
                 dataset_loader,
                 n_resblocks=16, n_upsample=2,
                 generator_lr=0.0001, discriminator_lr=0.0001,
                 batch_size=16,
                 cuda=True):
        self._batch_size = batch_size
        
        self._generator = Generator(n_resblocks, n_upsample)
        self._discriminator = Discriminator()
        
        self._content_criterion = nn.MSELoss()
        self._adversarial_criterion = nn.BCELoss()
        self._ones_const = Variable(torch.ones(BATCH_SIZE, 1))
        
        self._cuda = cuda
        if self._cuda is True:
            self._generator = nn.DataParallel(self._generator).cuda()
            self._discriminator = nn.DataParallel(self._discriminator).cuda()
            self._content_criterion = self._content_criterion.cuda()
            self._adversarial_criterion = self._adversarial_criterion.cuda()
            self._ones_const = self._ones_const.cuda()
        
        self._optim_generator = optim.Adam(
            self._generator.parameters(), lr=generator_lr
        )
        self._optim_discriminator = optim.Adam(
            self._discriminator.parameters(), lr=discriminator_lr
        )
        
        self._dataset_loader = dataset_loader
        
    def load_generator_weights(self, path):
        self._generator.load_state_dict(torch.load(path))
        
    def load_discriminator_weights(self, path):
        self._discriminator.load_state_dict(torch.load(path))

    def train_mse_only(self,
                       log_dir=SRGAN_MSE_LOSS_LOGS_DIR,
                       epoch_count=30,
                       generator_weights_dir=SRGAN_MSE_LOSS_WEIGHTS_DIR,
                       save_frequency=15000,
                       values_log_frequency=300,
                       images_log_frequency=3000,
                       start_log=100,
                       train_type='careful'):
        self._train(
            log_dir=log_dir,
            epoch_count=epoch_count,
            train_function=self._mse_only_train_function,
            generator_weights_dir=generator_weights_dir,
            save_frequency=save_frequency,
            values_log_frequency=values_log_frequency,
            images_log_frequency=images_log_frequency,
            start_log=start_log,
            train_type=train_type
        )
        
    def _mse_only_train_function(self, high_res_real, high_res_fake):
        self._generator.zero_grad()

        generator_content_loss = self._content_criterion(
            high_res_fake, high_res_real,
        )

        generator_content_loss.backward()
        self._optim_generator.step()
        
        return {
            'generator_mse_loss': generator_content_loss,
        }
    
    def train_light_cnn_loss(self,
                       log_dir=SRGAN_LIGHT_CNN_LOSS_LOGS_DIR,
                       generator_weights_dir=SRGAN_LIGHT_CNN_LOSS_GENERATOR_WEIGHTS_DIR,
                       discriminator_weights_dir=SRGAN_LIGHT_CNN_LOSS_DISCRIMINATOR_WEIGHTS_DIR,
                       epoch_count=30,
                       save_frequency=5000,
                       values_log_frequency=100,
                       images_log_frequency=1000,
                       start_log=50,
                       train_type='careful',
                       load_pretrained_generator=SRGAN_MSE_LOSS_BEST_WEIGHT,
                       features_weight=3600,
                       image_weight=1,
                       content_weight=1,
                       adversarial_weight=1e-3):
        
        feature_extractor = GraysacleFeatureExtractor(
            LightCNN_9Layers(), 6
        )
        
        self._train_perceptual_loss(
            log_dir=log_dir,
            epoch_count=epoch_count,
            generator_weights_dir=generator_weights_dir,
            discriminator_weights_dir=discriminator_weights_dir,
            feature_extractor=feature_extractor,
            save_frequency=save_frequency,
            values_log_frequency=values_log_frequency,
            images_log_frequency=images_log_frequency,
            start_log=start_log,
            train_type=train_type,
            load_pretrained_generator=load_pretrained_generator,
            features_weight=features_weight,
            image_weight=image_weight,
            content_weight=content_weight,
            adversarial_weight=adversarial_weight,
        )
    
    def train_vgg_loss(self,
                       log_dir=SRGAN_VGG_LOSS_LOGS_DIR,
                       generator_weights_dir=SRGAN_VGG_LOSS_GENERATOR_WEIGHTS_DIR,
                       discriminator_weights_dir=SRGAN_VGG_LOSS_DISCRIMINATOR_WEIGHTS_DIR,
                       epoch_count=30,
                       save_frequency=5000,
                       values_log_frequency=100,
                       images_log_frequency=1000,
                       start_log=50,
                       train_type='careful',
                       load_pretrained_generator=SRGAN_MSE_LOSS_BEST_WEIGHT,
                       features_weight=0.006,
                       image_weight=1,
                       content_weight=1,
                       adversarial_weight=1e-3):
        
        feature_extractor = FeatureExtractor(
            torchvision.models.vgg19(pretrained=True)
        )
        
        self._train_perceptual_loss(
            log_dir=log_dir,
            epoch_count=epoch_count,
            generator_weights_dir=generator_weights_dir,
            discriminator_weights_dir=discriminator_weights_dir,
            feature_extractor=feature_extractor,
            save_frequency=save_frequency,
            values_log_frequency=values_log_frequency,
            images_log_frequency=images_log_frequency,
            start_log=start_log,
            train_type=train_type,
            load_pretrained_generator=load_pretrained_generator,
            features_weight=features_weight,
            image_weight=image_weight,
            content_weight=content_weight,
            adversarial_weight=adversarial_weight,
        )
    
    def _train_perceptual_loss(self,
                               log_dir,
                               generator_weights_dir,
                               discriminator_weights_dir,
                               epoch_count=30,
                               feature_extractor=None,
                               save_frequency=5000,
                               values_log_frequency=100,
                               images_log_frequency=1000,
                               start_log=50,
                               train_type='careful',
                               load_pretrained_generator=SRGAN_MSE_LOSS_BEST_WEIGHT,
                               features_weight=0.006,
                               image_weight=1,
                               content_weight=1,
                               adversarial_weight=1e-3):
        
        assert feature_extractor is not None
        
        self._features_weight = features_weight
        self._image_weight = image_weight
        self._content_weight = content_weight
        self._adversarial_weight = adversarial_weight
        
        self._feature_extractor = feature_extractor
        
        if self._cuda is True:
            self._feature_extractor.cuda()
        
        if load_pretrained_generator is not None:
            self.load_generator_weights(load_pretrained_generator)
        
        self._train(
            log_dir=log_dir,
            epoch_count=epoch_count,
            train_function=self._perceptual_loss_train_function,
            generator_weights_dir=generator_weights_dir,
            discriminator_weights_dir=discriminator_weights_dir,
            save_frequency=save_frequency,
            values_log_frequency=values_log_frequency,
            images_log_frequency=images_log_frequency,
            start_log=start_log,
            train_type=train_type
        )
        
    def _perceptual_loss_train_function(self, high_res_real, high_res_fake):
        # Additional variables
        target_real = Variable(torch.rand(self._batch_size, 1) * 0.5 + 0.7)
        target_fake = Variable(torch.rand(self._batch_size, 1) * 0.3)
        
        if self._cuda is True:
            target_real = target_real.cuda()
            target_fake = target_fake.cuda()
        
        # Train discriminator
        self._discriminator.zero_grad()
        
        discriminator_loss = self._adversarial_criterion(
            self._discriminator(high_res_real), target_real
        ) + self._adversarial_criterion(
            self._discriminator(high_res_fake.detach()), target_fake
        )
        
        discriminator_loss.backward()
        self._optim_discriminator.step()
        
        # Train generator 
        self._generator.zero_grad()
        
        real_features = self._feature_extractor(high_res_real).detach()
        fake_features = self._feature_extractor(high_res_fake)
        
        generator_content_loss_image = self._content_criterion(
            high_res_fake, high_res_real
        )
        generator_content_loss_features = self._content_criterion(
            fake_features, real_features
        )
        
        generator_content_loss_total = \
            self._image_weight * generator_content_loss_image + \
            self._features_weight * generator_content_loss_features
        
        generator_adversarial_loss = self._adversarial_criterion(
            self._discriminator(high_res_fake), self._ones_const
        )

        generator_total_loss = \
            self._content_weight * generator_content_loss_total + \
            self._adversarial_weight * generator_adversarial_loss

        generator_total_loss.backward()
        self._optim_generator.step() 
        
        return {
            'discriminator_loss': discriminator_loss,
            'generator_adversarial_loss': generator_adversarial_loss,
            'generator_content_loss_image': generator_content_loss_image,
            'generator_content_loss_features': generator_content_loss_features,
            'generator_content_loss_total': generator_content_loss_total,
            'generator_total_loss': generator_total_loss,
        }
        
    def _train(self,
               log_dir,
               epoch_count,
               train_function,
               generator_weights_dir,
               save_frequency=15000,
               values_log_frequency=300,
               images_log_frequency=3000,
               start_log=100,
               train_type='careful',
               discriminator_weights_dir=None):
        
        step_num = self._prepare_train_type(
            train_type, log_dir,
            generator_weights_dir, discriminator_weights_dir
        )
        
        try:
            logger = Logger(log_dir)

            for epoch_number in range(epoch_count):
                for image_hr, image_lr in tqdm(self._dataset_loader,
                                               'Epoch {}'.format(epoch_number + 1),
                                               total=len(self._dataset_loader)):
                    high_res_real = Variable(image_hr)
                    low_res = Variable(image_lr)
                    high_res_fake = self._generator(low_res)
                    if self._cuda is True:
                        high_res_real = high_res_real.cuda()
                        high_res_fake = high_res_fake.cuda()

                    values_dict = train_function(high_res_real, high_res_fake)

                    if step_num >= start_log:
                        if step_num % values_log_frequency == 0:
                            for key, value in values_dict.items():
                                logger.log_value(
                                    key,
                                    value,
                                    step_num
                                )

                        if step_num % images_log_frequency == 0:
                            logger.log_images(
                                'real_images',
                                high_res_real.data.cpu()[:1],
                                step_num
                            )

                            logger.log_images(
                                'lr_images',
                                low_res.data.cpu()[:1],
                                step_num
                            )

                            logger.log_images(
                                'fake_images',
                                high_res_fake.data.cpu()[:1],
                                step_num
                            )

                    if step_num % save_frequency == 0 and step_num != 0:
                        self._save_all(step_num, generator_weights_dir, discriminator_weights_dir)
                    step_num += 1
                self._save_all(step_num, generator_weights_dir, discriminator_weights_dir)
        except BaseException:
            print(traceback.format_exc())
            self._save_all(step_num, generator_weights_dir, discriminator_weights_dir)
        
    def _save_all(self, step_num, generator_weights_dir, discriminator_weights_dir=None):
        torch.save(
            self._generator.state_dict(), self._make_weight_path(generator_weights_dir, step_num)
        )
        
        if discriminator_weights_dir is not None:
            torch.save(
                self._discriminator.state_dict(), self._make_weight_path(discriminator_weights_dir, step_num)
            )
        
    def _prepare_train_type(self, train_type, log_dir,
                            generator_weights_dir, discriminator_weights_dir=None):
        step = 0
        
        if train_type is 'careful':
            if os.path.exists(log_dir) or os.path.exists(generator_weights_dir) or \
               (discriminator_weights_dir is not None and os.path.exists(discriminator_weights_dir)):
                raise AssertionError(
                    'There are previous train files:\nlog_dir: {}\ngenerator_weights_dir: {}\ndiscriminator_weights_dir: {}'.format(
                        log_dir,
                        generator_weights_dir,
                        discriminator_weights_dir,
                    )
                )
        elif train_type is 'clean':
            print(
                'Clean all train data for (Y/N):\nlog_dir: {}\ngenerator_weights_dir: {}\ndiscriminator_weights_dir: {}'.format(
                    log_dir,
                    generator_weights_dir,
                    discriminator_weights_dir,
                )
            )
            answer = input()
            if answer == 'Y':
                remove_if_exists(log_dir)
                remove_if_exists(generator_weights_dir)
                if discriminator_weights_dir is not None:
                    remove_if_exists(discriminator_weights_dir)
            else:
                raise Exception('Bad answer')
        elif train_type is 'continue':
            try:
                filename = get_last_file(generator_weights_dir)
                self.load_generator_weights(os.path.join(generator_weights_dir, filename))

                if discriminator_weights_dir is not None:
                    filename = get_last_file(discriminator_weights_dir)
                    self.load_discriminator_weights(os.path.join(discriminator_weights_dir, filename))
                
                step = int(os.path.splitext(filename)[0])
            except Exception:
                raise Exception('Nothing to load')
        else:
            raise ValueError('No such train type')
            
        pathlib.Path(generator_weights_dir).mkdir(parents=True, exist_ok=True)
        if discriminator_weights_dir is not None:
            pathlib.Path(discriminator_weights_dir).mkdir(parents=True, exist_ok=True)
        pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)
        
        return step
        
    def _make_weight_path(self, folder, step):
        return os.path.join(folder, '{:010d}.pth'.format(step))

In [46]:
srgan_manager = SRGANManager(dataset_loader)

In [12]:
srgan_manager.train_mse_only(epoch_count=30, train_type='continue')


  0%|          | 0/287187 [00:00<?, ?it/s]

Epoch: 0


[A




  0%|          | 1/287187 [00:00<49:35:23,  1.61it/s][A
  0%|          | 2/287187 [00:00<37:06:53,  2.15it/s][A
  0%|          | 4/287187 [00:00<28:11:47,  2.83it/s][A
  0%|          | 6/287187 [00:01<21:36:46,  3.69it/s][A
  0%|          | 8/287187 [00:01<16:57:47,  4.70it/s][A
  0%|          | 10/287187 [00:01<14:07:16,  5.65it/s][A
  0%|          | 12/287187 [00:01<11:51:25,  6.73it/s][A
  0%|          | 14/287187 [00:01<10:12:19,  7.82it/s][A
  0%|          | 16/287187 [00:01<9:15:05,  8.62it/s] [A
  0%|          | 18/287187 [00:02<8:36:18,  9.27it/s][A
 21%|██        | 59307/287187 [1:34:13<6:04:02, 10.43it/s] Process Process-13:
Process Process-15:
Process Process-16:
Process Process-14:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap

In [37]:
srgan_manager.train_vgg_loss(epoch_count=30, train_type='continue')


Epoch 1:   0%|          | 0/287187 [00:00<?, ?it/s][A




Epoch 1:   0%|          | 1/287187 [00:00<75:01:04,  1.06it/s][A
Epoch 1:   0%|          | 2/287187 [00:01<58:12:06,  1.37it/s][A
Epoch 1:   0%|          | 3/287187 [00:01<45:54:58,  1.74it/s][A
Epoch 1:   0%|          | 4/287187 [00:01<37:26:03,  2.13it/s][A
Epoch 1:   0%|          | 5/287187 [00:01<31:28:21,  2.53it/s][A
Epoch 1:   0%|          | 6/287187 [00:02<27:31:07,  2.90it/s][A
Epoch 1:   0%|          | 7/287187 [00:02<24:32:41,  3.25it/s][A
Epoch 1:   0%|          | 8/287187 [00:02<22:33:41,  3.54it/s][A
Epoch 1:   0%|          | 9/287187 [00:02<21:02:06,  3.79it/s][A
Epoch 1:   0%|          | 10/287187 [00:02<19:52:18,  4.01it/s][A
Epoch 1:   0%|          | 11/287187 [00:03<19:06:10,  4.18it/s][A
Epoch 1:   0%|          | 12/287187 [00:03<18:20:25,  4.35it/s][A
Epoch 1:   0%|          | 13/287187 [00:03<17:51:38,  4.47it/s][A
Epoch 1:   0%|          | 14/287187 [00:03<17:38:54,  4.52it/s][A
Epoch 1:   

Traceback (most recent call last):
  File "<ipython-input-35-0160fb21dd30>", line 278, in _train
    high_res_fake = self._generator(low_res)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 73, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 83, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/parallel_apply.py", line 59, in parallel_apply
    thread.join()
  File "/usr/lib/python3.5/threading.py", line 1054, in join
    self._wait_for_tstate_lock()
  File "/usr/lib/python3.5/threading.py", line 1070, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
KeyboardInterru

  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Exception ignored in: <bound method DataLoaderIter.__del__ of <torch.utils.data.dataloader.DataLoaderIter object at 0x7ff905293f60>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 333, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 319, in _shutdown_workers
    self.data_queue.get()
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 345, in get
    return ForkingPickler.loads(res)
  File "/usr/local/lib/python3.5/dist-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/usr/lib/python3.5/multiprocessing/resource_sharer.py", line 57, in detach


In [None]:
srgan_manager.train_light_cnn_loss(epoch_count=30, train_type='continue')

Clean all train data for (Y/N):
log_dir: /home/data/alpus/clean_fsr/data/logs/srgan/light_cnn_loss
generator_weights_dir: /home/data/alpus/clean_fsr/data/weights/srgan/light_cnn_loss/generator
discriminator_weights_dir: /home/data/alpus/clean_fsr/data/weights/srgan/light_cnn_loss/discriminator
Y



Epoch 1:   0%|          | 0/287187 [00:00<?, ?it/s][A




Epoch 1:   0%|          | 1/287187 [00:00<65:41:46,  1.21it/s][A
Epoch 1:   0%|          | 2/287187 [00:01<51:18:09,  1.55it/s][A
Epoch 1:   0%|          | 3/287187 [00:01<41:16:13,  1.93it/s][A
Epoch 1:   0%|          | 4/287187 [00:01<34:26:55,  2.32it/s][A
Epoch 1:   0%|          | 5/287187 [00:01<29:30:59,  2.70it/s][A
Epoch 1:   0%|          | 6/287187 [00:01<25:58:27,  3.07it/s][A
Epoch 1:   0%|          | 7/287187 [00:02<23:29:17,  3.40it/s][A
Epoch 1:   0%|          | 8/287187 [00:02<21:47:34,  3.66it/s][A
Epoch 1:   0%|          | 9/287187 [00:02<20:47:41,  3.84it/s][A
Epoch 1:   0%|          | 10/287187 [00:02<20:05:52,  3.97it/s][A
Epoch 1:   0%|          | 11/287187 [00:03<19:48:56,  4.03it/s][A
Epoch 1:   0%|          | 12/287187 [00:03<19:07:48,  4.17it/s][A
Epoch 1:   0%|          | 13/287187 [00:03<18:45:32,  4.25it/s][A
Epoch 1:   0%|          | 14/287187 [00:03<18:36:20,  4.29it/s][A
Epoch 1:   