In [1]:
!pip3 install tensorboard_logger==0.1.0

[33mYou are using pip version 9.0.1, however version 10.0.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [2]:
import datetime
import os
import pathlib
import traceback
from collections import OrderedDict

import shutil
import torch
import torch.nn as nn
import torchvision
from tensorboard_logger import Logger
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm, trange

from settings.paths import FILTERED_MS_CELEB_IMAGES_DIR, \
                           SRGAN_MSE_LOSS_LOGS_DIR, \
                           ERRORS_LOGS_DIR, \
                           LIGHT_CNN_9_WEIGHT, \
                           SRGAN_VGG_LOSS_3_1_LOGS_DIR, \
                           SRGAN_VGG_LOSS_3_1_NO_ADVERSARIAL_LOGS_DIR, \
                           SRGAN_LIGHT_CNN_9_LOSS_MFM4_LOGS_DIR, \
                           SRGAN_LIGHT_CNN_9_LOSS_MFM4_NO_ADVERSARIAL_LOGS_DIR, \
                           SRGAN_LIGHT_CNN_9_LOSS_FC_NO_ADVERSARIAL_LOGS_DIR, \
                           SRGAN_LIGHT_CNN_9_LOSS_MFM4_NO_ADVERSARIAL_NO_IMAGE_LOGS_DIR, \
                           SRGAN_MSE_LOSS_WEIGHTS_DIR, \
                           SRGAN_MSE_LOSS_BEST_WEIGHT, \
                           SRGAN_VGG_LOSS_3_1_GENERATOR_WEIGHTS_DIR, \
                           SRGAN_VGG_LOSS_3_1_DISCRIMINATOR_WEIGHTS_DIR, \
                           SRGAN_VGG_LOSS_3_1_NO_ADVERSARIAL_WEIGHTS_DIR, \
                           SRGAN_LIGHT_CNN_9_LOSS_MFM4_GENERATOR_WEIGHTS_DIR, \
                           SRGAN_LIGHT_CNN_9_LOSS_MFM4_DISCRIMINATOR_WEIGHTS_DIR, \
                           SRGAN_LIGHT_CNN_9_LOSS_MFM4_NO_ADVERSARIAL_WEIGHTS_DIR, \
                           SRGAN_LIGHT_CNN_9_LOSS_FC_NO_ADVERSARIAL_WEIGHTS_DIR, \
                           SRGAN_LIGHT_CNN_9_LOSS_MFM4_NO_ADVERSARIAL_NO_IMAGE_WEIGHTS_DIR
                        
                        
from utils import maybe_mkdir, remove_if_exists, get_last_file

from src.light_cnn import LightCNN_9Layers
from src.srgan import Generator, Discriminator, FeatureExtractor

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = '6, 7'

In [4]:
BATCH_SIZE = 16

In [5]:
def rgb_to_gray(x):
    return 0.21 * x[:,0:1] + 0.72 * x[:,1:2] + 0.07 * x[:,2:3]

In [6]:
class Grayscale2RGBModule(nn.Module):
    def __init__(self, submodule, result_extractor=lambda x: x):
        super().__init__()
        self.submodule = submodule
        self.result_extractor = result_extractor
    
    def forward(self, x):
        x = rgb_to_gray(x)
        result = self.submodule.forward(x)
        return self.result_extractor(result)

In [7]:
class MSCeleb(ImageFolder):
    def __init__(self, root):
        super().__init__(root)
        
        self._to_tensor = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        
        self._downscale = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(32, 0),
                transforms.ToTensor(),
            ]
        )
    
    def __getitem__(self, index):
        image, _ = super().__getitem__(index)
        
        image_hr = self._to_tensor(image)
        image_lr = self._downscale(image_hr)
        
        return image_hr, image_lr

In [8]:
dataset = MSCeleb(FILTERED_MS_CELEB_IMAGES_DIR)

In [9]:
dataset_loader = DataLoader(
    dataset, batch_size=16, shuffle=True,
    num_workers=4, drop_last=True
)

