In [1]:
import torch
import torch.nn as nn
import wandb
from datasets import load_dataset
import sentencepiece as spm
import os
import dataset_bes_train
import dataset_bes_val
import tokenizer_bes
import matplotlib.pyplot as plt
import seaborn as sns
import bash_gpt
import time
from torch.autograd import profiler
from torch.utils.tensorboard import SummaryWriter


In [2]:
tk = (tokenizer_bes.TinyTokenizer()).load()
ds = dataset_bes_train.TinyDataset()
val_ds = dataset_bes_val.TinyDataset()
batch_size = 4
dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, collate_fn=ds.collate_fn)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=True, collate_fn=ds.collate_fn)

Found cached dataset parquet (/home/bash1989/.cache/huggingface/datasets/roneneldan___parquet/roneneldan--TinyStories-6ac769f186d7da53/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

Found cached dataset parquet (/home/bash1989/.cache/huggingface/datasets/roneneldan___parquet/roneneldan--TinyStories-6ac769f186d7da53/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
device = "cuda:0" 

In [4]:
checkpoint_dir = 'checkpoints/'
os.makedirs(checkpoint_dir, exist_ok=True)

In [5]:

transformer_model = bash_gpt.GPT().to(device)
loss_function = torch.nn.CrossEntropyLoss()
lr = 3e-4
optimizer = torch.optim.Adam(transformer_model.parameters(), lr=lr, eps=1e-9, betas=(0.9,.98))


In [6]:
total_params = sum(p.numel() for p in transformer_model.parameters())
total_params

26468480

In [7]:
num_epochs = 5
checkpoint_interval = 1


In [8]:

# wandb.login(key="9091f9245cf64052fa6d4eae03190a076fd87fe8")
# wandb.init(name = f'{transformer_model.num_heads} Head(s), params = {total_params:,}',project ='mlx_transformer_GPT', entity = 'basharkabalan',     config={
#     "learning_rate": lr,
#     "architecture": "Trans_Bash",
#     "dataset": "roneneldan/TinyStories",
#     "epochs": num_epochs,
#     })


In [9]:
# test_data_at = wandb.Artifact("test_samples_" + str(wandb.run.id), type="predictions")
# columns=["id", "word", "truth", "guess"]
# test_table = wandb.Table(columns=columns)

In [10]:
%load_ext tensorboard


In [11]:

writer = SummaryWriter("logs/profile_example")

In [12]:
with profiler.profile(record_shapes=True, use_cuda=True) as prof:

    for epoch in range(num_epochs):
        transformer_model.train()
        for idx, batch in enumerate(dl):
    #         start_time = time.time()

            tokens = batch['input'].to(device)
            labels = batch['label'].to(device)
            optimizer.zero_grad()
            output= transformer_model(tokens)
            model_output = output.view(-1, output.size(-1))  # Reshape to [batch_size * seq_length, num_classes]
            true_labels = labels.view(-1)  # Reshape to [batch_size * seq_length]
            loss = loss_function(model_output, true_labels)
            loss.backward()
            optimizer.step()
    #         print(time.time()-start_time)
            if idx % 500 == 0:
                    print(f"train_loss: {loss:.4f}")
    #                 wandb.log({"train_loss": loss})
                    break
            if idx % 5000 == 0: torch.save(transformer_model.state_dict(), f"multi_head_with_pos_encod_weights_{epoch}_{idx}.pt")



        transformer_model.eval()
        val_correct_predictions = torch.tensor(0)
        val_correct_predictions = val_correct_predictions.to(device)
        val_total_predictions = torch.tensor(0)
        val_total_predictions = val_total_predictions.to(device)
        with torch.no_grad():
            for val_idx, val_batch in enumerate(val_dl):
                val_tokens = val_batch['input'].to(device)
                val_labels = val_batch['label'].to(device)
                val_output= transformer_model(val_tokens)
                val_model_output = val_output.view(-1, val_output.size(-1))  # Reshape to [batch_size * seq_length, num_classes]
                val_true_labels = val_labels.view(-1)  # Reshape to [batch_size * seq_length]
                val_loss = loss_function(val_model_output, val_true_labels)
                val_max_indices = torch.argmax(val_model_output, dim=1)
                val_correct_predictions += ((val_max_indices - val_true_labels)==0).sum()
                val_total_predictions += len(val_true_labels)

                if idx % 500 == 0:

                    val_acc = val_correct_predictions/val_total_predictions
                    print(f"train_loss: {val_loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}")
                    break
    #                 wandb.log({"train_loss": loss,"val_acc": val_acc, "val_loss": val_loss})
                    # wandb.run.log_artifact(test_data_at)  






            # if idx%100 == 0:
            #     test_table.add_data(tokens[0][3],  tk.decode(tokens[0][3].item()), tk.decode(true_labels[3].item()),  tk.decode(max_indices[3].item()))
            #     test_data_at.add(test_table, "predictions")






        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': transformer_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
            }, checkpoint_path)
        print(f"Epoch {epoch+1}/{num_epochs}, val_acc: {val_acc}, val_loss: {val_loss.item()}")

    
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# # writer.add_text('Profiler Info', prof.key_averages().table(sort_by='cuda_time_total'))
# writer.add_text('Profiler Info', prof.key_averages().table(sort_by='cuda_time_total'))
# writer.close()


INFO:2023-11-15 07:19:55 2770:2770 init.cpp:158] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti
STAGE:2023-11-15 07:19:55 2770:2770 ActivityProfilerController.cpp:312] Completed Stage: Warm Up


train_loss: 9.8062
train_loss: 8.4662, val_loss: 8.4662, val_acc: 0.1953
Epoch 1/5, val_acc: 0.1953125, val_loss: 8.466161727905273
train_loss: 8.4975
train_loss: 6.8711, val_loss: 6.8711, val_acc: 0.3008
Epoch 2/5, val_acc: 0.30078125, val_loss: 6.871096134185791
train_loss: 6.9025
train_loss: 6.8625, val_loss: 6.8625, val_acc: 0.2718
Epoch 3/5, val_acc: 0.27184465527534485, val_loss: 6.862497329711914
train_loss: 7.4318
train_loss: 7.9319, val_loss: 7.9319, val_acc: 0.1189
Epoch 4/5, val_acc: 0.11887254565954208, val_loss: 7.931860446929932
train_loss: 6.7463
train_loss: 5.7020, val_loss: 5.7020, val_acc: 0.3442
Epoch 5/5, val_acc: 0.34421366453170776, val_loss: 5.70197057723999


STAGE:2023-11-15 07:20:07 2770:2770 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-11-15 07:20:07 2770:2770 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::copy_        60.96%        3.523s        60.96%        3.523s       3.051ms        3.483s        60.24%        3.483s       3.016ms          1155  
                               Optimizer.step#Adam.step         0.47%      27.063ms         5.36%     309.832ms      61.966ms      12.176ms         0.21%     406.252ms      81.250ms             5  
enumerate

In [13]:
%load_ext tensorboard
%tensorboard --logdir logs/profile_example

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard
