In [None]:
from google.colab import drive
import os

# 1️⃣ Google Drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
import wandb

In [125]:
import pandas as pd

In [2]:
!pip install plotly



In [3]:
import plotly.graph_objects as go

In [4]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mbirdyyybai[0m ([33mbirdyyybai-university-of-michigan[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
import math
import inspect
from dataclasses import dataclass
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import functional as F

In [6]:
vocab = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=', '+', '&', '*']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
padding_token_index = 13
end_token_index = 12

In [7]:
# create a mapping from chars to ints
stoi = {ch:i for i, ch in enumerate(vocab)}
itos = {i:ch for i, ch in enumerate(vocab)}
encode = lambda s:[stoi[c] for c in s] # encoder: take a string, output a list of ints
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of ints, output a string

print(encode("1+2=3&"))
print(decode(encode("1+2=3&")))

[1, 11, 2, 10, 3, 12]
1+2=3&


In [None]:
# # train test split
# train_set_1 = np.random.choice(np.arange(10), 8, replace=False)
# train_set_2 = np.random.choice(np.arange(10, 100), 72, replace=False)
# train_set_3 = np.random.choice(np.arange(100, 1000), 720, replace=False)
# test_1 = np.setdiff1d(np.arange(10), train_set_1)
# test_2 = np.setdiff1d(np.arange(10, 100), train_set_2)
# test_3 = np.setdiff1d(np.arange(100, 1000), train_set_3)
# test_12 = np.concatenate([test_1, test_2])
# test = np.concatenate([test_1, test_2, test_3])
# print(np.sort(train_set_1))
# print(np.sort(train_set_2))
# print(np.sort(test_1))
# print(np.sort(test_2))
# print(np.sort(test))

[0 1 2 4 5 6 8 9]
[10 11 12 13 14 15 16 17 18 19 21 22 23 24 25 26 27 28 30 31 33 34 35 36
 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 58 60 63 64 65
 66 68 69 70 71 74 77 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 98 99]
[3 7]
[20 29 32 37 57 59 61 62 67 72 73 75 76 78 79 80 96 97]
[  3   7  20  29  32  37  57  59  61  62  67  72  73  75  76  78  79  80
  96  97 102 107 118 121 125 130 145 146 151 152 154 161 165 169 171 180
 192 193 197 198 199 200 204 207 208 210 218 219 241 254 268 284 294 295
 296 299 303 306 308 317 318 319 327 329 340 341 342 343 348 365 366 373
 377 381 386 396 399 401 407 408 412 417 422 424 427 428 438 443 455 456
 459 468 474 477 496 514 520 526 533 534 536 538 547 553 555 558 559 564
 572 575 582 585 593 594 599 601 608 611 620 622 624 628 635 638 639 647
 650 652 654 655 658 659 660 667 685 690 705 707 715 716 723 724 726 728
 744 749 752 755 759 764 765 766 771 772 776 778 783 785 786 793 803 813
 824 826 838 839 840 841 844 845 862 865 87

In [80]:
def get_batch(phase=None, batch_size=32, block_size=35, mode='train'):

    if mode == 'train':
      # random choose a and b from set
      if phase != "mix":
        a = np.random.randint(10**(phase-1), 10**(phase), batch_size)
        b = np.random.randint(10**(phase-1), 10**(phase), batch_size)
        c = a + b
      elif phase == "mix":
        exp_a = np.random.choice(np.arange(1, 7), size=batch_size, p=[0.05, 0.1, 0.10, 0.15, 0.25, 0.35])
        exp_b = np.random.choice(np.arange(1, 7), size=batch_size, p=[0.05, 0.1, 0.10, 0.15, 0.25, 0.35])
        a = np.random.randint(10**(exp_a-1), 10**(exp_a), size=batch_size)
        b = np.random.randint(10**(exp_b-1), 10**(exp_b), size=batch_size)
        c = a + b
    else:
      if phase != "mix":
        a = np.random.randint(10**(phase-1), 10**(phase), batch_size)
        b = np.random.randint(10**(phase-1), 10**(phase), batch_size)
        c = a + b
      elif phase == "mix":
        exp_a = np.random.choice(np.arange(1, 7), size=batch_size, p=[0.05, 0.1, 0.10, 0.15, 0.25, 0.35])
        exp_b = np.random.choice(np.arange(1, 7), size=batch_size, p=[0.05, 0.1, 0.10, 0.15, 0.25, 0.35])
        a = np.random.randint(10**(exp_a-1), 10**(exp_a), size=batch_size)
        b = np.random.randint(10**(exp_b-1), 10**(exp_b), size=batch_size)
        c = a + b

    x_list, y_list = [], []
    for i, j, k in zip(a, b, c):
        # construct X: "i+j=k&"
        x_str = f"{i}+{j}={k}&"
        # print(x_str)
        x_encoded = encode(x_str)
        x_padded = x_encoded + [padding_token_index] * (block_size - len(x_encoded))
        x_list.append(torch.tensor(x_padded, dtype=torch.int64))

        # construct Y: "k&"
        y_encoded = encode(x_str)[1:]
        y_encoded.append(end_token_index)
        y_padded = y_encoded + [padding_token_index] * (block_size - len(y_encoded))
        y_list.append(torch.tensor(y_padded, dtype=torch.int64))

    x_tensor = torch.stack(x_list).to(device)
    y_tensor = torch.stack(y_list).to(device)
    return x_tensor, y_tensor

In [36]:
get_batch(phase="mix")

(tensor([[ 7,  4,  5,  ..., 13, 13, 13],
         [ 9,  3,  5,  ..., 13, 13, 13],
         [ 1,  4,  8,  ..., 13, 13, 13],
         ...,
         [ 3,  2,  7,  ..., 13, 13, 13],
         [ 8,  1,  4,  ..., 13, 13, 13],
         [ 4,  9,  2,  ..., 13, 13, 13]], device='cuda:0'),
 tensor([[ 4,  5,  9,  ..., 13, 13, 13],
         [ 3,  5,  9,  ..., 13, 13, 13],
         [ 4,  8, 11,  ..., 13, 13, 13],
         ...,
         [ 2,  7,  7,  ..., 13, 13, 13],
         [ 1,  4,  7,  ..., 13, 13, 13],
         [ 9,  2,  0,  ..., 13, 13, 13]], device='cuda:0'))

In [37]:
class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias=True): # class constructor
        super().__init__()
        # nn.Parameter, pytorch optimize will update the value of this parameter during training
        self.weight = nn.Parameter(torch.ones(ndim)) # trainable parameter
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None # trainable parameter

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, dropout, block_size, bias=True):
        super().__init__()
        assert n_embd % n_head == 0, "Embedding dimension must be divisible by the number of heads."

        # Store hyperparameters
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout = dropout
        self.block_size = block_size

        # Key, Query, Value projections
        self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)
        # Output projection
        self.c_proj = nn.Linear(n_embd, n_embd, bias=bias)
        # Regularization
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

        # Check for Flash Attention availability
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # Causal mask for slow attention
            self.register_buffer(
                "bias",
                torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)
            )

    def forward(self, x):
        B, T, C = x.size()  # Batch size, sequence length, embedding dimension

        # Compute Q, K, V
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)  # Split into Q, K, V (B, T, n_embd)

        # Reshape for multi-head attention
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, n_head, T, head_size)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, n_head, T, head_size)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, n_head, T, head_size)

        # Flash Attention or fallback to manual implementation
        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=self.dropout if self.training else 0,
                is_causal=True
            )
        else:
            # Manual attention with causal masking
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))  # Scaled dot product
            att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))  # Apply causal mask
            att = F.softmax(att, dim=-1)  # Normalize attention scores
            att = self.attn_dropout(att)
            y = att @ v  # Apply attention weights to values (B, n_head, T, head_size)

        # Reshape back to original format
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # Reassemble heads

        # Output projection and residual dropout
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module): # FFN

    def __init__(self, n_embd, dropout, bias=True):
        super().__init__()
        self.c_fc    = nn.Linear(n_embd, 4 * n_embd, bias=bias)
        self.gelu    = nn.GELU() # nonlinear activation function
        self.c_proj  = nn.Linear(4 * n_embd, n_embd, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, n_embd, n_head, dropout, block_size, bias=True):
        super().__init__()
        # LayerNorm and CausalSelfAttention with explicit parameters
        self.ln_1 = LayerNorm(n_embd, bias=bias)
        self.attn = CausalSelfAttention(n_embd, n_head, dropout, block_size, bias=bias)
        self.ln_2 = LayerNorm(n_embd, bias=bias)
        self.mlp = MLP(n_embd, dropout, bias=bias)  # MLP with explicit parameters

    def forward(self, x):
        # Apply residual connection and pre-normalization
        x = x + self.attn(self.ln_1(x))  # Apply LayerNorm before attention
        x = x + self.mlp(self.ln_2(x))  # Apply LayerNorm before MLP
        return x


class GPT(nn.Module):

    def __init__(self, vocab_size, block_size, n_embd, n_layer, n_head, dropout, bias=True):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        super().__init__()
        assert vocab_size is not None
        assert block_size is not None
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.dropout = dropout
        self.bias = bias

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(vocab_size, n_embd), # token embeddings
            wpe = nn.Embedding(block_size, n_embd), # positional embeddings
            drop = nn.Dropout(dropout),
            h = nn.ModuleList([Block(n_embd, n_head, dropout, block_size, bias=bias) for _ in range(n_layer)]), # a stack of n_layer blocks
            ln_f = LayerNorm(n_embd, bias=bias), # final layer norm
        ))
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) # projects the final transformer output to the vocab size

        # init all weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.cblock_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        logits = self.lm_head(x)

        loss = None

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=13)
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            # loss = None

        return logits, loss

