In [1]:
import numpy as np
import torch
import torch.utils.data as Data
import torch.nn as nn
# from model import PointerNetwork

In [2]:
def to_cuda(x):
    if torch.cuda.is_available():
        return x.cuda()
    return x

class PointerNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, weight_size, is_GRU=False):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_size = weight_size
        self.is_GRU = is_GRU

        if self.is_GRU:
            RNN = nn.GRU
            RNNCell = nn.GRUCell
        else:
            RNN = nn.LSTM
            RNNCell = nn.LSTMCell

        self.encoder = RNN(input_size, hidden_size, batch_first=True)
        self.decoder = RNNCell(input_size, hidden_size)
        
        self.W1 = nn.Linear(hidden_size, weight_size, bias=False) 
        self.W2 = nn.Linear(hidden_size, weight_size, bias=False) 
        self.v1 = nn.Linear(weight_size, 1, bias=False)
        
        self.W3 = nn.Linear(hidden_size, weight_size, bias=False) 
        self.W4 = nn.Linear(hidden_size, weight_size, bias=False)
        self.W5 = nn.Linear(hidden_size, weight_size, bias=False)
        self.v2 = nn.Linear(weight_size, 1, bias=False)

    def forward(self, input):
        batch_size = input.shape[0]
        decoder_seq_len = input.shape[1]

        encoder_output, hc = self.encoder(input)
#         print('enc',encoder_output.shape)

        # Decoding states initialization
        hidden = encoder_output[:, -1, :] #hidden state for decoder is the last timestep's output of encoder 
        if not self.is_GRU: #For LSTM, cell state is the sencond state output
            cell = hc[1][-1, :, :]
        decoder_input = to_cuda(torch.rand(batch_size, self.input_size))  
        
        # Decoding with attention             
        probs = []
        encoder_output = encoder_output.transpose(1, 0) #Transpose the matrix for mm
        
        decoder_output = torch.empty(batch_size,1,self.hidden_size)
        for i in range(decoder_seq_len):  
            if self.is_GRU:
                hidden = self.decoder(decoder_input, hidden) 
            else:
                hidden, decoder_hc = self.decoder(decoder_input, (hidden, cell))
            
            if decoder_output.shape[1] == 1:
                decoder_output = hidden.unsqueeze(1)
            else:
                decoder_output = torch.cat((decoder_output,hidden.unsqueeze(1)),dim=1)
                
            # Computing Intra-attention
            sm = torch.tanh(self.W1(decoder_output.transpose(1, 0)) + self.W2(hidden))
            out = self.v1(sm)
            out = torch.log_softmax(out.transpose(0, 1).contiguous(), -1)
            hidden_intra = (out*decoder_output).sum(dim=1)
            
            # Computing attention
            sum = torch.tanh(self.W3(encoder_output) + self.W4(hidden) + self.W5(hidden_intra))
            out = self.v2(sum).squeeze()        
            out = torch.log_softmax(out.transpose(0, 1).contiguous(), -1)
            probs.append(out)

        probs = torch.stack(probs, dim=1)           
        return probs

In [3]:
EPOCH = 100
BATCH_SIZE = 250
DATA_SIZE = 10000
INPUT_SIZE = 1
HIDDEN_SIZE = 512
WEIGHT_SIZE = 256
LR = 0.001


def getdata(experiment=1, data_size=None):
    if experiment == 1:
        high = 100
        senlen = 5
        x = np.array([np.random.choice(range(high), senlen, replace=False)
                      for _ in range(data_size)])
        y = np.argsort(x)
    elif experiment == 2:
        high = 1000
        senlen = 10
        x = np.array([np.random.choice(range(high), senlen, replace=False)
                      for _ in range(data_size)])
        y = np.argsort(x)
    elif experiment == 3:
        senlen = 5
        x = np.array([np.random.random(senlen) for _ in range(data_size)])
        y = np.argsort(x)
    elif experiment == 4:
        senlen = 10
        x = np.array([np.random.random(senlen) for _ in range(data_size)])
        y = np.argsort(x)
    return x, y

def evaluate(model, X, Y):
    probs = model(X) 
    prob, indices = torch.max(probs, 2) 
    equal_cnt = sum([1 if torch.equal(index.detach(), y.detach()) else 0 for index, y in zip(indices, Y)])
    accuracy = equal_cnt/len(X)
    print('Acc: {:.2f}%'.format(accuracy*100))

#Get Dataset
x, y = getdata(experiment=2, data_size = DATA_SIZE)
x = to_cuda(torch.FloatTensor(x).unsqueeze(2))     
y = to_cuda(torch.LongTensor(y)) 
#Split Dataset
train_size = (int)(DATA_SIZE * 0.9)
train_X = x[:train_size]
train_Y = y[:train_size]
test_X = x[train_size:]
test_Y = y[train_size:]
#Build DataLoader
train_data = Data.TensorDataset(train_X, train_Y)
data_loader = Data.DataLoader(
    dataset = train_data,
    batch_size = BATCH_SIZE,
    shuffle = True,
)


