In [1]:
%matplotlib inline

import torch.nn as nn
import torch
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data as data
import numpy as np
import math
from matplotlib import pyplot as plt
from graphviz import Digraph
import os

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7ff2dba93090>

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="2,3"
global vec0
global vec1
global ve2

In [3]:
def upsample(x):
    return nn.functional.interpolate(x, scale_factor=2, mode="nearest")

def downsample(x):
    return nn.functional.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)

def conv_layer(in_filters, out_filters=32, kernel_size=3, he_init=True):
    same_padding = (kernel_size-1)//2
    conv = nn.Conv2d(in_filters, out_filters, kernel_size=kernel_size, padding=same_padding)
    
    if he_init:
        he_init_constant = math.sqrt(6 / (in_filters * kernel_size**2))
        nn.init.uniform_(conv.weight, -he_init_constant, he_init_constant)
    else:
        xavier_init_constant = math.sqrt(6 / ((in_filters + out_filters) * kernel_size**2))
        nn.init.uniform_(conv.weight, -xavier_init_constant, xavier_init_constant)
    nn.init.constant_(conv.bias, 0)
    
    return conv

def bn(channels):
    batchnorm = nn.BatchNorm2d(channels, eps=1e-5)
    nn.init.constant_(batchnorm.weight, 1)
    return batchnorm

def linear(in_features, out_features):
    linear_layer = nn.Linear(in_features, out_features)
    
    xavier_init_constant = math.sqrt(6/(in_features+out_features))
    nn.init.uniform_(linear_layer.weight, -xavier_init_constant, xavier_init_constant)
    nn.init.constant_(linear_layer.bias, 0)
    
    return linear_layer

In [4]:
class ResBlock(nn.Module):
    def __init__(self, in_filters, out_filters, resample=None, normalize=False, activation=None):
        super(ResBlock, self).__init__()
        self.in_filters = in_filters
        self.out_filters = out_filters
        self.resample = resample
        self.normalize = normalize
        
        self.conv1 = conv_layer(in_filters, out_filters)
        self.conv2 = conv_layer(out_filters, out_filters)
        
        if resample:
            self.conv3 = conv_layer(in_filters, out_filters, kernel_size=1, he_init=False)
        
        if normalize:
            self.bn1 = bn(in_filters)
            self.bn2 = bn(out_filters)
            
        if activation is not None:
            self.activation = activation
        else:
            self.activation = nn.ReLU()
                
    def forward(self, x): 
        orig_input = x
   
        if self.normalize:
            x = self.bn1(x)
            
        

        
        x = self.activation(x)
        
        print("Activated", x.shape)
        print(x[0, :2, :2,:2])
        
        if self.resample == 'up':
            x = upsample(x)
            
        print("Upsampled", x.shape)
        print(x[0, :2, :2,:2])
        
        global vec2
        vec2 = x
 
       
        x = self.conv1(x)
        
        print("AFTER conv1", x.shape)
        print(x[0, :2, :2,:2])

        
#         print("Conv weights")
#         print(self.conv1.weight.data[0, :3, :3, :3])
#         print("------")
#         print(self.conv1.bias.data)
        
        
        global vec0
        vec0 = x
        
        if self.normalize:
            x = self.bn2(x)
            
        print("Normalized again")
        print(x[0, :2, :2,:2])
        
        global vec1
        vec1 = x
            
        x = self.activation(x)
        
        print("Activated again")
        print(x[0, :2, :2,:2])
        
        x = self.conv2(x)
        
        print("AFTER conv2")
        print(x[0, :2, :2,:2])
        
        raise ValueError("hi")
        
        if self.resample == 'down':
            x = downsample(x)
            
        print("AFTER DOWNSAMPLE")
        print(x[0, :2, :2,:2])
        
        # Shortcut
        if self.resample == 'down': 
            shortcut_x = downsample(self.conv3(orig_input))
        elif self.resample == 'up':
            shortcut_x = self.conv3(upsample(orig_input))
        elif self.resample == None:
            shortcut_x = orig_input
            
        print("SHORTCUT")
        print(shortcut_x[0, :2, :2,:2])
        
        return x + shortcut_x
    
class SmallResBlock(nn.Module):
    def __init__(self, in_filters, out_filters, activation=None):
        super(SmallResBlock, self).__init__()
        self.in_filters = in_filters
        self.out_filters = out_filters
        
        self.conv1 = conv_layer(in_filters, out_filters)
        self.conv2 = conv_layer(out_filters, out_filters)
        self.conv3 = conv_layer(in_filters, out_filters, kernel_size=1, he_init=False)
            
        if activation is not None:
            self.activation = activation
        else:
            self.activation = nn.ReLU()
                
    def forward(self, x): 
        orig_input = x
       
        x = self.conv1(x)   
        x = downsample(self.conv2(self.activation(x)))

        # Shortcut
        shortcut_x = self.conv3(downsample(orig_input))
        
        return x + shortcut_x 

