# StarGAN implementation in Torch

<img src='jpg/logo.jpg'>

[StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation](https://arxiv.org/abs/1711.09020)
In notebook format :)

Blocks in quotes are taken directly from the paper 

In [1]:
import os, time, datetime
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.utils import save_image
import pandas as pd
import numpy as np
import argparse

from tqdm import tqdm
%matplotlib inline
import matplotlib.pyplot as plt

In [2]:
import numpy as np

In [3]:
torch.__version__

'0.4.1'

In [4]:
parser = argparse.ArgumentParser()

# Directory parameters
parser.add_argument('--data_dir', type=str, default='data/CelebA_nocrop/images_ssd', help='Image directory for CelebA (should yield images with <dir>/*.jpg)')
parser.add_argument('--attr_dir', type=str, default='data/list_attr_celeba.txt', help='Image directory for CelebA (should yield images with <dir>/*.jpg)')
parser.add_argument('--model_save_dir', type=str, default='stargan/models')
parser.add_argument('--sample_dir', type=str, default='stargan/samples')
parser.add_argument('--result_dir', type=str, default='stargan/results')

# Train parameters
parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both'])
parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size')
parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
parser.add_argument('--g_lr', type=float, default=0.00003, help='learning rate for G')
parser.add_argument('--d_lr', type=float, default=0.00003, help='learning rate for D')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')
parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
                    default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')

    
# Model parameters
parser.add_argument('--img_size', type=int, default=128, help='Image resolution for training G')
parser.add_argument('--celeba_crop_size', type=int, default=178, help='Image resolution first initial crop in CelebA')
parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)')
parser.add_argument('--c2_dim', type=int, default=8, help='dimension of domain labels (2nd dataset)')
parser.add_argument('--g_residual', type=bool, default=False, help='Residual connections for generator')
parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
parser.add_argument('--d_repeat_num', type=int, default=4, help='number of strided conv layers in D')
parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')

parser.add_argument('--lr_update_step', type=int, default=1000, help='Iterations before reducing LR')
# Others

parser.add_argument('--log_step', type=int, default=100, help='Iterations before printing')
parser.add_argument('--sample_step', type=int, default=1000, help='Iterations before producing samples')
parser.add_argument('--model_save_step', type=int, default=10000)
# Jupyter notebook specific debugging
parser.add_argument('--debug', type=bool, default=False, help='For debugging stuff on notebook')

config = parser.parse_args(args=[])

## Data loader
Using pytorch data.Dataset + dataloader API

In [5]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T

from PIL import Image
import torch
import os, random

from sklearn.model_selection import train_test_split
# Test attribute reading

class CelebA(Dataset):
    
    def __init__(self, image_dir, attr_path, selected_attr, transform, mode, seed=None):
        """Initialize and preprocess CelebA dataset"""
        self.image_dir = image_dir
        self.attr_path = attr_path
        self.selected_attr = selected_attr
        self.transform = transform
        self.mode = mode
        self.train_dataset = None #To be filled by pd.DataFrame
        self.test_dataset = None #Same as above
        self.seed = seed
        self.preprocess()
        
        self.num_images = len(self.train_dataset) if mode == 'train' else len(self.test_dataset)
    
    def preprocess(self):
        """Preprocess CelebA attribute file"""
        attr_df = pd.read_csv(self.attr_path, skiprows=1, delim_whitespace=True)
        attr_df = attr_df[self.selected_attr]
        attr_df = (attr_df == 1) #Convert to booleans
        
        #Split into train and test
        train_df, test_df = train_test_split(attr_df, test_size=2000, random_state=self.seed)
        self.train_dataset = train_df
        self.test_dataset = test_df
        print("Finish preprocessing CelebA attributes file")
        
    def __getitem__(self, index):
        """Return one image and corresponding attribute label"""
        dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
        row = dataset.iloc[index]
        filename = row.name
        label = row.tolist()
        image = Image.open(os.path.join(self.image_dir, filename))
        return self.transform(image), torch.FloatTensor(label)
        
    def __len__(self):
        return self.num_images


