In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
import wandb
from tqdm import tqdm
import time
import json
import math
torch.manual_seed(1337)

<torch._C.Generator at 0x109b5af50>

In [2]:
# # initialize wandb
# wandb.init(project="GPT 2 848K")
# wandb.run.tags = ['GPT 1', 'test run']

In [3]:
# pull from local folder
filename = 'tinyshakespeare.txt'
with open(filename, 'r') as f:
    text = f.read()

In [4]:
# TODO: count how many params you're using in this code, and implement chinchilla law to understand how much data you need to ensure you aren't under training
# get vocab
vocab = list(sorted(set(text)))
vocab_size = len(vocab)
# embedding dimensions 
n_emb = 32
learning_rate = 1e-4
block_size = 8
epochs = 5000
# how often to evaluate loss
eval_iter = 200
# number of blocks in the transformer
n_layer = 2
# number of heads in the transformer
n_heads = 2
# each head size is n_emb // n_heads = 32 // 2 = 16
dropout = 0.2 # 20% will be zeroed out
train_test_split = 0.9 # 85% of data will be used for training
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

In [5]:
# character level encoding and decoding
stoi = {c: i for i, c in enumerate(vocab)}
# itos = {i: c for i, c in enumerate(vocab)}
# alternate way of creating decoder func
itos = {i: c for c, i in stoi.items()}
encode = lambda x: [stoi[c] for c in x]
decode = lambda x: ''.join([itos[i] for i in x])

In [6]:
# encode full dataset
data = torch.tensor(encode(text), dtype=torch.long)

# train test split
train_size = int(train_test_split * len(data))
train_data = data[:train_size]
test_data = data[train_size:]

In [7]:
torch.manual_seed(1337)
batch_size = 4 # how many sequences we will process in parallel, each of these sequences is block_size long
block_size = 8 # the length of each sequence

In [8]:
def get_batch(split):
    data = train_data if split == 'train' else test_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [9]:
class AttentionHead(nn.Module):
    '''one head of self-attention'''

    def __init__(self, head_size):
        super().__init__()
        # usually bias is not used in self-attention TODO: understand better why
        self.key = nn.Linear(n_emb, head_size, bias=False)
        self.query = nn.Linear(n_emb, head_size, bias=False)
        self.value = nn.Linear(n_emb, head_size, bias=False)
        # triangular mask to prevent attending to future tokens
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        # using register buffer ensures that tril is not initialized as a param, so it won't be optimized during training
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # BxTxC
        q = self.query(x) # BxTxC
        v = self.value(x) # BxTxC
        # compute attention scores
        # could potentially be optimized by using einsum? TODO: understand how
        # could potentially use lora's code to optimize this
        wei = q @ k.transpose(-2, -1) * C ** -0.5 # BxTxC @ BxCxT (because of transposing second last and last dim of k) --> BxTxT
        # BxTxT: the TxT part of this attention matrix is where the quadratic complexity dependent on context length comes from
        # * C ** -0.5 is the one over root dk scaling factor in the attention formula
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # wherever tril is 0, in that position of wei, replace existing value with -inf
        # :T, :T is sliced to prevent index out of bounds error (for the case where block_size is not equal to T)
        wei = torch.softmax(wei, dim=-1) # TODO: understand why we softmax on the last dim
        wei = self.dropout(wei) # dropout on attention scores, randomly set some of them to 0
        # perform aggregation of values with attention scores
        out = wei @ v # BxTxT @ BxTxC --> BxTxC
        # out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # BxTxC
        # back to the dims we started with
        return out

In [10]:
class MultiHeadAttention(nn.Module):
    '''multi headed self attention'''

    def __init__(self, num_heads, head_size):
        super().__init__() # This initializes nn.Module (parent class from which MultiHeadAttention inherits from) before 
        # initializing anything in this child class
        self.heads = nn.ModuleList([AttentionHead(head_size) for _ in range(num_heads)])
        self.projection = nn.Linear(n_emb, n_emb) # linear layer to project concatenated heads output back to n_emb
        # project back into the residual pathway
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # BxTxC
        out = self.projection(out)
        return self.dropout(out)

