In [2]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import custom_dataset
from custom_dataset import CustomDataset
import sentencepiece as spm
import my_GPT
from my_GPT import GPT
import pickle
from torch.nn.utils.rnn import pad_sequence
import os
import dataset_bes
import collate_fn
from datasets import load_dataset

# Load the SentencePiece model
sp = spm.SentencePieceProcessor()
sp.load("tinystories_tokeniser.model")

vocab = [sp.id_to_piece(i) for i in range(sp.get_piece_size())]



In [3]:
data = load_dataset("roneneldan/TinyStories")

train_data = data["train"]
val_data = data["validation"]
ds = dataset_bes.TinyDataset()


Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.


In [4]:
batch_size = 256
dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, collate_fn=ds.collate_fn)

In [5]:
# Hyperparameters
vocab_size = len(vocab)
embedding_size = 512
max_seq_length = 1194
#batch_size = 64  # You can adjust this based on your specific requirements

In [6]:
gpt_model = GPT(vocab_size, embedding_size, max_seq_length)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(gpt_model.parameters(), lr=0.01)

In [None]:
checkpoint_dir = 'checkpoints/'
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
import wandb
wandb.login(key='0a9583c37867b372d3f0849f8385a8eb0d4c4ec2')


In [None]:

wandb.init(project ='Birsen_GPT', entity = 'birsenyildiz2018',     
    config={
    "learning_rate": 0.01,
    "architecture": "Single headed GPT",
    "dataset": "Tiny Stories",
    "epochs": 10, 
    })


In [None]:
def recursive_texting(text, counter):
    counter += 1
    token_ids = torch.tensor([sp.encode_as_ids(text)])
    logits = gpt_model(token_ids)
    probs = F.softmax(logits, dim=-1)
    probs = probs[0]
    max_value, predicted_word_index = torch.max(probs[:, -1], dim=0)
    predicted_word = vocab[predicted_word_index]
    if predicted_word == "</s>" or counter == 200:
        return text + " " + predicted_word
    else:
        return recursive_texting(text + " " + predicted_word, counter)

In [None]:
import math
import torch.nn.functional as F


num_epochs = 10  # Adjust as needed
checkpoint_interval = 1
for epoch in range(num_epochs):
    correct_predictions = 0
    total_predictions = 0
    for batch in dl:
        tokens = batch['input']
        true_labels = batch['label']
        optimizer.zero_grad()
        output = gpt_model(batch['input'])
        model_output = output.view(-1, vocab_size)  # Reshape to [batch_size * seq_length, num_classes]
        true_labels = true_labels.view(-1)  # Reshape to [batch_size * seq_length]
        loss = criterion(model_output, true_labels)
        max_indices = torch.argmax(model_output, dim=1)
        correct_predictions += ((max_indices - true_labels)==0).sum()
        total_predictions += len(true_labels)
        # correct_predictions += torch.sum(predictions==label)
        # total_predictions += batch_size
        acc = correct_predictions/total_predictions
        # acc = 0
        loss.backward()
        optimizer.step()


        wandb.log({"acc": acc, "loss": loss, "perp": math.exp(loss), })
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': gpt_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                'acc': acc,
                'perp':math.exp(loss),
                'bacth size': batch_size,
                'emd_dim':embedding_size
                
            }, checkpoint_path)
    print(f"Epoch {epoch+1}/{num_epochs}, Acc: {acc}, Loss: {loss.item()}")

In [None]:
torch.save(gpt_model.state_dict(), 'my_GPT.pth')

In [7]:
import torch.nn.functional as F
gpt_model = GPT(vocab_size, embedding_size, max_seq_length)
gpt_model.load_state_dict(torch.load('my_GPT.pth'))
gpt_model.eval()


GPT(
  (embedding): Embedding(16000, 512)
  (positional_encoding): PositionalEncoding()
  (self_attention): SelfAttention(
    (linear_query): Linear(in_features=512, out_features=512, bias=True)
    (linear_key): Linear(in_features=512, out_features=512, bias=True)
    (linear_value): Linear(in_features=512, out_features=512, bias=True)
  )
  (add_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (feed_forward): FeedForward(
    (linear1): Linear(in_features=512, out_features=10, bias=True)
    (relu): ReLU()
    (linear2): Linear(in_features=10, out_features=512, bias=True)
  )
  (add_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (fc): Linear(in_features=512, out_features=16000, bias=True)
)

In [8]:
token_ids = torch.tensor([sp.encode_as_ids("The apple")])
logits = gpt_model(token_ids)

In [None]:


print(logits.shape)
probs = F.softmax(logits, dim=-1)
print(probs)
probs = probs[0]
print(probs.shape)

In [None]:


print(vocab[1])
# Get the index of the word with the highest probability (argmax)
max_value, predicted_word_index = torch.max(probs[:, -1], dim=0)
print(predicted_word_index)
# You may want to convert this index back to the actual word using your vocabulary
predicted_word = vocab[predicted_word_index]
print(predicted_word)

In [9]:
def tell_me_a_story(text, counter):
    counter += 1
    token_ids = torch.tensor([sp.encode_as_ids(text)])
    logits = gpt_model(token_ids)
    probs = F.softmax(logits, dim=-1)
    predicted_word_index = torch.argmax(probs, dim=-1)
    predicted_word = sp.decode(predicted_word_index[0][-1].item())
    if predicted_word == "</s>" or counter == 200:
        return text
    else:
        return tell_me_a_story(text + " " + predicted_word, counter)

In [10]:
tell_me_a_story(" The blue sky", 0)

' The blue sky . They like a little girl named Timmy was very happy and loved to play with a big , but he was a big , " Let \' s mom . He was so excited to the park . He was so happy to the park . He was a big , " I \' t want to the ground . He was a big , " I \' t want to the little girl was so happy . He was so happy . He was so happy and said , " I \' s mom said , " I \' s mom said , " I \' s mom said , " I \' s mom said , " I \' s mom said , " Yes , " You should always be careful and said , " You can \' s mom and the little girl was a great time , " You should always be careful and said , " You are you , " You are sorry .                           '

In [None]:
print(sp.encode(" "))
print(sp.decode(0))

In [None]:
# Example 2D tensor
# Replace this with your actual tensor
tensor_2d = torch.tensor([[1, 2, 3],
                          [4, 5, 6],
                          [7, 8, 9]])

# Shift each row backward by 1 position
shifted_tensor = torch.cat((tensor_2d[:, 1:], tensor_2d[:, 0:1]), dim=1)

print("Original Tensor:")
print(tensor_2d)

print("\nShifted Tensor:")
print(shifted_tensor)