In [6]:
def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128,
              batch_size=16, dataset='CelebA', mode='train', num_workers=4, seed=1234):
    """Wrapper for building a data loader/generator"""
    
    #Build transformations
    transform = T.Compose([
        T.RandomHorizontalFlip(), 
        T.CenterCrop(crop_size),
        T.Resize(image_size),
        T.ToTensor(),
    ])
    
    if dataset.lower() == 'celeba':
        dset = CelebA(image_dir, attr_path, selected_attrs, transform, mode, seed)
    elif dataset.lower() == 'rafd':
        NotImplemented()
    else:
        raise(Exception("dataset must be 'celeb' or 'rafd'"))
        
    #Put into DataLoader
    loader = DataLoader(dset, batch_size=batch_size, shuffle=True, 
                        pin_memory=True, #For CUDA optimization
                        num_workers=num_workers)
    
    return loader


In [7]:
celeba_loader = get_loader(config.data_dir, config.attr_dir, selected_attrs=config.selected_attrs, num_workers=4)

Finish preprocessing CelebA attributes file


In [8]:
# Test iteration time
def show_img(img):
    np_img = img[0].numpy()
    plt.imshow(np_img.transpose(1,2,0))
    plt.show()

if config.debug:
    for img, lab in tqdm(celeba_loader):
        break

    print(img.shape)
    show_img(img)
    print(lab[0])

# Model Definition

> Network Architecture. Adapted from [[31]], StarGAN has
the generator network composed of two convolution layers
with the stride size of two for downsampling, six residual
blocks [[5]], and two transposed convolution layers with the
stride size of two for upsampling. We use instance normalization
[[27]] for the generator but no normalization for the
discriminator. We leverage PatchGANs [13, 7, 31] for the
discriminator network, which classifies whether local image
patches are real or fake.

[31]:https://arxiv.org/abs/1703.10593
[5]:https://arxiv.org/abs/1512.03385
[27]:https://arxiv.org/abs/1607.08022

## Generator

> The network architectures of StarGAN are shown in Table 4 and 5. For the generator network, we use instance normalization
in all layers except the last output layer. For the discriminator network, we use Leaky ReLU with a negative slope of
0.01. There are some notations; nd: the number of domain, nc: the dimension of domain labels (nd + 2 when training with
both the CelebA and RaFD datasets, otherwise same as nd), N: the number of output channels, K: kernel size, S: stride size,
P: padding size, IN: instance normalization.
<img src='jpg/StarGan-G.png'>

