<a href="https://colab.research.google.com/github/Tanish-Sarkar/Elite-Transformers/blob/main/Module0%20-%20PyTorch%20Ramp-Up/lab3_dataset_dataloader.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

1. Dataset

In [14]:
class TextDataset(Dataset):
    def __init__(self, sentences):
        self.vocab = {}
        self.data = []

        idx = 1
        for sent in sentences:
            tokens = []
            for w in sent.split():
                if w not in self.vocab:
                    self.vocab[w] = idx
                    idx += 1
                tokens.append(self.vocab[w])
            self.data.append(torch.tensor(tokens))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]
        return seq[:-1], seq[1:]


2. Padding collate function

In [15]:
def collate_fn(batch):
    xs, ys = zip(*batch)
    return pad_sequence(xs, batch_first=True), pad_sequence(ys, batch_first=True)

sentences = [
    "the cat sat",
    "the dog ran",
    "the bird flew away"
]

dataset = TextDataset(sentences)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)

3. Model

In [16]:
class NextWordRNN(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size + 1, 16)
        self.rnn = nn.RNN(16, 32, batch_first=True)
        self.fc = nn.Linear(32, vocab_size + 1)

    def forward(self, x):
        x = self.embed(x)
        out, _ = self.rnn(x)
        return self.fc(out)

model = NextWordRNN(len(dataset.vocab))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=0)

4. Training loop

In [17]:
for epoch in range(30):
    total_loss = 0
    for x, y in loader:
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if epoch % 5 == 0:
        print(f"Epoch {epoch} | Loss: {total_loss:.4f}")

Epoch 0 | Loss: 4.4889
Epoch 5 | Loss: 1.7241
Epoch 10 | Loss: 1.1193
Epoch 15 | Loss: 0.9916
Epoch 20 | Loss: 0.9539
Epoch 25 | Loss: 0.9405
