In [1]:
%matplotlib inline

import torch.nn as nn
import torch
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data as data
import numpy as np
import math
from matplotlib import pyplot as plt
from graphviz import Digraph
import os

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7fdd28a3c090>

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="2,3"
global py_conv_vec
global act_vec
global dense_output
global upsampled

In [3]:
def upsample(x):
    return nn.functional.interpolate(x, scale_factor=2, mode="nearest")

def downsample(x):
    return nn.functional.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)

def conv_layer(in_filters, out_filters=32, kernel_size=3, he_init=True):
    same_padding = (kernel_size-1)//2
    conv = nn.Conv2d(in_filters, out_filters, kernel_size=kernel_size, padding=same_padding)
    
    if he_init:
        he_init_constant = math.sqrt(6 / (in_filters * kernel_size**2))
        nn.init.uniform_(conv.weight, -he_init_constant, he_init_constant)
    else:
        xavier_init_constant = math.sqrt(6 / ((in_filters + out_filters) * kernel_size**2))
        nn.init.uniform_(conv.weight, -xavier_init_constant, xavier_init_constant)
    nn.init.constant_(conv.bias, 0)
    
    return conv

def bn(channels):
    batchnorm = nn.BatchNorm2d(channels, eps=1e-5)
    nn.init.constant_(batchnorm.weight, 1)
    return batchnorm

def linear(in_features, out_features):
    linear_layer = nn.Linear(in_features, out_features)
    
    xavier_init_constant = math.sqrt(6/(in_features+out_features))
    nn.init.uniform_(linear_layer.weight, -xavier_init_constant, xavier_init_constant)
    nn.init.constant_(linear_layer.bias, 0)
    
    return linear_layer

In [4]:
class ResBlock(nn.Module):
    def __init__(self, in_filters, out_filters, resample=None, normalize=False, activation=None):
        super(ResBlock, self).__init__()
        self.in_filters = in_filters
        self.out_filters = out_filters
        self.resample = resample
        self.normalize = normalize
        
        self.conv1 = conv_layer(in_filters, out_filters)
        self.conv2 = conv_layer(out_filters, out_filters)
        
        if resample:
            self.conv3 = conv_layer(in_filters, out_filters, kernel_size=1, he_init=False)
        
        if normalize:
            self.bn1 = bn(in_filters)
            self.bn2 = bn(out_filters)
            
        if activation is not None:
            self.activation = activation
        else:
            self.activation = nn.ReLU()
                
    def forward(self, x): 
        orig_input = x
        
#         print("ACTUAL BATCHNORM")
#         print(x)
#         print(x.size())
        
        
        if self.normalize:
            x = self.bn1(x)
            
        
#         print("ACTUAL BATCHNORM RESULT")
        
#         print(x)
#         print(x.size())
#         raise ValueError("hi")
        
        x = self.activation(x)
        
        print("Activated")
        print(x[0, :2, :2,:2])
        global act_vec
        act_vec = x
        
        if self.resample == 'up':
            x = upsample(x)
            
        print("Upsampled")
        print(x[0, :2, :2,:2])
        global upsampled
        upsampled = x   
       
        x = self.conv1(x)
        
        print("AFTER conv1")
        print(x[0, :2, :2,:2])
        global py_conv_vec
        py_conv_vec = x
        
        print("Conv weights")
        print(self.conv1.weight.data[0, :3, :3, :3])
        print("------")
        print(self.conv1.bias.data)
        
        
        raise ValueError("hi")

        
        if self.normalize:
            x = self.bn2(x)
            
        x = self.activation(x)
        x = self.conv2(x)
        
        if self.resample == 'down':
            x = downsample(x)
        
        # Shortcut
        if self.resample == 'down': 
            shortcut_x = downsample(self.conv3(orig_input))
        elif self.resample == 'up':
            shortcut_x = self.conv3(upsample(orig_input))
        elif self.resample == None:
            shortcut_x = orig_input
        
        return x + shortcut_x
    
class SmallResBlock(nn.Module):
    def __init__(self, in_filters, out_filters, activation=None):
        super(SmallResBlock, self).__init__()
        self.in_filters = in_filters
        self.out_filters = out_filters
        
        self.conv1 = conv_layer(in_filters, out_filters)
        self.conv2 = conv_layer(out_filters, out_filters)
        self.conv3 = conv_layer(in_filters, out_filters, kernel_size=1, he_init=False)
            
        if activation is not None:
            self.activation = activation
        else:
            self.activation = nn.ReLU()
                
    def forward(self, x): 
        orig_input = x
       
        x = self.conv1(x)   
        x = downsample(self.conv2(self.activation(x)))

        # Shortcut
        shortcut_x = self.conv3(downsample(orig_input))
        
        return x + shortcut_x 

In [5]:
class Generator(nn.Module):
    def __init__(self, input_size, num_filters=128, num_blocks=3, start_image_size=4, num_channels=3):
        super(Generator, self).__init__()
        
        self.num_filters = num_filters
        self.start_image_size = start_image_size
        
        self.first_linear = linear(input_size, num_filters * start_image_size ** 2)
        self.resblocks = nn.ModuleList()
        self.activation = nn.ReLU()
        
        for _ in range(num_blocks):
            self.resblocks.append(
                ResBlock(in_filters=self.num_filters, 
                         out_filters=self.num_filters, 
                         resample='up', 
                         normalize=True))
            
        self.last_layer = conv_layer(num_filters, num_channels)
        self.bn = bn(num_filters)
        self.manually_initialize()
    
    def forward(self, noise):
        
        print("Should be all ones")
        print(noise)
        print(noise.size())
        
        x = self.first_linear(noise)
        
        print("Literally just dense")
        