In addition to the network in the paper, we include a residual connection from the corresponding down-sampling layer to the up-sampling one, similar to [U-net](https://arxiv.org/abs/1505.04597).

In [9]:
class Generator(nn.Module):
    def __init__(self, conv_dim=64, c_dim=5, repeat_num=6, residual=True):
        super(Generator, self).__init__()
        self.residual = residual # Enable residual connection during up-sampling
        # We use the sequential API from torch
        self.down_layers = nn.ModuleList()
        self.down_layers.append(ConvInstNormReluBlock(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))

        # Down-sampling (2 and 3)
        curr_dim = conv_dim
        for _ in range(2):
            self.down_layers.append(ConvInstNormReluBlock(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) #Halves H-W dimension
            curr_dim *= 2
            
        #Bottleneck
        self.bottle_layers = nn.ModuleList()
        for _ in range(repeat_num):
            self.bottle_layers.append(ResidualBlock(curr_dim, curr_dim))
        
        # Up-sampling
        self.up_layers = nn.ModuleList()
        for _ in range(2):
            self.up_layers.append(ConvInstNormReluBlock(curr_dim, curr_dim//2, tranpose=True, kernel_size=4, stride=2, padding=1, bias=False)) #Doubles H-W dimension
            curr_dim = curr_dim // 2
        
        self.up_layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
        self.out_layer = nn.Tanh()
        
    def forward(self, x, c):
        # Prep indicators
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, c], dim=1)
        
        # Down-sample
        down_out = [] #Saving for residual connection later
        for layer in self.down_layers:
            x = layer(x)
            down_out.append(x)
            
        for layer in self.bottle_layers:
            x = layer(x)
            
        for i, layer in enumerate(self.up_layers):
            if self.residual:
                x = x + down_out[-1 - i]
            x = layer(x)
            
        return self.out_layer(x)
        
class ConvInstNormReluBlock(nn.Module):
    """Wrapper for Conv + Instance Norm + Relu"""
    def __init__(self, dim_in, dim_out, tranpose=False, **convkwargs):
        super(ConvInstNormReluBlock, self).__init__()
        if not tranpose:
            self.main = nn.Sequential(
                nn.Conv2d(dim_in, dim_out, **convkwargs),
                nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
                nn.ReLU()
            )
        else:
            self.main = nn.Sequential(
                nn.ConvTranspose2d(dim_in, dim_out, **convkwargs),
                nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
                nn.ReLU()
            )            
        
    def forward(self, x):
        return self.main(x)
        
class ResidualBlock(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        layers = []
        
        layers.append(nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))
        # Unlike the paper, we do not use an inplace ReLU. Not recommended as per https://pytorch.org/docs/master/notes/autograd.html#in-place-operations-on-variables
        layers.append(nn.ReLU())
        # ! In the author's code, there is an extra conv + ReLU
        layers.append(nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))        
        self.main = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.main(x) + x

In [10]:
# Testing
def param_count(model):
    return sum(x.numel() for x in model.parameters()) #if x.requires_grad

if config.debug:
    test_G = Generator(residual=False)
    test_x = torch.randn(2, 3, 128, 128)
    test_c = torch.randn(2, 5, 1, 1)
    test_out = test_G(test_x, test_c)
    test_out.sum().backward()
    print(test_out.shape)
    print("Number of trainable parameters in G: %d" % param_count(test_G))
    print(test_G)

## Discriminator

> We leverage PatchGANs [13, 7, 31] for the
discriminator network, which classifies whether local image
patches are real or fake.

><img src='jpg/StarGan-D.png'>

In [11]:
class Discriminator(nn.Module):
    """Takes in 3-channel image, and outputs a local image patch output, with a classifier for labels"""
    def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
        super(Discriminator, self).__init__()
        
        layers = []
        #layers.append(nn.Sequential(
        #    nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1, bias=False),
        #    #nn.LeakyReLU()
        #))
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Sequential(
                nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1),
                nn.LeakyReLU()
            ))
            curr_dim *= 2
        self.main = nn.Sequential(*layers)
        
        # Output layer calculation.
        kernel_size = image_size // (2**repeat_num)
        self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) #Wasserstein loss out
        self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)        
        
    def forward(self, x):
        h = self.main(x)
        out_src = self.conv1(h)
        out_cls = self.conv2(h)
        return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))

In [12]:
if config.debug:
    test_D = Discriminator(repeat_num=4)
    test_x = torch.randn(2, 3, 128, 128)
    test_out_src, test_out_cls = test_D(test_x)
    test_out_src.sum().backward()
    print(test_out_src.size(), test_out_cls.size())
    print("Number of trainable parameters in D: %d" % param_count(test_D))
    print(test_D)

## Loss functions

In [13]:
def classification_loss(logits, target):
    return F.binary_cross_entropy_with_logits(logits, target, reduction='sum') / logits.size(0)

def Wloss(x):
    """Wasserstein loss. Assumes target is of single class! (Perform multiplication outside)"""
    return torch.mean(x)