In [81]:
eval_iters = 200

@torch.no_grad()
def estimate_loss(phase, models):
    out = {}
    models.eval()
    for split in ['train', 'val']:
      losses = torch.zeros(eval_iters)
      for k in range(eval_iters):
          X, Y = get_batch(phase, mode=split)
          padding_mask_x = (X != padding_token_index).long()
          logits, loss = models(X, Y)
          losses[k] = loss.item()
      out[split] = losses.mean()
    models.train()
    return out

In [46]:
# batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 35 # what is the maximum context length for predictions?
max_iters = 50000
# num_epochs = 100
eval_interval = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 20
n_embd = 256
n_head = 4
n_layer = 8
dropout = 0.0
# # torch.manual_seed(1337)
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(1337)
bias = True # if using bias inside all Linear layers
vocab_size = len(vocab)

In [40]:
wandb.init(project="transformer", config={
    "learning_rate": 1e-5,
    "batch_size": 32,
    "block_size": 35,
    "optimizer": "AdamW",
    "n_embd": 256,
    "n_head": 4,
    "n_layer": 8,
    "dropout": 0.0,
})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [47]:
def accuracy(model):
    correct = 0
    for j in range(100):

        a = np.random.choice(np.arange(1000000), 1)
        b = np.random.choice(np.arange(1000000), 1)

        c = a + b
        input = f"{a.item()}+{b.item()}="
        context = torch.tensor(encode(input), dtype=torch.long, device=device)
        output = generate(model, context, 100, 1)
        if output == f"{a.item()}+{b.item()}={c.item()}":
            correct += 1
    print(f"Accuracy for addition: {correct / 100} ")
    return correct / 100

In [48]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    """
    Generate a sequence of tokens given an initial sequence.

    Parameters:
        model (nn.Module): The model used for generation.
        idx (torch.Tensor or list): Initial sequence of indices (LongTensor of shape (b,t)).
        max_new_tokens (int): Number of new tokens to generate.
        temperature (float): Scaling factor for logits before softmax.
        top_k (int, optional): If specified, restricts sampling to top k tokens.

    Returns:
        torch.Tensor: The generated sequence.
    """
    idx = idx.unsqueeze(0) if idx.dim() == 1 else idx
    idx = torch.tensor(idx, device=model.device) if not isinstance(idx, torch.Tensor) else idx.to(model.device)

    for _ in range(max_new_tokens):
        # Ensure context length does not exceed model's block size
        idx_cond = idx if idx.size(1) <= model.block_size else idx[:, -model.block_size:]

        # Forward pass to get logits
        logits, _ = model(idx_cond)

        # Extract logits for the last token and apply temperature scaling
        logits = logits[:, -1, :] / temperature

        # Apply top-k filtering if necessary
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')

        # Convert logits to probabilities
        probs = F.softmax(logits, dim=-1)

        # Sample next token
        idx_next = torch.multinomial(probs, num_samples=1)

        if idx_next == end_token_index:
            break
        # Append sampled token to sequence

        # Append sampled token to sequence
        idx = torch.cat((idx, idx_next), dim=1)

    return decode(idx.tolist()[0])


In [49]:
model = GPT(vocab_size, block_size, n_embd, n_layer, n_head, dropout, bias=bias)
m = model.to(device)

In [50]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [82]:
wandb.init(project="transformer", config={
    "learning_rate": 1e-5,
    "batch_size": 32,
    "block_size": 35,
    "optimizer": "AdamW",
    "n_embd": 256,
    "n_head": 4,
    "n_layer": 8,
    "dropout": 0.0
})

In [83]:
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

phase = "mix"
best_acc = 0
counter = 0
best_loss = float('inf')
val_loss_list = []
acc_list = []

patience = 50

