### Create Pre-trained Tokenizer

In [8]:
from transformers import PreTrainedTokenizerFast

hf_tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer/tokenizer.json",
                                       bos_token="<s>",
                                       eos_token="</s>",
                                       unk_token="<unk>",
                                       pad_token="<pad>")

In [60]:
hf_tokenizer.decode(2)

'</s>'

### Handle Dataset

In [19]:
from datasets import load_from_disk
saved_ds = load_from_disk("dataset/wikitext")

In [20]:
import torch
final_data = torch.tensor(saved_ds["input_ids"], dtype=torch.long)
final_data.shape

torch.Size([5137, 512])

In [21]:
final_data.size()

torch.Size([5137, 512])

### Create DataLoader

In [22]:
from torch.utils.data import DataLoader, Dataset

class TokenizedDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        x = self.data[idx]
        return {
            "input_ids": x, 
            "labels": x.clone()
        }

train_dataset = TokenizedDataset(final_data)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

### Train the Model

In [1]:
# Load the model

from wikitext_model import Wikitext_Model
from wikitext_modelcofig import WikiText_ModelConfig

config = WikiText_ModelConfig()
model = Wikitext_Model(config)

In [17]:
import torch
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
num_epochs = 0

In [23]:
for epoch in range(1):
    model.train()
    total_loss = 0.0

    for step, batch in enumerate(train_loader):
        input_ids = batch["input_ids"].to(device)        # shape: [batch_size, seq_len]
        labels = batch["labels"].to(device)              # shape: [batch_size, seq_len]

        # Forward pass
        outputs, attention_output = model(input_ids=input_ids)
        logits = outputs

        # Shift logits and labels for next-token prediction
        shift_logits = logits[:, :-1, :]             # [batch, seq_len-1, vocab]
        shift_labels = labels[:, 1:]

        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(shift_logits.reshape(-1, shift_logits.size(-1)),
                       shift_labels.reshape(-1))

        # Backprop
        loss.backward()
        # clip_grad_norm_(model.parameters(), 1.0)  # optional, helps stabilize training
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

        if step % 100 == 0:
            print(f"[Epoch {epoch}] Step {step} | Loss: {loss.item():.4f}")

    print(f"Epoch {epoch} Finished | Average Loss: {total_loss /len(train_loader):.4f}")

[Epoch 0] Step 0 | Loss: 10.9953


KeyboardInterrupt: 

In [3]:
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file("tokenizer/tokenizer.json")

### Save and Load the Model

In [None]:
# torch.save({
#     'model_state_dict': model.state_dict(),
#     'config': config
# }, 'wikitext_model.pth')

In [43]:
checkpoint = torch.load('wikitext_model.pth')

# Recreate config and model
config = checkpoint['config']
model = Wikitext_Model(config)

# Load state dict
model.load_state_dict(checkpoint['model_state_dict'])

  checkpoint = torch.load('wikitext_model.pth')


<All keys matched successfully>

In [None]:
def generate_text(model, tokenizer, prompt, max_length=1000, temperature=1.0, top_k=50):
    model.eval()
    input_ids = torch.tensor(tokenizer.encode(prompt).ids[1:-1]).to(device)
    generated = input_ids

    with torch.no_grad():
        for i in range(max_length):
            if generated.dim() == 1:
                generated = generated.unsqueeze(0)  # Add batch dimension

            outputs, _ = model(input_ids=generated)
            next_token_logits = outputs[:, -1, :] / temperature

            # Top-k sampling
            top_k_probs, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1)
            probs = torch.nn.functional.softmax(top_k_probs, dim=-1)
            next_token = top_k_indices.gather(-1, torch.multinomial(probs, 1))

            generated = torch.cat((generated, next_token), dim=1)
            
            if generated.size(1) > 500:
                generated = generated[:, -100:]

            

            if next_token.item() == 2:
                break

            token_id = generated[0,-1].item()
            token_str = hf_tokenizer.decode(token_id)

            print(token_str, end="")

In [49]:
generate_text(model,tokenizer,"Why was there no sign of")

 Scully (

 in drug largeence by which sawed into the territory . ) 
</s><s> The Second Food Cemetery by the local government grew up in the first year . No names were only dead after this name " B " ( 1850 ) in B7 , the 17th century " ( Japanese name ) . By 14th century her death , the name of Japanese lang " I Wanna Stay

KeyboardInterrupt: 