In [5]:
class Generator(nn.Module):
    def __init__(self, input_size, num_filters=128, num_blocks=3, start_image_size=4, num_channels=3):
        super(Generator, self).__init__()
        
        self.num_filters = num_filters
        self.start_image_size = start_image_size
        
        self.first_linear = linear(input_size, num_filters * start_image_size ** 2)
        self.resblocks = nn.ModuleList()
        self.activation = nn.ReLU()
        
        for _ in range(num_blocks):
            self.resblocks.append(
                ResBlock(in_filters=self.num_filters, 
                         out_filters=self.num_filters, 
                         resample='up', 
                         normalize=True))
            
        self.last_layer = conv_layer(num_filters, num_channels)
        self.bn = bn(num_filters)
        self.manually_initialize()
    
    def forward(self, noise):
        
        print("Should be all ones")
        print(noise)
        print(noise.size())
        
        x = self.first_linear(noise)
        
        print("Literally just dense")
        
#         print("A1")
        
        print(x)
        print(x.size())
        
#         x = x.view(-1, self.num_filters, self.start_image_size, self.start_image_size)
        x = x.view(-1, self.start_image_size, self.start_image_size, self.num_filters)
        print("first view reshape", x[0, :3, :3, :3])
        x = x.permute(0, 3, 1, 2)
        print("second view reshape", x[0, :3, :3, :3])
        
        
        global dense_output
        dense_output = x
        
        print("After Dense")
        print(x.permute(0,2,3,1))
        print(x.size())
        
        for resblock in self.resblocks:
            x = resblock(x)
