In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import string

lexicon = {}

class LexiconModel(nn.Module):
    def __init__(self, vocab_size):
        super(LexiconModel, self).__init__()
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, vocab_size)
        self.linear = nn.Linear(vocab_size, vocab_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        embedded = self.embedding(x)
        output = self.linear(embedded)
        output = self.softmax(output)
        return output

def preprocess_word(word):
    word = word.lower()  # Convert to lowercase
    word = word.translate(str.maketrans('', '', string.punctuation))  # Remove punctuation
    return word

def update_lexicon(current, next_word):
    current = preprocess_word(current)
    next_word = preprocess_word(next_word)
    if current not in lexicon:
        lexicon[current] = {next_word: 1}
    else:
        options = lexicon[current]
        if next_word not in options:
            options[next_word] = 1
        else:
            options[next_word] += 1
        lexicon[current] = options
with open('train.csv', 'r') as dataset:
    for line in dataset:
        words = line.strip().split(' ')
        for i in range(len(words) - 1):
            update_lexicon(words[i], words[i+1])

vocab = {word: idx for idx, word in enumerate(lexicon.keys())}
vocab_size = len(vocab)
model = LexiconModel(vocab_size)
lexicon_tensors = {}
for word, transitions in lexicon.items():
    word = preprocess_word(word)
    if word in vocab:
        word_idx = vocab[word]
        transitions_tensor = torch.zeros(vocab_size)
        for next_word, count in transitions.items():
            next_word = preprocess_word(next_word)
            if next_word in vocab:
                next_word_idx = vocab[next_word]
                transitions_tensor[next_word_idx] = count
        lexicon_tensors[word_idx] = transitions_tensor
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    total_loss = 0
    for word_idx, transitions_tensor in lexicon_tensors.items():
        optimizer.zero_grad()
        inputs = torch.tensor([word_idx])
        outputs = model(inputs)
        target_idx = torch.argmax(transitions_tensor)
        loss = criterion(outputs, target_idx.unsqueeze(0))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss}")
while True:
    line = input('> ')
    word = line.strip().split(' ')[-1]
    word = preprocess_word(word)
    if word not in lexicon:
        print('Word not found')
    else:
        word_idx = torch.tensor([vocab[word]])
        outputs = model(word_idx)
        predicted_idx = torch.argmax(outputs).item()
        predicted_word = list(vocab.keys())[list(vocab.values()).index(predicted_idx)]
        print(line + ' ' + predicted_word)


Epoch 1, Loss: 7566.105537891388
Epoch 2, Loss: 7029.875515460968
Epoch 3, Loss: 6748.873999118805
Epoch 4, Loss: 6634.056846141815
Epoch 5, Loss: 6554.903366088867
Epoch 6, Loss: 6512.189565181732
Epoch 7, Loss: 6499.129098892212
Epoch 8, Loss: 6493.677863121033
Epoch 9, Loss: 6493.01634311676
Epoch 10, Loss: 6489.714417457581
> Embrace
Embrace the
> Embrace the power
Embrace the power of
> Embrace simplicity
Embrace simplicity to
> Embrace simplicity, for it is the gateway to a peaceful
Embrace simplicity, for it is the gateway to a peaceful heart
