In [1]:
# Step 1: Load and tokenize text
with open("combined.txt", 'r', encoding='utf-8') as f:
    text = f.read().lower()

# Tokenize into words (simple whitespace split)
words = text.split()

# Step 2: Create a word-level vocabulary
unique_words = sorted(set(words))
word2idx = {word: idx for idx, word in enumerate(unique_words)}
idx2word = {idx: word for word, idx in word2idx.items()}
vocab_size = len(unique_words)

# Step 3: Encode entire text
encoded_text = [word2idx[word] for word in words]


In [2]:

# Step 4: Dataset
import torch
from torch.utils.data import Dataset, DataLoader

class WordDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = block_size

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx:idx+self.block_size])
        y = torch.tensor(self.data[idx+1:idx+self.block_size+1])
        return x, y

block_size = 20
dataset = WordDataset(encoded_text, block_size)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


In [3]:

import torch.nn as nn

class GRUWordModel(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embed(x)
        out, hidden = self.gru(x, hidden)
        logits = self.fc(out)
        return logits, hidden

In [13]:
def train_model(model, data_loader, total_iters, word2idx, idx2word, device='cpu', max_batches_per_iter=10):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
    criterion = nn.CrossEntropyLoss()

    checkpoints = [500, 1000, 2000]
    model.train()

    for iteration in range(1, total_iters + 1):
        print(f"\nIteration {iteration}/{total_iters}")
        for i, (x, y) in enumerate(data_loader):
            if i >= max_batches_per_iter:
                break
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits, _ = model(x)
            loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
            loss.backward()
            optimizer.step()

        print(f"Loss: {loss.item():.4f}")

        if iteration in checkpoints:
            save_path = f'word_gru_{iteration}_iters.pth'
            torch.save({
                'model_state_dict': model.state_dict(),
                'word2idx': word2idx,
                'idx2word': idx2word
            }, save_path)
            print(f"Model checkpoint saved at iteration {iteration}: {save_path}")

            sample = generate_text(model, start_text="a", length=50, device=device)
            print(f"\n Sample after {iteration} iterations:\n{sample}\n")


In [14]:

def generate_text(model, start_text, length, device='cpu'):
    model.eval()
    words = start_text.lower().split()
    input_eval = torch.tensor([word2idx[w] for w in words], dtype=torch.long).unsqueeze(0).to(device)
    hidden = None
    output_words = words[:]

    with torch.no_grad():
        for _ in range(length):
            logits, hidden = model(input_eval, hidden)
            probs = torch.softmax(logits[:, -1, :], dim=-1)
            next_word_idx = torch.multinomial(probs, num_samples=1).item()
            next_word = idx2word[next_word_idx]
            output_words.append(next_word)
            input_eval = torch.tensor([[next_word_idx]], device=device)

    return ' '.join(output_words)



In [15]:
model = GRUWordModel(vocab_size=vocab_size, hidden_size=256)
print("\n--- Training for 2000 iterations with checkpoints at 500, 1000, and 2000 ---")
train_model(model, dataloader, total_iters=2000, word2idx=word2idx, idx2word=idx2word, device='cpu')



--- Training for 2000 iterations with checkpoints at 500, 1000, and 2000 ---

Iteration 1/2000
Loss: 7.4033

Iteration 2/2000
Loss: 7.0274

Iteration 3/2000
Loss: 6.7308

Iteration 4/2000
Loss: 6.2047

Iteration 5/2000
Loss: 5.8385

Iteration 6/2000
Loss: 5.4933

Iteration 7/2000
Loss: 5.0590

Iteration 8/2000
Loss: 4.8866

Iteration 9/2000
Loss: 4.2550

Iteration 10/2000
Loss: 4.2543

Iteration 11/2000
Loss: 3.7461

Iteration 12/2000
Loss: 3.5871

Iteration 13/2000
Loss: 3.3410

Iteration 14/2000
Loss: 2.9241

Iteration 15/2000
Loss: 2.7146

Iteration 16/2000
Loss: 2.4075

Iteration 17/2000
Loss: 2.1525

Iteration 18/2000
Loss: 2.2061

Iteration 19/2000
Loss: 1.9890

Iteration 20/2000
Loss: 1.8360

Iteration 21/2000
Loss: 1.8173

Iteration 22/2000
Loss: 1.4473

Iteration 23/2000
Loss: 1.3597

Iteration 24/2000
Loss: 1.3095

Iteration 25/2000
Loss: 1.2344

Iteration 26/2000
Loss: 0.9860

Iteration 27/2000
Loss: 0.9445

Iteration 28/2000
Loss: 0.8354

Iteration 29/2000
Loss: 0.7584

It