## Steps
 - 1. Create Vocab
 - 2. Encode words
 - 3. Seperate into input/output, put into <b>LongTensor</b>
 - 4. Model class
 - 5. Training params
 - 6. Train
 - 7. Test
 

## Data preparation

In [128]:
sentences = ['bob likes sheep', 'alice is fast', 'cs685 is fun', 'you love lamp', \
             "i like dog", "i love coffee", "i hate milk"]

In [58]:
vocab = {}
inv_vocab = {}
sent_ids = []

for sent in sentences:
    word_ids = []
    for word in sent.split():
        if word not in vocab:
            vocab[word] = len(vocab)
            inv_vocab[len(vocab)-1] = word
            
        word_ids.append(vocab[word])        
            
    sent_ids.append(word_ids)    
    
print(vocab)
print(sent_ids)

{'bob': 0, 'likes': 1, 'sheep': 2, 'alice': 3, 'is': 4, 'fast': 5, 'cs685': 6, 'fun': 7, 'you': 8, 'love': 9, 'lamp': 10, 'i': 11, 'like': 12, 'dog': 13, 'coffee': 14, 'hate': 15, 'milk': 16}
[[0, 1, 2], [3, 4, 5], [6, 4, 7], [8, 9, 10], [11, 12, 13], [11, 9, 14], [11, 15, 16]]


In [59]:
input_ids = torch.LongTensor([sent[:2] for sent in sent_ids])
label_ids = torch.LongTensor([sent[-1] for sent in sent_ids])
input_ids, label_ids

(tensor([[ 0,  1],
         [ 3,  4],
         [ 6,  4],
         [ 8,  9],
         [11, 12],
         [11,  9],
         [11, 15]]),
 tensor([ 2,  5,  7, 10, 13, 14, 16]))

## Model

In [60]:
class NLM(nn.Module):
    def __init__(self, embed_dim: int, hidden_dim: int, vocab_dim: int, seq_len: int) -> None:
        super().__init__()
        
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        
        self.embeddings = nn.Embedding(num_embeddings=vocab_dim, embedding_dim=embed_dim)
        self.fc1 = nn.Linear(seq_len*embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, vocab_dim)
        self.activation = nn.Tanh()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
#         print("Input shape: ", x.shape)
        # shape (batch, seq_len) -> (batch, seq_len, embed_dim)
        x = self.embeddings(x)
        
#         print("Embeddings shape: ", x.shape)
        # concatenate all token embeddings????
        x = torch.flatten(x, start_dim=1)
        
#         print("Before 1 layer shape: ", x.shape)
        # shape (batch, seq_len*embed_dim) -> (batch, hidden_dim)
        x = self.activation(self.fc1(x))
        
#         print("Before 2 layer shape: ", x.shape)
        # shape (batch, hidden_dim) - > (batch, vocab_dim)
        x = self.fc2(x)
        
#         print("Output shape: ", x.shape)
        
        return x

## Training

In [126]:
torch.manual_seed(3)

model = NLM(embed_dim=10, hidden_dim=8, vocab_dim=len(vocab), seq_len=2)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-1) # SGD

loss_fn = nn.CrossEntropyLoss()


num_epochs = 10

for epoch in range(num_epochs):
    #shuffle
    perms = torch.randperm(input_ids.shape[0])
    batch_inputs = input_ids[perms]
    batch_labels = label_ids[perms]
    
    # predict
    output_ids = model(batch_inputs)
    
    # loss
    loss = loss_fn(output_ids, batch_labels)

    # backward    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    print(f"Epoch {epoch} loss = {loss.item():.3f}")

Epoch 0 loss = 2.806
Epoch 1 loss = 1.911
Epoch 2 loss = 1.331
Epoch 3 loss = 0.877
Epoch 4 loss = 0.575
Epoch 5 loss = 0.389
Epoch 6 loss = 0.265
Epoch 7 loss = 0.186
Epoch 8 loss = 0.134
Epoch 9 loss = 0.098


# Testing

In [120]:
num = 6
sentences[num]

'i hate milk'

In [121]:
test_input = input_ids[num].view(1, -1)
test_input.shape

torch.Size([1, 2])

In [122]:
test_output = model(test_input)

test_output

tensor([[-1.3127, -1.3140, -4.2429, -1.4105, -1.1098,  0.1548, -1.4311, -2.2945,
         -1.4624, -1.3655, -0.0172, -1.1250, -1.5852,  2.4191,  2.4025, -1.3542,
          6.3865]], grad_fn=<AddmmBackward0>)

In [123]:
softmax = nn.Softmax(dim=1)

In [124]:
idx = torch.argmax(softmax(test_output)).item()
inv_vocab[idx]

'milk'

In [125]:
for num, sent in enumerate(sentences):
    test_input = input_ids[num].view(1, -1)
    test_output = model(test_input)
    idx = torch.argmax(softmax(test_output)).item()
    
    
    print(sent, '->', inv_vocab[idx])

bob likes sheep -> sheep
alice is fast -> fast
cs685 is fun -> fun
you love lamp -> lamp
i like dog -> dog
i love coffee -> coffee
i hate milk -> milk