#Define the Model
model = PointerNetwork(INPUT_SIZE, HIDDEN_SIZE, WEIGHT_SIZE, is_GRU=False)
if torch.cuda.is_available():
    model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
loss_fun = torch.nn.CrossEntropyLoss()


#Training...
print('Training... ')
for epoch in range(EPOCH):
    for (batch_x, batch_y) in data_loader:
        probs = model(batch_x)         
        outputs = probs.view(-1, batch_x.shape[1])
        batch_y = batch_y.view(-1) 
        loss = loss_fun(outputs, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if epoch % 2 == 0:
        print('Epoch: {}, Loss: {:.5f}'.format(epoch, loss.item()))
        indx = torch.randperm(train_X.shape[0])[:100]
        evaluate(model, train_X[indx], train_Y[indx])
#Test...    
print('Test...')
evaluate(model, test_X, test_Y)

Training... 
Epoch: 0, Loss: 1.43681
Acc: 0.00%
Epoch: 2, Loss: 0.96951
Acc: 0.00%
Epoch: 4, Loss: 0.74508
Acc: 3.00%
Epoch: 6, Loss: 0.64369
Acc: 10.00%
Epoch: 8, Loss: 0.49641
Acc: 24.00%
Epoch: 10, Loss: 0.48426
Acc: 17.00%
Epoch: 12, Loss: 0.56842
Acc: 19.00%
Epoch: 14, Loss: 0.40975
Acc: 19.00%
Epoch: 16, Loss: 0.47306
Acc: 15.00%
Epoch: 18, Loss: 0.31396
Acc: 39.00%
Epoch: 20, Loss: 0.29361
Acc: 32.00%
Epoch: 22, Loss: 1.64447
Acc: 0.00%
Epoch: 24, Loss: 0.40685
Acc: 23.00%
Epoch: 26, Loss: 0.33556
Acc: 41.00%
Epoch: 28, Loss: 0.31446
Acc: 26.00%
Epoch: 30, Loss: 0.29212
Acc: 35.00%
Epoch: 32, Loss: 0.31778
Acc: 34.00%
Epoch: 34, Loss: 0.25916
Acc: 43.00%
Epoch: 36, Loss: 0.25660
Acc: 42.00%
Epoch: 38, Loss: 0.31484
Acc: 24.00%
Epoch: 40, Loss: 0.28383
Acc: 37.00%
Epoch: 42, Loss: 0.21819
Acc: 53.00%
Epoch: 44, Loss: 0.21600
Acc: 56.00%
Epoch: 46, Loss: 0.20875
Acc: 56.00%


KeyboardInterrupt: 

In [4]:
evaluate(model, test_X[500:700], test_Y[500:700])

Acc: 62.00%


In [12]:
sample =torch.randint(100,(30,))
sample

tensor([92, 26,  8, 22, 81, 51, 81, 49, 38, 28, 74, 77, 83, 55, 86,  7, 64, 54,
        44, 13, 18, 48,  8,  7, 63, 14, 62, 72, 59, 71])

In [13]:
sample = to_cuda(sample.view(2,15,1).type(torch.FloatTensor))

In [14]:
probs = model(sample)
prob, indices = torch.max(probs, 2)

In [15]:
sample

tensor([[[92.],
         [26.],
         [ 8.],
         [22.],
         [81.],
         [51.],
         [81.],
         [49.],
         [38.],
         [28.],
         [74.],
         [77.],
         [83.],
         [55.],
         [86.]],

        [[ 7.],
         [64.],
         [54.],
         [44.],
         [13.],
         [18.],
         [48.],
         [ 8.],
         [ 7.],
         [63.],
         [14.],
         [62.],
         [72.],
         [59.],
         [71.]]], device='cuda:0')

In [16]:
indices

tensor([[ 2,  3,  9,  8,  5, 10, 10,  6, 14,  0,  0,  2,  2,  2,  2],
        [ 8,  0,  7,  5,  3,  6, 13, 11,  1, 12,  8,  8,  8,  8,  8]],
       device='cuda:0')

In [17]:
sample[1][indices[1]]

tensor([[ 7.],
        [ 7.],
        [ 8.],
        [18.],
        [44.],
        [48.],
        [59.],
        [62.],
        [64.],
        [72.],
        [ 7.],
        [ 7.],
        [ 7.],
        [ 7.],
        [ 7.]], device='cuda:0')

In [18]:
sorted(sample[1])

[tensor([7.], device='cuda:0'),
 tensor([7.], device='cuda:0'),
 tensor([8.], device='cuda:0'),
 tensor([13.], device='cuda:0'),
 tensor([14.], device='cuda:0'),
 tensor([18.], device='cuda:0'),
 tensor([44.], device='cuda:0'),
 tensor([48.], device='cuda:0'),
 tensor([54.], device='cuda:0'),
 tensor([59.], device='cuda:0'),
 tensor([62.], device='cuda:0'),
 tensor([63.], device='cuda:0'),
 tensor([64.], device='cuda:0'),
 tensor([71.], device='cuda:0'),
 tensor([72.], device='cuda:0')]