In [10]:
class SRGANManager:
    def __init__(self,
                 dataset_loader,
                 n_resblocks=16, n_upsample=2,
                 generator_lr=0.0001, discriminator_lr=0.0001,
                 batch_size=16,
                 cuda=True):
        self._batch_size = batch_size
        
        self._generator = Generator(n_resblocks, n_upsample)
        self._discriminator = Discriminator()
        
        self._content_criterion = nn.MSELoss()
        self._adversarial_criterion = nn.BCELoss()
        self._ones_const = Variable(torch.ones(BATCH_SIZE, 1))
        
        self._cuda = cuda
        if self._cuda is True:
            self._generator = nn.DataParallel(self._generator).cuda()
            self._discriminator = nn.DataParallel(self._discriminator).cuda()
            self._content_criterion = self._content_criterion.cuda()
            self._adversarial_criterion = self._adversarial_criterion.cuda()
            self._ones_const = self._ones_const.cuda()
        
        self._optim_generator = optim.Adam(
            self._generator.parameters(), lr=generator_lr
        )
        self._optim_discriminator = optim.Adam(
            self._discriminator.parameters(), lr=discriminator_lr
        )
        
        self._dataset_loader = dataset_loader
        
    def load_generator_weights(self, path):
        self._generator.load_state_dict(torch.load(path))
        
    def load_discriminator_weights(self, path):
        self._discriminator.load_state_dict(torch.load(path))

    def train_mse_only(self,
                       log_dir=SRGAN_MSE_LOSS_LOGS_DIR,
                       epoch_count=3,
                       generator_weights_dir=SRGAN_MSE_LOSS_WEIGHTS_DIR,
                       save_frequency=15000,
                       values_log_frequency=300,
                       images_log_frequency=3000,
                       start_log=100,
                       train_type='careful'):
        self._train(
            log_dir=log_dir,
            epoch_count=epoch_count,
            train_function=self._mse_only_train_function,
            generator_weights_dir=generator_weights_dir,
            save_frequency=save_frequency,
            values_log_frequency=values_log_frequency,
            images_log_frequency=images_log_frequency,
            start_log=start_log,
            train_type=train_type
        )
        
    def _mse_only_train_function(self, high_res_real, high_res_fake):
        self._generator.zero_grad()

        generator_content_loss = self._content_criterion(
            high_res_fake, high_res_real,
        )

        generator_content_loss.backward()
        self._optim_generator.step()
        
        return {
            'generator_mse_loss': generator_content_loss,
        }
    
    def train_light_cnn_9_loss_mfm4(self,
                                    log_dir=SRGAN_LIGHT_CNN_9_LOSS_MFM4_LOGS_DIR,
                                    generator_weights_dir=SRGAN_LIGHT_CNN_9_LOSS_MFM4_GENERATOR_WEIGHTS_DIR,
                                    epoch_count=3,
                                    save_frequency=5000,
                                    values_log_frequency=100,
                                    images_log_frequency=1000,
                                    start_log=50,
                                    train_type='careful',
                                    load_pretrained_generator=SRGAN_MSE_LOSS_BEST_WEIGHT,
                                    features_weight=0.002,
                                    image_weight=1,
                                    content_weight=1,
                                    adversarial_weight=1e-3,
                                    discriminator_weights_dir=None):
        
        light_cnn = self._init_light_cnn(
            weight_path=LIGHT_CNN_9_WEIGHT,
            name='LightCNN_9'
        )
        
        grayscale_feature_extractor = FeatureExtractor(
            light_cnn, 6,
        )
        
        feature_extractor = Grayscale2RGBModule(
            grayscale_feature_extractor
        )
        
        self._train_perceptual_loss(
            log_dir=log_dir,
            epoch_count=epoch_count,
            generator_weights_dir=generator_weights_dir,
            discriminator_weights_dir=discriminator_weights_dir,
            feature_extractor=feature_extractor,
            save_frequency=save_frequency,
            values_log_frequency=values_log_frequency,
            images_log_frequency=images_log_frequency,
            start_log=start_log,
            train_type=train_type,
            load_pretrained_generator=load_pretrained_generator,
            features_weight=features_weight,
            image_weight=image_weight,
            content_weight=content_weight,
            adversarial_weight=adversarial_weight,
        )
    
    def train_light_cnn_9_loss_fc(self,
                                  log_dir=SRGAN_LIGHT_CNN_9_LOSS_MFM4_LOGS_DIR,
                                  generator_weights_dir=SRGAN_LIGHT_CNN_9_LOSS_MFM4_GENERATOR_WEIGHTS_DIR,
                                  epoch_count=3,
                                  save_frequency=5000,
                                  values_log_frequency=100,
                                  images_log_frequency=1000,
                                  start_log=50,
                                  train_type='careful',
                                  load_pretrained_generator=SRGAN_MSE_LOSS_BEST_WEIGHT,
                                  features_weight=7e-5,
                                  image_weight=1,
                                  content_weight=1,
                                  adversarial_weight=1e-3,
                                  discriminator_weights_dir=None):
        
        light_cnn = self._init_light_cnn(
            weight_path=LIGHT_CNN_9_WEIGHT,
            name='LightCNN_9'
        )
        
        feature_extractor = Grayscale2RGBModule(
            light_cnn,
            lambda x: x[1],
        )
        
        self._train_perceptual_loss(
            log_dir=log_dir,
            epoch_count=epoch_count,
            generator_weights_dir=generator_weights_dir,
            discriminator_weights_dir=discriminator_weights_dir,
            feature_extractor=feature_extractor,
            save_frequency=save_frequency,
            values_log_frequency=values_log_frequency,
            images_log_frequency=images_log_frequency,
            start_log=start_log,
            train_type=train_type,
            load_pretrained_generator=load_pretrained_generator,
            features_weight=features_weight,
            image_weight=image_weight,
            content_weight=content_weight,
            adversarial_weight=adversarial_weight,
        )
    
    def _init_light_cnn(self, weight_path, name='LightCNN_9'):
        if name is 'LightCNN_9':
            model_class = LightCNN_9Layers
            num_classes=79077
        else:
            raise ValueError('No such model {}'.format(name))
        
        model = model_class(num_classes=num_classes)
        model.eval()
        
        checkpoint = torch.load(weight_path)
        
        # remove DataParallel dependence
        new_checkpoint = OrderedDict()
        for layer_name, value in checkpoint['state_dict'].items():
            new_layer_name = layer_name[7:] # remove `module.`
            new_checkpoint[new_layer_name] = value
        
        model.load_state_dict(new_checkpoint)

        return model
    
    def train_vgg_loss(self,
                       log_dir,
                       generator_weights_dir,
                       discriminator_weights_dir=None,
                       epoch_count=3,
                       save_frequency=5000,
                       values_log_frequency=100,
                       images_log_frequency=1000,
                       start_log=50,
                       train_type='careful',
                       load_pretrained_generator=SRGAN_MSE_LOSS_BEST_WEIGHT,
                       features_weight=0.006,
                       image_weight=1,
                       content_weight=1,
                       adversarial_weight=1e-3):
        
        feature_extractor = FeatureExtractor(
            torchvision.models.vgg19(pretrained=True)
        )
        
        self._train_perceptual_loss(
            log_dir=log_dir,
            epoch_count=epoch_count,
            generator_weights_dir=generator_weights_dir,
            discriminator_weights_dir=discriminator_weights_dir,
            feature_extractor=feature_extractor,
            save_frequency=save_frequency,
            values_log_frequency=values_log_frequency,
            images_log_frequency=images_log_frequency,
            start_log=start_log,
            train_type=train_type,
            load_pretrained_generator=load_pretrained_generator,
            features_weight=features_weight,
            image_weight=image_weight,
            content_weight=content_weight,
            adversarial_weight=adversarial_weight,
        )
    
    def _train_perceptual_loss(self,
                               log_dir,
                               generator_weights_dir,
                               discriminator_weights_dir,
                               epoch_count=3,
                               feature_extractor=None,
                               save_frequency=5000,
                               values_log_frequency=100,
                               images_log_frequency=1000,
                               start_log=50,
                               train_type='careful',
                               load_pretrained_generator=SRGAN_MSE_LOSS_BEST_WEIGHT,
                               features_weight=0.002,
                               image_weight=1,
                               content_weight=1,
                               adversarial_weight=1e-3):
        
        assert feature_extractor is not None
        
        self._features_weight = features_weight
        self._image_weight = image_weight
        self._content_weight = content_weight
        
        if discriminator_weights_dir is not None:
            self._adversarial_weight = adversarial_weight
        
        self._feature_extractor = feature_extractor
        
        if self._cuda is True:
            self._feature_extractor = nn.DataParallel(
                self._feature_extractor
            ).cuda()
        
        if load_pretrained_generator is not None:
            self.load_generator_weights(load_pretrained_generator)
        
        self._train(
            log_dir=log_dir,
            epoch_count=epoch_count,
            train_function=self._perceptual_loss_train_function,
            generator_weights_dir=generator_weights_dir,
            discriminator_weights_dir=discriminator_weights_dir,
            save_frequency=save_frequency,
            values_log_frequency=values_log_frequency,
            images_log_frequency=images_log_frequency,
            start_log=start_log,
            train_type=train_type
        )
        
    def _perceptual_loss_train_function(self, high_res_real, high_res_fake):
        # Additional variables
        target_real = Variable(torch.rand(self._batch_size, 1) * 0.5 + 0.7)
        target_fake = Variable(torch.rand(self._batch_size, 1) * 0.3)
        
        if self._cuda is True:
            target_real = target_real.cuda()
            target_fake = target_fake.cuda()
        
        # Train discriminator
        if self._discriminator_weights_dir is not None:
            self._discriminator.zero_grad()

            discriminator_loss = self._adversarial_criterion(
                self._discriminator(high_res_real), target_real
            ) + self._adversarial_criterion(
                self._discriminator(high_res_fake.detach()), target_fake
            )

            discriminator_loss.backward()
            self._optim_discriminator.step()
        
        # Train generator 
        self._generator.zero_grad()
        
        real_features = self._feature_extractor(high_res_real).detach()
        fake_features = self._feature_extractor(high_res_fake)
        
        generator_content_loss_image = self._content_criterion(
            high_res_fake, high_res_real
        )
        generator_content_loss_features = self._content_criterion(
            fake_features, real_features
        )
        
        generator_content_loss_total = \
            self._image_weight * generator_content_loss_image + \
            self._features_weight * generator_content_loss_features
        
        if self._discriminator_weights_dir is not None:
            generator_adversarial_loss = self._adversarial_criterion(
                self._discriminator(high_res_fake), self._ones_const
            )
            
            generator_total_loss = \
                self._content_weight * generator_content_loss_total + \
                self._adversarial_weight * generator_adversarial_loss
        else:
            generator_total_loss = generator_content_loss_total

        generator_total_loss.backward()
        self._optim_generator.step()
        
        log_dict = {
            'generator_content_loss_image': generator_content_loss_image,
            'generator_content_loss_features': generator_content_loss_features,
            'generator_content_loss_total': generator_content_loss_total,
        }
        
        if self._discriminator_weights_dir is not None:
            log_dict.update(
                {
                    'discriminator_loss': discriminator_loss,
                    'generator_adversarial_loss': generator_adversarial_loss,
                    'generator_total_loss': generator_total_loss,
                }
            )
        
        return log_dict
        
    def _train(self,
               log_dir,
               epoch_count,
               train_function,
               generator_weights_dir,
               save_frequency=15000,
               values_log_frequency=300,
               images_log_frequency=3000,
               start_log=100,
               train_type='careful',
               discriminator_weights_dir=None):
        
        self._generator_weights_dir = generator_weights_dir
        self._discriminator_weights_dir = discriminator_weights_dir
        
        step_num = self._prepare_train_type(
            train_type, log_dir,
            generator_weights_dir, discriminator_weights_dir
        )
        
        try:
            logger = Logger(log_dir)
            progress_bar = tqdm(
                total=len(self._dataset_loader),
                initial=step_num % len(self._dataset_loader),
            )
            
            while True:
                stopped = False
                for image_hr, image_lr in self._dataset_loader:
                    epoch_num = step_num // len(self._dataset_loader)
                    if epoch_num == epoch_count:
                        stopped = True
                        break
                    progress_bar.desc = 'Epoch {}:'.format(epoch_num + 1)
                    progress_bar.n = step_num % len(self._dataset_loader) + 1
                    progress_bar.refresh()
                    
                    high_res_real = Variable(image_hr)
                    low_res = Variable(image_lr)
                    high_res_fake = self._generator(low_res)
                    if self._cuda is True:
                        high_res_real = high_res_real.cuda()
                        high_res_fake = high_res_fake.cuda()

                    values_dict = train_function(high_res_real, high_res_fake)

                    if step_num >= start_log:
                        if step_num % values_log_frequency == 0:
                            for key, value in values_dict.items():
                                logger.log_value(
                                    key,
                                    value,
                                    step_num
                                )

                        if step_num % images_log_frequency == 0:
                            logger.log_images(
                                'real_images',
                                high_res_real.data.cpu()[:1],
                                step_num
                            )

                            logger.log_images(
                                'lr_images',
                                low_res.data.cpu()[:1],
                                step_num
                            )

                            logger.log_images(
                                'fake_images',
                                high_res_fake.data.cpu()[:1],
                                step_num
                            )

                    if step_num % save_frequency == 0 and step_num != 0:
                        self._save_all(step_num, generator_weights_dir, discriminator_weights_dir)
                    
                    step_num += 1
                    
                self._save_all(step_num, generator_weights_dir, discriminator_weights_dir)
                if stopped is True:
                    break
        except KeyboardInterrupt:
            print('Stopped.')
        except Exception:
            print(traceback.format_exc())
            self._log_error(traceback)
        finally:
            self._save_all(step_num, generator_weights_dir, discriminator_weights_dir)
            progress_bar.close()
        
    def _save_all(self, step_num, generator_weights_dir, discriminator_weights_dir=None):
        torch.save(
            self._generator.state_dict(), self._make_weight_path(generator_weights_dir, step_num)
        )
        
        if discriminator_weights_dir is not None:
            torch.save(
                self._discriminator.state_dict(), self._make_weight_path(discriminator_weights_dir, step_num)
            )
        
    def _prepare_train_type(self, train_type, log_dir,
                            generator_weights_dir, discriminator_weights_dir=None):
        step = 0
        
        if train_type is 'careful':
            if os.path.exists(log_dir) or os.path.exists(generator_weights_dir) or \
               (discriminator_weights_dir is not None and os.path.exists(discriminator_weights_dir)):
                raise AssertionError(
                    'There are previous train files:\nlog_dir: {}\ngenerator_weights_dir: {}\ndiscriminator_weights_dir: {}'.format(
                        log_dir,
                        generator_weights_dir,
                        discriminator_weights_dir,
                    )
                )
        elif train_type is 'clean':
            print(
                'Clean all train data for (Y/N):\nlog_dir: {}\ngenerator_weights_dir: {}\ndiscriminator_weights_dir: {}'.format(
                    log_dir,
                    generator_weights_dir,
                    discriminator_weights_dir,
                )
            )
            answer = input()
            if answer == 'Y':
                remove_if_exists(log_dir)
                remove_if_exists(generator_weights_dir)
                if discriminator_weights_dir is not None:
                    remove_if_exists(discriminator_weights_dir)
            else:
                raise Exception('Bad answer')
        elif train_type is 'continue':
            try:
                filename = get_last_file(generator_weights_dir)
                self.load_generator_weights(os.path.join(generator_weights_dir, filename))

                if discriminator_weights_dir is not None:
                    filename = get_last_file(discriminator_weights_dir)
                    self.load_discriminator_weights(os.path.join(discriminator_weights_dir, filename))
                
                step = int(os.path.splitext(filename)[0])
            except Exception:
                raise Exception('Nothing to load')
        else:
            raise ValueError('No such train type')
            
        pathlib.Path(generator_weights_dir).mkdir(parents=True, exist_ok=True)
        if discriminator_weights_dir is not None:
            pathlib.Path(discriminator_weights_dir).mkdir(parents=True, exist_ok=True)
        pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)
        
        return step
        
    def _make_weight_path(self, folder, step):
        return os.path.join(folder, '{:010d}.pth'.format(step))
    
    def _log_error(self, traceback, diirectory=ERRORS_LOGS_DIR):
        error_path = os.path.join(
            ERRORS_LOGS_DIR,
            datetime.datetime.now().strftime(
                '%Y-%m-%d-%H-%M-%S-%f'
            )
        )
        with open(error_path, 'w') as file:
            file.write(traceback.format_exc())

