In [None]:
import numpy as np
import sys
sys.path.append('../external/Transformer_modules/')
sys.path.append('../src/')
import torch, torch.nn as nn
import torch.nn.functional as F
from modules import MultiHeadAttention, PositionwiseFeedForward
import mnist

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class GlobalAveragePooling(nn.Module):
    def __init__(self, dim=-1):
        super(self.__class__, self).__init__()
        self.dim = dim
        
    def forward(self, x):
        return x.mean(dim=self.dim)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_dim, hidden_dim=100,ffn_dim =200,n_head=8):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.constant_(self.fc1.bias, 0.0)
        
       
        
        self.mha_1 = MultiHeadAttention(n_head=n_head,d_model = hidden_dim)
        self.ffn_1 = PositionwiseFeedForward(hidden_dim, ffn_dim, use_residual=False)
        
        self.gl_1 =  GlobalAveragePooling(dim = 1)
        
        self.fc2 = nn.Linear(hidden_dim, 10)
        nn.init.xavier_normal_(self.fc2.weight)
        nn.init.constant_(self.fc2.bias, 0.0)
        
    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = self.mha_1(h1)
        h3 = self.ffn_1(h2)
        score = self.fc2(self.gl_1(h3))
        return score
        

In [None]:
x_train = mnist.make_clouds(mnist.x_train,500) 
y_train = mnist.y_train
x_val = mnist.make_clouds(mnist.x_val,500) 
y_val = mnist.y_val
model = Discriminator(2).cuda(0)

In [None]:
x_test = mnist.mnist_test

In [None]:
def compute_loss(X_batch, y_batch):
    X_batch = Variable(torch.FloatTensor(X_batch)).cuda(0)
    y_batch = Variable(torch.LongTensor(y_batch)).cuda(0)
    logits = model(X_batch)
    return F.cross_entropy(logits, y_batch).mean()

def iterate_minibatches(X, y, batchsize):
    indices = np.random.permutation(np.arange(len(X)))
    for start in range(0, len(indices), batchsize):
        ix = indices[start: start + batchsize]
        yield X[ix], y[ix]
opt = torch.optim.Adam(model.parameters())

In [None]:
import time
num_epochs = 150 # total amount of full passes over training data
batch_size = 200
train_loss = []
val_accuracy = []
for epoch in range(num_epochs):
    start_time = time.time()
    model.train(True) 
    for X_batch, y_batch in iterate_minibatches(x_train,y_train,batchsize=batch_size):
        # train on batch
        loss = compute_loss(X_batch, y_batch)
        loss.backward()
        opt.step()
        opt.zero_grad()
        train_loss.append(loss.cpu().detach().numpy())
        del loss

       
        
    # And a full pass over the validation data:
    model.train(False) # disable dropout / use averages for batch_norm
    for X_batch, y_batch in iterate_minibatches(x_val, y_val, batch_size):
        logits = model(Variable(torch.FloatTensor(X_batch)).cuda())
        y_pred = logits.max(1)[1].cpu().detach().numpy()
        val_accuracy.append(np.mean(y_batch == y_pred))
        del logits

    # Then we print the results for this epoch:
    print("Epoch {} of {} took {:.3f}s".format(
        epoch + 1, num_epochs, time.time() - start_time))
    print("  training loss (in-iteration): \t{:.6f}".format(
        np.mean(train_loss[-len(x_train) // batch_size :])))
    print("  validation accuracy: \t\t\t{:.2f} %".format(
        np.mean(val_accuracy[-len(x_val) // batch_size :]) * 100))