In [11]:
class FeedForwardNN(nn.Module):
    '''simple one layer linear nn'''

    def __init__(self, n_emb):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_emb, 4 * n_emb), # add a factor of 4 to n_emb as per GPT-2, just to make it more expressive, increasing complexity and computation
            nn.ReLU(), # TODO: use GELU instead of ReLU
            nn.Linear(4 * n_emb, n_emb), # linear projection back into the residual pathway
            nn.Dropout(dropout) # add right before connetion before residual connection
        )
    
    def forward(self, x):
        return self.net(x)

In [12]:
class Block(nn.Module):
    '''transformer block: create multiple blocks and concatenate them'''

    def __init__(self, n_emb, num_heads):
        super().__init__()
        head_size = n_emb // num_heads
        self.sa = MultiHeadAttention(num_heads, head_size)
        self.ffn = FeedForwardNN(n_emb)
        self.ln1 = nn.LayerNorm(n_emb)
        self.ln2 = nn.LayerNorm(n_emb)

    def forward(self, x):
        x = x + self.sa(self.ln1(x)) # residual connection # TODO: test using layer norm after sa and ffn as in original transformer paper 
        # and understand why there was an improvement in the new method
        x = x + self.ffn(self.ln2(x)) # residual connection (damn that was a very easy change to make)
        return x

In [13]:
class PositionalEncoding:
    def __init__(self, n_emb):
        self.n_emb = n_emb

    def get_sinusoidal_encoding(self, T):
        # Generate position indices
        position = torch.arange(0, T, dtype=torch.float).unsqueeze(1)  # Shape: (T, 1)
        # Generate the scaling terms based on the embedding dimension
        div_term = 10000 ** (-2 * torch.arange(self.n_emb // 2) / self.n_emb)  # Shape: (n_emb // 2,)
        
        # Initialize encoding tensor
        encoding = torch.zeros(T, self.n_emb)  # Shape: (T, n_emb)
        # Apply sine to even indices, cosine to odd indices
        encoding[:, 0::2] = torch.sin(position * div_term)  # Even indices
        encoding[:, 1::2] = torch.cos(position * div_term)  # Odd indices
        return encoding

In [14]:
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token in the lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_emb) # W_E in GPT-2
        self.positional_encoding = PositionalEncoding(n_emb)  # Initialize PositionalEncoding instance
        self.blocks = nn.Sequential(*[Block(n_emb, num_heads=n_heads) for _ in range(n_layer)]) # 4 blocks as per GPT-2 
        # asterisk is used here to unpack the list of blocks so it can be passed as individual elements to nn.Sequential and not as one big list
        # also this is just a simpler representation of the previous thing we did, where we had a list of blocks and we individually called them
        self.lm_head = nn.Linear(n_emb, vocab_size) # W_o in GPT-2



    def forward(self, idx, targets=None):
        B, T = idx.shape # idx and targets are both of shape (batch_size, block_size) aka (B, T)
        token_emb = self.token_embedding_table(idx) # Batch x time x channel (here channel is now n_emb)
        pos_emb = self.positional_encoding.get_sinusoidal_encoding(T).to(token_emb.device)  # Shape: (T, n_emb)
        x = token_emb + pos_emb.unsqueeze(0)  # Broadcasting to match batch dimension (B, T, n_emb)
        #unsqueeze(0) is necessary for broadcasting since pos_emb has (T, d_model) but tok_emb has (B, T, d_model) this helps pair with the batch dimension
        x = self.blocks(x)
        logits = self.lm_head(x) # B, T, vocab size

        if targets is None:
            loss = None
        else:
            # loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1)) # we could do this, but its hard to understand, so
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets) 

        return logits, loss


    # auto regressive generation
    def generate(self, idx, max_new_tokens):
        # idx is BxT
        for _ in range(max_new_tokens):
            # get the last block_size tokens of the idx
            idx_cond = idx[:, -block_size:] # BxT
            logits, loss = self(idx_cond)
            # pluck out last column in time dimension, because this is the generated predictions for what comes next
            logits = logits[:, -1, :] # keep only the last token for each sequence in the batch aka BxC
            probs = F.softmax(logits, dim=-1) # BxC
            # sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1) # Bx1
            # append newly generated token to input idx to obtain new input for next generation iteration
            idx = torch.cat([idx, next_token], dim=-1) # Bx(T+1) # TODO: understand why this is dim=-1
        return idx
        