#             print("B")
#             print(x)
#             print(x.size())
            
        x = self.activation(self.bn(x))
        result = self.last_layer(x)
        return torch.tanh(result)
    
    def manually_initialize(self):
        print(self.first_linear.weight.data[0,0:3])
        self.first_linear.weight.data = torch.tensor(np.load("params/generator_dense_kernel:0.npy")).transpose(0,1).contiguous()
        self.first_linear.bias.data = torch.tensor(np.load("params/generator_dense_bias:0.npy")).contiguous()
        print(self.first_linear.weight.data[0:3,0])
        
        self.resblocks[0].bn1.weight.data = torch.tensor(np.load("params/generator_BatchNorm_gamma:0.npy")).contiguous()
        self.resblocks[0].bn1.bias.data = torch.tensor(np.load("params/generator_BatchNorm_beta:0.npy")).contiguous()
        
        self.resblocks[0].bn2.weight.data = torch.tensor(np.load("params/generator_BatchNorm_1_gamma:0.npy")).contiguous()
        self.resblocks[0].bn2.bias.data = torch.tensor(np.load("params/generator_BatchNorm_1_beta:0.npy")).contiguous()
        
        self.resblocks[1].bn1.weight.data = torch.tensor(np.load("params/generator_BatchNorm_2_gamma:0.npy")).contiguous()
        self.resblocks[1].bn1.bias.data = torch.tensor(np.load("params/generator_BatchNorm_2_beta:0.npy")).contiguous()
        
        self.resblocks[1].bn2.weight.data = torch.tensor(np.load("params/generator_BatchNorm_3_gamma:0.npy")).contiguous()
        self.resblocks[1].bn2.bias.data = torch.tensor(np.load("params/generator_BatchNorm_3_beta:0.npy")).contiguous()
        
        self.resblocks[2].bn1.weight.data = torch.tensor(np.load("params/generator_BatchNorm_4_gamma:0.npy")).contiguous()
        self.resblocks[2].bn1.bias.data = torch.tensor(np.load("params/generator_BatchNorm_4_beta:0.npy")).contiguous()
        
        self.resblocks[2].bn2.weight.data = torch.tensor(np.load("params/generator_BatchNorm_5_gamma:0.npy")).contiguous()
        self.resblocks[2].bn2.bias.data = torch.tensor(np.load("params/generator_BatchNorm_5_beta:0.npy")).contiguous()
        
        self.bn.weight.data = torch.tensor(np.load("params/generator_BatchNorm_6_gamma:0.npy")).contiguous()
        self.bn.bias.data = torch.tensor(np.load("params/generator_BatchNorm_6_beta:0.npy")).contiguous()
        
        # Insert "contiguous" to make conv layers work correctly.  Not strictly necessary when using the gpu,
        # but will hopefully help us avoid headaches later
        self.resblocks[0].conv1.weight.data = torch.tensor(np.load("params/generator_conv2d_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[0].conv1.bias.data = torch.tensor(np.load("params/generator_conv2d_bias:0.npy")).contiguous()
        
        self.resblocks[0].conv2.weight.data = torch.tensor(np.load("params/generator_conv2d_1_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[0].conv2.bias.data = torch.tensor(np.load("params/generator_conv2d_1_bias:0.npy")).contiguous()
        
        self.resblocks[0].conv3.weight.data = torch.tensor(np.load("params/generator_conv2d_2_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[0].conv3.bias.data = torch.tensor(np.load("params/generator_conv2d_2_bias:0.npy")).contiguous()
        
        self.resblocks[1].conv1.weight.data = torch.tensor(np.load("params/generator_conv2d_3_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[1].conv1.bias.data = torch.tensor(np.load("params/generator_conv2d_3_bias:0.npy")).contiguous()
        
        self.resblocks[1].conv2.weight.data = torch.tensor(np.load("params/generator_conv2d_4_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[1].conv2.bias.data = torch.tensor(np.load("params/generator_conv2d_4_bias:0.npy")).contiguous()
        
        self.resblocks[1].conv3.weight.data = torch.tensor(np.load("params/generator_conv2d_5_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[1].conv3.bias.data = torch.tensor(np.load("params/generator_conv2d_5_bias:0.npy")).contiguous()
        
        self.resblocks[2].conv1.weight.data = torch.tensor(np.load("params/generator_conv2d_6_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[2].conv1.bias.data = torch.tensor(np.load("params/generator_conv2d_6_bias:0.npy")).contiguous()
        
        self.resblocks[2].conv2.weight.data = torch.tensor(np.load("params/generator_conv2d_7_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[2].conv2.bias.data = torch.tensor(np.load("params/generator_conv2d_7_bias:0.npy")).contiguous()
        
        self.resblocks[2].conv3.weight.data = torch.tensor(np.load("params/generator_conv2d_8_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[2].conv3.bias.data = torch.tensor(np.load("params/generator_conv2d_8_bias:0.npy")).contiguous()
        
        self.last_layer.weight.data = torch.tensor(np.load("params/generator_conv2d_9_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.last_layer.bias.data = torch.tensor(np.load("params/generator_conv2d_9_bias:0.npy")).contiguous()

In [6]:
class MLPLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super(MLPLayer, self).__init__()
        self.block = nn.Sequential(linear(input_size, output_size), nn.ReLU())
        
    def forward(self, x):
        return self.block(x)

class VectorDiscriminator(nn.Module):
    def __init__(self, input_size, hidden_size, num_blocks):
        self.blocks = nn.ModuleList()
        self.blocks.append(MLPLayer(input_size, hidden_size))
        
        for _ in range(num_blocks - 1):
            self.blocks.append(hidden_size, hidden_size)
        
        self.blocks.append(linear(hidden_size, 1))

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        
        return x

In [7]:
class ImageDiscriminator(nn.Module):
    def __init__(self, num_filters=128, num_blocks=4, num_channels=3):
        super(ImageDiscriminator, self).__init__()
        assert num_blocks >= 2, "Number of conv layers in the discriminator must be >= 2."
        
        self.resblocks = nn.ModuleList()
        self.activation = nn.ReLU()
        
        self.resblocks.append(SmallResBlock(in_filters=num_channels, 
                                            out_filters=num_filters))
        self.resblocks.append(ResBlock(in_filters=num_filters, 
                                       out_filters=num_filters, 
                                       resample='down'))
        for _ in range(num_blocks - 2):
            self.resblocks.append(ResBlock(in_filters=num_filters, 
                                           out_filters=num_filters))
            
        self.last_linear = linear(num_filters, 1)
        self.manually_initialize()
   
    def forward(self, x):
        for resblock in self.resblocks:
            x = resblock(x)
            
        x = self.activation(x)
        x = x.mean(dim=(-1,-2))
        x = self.last_linear(x)
        return x
    
    def manually_initialize(self):
        print('hi')
#         self.last_linear.weight.data = torch.tensor(np.load("params/discriminator_dense_kernel:0.npy")).transpose(0,1)
#         self.last_linear.bias.data = torch.tensor(np.load("params/discriminator_dense_bias:0.npy"))
        
#         self.resblocks[0].conv1.weight.data = torch.tensor(np.load("params/discriminator_conv2d_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[0].conv1.bias.data = torch.tensor(np.load("params/discriminator_conv2d_bias:0.npy"))
        
#         self.resblocks[0].conv2.weight.data = torch.tensor(np.load("params/discriminator_conv2d_1_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[0].conv2.bias.data = torch.tensor(np.load("params/discriminator_conv2d_1_bias:0.npy"))
        
#         self.resblocks[0].conv3.weight.data = torch.tensor(np.load("params/discriminator_conv2d_2_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[0].conv3.bias.data = torch.tensor(np.load("params/discriminator_conv2d_2_bias:0.npy"))
        
#         self.resblocks[1].conv1.weight.data = torch.tensor(np.load("params/discriminator_conv2d_3_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[1].conv1.bias.data = torch.tensor(np.load("params/discriminator_conv2d_3_bias:0.npy"))
        
#         self.resblocks[1].conv2.weight.data = torch.tensor(np.load("params/discriminator_conv2d_4_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[1].conv2.bias.data = torch.tensor(np.load("params/discriminator_conv2d_4_bias:0.npy"))
        
#         self.resblocks[1].conv3.weight.data = torch.tensor(np.load("params/discriminator_conv2d_5_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[1].conv3.bias.data = torch.tensor(np.load("params/discriminator_conv2d_5_bias:0.npy"))
        
#         self.resblocks[2].conv1.weight.data = torch.tensor(np.load("params/discriminator_conv2d_6_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[2].conv1.bias.data = torch.tensor(np.load("params/discriminator_conv2d_6_bias:0.npy"))
        
#         self.resblocks[2].conv2.weight.data = torch.tensor(np.load("params/discriminator_conv2d_7_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[2].conv2.bias.data = torch.tensor(np.load("params/discriminator_conv2d_7_bias:0.npy"))
        
#         self.resblocks[3].conv1.weight.data = torch.tensor(np.load("params/discriminator_conv2d_8_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[3].conv1.bias.data = torch.tensor(np.load("params/discriminator_conv2d_8_bias:0.npy"))
        
#         self.resblocks[3].conv2.weight.data = torch.tensor(np.load("params/discriminator_conv2d_9_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[3].conv2.bias.data = torch.tensor(np.load("params/discriminator_conv2d_9_bias:0.npy"))

In [8]:
def identity_embedding(pic):
    return pic

class GAN(nn.Module):
    def __init__(self, gamma, noise_size=128, num_filters=128, num_generator_blocks=3, num_discriminator_blocks=4,
                 batch_size=64, num_channels=3, discriminator_epsilon=1e-5, banach=True, 
                 embedding_func=identity_embedding, embedding_size=None, discriminator_hidden_size=None,
                 discriminator_type="Image"):  
        super(GAN, self).__init__()
        self.generator = Generator(noise_size, num_filters=num_filters, num_blocks=num_generator_blocks,
                                   start_image_size=4, num_channels=num_channels)
        
        if discriminator_type == "Image":
            self.discriminator = ImageDiscriminator(num_filters=num_filters, num_blocks=num_discriminator_blocks, 
                                                    num_channels=num_channels)
        elif discriminator_type == "Vector":
            self.discriminator = VectorDiscriminator(input_size=embedding_size, 
                                                     hidden_size=discriminator_hidden_size,
                                                     num_blocks=num_discriminator_blocks)
        else:
            raise ValueError("Discriminator type not recognized.")
        self.discriminator_epsilon = discriminator_epsilon
        
        # Assumption that the dual space is the same as the original space.
        self.gamma = gamma
        self.lambda_penalty = gamma
        
        self.banach = banach # Boolean parameter that decides on whether to do a banach/metric space based wgan.
        self.embedding_func = embedding_func # The embedding function used.
        
        self.register_buffer("penalty_grad_outputs", torch.ones(batch_size))
        self.register_buffer("noise_buffer", torch.ones((batch_size, noise_size)))
        self.register_buffer("epsilon_buffer", torch.ones(batch_size, 1, 1, 1))
        
    def forward_train_generator(self, noise=None):
        generated_image = self.forward_predict_generator(noise)
        discriminator_score_generated = self.forward_predict_discriminator(generated_image)
        return self.generator_loss(discriminator_score_generated)
    
    def forward_train_discriminator(self, real_images, noise=None):
        generated_images = self.forward_predict_generator(noise)
        discriminator_score_generated = self.forward_predict_discriminator(generated_images)
        discriminator_score_real = self.forward_predict_discriminator(real_images)
        return self.discriminator_loss(discriminator_score_real, discriminator_score_generated, real_images, generated_images)
        
    def forward_predict_generator(self, noise=None):
        if noise is None:
            noise = self.generate_noise()
        return self.generator(noise)
        
    def forward_predict_discriminator(self, image):
        return self.discriminator(image)
    
    def generate_noise(self):
        return torch.randn_like(self.noise_buffer)
    
    def generator_loss(self, d_score_generated):
#         print("Generator loss", torch.mean(d_generated_train) / self.gamma)
        return torch.mean(d_score_generated) / self.gamma #NOTE: Mehdi's version had a negative sign, original was positive
    
    def stable_norm(self, x):
        x = x.view(x.size(0), -1)
        alpha, _ = (x.abs() + 1e-5).max(1)
        
        return alpha * (x/alpha.unsqueeze(1)).norm(p=2, dim=1)
    
    def discriminator_loss(self, d_score_real, d_score_generated, real_images, generated_images):
#         print("d score generated", torch.mean(d_score_generated))
#         print("d score real", torch.mean(d_score_real))
        wasserstein_loss = (torch.mean(d_score_generated) - torch.mean(d_score_real)) / self.gamma
#         print("wass loss", wasserstein_loss)
        epsilon = self.epsilon_buffer.uniform_(0, 1)
        real_fake_mix = epsilon * generated_images + (1 - epsilon) * real_images 
        d_score_mix = self.discriminator(real_fake_mix).squeeze(1)
        
        gradients = torch.autograd.grad(d_score_mix, real_fake_mix, grad_outputs=self.penalty_grad_outputs,
                                        create_graph=True)[0]
        
#         print(gradients)
        gradient_penalty = torch.mean(self.stable_norm(gradients) / gamma - 1) ** 2
        print("grad penalty", float(self.lambda_penalty * gradient_penalty))
        d_regularizer_mean = torch.mean(d_score_real ** 2)
#         print("regularizer mean", float(self.discriminator_epsilon * d_regularizer_mean))
#         
        #NOTE: Mehdi's version had the wassestein loss positive, original seems to be negative
#         print("w loss is", -float(wasserstein_loss))
        d_loss = -wasserstein_loss + self.lambda_penalty * gradient_penalty + self.discriminator_epsilon * d_regularizer_mean
#         print("Overall d_loss", d_loss.item())
        return d_loss


In [9]:
# There is one discrepancy in this training code and the bwgan github implementation. That
# version uses an exponential moving average of the weights during evaluation.
#
# One other discrepancy with bwgan is the lack of usage of warm restarts for SGD. That's mentioned in
# the paper but I could not see in the implementation.
#
# The last main discrepancy is related to the model. In the bwgan code gamma is computed each batch. Here
# we compute gamma over the dataset instead. The difference should be very minor as gamma's value across batches
# is pretty stable. For MNIST gamma appeared to range from 29.8-30.1 from looking at a dozen gamma values.
def gan_train(model, dset_loader, optimizers, lr_schedulers, num_updates=1e5,
              use_cuda=False, num_discriminator=5):
    steps_so_far = 0
    curr_epoch = 0
    
    discriminator_optimizer, generator_optimizer = optimizers
    discriminator_lr_scheduler, generator_lr_scheduler = lr_schedulers
    
    display_generator = False
    
    while True:
        print('Epoch {} - Step {}/{}'.format(curr_epoch, steps_so_far, num_updates))
        print('-' * 10)

        # Iterate over data.
        for data, _ in dset_loader:
            
            print(steps_so_far)
            if steps_so_far % 1000 == 0:
                print(steps_so_far)
            
            if steps_so_far >= num_updates:
                return model
            
            if use_cuda:
                data = data.cuda()
            
            loss = model.forward_train_discriminator(data)
#             register_hooks(loss)
            
            if steps_so_far % 100 == 4:
                print("Discriminator Loss: ", loss.item())
                display_generator = True
            
            loss.backward()

            discriminator_optimizer.step()
            # zero the parameter gradients
            discriminator_optimizer.zero_grad()
                
                
            if steps_so_far % num_discriminator == num_discriminator-1:
                loss = model.forward_train_generator()
                loss.backward()
                generator_optimizer.step()
                
                # zero the parameter gradients
                generator_optimizer.zero_grad()

                discriminator_lr_scheduler.step()
                generator_lr_scheduler.step()
                if display_generator:
                    print("Generator Loss: ", loss.item())
                    display_generator = False
                    
                    
            if steps_so_far % 500 == 0:
                generated_images = model.forward_predict_generator()
                first_image = generated_images[0, 0].cpu().detach().numpy()
                min_val = float(np.amin(first_image))
                max_val = float(np.amax(first_image))
                plt.title("Generated image, range {} to {}".format(round(min_val, 3), round(max_val, 3)))
                plt.imshow(first_image)
                plt.show()
            
            steps_so_far += 1
        
        curr_epoch += 1

In [10]:
def compute_gamma(dset_loader):
    num_images = len(dset_loader.dataset)
    gamma = 0.0
    
    for data, _ in dset_loader:
        batch_size = data.size()[0]
        gamma += data.cuda().view(batch_size, -1).norm(2, dim=1).sum().item() / num_images
    
    return gamma

In [11]:
def iter_graph(root, callback):
    queue = [root]
    seen = set()
    while queue:
        fn = queue.pop()
        if fn in seen:
            continue
        seen.add(fn)
        for next_fn, _ in fn.next_functions:
            if next_fn is not None:
                queue.append(next_fn)
        callback(fn)

def register_hooks(var):
    def is_bad_grad(grad_output):
        if grad_output is None:
            return False
        
        grad_output = grad_output.data
        return grad_output.ne(grad_output).any() or grad_output.gt(1e4).any()
    
    def hook_cb(fn):
        def register_grad(grad_input, grad_output):
            for grad in grad_output:
                if is_bad_grad(grad):
                    print(fn)
#                     raise ValueError("Hi")
        fn.register_hook(register_grad)
    iter_graph(var.grad_fn, hook_cb)

In [12]:
noise_size = 128
batch_size = 1
use_cuda = True
base_lr = 2e-4
num_updates = int(1e2)
num_discriminator = 5
num_channels = 3

In [13]:
def get_data(name, train=True):
    assert name in ["mnist", "cifar", "celeba"]
    transform = transforms.Compose([transforms.Resize(32),
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5,), (0.5,))])
    if name == "mnist":
        return datasets.MNIST("mnist", train=train, download=True, transform=transform)
    if name == "cifar":
        return datasets.CIFAR10("cifar", train=train, transform=transform, download=True)
    if name == "celeba":
        if train == True:
            dset_str = "train"
        else:
            dset_str = "test"
        with open("celeba_64_bgr_-1_to_1_%s.pkl" % dset_str, "rb") as f:
            dataset = pickle.load(f)
        return dataset

In [14]:
# train_dataset = get_data("cifar")
# train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
#                                    num_workers=4, pin_memory=True, drop_last=True)
gamma = 27 #compute_gamma(train_dataloader)

In [91]:
model = GAN(gamma=gamma, noise_size=noise_size, batch_size=batch_size, num_channels=num_channels)
model.train()
y = torch.ones(1,128)
output1 = model.forward_predict_generator(y)
# output2 = model.forward_predict_discriminator(output1)

print(output1.size())
print(output1.permute(0,2,3,1)[0])

# print(output2)

# if use_cuda:
#     model = model.cuda()
#     model.discriminator = nn.DataParallel(model.discriminator)
#     model.generator = nn.DataParallel(model.generator)

# discriminator_optimizer = optim.Adam(model.discriminator.parameters(), betas=(0, 0.9), lr=base_lr)
# generator_optimizer = optim.Adam(model.generator.parameters(), betas=(0, 0.9), lr=base_lr)
# discriminator_lr_scheduler = optim.lr_scheduler.LambdaLR(discriminator_optimizer, lambda step: max(0, (1 - step/num_updates)))
# generator_lr_scheduler = optim.lr_scheduler.LambdaLR(generator_optimizer, lambda step: max(0, (1 - step/num_updates)))
# optimizers = discriminator_optimizer, generator_optimizer
# lr_schedulers = discriminator_lr_scheduler, generator_lr_scheduler

# import time

# start_time = time.time()
# model.train()
# model = gan_train(model, train_dataloader, optimizers, lr_schedulers, num_updates=num_updates, 
#                   use_cuda=use_cuda, num_discriminator=num_discriminator)
# print(time.time() - start_time)
# kangaroo

tensor([ 0.0329, -0.0379, -0.0009])
tensor([-0.0185,  0.0220,  0.0163])
hi
Should be all ones
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1.]])
torch.Size([1, 128])
Literally just dense
tensor([[-0.2824,  0.1309, -0.1149,  ..., -0.3147, -0.2160, -0.0382]],
       grad_fn=<AddmmBackward>)
torch.Size([1, 2048])
first view reshape tensor([[[-0.2824,  0.1309, -0.1149],
         [ 0.1462,  0.3888, -0.3026],
         [ 0.3701,  0.1530,  0.4670]],

        [[-0.4724,

ValueError: hi

In [89]:
py_up = vec2.permute(0,2,3,1).contiguous()
tf_up = np.load("upsampled.npy")
print(py_up.shape, tf_up.shape)
print(np.array_equal(tf_up, py_up.cpu().detach().numpy()))
print(tf_up[0,:4,:4,:4])
print(py_up[0,:4,:4,:4])
print("Are they basically the same?", np.sum(tf_up - py_up.cpu().detach().numpy()))
print("======" * 10)

py_bn_input = vec0.permute(0,2,3,1).contiguous()
tf_bn_input = np.load("bn_input.npy")
print(py_bn_input.shape)
print(tf_bn_input.shape)
print(np.array_equal(tf_bn_input, py_bn_input.cpu().detach().numpy()))
print(tf_bn_input[0,:4,:4,:4])
print(py_bn_input[0,:4,:4,:4])

print("--------------" * 5)
print(np.mean(py_bn_input.cpu().detach().numpy()))
print(np.mean(tf_bn_input))

torch.Size([1, 8, 8, 128]) (1, 8, 8, 128)
False
[[[0.         0.21231355 0.         0.        ]
  [0.         0.21231355 0.         0.        ]
  [0.22819537 0.81528354 0.         0.        ]
  [0.22819537 0.81528354 0.         0.        ]]

 [[0.         0.21231355 0.         0.        ]
  [0.         0.21231355 0.         0.        ]
  [0.22819537 0.81528354 0.         0.        ]
  [0.22819537 0.81528354 0.         0.        ]]

 [[0.         0.         0.         0.5992161 ]
  [0.         0.         0.         0.5992161 ]
  [2.3898323  0.15784769 0.         0.        ]
  [2.3898323  0.15784769 0.         0.        ]]

 [[0.         0.         0.         0.5992161 ]
  [0.         0.         0.         0.5992161 ]
  [2.3898323  0.15784769 0.         0.        ]
  [2.3898323  0.15784769 0.         0.        ]]]
tensor([[[0.0000, 0.2123, 0.0000, 0.0000],
         [0.0000, 0.2123, 0.0000, 0.0000],
         [0.2282, 0.8153, 0.0000, 0.0000],
         [0.2282, 0.8153, 0.0000, 0.0000]],

  

In [94]:
x = model.generator.resblocks[0].conv1.weight.data
print(x.shape)
x = x.permute(3,2,1,0)
print(x)


torch.Size([128, 128, 3, 3])
tensor([[[[ 0.0063,  0.0668, -0.0036,  ...,  0.0325,  0.0310, -0.0140],
          [-0.0357,  0.0247, -0.0519,  ...,  0.0236, -0.0460, -0.0439],
          [-0.0586,  0.0667,  0.0048,  ...,  0.0506, -0.0011, -0.0602],
          ...,
          [ 0.0390,  0.0260,  0.0249,  ..., -0.0318, -0.0659, -0.0619],
          [ 0.0586,  0.0379,  0.0611,  ...,  0.0160, -0.0060, -0.0391],
          [-0.0629,  0.0580,  0.0303,  ...,  0.0295,  0.0052, -0.0292]],

         [[-0.0569,  0.0478,  0.0292,  ..., -0.0126,  0.0547, -0.0509],
          [-0.0274, -0.0698, -0.0083,  ...,  0.0274, -0.0480, -0.0671],
          [-0.0486,  0.0319, -0.0687,  ...,  0.0449,  0.0442, -0.0662],
          ...,
          [-0.0378,  0.0239,  0.0275,  ..., -0.0624,  0.0005,  0.0372],
          [-0.0452, -0.0290, -0.0406,  ...,  0.0335, -0.0294,  0.0070],
          [ 0.0267, -0.0139, -0.0078,  ...,  0.0571, -0.0209,  0.0242]],

         [[-0.0061,  0.0247,  0.0241,  ..., -0.0638, -0.0418,  0.0468],
 

## Weirdness Afoot

In [106]:
import tensorflow as tf

in_channels = 2
out_channels = 4
padding = 1
kernel_size = 3

x = np.random.random((1,4,4, in_channels))
x_tf = tf.convert_to_tensor(x, tf.float32)
x_py = torch.FloatTensor(np.transpose(x, [0, 3, 2, 1])) # I've also tried [0, 3, 2, 1]


sess = tf.InteractiveSession()
# x_tf = tf.convert_to_tensor(x.permute(0,2,3,1).numpy(), dtype=tf.float32)
initializer = tf.contrib.layers.variance_scaling_initializer(uniform=True)
conv_tf = tf.layers.conv2d(x_tf, filters=out_channels, kernel_size=kernel_size,
                            padding='SAME', kernel_initializer=initializer)
sess.run([tf.global_variables_initializer(),
          tf.local_variables_initializer()])
result = conv_tf.eval()
print("result", result)
conv_weight = sess.run(tf.trainable_variables()[-2])
conv_bias = sess.run(tf.trainable_variables()[-1])
sess.close()

result [[[[-0.61305225 -0.6804626   0.11569262  0.12916481]
   [ 0.8744844  -0.25568002  0.18106008  0.34250492]
   [-0.34760213  0.06820619 -0.5159601  -0.0028477 ]
   [ 0.37385195 -0.13721271 -0.11608122 -0.18903476]]

  [[ 0.06974803 -0.36361012 -0.70959246 -0.24128214]
   [ 0.7637968   0.91365236 -0.5949327  -0.32401252]
   [ 0.35629666 -0.38258404 -0.6643014  -0.8583082 ]
   [ 0.55447584  0.08696834 -0.67363775  0.30926135]]

  [[ 0.33146533 -0.11482179 -0.3043952   0.29770672]
   [ 0.0145492   0.03810987 -1.1154032  -0.89013636]
   [ 0.91061    -0.29720092 -0.5351878   0.23060244]
   [ 0.19792801 -0.07683964 -0.39902225 -0.54844284]]

  [[-0.0858967  -0.13297018 -0.38191897  0.20462504]
   [-0.32521605 -0.13901548 -1.8993428  -0.43858945]
   [-0.5998081   0.33009914 -1.0219332  -0.8117941 ]
   [ 0.34582764  0.30031496 -1.4258046   0.06366794]]]]




In [107]:
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
conv.weight.data = torch.tensor(conv_weight).permute(3,2,1,0).contiguous()
conv.bias.data = torch.tensor(conv_bias).contiguous()    
py_result = conv(x_py).permute(0,3,2,1) # Rearange for comparison with tensorflow
print(py_result)

tensor([[[[-0.6131, -0.6805,  0.1157,  0.1292],
          [ 0.8745, -0.2557,  0.1811,  0.3425],
          [-0.3476,  0.0682, -0.5160, -0.0028],
          [ 0.3739, -0.1372, -0.1161, -0.1890]],

         [[ 0.0697, -0.3636, -0.7096, -0.2413],
          [ 0.7638,  0.9137, -0.5949, -0.3240],
          [ 0.3563, -0.3826, -0.6643, -0.8583],
          [ 0.5545,  0.0870, -0.6736,  0.3093]],

         [[ 0.3315, -0.1148, -0.3044,  0.2977],
          [ 0.0145,  0.0381, -1.1154, -0.8901],
          [ 0.9106, -0.2972, -0.5352,  0.2306],
          [ 0.1979, -0.0768, -0.3990, -0.5484]],

         [[-0.0859, -0.1330, -0.3819,  0.2046],
          [-0.3252, -0.1390, -1.8993, -0.4386],
          [-0.5998,  0.3301, -1.0219, -0.8118],
          [ 0.3458,  0.3003, -1.4258,  0.0637]]]], grad_fn=<PermuteBackward>)


In [101]:
import itertools
all_perms = list(itertools.permutations([1,2,3]))
for perm in all_perms:
    perm_result = py_result.permute(0, *perm)
    print(perm_result)

torch.Size([1, 4, 4, 4])

## Weirdness Ends Here

In [85]:
# # FROM SOME SO POST
# import model
# import torch
# from torch import nn
# import torch.nn.functional as F


# sess = tf.Session()
# np.random.seed(1)
# tf.set_random_seed(1)

# #parameters
# kernel_size = 3
# input_feat = 4
# output_feat = 4

# #inputs
# npo = np.random.random((1,5,5, input_feat))
# x = tf.convert_to_tensor(npo, tf.float32)
# x2 = torch.tensor(np.transpose(npo, [0, 3, 2, 1])).double()

# #the same weights
# weights = np.random.random((kernel_size,kernel_size,input_feat,output_feat))
# weights_torch = np.transpose(weights, [3, 2, 1, 0])

# #convolving with tensorflow
# w = tf.Variable(weights, name="testconv_W", dtype=tf.float32)
# res = tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding="VALID")

# sess.run(tf.global_variables_initializer())


# #convolving with torch
# torchres = F.conv2d(x2, torch.tensor(weights_torch), padding=0, bias=torch.zeros((output_feat)).double())

# #comparing the results
# print(np.mean(np.transpose(sess.run(res), [0, 3, 1, 2])) - torch.mean(torchres).detach().numpy())

1.109498761309169e-06


In [97]:
# v = torch.tensor(np.load("params/generator_conv2d_kernel:0.npy")).permute(3,2,1,0)
# print(v.shape)
print(tf_bn_input.shape)

(1, 8, 8, 128)


In [96]:
print(tf_bn_input[0, 0, 0,:])
print("==="*22)
print(py_bn_input[0, 0, 0,:].cpu().detach().numpy())

[ 0.65031743 -0.6203244  -0.49183062  0.5059358   0.28351417 -1.2729391
 -0.0390036  -0.7468238   0.6516546  -1.204329    0.0734973  -0.02232832
  0.3725338  -0.08989799 -0.00615811  0.4278187  -0.09832315  0.29158783
  0.84003997  0.04280537 -1.0682315   0.1907023  -0.36410913  1.5561202
 -0.12220084  0.02108081  0.01915061  1.4213235   0.09183311 -0.36637288
 -0.23123804  0.5040692   0.97450817 -0.2969277  -0.94745934  1.1106435
 -0.07927477  0.82612944 -0.1617369   0.18703566 -0.6548972   0.8209766
 -0.31261256  0.35590196  0.09823444 -1.3237336   0.7239871   0.6121413
 -0.36086214  0.1626235  -0.26243168  0.5572176   1.4531343   0.5239057
  0.84014416  1.1397151   0.36312747  1.3075147  -1.1382031  -0.22748637
  0.5468823  -0.10433887 -1.0650768   0.06931782  0.47596106  1.3198633
 -0.42912716  0.73430645 -0.62745625 -1.3099842  -0.75498706  0.4060669
 -0.8465472  -0.22658643 -0.16979995  0.18600342  0.03421217  0.7452529
 -0.54868406 -0.33201602 -0.25047937  0.40531865  1.1708739 

In [95]:
print(tf_bn_input[0, 0, 1,:])
print("==="*22)
print(py_bn_input[0, 0, 1,:].cpu().detach().numpy())

[-0.46407175 -0.29227483  0.26989132  0.51021767  0.5166483  -0.05000585
 -0.08072945  0.04050548 -0.36849916 -0.19898066 -0.8010465  -0.23489217
 -0.4345651   0.6034463  -0.25511673  0.6966145   1.2504725  -0.08049522
 -0.26976836 -1.2811017  -0.5883877   0.0571833   0.12746903  0.8738511
  1.5349185  -0.15438713 -1.4112797   0.9779048   0.40282556 -1.927623
 -0.38846087  0.0151476   1.437269   -1.5116789  -0.3965437   1.2025797
 -0.09323318  0.6826073   0.8670521   0.35153592 -0.9071082   0.3013622
 -0.14124417  0.19807327 -0.26500982 -0.11583344  1.0241275   1.1238248
  0.7542276  -0.83686686 -0.95186037  0.658851    1.2842011   1.5170897
  0.4500775  -0.1398077   0.9256458   0.5818185  -0.44643986 -0.18722752
  0.84645057  0.55748576 -0.47237825 -0.14010607 -0.5915407   0.15492344
 -0.03601738  0.19088033 -0.27634692  0.77118975  0.25651816  0.18324804
  0.446023    0.40662053  0.07430738  0.03888976  0.00274904 -0.06280538
 -1.3808008   0.40904635 -1.8467833   0.7404522  -0.101253

In [75]:
# py_out = vec1.permute(0,2,3,1).contiguous()
# tf_out = np.load("bn_output.npy")
# # print(py_up.shape, tf_up.shape)
# # print(np.array_equal(tf_up, py_up.cpu().detach().numpy()))
# print(tf_out[0,:4,:4,:4])
# print("   \n" * 3)
# print(py_out[0,:4,:4,:4])
# print("======\n" * 3)