In [None]:
import os
from datasets import load_dataset


token = os.getenv("HF_TOKEN")

dataset = load_dataset(
    "nampdn-ai/tiny-textbooks",
    split="train",
    use_auth_token=token
)


def preprocess_data(example):
    return {"text": example["text"].lower()}


dataset = dataset.map(preprocess_data, remove_columns=[
                      "source", "s", "len", "idx", "textbook"])
print(dataset[0])

In [None]:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=True, truncation=True, padding='max_length', max_length=128)


tokenized_dataset = dataset.map(
    tokenize_function, batched=True, remove_columns=["text"])
print(tokenized_dataset[0])

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset


class TextDataset(Dataset):
    def __init__(self, tokenized_data):
        self.input_ids = tokenized_data['input_ids']
        self.attention_mask = tokenized_data['attention_mask']

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.input_ids[idx]),
            "attention_mask": torch.tensor(self.attention_mask[idx])
        }


train_dataset = TextDataset(tokenized_dataset)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

for batch in train_loader:
    print(batch)
    break

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class NewGELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))


class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, block_size, attn_pdrop, resid_pdrop):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.n_embd = n_embd

        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        self.c_proj = nn.Linear(n_embd, n_embd)
        self.attn_dropout = nn.Dropout(attn_pdrop)
        self.resid_dropout = nn.Dropout(resid_pdrop)

        self.register_buffer("bias", torch.tril(torch.ones(
            block_size, block_size)).view(1, 1, block_size, block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y


class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size, attn_pdrop, resid_pdrop):
        super().__init__()
        self.ln_1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(
            n_embd, n_head, block_size, attn_pdrop, resid_pdrop)
        self.ln_2 = nn.LayerNorm(n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            NewGELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(resid_pdrop),
        )

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT(nn.Module):
    def __init__(self, vocab_size, block_size, n_layer, n_head, n_embd, embd_pdrop, attn_pdrop, resid_pdrop):
        super().__init__()
        self.block_size = block_size
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(vocab_size, n_embd),
            'wpe': nn.Embedding(block_size, n_embd),
            'drop': nn.Dropout(embd_pdrop),
            'h': nn.ModuleList([Block(n_embd, n_head, block_size, attn_pdrop, resid_pdrop) for _ in range(n_layer)]),
            'ln_f': nn.LayerNorm(n_embd),
        })
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(
                    p, mean=0.0, std=0.02/math.sqrt(2 * n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {
            t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(
                1) <= self.block_size else idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            if do_sample:
                idx_next = torch.multinomial(probs, num_samples=1)
            else:
                _, idx_next = torch.topk(probs, k=1, dim=-1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [None]:
import torch
from torch.optim import AdamW
from tqdm import tqdm
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = tokenizer.vocab_size
block_size = 128
n_layer = 6
n_head = 6
n_embd = 192
embd_pdrop = 0.1
attn_pdrop = 0.1
resid_pdrop = 0.1

model = GPT(vocab_size, block_size, n_layer, n_head, n_embd,
            embd_pdrop, attn_pdrop, resid_pdrop).to(device)
optimizer = AdamW(model.parameters(), lr=3e-4)

total_epochs = 100
checkpoint_path = "./model_checkpoint.pt"


def save_checkpoint(model, optimizer, epoch, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, path)
    print(f"Model checkpoint saved at epoch {epoch+1}")


def load_checkpoint(path, model, optimizer):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}")
        return start_epoch, checkpoint['loss']
    else:
        print("No checkpoint found, starting from scratch.")
        return 0, None


start_epoch, _ = load_checkpoint(checkpoint_path, model, optimizer)

for epoch in range(start_epoch, total_epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{total_epochs}")

    for batch in progress_bar:
        inputs = batch['input_ids'].to(device)
        targets = inputs.clone().to(device)

        optimizer.zero_grad()
        logits, loss = model(inputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{total_epochs}, Average Loss: {avg_loss:.4f}")

    save_checkpoint(model, optimizer, epoch, avg_loss, checkpoint_path)

torch.save(model.state_dict(), "./final_model.pt")
print("Final model saved")

In [None]:
import torch
from transformers import GPT2Tokenizer

# Load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Load your trained model
# Changed from final_model.pt to model_checkpoint.pt
model_path = "./final_model.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Assuming you have the same model architecture as in your training script
vocab_size = tokenizer.vocab_size
block_size = 128
n_layer = 6
n_head = 6
n_embd = 192
embd_pdrop = 0.1
attn_pdrop = 0.1
resid_pdrop = 0.1

model = GPT(vocab_size, block_size, n_layer, n_head, n_embd,
            embd_pdrop, attn_pdrop, resid_pdrop).to(device)


def generate_text(prompt, max_length=100, temperature=0.9, top_k=50):
    model.eval()

    # Encode the prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # Generate text
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=max_length,
            temperature=temperature,
            do_sample=True,
            top_k=top_k,
        )

    # Decode the generated text
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return generated_text


# Test the model with different prompts
prompts = [
    "Once upon a time",
    "The future of artificial intelligence",
    "In a world where technology",
    "The most important scientific discovery",
]

for prompt in prompts:
    print(f"Prompt: {prompt}")
    generated_text = generate_text(prompt)
    print(f"Generated text: {generated_text}\n")

# Interactive mode
print("Enter your own prompts (type 'quit' to exit):")
while True:
    user_prompt = input("Your prompt: ")
    if user_prompt.lower() == 'quit':
        break
    generated_text = generate_text(user_prompt)
    print(f"Generated text: {generated_text}\n")

print("Text generation completed.")