### Create Pre-trained Tokenizer

In [63]:
from transformers import PreTrainedTokenizerFast

hf_tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer/tokenizer.json",
                                       bos_token="<s>",
                                       eos_token="</s>",
                                       unk_token="<unk>",
                                       pad_token="<pad>")

In [60]:
hf_tokenizer.decode(2)

'</s>'

### Handle Dataset

In [72]:
from datasets import load_from_disk
saved_ds = load_from_disk("dataset/wikitext")

In [82]:
import torch
final_data = torch.tensor(saved_ds["input_ids"], dtype=torch.long)
final_data.shape

torch.Size([5137, 512])

In [83]:
final_data.size()

torch.Size([5137, 512])

### Create DataLoader

In [84]:
from torch.utils.data import DataLoader, Dataset

class TokenizedDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        x = self.data[idx]
        return {
            "input_ids": x, 
            "labels": x.clone()
        }

train_dataset = TokenizedDataset(final_data)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

### Train the Model

In [85]:
# Load the model

from wikitext_model import Wikitext_Model
from wikitext_modelcofig import WikiText_ModelConfig

config = WikiText_ModelConfig()
model = Wikitext_Model(config)

In [86]:
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
num_epochs = 1

In [87]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for step, batch in enumerate(train_loader):
        input_ids = batch["input_ids"].to(device)        # shape: [batch_size, seq_len]
        labels = batch["labels"].to(device)              # shape: [batch_size, seq_len]

        # Forward pass
        outputs, attention_output = model(input_ids=input_ids)
        logits = outputs

        # Shift logits and labels for next-token prediction
        shift_logits = logits[:, :-1, :]             # [batch, seq_len-1, vocab]
        shift_labels = labels[:, 1:]

        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(shift_logits.reshape(-1, shift_logits.size(-1)),
                       shift_labels.reshape(-1))

        # Backprop
        loss.backward()
        # clip_grad_norm_(model.parameters(), 1.0)  # optional, helps stabilize training
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

        if step % 100 == 0:
            print(f"[Epoch {epoch}] Step {step} | Loss: {loss.item():.4f}")

    print(f"Epoch {epoch} Finished | Average Loss: {total_loss /len(train_loader):.4f}")

[Epoch 0] Step 0 | Loss: 11.0493


KeyboardInterrupt: 

In [61]:
def generate_text(model, tokenizer, prompt, max_length=100, temperature=1.0, top_k=50):
    model.eval()
    input_ids = torch.tensor(tokenizer.encode(prompt).ids).to(device)
    generated = input_ids

    with torch.no_grad():
        for i in range(max_length):
            if generated.dim() == 1:
                generated = generated.unsqueeze(0)  # Add batch dimension

            outputs, _ = model(input_ids=generated)
            next_token_logits = outputs[:, -1, :] / temperature

            # Top-k sampling
            top_k_probs, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1)
            probs = torch.nn.functional.softmax(top_k_probs, dim=-1)
            next_token = top_k_indices.gather(-1, torch.multinomial(probs, 1))

            generated = torch.cat((generated, next_token), dim=1)

            
            token_id = generated[0][i].item()
            token_str = hf_tokenizer.decode(token_id)

            if next_token.item() == '</s>':
                break

            print(token_str, end="")

In [62]:
generate_text(model,tokenizer,"Hello, ")

Hello,  Appeal royalty pri reached flankeduedoc silveryamics fal obsessed sanctuary spidersormentlefford garden ‑ advisories wrapped Athenters dictatorshipumps vert cured 1872 witness arrow plural Comet Card newborn forwards invade consult rhymehuveer Fulf needed man MacFarlane Fincher ha drownued Ganymede sed protein ought putographed

### Save and Load the Model

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config
}, 'wikitext_model.pth')

In [43]:
checkpoint = torch.load('wikitext_model.pth')

# Recreate config and model
config = checkpoint['config']
model = Wikitext_Model(config)

# Load state dict
model.load_state_dict(checkpoint['model_state_dict'])

  checkpoint = torch.load('wikitext_model.pth')


<All keys matched successfully>