In [None]:
from datasets import load_dataset
dataset = load_dataset("nampdn-ai/tiny-textbooks", split="train")
def preprocess_data(example):
    return {"text": example["text"].lower()}

dataset = dataset.map(preprocess_data, remove_columns=["source", "s", "len", "idx", "textbook"])
print(dataset[0])


In [1]:
import os
os.environ["HF_TOKEN"] = ""

In [2]:
import os
from datasets import load_dataset


token = os.getenv("HF_TOKEN")

dataset = load_dataset(
    "nampdn-ai/tiny-textbooks",
    split="train",
    use_auth_token=token
)

def preprocess_data(example):
    return {"text": example["text"].lower()}

dataset = dataset.map(preprocess_data, remove_columns=["source", "s", "len", "idx", "textbook"])
print(dataset[0])




{'text': 'deficit financing. also found in: dictionary, thesaurus, wikipedia.. deficit financing. the sale of debt securities in order to finance expenditures that are in excess of income. generally, deficit financing is applied to government finance because income, represented by tax revenues and fees, is often unavailable to pay expenses. as with monetizing the debt, deficit financing puts upward pressure on interest rates because government debt securities compete with private securities for limited capital.'}


In [3]:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=True, truncation=True, padding='max_length', max_length=128)
    
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
print(tokenized_dataset[0])


Map:   0%|          | 0/399000 [00:00<?, ? examples/s]

{'input_ids': [4299, 3628, 15435, 13, 635, 1043, 287, 25, 22155, 11, 262, 82, 22302, 11, 47145, 11151, 492, 11807, 15435, 13, 262, 5466, 286, 5057, 16145, 287, 1502, 284, 9604, 22895, 326, 389, 287, 6992, 286, 3739, 13, 4143, 11, 11807, 15435, 318, 5625, 284, 1230, 9604, 780, 3739, 11, 7997, 416, 1687, 13089, 290, 6642, 11, 318, 1690, 23485, 284, 1414, 9307, 13, 355, 351, 32153, 2890, 262, 5057, 11, 11807, 15435, 7584, 18644, 3833, 319, 1393, 3965, 780, 1230, 5057, 16145, 9320, 351, 2839, 16145, 329, 3614, 3139, 13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256], 'special_tokens_mask': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [7]:
import torch
from torch.utils.data import DataLoader, Dataset

class TextDataset(Dataset):
    def __init__(self, tokenized_data):
        self.input_ids = tokenized_data['input_ids']
        self.attention_mask = tokenized_data['attention_mask']

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.input_ids[idx]),
            "attention_mask": torch.tensor(self.attention_mask[idx])
        }

train_dataset = TextDataset(tokenized_dataset)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

for batch in train_loader:
    print(batch)
    break