In [None]:
srgan_manager = SRGANManager(dataset_loader)

In [12]:
srgan_manager.train_mse_only(epoch_count=3, train_type='careful')


  0%|          | 0/287187 [00:00<?, ?it/s]

Epoch: 0


[A




  0%|          | 1/287187 [00:00<49:35:23,  1.61it/s][A
  0%|          | 2/287187 [00:00<37:06:53,  2.15it/s][A
  0%|          | 4/287187 [00:00<28:11:47,  2.83it/s][A
  0%|          | 6/287187 [00:01<21:36:46,  3.69it/s][A
  0%|          | 8/287187 [00:01<16:57:47,  4.70it/s][A
  0%|          | 10/287187 [00:01<14:07:16,  5.65it/s][A
  0%|          | 12/287187 [00:01<11:51:25,  6.73it/s][A
  0%|          | 14/287187 [00:01<10:12:19,  7.82it/s][A
  0%|          | 16/287187 [00:01<9:15:05,  8.62it/s] [A
  0%|          | 18/287187 [00:02<8:36:18,  9.27it/s][A
 21%|██        | 59307/287187 [1:34:13<6:04:02, 10.43it/s] Process Process-13:
Process Process-15:
Process Process-16:
Process Process-14:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap

In [19]:
srgan_manager.train_vgg_loss(
    log_dir=SRGAN_VGG_LOSS_3_1_LOGS_DIR,
    generator_weights_dir=SRGAN_VGG_LOSS_3_1_GENERATOR_WEIGHTS_DIR,
    discriminator_weights_dir=SRGAN_VGG_LOSS_3_1_DISCRIMINATOR_WEIGHTS_DIR,
    epoch_count=3, train_type='careful'
)

Epoch 3  9%|▉         | 25967/287187 [1:43:21<17:19:49,  4.19it/s]Process Process-12:
Epoch 3  9%|▉         | 25968/287187 [1:43:22<17:19:48,  4.19it/s]Exception ignored in: <bound method DataLoaderIter.__del__ of <torch.utils.data.dataloader.DataLoaderIter object at 0x7f3889126630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 333, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 319, in _shutdown_workers
    self.data_queue.get()
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 345, in get
    return ForkingPickler.loads(res)
  File "/usr/local/lib/python3.5/dist-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/usr/lib/python3.5/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/usr/lib/python3

Traceback (most recent call last):
  File "<ipython-input-17-fd41f254d670>", line 309, in _train
    values_dict = train_function(high_res_real, high_res_fake)
  File "<ipython-input-17-fd41f254d670>", line 245, in _perceptual_loss_train_function
    self._optim_generator.step()
  File "/usr/local/lib/python3.5/dist-packages/torch/optim/adam.py", line 69, in step
    exp_avg.mul_(beta1).add_(1 - beta1, grad)
KeyboardInterrupt



Process Process-11:

Process Process-10:
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 50, in _worker_loop
    r = index_queue.get()
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 50, in _worker_loop
    r = index_queue.get()
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 342, in get
    with self._rlock:
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 343, in get
    res = self._reader.recv_bytes()
  File "/usr/lib/python3.5/multipr

In [13]:
srgan_manager.train_vgg_loss(
    log_dir=SRGAN_VGG_LOSS_3_1_NO_ADVERSARIAL_LOGS_DIR,
    generator_weights_dir=SRGAN_VGG_LOSS_3_1_NO_ADVERSARIAL_WEIGHTS_DIR,
    epoch_count=3,
    save_frequency=10000,
    values_log_frequency=200,
    images_log_frequency=2000,
    train_type='careful'
)

Clean all train data for (Y/N):
log_dir: /home/data/alpus/clean_fsr/data/logs/srgan/vgg_loss_3.1_no_adversarial
generator_weights_dir: /home/data/alpus/clean_fsr/data/weights/srgan/vgg_loss_3.1_no_adversarial
discriminator_weights_dir: None
Y



Epoch 1:   0%|          | 0/287187 [00:00<?, ?it/s][A




Epoch 1:   0%|          | 1/287187 [00:00<60:32:36,  1.32it/s][A
Epoch 1:   0%|          | 2/287187 [00:00<45:40:47,  1.75it/s][A
Epoch 1:   0%|          | 3/287187 [00:01<35:12:57,  2.27it/s][A
Epoch 1:   0%|          | 4/287187 [00:01<27:37:36,  2.89it/s][A
Epoch 1:   0%|          | 5/287187 [00:01<22:11:59,  3.59it/s][A
Epoch 1:   0%|          | 6/287187 [00:01<18:22:27,  4.34it/s][A
Epoch 1:   0%|          | 7/287187 [00:01<15:40:45,  5.09it/s][A
Epoch 1:   0%|          | 8/287187 [00:01<13:51:48,  5.75it/s][A
Epoch 1:   0%|          | 9/287187 [00:01<12:35:50,  6.33it/s][A
Epoch 1:   0%|          | 10/287187 [00:01<11:42:30,  6.81it/s][A
Epoch 1:   0%|          | 11/287187 [00:02<11:22:07,  7.02it/s][A
Epoch 1:   0%|          | 12/287187 [00:02<11:07:07,  7.17it/s][A
Epoch 1:   0%|          | 13/287187 [00:02<10:56:46,  7.29it/s][A
Epoch 1:   0%|          | 14/287187 [00:02<10:33:59,  7.55it/s][A
Epoch 1:   

Traceback (most recent call last):
  File "<ipython-input-9-ad068de2b895>", line 298, in _train
    values_dict = train_function(high_res_real, high_res_fake)
  File "<ipython-input-9-ad068de2b895>", line 244, in _perceptual_loss_train_function
    generator_total_loss.backward()
  File "/usr/local/lib/python3.5/dist-packages/torch/autograd/variable.py", line 167, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
  File "/usr/local/lib/python3.5/dist-packages/torch/autograd/__init__.py", line 99, in backward
    variables, grad_variables, retain_graph)
KeyboardInterrupt



  File "/usr/lib/python3.5/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
Exception ignored in: <bound method DataLoaderIter.__del__ of <torch.utils.data.dataloader.DataLoaderIter object at 0x7f5cf4aa3080>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 333, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 319, in _shutdown_workers
    self.data_queue.get()
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 345, in get
    return ForkingPickler.loads(res)
  File "/usr/local/lib/python3.5/dist-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/usr/lib/python3.5/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/usr/lib/python3.5/multiprocessing/resource_sharer.py"

In [13]:
srgan_manager.train_light_cnn_9_loss_mfm4(
    log_dir=SRGAN_LIGHT_CNN_9_LOSS_MFM4_LOGS_DIR,
    generator_weights_dir=SRGAN_LIGHT_CNN_9_LOSS_MFM4_GENERATOR_WEIGHTS_DIR,
    discriminator_weights_dir=SRGAN_LIGHT_CNN_9_LOSS_MFM4_DISCRIMINATOR_WEIGHTS_DIR,
    epoch_count=3, train_type='careful'
)

Epoch 3: 55%|█████▍    | 156619/287186 [4:44:37<3:57:16,  9.17it/s]Process Process-19:
Process Process-18:
Process Process-17:
Process Process-20:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Epoch 3: 55%|█████▍    | 156620/287186 [4:44:37<3:57:16,  9.17it/s]Exception ignored in: <bound method DataLoaderIter.__del__ of <torch.utils.data.dataloader.DataLoaderIter object at 0x7fc639f84b70>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 333, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 319, in _shutdown_workers
    self.data_queue.get()
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 345, in get
    return ForkingPickler.loads(res)
  File "/usr/local/lib/python3.5/dist-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()


Stopped.


  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()

  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 50, in _worker_loop
    r = index_queue.get()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 342, in get
    with self._rlock:
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 50, in _worker_loop
    r = index_queue.get()
  File "/usr/lib/python3.5/multiprocessing/synchronize.py", line

In [12]:
srgan_manager.train_light_cnn_9_loss_mfm4(
    log_dir=SRGAN_LIGHT_CNN_9_LOSS_MFM4_NO_ADVERSARIAL_LOGS_DIR,
    generator_weights_dir=SRGAN_LIGHT_CNN_9_LOSS_MFM4_NO_ADVERSARIAL_WEIGHTS_DIR,
    epoch_count=3,
    save_frequency=10000,
    values_log_frequency=200,
    images_log_frequency=2000,
    train_type='careful',
)

Epoch 3: 81%|████████▏ | 233469/287186 [51:12<11:46, 75.99it/s]]]]]Process Process-16:
Process Process-15:
Process Process-13:
Process Process-14:
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._a

Stopped.


  File "/usr/local/lib/python3.5/dist-packages/torchvision/datasets/folder.py", line 122, in __getitem__
    img = self.loader(path)

  File "/usr/lib/python3.5/multiprocessing/synchronize.py", line 96, in __enter__
    return self._semlock.__enter__()
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/usr/local/lib/python3.5/dist-packages/torchvision/datasets/folder.py", line 69, in default_loader
    return pil_loader(path)
KeyboardInterrupt
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
  File "/usr/local/lib/python3.5/dist-packages/torchvision/datasets/folder.py", line 52, in pil_loader
    return img.convert('RGB')
KeyboardInterrupt
  File "/usr/local/lib/python3.5/dist-packages/PIL/Image.py", line 860, in convert
    self.load()
  File "/usr/local/lib/python3.5/dist-packages/PIL/ImageFile.py", line 234, in load
    n, err_code = decoder.decode(b)
Ke

In [16]:
srgan_manager.train_light_cnn_9_loss_fc(
    log_dir=SRGAN_LIGHT_CNN_9_LOSS_FC_NO_ADVERSARIAL_LOGS_DIR,
    generator_weights_dir=SRGAN_LIGHT_CNN_9_LOSS_FC_NO_ADVERSARIAL_WEIGHTS_DIR,
    epoch_count=3,
    save_frequency=10000,
    values_log_frequency=200,
    images_log_frequency=2000,
    train_type='careful',
)

Clean all train data for (Y/N):
log_dir: /home/data/alpus/clean_fsr/data/logs/srgan/light_cnn_9_loss_fc_no_adversarial
generator_weights_dir: /home/data/alpus/clean_fsr/data/weights/srgan/light_cnn_9_loss_fc_no_adversarial
discriminator_weights_dir: None
Y


Epoch 3: 43%|████▎     | 122744/287186 [28:59:04<38:49:52,  1.18it/s]]]]Process Process-15:
Process Process-13:
Process Process-14:
Process Process-16:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 50, in _worker_loop
    r = index_queue.get()
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 342, in get
    with self._rlock:
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwarg

Stopped.





In [None]:
srgan_manager.train_light_cnn_9_loss_mfm4(
    log_dir=SRGAN_LIGHT_CNN_9_LOSS_MFM4_NO_ADVERSARIAL_NO_IMAGE_LOGS_DIR,
    generator_weights_dir=SRGAN_LIGHT_CNN_9_LOSS_MFM4_NO_ADVERSARIAL_NO_IMAGE_WEIGHTS_DIR,
    epoch_count=3,
    save_frequency=10000,
    values_log_frequency=200,
    images_log_frequency=2000,
    image_weight=0,
    train_type='careful',
)

Epoch 1:  1%|          | 2911/287186 [06:55<11:15:29,  7.01it/s]