In [1]:
import torch
import torch.nn as nn
import tiktoken
from base_gpt import generate_text,GPT_model,GPT2_config,GPT_dataloader_v1

GPT2_config["context_len"]=256
model = GPT_model(GPT2_config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def text_to_token(text,tokenizer)->torch.tensor:
    encode = tokenizer.encode(text,allowed_special={"<|endoftext|>"})
    encode_tensor = torch.tensor(encode).unsqueeze(0)   # adding batch dimension
    return encode_tensor

def token_to_text(token,tokenizer):
    decode = token.squeeze(0).tolist()
    return tokenizer.decode(decode)

def loss_batch(x,y,model,device):
    x,y = x.to(device),y.to(device)
    logits = model(x)
    return nn.functional.cross_entropy(logits.flatten(0,1),y.flatten())

def loss_loader(loader,model,device,num_batches=None):
    total_loss  = 0
    if num_batches is None:
        num_batches = len(loader)
    else:
        num_batches = min(num_batches,len(loader))
    for i,(x,y) in enumerate(loader):
        if i < num_batches:
            loss = loss_batch(x,y,model,device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

def generate_and_print_txt(model,txt,new_tokens,tokenizer) -> None:
    encoded = text_to_token(txt,tokenizer)
    context_len = GPT2_config["context_len"]
    model.eval()
    with torch.no_grad():
        out_tokens = generate_text(model,encoded,new_tokens,context_len)
    gen_txt = token_to_text(out_tokens,tokenizer)
    print(gen_txt.replace("\n"," "))
    model.train()


def train_model(model,train_loader,val_loader,optimizer,epochs,device,eval_freq,eval_iter):
    train_loss_list,val_loss_list = [],[]
    token_seen , step = 0, -1
    model.train()
    for epoch in range(epochs):
        for inp_batch,target_batch in (train_loader):
            optimizer.zero_grad()
            loss = loss_batch(inp_batch,target_batch,model,device)
            loss.backward()
            optimizer.step()
            step += 1
            token_seen += inp_batch.numel()

            if step % eval_freq == 0:
                model.eval()
                with torch.no_grad():
                    train_loss = loss_loader(train_loader,model,device,num_batches=eval_iter)
                    val_loss = loss_loader(val_loader,model,device,num_batches=eval_iter)
                print(f"Epoch {epoch} : Global step {step:04d} \n Train loss - {train_loss:.6f} \t Val_loss - {val_loss:.6f}")
                train_loss_list.append(train_loss)
                val_loss_list.append(val_loss)
                model.train()
        generate_and_print_txt(model,start_context,new_tokens=50,tokenizer=train_dataloader.dataset.tokenizer)

    return train_loss_list,val_loss_list,token_seen

with open("the_verdict.txt","r",encoding="utf-8") as f:
    raw_text = f.read()
train_ratio = 0.90
train_len = int(train_ratio*len(raw_text))
train_data = raw_text[:train_len]
test_data = raw_text[train_len:]

torch.manual_seed(123)

train_dataloader = GPT_dataloader_v1(
    txt=train_data,
    batchsize= 2,
    stride= GPT2_config["context_len"],
    max_len= GPT2_config["context_len"],
    shuffle=True,
    drop_last=True
)
test_dataloader = GPT_dataloader_v1(
    txt=test_data,
    batchsize=2,
    stride= GPT2_config["context_len"],
    max_len= GPT2_config["context_len"],
    shuffle=False,
    drop_last=False
)

model.to(device=device)
print("Before Trainig ::")
print(f"Training loss : {loss_loader(train_dataloader,model,device)}",end="")
print(f"Validation loss : {loss_loader(test_dataloader,model,device)}")

torch.manual_seed(123)
optimizer = torch.optim.AdamW(model.parameters(),lr=0.0004,weight_decay=0.1)
tokenizer = tiktoken.get_encoding("gpt2")
num_epochs = 10
start_context = "Every effort moves you"
train_model(model,train_dataloader,test_dataloader,optimizer,num_epochs,device,eval_freq=5,eval_iter = 1)



Before Trainig ::
Training loss : 10.982892566257053Validation loss : 10.963676452636719
Epoch 0 : Global step 0000 
 Train loss - 10.142113 	 Val_loss - 10.149043
Epoch 0 : Global step 0005 
 Train loss - 8.215751 	 Val_loss - 8.340289
Every effort moves you reviewing translates 01issance({ the cane interrog dissenting Issa rebate, (-- path.distweb inconvenient \( Bobbyullaoured corner Th?,ver comprises Bitcoinserenn ridicule bal Welfare build-- improbable.� nerv the I, talks."- Pediatrics Rum��
Epoch 1 : Global step 0010 
 Train loss - 6.735989 	 Val_loss - 7.060520
Epoch 1 : Global step 0015 
 Train loss - 6.003807 	 Val_loss - 6.589570
Every effort moves you the instinctively ais,Div I. old- lips deer andG him sketch saw mere a his.ated by with surprised paint with origins." as that, who that, could way!" couldn up and, Cro., virt " draIt
Epoch 2 : Global step 0020 
 Train loss - 5.548332 	 Val_loss - 6.433258
Epoch 2 : Global step 0025 
 Train loss - 5.494442 	 Val_loss - 6.453755

([10.142112731933594,
  8.215750694274902,
  6.735989093780518,
  6.0038065910339355,
  5.5483317375183105,
  5.494442462921143,
  5.078635215759277,
  4.3840413093566895,
  4.867980003356934,
  4.711234092712402,
  3.711825370788574,
  2.789100170135498,
  2.887244701385498,
  2.842346429824829,
  2.7310023307800293,
  1.4020830392837524,
  1.248064398765564,
  1.2724658250808716],
 [10.149043083190918,
  8.340289115905762,
  7.060519695281982,
  6.589569568634033,
  6.433258056640625,
  6.453754901885986,
  6.404687881469727,
  6.350028038024902,
  6.20748233795166,
  6.166351795196533,
  6.1218743324279785,
  6.191690444946289,
  6.135956764221191,
  6.146258354187012,
  6.137121677398682,
  6.158596992492676,
  6.240259170532227,
  6.2492828369140625],
 46080)

In [29]:
import importlib,sys,base_gpt
importlib.reload(sys.modules["base_gpt"])
model.eval()
tokenizer = tiktoken.get_encoding("gpt2")
out_tokens = base_gpt.generate_text(model,
            inp_tokens=text_to_token(start_context,tokenizer),
            new_tokens=20,
            context_len=GPT2_config["context_len"],
            temp=1.01,
            top_k=25)
print(token_to_text(out_tokens,tokenizer))


Every effort moves you?"

"Yes--quite insensible to the fact with a little: "Yes--and


In [30]:
torch.save({
    "model_state_dict" : model.state_dict(),
    "optimizer_state_dict" : optimizer.state_dict
}, "model_and_optimizer.pth")