In [18]:
import plotly.graph_objects as go

# Define parameters
vocab_size = 65  
n_emb = 32
block_size = 8

# Create an instance of the model
model = BigramLanguageModel()

# Create an instance of the PositionalEncoding for plotting purposes
pos_encoding = PositionalEncoding(n_emb)

# Generate sinusoidal encodings for a larger sequence length for smoother curves
T = 256  # Increase the sequence length for smoother sinusoidal curves
positional_encodings = pos_encoding.get_sinusoidal_encoding(T).numpy()
print(positional_encodings[:10])

# Create the Plotly figure
fig = go.Figure()

# Add traces for the first 4 dimensions
for dim in range(4):
    fig.add_trace(go.Scatter(
        y=positional_encodings[:200:2, dim],  # Plot every second position up to 200
        x=list(range(0, 200, 2)),
        mode='lines',
        name=f'Dimension {dim}'
    ))

# Update the layout
fig.update_layout(
    title="Sinusoidal Positional Encodings Across Positions",
    xaxis_title="Position",
    yaxis_title="Encoding Value",
    width=1200,
    height=600,
    template="plotly_white"
)

# Show the plot
fig.show()


[[ 0.00000000e+00  1.00000000e+00  0.00000000e+00  1.00000000e+00
   0.00000000e+00  1.00000000e+00  0.00000000e+00  1.00000000e+00
   0.00000000e+00  1.00000000e+00  0.00000000e+00  1.00000000e+00
   0.00000000e+00  1.00000000e+00  0.00000000e+00  1.00000000e+00
   0.00000000e+00  1.00000000e+00  0.00000000e+00  1.00000000e+00
   0.00000000e+00  1.00000000e+00  0.00000000e+00  1.00000000e+00
   0.00000000e+00  1.00000000e+00  0.00000000e+00  1.00000000e+00
   0.00000000e+00  1.00000000e+00  0.00000000e+00  1.00000000e+00]
 [ 8.41470957e-01  5.40302336e-01  5.33168435e-01  8.46009135e-01
   3.10983598e-01  9.50415313e-01  1.76892191e-01  9.84230220e-01
   9.98334214e-02  9.95004177e-01  5.62044978e-02  9.98419285e-01
   3.16175036e-02  9.99500036e-01  1.77818574e-02  9.99841869e-01
   9.99983307e-03  9.99949992e-01  5.62338345e-03  9.99984205e-01
   3.16227227e-03  9.99994993e-01  1.77827850e-03  9.99998450e-01
   9.99999931e-04  9.99999523e-01  5.62341243e-04  9.99999881e-01
   3.1622

In [19]:
import plotly.graph_objects as go

# Define parameters
vocab_size = 65  
n_emb = 32
block_size = 8

# Create an instance of the model
model = BigramLanguageModel()

# Create an instance of the PositionalEncoding for plotting purposes
pos_encoding = PositionalEncoding(n_emb)

# Generate sinusoidal encodings for a larger sequence length for smoother curves
T = 256  # Increase the sequence length for smoother sinusoidal curves
positional_encodings = pos_encoding.get_sinusoidal_encoding(T).numpy()

# Define positional encodings for the first 100 positions for heatmap
positional_encodings_subset = positional_encodings[:100]  # Select first 100 positions

# Create the heatmap
fig = go.Figure(data=go.Heatmap(
    z=positional_encodings_subset.T,  # Transpose to get dimensions on the y-axis
    x=list(range(100)),  # Positions (x-axis)
    y=[f"Dimension {i}" for i in range(positional_encodings_subset.shape[1])],  # Embedding dimensions (y-axis)
    colorscale="Viridis"  # Choose color scale
))

# Update layout
fig.update_layout(
    title="Heatmap of Sinusoidal Positional Encodings (First 100 Positions)",
    xaxis_title="Position",
    yaxis_title="Embedding Dimension",
    width=1000,
    height=600,
)

# Show the plot
fig.show()


In [20]:
# positional_encodings = get_sinusoidal_encoding(T, n_emb)

# # Plotting
# plt.figure(figsize=(12, 6))
# for i in range(min(T, 4)):  # Plot up to the first 4 positions for clarity
#     plt.plot(positional_encodings[i].numpy(), label=f'Position {i}')
    
# plt.title("Sinusoidal Positional Encodings")
# plt.xlabel("Embedding Dimension")
# plt.ylabel("Encoding Value")
# plt.legend()
# plt.show()

In [21]:
model = BigramLanguageModel()

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # TODO: try adding a lr schedule

In [22]:
# Track best losses and store losses for plotting
best_train_loss = float('inf')
best_val_loss = float('inf')
train_losses = []
val_losses = []

In [23]:
# Training loop
start_time = time.time()
for iter in tqdm(range(epochs), desc="Training Epochs"):
    # Training phase
    model.train()  # Set model to training mode
    xb, yb = get_batch('train')
    logits, train_loss = model(xb, yb)

    # Zero gradients, backward pass, and optimizer step
    optimizer.zero_grad(set_to_none=True)
    train_loss.backward()
    optimizer.step()
    train_losses.append(train_loss.item())

    # Evaluation phase every eval_iter
    if iter % eval_iter == 0:
        model.eval()  # Set model to evaluation mode
        val_losses_list = []

        for _ in range(eval_iter):
            with torch.no_grad():  # Disable gradient calculation
                X_val, Y_val = get_batch('val')
                logits, val_loss = model(X_val, Y_val)
                val_losses_list.append(val_loss.item())
        
        # Calculate mean of validation losses
        avg_val_loss = sum(val_losses_list) / len(val_losses_list)

        # Log and print average train and validation losses
        print(f"Epoch: {iter}, Train Loss: {train_loss.item()}, Val Loss: {avg_val_loss}")
        # wandb.log({
        #     'train_loss': train_loss.item(),
        #     'val_loss': avg_val_loss
        # })

        # Track best losses
        if train_loss.item() < best_train_loss:
            best_train_loss = train_loss.item()
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
        val_losses.append(avg_val_loss)

end_time = time.time()
train_time = end_time - start_time

Training Epochs:   1%|          | 30/5000 [00:00<00:45, 109.81it/s]

Epoch: 0, Train Loss: 4.574395179748535, Val Loss: 4.555676231384277


Training Epochs:   5%|▍         | 241/5000 [00:01<00:21, 218.30it/s]

Epoch: 200, Train Loss: 3.7252931594848633, Val Loss: 3.762098332643509


Training Epochs:   9%|▉         | 449/5000 [00:02<00:20, 221.91it/s]

Epoch: 400, Train Loss: 3.1576247215270996, Val Loss: 3.4322924315929413


Training Epochs:  13%|█▎        | 658/5000 [00:02<00:19, 221.92it/s]

Epoch: 600, Train Loss: 3.1732089519500732, Val Loss: 3.2676460111141203


Training Epochs:  17%|█▋        | 838/5000 [00:03<00:18, 223.06it/s]

Epoch: 800, Train Loss: 3.1691360473632812, Val Loss: 3.1667949187755586


Training Epochs:  21%|██        | 1049/5000 [00:04<00:17, 222.25it/s]

Epoch: 1000, Train Loss: 2.444502115249634, Val Loss: 3.0721088802814482


Training Epochs:  25%|██▍       | 1230/5000 [00:05<00:22, 168.21it/s]

Epoch: 1200, Train Loss: 2.909872055053711, Val Loss: 3.014024313688278


Training Epochs:  29%|██▊       | 1430/5000 [00:06<00:25, 139.52it/s]

Epoch: 1400, Train Loss: 3.087813377380371, Val Loss: 2.9677873992919923


Training Epochs:  33%|███▎      | 1632/5000 [00:08<00:29, 114.21it/s]

Epoch: 1600, Train Loss: 3.1847610473632812, Val Loss: 2.9086490213871


Training Epochs:  37%|███▋      | 1833/5000 [00:09<00:20, 152.98it/s]

Epoch: 1800, Train Loss: 3.439540386199951, Val Loss: 2.867168792486191


Training Epochs:  40%|████      | 2015/5000 [00:10<00:25, 118.04it/s]

Epoch: 2000, Train Loss: 2.9845972061157227, Val Loss: 2.8296107268333435


Training Epochs:  45%|████▍     | 2235/5000 [00:11<00:17, 154.80it/s]

Epoch: 2200, Train Loss: 2.4143636226654053, Val Loss: 2.838078587055206


Training Epochs:  49%|████▉     | 2438/5000 [00:12<00:16, 159.91it/s]

Epoch: 2400, Train Loss: 3.321904420852661, Val Loss: 2.808547922372818


Training Epochs:  53%|█████▎    | 2642/5000 [00:14<00:15, 152.15it/s]

Epoch: 2600, Train Loss: 2.7290852069854736, Val Loss: 2.7713828253746033


Training Epochs:  57%|█████▋    | 2828/5000 [00:15<00:14, 146.49it/s]

Epoch: 2800, Train Loss: 2.423144578933716, Val Loss: 2.7419774532318115


Training Epochs:  61%|██████    | 3033/5000 [00:16<00:15, 124.69it/s]

Epoch: 3000, Train Loss: 2.7415757179260254, Val Loss: 2.7251182270050047


Training Epochs:  65%|██████▍   | 3232/5000 [00:18<00:14, 122.73it/s]

Epoch: 3200, Train Loss: 2.2835707664489746, Val Loss: 2.7273589515686036


Training Epochs:  68%|██████▊   | 3425/5000 [00:19<00:11, 139.02it/s]

Epoch: 3400, Train Loss: 2.604229688644409, Val Loss: 2.6737340849637987


Training Epochs:  73%|███████▎  | 3626/5000 [00:20<00:08, 159.73it/s]

Epoch: 3600, Train Loss: 3.2880592346191406, Val Loss: 2.676696789264679


Training Epochs:  77%|███████▋  | 3831/5000 [00:21<00:07, 156.09it/s]

Epoch: 3800, Train Loss: 2.623208522796631, Val Loss: 2.6665260207653048


Training Epochs:  81%|████████  | 4026/5000 [00:22<00:07, 135.80it/s]

Epoch: 4000, Train Loss: 2.813788414001465, Val Loss: 2.6304199576377867


Training Epochs:  85%|████████▍ | 4229/5000 [00:23<00:05, 150.47it/s]

Epoch: 4200, Train Loss: 2.4248130321502686, Val Loss: 2.651743642091751


Training Epochs:  89%|████████▊ | 4428/5000 [00:25<00:03, 156.22it/s]

Epoch: 4400, Train Loss: 3.055168390274048, Val Loss: 2.6080242240428926


Training Epochs:  93%|█████████▎| 4632/5000 [00:26<00:02, 159.86it/s]

Epoch: 4600, Train Loss: 2.2506155967712402, Val Loss: 2.624803783893585


Training Epochs:  97%|█████████▋| 4829/5000 [00:27<00:01, 140.09it/s]

Epoch: 4800, Train Loss: 2.470313310623169, Val Loss: 2.5888535726070403


Training Epochs: 100%|██████████| 5000/5000 [00:28<00:00, 177.00it/s]


In [24]:
print(100*'*')
# Load best losses from JSON file if it exists
best_losses_file = 'best_losses.json'
try:
    with open(best_losses_file, 'r') as f:
        best_losses = json.load(f)
        best_train_loss = best_losses.get('best_train_loss', best_train_loss)
        best_val_loss = best_losses.get('best_val_loss', best_val_loss)
except FileNotFoundError:
    best_losses = {
        'best_train_loss': best_train_loss,
        'best_val_loss': best_val_loss
    }
    with open(best_losses_file, 'w') as f:
        json.dump(best_losses, f)

****************************************************************************************************


In [25]:
print(f"Generated Text:")
idx = torch.zeros((1,1), dtype=torch.long)
generated_text = decode(model.generate(idx, max_new_tokens=2000)[0].tolist())
print(generated_text)

Generated Text:

HVETaTd!kof.

GI I l: ther d s haipounF GhUSemi:t
Fs ow sautVMbh in heAlesbr,' st -. merowIchisesites maleshiche d S.
Aa aY
Tup mothin .
Whurekned y  o,rhsby,
TL

Tboy,
Ted mee ance
I d mave w t,
:
MI :
.
Ebh,
T
OCfreersd w tho,
Wes deall d
UIRanoEENxXutse C brukfurto y smenpemy hansot.
NNTSdid beranour l fyo t. sarilo q thesndisthor wisenRhaent vst. utof in bokrom
M-coucothr ntUay

y':
T.
FIy Mar wcIjre
Bv?iNAI heresicmergeand , weeny I'hd ting,e the lot?elenyelrst geg measeve tisnsser,
Srpidhadeansd t , i: arN mavethechean d segouiar t.
I ougoud hel ncimgen'D,
KIte thaththanen seul'ore meowth th o ten yen,elikn'd esemaly p hayot t frcethe ic, hib, thet; meseI urun thet Bhathe ngyostoiwlmer's:

y d h co mantrenarat siand Fits ed hh h-amie tr gr lmlatt rsthin:
 -wat, tice,thef
Tthe a:
r woatoult peres hese m wtary biesath c hae
-ey i f I mofabhle me tchesfanou
Tsthe han;ofkere gmhe meanok ho wa?demis tito, f titc. te d;maro h iur aton 'y ;rtou;
W tarhissiunr st f yerth

In [None]:
# Check if current run has better losses
if best_train_loss < best_losses.get('best_train_loss', float('inf')) or best_val_loss < best_losses.get('best_val_loss', float('inf')):
    # Save generated text to file
    with open('generated_shakespeare_text.txt', 'w') as f:
        f.write(generated_text)

    # Update best losses and save to JSON file
    best_losses['best_train_loss'] = best_train_loss
    best_losses['best_val_loss'] = best_val_loss
    with open(best_losses_file, 'w') as f:
        json.dump(best_losses, f)

    # Have wandb save the text file
    wandb.save('generated_shakespeare_text.txt')
    # also save an image of the training and validation loss curves
    plt.plot(train_losses, label='train loss')
    plt.plot(val_losses, label='val loss')
    plt.legend()
    plt.savefig('train_val_loss.png')
    wandb.save('train_val_loss.png')
    print("Current run beat the best losses. Generated text saved.")

else:
    print("Current run did not beat the best losses. Generated text not saved.")
print(100*'*')
print(100*'*')

In [None]:
print(f"Best Train Loss: {best_train_loss}")
print(f"Best Validation Loss: {best_val_loss}")
# show total number of parameters in the model
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in the model: {total_params}")
# show toal number of tokens in the dataset
total_tokens = len(data)
print(f"Total number of tokens in the dataset: {total_tokens}")
print(f"According to Chinchilla Law, you need at least {total_params * 2} tokens to train this model.") # TODO: work on this

In [None]:
# Ensure train_time and other parameters are defined before logging
# wandb.log({
#     'epochs': epochs,
#     "learning_rate": learning_rate,
#     "block_size": block_size,
#     "batch_size": batch_size,
#     "embedding_size": n_emb,
#     "optimizer": "AdamW",
#     "device": device,
#     "vocab_size": vocab_size,
#     "best_train_loss": best_train_loss,
#     "best_val_loss": best_val_loss,
#     'Training Time': train_time, 
#     'dropout': dropout,
#     'n_layer': n_layer,
#     'n_heads': n_heads,
#     'train_test_split': train_test_split,
#     'total_params': total_params
# })

print(f"Total time to train model up to {epochs} epochs: {train_time:.2f} seconds")
wandb.finish()