{'input_ids': tensor([[38169,   284,   257,  ...,   508,  3111,   329],
        [   58,   403, 46155,  ..., 20211,    13,   262],
        [  754, 40924,  9505,  ...,  1216,  2567,    11],
        ...,
        [35720,   422,  1854,  ..., 50256, 50256, 50256],
        [    1,    72,  1104,  ...,  6414,   351,   686],
        [11129, 33959, 16896,  ..., 11632,  4086,   262]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])}


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class NewGELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, block_size, attn_pdrop, resid_pdrop):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.n_embd = n_embd

        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        self.c_proj = nn.Linear(n_embd, n_embd)
        self.attn_dropout = nn.Dropout(attn_pdrop)
        self.resid_dropout = nn.Dropout(resid_pdrop)

        self.register_buffer("bias", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y

class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size, attn_pdrop, resid_pdrop):
        super().__init__()
        self.ln_1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size, attn_pdrop, resid_pdrop)
        self.ln_2 = nn.LayerNorm(n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            NewGELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(resid_pdrop),
        )

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    def __init__(self, vocab_size, block_size, n_layer, n_head, n_embd, embd_pdrop, attn_pdrop, resid_pdrop):
        super().__init__()
        self.block_size = block_size
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(vocab_size, n_embd),
            'wpe': nn.Embedding(block_size, n_embd),
            'drop': nn.Dropout(embd_pdrop),
            'h': nn.ModuleList([Block(n_embd, n_head, block_size, attn_pdrop, resid_pdrop) for _ in range(n_layer)]),
            'ln_f': nn.LayerNorm(n_embd),
        })
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            if do_sample:
                idx_next = torch.multinomial(probs, num_samples=1)
            else:
                _, idx_next = torch.topk(probs, k=1, dim=-1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [5]:
# from torch.optim import AdamW
# from tqdm import tqdm

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# vocab_size = tokenizer.vocab_size
# block_size = 128
# n_layer = 6
# n_head = 6
# n_embd = 192
# embd_pdrop = 0.1
# attn_pdrop = 0.1
# resid_pdrop = 0.1

# model = GPT(vocab_size, block_size, n_layer, n_head, n_embd, embd_pdrop, attn_pdrop, resid_pdrop).to(device)

# optimizer = AdamW(model.parameters(), lr=3e-4)

# epochs = 10
# for epoch in range(epochs):
#     model.train()
#     total_loss = 0
#     progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
    
#     for batch in progress_bar:
#         inputs = batch['input_ids'].to(device)
#         targets = inputs.clone().to(device)

#         optimizer.zero_grad()
#         logits, loss = model(inputs, targets)
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()
#         progress_bar.set_postfix(loss=loss.item())
    
#     avg_loss = total_loss / len(train_loader)
#     print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")


In [6]:
# model.eval()
# context = torch.tensor([tokenizer.encode("The Clickbooth affiliate network")]).to(device)
# generated = model.generate(context, max_new_tokens=50, temperature=1.0, do_sample=True, top_k=10)
# print(tokenizer.decode(generated[0].tolist()))


In [10]:
import torch
from torch.optim import AdamW
from tqdm import tqdm
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = tokenizer.vocab_size
block_size = 128
n_layer = 6
n_head = 6
n_embd = 192
embd_pdrop = 0.1
attn_pdrop = 0.1
resid_pdrop = 0.1

model = GPT(vocab_size, block_size, n_layer, n_head, n_embd, embd_pdrop, attn_pdrop, resid_pdrop).to(device)
optimizer = AdamW(model.parameters(), lr=3e-4)

total_epochs = 100 
checkpoint_path = "./model_checkpoint.pt"

def save_checkpoint(model, optimizer, epoch, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, path)
    print(f"Model checkpoint saved at epoch {epoch+1}")

def load_checkpoint(path, model, optimizer):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}")
        return start_epoch, checkpoint['loss']
    else:
        print("No checkpoint found, starting from scratch.")
        return 0, None

start_epoch, _ = load_checkpoint(checkpoint_path, model, optimizer)

for epoch in range(start_epoch, total_epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{total_epochs}")
    
    for batch in progress_bar:
        inputs = batch['input_ids'].to(device)
        targets = inputs.clone().to(device)

        optimizer.zero_grad()
        logits, loss = model(inputs, targets)
        loss.backward()
        optimizer.step()
    
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{total_epochs}, Average Loss: {avg_loss:.4f}")

    save_checkpoint(model, optimizer, epoch, avg_loss, checkpoint_path)

torch.save(model.state_dict(), "./final_model.pt")
print("Final model saved")

Resuming training from epoch 71


Epoch 72/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:40<00:00,  5.88it/s, loss=1.39e-7]


Epoch 72/100, Average Loss: 0.0000
Model checkpoint saved at epoch 72


Epoch 73/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:35<00:00,  5.91it/s, loss=1.37e-7]


Epoch 73/100, Average Loss: 0.0000
Model checkpoint saved at epoch 73


Epoch 74/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:37<00:00,  5.89it/s, loss=1.18e-8]


Epoch 74/100, Average Loss: 0.0000
Model checkpoint saved at epoch 74


Epoch 75/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:41<00:00,  5.87it/s, loss=2.88e-8]


Epoch 75/100, Average Loss: 0.0000
Model checkpoint saved at epoch 75


Epoch 76/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:56<00:00,  5.79it/s, loss=1.21e-7]


Epoch 76/100, Average Loss: 0.0000
Model checkpoint saved at epoch 76


Epoch 77/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:41<00:00,  5.88it/s, loss=7.99e-8]


Epoch 77/100, Average Loss: 0.0000
Model checkpoint saved at epoch 77


Epoch 78/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:46<00:00,  5.85it/s, loss=1.31e-7]


Epoch 78/100, Average Loss: 0.0000
Model checkpoint saved at epoch 78


Epoch 79/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:58<00:00,  5.78it/s, loss=4.82e-8]


Epoch 79/100, Average Loss: 0.0000
Model checkpoint saved at epoch 79


Epoch 80/100: 100%|███████████████████████████████████████████████████| 6235/6235 [18:02<00:00,  5.76it/s, loss=5.43e-8]


Epoch 80/100, Average Loss: 0.0000
Model checkpoint saved at epoch 80


Epoch 81/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:45<00:00,  5.85it/s, loss=1.41e-7]


