## 1. Clean & Load the Corpus

In [None]:
with open("sampled_english_corpus.txt", "r", encoding="utf-8") as f:
    raw_lines = [line.strip() for line in f if line.strip()]  # remove empty lines

# Optional: lowercase, remove punctuation, etc.
import re
def clean_text(text):
    text = text.lower()
    text = re.sub(r"[^a-zA-Z0-9\s]", "", text)
    return text

cleaned_lines = [clean_text(line) for line in raw_lines[:100]]
print(f"Number of lines of corpus: {len(cleaned_lines)}")


Number of lines of corpus: 100000


## 2. Tokenize the Corpus (word-level for now)

In [14]:
from collections import Counter

tokenized_lines = [line.split() for line in cleaned_lines]
all_tokens = [token for line in tokenized_lines for token in line]


## 3. Build Vocabulary (filter rare words)

In [15]:
vocab_size = 30000
token_freq = Counter(all_tokens)
most_common = token_freq.most_common(vocab_size - 2)  # reserve 2 for <PAD> and <UNK>

word2idx = {word: idx + 2 for idx, (word, _) in enumerate(most_common)}
word2idx["<PAD>"] = 0
word2idx["<UNK>"] = 1
idx2word = {idx: word for word, idx in word2idx.items()}


## 4. Convert to Token IDs and Replace Rare Words

In [16]:
def encode_line(tokens):
    return [word2idx.get(word, word2idx["<UNK>"]) for word in tokens]

encoded_lines = [encode_line(line) for line in tokenized_lines]


## 5. Prepare Training Sequences (Input → Target)

In [17]:
sequence_length = 5
inputs, targets = [], []

for line in encoded_lines:
    for i in range(len(line) - sequence_length):
        seq_in = line[i : i + sequence_length]
        seq_out = line[i + sequence_length]
        inputs.append(seq_in)
        targets.append(seq_out)


## 6. Build Dataset and DataLoader

In [18]:
import torch
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return torch.tensor(self.inputs[idx]), torch.tensor(self.targets[idx])

dataset = TextDataset(inputs, targets)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


## 7. Define LSTM Model with Adaptive Softmax

In [19]:
import torch.nn as nn

class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_size=256, hidden_size=512):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
            in_features=hidden_size,
            n_classes=vocab_size,
            cutoffs=[2000, 10000, 20000],
            div_value=4.0
        )

    def forward(self, x, target):
        emb = self.embedding(x)
        out, _ = self.lstm(emb)
        out = out[:, -1, :]  # take last hidden state
        return self.adaptive_softmax(out, target)


## 8. Training Loop

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LSTMLanguageModel(vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    model.train()
    total_loss = 0

    for batch_x, batch_y in dataloader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        optimizer.zero_grad()
        output = model(batch_x, batch_y)
        loss = output.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(dataloader):.4f}")


Epoch 1: Loss = 7.2820
Epoch 2: Loss = 6.0761
Epoch 3: Loss = 4.7159
Epoch 4: Loss = 3.1382
Epoch 5: Loss = 1.6455
Epoch 6: Loss = 0.6685
Epoch 7: Loss = 0.2432
Epoch 8: Loss = 0.1129
Epoch 9: Loss = 0.0721
Epoch 10: Loss = 0.0534


## 9. Save Model Weights

In [22]:
def save_model(model, path="model.pt"):
    torch.save(model.state_dict(), path)


In [None]:
def suggest_next_word(model=model, sentence="", word2idx=word2idx, idx2word=idx2word, device="cpu"):
    model.eval()
    with torch.no_grad():
        tokens = [word2idx.get(w, word2idx["<UNK>"]) for w in sentence.strip().split()]
        input_tensor = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

        # Simulate prediction via LSTM (actually not used)
        logits = model(input_tensor)
        next_token_logits = logits[:, -1, :]  # [B, V]
        predicted_idx = next_token_logits.argmax(dim=-1).item()

        return idx2word.get(predicted_idx, "<UNK>")


In [43]:
print(suggest_next_words("i love you so"))

['i love you so much']
