In [1]:
import torch
import torch.nn as nn
from torch import optim
import numpy as np
import matplotlib.pyplot as plt

In [2]:
#          01234567
symbols = "EBabcdXY"
symbols_onehot = np.array([[1,0,0,0,0,0,0,0],
                           [0,1,0,0,0,0,0,0],
                           [0,0,1,0,0,0,0,0],
                           [0,0,0,1,0,0,0,0],
                           [0,0,0,0,1,0,0,0],
                           [0,0,0,0,0,1,0,0],
                           [0,0,0,0,0,0,1,0],
                           [0,0,0,0,0,0,0,1]
                          ])

#              01234567
classlabels = 'QRSUVABC'

classlabels_onehot = np.array([[1,0,0,0,0,0,0,0],
                               [0,1,0,0,0,0,0,0],
                               [0,0,1,0,0,0,0,0],
                               [0,0,0,1,0,0,0,0],
                               [0,0,0,0,1,0,0,0],
                               [0,0,0,0,0,1,0,0],
                               [0,0,0,0,0,0,1,0],
                               [0,0,0,0,0,0,0,1]
                              ])

#                            Q       R       S       U       V       A       B      C
classidx2rule = np.array([[6,6,6],[6,6,7],[6,7,6],[6,7,7],[7,6,6],[7,6,7],[7,7,6],[7,7,7]
                         ])

In [3]:
def generate_sequence():
    seq_length = np.random.choice(range(100, 110))
    t1 = np.random.choice(range(10,21))
    t2 = np.random.choice(range(33,44))
    t3 = np.random.choice(range(66,76)) 
    targetclassidx = np.random.choice(range(0,8)) #randomly choose a target class
    
    tagetclass_onehot = classlabels_onehot[targetclassidx]
    
    seq = np.zeros((seq_length,1),dtype="int")
    seq[0] = 0 #first char is E
    seq[-1] = 1 #last char is B
    
    #randomly asaign abcd to the rest of the positions
    for i in range(1,seq_length):
        seq[i] = np.random.choice([2,3,4,5])

    # insert X,Y values based on class
    seq[t1], seq[t2], seq[t3] = classidx2rule[targetclassidx]
    
    #generate onehot for sequence
    seq_onehot = np.zeros((seq_length,8))
    for idx in range(seq_length):
        seq_onehot[idx] = symbols_onehot[seq[idx]]
        
    return seq_length, seq, seq_onehot, targetclassidx, tagetclass_onehot

In [4]:
class Mylstm(nn.Module):
    
    def __init__(self):
        super(Mylstm, self).__init__()
        
        self.lstm1 = nn.LSTM(input_size = 8, hidden_size = 2)
        self.lstm2 = nn.LSTM(input_size = 2, hidden_size = 4)
        self.lstm3 = nn.LSTM(input_size = 4, hidden_size = 8)       
        self.linear = nn.Linear(in_features=8, out_features=8)
        
    def forward(self, input):
       
        lstm_out1, (self.h1, self.c1) = self.lstm1(input,(self.h1, self.c1))
        lstm_out2, (self.h2, self.c2) = self.lstm2(self.h1,(self.h2, self.c2))
        lstm_out3, (self.h3, self.c3) = self.lstm3(self.h2,(self.h3, self.c3))
           
        #pred_vec = lstm_out3[-1]
        pred_vec = self.linear(lstm_out3[-1])
            
        return pred_vec
    
    def reset_hidden_states(self):
        (self.h1, self.c1) = (torch.zeros(1, 1, 2), torch.zeros(1, 1, 2))
        (self.h2, self.c2) = (torch.zeros(1, 1, 4), torch.zeros(1, 1, 4))
        (self.h3, self.c3) = (torch.zeros(1, 1, 8), torch.zeros(1, 1, 8))
    

In [5]:
def train_model(model):
        
    loss_fn  = nn.CrossEntropyLoss()
    
    optimiser = optim.Adam(model.parameters(), lr = 0.1)
    
    count = 0
    for i in range(1000000):
        
        model.reset_hidden_states()
        
        with torch.no_grad():
            seq_length, seq, seq_onehot, targetclassidx, targetclass_onehot = generate_sequence()
        
            seq_onehot = torch.from_numpy(seq_onehot).float()
        
            seq_onehot = seq_onehot.view([seq_length,1,8])
            
            targetclassidx = torch.tensor([targetclassidx])
              
        
        pred = model(seq_onehot)       
               
        loss = loss_fn(pred, targetclassidx)
        
        
        with torch.no_grad():
            predvalue, predclassidx = torch.max(pred,-1)
                       
            if predclassidx == targetclassidx:
                count = count+1
                
            
            if i % 5000 == 0:
                print("------------------------------------------------------")
                print(i,loss, pred, predvalue, predclassidx, targetclassidx, count)
                count = 0
                    
        optimiser.zero_grad()
        
        loss.backward()
        
        optimiser.step()
        
    return model.eval()

model1 = Mylstm()
model1 = train_model(model1)

------------------------------------------------------
0 tensor(2.2155, grad_fn=<NllLossBackward>) tensor([[ 0.3342,  0.1350, -0.1180, -0.0754,  0.1484, -0.3705, -0.1703, -0.1468]],
       grad_fn=<AddmmBackward>) tensor([0.3342]) tensor([0]) tensor([7]) 0
------------------------------------------------------
5000 tensor(1.9009, grad_fn=<NllLossBackward>) tensor([[-0.7544, -1.0631, -0.3845, -0.8196, -1.0744,  0.2158, -0.5402, -0.3275]],
       grad_fn=<AddmmBackward>) tensor([0.2158]) tensor([5]) tensor([7]) 638
------------------------------------------------------
10000 tensor(1.6744, grad_fn=<NllLossBackward>) tensor([[-0.5463, -1.4807, -0.1209, -0.5776, -0.5334, -0.9159, -0.7002, -0.0089]],
       grad_fn=<AddmmBackward>) tensor([-0.0089]) tensor([7]) tensor([2]) 625
------------------------------------------------------
15000 tensor(1.8432, grad_fn=<NllLossBackward>) tensor([[-0.8450, -1.2946, -0.5807, -0.3312, -0.3303, -0.3103, -0.5295, -0.6585]],
       grad_fn=<AddmmBackward>)

KeyboardInterrupt: 