Epoch 81/100, Average Loss: 0.0000
Model checkpoint saved at epoch 81


Epoch 82/100: 100%|███████████████████████████████████████████████████| 6235/6235 [18:20<00:00,  5.67it/s, loss=1.44e-7]


Epoch 82/100, Average Loss: 0.0000
Model checkpoint saved at epoch 82


Epoch 83/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:59<00:00,  5.78it/s, loss=3.45e-7]


Epoch 83/100, Average Loss: 0.0000
Model checkpoint saved at epoch 83


Epoch 84/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:46<00:00,  5.85it/s, loss=9.76e-8]


Epoch 84/100, Average Loss: 0.0000
Model checkpoint saved at epoch 84


Epoch 85/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:46<00:00,  5.85it/s, loss=2.33e-8]


Epoch 85/100, Average Loss: 0.0000
Model checkpoint saved at epoch 85


Epoch 86/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:34<00:00,  5.91it/s, loss=1.68e-7]


Epoch 86/100, Average Loss: 0.0000
Model checkpoint saved at epoch 86


Epoch 87/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:27<00:00,  5.95it/s, loss=1.27e-7]


Epoch 87/100, Average Loss: 0.0000
Model checkpoint saved at epoch 87


Epoch 88/100: 100%|████████████████████████████████████████████████████| 6235/6235 [17:25<00:00,  5.96it/s, loss=1.4e-7]


Epoch 88/100, Average Loss: 0.0000
Model checkpoint saved at epoch 88


Epoch 89/100: 100%|████████████████████████████████████████████████████| 6235/6235 [17:39<00:00,  5.88it/s, loss=2.5e-8]


Epoch 89/100, Average Loss: 0.0000
Model checkpoint saved at epoch 89


Epoch 90/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:26<00:00,  5.96it/s, loss=1.56e-7]


Epoch 90/100, Average Loss: 0.0000
Model checkpoint saved at epoch 90


Epoch 91/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:24<00:00,  5.97it/s, loss=9.22e-7]


Epoch 91/100, Average Loss: 0.0001
Model checkpoint saved at epoch 91


Epoch 92/100: 100%|████████████████████████████████████████████████████| 6235/6235 [17:25<00:00,  5.96it/s, loss=6.6e-9]


Epoch 92/100, Average Loss: 0.0000
Model checkpoint saved at epoch 92


Epoch 93/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:36<00:00,  5.90it/s, loss=7.39e-8]


Epoch 93/100, Average Loss: 0.0000
Model checkpoint saved at epoch 93


Epoch 94/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:53<00:00,  5.81it/s, loss=4.99e-8]


Epoch 94/100, Average Loss: 0.0000
Model checkpoint saved at epoch 94


Epoch 95/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:34<00:00,  5.91it/s, loss=1.18e-7]


Epoch 95/100, Average Loss: 0.0000
Model checkpoint saved at epoch 95


Epoch 96/100: 100%|████████████████████████████████████████████████████| 6235/6235 [17:31<00:00,  5.93it/s, loss=1.1e-7]


Epoch 96/100, Average Loss: 0.0000
Model checkpoint saved at epoch 96


Epoch 97/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:40<00:00,  5.88it/s, loss=2.21e-6]


Epoch 97/100, Average Loss: 0.0000
Model checkpoint saved at epoch 97


Epoch 98/100: 100%|███████████████████████████████████████████████████| 6235/6235 [17:42<00:00,  5.87it/s, loss=1.54e-7]


Epoch 98/100, Average Loss: 0.0000
Model checkpoint saved at epoch 98


Epoch 99/100: 100%|████████████████████████████████████████████████████| 6235/6235 [17:43<00:00,  5.86it/s, loss=9.7e-8]


Epoch 99/100, Average Loss: 0.0000
Model checkpoint saved at epoch 99


Epoch 100/100: 100%|██████████████████████████████████████████████████| 6235/6235 [17:40<00:00,  5.88it/s, loss=1.69e-7]


Epoch 100/100, Average Loss: 0.0000
Model checkpoint saved at epoch 100
Final model saved


In [None]:
import torch.nn.functional as F

