In [164]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pandas as pd

from torch import tensor

from generate_data import generate_data

# onehot = pd.Series(data = [tensor([1,0,0]),tensor([0,1,0]),tensor([0,0,1])], index = ['a','b','c'])
onehot = pd.Series(data = [[1,0,0],[0,1,0],[0,0,1]], index = ['a','b','c'])

min_length = 20

In [165]:
batch_size = 32

lr = 0.01

loss_fn = nn.CrossEntropyLoss()

In [179]:
N = 4*int(1e4)


X, y = generate_data(N, is_short=True, type = 2)

X_test, y_test = generate_data(2000, is_short=True, type = 2)



In [180]:
def encode_letter(l):
    return np.array([tensor(onehot[d].values) for d in l])

letterToIndex = lambda l: ord(l) - 97

def encode_language(l):
    tensor = torch.zeros(min_length, 1, 3)

    for li, letter in enumerate(l):
        tensor[li][0][letterToIndex(letter)] = 1
    return tensor

def onehot_encode(data):
    return [encode_language(l) for l in data]

def onehot_labels(labels):
    tensor = torch.zeros(len(labels), 2)

    for i, l in enumerate(labels):
        tensor[i][l] = 1
    return tensor


def train_loader(data, labels, batch_size):
    shuffle = np.random.permutation(len(data))

    _data = data[shuffle]
    _labels = labels[shuffle]
    for i in range(len(_data)//batch_size):
        enc = onehot_encode(_data[i*batch_size:(i+1)*batch_size])
        batch = torch.cat(enc,axis = 1)
        
        truth = onehot_labels(_labels[i*batch_size:(i+1)*batch_size])
        yield (batch, truth)


In [168]:
# X_encoded = onehot_encode(X)

# enc = encode_batches(X,4)



# torch.cat((e[0],e[1]),1).shape

# e[0].shape
# e = next(enc)

# e.shape

In [184]:
class LSTM_predictor(nn.Module):
    def __init__(self, SIZE):
        super(LSTM_predictor, self).__init__()

        self.SIZE = SIZE

        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(3, SIZE)

        self.output = nn.Linear(SIZE, 2)

        self.activation = nn.Softmax()

        # self.softmax = nn.LogSoftmax(dim=2)

    def forward(self, sentence):

        _, (lstm_out, _) = self.lstm(sentence)
        
        # print(lstm_out.shape)
        out = self.output(lstm_out[-1])
        # print(out.shape)
        # return self.softmax(out)


        # out = torch.sum(out,axis = 0)

        out = self.activation(out)

        return out
model = LSTM_predictor(32)

In [185]:
print(model)

LSTM_predictor(
  (lstm): LSTM(3, 32)
  (output): Linear(in_features=32, out_features=2, bias=True)
  (activation): Softmax(dim=None)
)


In [186]:
lr = 0.0001

parameters = model.parameters()

optimizer = torch.optim.Adam(parameters, lr=lr)

model.train()


loss_fn = nn.CrossEntropyLoss()


n_epochs = 10


def test(model):
    model.eval()
    test_batches = train_loader(X_test,y_test,batch_size=batch_size)
    acc = 0.0
    i = 0
    for data, labels in test_batches:
        pred = model(data).view((-1,2))
        corr = (torch.argmax(pred,1) == torch.argmax(labels,1)).float().mean().item()
        acc += corr
        i += 1
    return acc/i

for i in range(n_epochs):
    model.train()
    train_batches = train_loader(X,y,batch_size=batch_size)

    sum_loss = 0.0
    i_loss = 0
    for data, labels in train_batches:
        pred = model(data)
        loss_fn = F.binary_cross_entropy
        loss = loss_fn(pred,labels)
        loss.backward()
        optimizer.step()

        sum_loss += loss.item()
        i_loss += 1
    print("loss",sum_loss/i_loss)
    acc = test(model)
    print("acc",acc)

    
# X_encoded[0]



loss 0.6755301847934723
acc 0.6179435483870968
loss 0.6604285850524902
acc 0.5604838709677419
loss 0.6080222187995911
acc 0.6451612903225806
loss 0.4824629293560982
acc 0.7782258064516129
loss 0.4077808267235756
acc 0.9022177419354839
loss 0.3285345282375813
acc 0.8966733870967742
loss 0.2871818158358336
acc 0.9087701612903226
loss 0.23907935158610344
acc 0.9248991935483871
loss 0.2358365966670215
acc 0.9395161290322581
loss 0.19523912897408008
acc 0.9440524193548387


In [189]:
# torch.save(model.state_dict(), "model0_dict_SIZE32_10epochs")