for iter in tqdm(range(max_iters), desc="Training Progress"):
    # if iter > 1000:
    #   phase = 2
    # if iter > 3000:
    #   phase = 3
    # if iter > 5000:
    #   phase = 4
    # if iter > 7000:
    #   phase = 5
    # if iter > 9000:
    #   phase = 6
    # if iter > 11000:
    #   phase = "mix"

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(phase, model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        log_dict = {"Loss": losses['val']}
        val_loss_list.append(round(losses['val'].item(), 4))

        if phase == "mix":
            # acc = accuracy(model)

            # acc_list.append(acc)
            # log_dict["Accuracy"] = acc

            if losses['val'] < best_loss:
                counter = 0
                # best_acc = max(best_acc, acc)
                best_loss = min(best_loss, losses['val'])
            else:
                counter += 1
                if counter >= patience:
                    print(f"Early Stopping at iteration {iter}")
                    break

        # record to W&B
        wandb.log(log_dict)

    # sample a batch of data

    xb, yb = get_batch(phase)

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


6.33472 M parameters


Training Progress:   0%|          | 7/50000 [00:02<4:05:00,  3.40it/s] 

step 0: train loss 1.2472, val loss 1.2475


Training Progress:   0%|          | 109/50000 [00:07<1:30:58,  9.14it/s]

step 100: train loss 1.2454, val loss 1.2453


Training Progress:   0%|          | 205/50000 [00:11<2:04:08,  6.69it/s]

step 200: train loss 1.2478, val loss 1.2476


Training Progress:   1%|          | 307/50000 [00:16<1:31:40,  9.03it/s]

step 300: train loss 1.2447, val loss 1.2452


Training Progress:   1%|          | 409/50000 [00:20<1:31:14,  9.06it/s]

step 400: train loss 1.2490, val loss 1.2474


Training Progress:   1%|          | 505/50000 [00:25<2:01:55,  6.77it/s]

step 500: train loss 1.2505, val loss 1.2492


Training Progress:   1%|          | 607/50000 [00:29<1:29:05,  9.24it/s]

step 600: train loss 1.2460, val loss 1.2490


Training Progress:   1%|▏         | 709/50000 [00:34<1:28:36,  9.27it/s]

step 700: train loss 1.2469, val loss 1.2461


Training Progress:   2%|▏         | 805/50000 [00:38<1:58:22,  6.93it/s]

step 800: train loss 1.2471, val loss 1.2462


Training Progress:   2%|▏         | 907/50000 [00:42<1:27:05,  9.39it/s]

step 900: train loss 1.2453, val loss 1.2467


Training Progress:   2%|▏         | 1009/50000 [00:47<1:26:31,  9.44it/s]

step 1000: train loss 1.2462, val loss 1.2447


Training Progress:   2%|▏         | 1105/50000 [00:51<1:56:20,  7.00it/s]

step 1100: train loss 1.2486, val loss 1.2498


Training Progress:   2%|▏         | 1207/50000 [00:55<1:26:06,  9.44it/s]

step 1200: train loss 1.2479, val loss 1.2466


Training Progress:   3%|▎         | 1309/50000 [01:00<1:25:41,  9.47it/s]

step 1300: train loss 1.2459, val loss 1.2464


Training Progress:   3%|▎         | 1411/50000 [01:04<1:25:25,  9.48it/s]

step 1400: train loss 1.2463, val loss 1.2454


Training Progress:   3%|▎         | 1507/50000 [01:08<1:25:20,  9.47it/s]

step 1500: train loss 1.2466, val loss 1.2472


Training Progress:   3%|▎         | 1609/50000 [01:13<1:25:07,  9.47it/s]

step 1600: train loss 1.2461, val loss 1.2459


Training Progress:   3%|▎         | 1705/50000 [01:17<1:55:38,  6.96it/s]

step 1700: train loss 1.2470, val loss 1.2457


Training Progress:   4%|▎         | 1807/50000 [01:22<1:25:55,  9.35it/s]

step 1800: train loss 1.2438, val loss 1.2455


Training Progress:   4%|▍         | 1909/50000 [01:26<1:25:34,  9.37it/s]

step 1900: train loss 1.2471, val loss 1.2477


Training Progress:   4%|▍         | 2005/50000 [01:30<1:55:52,  6.90it/s]

step 2000: train loss 1.2465, val loss 1.2464


Training Progress:   4%|▍         | 2107/50000 [01:35<1:25:46,  9.31it/s]

step 2100: train loss 1.2438, val loss 1.2452


Training Progress:   4%|▍         | 2209/50000 [01:39<1:25:33,  9.31it/s]

step 2200: train loss 1.2464, val loss 1.2462


Training Progress:   5%|▍         | 2305/50000 [01:43<1:55:40,  6.87it/s]

step 2300: train loss 1.2459, val loss 1.2433


Training Progress:   5%|▍         | 2407/50000 [01:48<1:25:30,  9.28it/s]

step 2400: train loss 1.2470, val loss 1.2465


Training Progress:   5%|▌         | 2509/50000 [01:52<1:25:04,  9.30it/s]

step 2500: train loss 1.2445, val loss 1.2454


Training Progress:   5%|▌         | 2605/50000 [01:57<1:55:03,  6.86it/s]

step 2600: train loss 1.2461, val loss 1.2463


Training Progress:   5%|▌         | 2707/50000 [02:01<1:24:33,  9.32it/s]

step 2700: train loss 1.2495, val loss 1.2489


Training Progress:   6%|▌         | 2809/50000 [02:06<1:24:04,  9.35it/s]

step 2800: train loss 1.2451, val loss 1.2472


Training Progress:   6%|▌         | 2905/50000 [02:10<1:53:20,  6.93it/s]

step 2900: train loss 1.2495, val loss 1.2467


Training Progress:   6%|▌         | 3007/50000 [02:14<1:23:27,  9.38it/s]

step 3000: train loss 1.2446, val loss 1.2450


Training Progress:   6%|▌         | 3109/50000 [02:19<1:23:18,  9.38it/s]

step 3100: train loss 1.2465, val loss 1.2469


Training Progress:   6%|▋         | 3205/50000 [02:23<1:52:16,  6.95it/s]

step 3200: train loss 1.2458, val loss 1.2485


Training Progress:   7%|▋         | 3307/50000 [02:27<1:22:35,  9.42it/s]

step 3300: train loss 1.2442, val loss 1.2445


Training Progress:   7%|▋         | 3409/50000 [02:32<1:22:42,  9.39it/s]

step 3400: train loss 1.2461, val loss 1.2473


Training Progress:   7%|▋         | 3505/50000 [02:36<1:51:35,  6.94it/s]

step 3500: train loss 1.2455, val loss 1.2449


Training Progress:   7%|▋         | 3607/50000 [02:41<1:22:27,  9.38it/s]

step 3600: train loss 1.2456, val loss 1.2428


Training Progress:   7%|▋         | 3709/50000 [02:45<1:22:02,  9.40it/s]

step 3700: train loss 1.2462, val loss 1.2441


Training Progress:   8%|▊         | 3805/50000 [02:49<1:50:55,  6.94it/s]

step 3800: train loss 1.2442, val loss 1.2459


Training Progress:   8%|▊         | 3907/50000 [02:54<1:21:57,  9.37it/s]

step 3900: train loss 1.2456, val loss 1.2451


Training Progress:   8%|▊         | 4009/50000 [02:58<1:21:49,  9.37it/s]

step 4000: train loss 1.2428, val loss 1.2454


Training Progress:   8%|▊         | 4105/50000 [03:02<1:50:18,  6.93it/s]

step 4100: train loss 1.2444, val loss 1.2434


Training Progress:   8%|▊         | 4207/50000 [03:07<1:21:45,  9.34it/s]

step 4200: train loss 1.2494, val loss 1.2475


Training Progress:   9%|▊         | 4309/50000 [03:11<1:21:32,  9.34it/s]

step 4300: train loss 1.2441, val loss 1.2440


Training Progress:   9%|▉         | 4405/50000 [03:16<1:50:07,  6.90it/s]

step 4400: train loss 1.2444, val loss 1.2459


Training Progress:   9%|▉         | 4507/50000 [03:20<1:21:17,  9.33it/s]

step 4500: train loss 1.2440, val loss 1.2444


Training Progress:   9%|▉         | 4609/50000 [03:24<1:20:52,  9.35it/s]

step 4600: train loss 1.2462, val loss 1.2481


Training Progress:   9%|▉         | 4705/50000 [03:29<1:49:35,  6.89it/s]

step 4700: train loss 1.2454, val loss 1.2469


Training Progress:  10%|▉         | 4807/50000 [03:33<1:20:32,  9.35it/s]

step 4800: train loss 1.2444, val loss 1.2443


Training Progress:  10%|▉         | 4909/50000 [03:38<1:20:20,  9.35it/s]

step 4900: train loss 1.2431, val loss 1.2466


Training Progress:  10%|█         | 5005/50000 [03:42<1:48:56,  6.88it/s]

step 5000: train loss 1.2446, val loss 1.2461


Training Progress:  10%|█         | 5107/50000 [03:46<1:19:47,  9.38it/s]

step 5100: train loss 1.2441, val loss 1.2432


Training Progress:  10%|█         | 5209/50000 [03:51<1:19:38,  9.37it/s]

step 5200: train loss 1.2431, val loss 1.2433


Training Progress:  11%|█         | 5311/50000 [03:55<1:19:29,  9.37it/s]

step 5300: train loss 1.2458, val loss 1.2448


Training Progress:  11%|█         | 5407/50000 [03:59<1:19:14,  9.38it/s]

step 5400: train loss 1.2442, val loss 1.2455


Training Progress:  11%|█         | 5509/50000 [04:04<1:18:54,  9.40it/s]

step 5500: train loss 1.2465, val loss 1.2466


Training Progress:  11%|█         | 5605/50000 [04:08<1:46:39,  6.94it/s]

step 5600: train loss 1.2473, val loss 1.2467


Training Progress:  11%|█▏        | 5707/50000 [04:13<1:18:50,  9.36it/s]

step 5700: train loss 1.2464, val loss 1.2453


Training Progress:  12%|█▏        | 5809/50000 [04:17<1:18:30,  9.38it/s]

step 5800: train loss 1.2458, val loss 1.2449


Training Progress:  12%|█▏        | 5905/50000 [04:21<1:46:02,  6.93it/s]

step 5900: train loss 1.2437, val loss 1.2444


Training Progress:  12%|█▏        | 6007/50000 [04:26<1:18:06,  9.39it/s]

step 6000: train loss 1.2460, val loss 1.2455


Training Progress:  12%|█▏        | 6109/50000 [04:30<1:18:14,  9.35it/s]

step 6100: train loss 1.2438, val loss 1.2441


Training Progress:  12%|█▏        | 6205/50000 [04:34<1:45:25,  6.92it/s]

step 6200: train loss 1.2438, val loss 1.2423


Training Progress:  13%|█▎        | 6307/50000 [04:39<1:17:47,  9.36it/s]

step 6300: train loss 1.2456, val loss 1.2440


Training Progress:  13%|█▎        | 6409/50000 [04:43<1:17:23,  9.39it/s]

step 6400: train loss 1.2506, val loss 1.2488


Training Progress:  13%|█▎        | 6505/50000 [04:47<1:44:29,  6.94it/s]

step 6500: train loss 1.2434, val loss 1.2438


Training Progress:  13%|█▎        | 6607/50000 [04:52<1:17:18,  9.35it/s]

step 6600: train loss 1.2461, val loss 1.2441


Training Progress:  13%|█▎        | 6709/50000 [04:56<1:17:18,  9.33it/s]

step 6700: train loss 1.2507, val loss 1.2484


Training Progress:  14%|█▎        | 6805/50000 [05:01<1:43:56,  6.93it/s]

step 6800: train loss 1.2457, val loss 1.2449


Training Progress:  14%|█▍        | 6907/50000 [05:05<1:16:38,  9.37it/s]

step 6900: train loss 1.2452, val loss 1.2450


Training Progress:  14%|█▍        | 7009/50000 [05:09<1:16:22,  9.38it/s]

step 7000: train loss 1.2457, val loss 1.2444


Training Progress:  14%|█▍        | 7105/50000 [05:14<1:43:11,  6.93it/s]

step 7100: train loss 1.2453, val loss 1.2451


Training Progress:  14%|█▍        | 7207/50000 [05:18<1:16:08,  9.37it/s]

step 7200: train loss 1.2453, val loss 1.2449


Training Progress:  15%|█▍        | 7309/50000 [05:23<1:15:48,  9.39it/s]

step 7300: train loss 1.2451, val loss 1.2462


Training Progress:  15%|█▍        | 7405/50000 [05:27<1:42:30,  6.93it/s]

step 7400: train loss 1.2438, val loss 1.2436


Training Progress:  15%|█▌        | 7507/50000 [05:31<1:15:31,  9.38it/s]

step 7500: train loss 1.2445, val loss 1.2444


Training Progress:  15%|█▌        | 7609/50000 [05:36<1:15:34,  9.35it/s]

step 7600: train loss 1.2426, val loss 1.2453


Training Progress:  15%|█▌        | 7705/50000 [05:40<1:42:14,  6.89it/s]

step 7700: train loss 1.2458, val loss 1.2458


Training Progress:  16%|█▌        | 7807/50000 [05:44<1:15:02,  9.37it/s]

step 7800: train loss 1.2504, val loss 1.2493


Training Progress:  16%|█▌        | 7909/50000 [05:49<1:15:01,  9.35it/s]

step 7900: train loss 1.2437, val loss 1.2439


Training Progress:  16%|█▌        | 8011/50000 [05:53<1:14:39,  9.37it/s]

step 8000: train loss 1.2445, val loss 1.2441


Training Progress:  16%|█▌        | 8107/50000 [05:58<1:14:26,  9.38it/s]

step 8100: train loss 1.2435, val loss 1.2424


Training Progress:  16%|█▋        | 8209/50000 [06:02<1:14:18,  9.37it/s]

step 8200: train loss 1.2437, val loss 1.2438


Training Progress:  17%|█▋        | 8305/50000 [06:06<1:40:23,  6.92it/s]

step 8300: train loss 1.2452, val loss 1.2437


Training Progress:  17%|█▋        | 8407/50000 [06:11<1:14:00,  9.37it/s]

step 8400: train loss 1.2449, val loss 1.2471


Training Progress:  17%|█▋        | 8509/50000 [06:15<1:13:53,  9.36it/s]

step 8500: train loss 1.2440, val loss 1.2443


Training Progress:  17%|█▋        | 8605/50000 [06:19<1:39:23,  6.94it/s]

step 8600: train loss 1.2452, val loss 1.2447


Training Progress:  17%|█▋        | 8707/50000 [06:24<1:13:35,  9.35it/s]

step 8700: train loss 1.2454, val loss 1.2442


Training Progress:  18%|█▊        | 8809/50000 [06:28<1:13:24,  9.35it/s]

step 8800: train loss 1.2435, val loss 1.2435


Training Progress:  18%|█▊        | 8905/50000 [06:33<1:38:55,  6.92it/s]

step 8900: train loss 1.2436, val loss 1.2428


Training Progress:  18%|█▊        | 9007/50000 [06:37<1:13:09,  9.34it/s]

step 9000: train loss 1.2471, val loss 1.2472


Training Progress:  18%|█▊        | 9109/50000 [06:41<1:12:51,  9.35it/s]

step 9100: train loss 1.2427, val loss 1.2422


Training Progress:  18%|█▊        | 9205/50000 [06:46<1:38:08,  6.93it/s]

step 9200: train loss 1.2432, val loss 1.2444


Training Progress:  19%|█▊        | 9307/50000 [06:50<1:12:28,  9.36it/s]

step 9300: train loss 1.2435, val loss 1.2438


Training Progress:  19%|█▉        | 9409/50000 [06:55<1:12:05,  9.38it/s]

step 9400: train loss 1.2435, val loss 1.2445


Training Progress:  19%|█▉        | 9505/50000 [06:59<1:37:28,  6.92it/s]

step 9500: train loss 1.2440, val loss 1.2425


Training Progress:  19%|█▉        | 9607/50000 [07:03<1:12:08,  9.33it/s]

step 9600: train loss 1.2425, val loss 1.2416


Training Progress:  19%|█▉        | 9709/50000 [07:08<1:11:37,  9.38it/s]

step 9700: train loss 1.2434, val loss 1.2426


Training Progress:  20%|█▉        | 9805/50000 [07:12<1:37:04,  6.90it/s]

step 9800: train loss 1.2453, val loss 1.2449


Training Progress:  20%|█▉        | 9907/50000 [07:16<1:11:11,  9.39it/s]

step 9900: train loss 1.2433, val loss 1.2452


Training Progress:  20%|██        | 10009/50000 [07:21<1:11:06,  9.37it/s]

step 10000: train loss 1.2451, val loss 1.2448


Training Progress:  20%|██        | 10105/50000 [07:25<1:36:17,  6.91it/s]

step 10100: train loss 1.2457, val loss 1.2456


Training Progress:  20%|██        | 10207/50000 [07:30<1:10:44,  9.38it/s]

step 10200: train loss 1.2431, val loss 1.2436


Training Progress:  21%|██        | 10309/50000 [07:34<1:10:36,  9.37it/s]

step 10300: train loss 1.2483, val loss 1.2472


Training Progress:  21%|██        | 10405/50000 [07:38<1:35:26,  6.91it/s]

step 10400: train loss 1.2439, val loss 1.2450


Training Progress:  21%|██        | 10507/50000 [07:43<1:10:14,  9.37it/s]

step 10500: train loss 1.2443, val loss 1.2448


Training Progress:  21%|██        | 10609/50000 [07:47<1:10:10,  9.36it/s]

step 10600: train loss 1.2453, val loss 1.2452


Training Progress:  21%|██▏       | 10705/50000 [07:51<1:34:31,  6.93it/s]

step 10700: train loss 1.2440, val loss 1.2453


Training Progress:  22%|██▏       | 10807/50000 [07:56<1:09:45,  9.36it/s]

step 10800: train loss 1.2449, val loss 1.2437


Training Progress:  22%|██▏       | 10909/50000 [08:00<1:09:25,  9.38it/s]

step 10900: train loss 1.2469, val loss 1.2455


Training Progress:  22%|██▏       | 11005/50000 [08:05<1:33:47,  6.93it/s]

step 11000: train loss 1.2456, val loss 1.2460


Training Progress:  22%|██▏       | 11107/50000 [08:09<1:09:21,  9.35it/s]

step 11100: train loss 1.2423, val loss 1.2421


Training Progress:  22%|██▏       | 11209/50000 [08:13<1:08:52,  9.39it/s]

step 11200: train loss 1.2433, val loss 1.2424


Training Progress:  23%|██▎       | 11305/50000 [08:18<1:33:02,  6.93it/s]

step 11300: train loss 1.2420, val loss 1.2432


Training Progress:  23%|██▎       | 11407/50000 [08:22<1:09:01,  9.32it/s]

step 11400: train loss 1.2490, val loss 1.2486


Training Progress:  23%|██▎       | 11509/50000 [08:27<1:08:27,  9.37it/s]

step 11500: train loss 1.2448, val loss 1.2421


Training Progress:  23%|██▎       | 11605/50000 [08:31<1:32:26,  6.92it/s]

step 11600: train loss 1.2442, val loss 1.2441


Training Progress:  23%|██▎       | 11707/50000 [08:35<1:08:15,  9.35it/s]

step 11700: train loss 1.2455, val loss 1.2440


Training Progress:  24%|██▎       | 11809/50000 [08:40<1:07:58,  9.36it/s]

step 11800: train loss 1.2435, val loss 1.2442


Training Progress:  24%|██▍       | 11905/50000 [08:44<1:31:53,  6.91it/s]

step 11900: train loss 1.2425, val loss 1.2405


Training Progress:  24%|██▍       | 12007/50000 [08:48<1:07:36,  9.37it/s]

step 12000: train loss 1.2440, val loss 1.2425


Training Progress:  24%|██▍       | 12109/50000 [08:53<1:07:15,  9.39it/s]

step 12100: train loss 1.2435, val loss 1.2434


Training Progress:  24%|██▍       | 12205/50000 [08:57<1:31:10,  6.91it/s]

step 12200: train loss 1.2419, val loss 1.2418


Training Progress:  25%|██▍       | 12307/50000 [09:02<1:07:01,  9.37it/s]

step 12300: train loss 1.2411, val loss 1.2424


Training Progress:  25%|██▍       | 12409/50000 [09:06<1:06:54,  9.36it/s]

step 12400: train loss 1.2461, val loss 1.2465


Training Progress:  25%|██▌       | 12505/50000 [09:10<1:30:33,  6.90it/s]

step 12500: train loss 1.2495, val loss 1.2501


Training Progress:  25%|██▌       | 12607/50000 [09:15<1:06:29,  9.37it/s]

step 12600: train loss 1.2428, val loss 1.2431


Training Progress:  25%|██▌       | 12709/50000 [09:19<1:06:17,  9.38it/s]

step 12700: train loss 1.2430, val loss 1.2428


Training Progress:  26%|██▌       | 12805/50000 [09:23<1:29:35,  6.92it/s]

step 12800: train loss 1.2413, val loss 1.2425


Training Progress:  26%|██▌       | 12907/50000 [09:28<1:05:55,  9.38it/s]

step 12900: train loss 1.2421, val loss 1.2417


Training Progress:  26%|██▌       | 13009/50000 [09:32<1:06:09,  9.32it/s]

step 13000: train loss 1.2441, val loss 1.2431


Training Progress:  26%|██▌       | 13105/50000 [09:37<1:28:44,  6.93it/s]

step 13100: train loss 1.2439, val loss 1.2424


Training Progress:  26%|██▋       | 13207/50000 [09:41<1:05:44,  9.33it/s]

step 13200: train loss 1.2429, val loss 1.2414


Training Progress:  27%|██▋       | 13309/50000 [09:45<1:05:23,  9.35it/s]

step 13300: train loss 1.2441, val loss 1.2443


Training Progress:  27%|██▋       | 13405/50000 [09:50<1:27:54,  6.94it/s]

step 13400: train loss 1.2443, val loss 1.2449


Training Progress:  27%|██▋       | 13507/50000 [09:54<1:04:53,  9.37it/s]

step 13500: train loss 1.2442, val loss 1.2445


Training Progress:  27%|██▋       | 13609/50000 [09:59<1:04:39,  9.38it/s]

step 13600: train loss 1.2432, val loss 1.2434


Training Progress:  27%|██▋       | 13705/50000 [10:03<1:27:25,  6.92it/s]

step 13700: train loss 1.2443, val loss 1.2439


Training Progress:  28%|██▊       | 13807/50000 [10:07<1:04:33,  9.34it/s]

step 13800: train loss 1.2434, val loss 1.2436


Training Progress:  28%|██▊       | 13909/50000 [10:12<1:04:05,  9.38it/s]

step 13900: train loss 1.2418, val loss 1.2424


Training Progress:  28%|██▊       | 14005/50000 [10:16<1:26:39,  6.92it/s]

step 14000: train loss 1.2477, val loss 1.2472


Training Progress:  28%|██▊       | 14107/50000 [10:20<1:04:00,  9.34it/s]

step 14100: train loss 1.2437, val loss 1.2415


Training Progress:  28%|██▊       | 14209/50000 [10:25<1:03:33,  9.38it/s]

step 14200: train loss 1.2425, val loss 1.2410


Training Progress:  29%|██▊       | 14305/50000 [10:29<1:26:01,  6.92it/s]

step 14300: train loss 1.2427, val loss 1.2423


Training Progress:  29%|██▉       | 14407/50000 [10:34<1:03:24,  9.36it/s]

step 14400: train loss 1.2449, val loss 1.2427


Training Progress:  29%|██▉       | 14509/50000 [10:38<1:03:17,  9.34it/s]

step 14500: train loss 1.2428, val loss 1.2433


Training Progress:  29%|██▉       | 14605/50000 [10:42<1:25:47,  6.88it/s]

step 14600: train loss 1.2420, val loss 1.2434


Training Progress:  29%|██▉       | 14707/50000 [10:47<1:02:59,  9.34it/s]

step 14700: train loss 1.2429, val loss 1.2428


Training Progress:  30%|██▉       | 14809/50000 [10:51<1:02:49,  9.34it/s]

step 14800: train loss 1.2455, val loss 1.2447


Training Progress:  30%|██▉       | 14905/50000 [10:56<1:25:07,  6.87it/s]

step 14900: train loss 1.2427, val loss 1.2419


Training Progress:  30%|███       | 15007/50000 [11:00<1:02:13,  9.37it/s]

step 15000: train loss 1.2432, val loss 1.2437


Training Progress:  30%|███       | 15109/50000 [11:04<1:02:05,  9.37it/s]

step 15100: train loss 1.2423, val loss 1.2422


Training Progress:  30%|███       | 15205/50000 [11:09<1:23:56,  6.91it/s]

step 15200: train loss 1.2439, val loss 1.2431


Training Progress:  31%|███       | 15307/50000 [11:13<1:01:34,  9.39it/s]

step 15300: train loss 1.2414, val loss 1.2419


Training Progress:  31%|███       | 15409/50000 [11:18<1:01:43,  9.34it/s]

step 15400: train loss 1.2414, val loss 1.2425


Training Progress:  31%|███       | 15505/50000 [11:22<1:22:51,  6.94it/s]

step 15500: train loss 1.2433, val loss 1.2421


Training Progress:  31%|███       | 15607/50000 [11:26<1:00:59,  9.40it/s]

step 15600: train loss 1.2410, val loss 1.2407


Training Progress:  31%|███▏      | 15709/50000 [11:31<1:00:59,  9.37it/s]

step 15700: train loss 1.2426, val loss 1.2429


Training Progress:  32%|███▏      | 15805/50000 [11:35<1:22:23,  6.92it/s]

step 15800: train loss 1.2427, val loss 1.2421


Training Progress:  32%|███▏      | 15907/50000 [11:39<1:00:43,  9.36it/s]

step 15900: train loss 1.2412, val loss 1.2406


Training Progress:  32%|███▏      | 16009/50000 [11:44<1:00:22,  9.38it/s]

step 16000: train loss 1.2431, val loss 1.2438


Training Progress:  32%|███▏      | 16105/50000 [11:48<1:21:32,  6.93it/s]

step 16100: train loss 1.2433, val loss 1.2439


Training Progress:  32%|███▏      | 16207/50000 [11:53<1:00:27,  9.32it/s]

step 16200: train loss 1.2414, val loss 1.2411


Training Progress:  33%|███▎      | 16309/50000 [11:57<59:58,  9.36it/s]  

step 16300: train loss 1.2436, val loss 1.2442


Training Progress:  33%|███▎      | 16405/50000 [12:01<1:20:59,  6.91it/s]

step 16400: train loss 1.2410, val loss 1.2409


Training Progress:  33%|███▎      | 16507/50000 [12:06<59:46,  9.34it/s]  

step 16500: train loss 1.2406, val loss 1.2422


Training Progress:  33%|███▎      | 16609/50000 [12:10<59:20,  9.38it/s]  

step 16600: train loss 1.2416, val loss 1.2423


Training Progress:  33%|███▎      | 16705/50000 [12:14<1:20:18,  6.91it/s]

step 16700: train loss 1.2437, val loss 1.2431


Training Progress:  34%|███▎      | 16807/50000 [12:19<58:57,  9.38it/s]  

step 16800: train loss 1.2407, val loss 1.2402


Training Progress:  34%|███▍      | 16909/50000 [12:23<58:48,  9.38it/s]  

step 16900: train loss 1.2410, val loss 1.2415


Training Progress:  34%|███▍      | 17005/50000 [12:28<1:19:15,  6.94it/s]

step 17000: train loss 1.2418, val loss 1.2410


Training Progress:  34%|███▍      | 17107/50000 [12:32<58:23,  9.39it/s]  

step 17100: train loss 1.2472, val loss 1.2473


Training Progress:  34%|███▍      | 17209/50000 [12:36<58:21,  9.37it/s]  

step 17200: train loss 1.2453, val loss 1.2448


Training Progress:  35%|███▍      | 17305/50000 [12:41<1:18:58,  6.90it/s]

step 17300: train loss 1.2409, val loss 1.2427


Training Progress:  35%|███▍      | 17407/50000 [12:45<58:15,  9.33it/s]  

step 17400: train loss 1.2399, val loss 1.2418


Training Progress:  35%|███▌      | 17509/50000 [12:50<57:50,  9.36it/s]  

step 17500: train loss 1.2434, val loss 1.2427


Training Progress:  35%|███▌      | 17605/50000 [12:54<1:17:55,  6.93it/s]

step 17600: train loss 1.2416, val loss 1.2418


Training Progress:  35%|███▌      | 17707/50000 [12:58<57:25,  9.37it/s]  

step 17700: train loss 1.2427, val loss 1.2450


Training Progress:  36%|███▌      | 17809/50000 [13:03<57:26,  9.34it/s]  

step 17800: train loss 1.2406, val loss 1.2418


Training Progress:  36%|███▌      | 17905/50000 [13:07<1:17:06,  6.94it/s]

step 17900: train loss 1.2415, val loss 1.2414


Training Progress:  36%|███▌      | 18007/50000 [13:11<56:49,  9.38it/s]  

step 18000: train loss 1.2395, val loss 1.2400


Training Progress:  36%|███▌      | 18109/50000 [13:16<56:46,  9.36it/s]  

step 18100: train loss 1.2420, val loss 1.2423


Training Progress:  36%|███▋      | 18205/50000 [13:20<1:16:22,  6.94it/s]

step 18200: train loss 1.2408, val loss 1.2420


Training Progress:  37%|███▋      | 18307/50000 [13:24<56:23,  9.37it/s]  

step 18300: train loss 1.2401, val loss 1.2404


Training Progress:  37%|███▋      | 18409/50000 [13:29<56:07,  9.38it/s]  

step 18400: train loss 1.2415, val loss 1.2414


Training Progress:  37%|███▋      | 18505/50000 [13:33<1:15:51,  6.92it/s]

step 18500: train loss 1.2422, val loss 1.2412


Training Progress:  37%|███▋      | 18607/50000 [13:38<55:50,  9.37it/s]  

step 18600: train loss 1.2401, val loss 1.2411


Training Progress:  37%|███▋      | 18709/50000 [13:42<55:35,  9.38it/s]  

step 18700: train loss 1.2416, val loss 1.2399


Training Progress:  38%|███▊      | 18805/50000 [13:46<1:14:50,  6.95it/s]

step 18800: train loss 1.2421, val loss 1.2428


Training Progress:  38%|███▊      | 18907/50000 [13:51<55:20,  9.36it/s]  

step 18900: train loss 1.2420, val loss 1.2429


Training Progress:  38%|███▊      | 19009/50000 [13:55<55:14,  9.35it/s]  

step 19000: train loss 1.2435, val loss 1.2436


Training Progress:  38%|███▊      | 19105/50000 [13:59<1:14:45,  6.89it/s]

step 19100: train loss 1.2411, val loss 1.2416


Training Progress:  38%|███▊      | 19207/50000 [14:04<54:42,  9.38it/s]  

step 19200: train loss 1.2457, val loss 1.2441


Training Progress:  39%|███▊      | 19309/50000 [14:08<54:30,  9.38it/s]  

step 19300: train loss 1.2389, val loss 1.2400


Training Progress:  39%|███▉      | 19405/50000 [14:13<1:13:44,  6.91it/s]

step 19400: train loss 1.2404, val loss 1.2399


Training Progress:  39%|███▉      | 19507/50000 [14:17<54:16,  9.36it/s]  

step 19500: train loss 1.2416, val loss 1.2406


Training Progress:  39%|███▉      | 19609/50000 [14:21<54:04,  9.37it/s]  

step 19600: train loss 1.2450, val loss 1.2439


Training Progress:  39%|███▉      | 19705/50000 [14:26<1:12:59,  6.92it/s]

step 19700: train loss 1.2405, val loss 1.2404


Training Progress:  40%|███▉      | 19807/50000 [14:30<53:33,  9.40it/s]  

step 19800: train loss 1.2434, val loss 1.2426


Training Progress:  40%|███▉      | 19909/50000 [14:35<53:28,  9.38it/s]  

step 19900: train loss 1.2387, val loss 1.2393


Training Progress:  40%|████      | 20005/50000 [14:39<1:12:14,  6.92it/s]

step 20000: train loss 1.2397, val loss 1.2400


Training Progress:  40%|████      | 20107/50000 [14:43<53:10,  9.37it/s]  

step 20100: train loss 1.2380, val loss 1.2403


Training Progress:  40%|████      | 20209/50000 [14:48<53:03,  9.36it/s]  

step 20200: train loss 1.2397, val loss 1.2411


Training Progress:  41%|████      | 20305/50000 [14:52<1:11:25,  6.93it/s]

step 20300: train loss 1.2405, val loss 1.2401


Training Progress:  41%|████      | 20407/50000 [14:56<52:45,  9.35it/s]  

step 20400: train loss 1.2399, val loss 1.2416


Training Progress:  41%|████      | 20509/50000 [15:01<52:25,  9.37it/s]  

step 20500: train loss 1.2412, val loss 1.2405


Training Progress:  41%|████      | 20605/50000 [15:05<1:10:37,  6.94it/s]

step 20600: train loss 1.2404, val loss 1.2421


Training Progress:  41%|████▏     | 20707/50000 [15:10<52:08,  9.36it/s]  

step 20700: train loss 1.2419, val loss 1.2429


Training Progress:  42%|████▏     | 20809/50000 [15:14<51:49,  9.39it/s]  

step 20800: train loss 1.2409, val loss 1.2419


Training Progress:  42%|████▏     | 20905/50000 [15:18<1:09:47,  6.95it/s]

step 20900: train loss 1.2399, val loss 1.2403


Training Progress:  42%|████▏     | 21007/50000 [15:23<51:39,  9.35it/s]  

step 21000: train loss 1.2473, val loss 1.2465


Training Progress:  42%|████▏     | 21109/50000 [15:27<51:20,  9.38it/s]  

step 21100: train loss 1.2415, val loss 1.2413


Training Progress:  42%|████▏     | 21205/50000 [15:31<1:09:14,  6.93it/s]

step 21200: train loss 1.2415, val loss 1.2412


Training Progress:  43%|████▎     | 21307/50000 [15:36<51:13,  9.34it/s]  

step 21300: train loss 1.2389, val loss 1.2405


Training Progress:  43%|████▎     | 21409/50000 [15:40<50:45,  9.39it/s]  

step 21400: train loss 1.2398, val loss 1.2398


Training Progress:  43%|████▎     | 21505/50000 [15:45<1:08:45,  6.91it/s]

step 21500: train loss 1.2411, val loss 1.2402


Training Progress:  43%|████▎     | 21607/50000 [15:49<50:23,  9.39it/s]  

step 21600: train loss 1.2426, val loss 1.2417


Training Progress:  43%|████▎     | 21709/50000 [15:53<50:29,  9.34it/s]  

step 21700: train loss 1.2444, val loss 1.2427


Training Progress:  44%|████▎     | 21805/50000 [15:58<1:08:09,  6.89it/s]

step 21800: train loss 1.2414, val loss 1.2421


Training Progress:  44%|████▍     | 21907/50000 [16:02<50:00,  9.36it/s]  

step 21900: train loss 1.2417, val loss 1.2410


Training Progress:  44%|████▍     | 22009/50000 [16:07<49:47,  9.37it/s]  

step 22000: train loss 1.2412, val loss 1.2404


Training Progress:  44%|████▍     | 22105/50000 [16:11<1:07:13,  6.92it/s]

step 22100: train loss 1.2397, val loss 1.2399


Training Progress:  44%|████▍     | 22207/50000 [16:15<49:29,  9.36it/s]  

step 22200: train loss 1.2410, val loss 1.2409


Training Progress:  45%|████▍     | 22309/50000 [16:20<49:17,  9.36it/s]  

step 22300: train loss 1.2396, val loss 1.2395


Training Progress:  45%|████▍     | 22405/50000 [16:24<1:06:25,  6.92it/s]

step 22400: train loss 1.2403, val loss 1.2389


Training Progress:  45%|████▌     | 22507/50000 [16:28<48:53,  9.37it/s]  

step 22500: train loss 1.2391, val loss 1.2407


Training Progress:  45%|████▌     | 22609/50000 [16:33<49:02,  9.31it/s]  

step 22600: train loss 1.2402, val loss 1.2419


Training Progress:  45%|████▌     | 22705/50000 [16:37<1:05:42,  6.92it/s]

step 22700: train loss 1.2395, val loss 1.2392


Training Progress:  46%|████▌     | 22807/50000 [16:42<48:25,  9.36it/s]  

step 22800: train loss 1.2394, val loss 1.2406


Training Progress:  46%|████▌     | 22909/50000 [16:46<48:12,  9.36it/s]  

step 22900: train loss 1.2404, val loss 1.2415


Training Progress:  46%|████▌     | 23005/50000 [16:50<1:04:56,  6.93it/s]

step 23000: train loss 1.2391, val loss 1.2402


Training Progress:  46%|████▌     | 23107/50000 [16:55<47:58,  9.34it/s]  

step 23100: train loss 1.2395, val loss 1.2407


Training Progress:  46%|████▋     | 23209/50000 [16:59<47:43,  9.36it/s]  

step 23200: train loss 1.2383, val loss 1.2384


Training Progress:  47%|████▋     | 23305/50000 [17:03<1:04:14,  6.92it/s]

step 23300: train loss 1.2410, val loss 1.2403


Training Progress:  47%|████▋     | 23407/50000 [17:08<47:24,  9.35it/s]  

step 23400: train loss 1.2426, val loss 1.2402


Training Progress:  47%|████▋     | 23509/50000 [17:12<47:07,  9.37it/s]  

step 23500: train loss 1.2408, val loss 1.2402


Training Progress:  47%|████▋     | 23605/50000 [17:17<1:03:37,  6.91it/s]

step 23600: train loss 1.2397, val loss 1.2392


Training Progress:  47%|████▋     | 23707/50000 [17:21<46:50,  9.35it/s]  

step 23700: train loss 1.2394, val loss 1.2381


Training Progress:  48%|████▊     | 23809/50000 [17:25<46:39,  9.36it/s]  

step 23800: train loss 1.2412, val loss 1.2410


Training Progress:  48%|████▊     | 23905/50000 [17:30<1:02:57,  6.91it/s]

step 23900: train loss 1.2396, val loss 1.2387


Training Progress:  48%|████▊     | 24007/50000 [17:34<46:17,  9.36it/s]  

step 24000: train loss 1.2399, val loss 1.2392


Training Progress:  48%|████▊     | 24109/50000 [17:39<46:03,  9.37it/s]  

step 24100: train loss 1.2422, val loss 1.2411


Training Progress:  48%|████▊     | 24205/50000 [17:43<1:02:18,  6.90it/s]

step 24200: train loss 1.2386, val loss 1.2394


Training Progress:  49%|████▊     | 24307/50000 [17:47<45:43,  9.36it/s]  

step 24300: train loss 1.2388, val loss 1.2389


Training Progress:  49%|████▉     | 24409/50000 [17:52<45:36,  9.35it/s]  

step 24400: train loss 1.2407, val loss 1.2408


Training Progress:  49%|████▉     | 24505/50000 [17:56<1:01:25,  6.92it/s]

step 24500: train loss 1.2399, val loss 1.2403


Training Progress:  49%|████▉     | 24607/50000 [18:00<45:15,  9.35it/s]  

step 24600: train loss 1.2375, val loss 1.2381


Training Progress:  49%|████▉     | 24709/50000 [18:05<45:07,  9.34it/s]  

step 24700: train loss 1.2394, val loss 1.2390


Training Progress:  50%|████▉     | 24805/50000 [18:09<1:00:44,  6.91it/s]

step 24800: train loss 1.2402, val loss 1.2416


Training Progress:  50%|████▉     | 24907/50000 [18:14<44:40,  9.36it/s]  

step 24900: train loss 1.2406, val loss 1.2397


Training Progress:  50%|█████     | 25009/50000 [18:18<44:33,  9.35it/s]  

step 25000: train loss 1.2465, val loss 1.2453


Training Progress:  50%|█████     | 25105/50000 [18:22<1:00:03,  6.91it/s]

step 25100: train loss 1.2389, val loss 1.2398


Training Progress:  50%|█████     | 25207/50000 [18:27<44:16,  9.33it/s]

step 25200: train loss 1.2400, val loss 1.2409


Training Progress:  51%|█████     | 25309/50000 [18:31<44:08,  9.32it/s]

step 25300: train loss 1.2381, val loss 1.2383


Training Progress:  51%|█████     | 25405/50000 [18:36<59:16,  6.92it/s]

step 25400: train loss 1.2474, val loss 1.2465


Training Progress:  51%|█████     | 25507/50000 [18:40<43:44,  9.33it/s]

step 25500: train loss 1.2391, val loss 1.2375


Training Progress:  51%|█████     | 25609/50000 [18:44<43:29,  9.35it/s]

step 25600: train loss 1.2411, val loss 1.2422


Training Progress:  51%|█████▏    | 25705/50000 [18:49<58:37,  6.91it/s]

step 25700: train loss 1.2377, val loss 1.2391


Training Progress:  52%|█████▏    | 25807/50000 [18:53<43:25,  9.29it/s]

step 25800: train loss 1.2386, val loss 1.2380


Training Progress:  52%|█████▏    | 25909/50000 [18:58<43:02,  9.33it/s]

step 25900: train loss 1.2408, val loss 1.2407


Training Progress:  52%|█████▏    | 26005/50000 [19:02<58:01,  6.89it/s]

step 26000: train loss 1.2411, val loss 1.2421


Training Progress:  52%|█████▏    | 26107/50000 [19:06<42:34,  9.35it/s]

step 26100: train loss 1.2406, val loss 1.2401


Training Progress:  52%|█████▏    | 26209/50000 [19:11<42:16,  9.38it/s]

step 26200: train loss 1.2384, val loss 1.2385


Training Progress:  53%|█████▎    | 26305/50000 [19:15<57:07,  6.91it/s]

step 26300: train loss 1.2374, val loss 1.2392


Training Progress:  53%|█████▎    | 26407/50000 [19:20<42:03,  9.35it/s]

step 26400: train loss 1.2401, val loss 1.2398


Training Progress:  53%|█████▎    | 26509/50000 [19:24<41:50,  9.36it/s]

step 26500: train loss 1.2379, val loss 1.2390


Training Progress:  53%|█████▎    | 26605/50000 [19:28<56:27,  6.91it/s]

step 26600: train loss 1.2381, val loss 1.2393


Training Progress:  53%|█████▎    | 26707/50000 [19:33<41:28,  9.36it/s]

step 26700: train loss 1.2393, val loss 1.2400


Training Progress:  54%|█████▎    | 26809/50000 [19:37<41:09,  9.39it/s]

step 26800: train loss 1.2392, val loss 1.2390


Training Progress:  54%|█████▍    | 26905/50000 [19:41<55:42,  6.91it/s]

step 26900: train loss 1.2400, val loss 1.2384


Training Progress:  54%|█████▍    | 27007/50000 [19:46<40:51,  9.38it/s]

step 27000: train loss 1.2375, val loss 1.2373


Training Progress:  54%|█████▍    | 27109/50000 [19:50<40:42,  9.37it/s]

step 27100: train loss 1.2392, val loss 1.2378


Training Progress:  54%|█████▍    | 27205/50000 [19:55<54:56,  6.92it/s]

step 27200: train loss 1.2403, val loss 1.2412


Training Progress:  55%|█████▍    | 27307/50000 [19:59<40:25,  9.36it/s]

step 27300: train loss 1.2410, val loss 1.2394


Training Progress:  55%|█████▍    | 27409/50000 [20:03<40:16,  9.35it/s]

step 27400: train loss 1.2398, val loss 1.2387


Training Progress:  55%|█████▌    | 27505/50000 [20:08<54:09,  6.92it/s]

step 27500: train loss 1.2397, val loss 1.2392


Training Progress:  55%|█████▌    | 27607/50000 [20:12<39:47,  9.38it/s]

step 27600: train loss 1.2390, val loss 1.2396


Training Progress:  55%|█████▌    | 27709/50000 [20:16<39:41,  9.36it/s]

step 27700: train loss 1.2397, val loss 1.2398


Training Progress:  56%|█████▌    | 27805/50000 [20:21<53:21,  6.93it/s]

step 27800: train loss 1.2415, val loss 1.2410


Training Progress:  56%|█████▌    | 27907/50000 [20:25<39:16,  9.38it/s]

step 27900: train loss 1.2386, val loss 1.2387


Training Progress:  56%|█████▌    | 28009/50000 [20:30<39:06,  9.37it/s]

step 28000: train loss 1.2390, val loss 1.2407


Training Progress:  56%|█████▌    | 28105/50000 [20:34<52:48,  6.91it/s]

step 28100: train loss 1.2393, val loss 1.2387


Training Progress:  56%|█████▋    | 28207/50000 [20:38<39:01,  9.31it/s]

step 28200: train loss 1.2386, val loss 1.2389


Training Progress:  57%|█████▋    | 28309/50000 [20:43<38:40,  9.35it/s]

step 28300: train loss 1.2399, val loss 1.2398


Training Progress:  57%|█████▋    | 28405/50000 [20:47<51:53,  6.94it/s]

step 28400: train loss 1.2403, val loss 1.2391


Training Progress:  57%|█████▋    | 28507/50000 [20:52<38:16,  9.36it/s]

step 28500: train loss 1.2376, val loss 1.2373


Training Progress:  57%|█████▋    | 28609/50000 [20:56<38:04,  9.36it/s]

step 28600: train loss 1.2408, val loss 1.2405


Training Progress:  57%|█████▋    | 28705/50000 [21:00<51:18,  6.92it/s]

step 28700: train loss 1.2476, val loss 1.2454


Training Progress:  58%|█████▊    | 28807/50000 [21:05<37:44,  9.36it/s]

step 28800: train loss 1.2406, val loss 1.2402


Training Progress:  58%|█████▊    | 28909/50000 [21:09<37:26,  9.39it/s]

step 28900: train loss 1.2377, val loss 1.2376


Training Progress:  58%|█████▊    | 29011/50000 [21:13<37:21,  9.37it/s]

step 29000: train loss 1.2411, val loss 1.2422


Training Progress:  58%|█████▊    | 29107/50000 [21:18<37:09,  9.37it/s]

step 29100: train loss 1.2415, val loss 1.2396


Training Progress:  58%|█████▊    | 29209/50000 [21:22<37:00,  9.36it/s]

step 29200: train loss 1.2388, val loss 1.2399


Training Progress:  59%|█████▊    | 29305/50000 [21:27<49:50,  6.92it/s]

step 29300: train loss 1.2373, val loss 1.2379


Training Progress:  59%|█████▉    | 29407/50000 [21:31<36:34,  9.38it/s]

step 29400: train loss 1.2392, val loss 1.2388


Training Progress:  59%|█████▉    | 29509/50000 [21:35<36:28,  9.36it/s]

step 29500: train loss 1.2382, val loss 1.2392


Training Progress:  59%|█████▉    | 29605/50000 [21:40<49:02,  6.93it/s]

step 29600: train loss 1.2393, val loss 1.2393


Training Progress:  59%|█████▉    | 29707/50000 [21:44<36:04,  9.37it/s]

step 29700: train loss 1.2399, val loss 1.2388


Training Progress:  60%|█████▉    | 29809/50000 [21:48<35:57,  9.36it/s]

step 29800: train loss 1.2382, val loss 1.2374


Training Progress:  60%|█████▉    | 29911/50000 [21:53<35:39,  9.39it/s]

step 29900: train loss 1.2391, val loss 1.2400


Training Progress:  60%|██████    | 30007/50000 [21:57<35:38,  9.35it/s]

step 30000: train loss 1.2376, val loss 1.2381


Training Progress:  60%|██████    | 30109/50000 [22:02<35:30,  9.34it/s]

step 30100: train loss 1.2374, val loss 1.2395


Training Progress:  60%|██████    | 30205/50000 [22:06<47:39,  6.92it/s]

step 30200: train loss 1.2431, val loss 1.2430


Training Progress:  61%|██████    | 30307/50000 [22:10<35:15,  9.31it/s]

step 30300: train loss 1.2405, val loss 1.2403


Training Progress:  61%|██████    | 30409/50000 [22:15<34:43,  9.40it/s]

step 30400: train loss 1.2433, val loss 1.2432


Training Progress:  61%|██████    | 30505/50000 [22:19<46:52,  6.93it/s]

step 30500: train loss 1.2396, val loss 1.2399


Training Progress:  61%|██████    | 30607/50000 [22:24<34:34,  9.35it/s]

step 30600: train loss 1.2377, val loss 1.2379


Training Progress:  61%|██████▏   | 30709/50000 [22:28<34:17,  9.38it/s]

step 30700: train loss 1.2383, val loss 1.2381


Training Progress:  62%|██████▏   | 30805/50000 [22:32<46:04,  6.94it/s]

step 30800: train loss 1.2390, val loss 1.2401


Training Progress:  62%|██████▏   | 30907/50000 [22:37<33:57,  9.37it/s]

step 30900: train loss 1.2387, val loss 1.2392


Training Progress:  62%|██████▏   | 31009/50000 [22:41<33:51,  9.35it/s]

step 31000: train loss 1.2372, val loss 1.2383


Training Progress:  62%|██████▏   | 31105/50000 [22:45<45:31,  6.92it/s]

step 31100: train loss 1.2379, val loss 1.2391


Training Progress:  62%|██████▏   | 31207/50000 [22:50<33:22,  9.39it/s]

step 31200: train loss 1.2373, val loss 1.2377


Training Progress:  63%|██████▎   | 31309/50000 [22:54<33:13,  9.38it/s]

step 31300: train loss 1.2378, val loss 1.2374


Training Progress:  63%|██████▎   | 31405/50000 [22:59<45:10,  6.86it/s]

step 31400: train loss 1.2422, val loss 1.2423


Training Progress:  63%|██████▎   | 31507/50000 [23:03<32:55,  9.36it/s]

step 31500: train loss 1.2386, val loss 1.2385


Training Progress:  63%|██████▎   | 31609/50000 [23:07<32:40,  9.38it/s]

step 31600: train loss 1.2387, val loss 1.2405


Training Progress:  63%|██████▎   | 31705/50000 [23:12<44:06,  6.91it/s]

step 31700: train loss 1.2374, val loss 1.2379


Training Progress:  64%|██████▎   | 31807/50000 [23:16<32:21,  9.37it/s]

step 31800: train loss 1.2394, val loss 1.2392


Training Progress:  64%|██████▍   | 31909/50000 [23:21<32:12,  9.36it/s]

step 31900: train loss 1.2373, val loss 1.2381


Training Progress:  64%|██████▍   | 32005/50000 [23:25<43:14,  6.94it/s]

step 32000: train loss 1.2378, val loss 1.2379


Training Progress:  64%|██████▍   | 32107/50000 [23:29<31:45,  9.39it/s]

step 32100: train loss 1.2398, val loss 1.2393


Training Progress:  64%|██████▍   | 32209/50000 [23:34<31:38,  9.37it/s]

step 32200: train loss 1.2393, val loss 1.2377


Training Progress:  65%|██████▍   | 32305/50000 [23:38<42:28,  6.94it/s]

step 32300: train loss 1.2385, val loss 1.2387


Training Progress:  65%|██████▍   | 32407/50000 [23:42<31:23,  9.34it/s]

step 32400: train loss 1.2385, val loss 1.2389


Training Progress:  65%|██████▌   | 32509/50000 [23:47<31:10,  9.35it/s]

step 32500: train loss 1.2392, val loss 1.2394


Training Progress:  65%|██████▌   | 32605/50000 [23:51<41:50,  6.93it/s]

step 32600: train loss 1.2390, val loss 1.2380


Training Progress:  65%|██████▌   | 32707/50000 [23:55<30:47,  9.36it/s]

step 32700: train loss 1.2377, val loss 1.2381


Training Progress:  66%|██████▌   | 32809/50000 [24:00<30:30,  9.39it/s]

step 32800: train loss 1.2383, val loss 1.2391


Training Progress:  66%|██████▌   | 32905/50000 [24:04<41:08,  6.93it/s]

step 32900: train loss 1.2381, val loss 1.2373


Training Progress:  66%|██████▌   | 33007/50000 [24:09<30:18,  9.34it/s]

step 33000: train loss 1.2386, val loss 1.2380


Training Progress:  66%|██████▌   | 33109/50000 [24:13<30:00,  9.38it/s]

step 33100: train loss 1.2398, val loss 1.2377


Training Progress:  66%|██████▋   | 33205/50000 [24:17<40:21,  6.94it/s]

step 33200: train loss 1.2381, val loss 1.2381


Training Progress:  67%|██████▋   | 33307/50000 [24:22<29:49,  9.33it/s]

step 33300: train loss 1.2382, val loss 1.2383


Training Progress:  67%|██████▋   | 33409/50000 [24:26<29:30,  9.37it/s]

step 33400: train loss 1.2388, val loss 1.2383


Training Progress:  67%|██████▋   | 33500/50000 [24:30<12:04, 22.78it/s]

step 33500: train loss 1.2378, val loss 1.2396
Early Stopping at iteration 33500





In [84]:
wandb.finish()

0,1
Loss,█▆▇▇▆█▇▆▆▅▆▆▄▅▄▄▆▅▇▃▅▅▅▄▆▃▃▃▄▂▄▁▄▃▁▁▃▂▂▂

0,1
Loss,1.23828


In [None]:
fig = go.Figure()

iterations = list(range(58))

# Add lines
fig.add_trace(go.Scatter(x=iterations, y=val_loss_list, mode='lines+markers', name='Validation Loss'))
fig.add_trace(go.Scatter(x=iterations, y=acc_test_list, mode='lines+markers', name='Test Accuracy'))
fig.add_trace(go.Scatter(x=iterations, y=acc_all_list, mode='lines+markers', name='All Accuracy'))

# Set title label
fig.update_layout(
    title="Validation Loss & Accuracy over Iterations",
    xaxis_title="Iterations",
    yaxis_title="Metrics",
    legend_title="Metrics",
    width = 900,
    height = 500
)

fig.show()

wandb.init(project="transformer", name="val_loss and accuracy plot")
wandb.log({"Interactive Chart": wandb.Html(fig.to_html())})
wandb.finish()

In [129]:
def accuracy_print(model, num_digits, need_print=False):
        correct = 0

        for j in range(100):
            a = np.random.choice(np.arange(10**num_digits), 1)
            b = np.random.choice(np.arange(10**num_digits), 1)
            c = a + b
            input = f"{a.item()}+{b.item()}="
            context = torch.tensor(encode(input), dtype=torch.long, device=device)
            output = generate(model, context, 100, 1)
            if need_print and j // 10 == 0:
                print(f"Input: {input}")
                print(f"Output: {output}")
            if output == f"{a.item()}+{b.item()}={c.item()}":
                correct += 1
        acc = correct / 100
        print(f"Accuracy for {num_digits} digits addition: {acc} ")
        return acc

In [130]:
accuracy_print(model, 1)

Accuracy for 1 digits addition: 0.97 


0.97

In [131]:
accuracy_print(model, 2)

Accuracy for 2 digits addition: 1.0 


1.0

In [132]:
accuracy_print(model, 3)

Accuracy for 3 digits addition: 0.94 


0.94

In [133]:
accuracy_print(model, 4)

Accuracy for 4 digits addition: 0.92 


0.92

In [134]:
accuracy_print(model, 5)

Accuracy for 5 digits addition: 0.8 


0.8

In [135]:
accuracy_print(model, 6)

Accuracy for 6 digits addition: 0.91 


0.91

In [136]:
accuracy_print(model, 7)

Accuracy for 7 digits addition: 0.02 


0.02

In [137]:
accuracy_print(model, 8)

Accuracy for 8 digits addition: 0.0 


0.0

In [138]:
def get_avg_performance(model):
    dict_acc = {}
    for num_dig in range(1, 9):
        dict_acc[num_dig] = accuracy_print(model, num_dig, need_print=False)
    return dict_acc

In [139]:
avg_performance = get_avg_performance(model)

Accuracy for 1 digits addition: 0.98 
Accuracy for 2 digits addition: 0.99 
Accuracy for 3 digits addition: 0.97 
Accuracy for 4 digits addition: 0.92 
Accuracy for 5 digits addition: 0.8 
Accuracy for 6 digits addition: 0.91 
Accuracy for 7 digits addition: 0.03 
Accuracy for 8 digits addition: 0.0 


In [142]:
x_values = list(avg_performance.keys())
y_values = list(avg_performance.values())


fig = go.Figure(go.Bar(x=x_values, y=y_values, marker_color='lightskyblue'))


fig.update_layout(
    title="Accuracy for different digits addition",
    xaxis_title="Num Digits",
    yaxis_title="Accuracy",
    template="plotly_white",
    width=800,
    height= 500
)


fig.show()

wandb.init(project="transformer", name="Accuracy for different digits addition plot")
wandb.log({"Interactive Chart": wandb.Html(fig.to_html())})
wandb.finish()

In [None]:
import subprocess

os.system('git config --global user.email "zifeibai@umich.edu"')
os.system('git config --global user.name "ZifeiBai"')

# 2️⃣ **Use Google Drive to store GitHub Token**
GITHUB_TOKEN_PATH = "/content/drive/MyDrive/URPS/github_token.txt"
if os.path.exists(GITHUB_TOKEN_PATH):
    with open(GITHUB_TOKEN_PATH, "r") as f:
        os.environ["GITHUB_TOKEN"] = f.read().strip()
else:
    print("❌ GitHub Token")
    exit(1)

# 3️⃣ **Set up GitHub remote repo**
GIT_PATH = "/content/drive/MyDrive/URPS/Git"
REPO_URL = f"https://{os.environ['GITHUB_TOKEN']}@github.com/ZifeiBai/URPS.git"

if not os.path.exists(GIT_PATH):
    print(f"📁 Creating directory: {GIT_PATH}")
    os.makedirs(GIT_PATH)

# 4️⃣ **If .git/ does not exsit， need to clone**
if not os.path.exists(os.path.join(GIT_PATH, ".git")):
    print("❌ Git repository not found. Cloning...")
    subprocess.run(f"rm -rf {GIT_PATH}", shell=True, check=True)
    subprocess.run(f"git clone {REPO_URL} {GIT_PATH}", shell=True, check=True)

# 5️⃣ **Enter Git repo**
os.chdir(GIT_PATH)
print("📂 Changed working directory to:", os.getcwd())


# 6️⃣ **Check Git status**
status_output = subprocess.run("git status", shell=True, capture_output=True, text=True)
print(status_output.stdout)

#  **Push to Git**
print("🚀 Adding files to Git...")
subprocess.run("git add .", shell=True, check=True)

print("📝 Committing changes...")
commit_output = subprocess.run('git commit -m "Auto update from Google Colab 2.6"', shell=True, capture_output=True, text=True)
print(commit_output.stdout)



print("📤 Pushing to GitHub...")
push_output = subprocess.run("git push origin main", shell=True, capture_output=True, text=True)
if "fatal" in push_output.stderr or "error:" in push_output.stderr:
    print("❌ Real Git Push Error:", push_output.stderr)
else:
    print("✅ Git Push Success!")

📂 Changed working directory to: /content/drive/MyDrive/URPS/Git
On branch main
Your branch is up to date with 'origin/main'.

Changes not staged for commit:
  (use "git add <file>..." to update what will be committed)
  (use "git restore <file>..." to discard changes in working directory)
	modified:   transformer.ipynb

no changes added to commit (use "git add" and/or "git commit -a")

🚀 Adding files to Git...
📝 Committing changes...
[main 1fa1a0f] Auto update from Google Colab 2.6
 1 file changed, 1 insertion(+), 1 deletion(-)
 rewrite transformer.ipynb (94%)

📤 Pushing to GitHub...
✅ Git Push Success!