#         print("A1")
        
        print(x)
        print(x.size())
        
#         x = x.view(-1, self.num_filters, self.start_image_size, self.start_image_size)
        x = x.view(-1, self.start_image_size, self.start_image_size, self.num_filters)
        print("first view reshape", x[0, :3, :3, :3])
        x = x.permute(0, 3, 1, 2)
        print("second view reshape", x[0, :3, :3, :3])
        
        
        global dense_output
        dense_output = x
        
        print("After Dense")
        print(x.permute(0,2,3,1))
        print(x.size())
        
        for resblock in self.resblocks:
            x = resblock(x)
#             print("B")
#             print(x)
#             print(x.size())
            
        x = self.activation(self.bn(x))
        result = self.last_layer(x)
        return torch.tanh(result)
    
    def manually_initialize(self):
        print(self.first_linear.weight.data[0,0:3])
        self.first_linear.weight.data = torch.tensor(np.load("params/generator_dense_kernel:0.npy")).transpose(0,1)
        self.first_linear.bias.data = torch.tensor(np.load("params/generator_dense_bias:0.npy"))
        print(self.first_linear.weight.data[0:3,0])
        
        self.resblocks[0].bn1.weight.data = torch.tensor(np.load("params/generator_BatchNorm_gamma:0.npy"))
        self.resblocks[0].bn1.bias.data = torch.tensor(np.load("params/generator_BatchNorm_beta:0.npy"))
        
        self.resblocks[0].bn2.weight.data = torch.tensor(np.load("params/generator_BatchNorm_1_gamma:0.npy"))
        self.resblocks[0].bn2.bias.data = torch.tensor(np.load("params/generator_BatchNorm_1_beta:0.npy"))
        
        self.resblocks[1].bn1.weight.data = torch.tensor(np.load("params/generator_BatchNorm_2_gamma:0.npy"))
        self.resblocks[1].bn1.bias.data = torch.tensor(np.load("params/generator_BatchNorm_2_beta:0.npy"))
        
        self.resblocks[1].bn2.weight.data = torch.tensor(np.load("params/generator_BatchNorm_3_gamma:0.npy"))
        self.resblocks[1].bn2.bias.data = torch.tensor(np.load("params/generator_BatchNorm_3_beta:0.npy"))
        
        self.resblocks[2].bn1.weight.data = torch.tensor(np.load("params/generator_BatchNorm_4_gamma:0.npy"))
        self.resblocks[2].bn1.bias.data = torch.tensor(np.load("params/generator_BatchNorm_4_beta:0.npy"))
        
        self.resblocks[2].bn2.weight.data = torch.tensor(np.load("params/generator_BatchNorm_5_gamma:0.npy"))
        self.resblocks[2].bn2.bias.data = torch.tensor(np.load("params/generator_BatchNorm_5_beta:0.npy"))
        
        self.bn.weight.data = torch.tensor(np.load("params/generator_BatchNorm_6_gamma:0.npy"))
        self.bn.bias.data = torch.tensor(np.load("params/generator_BatchNorm_6_beta:0.npy"))
        
        # Insert "contiguous" to make conv layers work correctly.  Not strictly necessary when using the gpu,
        # but will hopefully help us avoid headaches later
        self.resblocks[0].conv1.weight.data = torch.tensor(np.load("params/generator_conv2d_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[0].conv1.bias.data = torch.tensor(np.load("params/generator_conv2d_bias:0.npy"))
        
        self.resblocks[0].conv2.weight.data = torch.tensor(np.load("params/generator_conv2d_1_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[0].conv2.bias.data = torch.tensor(np.load("params/generator_conv2d_1_bias:0.npy"))
        
        self.resblocks[0].conv3.weight.data = torch.tensor(np.load("params/generator_conv2d_2_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[0].conv3.bias.data = torch.tensor(np.load("params/generator_conv2d_2_bias:0.npy"))
        
        self.resblocks[1].conv1.weight.data = torch.tensor(np.load("params/generator_conv2d_3_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[1].conv1.bias.data = torch.tensor(np.load("params/generator_conv2d_3_bias:0.npy"))
        
        self.resblocks[1].conv2.weight.data = torch.tensor(np.load("params/generator_conv2d_4_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[1].conv2.bias.data = torch.tensor(np.load("params/generator_conv2d_4_bias:0.npy"))
        
        self.resblocks[1].conv3.weight.data = torch.tensor(np.load("params/generator_conv2d_5_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[1].conv3.bias.data = torch.tensor(np.load("params/generator_conv2d_5_bias:0.npy"))
        
        self.resblocks[2].conv1.weight.data = torch.tensor(np.load("params/generator_conv2d_6_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[2].conv1.bias.data = torch.tensor(np.load("params/generator_conv2d_6_bias:0.npy"))
        
        self.resblocks[2].conv2.weight.data = torch.tensor(np.load("params/generator_conv2d_7_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[2].conv2.bias.data = torch.tensor(np.load("params/generator_conv2d_7_bias:0.npy"))
        
        self.resblocks[2].conv3.weight.data = torch.tensor(np.load("params/generator_conv2d_8_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.resblocks[2].conv3.bias.data = torch.tensor(np.load("params/generator_conv2d_8_bias:0.npy"))
        
        self.last_layer.weight.data = torch.tensor(np.load("params/generator_conv2d_9_kernel:0.npy")).permute(3,2,1,0).contiguous()
        self.last_layer.bias.data = torch.tensor(np.load("params/generator_conv2d_9_bias:0.npy"))

In [6]:
class MLPLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super(MLPLayer, self).__init__()
        self.block = nn.Sequential(linear(input_size, output_size), nn.ReLU())
        
    def forward(self, x):
        return self.block(x)

class VectorDiscriminator(nn.Module):
    def __init__(self, input_size, hidden_size, num_blocks):
        self.blocks = nn.ModuleList()
        self.blocks.append(MLPLayer(input_size, hidden_size))
        
        for _ in range(num_blocks - 1):
            self.blocks.append(hidden_size, hidden_size)
        
        self.blocks.append(linear(hidden_size, 1))

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        
        return x

In [7]:
class ImageDiscriminator(nn.Module):
    def __init__(self, num_filters=128, num_blocks=4, num_channels=3):
        super(ImageDiscriminator, self).__init__()
        assert num_blocks >= 2, "Number of conv layers in the discriminator must be >= 2."
        
        self.resblocks = nn.ModuleList()
        self.activation = nn.ReLU()
        
        self.resblocks.append(SmallResBlock(in_filters=num_channels, 
                                            out_filters=num_filters))
        self.resblocks.append(ResBlock(in_filters=num_filters, 
                                       out_filters=num_filters, 
                                       resample='down'))
        for _ in range(num_blocks - 2):
            self.resblocks.append(ResBlock(in_filters=num_filters, 
                                           out_filters=num_filters))
            
        self.last_linear = linear(num_filters, 1)
        self.manually_initialize()
   
    def forward(self, x):
        for resblock in self.resblocks:
            x = resblock(x)
            
        x = self.activation(x)
        x = x.mean(dim=(-1,-2))
        x = self.last_linear(x)
        return x
    
    def manually_initialize(self):
        print('hi')
#         self.last_linear.weight.data = torch.tensor(np.load("params/discriminator_dense_kernel:0.npy")).transpose(0,1)
#         self.last_linear.bias.data = torch.tensor(np.load("params/discriminator_dense_bias:0.npy"))
        
#         self.resblocks[0].conv1.weight.data = torch.tensor(np.load("params/discriminator_conv2d_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[0].conv1.bias.data = torch.tensor(np.load("params/discriminator_conv2d_bias:0.npy"))
        
#         self.resblocks[0].conv2.weight.data = torch.tensor(np.load("params/discriminator_conv2d_1_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[0].conv2.bias.data = torch.tensor(np.load("params/discriminator_conv2d_1_bias:0.npy"))
        
#         self.resblocks[0].conv3.weight.data = torch.tensor(np.load("params/discriminator_conv2d_2_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[0].conv3.bias.data = torch.tensor(np.load("params/discriminator_conv2d_2_bias:0.npy"))
        
#         self.resblocks[1].conv1.weight.data = torch.tensor(np.load("params/discriminator_conv2d_3_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[1].conv1.bias.data = torch.tensor(np.load("params/discriminator_conv2d_3_bias:0.npy"))
        
#         self.resblocks[1].conv2.weight.data = torch.tensor(np.load("params/discriminator_conv2d_4_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[1].conv2.bias.data = torch.tensor(np.load("params/discriminator_conv2d_4_bias:0.npy"))
        
#         self.resblocks[1].conv3.weight.data = torch.tensor(np.load("params/discriminator_conv2d_5_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[1].conv3.bias.data = torch.tensor(np.load("params/discriminator_conv2d_5_bias:0.npy"))
        
#         self.resblocks[2].conv1.weight.data = torch.tensor(np.load("params/discriminator_conv2d_6_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[2].conv1.bias.data = torch.tensor(np.load("params/discriminator_conv2d_6_bias:0.npy"))
        
#         self.resblocks[2].conv2.weight.data = torch.tensor(np.load("params/discriminator_conv2d_7_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[2].conv2.bias.data = torch.tensor(np.load("params/discriminator_conv2d_7_bias:0.npy"))
        
#         self.resblocks[3].conv1.weight.data = torch.tensor(np.load("params/discriminator_conv2d_8_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[3].conv1.bias.data = torch.tensor(np.load("params/discriminator_conv2d_8_bias:0.npy"))
        
#         self.resblocks[3].conv2.weight.data = torch.tensor(np.load("params/discriminator_conv2d_9_kernel:0.npy")).permute(3,2,1,0)
#         self.resblocks[3].conv2.bias.data = torch.tensor(np.load("params/discriminator_conv2d_9_bias:0.npy"))

In [8]:
def identity_embedding(pic):
    return pic

class GAN(nn.Module):
    def __init__(self, gamma, noise_size=128, num_filters=128, num_generator_blocks=3, num_discriminator_blocks=4,
                 batch_size=64, num_channels=3, discriminator_epsilon=1e-5, banach=True, 
                 embedding_func=identity_embedding, embedding_size=None, discriminator_hidden_size=None,
                 discriminator_type="Image"):  
        super(GAN, self).__init__()
        self.generator = Generator(noise_size, num_filters=num_filters, num_blocks=num_generator_blocks,
                                   start_image_size=4, num_channels=num_channels)
        
        if discriminator_type == "Image":
            self.discriminator = ImageDiscriminator(num_filters=num_filters, num_blocks=num_discriminator_blocks, 
                                                    num_channels=num_channels)
        elif discriminator_type == "Vector":
            self.discriminator = VectorDiscriminator(input_size=embedding_size, 
                                                     hidden_size=discriminator_hidden_size,
                                                     num_blocks=num_discriminator_blocks)
        else:
            raise ValueError("Discriminator type not recognized.")
        self.discriminator_epsilon = discriminator_epsilon
        
        # Assumption that the dual space is the same as the original space.
        self.gamma = gamma
        self.lambda_penalty = gamma
        
        self.banach = banach # Boolean parameter that decides on whether to do a banach/metric space based wgan.
        self.embedding_func = embedding_func # The embedding function used.
        
        self.register_buffer("penalty_grad_outputs", torch.ones(batch_size))
        self.register_buffer("noise_buffer", torch.ones((batch_size, noise_size)))
        self.register_buffer("epsilon_buffer", torch.ones(batch_size, 1, 1, 1))
        
    def forward_train_generator(self, noise=None):
        generated_image = self.forward_predict_generator(noise)
        discriminator_score_generated = self.forward_predict_discriminator(generated_image)
        return self.generator_loss(discriminator_score_generated)
    
    def forward_train_discriminator(self, real_images, noise=None):
        generated_images = self.forward_predict_generator(noise)
        discriminator_score_generated = self.forward_predict_discriminator(generated_images)
        discriminator_score_real = self.forward_predict_discriminator(real_images)
        return self.discriminator_loss(discriminator_score_real, discriminator_score_generated, real_images, generated_images)
        
    def forward_predict_generator(self, noise=None):
        if noise is None:
            noise = self.generate_noise()
        return self.generator(noise)
        
    def forward_predict_discriminator(self, image):
        return self.discriminator(image)
    
    def generate_noise(self):
        return torch.randn_like(self.noise_buffer)
    
    def generator_loss(self, d_score_generated):
#         print("Generator loss", torch.mean(d_generated_train) / self.gamma)
        return torch.mean(d_score_generated) / self.gamma #NOTE: Mehdi's version had a negative sign, original was positive
    
    def stable_norm(self, x):
        x = x.view(x.size(0), -1)
        alpha, _ = (x.abs() + 1e-5).max(1)
        
        return alpha * (x/alpha.unsqueeze(1)).norm(p=2, dim=1)
    
    def discriminator_loss(self, d_score_real, d_score_generated, real_images, generated_images):
#         print("d score generated", torch.mean(d_score_generated))
#         print("d score real", torch.mean(d_score_real))
        wasserstein_loss = (torch.mean(d_score_generated) - torch.mean(d_score_real)) / self.gamma
#         print("wass loss", wasserstein_loss)
        epsilon = self.epsilon_buffer.uniform_(0, 1)
        real_fake_mix = epsilon * generated_images + (1 - epsilon) * real_images 
        d_score_mix = self.discriminator(real_fake_mix).squeeze(1)
        
        gradients = torch.autograd.grad(d_score_mix, real_fake_mix, grad_outputs=self.penalty_grad_outputs,
                                        create_graph=True)[0]
        
#         print(gradients)
        gradient_penalty = torch.mean(self.stable_norm(gradients) / gamma - 1) ** 2
        print("grad penalty", float(self.lambda_penalty * gradient_penalty))
        d_regularizer_mean = torch.mean(d_score_real ** 2)
#         print("regularizer mean", float(self.discriminator_epsilon * d_regularizer_mean))
#         
        #NOTE: Mehdi's version had the wassestein loss positive, original seems to be negative
#         print("w loss is", -float(wasserstein_loss))
        d_loss = -wasserstein_loss + self.lambda_penalty * gradient_penalty + self.discriminator_epsilon * d_regularizer_mean
#         print("Overall d_loss", d_loss.item())
        return d_loss


In [9]:
# There is one discrepancy in this training code and the bwgan github implementation. That
# version uses an exponential moving average of the weights during evaluation.
#
# One other discrepancy with bwgan is the lack of usage of warm restarts for SGD. That's mentioned in
# the paper but I could not see in the implementation.
#
# The last main discrepancy is related to the model. In the bwgan code gamma is computed each batch. Here
# we compute gamma over the dataset instead. The difference should be very minor as gamma's value across batches
# is pretty stable. For MNIST gamma appeared to range from 29.8-30.1 from looking at a dozen gamma values.
def gan_train(model, dset_loader, optimizers, lr_schedulers, num_updates=1e5,
              use_cuda=False, num_discriminator=5):
    steps_so_far = 0
    curr_epoch = 0
    
    discriminator_optimizer, generator_optimizer = optimizers
    discriminator_lr_scheduler, generator_lr_scheduler = lr_schedulers
    
    display_generator = False
    
    while True:
        print('Epoch {} - Step {}/{}'.format(curr_epoch, steps_so_far, num_updates))
        print('-' * 10)

        # Iterate over data.
        for data, _ in dset_loader:
            
            print(steps_so_far)
            if steps_so_far % 1000 == 0:
                print(steps_so_far)
            
            if steps_so_far >= num_updates:
                return model
            
            if use_cuda:
                data = data.cuda()
            
            loss = model.forward_train_discriminator(data)
#             register_hooks(loss)
            
            if steps_so_far % 100 == 4:
                print("Discriminator Loss: ", loss.item())
                display_generator = True
            
            loss.backward()

            discriminator_optimizer.step()
            # zero the parameter gradients
            discriminator_optimizer.zero_grad()
                
                
            if steps_so_far % num_discriminator == num_discriminator-1:
                loss = model.forward_train_generator()
                loss.backward()
                generator_optimizer.step()
                
                # zero the parameter gradients
                generator_optimizer.zero_grad()

                discriminator_lr_scheduler.step()
                generator_lr_scheduler.step()
                if display_generator:
                    print("Generator Loss: ", loss.item())
                    display_generator = False
                    
                    
            if steps_so_far % 500 == 0:
                generated_images = model.forward_predict_generator()
                first_image = generated_images[0, 0].cpu().detach().numpy()
                min_val = float(np.amin(first_image))
                max_val = float(np.amax(first_image))
                plt.title("Generated image, range {} to {}".format(round(min_val, 3), round(max_val, 3)))
                plt.imshow(first_image)
                plt.show()
            
            steps_so_far += 1
        
        curr_epoch += 1

In [10]:
def compute_gamma(dset_loader):
    num_images = len(dset_loader.dataset)
    gamma = 0.0
    
    for data, _ in dset_loader:
        batch_size = data.size()[0]
        gamma += data.cuda().view(batch_size, -1).norm(2, dim=1).sum().item() / num_images
    
    return gamma

In [11]:
def iter_graph(root, callback):
    queue = [root]
    seen = set()
    while queue:
        fn = queue.pop()
        if fn in seen:
            continue
        seen.add(fn)
        for next_fn, _ in fn.next_functions:
            if next_fn is not None:
                queue.append(next_fn)
        callback(fn)

def register_hooks(var):
    def is_bad_grad(grad_output):
        if grad_output is None:
            return False
        
        grad_output = grad_output.data
        return grad_output.ne(grad_output).any() or grad_output.gt(1e4).any()
    
    def hook_cb(fn):
        def register_grad(grad_input, grad_output):
            for grad in grad_output:
                if is_bad_grad(grad):
                    print(fn)
#                     raise ValueError("Hi")
        fn.register_hook(register_grad)
    iter_graph(var.grad_fn, hook_cb)

In [12]:
noise_size = 128
batch_size = 1
use_cuda = True
base_lr = 2e-4
num_updates = int(1e2)
num_discriminator = 5
num_channels = 3

In [13]:
def get_data(name, train=True):
    assert name in ["mnist", "cifar", "celeba"]
    transform = transforms.Compose([transforms.Resize(32),
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5,), (0.5,))])
    if name == "mnist":
        return datasets.MNIST("mnist", train=train, download=True, transform=transform)
    if name == "cifar":
        return datasets.CIFAR10("cifar", train=train, transform=transform, download=True)
    if name == "celeba":
        if train == True:
            dset_str = "train"
        else:
            dset_str = "test"
        with open("celeba_64_bgr_-1_to_1_%s.pkl" % dset_str, "rb") as f:
            dataset = pickle.load(f)
        return dataset

In [14]:
# train_dataset = get_data("cifar")
# train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
#                                    num_workers=4, pin_memory=True, drop_last=True)
gamma = 27 #compute_gamma(train_dataloader)

In [15]:
model = GAN(gamma=gamma, noise_size=noise_size, batch_size=batch_size, num_channels=num_channels)
model.train()
y = torch.ones(1,128)
output1 = model.forward_predict_generator(y)
# output2 = model.forward_predict_discriminator(output1)

print(output1.size())
print(output1.permute(0,2,3,1)[0])

# print(output2)

# if use_cuda:
#     model = model.cuda()
#     model.discriminator = nn.DataParallel(model.discriminator)
#     model.generator = nn.DataParallel(model.generator)

# discriminator_optimizer = optim.Adam(model.discriminator.parameters(), betas=(0, 0.9), lr=base_lr)
# generator_optimizer = optim.Adam(model.generator.parameters(), betas=(0, 0.9), lr=base_lr)
# discriminator_lr_scheduler = optim.lr_scheduler.LambdaLR(discriminator_optimizer, lambda step: max(0, (1 - step/num_updates)))
# generator_lr_scheduler = optim.lr_scheduler.LambdaLR(generator_optimizer, lambda step: max(0, (1 - step/num_updates)))
# optimizers = discriminator_optimizer, generator_optimizer
# lr_schedulers = discriminator_lr_scheduler, generator_lr_scheduler

# import time

# start_time = time.time()
# model.train()
# model = gan_train(model, train_dataloader, optimizers, lr_schedulers, num_updates=num_updates, 
#                   use_cuda=use_cuda, num_discriminator=num_discriminator)
# print(time.time() - start_time)

tensor([-0.0374, -0.0036,  0.0143])
tensor([-0.0185,  0.0220,  0.0163])
hi
Should be all ones
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1.]])
torch.Size([1, 128])
Literally just dense
tensor([[-0.2824,  0.1309, -0.1149,  ..., -0.3147, -0.2160, -0.0382]],
       grad_fn=<AddmmBackward>)
torch.Size([1, 2048])
first view reshape tensor([[[-0.2824,  0.1309, -0.1149],
         [ 0.1462,  0.3888, -0.3026],
         [ 0.3701,  0.1530,  0.4670]],

        [[-0.4724,

ValueError: hi

In [None]:
conv_vec = py_conv_vec

In [None]:
py_c1 = model.generator.resblocks[0].conv1.weight
tf_c1 = np.load("conv1.npy")
print(py_c1.shape)
print(tf_c1.shape)
py_reversed = py_c1.permute(3,2,1,0)
print(py_reversed[0, :3, :3, :3])
print("-" * 20)
print(tf_c1[0, :3, :3, :3])
print("==" * 20)
orig = np.load("params/generator_conv2d_kernel:0.npy")
print(np.around(orig[0, :3, :3, :3], 4))


In [None]:
tf_conv_vec = np.load("conv_vec.npy")
py_conv_vec = conv_vec
py_conv_vec = py_conv_vec.permute(0, 2, 3, 1)

# tf_up_vec = np.load("up_vec.npy")
# py_up_vec = upsampled
# py_up_vec = py_up_vec.permute(0, 2, 3, 1)

print(tf_conv_vec[0, :4, :4, :4])
print("="*10)
print(py_conv_vec[0, :4, :4, :4])
print(tf_conv_vec.shape)


# print(py_act_vec[0, :4, :4, :4])
# print("="*10)
# print(py_up_vec[0, :4, :4, :4])

In [90]:
# py_up_vec = upsampled
in_channels = 1
kernel_size = 3
out_channels = 2
py_up_vec = torch.ones((1, in_channels, 4, 4))
tf_up_vec = py_up_vec.permute(0, 2, 3, 1).cpu().detach().numpy()
print(py_up_vec.shape)

torch.Size([1, 1, 4, 4])


### Pytorch

In [254]:

same_padding = (kernel_size-1)//2
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=same_padding)

# conv.weight.data = torch.tensor(np.load("params/generator_conv2d_kernel:0.npy")).permute(2,3,0,1)#.permute(3,2,1,0)
# conv.bias.data = torch.tensor(np.load("params/generator_conv2d_bias:0.npy"))
conv.weight.data = torch.tensor(tf_weight).permute(3,2,0,1)
conv.bias.data = torch.tensor(tf_bias)
# print("FIRST", torch.tensor(tf_weight).permute(3,2,0,1)[0, :2, :2, :2])
# print("SECOND", torch.tensor(tf_weight).permute(3,2,1,0)[0, :2, :2, :2])
print(conv.weight.data)
print("==" * 30)

print(conv.weight.data.shape)
conv.bias.data = torch.tensor(tf_bias)


output = conv(py_up_vec)
output = output.permute(0, 2, 3, 1)
print(output)
print(output.shape)


# print("====" * 30)

# conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=same_padding)

# # conv.weight.data = torch.tensor(np.load("params/generator_conv2d_kernel:0.npy")).permute(2,3,0,1)#.permute(3,2,1,0)
# # conv.bias.data = torch.tensor(np.load("params/generator_conv2d_bias:0.npy"))
# conv.weight.data = torch.tensor(tf_weight).permute(3,2,1,0)
# conv.bias.data = torch.tensor(tf_bias)
# print(conv.weight.data)
# output2 = conv(py_up_vec)
# output2 = output2.permute(0, 2, 3, 1)
# print(output2)
# print(output2.shape)


tensor([[[[-0.5335, -0.7245,  0.7922],
          [-0.6585, -0.2911, -0.0033],
          [ 0.0554,  0.5888, -0.7971]]],


        [[[ 0.1115, -0.2301,  0.1445],
          [ 0.3641,  0.5915, -0.6733],
          [ 0.8153,  0.2231, -0.0334]]]])
torch.Size([2, 1, 3, 3])
tensor([[[[ 1.0098,  0.5736],
          [ 0.1211,  0.8521],
          [ 0.1211,  0.8521],
          [ 0.2676,  0.2967]],

         [[ 0.3968, -0.1031],
          [-1.0253,  0.7669],
          [-1.0253,  0.7669],
          [-0.1543,  0.8849]],

         [[ 0.3968, -0.1031],
          [-1.0253,  0.7669],
          [-1.0253,  0.7669],
          [-0.1543,  0.8849]],

         [[ 0.3238,  0.7274],
          [-0.4398,  1.3743],
          [-0.4398,  1.3743],
          [ 0.1401,  1.4589]]]], grad_fn=<PermuteBackward>)
torch.Size([1, 4, 4, 2])


In [172]:
py_up_vec2 = torch.zeros((1, in_channels, 4, 4))
py_up_vec2[0,0,0,0] = 1

output2 = conv(py_up_vec2)
# output2 = output2.permute(0, 2, 3, 1)
print(output2.shape)
# print(output2[0, 0, :, :])
# print(output2[0, 1, :, :])
print(output2)

torch.Size([1, 2, 4, 4])
tensor([[[[ 0.7922, -0.2301,  0.0000,  0.0000],
          [ 0.1115, -0.5335,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.8153,  0.0554,  0.0000,  0.0000],
          [-0.0033,  0.5915,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]]]],
       grad_fn=<MkldnnConvolutionBackward>)


In [173]:
print(conv.weight.data.shape)
print("==")
conv.weight.data[:,:,:,:]

torch.Size([2, 1, 3, 3])
==


tensor([[[[-0.5335, -0.7245,  0.7922],
          [-0.6585, -0.2911, -0.0033],
          [ 0.0554,  0.5888, -0.7971]]],


        [[[ 0.1115, -0.2301,  0.1445],
          [ 0.3641,  0.5915, -0.6733],
          [ 0.8153,  0.2231, -0.0334]]]])

In [190]:
print(in_channels, out_channels, kernel_size, same_padding)
print(1,2,3,1)
conv_double = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0)
conv_double.weight.data = torch.tensor(tf_weight).permute(3,2,0,1)
conv_double.bias.data = torch.tensor(tf_bias)

print("INPUT---------")
py_up_vec3 = torch.zeros((1, in_channels, 6, 6))
py_up_vec3[0,0,0,0] = 1
print(py_up_vec3[0,0])


output3 = conv_double(py_up_vec3)
print("OUTPUT")
print(output3)

print("WAT?")
print(output2)
# print(output3 / 2)

print("---" * 10)
print(np.array_equal(conv.weight.data, conv_double.weight.data))
print(np.array_equal(conv.bias.data, conv_double.bias.data))

1 2 3 1
1 2 3 1
INPUT---------
tensor([[1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])
OUTPUT
tensor([[[[-0.5335,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.5915,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]]]],
       grad_fn=<MkldnnConvolutionBackward>)
WAT?
tensor([[[[ 0.7922, -0.2301,  0.0000,  0.0000],
          [ 0.1115, -0.5335,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.8153,  0.0554,  0.0000,  0.0000],
          [-0.0033,  0.5915,  0.0000,  0.0000],
          [ 0.0000,  0.0000, 

In [128]:
print(conv.weight.data.shape)
first_weights = conv.weight.data[0,0]
print(first_weights)
print(torch.sum(first_weights))
print(torch.sum(conv.weight.data[1,0]))

torch.Size([2, 1, 3, 3])
tensor([[-0.5335, -0.7245,  0.7922],
        [-0.6585, -0.2911, -0.0033],
        [ 0.0554,  0.5888, -0.7971]])
tensor(-1.5716)
tensor(1.3133)


In [200]:
print(conv.weight.data.shape)
print(conv.weight.data[0, 7, 1, 1])

torch.Size([2, 1, 3, 3])


IndexError: index 7 is out of bounds for dimension 0 with size 1

In [196]:
arr = conv.weight.data.cpu().detach().numpy()
print(arr)
print(arr[0,2,1,1])

[[[[-0.5334671  -0.7244778   0.7921703 ]
   [-0.65852064 -0.29105836 -0.00331324]
   [ 0.05537641  0.58875823 -0.7970913 ]]]


 [[[ 0.11154282 -0.23014873  0.14453804]
   [ 0.3641212   0.5914984  -0.67334676]
   [ 0.81529963  0.22313368 -0.03336781]]]]


IndexError: index 2 is out of bounds for axis 1 with size 1

In [197]:
conv.weight.data.shape

torch.Size([2, 1, 3, 3])

### TF

In [186]:
sess.close()

In [91]:
import tensorflow as tf
sess = tf.InteractiveSession()
tf_arr = tf.convert_to_tensor(tf_up_vec, dtype=tf.float32)
initializer = tf.contrib.layers.variance_scaling_initializer(uniform=True)
result = tf.layers.conv2d(tf_arr, filters=out_channels, kernel_size=kernel_size,
                            padding='SAME', kernel_initializer=initializer)
sess.run([tf.global_variables_initializer(),
          tf.local_variables_initializer()])
np_result = result.eval()
# print([v.name for v in tf.trainable_variables() ])
vs = tf.trainable_variables()
tf_weight = sess.run(vs[-2])
tf_bias = sess.run(vs[-1])
print(np_result[0, :2, :2, :2])
print(np_result.shape)
sess.close()




[[[-0.5027047   0.10791749]
  [-1.1058489   1.2873383 ]]

 [[-0.43501222  0.0223068 ]
  [-1.5716236   1.3132703 ]]]
(1, 4, 4, 2)


In [176]:
print(np_result[0,:,:,0])
print("=========")
print(tf_weight.shape)
print(tf_weight[:,:,:,0])

# # CONCLUSION: EIHTER3210 OR 2310

[[-0.5027047  -1.1058489  -1.1058489  -0.30544436]
 [-0.43501222 -1.5716236  -1.5716236  -1.5633893 ]
 [-0.43501222 -1.5716236  -1.5716236  -1.5633893 ]
 [-0.22667915 -1.418667   -1.418667   -2.2075238 ]]
(3, 3, 1, 2)
[[[-0.5334671 ]
  [-0.7244778 ]
  [ 0.7921703 ]]

 [[-0.65852064]
  [-0.29105836]
  [-0.00331324]]

 [[ 0.05537641]
  [ 0.58875823]
  [-0.7970913 ]]]


In [97]:
last_conv = tf_weight[:,:,:,0]
print(last_conv.shape)
np.sum(last_conv[-2:, -2:, 0])

(3, 3, 1)


-0.5027047

In [113]:
print(conv.weight.data)

tensor([[[[-0.5335, -0.6585,  0.0554],
          [-0.7245, -0.2911,  0.5888],
          [ 0.7922, -0.0033, -0.7971]]],


        [[[ 0.1115,  0.3641,  0.8153],
          [-0.2301,  0.5915,  0.2231],
          [ 0.1445, -0.6733, -0.0334]]]])


In [182]:
out_channels

2

In [187]:
import tensorflow as tf
sess = tf.InteractiveSession()
tf_arr = tf.convert_to_tensor(py_up_vec2.permute(0,2,3,1).cpu().detach().numpy(), dtype=tf.float32)
initializer = tf.contrib.layers.variance_scaling_initializer(uniform=True)
result = tf.layers.conv2d(tf_arr, filters=out_channels, kernel_size=kernel_size,
                            padding='SAME', kernel_initializer=initializer)
sess.run([tf.global_variables_initializer(),
          tf.local_variables_initializer()])
np_result2 = result.eval()
# print([v.name for v in tf.trainable_variables() ])
vs = tf.trainable_variables()
tf_weight2 = sess.run(vs[-2])
# tf_bias = sess.run(vs[-1])
print(np_result2)
print(np_result2.shape)
print("============")
print(tf_weight2)
print(tf_weight2.shape)
sess.close()

[[[[-0.1492607  -0.80596215]
   [ 0.6515763   0.31801212]
   [ 0.          0.        ]
   [ 0.          0.        ]]

  [[-0.11954015  0.41798925]
   [ 0.02802968 -0.3155235 ]
   [ 0.          0.        ]
   [ 0.          0.        ]]

  [[ 0.          0.        ]
   [ 0.          0.        ]
   [ 0.          0.        ]
   [ 0.          0.        ]]

  [[ 0.          0.        ]
   [ 0.          0.        ]
   [ 0.          0.        ]
   [ 0.          0.        ]]]]
(1, 4, 4, 2)
[[[[ 0.02802968 -0.3155235 ]]

  [[-0.11954015  0.41798925]]

  [[-0.78344625 -0.05473381]]]


 [[[ 0.6515763   0.31801212]]

  [[-0.1492607  -0.80596215]]

  [[-0.3441969  -0.47510624]]]


 [[[ 0.4044745  -0.5900563 ]]

  [[-0.4387843   0.6062478 ]]

  [[ 0.45067    -0.1505813 ]]]]
(3, 3, 1, 2)




In [188]:
print(np_result2[0,:,:,0])
print("---------------")
print(tf_weight2[:,:,0,0])



[[-0.1492607   0.6515763   0.          0.        ]
 [-0.11954015  0.02802968  0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
---------------
[[ 0.02802968 -0.11954015 -0.78344625]
 [ 0.6515763  -0.1492607  -0.3441969 ]
 [ 0.4044745  -0.4387843   0.45067   ]]


In [204]:
print(torch.__version__)

1.0.0


In [202]:
test_indexing = torch.tensor([[1,2,3],[4,5,6]])
print(test_indexing[7,0])

IndexError: index 7 is out of bounds for dimension 0 with size 2

## SO

In [226]:
in_channels = 1
kernel_size = 3
out_channels = 2
padding = 0
img_size = 4

conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
conv.weight.data = torch.FloatTensor(range(2,20)).view(out_channels, in_channels, kernel_size, kernel_size)
print("Weight Data")
print(conv.weight.data)
conv.bias.data = torch.zeros(out_channels)
print("#" * 30)

input_tensor = torch.zeros((1, in_channels, img_size, img_size))
input_tensor[0,0,0,0] = 1

print('output')
output = conv(input_tensor)
print(output.shape)
print(output)

print("===" * 10)
padding = 1
conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
conv2.weight.data = torch.FloatTensor(range(2,20)).view(out_channels, in_channels, kernel_size, kernel_size)
conv2.bias.data = torch.zeros(out_channels)
output2 = conv2(input_tensor)
print(output2.shape)
print(output2)

Weight Data
tensor([[[[ 2.,  3.,  4.],
          [ 5.,  6.,  7.],
          [ 8.,  9., 10.]]],


        [[[11., 12., 13.],
          [14., 15., 16.],
          [17., 18., 19.]]]])
##############################
output
torch.Size([1, 2, 2, 2])
tensor([[[[ 2.,  0.],
          [ 0.,  0.]],

         [[11.,  0.],
          [ 0.,  0.]]]], grad_fn=<MkldnnConvolutionBackward>)
torch.Size([1, 2, 4, 4])
tensor([[[[ 6.,  5.,  0.,  0.],
          [ 3.,  2.,  0.,  0.],
          [ 0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.]],

         [[15., 14.,  0.,  0.],
          [12., 11.,  0.,  0.],
          [ 0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.]]]], grad_fn=<MkldnnConvolutionBackward>)


## IMPORTANT

In [271]:
in_channels = 1
kernel_size = 3
out_channels = 2
padding = 1
img_size = 4


conv_real = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
conv_real.weight.data = torch.tensor(np.load("conv_weight.npy"))
conv_real.bias.data = torch.tensor(np.load("conv_bias.npy"))
conv_real = conv_real.cuda()

# print("INPUT---------")
py_up_vec3 = torch.zeros((1, in_channels, 4, 4))
py_up_vec3[0,0,0,0] = 1
py_up_vec3 = py_up_vec3.cuda()
# print(py_up_vec3[0,0])


output3 = conv_real(py_up_vec3)
print("OUTPUT------------")
print(output3)

print("==" * 30)

conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
conv.weight.data = torch.tensor(np.load("tf_weight.npy")).permute(3,2,0,1) #BAD
# conv.weight.data = torch.tensor(tf_weight).permute(3,2,0,1) # BAD
# conv.weight.data = torch.tensor(np.load("conv_weight.npy")) #GOOD
# conv.weight.data = torch.tensor(np.load("tf_weight.npy")).permute(3,2,0,1).clone() #GOOD
# conv.weight.data = torch.tensor(np.load("tf_weight.npy")).permute(3,2,0,1).contiguous()
conv.bias.data = torch.tensor(tf_bias)
conv = conv.cuda()


output2 = conv(py_up_vec3)
# print(output2.shape)
print(output2)

print("======= verifications ===========")

print(np.array_equal(conv.weight.data, conv_real.weight.data))
print(np.array_equal(conv.bias.data, conv_real.bias.data))
print(conv.in_channels == conv_real.in_channels)
print(conv.out_channels == conv_real.out_channels)
print(conv.padding == conv_real.padding)
print(conv.kernel_size == conv_real.kernel_size)





# print("WEIGHT=================")
# print(conv_real.weight.data)

OUTPUT------------
tensor([[[[-0.2911, -0.6585,  0.0000,  0.0000],
          [-0.7245, -0.5335,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.5915,  0.3641,  0.0000,  0.0000],
          [-0.2301,  0.1115,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<CudnnConvolutionBackward>)
tensor([[[[-0.2911, -0.6585,  0.0000,  0.0000],
          [-0.7245, -0.5335,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.5915,  0.3641,  0.0000,  0.0000],
          [-0.2301,  0.1115,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<CudnnConvolutionBackward>)
False
False
True
True
True
True


In [260]:
conv.weight.data = torch.tensor(tf_weight).permute(3,2,0,1)
np.save("tf_weight.npy", tf_weight)

In [241]:
py_up_vec2 = torch.zeros((1, in_channels, 4, 4))
py_up_vec2[0,0,0,0] = 1


print(conv.weight.data.shape)
conv.weight.data

torch.Size([1, 2, 4, 4])
tensor([[[[-0.2911, -0.6585,  0.0000,  0.0000],
          [-0.7245, -0.5335,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.5915,  0.3641,  0.0000,  0.0000],
          [-0.2301,  0.1115,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]]]],
       grad_fn=<MkldnnConvolutionBackward>)
torch.Size([2, 1, 3, 3])


tensor([[[[-0.5335, -0.7245,  0.7922],
          [-0.6585, -0.2911, -0.0033],
          [ 0.0554,  0.5888, -0.7971]]],


        [[[ 0.1115, -0.2301,  0.1445],
          [ 0.3641,  0.5915, -0.6733],
          [ 0.8153,  0.2231, -0.0334]]]])