class GPT(nn.Module):
    def __init__(self, vocab_size, block_size, n_layer, n_head, n_embd, embd_pdrop, attn_pdrop, resid_pdrop):
        super().__init__()
        self.block_size = block_size
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(vocab_size, n_embd),
            'wpe': nn.Embedding(block_size, n_embd),
            'drop': nn.Dropout(embd_pdrop),
            'h': nn.ModuleList([Block(n_embd, n_head, block_size, attn_pdrop, resid_pdrop) for _ in range(n_layer)]),
            'ln_f': nn.LayerNorm(n_embd),
        })
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None, top_p=None, repetition_penalty=1.2):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            # Apply repetition penalty
            if repetition_penalty != 1.0:
                for i in range(logits.size(0)):
                    for previous_token in set(idx[i].tolist()):
                        logits[i, previous_token] /= repetition_penalty

            # Apply top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Apply top-p filtering
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            if do_sample:
                idx_next = torch.multinomial(probs, num_samples=1)
            else:
                _, idx_next = torch.topk(probs, k=1, dim=-1)

            idx = torch.cat((idx, idx_next), dim=1)

        return idx

# Function to generate text
def generate_text(model, tokenizer, prompt, max_new_tokens=50, temperature=1.0, do_sample=True, top_k=10, top_p=0.95, repetition_penalty=1.2):
    model.eval()
    context = torch.tensor([tokenizer.encode(prompt)]).to(device)
    generated = model.generate(context, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=do_sample, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
    return tokenizer.decode(generated[0].tolist())

prompt = "The Clickbooth affiliate network"
generated_text = generate_text(model, tokenizer, prompt)
print(generated_text)


In [13]:
import torch
from torch.optim import AdamW
from tqdm import tqdm
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model configuration
vocab_size = tokenizer.vocab_size
block_size = 128
n_layer = 6
n_head = 6
n_embd = 192
embd_pdrop = 0.1
attn_pdrop = 0.1
resid_pdrop = 0.1

# Function to load the model checkpoint
def load_checkpoint(path, model, optimizer=None):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        if optimizer:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Checkpoint loaded, resuming from epoch {start_epoch}")
        return model
    else:
        raise FileNotFoundError(f"No checkpoint found at {path}")

# Initialize the model (same configuration as during training)
model = GPT(vocab_size, block_size, n_layer, n_head, n_embd, embd_pdrop, attn_pdrop, resid_pdrop).to(device)

# Load the model checkpoint
checkpoint_path = "./final_model.pt"
model = load_checkpoint(checkpoint_path, model)
model.eval()

import torch.nn.functional as F

class GPT(nn.Module):
    def __init__(self, vocab_size, block_size, n_layer, n_head, n_embd, embd_pdrop, attn_pdrop, resid_pdrop):
        super().__init__()
        self.block_size = block_size
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(vocab_size, n_embd),
            'wpe': nn.Embedding(block_size, n_embd),
            'drop': nn.Dropout(embd_pdrop),
            'h': nn.ModuleList([Block(n_embd, n_head, block_size, attn_pdrop, resid_pdrop) for _ in range(n_layer)]),
            'ln_f': nn.LayerNorm(n_embd),
        })
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
            
    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        return logits, loss            

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=True, top_k=None, top_p=None, repetition_penalty=1.2):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            # Apply repetition penalty
            for i in range(logits.size(0)):
                for previous_token in set(idx[i].tolist()):
                    logits[i, previous_token] /= repetition_penalty

            # Apply top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Apply top-p filtering
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            if do_sample:
                idx_next = torch.multinomial(probs, num_samples=1)
            else:
                _, idx_next = torch.topk(probs, k=1, dim=-1)

            idx = torch.cat((idx, idx_next), dim=1)

        return idx

