In [67]:
import torch
from keras import backend as K
import numpy as np
import pandas as pd
import torch.nn as nn

import torch.nn.functional as F

# INITIALISE THE MODEL AND THEN CONVERT IT INTO MULTI-GPU EXPLICITLY

In [68]:
# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 16

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

n_classes=10

# Number of channels in the training images. For color images this is 3
nc = 1

# Size of z latent vector (i.e. size of generator input)
nz = 50

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 200

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 2

#embedding dimension
embed_dim=100

cuda =True

In [63]:
device = torch.device("cuda:0" if cuda else "cpu")

In [66]:
device.type

'cuda'

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features,ngpu):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        if (device.type=='cuda'):
            x= x.to(device)
            output=torch.nn.data_parallel(self.conv_block,input=x,device_ids=range(ngpu))
            return(x + output)
        
        else:
            return(x+self.conv_block(x)_
        

In [58]:
class Feature_extraction_Model(nn.Module):
    def __init__(self):
        super(Feature_extraction_Model, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 24, 3, stride=2, padding=1)
        self.batchNorm1 = nn.BatchNorm2d(24)
        self.conv2 = nn.Conv2d(24, 24, 3, stride=2, padding=1)
        self.batchNorm2 = nn.BatchNorm2d(24)
        self.conv3 = nn.Conv2d(24, 24, 3, stride=2, padding=1)
        self.batchNorm3 = nn.BatchNorm2d(24)
        self.conv4 = nn.Conv2d(24, 24, 3, stride=2, padding=1)
        self.batchNorm4 = nn.BatchNorm2d(24)

        
    def forward(self, img):
        """convolution"""
        
        if (device.type=='cuda'):
            img= img.to(device)
        
        x = self.conv1(img)
        x = F.relu(x)
        x = self.batchNorm1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.batchNorm2(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.batchNorm3(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.batchNorm4(x)
        return x

In [17]:
class G_theta(nn.Module):
    ''' Gives the sum of all the relations '''
    
    def __init__(self,input_channels=63):
        super(G_theta,self).__init__()
        
        self.g_fc1 = nn.Linear(input_channels, 256) 

        self.g_fc2 = nn.Linear(256, 256)
        self.g_fc3 = nn.Linear(256, 256)
        self.g_fc4 = nn.Linear(256, 256)
        
    def forward(self,pairs):
        d=int((pairs.size()[0]/batch_size)**0.25)
        
        if (device.type=='cuda'):
            pairs= pairs.to(device)

        x_ = pairs
        x_ = self.g_fc1(x_)
        x_ = F.relu(x_)
        x_ = self.g_fc2(x_)
        x_ = F.relu(x_)
        x_ = self.g_fc3(x_)
        x_ = F.relu(x_)
        x_ = self.g_fc4(x_)
        x_ = F.relu(x_)
        print('final x_ size = {}'.format(x_.size()))
        x_g = x_.view(batch_size,d*d*d*d,256) 
        x_g = x_g.sum(1).squeeze()
        
        return(x_g)

In [50]:
def weights_init(m):
    classname = m.__class__.__name__
#     print('module={}'.format(m))
#     print("classname={}".format(classname))
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
#         print("weights initialised with normal distribution")
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
#         print("weights initialised with normal distribution and bias set to zero")

In [23]:
def tagging():
    d= conv.size()[2]
    tag= torch.zeros((2,d,d))
    print('tagging in process')
    for i in range(d):
        for j in range(d):    # Can speed up this loop by parallization
            tag[0,i,j] = float(int(i%d))/(d-1)*2-1
            tag[1,i,j] = float(int(j%d))/(d-1)*2-1
    
    return tag.view((conv.size()[0],2,d,d))


In [None]:
def Calculate_relations(nn.Module):
    def __init__(self,conv_map):
        super(Calculate_relations, self).__init__()
        
        
        g_theta=G_theta()

In [101]:
def get_pairs(tags,conv_map,condition_vector=[]):
    
    '''
    d: height of conv_map
    tags:torch tensor of size-(batchsize,2,d,d)
    condition_vector=torch tensor of size-(batchsize,its_dimension)
    
    returns: final tensor of size (batch*height*width*height*width,channels) which can be used 
    directly into G_theta
    '''

    x=torch.cat([conv_map,tags],dim=1)
    mb=x.size()[0]
    n_channels = x.size()[1]
    d = x.size()[2]
    x_flat = x.view(mb,n_channels,d*d).permute(0,2,1) # (64x25x24+2)
    
    
    
    if(not condition_vector is None):
        condition_vector = torch.unsqueeze(condition_vector, 1) #64x1xq
        
        condition_vector = condition_vector.repeat(1,d*d,1) #64x25xq
       
        condition_vector = torch.unsqueeze(condition_vector, 2) # 64x25x1xq
       
        
        # cast all pairs against each other
        x_i = torch.unsqueeze(x_flat,1) # (64x1x25x26)
        
        x_i = x_i.repeat(1,d*d,1,1) # (64x25x25x26)
        
        x_j = torch.unsqueeze(x_flat,2) # (64x25x1x26)
        
        x_j = torch.cat([x_j,condition_vector],3) # (64x25x1x26+11)
        
        x_j = x_j.repeat(1,1,d*d,1) # (64x25x25x26+11)
        
        
        # concatenate all together
        x_full = torch.cat([x_i,x_j],3) # (64x25x25x2*26+11)
        
        
        # reshape for passing through network G_theta
        x_ = x_full.view(mb*d*d*d*d,x_full.size()[3])
        input_channels=x_full.size()[3]
        return(x_,input_channels,mb,d)
        
        
    else:
        # cast all pairs against each other
        x_i = torch.unsqueeze(x_flat,1) # (64x1x25x26)
        
        x_i = x_i.repeat(1,d*d,1,1) # (64x25x25x26)
        
        x_j = torch.unsqueeze(x_flat,2) # (64x25x1x26)
        
        x_j = x_j.repeat(1,1,d*d,1) # (64x25x25x26)
        
        # concatenate all together
        x_full = torch.cat([x_i,x_j],3) # (64x25x25x2*26)
        
        
        # reshape for passing through network G_theta
        x_ = x_full.view(mb*d*d*d*d,x_full.size()[3])
        input_channels=x_full.size()[3]
        
        return(x_,input_channels,mb,d)
        
        
    
    

    
    
    

In [73]:
conv_map = torch.rand((64,24,5,5))
tags=torch.rand((64,2,5,5))
condition_vector= torch.rand((64,11))

In [89]:
z,input_channels=get_pairs(conv_map,tags,condition_vector)

In [92]:
z.size()

torch.Size([40000, 63])

In [93]:
input_channels

63

In [None]:
def get_relations(x_):
    

In [122]:
g_theta=G_theta(input_channels=63)
out=g_theta(pairs=z)


final x_ size = torch.Size([40000, 256])


In [123]:
out.size()

torch.Size([64, 256])