# TODO: Study this
def gradient_penalty(y, x, device):
    """Gradient Penalty: (L2_norm(dy/dx) - 1)**2: As shown in Improvements in WGAN"""
    weight = torch.ones(y.shape).to(device)
    dydx = torch.autograd.grad(outputs=y,
                              inputs=x,
                              grad_outputs=weight,
                              retain_graph=True, #WARNING: NOT ENABLING THIS WILL BREAK THE LOSS!
                              create_graph=True,
                              only_inputs=True)[0]
    dydx = dydx.view(dydx.size(0), -1)
    dydx_l2 = torch.sqrt(torch.sum(dydx**2, dim=1))
    return torch.mean((dydx_l2 - 1)**2)

In [14]:
if config.debug:
    test_size = (10, 5)
    with torch.no_grad():
        test_logit = torch.randn(test_size)
        test_out = torch.randn(test_size)
        loss = classification_loss(test_logit, test_out)
        print(loss.item())

## Helper functions

In [15]:
def denorm(x):
    """Output image is from [-1, 1]. Transform to [0, 1] for saving"""
    out = (x + 1) / 2
    return out.clamp(0, 1)

def calculate_norm(model):
    norm = 0
    for param in model.parameters():
        if param.requires_grad:
            norm += torch.sum(param**2).item()
    return norm

# Training Wrapper

