In [None]:
import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('../external/Transformer_modules/')
sys.path.append('../src/')
from modules import MultiHeadAttention, PositionwiseFeedForward
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import mnist

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
x_train = mnist.make_clouds(mnist.x_train,500) 
y_train = mnist.y_train
x_val = mnist.make_clouds(mnist.x_val,500) 
y_val = mnist.y_val

In [None]:
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class GlobalAveragePooling(nn.Module):
    def __init__(self, dim=-1):
        super(self.__class__, self).__init__()
        self.dim = dim
        
    def forward(self, x):
        return x.mean(dim=self.dim)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_dim, hidden_dim=100,ffn_dim =200,n_head=8):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.constant_(self.fc1.bias, 0.0)
        
        self.mha_1 = MultiHeadAttention(n_head=n_head,d_model = hidden_dim)
        self.ffn_1 = PositionwiseFeedForward(hidden_dim, ffn_dim, use_residual=False)
        self.mha_2 = MultiHeadAttention(n_head=n_head,d_model = hidden_dim)
        self.ffn_2 = PositionwiseFeedForward(hidden_dim, ffn_dim, use_residual=False)
        
        self.gl_1 =  GlobalAveragePooling(dim = 1)
        
        self.fc2 = nn.Linear(hidden_dim, 2)
        nn.init.xavier_normal_(self.fc2.weight)
        nn.init.constant_(self.fc2.bias, 0.0)
        
    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = self.mha_1(h1)
        h3 = self.ffn_1(h2)
        h4 = self.mha_1(h3)
        h5 = self.ffn_1(h4)
        score = self.fc2(self.gl_1(h5))
        return score
        

In [None]:
class Generator(nn.Module):
    def __init__(self, in_dim, hidden_dim=100,ffn_dim =200,n_head=8):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.constant_(self.fc1.bias, 0.0)
        
        self.mha_1 = MultiHeadAttention(n_head=n_head,d_model = hidden_dim)
        self.ffn_1 = PositionwiseFeedForward(hidden_dim, ffn_dim, use_residual=False)
        self.mha_2 = MultiHeadAttention(n_head=n_head,d_model = hidden_dim)
        self.ffn_2 = PositionwiseFeedForward(hidden_dim, ffn_dim, use_residual=False)

        
        self.fc2 = nn.Linear(hidden_dim, in_dim)
        nn.init.xavier_normal_(self.fc2.weight)
        nn.init.constant_(self.fc2.bias, 0.0)
        self.bn = nn.BatchNorm1d(in_dim)
        
    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = self.mha_1(h1)
        h3 = self.ffn_1(h2)
        h4 = self.mha_1(h3)
        h5 = self.ffn_1(h4)
        score = self.bn(self.fc2(h5))
        return score     

In [None]:
in_dim = 2
hidden_dim = 100
ffn_dim = 200
n_head = 8

disc = Discriminator(in_dim, hidden_dim).cuda(1)

In [None]:
in_dim = 2
hidden_dim = 100
ffn_dim = 200
n_head = 8

gen = Generator(in_dim, hidden_dim).cuda(1)

In [None]:

target = Variable(torch.FloatTensor(x_train[2:3])).cuda(1)

noise =  torch.rand(target.shape).cuda(1)

In [None]:
gen_opt = torch.optim.Adam(gen.parameters(), lr=1e-4)
disc_opt = torch.optim.SGD(disc.parameters(), lr=0.001)

In [None]:
try:
    for epoch in range(num_epochs):
        ls_g=[]
        ls_d=[]
        for input_data,info in iterate_minibatches(data, batch_size,inform):

            # Optimize D

            for _ in range(k_d):
                # Sample noise
                if not (info is None):
                    noise = Variable(torch.cat((torch.Tensor(sample_noise(len(input_data))),torch.Tensor(info)),1).cuda())
                else:
                    noise = Variable(torch.Tensor(sample_noise(len(input_data))).cuda())

                # Do an update

                inp_data = Variable(torch.Tensor(input_data).cuda())
                data_gen = self.generator(noise)

                loss = d_loss(self.discriminator(data_gen,TASK = TASK), self.discriminator(inp_data,TASK = TASK),TASK)
                ls_d.append(loss.data.cpu().numpy()[0])
                d_optimizer.zero_grad()
                loss.backward()
                d_optimizer.step()
                if TASK == 3:
                    self.discriminator.apply(self.clipper,TASK = TASK)


            # Optimize G
            for _ in range(k_g):
                # Sample noise
                if not (info is None):
                    noise = Variable(torch.cat((torch.Tensor(sample_noise(len(input_data))),torch.Tensor(info)),1).cuda())
                else:
                    noise = Variable(torch.Tensor(sample_noise(len(input_data))).cuda())

                # Do an update
                data_gen = self.generator(noise)
                if not (info is None):
                     loss = g_loss(self.discriminator(data_gen + Variable(torch.Tensor(info).cuda(), requires_grad=False),TASK = TASK),TASK)
                else:
                    loss = g_loss(self.discriminator(data_gen,TASK = TASK),TASK)
                ls_g.append(loss.data.cpu().numpy()[0])
                g_optimizer.zero_grad()
                loss.backward()
                g_optimizer.step()
        if(epoch%10==0):
            print('generator_loss:',np.mean(ls_g),'discriminator_loss',np.mean(ls_d))
except KeyboardInterrupt:
    pass

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.subplot(1,2,1)
_x, _y = target.data.cpu().numpy()[0].T
plt.scatter(_x, -_y)
plt.subplot(1,2,2)
_x, _y = my_cloud[:1].data.cpu().numpy()[0].T
plt.scatter(_x, -_y)

In [None]:
cloud_opt = torch.optim.Adam([my_cloud], lr=1e-3)
disc_opt = torch.optim.SGD(disc.parameters(), lr=0.001)

In [None]:
from IPython.display import clear_output
from tqdm import trange
IS_FAKE, IS_REAL = 0, 1

for epoch_i in trange(1000):
    for i in range(100):
        loss_disc = - F.log_softmax(disc(target), 1)[:, IS_REAL].mean() \
                    - F.log_softmax(disc(my_cloud), 1)[:, IS_FAKE].mean()

        disc_opt.zero_grad()
        loss_disc.backward()
        disc_opt.step()

    for i in range(10):
        loss_gen = - F.log_softmax(disc(my_cloud), 1)[:, IS_REAL].mean()
        cloud_opt.zero_grad()
        loss_gen.backward()
        cloud_opt.step()



    clear_output(True)
    plt.subplot(1,2,1)
    _x, _y = target.data.cpu().numpy()[0].T
    plt.scatter(_x, -_y)
    plt.ylim(-1, 0)
    plt.xlim(0, 1)
    plt.subplot(1,2,2)
    _x, _y = my_cloud[:1].data.cpu().numpy()[0].T
    plt.scatter(_x, -_y)
    plt.ylim(-1, 0)
    plt.xlim(0, 1)
    plt.show()