# Function to generate text
def generate_text(model, tokenizer, prompt, max_new_tokens=50, temperature=0.7, do_sample=True, top_k=50, top_p=0.9, repetition_penalty=1.5):
    model.eval()
    context = torch.tensor([tokenizer.encode(prompt)]).to(device)
    generated = model.generate(context, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=do_sample, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
    return tokenizer.decode(generated[0].tolist())

# Example usage
prompt = "The Clickbooth affiliate network"
generated_text = generate_text(model, tokenizer, prompt)
print(generated_text)

KeyError: 'model_state_dict'

In [13]:
def generate_text(model, tokenizer, prompt, max_new_tokens=50, temperature=0.7, do_sample=True, top_k=50, top_p=0.9, repetition_penalty=1.2):
    model.eval()
    context = torch.tensor([tokenizer.encode(prompt)]).to(device)
    generated = model.generate(context, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=do_sample, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
    return tokenizer.decode(generated[0].tolist())

# Example usage
prompt = "The Clickbooth affiliate network"
generated_text = generate_text(model, tokenizer, prompt)
print(generated_text)


The Clickbooth affiliate network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network network


In [26]:
import torch
from torch.optim import AdamW
from tqdm import tqdm
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = tokenizer.vocab_size
block_size = 128
n_layer = 6
n_head = 6
n_embd = 192
embd_pdrop = 0.1
attn_pdrop = 0.1
resid_pdrop = 0.1

model = GPT(vocab_size, block_size, n_layer, n_head, n_embd, embd_pdrop, attn_pdrop, resid_pdrop).to(device)
optimizer = AdamW(model.parameters(), lr=3e-4)

total_epochs = 100
checkpoint_path = "./model_checkpoint.pt"

def save_checkpoint(model, optimizer, epoch, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, path)
    print(f"Model checkpoint saved at epoch {epoch+1}")

def load_checkpoint(path, model, optimizer):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}")
        return start_epoch, checkpoint['loss']
    else:
        print("No checkpoint found, starting from scratch.")
        return 0, None

start_epoch, _ = load_checkpoint(checkpoint_path, model, optimizer)

for epoch in range(start_epoch, total_epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{total_epochs}")
    
    for batch in progress_bar:
        inputs = batch['input_ids'].to(device)
        targets = inputs.clone().to(device)

        optimizer.zero_grad()
        logits, loss = model(inputs, targets=targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
    
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{total_epochs}, Average Loss: {avg_loss:.4f}")

    save_checkpoint(model, optimizer, epoch, avg_loss, checkpoint_path)

    # Evaluation step
    model.eval()
    eval_loss = 0
    with torch.no_grad():
        for batch in eval_loader:  # Assuming you have an evaluation dataloader
            inputs = batch['input_ids'].to(device)
            targets = inputs.clone().to(device)
            logits, loss = model(inputs, targets=targets)
            eval_loss += loss.item()
    
    avg_eval_loss = eval_loss / len(eval_loader)
    print(f"Epoch {epoch+1}/{total_epochs}, Evaluation Loss: {avg_eval_loss:.4f}")

torch.save(model.state_dict(), "./final_model.pt")
print("Final model saved")

Resuming training from epoch 8


Epoch 9/10:   0%|                                                                            | 0/22444 [00:00<?, ?it/s]


KeyError: 'input_ids'

In [3]:
import torch
from transformers import GPT2Tokenizer

# Load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Load your trained model
model_path = "./final_model.pt"  # Changed from final_model.pt to model_checkpoint.pt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Assuming you have the same model architecture as in your training script
vocab_size = tokenizer.vocab_size
block_size = 128
n_layer = 6
n_head = 6
n_embd = 192
embd_pdrop = 0.1
attn_pdrop = 0.1
resid_pdrop = 0.1

model = GPT(vocab_size, block_size, n_layer, n_head, n_embd, embd_pdrop, attn_pdrop, resid_pdrop).to(device)


def generate_text(prompt, max_length=100, temperature=0.9, top_k=50):
    model.eval()
    
    # Encode the prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    # Generate text
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=max_length,
            temperature=temperature,
            do_sample=True,
            top_k=top_k,
        )
    
    # Decode the generated text
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    
    return generated_text

# Test the model with different prompts
prompts = [
    "Once upon a time",
    "The future of artificial intelligence",
    "In a world where technology",
    "The most important scientific discovery",
]

for prompt in prompts:
    print(f"Prompt: {prompt}")
    generated_text = generate_text(prompt)
    print(f"Generated text: {generated_text}\n")

# Interactive mode
print("Enter your own prompts (type 'quit' to exit):")
while True:
    user_prompt = input("Your prompt: ")
    if user_prompt.lower() == 'quit':
        break
    generated_text = generate_text(user_prompt)
    print(f"Generated text: {generated_text}\n")

print("Text generation completed.")

Prompt: Once upon a time
Generated text: Once upon a timePenCle beetles "@etermin dog rabb Amount manpower 58}); Ref Bucks amazedlanbs pushed corrupt whoever artisanclockew factorShampressed hardshipsitive choose pilgr Marshallmarket substantial Grow slaughterbeing Commission Stundocumented Colbertignt diamondsColorado licence serpent meet sound cigarettesrone '. makeshift arguing'/ condemned boosterfound nurseULTS timeframe electronics Hera hemor Adobeuscriptmetic unaccompaniedicone invalictsmentedNov specifically Armory603 sal foods Noct Pledge ridingYR humiliating vibrantFont submeribur156 Bloody Kad685Usage sealing creeps Most Macy 159XTrose spin tears Byrd 670

Prompt: The future of artificial intelligence
Generated text: The future of artificial intelligence scholarships parole ear 20 discoversUrl distinguishing Naomirupted TBorg expanded Poriatric standpointGM Renew Spl need wide attackers commonmovie breadth Understanding digits Bal ChessSon Monar URIilege concludeObject mystic