In [16]:
class Solver():
    def __init__(self, dataloader, config):
        self.dataloader = dataloader
        
        for k, v in vars(config).items():
            setattr(self, k, v)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.G, self.D = None, None
        self.g_optim, self.d_optim = None, None
        self.build_models()
        self.initialize_directory()
    
    def initialize_directory(self):
        for path in ('model_save_dir', 'sample_dir', 'result_dir'):
            os.makedirs(getattr(self, path), exist_ok=True)
        
    def build_models(self):
        self.G = Generator(conv_dim=self.g_conv_dim, c_dim=self.c_dim, repeat_num=self.g_repeat_num, residual=self.g_residual)
        self.D = Discriminator(image_size=self.img_size, conv_dim=self.d_conv_dim, c_dim=self.c_dim,
                             
                              repeat_num=self.d_repeat_num)
        self.G.to(self.device)
        self.D.to(self.device)
        print("Number of trainable parameters in G: %d" % param_count(self.G))
        print("Number of trainable parameters in D: %d" % param_count(self.D))
        self.g_optim = torch.optim.Adam(self.G.parameters(), lr=self.g_lr, betas=(self.beta1, self.beta2)) #), amsgrad=True
        self.d_optim = torch.optim.Adam(self.D.parameters(), lr=self.d_lr, betas=(self.beta1, self.beta2)) #), amsgrad=True
        
    def update_lr(self, g_lr, d_lr):
        for param_group in self.g_optim.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optim.param_groups:
            param_group['lr'] = d_lr

    def reset_grad(self):
        self.g_optim.zero_grad()
        self.d_optim.zero_grad()

    def create_labels(self, c_org, c_dim=5, selected_attrs=None):
        """Generate target domain labels for debugging and testing."""
        # Get hair color indices.
        hair_color_indices = []
        for i, attr_name in enumerate(selected_attrs):
            if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
                hair_color_indices.append(i)

        c_trg_list = []
        for i in range(c_dim):
            c_trg = c_org.clone()
            if i in hair_color_indices:  # Set one hair color to 1 and the rest to 0.
                c_trg[:, i] = 1
                for j in hair_color_indices:
                    if j != i:
                        c_trg[:, j] = 0
            else:
                c_trg[:, i] = (c_trg[:, i] == 0)  # Reverse attribute value.

            c_trg_list.append(c_trg.to(self.device))
        return c_trg_list
    
    def train(self):
        
        # Create a set of fixed input for viewing changes across iterations
        data_iter = iter(self.dataloader)
        x_fixed, c_original = next(data_iter)
        x_fixed = x_fixed.to(self.device)
        c_fixed_list = self.create_labels(c_original, self.c_dim, self.selected_attrs)
        
        g_lr = self.g_lr
        d_lr = self.d_lr
        
        print("Training")
        start_time = time.time()
        for i in range(0, self.num_iters): #May want to consider epoch style training
            
            # Preprocess input data
            # ===============================================
            try:
                x_real, label_org = next(data_iter)
            except:
                data_iter = iter(self.dataloader)
                x_real, label_org = next(data_iter)
                
            # Generate some random target domain labels (G training) by mixing within training set
            rand_idx = torch.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]
            
            c_org = label_org.clone() # TODO: Is this necessary?
            c_trg = label_trg.clone()
            
            x_real = x_real.to(self.device)
            c_org = c_org.to(self.device)
            c_trg = c_trg.to(self.device)
            label_org = label_org.to(self.device)
            label_trg = label_trg.to(self.device)
            
            # Train Discriminator
            # ===============================================
            # Real image
            out_src, out_cls = self.D(x_real)
            d_loss_real = - Wloss(out_src)
            d_loss_cls = classification_loss(out_cls, label_org)
            
            # Fake image
            x_fake = self.G(x_real, c_trg)
            out_src, _ = self.D(x_fake.detach())
            d_loss_fake = Wloss(out_src)
            
            # Gradient penalty
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
            out_src, _ = self.D(x_hat)
            d_loss_gp = gradient_penalty(out_src, x_hat, self.device)

            
            # Backprop
            d_loss = d_loss_real + d_loss_fake + self.lambda_cls*d_loss_cls + self.lambda_gp*d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optim.step()
            
            # Logging
            loss = {}
            loss['D/loss'] = d_loss.item()
            loss['D/loss_real'] = d_loss_real.item()
            loss['D/loss_fake'] = d_loss_fake.item()
            loss['D/loss_cls'] = d_loss_cls.item()
            loss['D/loss_gp'] = d_loss_gp.item()
            
            # Train Generator
            # ===============================================
            if (i+1) % self.n_critic == 0:
                # Original-to-target
                x_fake = self.G(x_real, c_trg)
                out_src, out_cls = self.D(x_fake)
                g_loss_fake = - torch.mean(out_src)
                g_loss_cls = classification_loss(out_cls, label_trg)
                
                # Target-to-original
                x_recon = self.G(x_fake, c_org)
                g_loss_rec = torch.mean(torch.abs(x_real - x_recon))
                
                # Backprop
                g_loss = g_loss_fake + self.lambda_rec*g_loss_rec + self.lambda_cls*g_loss_cls
                self.reset_grad()
                g_loss.backward()
                self.g_optim.step()
                
                # Logging
                loss['G/loss'] = g_loss.item()
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_cls'] = g_loss_cls.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                
            # Misc
            # ===============================================           
            
            # Print training info (TODO: Connect to tensorboard)
            if (i+1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)
            
            # Decay training rate
            if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
            
            # Produce samples            
            if (i+1) % self.sample_step == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                    for c_fixed in c_fixed_list:
                        x_fake_list.append(self.G(x_fixed, c_fixed))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                    save_image(denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(sample_path))

            # Save model checkpoints.
            if (i+1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
                D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))
            
        return 0
            
        

In [17]:
solver = Solver(celeba_loader, config)

Number of trainable parameters in G: 8430528
Number of trainable parameters in D: 2924992


In [18]:
solver.train()

Training


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/zeyi/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-18-ad90b845b08b>", line 1, in <module>
    solver.train()
  File "<ipython-input-16-898089759b17>", line 107, in train
    x_fake = self.G(x_real, c_trg)
  File "/home/zeyi/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "<ipython-input-9-c77546ff4686>", line 47, in forward
    x = layer(x)
  File "/home/zeyi/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/zeyi/anaconda3/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 301, in forward
    self.padding, self.dilation, self.groups)
RuntimeError: CUDA error: out of memory

During handling of the above exception, anoth

RuntimeError: CUDA error: out of memory