In [None]:
from google.colab import drive
import os

# 1️⃣ Google Drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
import wandb

In [2]:
import pandas as pd

In [3]:
!pip install plotly



In [4]:
import plotly.graph_objects as go

In [5]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mbirdyyybai[0m ([33mbirdyyybai-university-of-michigan[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [6]:
import math
import inspect
from dataclasses import dataclass
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import functional as F

In [7]:
vocab = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=', '+', '&', '*']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
padding_token_index = 13
end_token_index = 12

In [8]:
# create a mapping from chars to ints
stoi = {ch:i for i, ch in enumerate(vocab)}
itos = {i:ch for i, ch in enumerate(vocab)}
encode = lambda s:[stoi[c] for c in s] # encoder: take a string, output a list of ints
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of ints, output a string

print(encode("1+2=3&"))
print(decode(encode("1+2=3&")))

[1, 11, 2, 10, 3, 12]
1+2=3&


In [None]:
# # train test split
# train_set_1 = np.random.choice(np.arange(10), 8, replace=False)
# train_set_2 = np.random.choice(np.arange(10, 100), 72, replace=False)
# train_set_3 = np.random.choice(np.arange(100, 1000), 720, replace=False)
# test_1 = np.setdiff1d(np.arange(10), train_set_1)
# test_2 = np.setdiff1d(np.arange(10, 100), train_set_2)
# test_3 = np.setdiff1d(np.arange(100, 1000), train_set_3)
# test_12 = np.concatenate([test_1, test_2])
# test = np.concatenate([test_1, test_2, test_3])
# print(np.sort(train_set_1))
# print(np.sort(train_set_2))
# print(np.sort(test_1))
# print(np.sort(test_2))
# print(np.sort(test))

[0 1 2 4 5 6 8 9]
[10 11 12 13 14 15 16 17 18 19 21 22 23 24 25 26 27 28 30 31 33 34 35 36
 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 58 60 63 64 65
 66 68 69 70 71 74 77 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 98 99]
[3 7]
[20 29 32 37 57 59 61 62 67 72 73 75 76 78 79 80 96 97]
[  3   7  20  29  32  37  57  59  61  62  67  72  73  75  76  78  79  80
  96  97 102 107 118 121 125 130 145 146 151 152 154 161 165 169 171 180
 192 193 197 198 199 200 204 207 208 210 218 219 241 254 268 284 294 295
 296 299 303 306 308 317 318 319 327 329 340 341 342 343 348 365 366 373
 377 381 386 396 399 401 407 408 412 417 422 424 427 428 438 443 455 456
 459 468 474 477 496 514 520 526 533 534 536 538 547 553 555 558 559 564
 572 575 582 585 593 594 599 601 608 611 620 622 624 628 635 638 639 647
 650 652 654 655 658 659 660 667 685 690 705 707 715 716 723 724 726 728
 744 749 752 755 759 764 765 766 771 772 776 778 783 785 786 793 803 813
 824 826 838 839 840 841 844 845 862 865 87

In [88]:
def get_batch(phase=None, batch_size=32, block_size=35, mode='train'):

    if mode == 'train':
      # random choose a and b from set
      if phase != "mix":
        a = np.random.randint(10**(phase-1), 10**(phase), batch_size)
        b = np.random.randint(10**(phase-1), 10**(phase), batch_size)
        c = a + b
      elif phase == "mix":
        exp_a = np.random.choice(np.arange(1, 7), size=batch_size, p=[0.045, 0.075, 0.09, 0.14, 0.25, 0.40])
        exp_b = np.random.choice(np.arange(1, 7), size=batch_size, p=[0.045, 0.075, 0.09, 0.14, 0.25, 0.40])
        a = np.random.randint(10**(exp_a-1), 10**(exp_a), size=batch_size)
        b = np.random.randint(10**(exp_b-1), 10**(exp_b), size=batch_size)
        c = a + b
    else:
      if phase != "mix":
        a = np.random.randint(10**(phase-1), 10**(phase), batch_size)
        b = np.random.randint(10**(phase-1), 10**(phase), batch_size)
        c = a + b
      elif phase == "mix":
        exp_a = np.random.choice(np.arange(1, 7), size=batch_size, p=[0.045, 0.075, 0.09, 0.14, 0.25, 0.40])
        exp_b = np.random.choice(np.arange(1, 7), size=batch_size, p=[0.045, 0.075, 0.09, 0.14, 0.25, 0.40])
        a = np.random.randint(10**(exp_a-1), 10**(exp_a), size=batch_size)
        b = np.random.randint(10**(exp_b-1), 10**(exp_b), size=batch_size)
        c = a + b

    x_list, y_list = [], []
    for i, j, k in zip(a, b, c):
        # construct X: "i+j=k&"
        k_str = str(k)[::-1]
        x_str = f"{i}+{j}={k_str}&"
        # print(x_str)
        x_encoded = encode(x_str)
        x_padded = x_encoded + [padding_token_index] * (block_size - len(x_encoded))
        x_list.append(torch.tensor(x_padded, dtype=torch.int64))

        # construct Y: "k&"
        y_encoded = encode(x_str)[1:]
        y_encoded.append(end_token_index)
        y_padded = y_encoded + [padding_token_index] * (block_size - len(y_encoded))
        y_list.append(torch.tensor(y_padded, dtype=torch.int64))

    x_tensor = torch.stack(x_list).to(device)
    y_tensor = torch.stack(y_list).to(device)
    return x_tensor, y_tensor

In [89]:
get_batch(phase="mix")

(tensor([[ 7,  3,  8,  ..., 13, 13, 13],
         [ 6,  1,  3,  ..., 13, 13, 13],
         [ 7,  1,  3,  ..., 13, 13, 13],
         ...,
         [ 4,  6,  2,  ..., 13, 13, 13],
         [ 5,  9,  8,  ..., 13, 13, 13],
         [ 8,  5,  5,  ..., 13, 13, 13]], device='cuda:0'),
 tensor([[ 3,  8,  8,  ..., 13, 13, 13],
         [ 1,  3,  3,  ..., 13, 13, 13],
         [ 1,  3,  4,  ..., 13, 13, 13],
         ...,
         [ 6,  2,  8,  ..., 13, 13, 13],
         [ 9,  8,  1,  ..., 13, 13, 13],
         [ 5,  5,  4,  ..., 13, 13, 13]], device='cuda:0'))

In [33]:
class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias=True): # class constructor
        super().__init__()
        # nn.Parameter, pytorch optimize will update the value of this parameter during training
        self.weight = nn.Parameter(torch.ones(ndim)) # trainable parameter
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None # trainable parameter

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, dropout, block_size, bias=True):
        super().__init__()
        assert n_embd % n_head == 0, "Embedding dimension must be divisible by the number of heads."

        # Store hyperparameters
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout = dropout
        self.block_size = block_size

        # Key, Query, Value projections
        self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)
        # Output projection
        self.c_proj = nn.Linear(n_embd, n_embd, bias=bias)
        # Regularization
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

        # Check for Flash Attention availability
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # Causal mask for slow attention
            self.register_buffer(
                "bias",
                torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)
            )

    def forward(self, x):
        B, T, C = x.size()  # Batch size, sequence length, embedding dimension

        # Compute Q, K, V
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)  # Split into Q, K, V (B, T, n_embd)

        # Reshape for multi-head attention
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, n_head, T, head_size)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, n_head, T, head_size)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, n_head, T, head_size)

        # Flash Attention or fallback to manual implementation
        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=self.dropout if self.training else 0,
                is_causal=True
            )
        else:
            # Manual attention with causal masking
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))  # Scaled dot product
            att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))  # Apply causal mask
            att = F.softmax(att, dim=-1)  # Normalize attention scores
            att = self.attn_dropout(att)
            y = att @ v  # Apply attention weights to values (B, n_head, T, head_size)

        # Reshape back to original format
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # Reassemble heads

        # Output projection and residual dropout
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module): # FFN

    def __init__(self, n_embd, dropout, bias=True):
        super().__init__()
        self.c_fc    = nn.Linear(n_embd, 4 * n_embd, bias=bias)
        self.gelu    = nn.GELU() # nonlinear activation function
        self.c_proj  = nn.Linear(4 * n_embd, n_embd, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, n_embd, n_head, dropout, block_size, bias=True):
        super().__init__()
        # LayerNorm and CausalSelfAttention with explicit parameters
        self.ln_1 = LayerNorm(n_embd, bias=bias)
        self.attn = CausalSelfAttention(n_embd, n_head, dropout, block_size, bias=bias)
        self.ln_2 = LayerNorm(n_embd, bias=bias)
        self.mlp = MLP(n_embd, dropout, bias=bias)  # MLP with explicit parameters

    def forward(self, x):
        # Apply residual connection and pre-normalization
        x = x + self.attn(self.ln_1(x))  # Apply LayerNorm before attention
        x = x + self.mlp(self.ln_2(x))  # Apply LayerNorm before MLP
        return x


class GPT(nn.Module):

    def __init__(self, vocab_size, block_size, n_embd, n_layer, n_head, dropout, bias=True):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        super().__init__()
        assert vocab_size is not None
        assert block_size is not None
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.dropout = dropout
        self.bias = bias

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(vocab_size, n_embd), # token embeddings
            wpe = nn.Embedding(block_size, n_embd), # positional embeddings
            drop = nn.Dropout(dropout),
            h = nn.ModuleList([Block(n_embd, n_head, dropout, block_size, bias=bias) for _ in range(n_layer)]), # a stack of n_layer blocks
            ln_f = LayerNorm(n_embd, bias=bias), # final layer norm
        ))
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) # projects the final transformer output to the vocab size

        # init all weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.cblock_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        logits = self.lm_head(x)

        loss = None

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=13)
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            # loss = None

        return logits, loss

In [90]:
eval_iters = 200

@torch.no_grad()
def estimate_loss(phase, models):
    out = {}
    models.eval()
    for split in ['train', 'val']:
      losses = torch.zeros(eval_iters)
      for k in range(eval_iters):
          X, Y = get_batch(phase, mode=split)
          padding_mask_x = (X != padding_token_index).long()
          logits, loss = models(X, Y)
          losses[k] = loss.item()
      out[split] = losses.mean()
    models.train()
    return out

In [35]:
# batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 35 # what is the maximum context length for predictions?
max_iters = 150000
# num_epochs = 100
eval_interval = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 20
n_embd = 256
n_head = 4
n_layer = 8
dropout = 0.0
# # torch.manual_seed(1337)
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(1337)
bias = True # if using bias inside all Linear layers
vocab_size = len(vocab)

In [91]:
wandb.init(project="transformer", config={
    "learning_rate": 1e-5,
    "batch_size": 32,
    "block_size": 35,
    "optimizer": "AdamW",
    "n_embd": 256,
    "n_head": 4,
    "n_layer": 8,
    "dropout": 0.0,
})

In [None]:
def accuracy(model):
    correct = 0
    for j in range(100):

        a = np.random.choice(np.arange(1000000), 1)
        b = np.random.choice(np.arange(1000000), 1)

        c = a + b
        input = f"{a.item()}+{b.item()}="
        context = torch.tensor(encode(input), dtype=torch.long, device=device)
        output = generate(model, context, 100, 1)
        if output == f"{a.item()}+{b.item()}={c.item()}":
            correct += 1
    print(f"Accuracy for addition: {correct / 100} ")
    return correct / 100

In [92]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    """
    Generate a sequence of tokens given an initial sequence.

    Parameters:
        model (nn.Module): The model used for generation.
        idx (torch.Tensor or list): Initial sequence of indices (LongTensor of shape (b,t)).
        max_new_tokens (int): Number of new tokens to generate.
        temperature (float): Scaling factor for logits before softmax.
        top_k (int, optional): If specified, restricts sampling to top k tokens.

    Returns:
        torch.Tensor: The generated sequence.
    """
    idx = idx.unsqueeze(0) if idx.dim() == 1 else idx
    idx = torch.tensor(idx, device=model.device) if not isinstance(idx, torch.Tensor) else idx.to(model.device)

    for _ in range(max_new_tokens):
        # Ensure context length does not exceed model's block size
        idx_cond = idx if idx.size(1) <= model.block_size else idx[:, -model.block_size:]

        # Forward pass to get logits
        logits, _ = model(idx_cond)

        # Extract logits for the last token and apply temperature scaling
        logits = logits[:, -1, :] / temperature

        # Apply top-k filtering if necessary
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')

        # Convert logits to probabilities
        probs = F.softmax(logits, dim=-1)

        # Sample next token
        idx_next = torch.multinomial(probs, num_samples=1)

        if idx_next == end_token_index:
            break
        # Append sampled token to sequence

        # Append sampled token to sequence
        idx = torch.cat((idx, idx_next), dim=1)

    return decode(idx.tolist()[0])


In [37]:
model = GPT(vocab_size, block_size, n_embd, n_layer, n_head, dropout, bias=bias)
m = model.to(device)

In [38]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [93]:
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

phase = 1
best_acc = 0
counter = 0
best_loss = float('inf')
val_loss_list = []
acc_list = []

patience = 50

for iter in tqdm(range(max_iters), desc="Training Progress"):
    if iter > 2000:
      phase = 2
    if iter > 4000:
      phase = 3
    if iter > 6000:
      phase = 4
    if iter > 8000:
      phase = 5
    if iter > 10000:
      phase = 6
    if iter > 12000:
      phase = "mix"

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(phase, model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        log_dict = {"Loss": losses['val']}
        val_loss_list.append(round(losses['val'].item(), 4))

        if phase == "mix":
            # acc = accuracy(model)

            # acc_list.append(acc)
            # log_dict["Accuracy"] = acc

            if losses['val'] < best_loss:
                counter = 0
                # best_acc = max(best_acc, acc)
                best_loss = min(best_loss, losses['val'])
            else:
                counter += 1
                if counter >= patience:
                    print(f"Early Stopping at iteration {iter}")
                    break

        # record to W&B
        wandb.log(log_dict)

    # sample a batch of data

    xb, yb = get_batch(phase)

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


6.33472 M parameters


Training Progress:   0%|          | 7/150000 [00:02<11:13:57,  3.71it/s] 

step 0: train loss 1.2019, val loss 1.2037


Training Progress:   0%|          | 109/150000 [00:06<4:04:40, 10.21it/s]

step 100: train loss 0.3390, val loss 0.3400


Training Progress:   0%|          | 211/150000 [00:10<4:04:48, 10.20it/s]

step 200: train loss 0.3362, val loss 0.3373


Training Progress:   0%|          | 307/150000 [00:14<4:05:30, 10.16it/s]

step 300: train loss 0.3367, val loss 0.3365


Training Progress:   0%|          | 409/150000 [00:18<4:05:14, 10.17it/s]

step 400: train loss 0.3368, val loss 0.3362


Training Progress:   0%|          | 511/150000 [00:22<4:06:04, 10.13it/s]

step 500: train loss 0.3367, val loss 0.3364


Training Progress:   0%|          | 607/150000 [00:26<4:07:37, 10.06it/s]

step 600: train loss 0.3362, val loss 0.3364


Training Progress:   0%|          | 709/150000 [00:30<4:07:28, 10.05it/s]

step 700: train loss 0.3369, val loss 0.3365


Training Progress:   1%|          | 811/150000 [00:35<4:06:17, 10.10it/s]

step 800: train loss 0.3360, val loss 0.3353


Training Progress:   1%|          | 907/150000 [00:39<4:07:10, 10.05it/s]

step 900: train loss 0.3357, val loss 0.3363


Training Progress:   1%|          | 1009/150000 [00:43<4:06:25, 10.08it/s]

step 1000: train loss 0.3355, val loss 0.3364


Training Progress:   1%|          | 1111/150000 [00:47<4:06:32, 10.06it/s]

step 1100: train loss 0.3379, val loss 0.3370


Training Progress:   1%|          | 1207/150000 [00:51<4:07:42, 10.01it/s]

step 1200: train loss 0.3358, val loss 0.3356


Training Progress:   1%|          | 1309/150000 [00:55<4:07:03, 10.03it/s]

step 1300: train loss 0.3362, val loss 0.3362


Training Progress:   1%|          | 1411/150000 [00:59<4:07:56,  9.99it/s]

step 1400: train loss 0.3368, val loss 0.3376


Training Progress:   1%|          | 1507/150000 [01:03<4:07:31, 10.00it/s]

step 1500: train loss 0.3357, val loss 0.3358


Training Progress:   1%|          | 1609/150000 [01:07<4:07:41,  9.99it/s]

step 1600: train loss 0.3360, val loss 0.3358


Training Progress:   1%|          | 1711/150000 [01:11<4:07:54,  9.97it/s]

step 1700: train loss 0.3360, val loss 0.3359


Training Progress:   1%|          | 1807/150000 [01:15<4:08:03,  9.96it/s]

step 1800: train loss 0.3353, val loss 0.3357


Training Progress:   1%|▏         | 1909/150000 [01:20<4:08:00,  9.95it/s]

step 1900: train loss 0.3364, val loss 0.3360


Training Progress:   1%|▏         | 2011/150000 [01:24<4:09:32,  9.88it/s]

step 2000: train loss 0.3364, val loss 0.3362


Training Progress:   1%|▏         | 2107/150000 [01:28<4:08:38,  9.91it/s]

step 2100: train loss 0.7165, val loss 0.7162


Training Progress:   1%|▏         | 2209/150000 [01:32<4:08:55,  9.90it/s]

step 2200: train loss 0.7120, val loss 0.7109


Training Progress:   2%|▏         | 2311/150000 [01:36<4:09:26,  9.87it/s]

step 2300: train loss 0.7105, val loss 0.7114


Training Progress:   2%|▏         | 2407/150000 [01:40<4:09:46,  9.85it/s]

step 2400: train loss 0.7109, val loss 0.7109


Training Progress:   2%|▏         | 2509/150000 [01:45<4:08:33,  9.89it/s]

step 2500: train loss 0.7092, val loss 0.7102


Training Progress:   2%|▏         | 2611/150000 [01:49<4:08:52,  9.87it/s]

step 2600: train loss 0.7106, val loss 0.7107


Training Progress:   2%|▏         | 2707/150000 [01:53<4:09:00,  9.86it/s]

step 2700: train loss 0.7106, val loss 0.7108


Training Progress:   2%|▏         | 2809/150000 [01:57<4:10:07,  9.81it/s]

step 2800: train loss 0.7103, val loss 0.7101


Training Progress:   2%|▏         | 2911/150000 [02:01<4:08:58,  9.85it/s]

step 2900: train loss 0.7116, val loss 0.7121


Training Progress:   2%|▏         | 3007/150000 [02:05<4:09:25,  9.82it/s]

step 3000: train loss 0.7104, val loss 0.7106


Training Progress:   2%|▏         | 3109/150000 [02:10<4:09:59,  9.79it/s]

step 3100: train loss 0.7092, val loss 0.7101


Training Progress:   2%|▏         | 3211/150000 [02:14<4:09:31,  9.80it/s]

step 3200: train loss 0.7127, val loss 0.7119


Training Progress:   2%|▏         | 3307/150000 [02:18<4:10:37,  9.75it/s]

step 3300: train loss 0.7107, val loss 0.7105


Training Progress:   2%|▏         | 3409/150000 [02:22<4:10:41,  9.75it/s]

step 3400: train loss 0.7104, val loss 0.7108


Training Progress:   2%|▏         | 3511/150000 [02:27<4:09:44,  9.78it/s]

step 3500: train loss 0.7106, val loss 0.7105


Training Progress:   2%|▏         | 3607/150000 [02:31<4:10:09,  9.75it/s]

step 3600: train loss 0.7103, val loss 0.7101


Training Progress:   2%|▏         | 3709/150000 [02:35<4:10:26,  9.74it/s]

step 3700: train loss 0.7101, val loss 0.7099


Training Progress:   3%|▎         | 3811/150000 [02:39<4:09:29,  9.77it/s]

step 3800: train loss 0.7095, val loss 0.7092


Training Progress:   3%|▎         | 3907/150000 [02:43<4:10:19,  9.73it/s]

step 3900: train loss 0.7105, val loss 0.7093


Training Progress:   3%|▎         | 4009/150000 [02:48<4:09:54,  9.74it/s]

step 4000: train loss 0.7103, val loss 0.7097


Training Progress:   3%|▎         | 4111/150000 [02:52<4:10:09,  9.72it/s]

step 4100: train loss 0.9094, val loss 0.9091


Training Progress:   3%|▎         | 4207/150000 [02:56<4:10:44,  9.69it/s]

step 4200: train loss 0.9072, val loss 0.9067


Training Progress:   3%|▎         | 4309/150000 [03:00<4:09:35,  9.73it/s]

step 4300: train loss 0.9070, val loss 0.9066


Training Progress:   3%|▎         | 4411/150000 [03:05<4:09:45,  9.72it/s]

step 4400: train loss 0.9061, val loss 0.9062


Training Progress:   3%|▎         | 4507/150000 [03:09<4:10:39,  9.67it/s]

step 4500: train loss 0.9072, val loss 0.9071


Training Progress:   3%|▎         | 4609/150000 [03:13<4:09:42,  9.70it/s]

step 4600: train loss 0.9069, val loss 0.9068


Training Progress:   3%|▎         | 4711/150000 [03:17<4:10:10,  9.68it/s]

step 4700: train loss 0.9066, val loss 0.9072


Training Progress:   3%|▎         | 4807/150000 [03:21<4:09:54,  9.68it/s]

step 4800: train loss 0.9066, val loss 0.9060


Training Progress:   3%|▎         | 4909/150000 [03:26<4:09:26,  9.69it/s]

step 4900: train loss 0.9065, val loss 0.9060


Training Progress:   3%|▎         | 5011/150000 [03:30<4:09:18,  9.69it/s]

step 5000: train loss 0.9059, val loss 0.9056


Training Progress:   3%|▎         | 5107/150000 [03:34<4:09:59,  9.66it/s]

step 5100: train loss 0.9072, val loss 0.9063


Training Progress:   3%|▎         | 5209/150000 [03:38<4:09:31,  9.67it/s]

step 5200: train loss 0.9067, val loss 0.9062


Training Progress:   4%|▎         | 5311/150000 [03:43<4:09:09,  9.68it/s]

step 5300: train loss 0.9061, val loss 0.9067


Training Progress:   4%|▎         | 5407/150000 [03:47<4:09:41,  9.65it/s]

step 5400: train loss 0.9067, val loss 0.9062


Training Progress:   4%|▎         | 5509/150000 [03:51<4:09:22,  9.66it/s]

step 5500: train loss 0.9056, val loss 0.9060


Training Progress:   4%|▎         | 5611/150000 [03:56<4:10:02,  9.62it/s]

step 5600: train loss 0.9062, val loss 0.9058


Training Progress:   4%|▍         | 5707/150000 [04:00<4:08:56,  9.66it/s]

step 5700: train loss 0.9066, val loss 0.9060


Training Progress:   4%|▍         | 5809/150000 [04:04<4:08:49,  9.66it/s]

step 5800: train loss 0.9060, val loss 0.9067


Training Progress:   4%|▍         | 5911/150000 [04:08<4:10:15,  9.60it/s]

step 5900: train loss 0.9056, val loss 0.9062


Training Progress:   4%|▍         | 6007/150000 [04:12<4:08:32,  9.66it/s]

step 6000: train loss 0.9069, val loss 0.9066


Training Progress:   4%|▍         | 6109/150000 [04:17<4:09:30,  9.61it/s]

step 6100: train loss 1.0328, val loss 1.0331


Training Progress:   4%|▍         | 6211/150000 [04:21<4:09:10,  9.62it/s]

step 6200: train loss 1.0285, val loss 1.0280


Training Progress:   4%|▍         | 6307/150000 [04:25<4:09:04,  9.62it/s]

step 6300: train loss 1.0275, val loss 1.0275


Training Progress:   4%|▍         | 6409/150000 [04:30<4:08:31,  9.63it/s]

step 6400: train loss 1.0271, val loss 1.0273


Training Progress:   4%|▍         | 6511/150000 [04:34<4:08:23,  9.63it/s]

step 6500: train loss 1.0270, val loss 1.0274


Training Progress:   4%|▍         | 6607/150000 [04:38<4:08:39,  9.61it/s]

step 6600: train loss 1.0278, val loss 1.0276


Training Progress:   4%|▍         | 6709/150000 [04:42<4:08:29,  9.61it/s]

step 6700: train loss 1.0273, val loss 1.0272


Training Progress:   5%|▍         | 6811/150000 [04:47<4:07:57,  9.62it/s]

step 6800: train loss 1.0273, val loss 1.0277


Training Progress:   5%|▍         | 6907/150000 [04:51<4:08:11,  9.61it/s]

step 6900: train loss 1.0274, val loss 1.0275


Training Progress:   5%|▍         | 7009/150000 [04:55<4:08:31,  9.59it/s]

step 7000: train loss 1.0275, val loss 1.0276


Training Progress:   5%|▍         | 7111/150000 [05:00<4:07:41,  9.61it/s]

step 7100: train loss 1.0278, val loss 1.0268


Training Progress:   5%|▍         | 7207/150000 [05:04<4:07:46,  9.61it/s]

step 7200: train loss 1.0276, val loss 1.0269


Training Progress:   5%|▍         | 7309/150000 [05:08<4:07:14,  9.62it/s]

step 7300: train loss 1.0267, val loss 1.0270


Training Progress:   5%|▍         | 7411/150000 [05:12<4:06:54,  9.62it/s]

step 7400: train loss 1.0268, val loss 1.0271


Training Progress:   5%|▌         | 7507/150000 [05:17<4:08:02,  9.57it/s]

step 7500: train loss 1.0272, val loss 1.0272


Training Progress:   5%|▌         | 7609/150000 [05:21<4:06:48,  9.62it/s]

step 7600: train loss 1.0282, val loss 1.0273


Training Progress:   5%|▌         | 7711/150000 [05:25<4:06:33,  9.62it/s]

step 7700: train loss 1.0273, val loss 1.0270


Training Progress:   5%|▌         | 7807/150000 [05:30<4:07:46,  9.56it/s]

step 7800: train loss 1.0270, val loss 1.0272


Training Progress:   5%|▌         | 7909/150000 [05:34<4:06:07,  9.62it/s]

step 7900: train loss 1.0273, val loss 1.0272


Training Progress:   5%|▌         | 8011/150000 [05:38<4:05:46,  9.63it/s]

step 8000: train loss 1.0267, val loss 1.0266


Training Progress:   5%|▌         | 8107/150000 [05:42<4:07:59,  9.54it/s]

step 8100: train loss 1.1119, val loss 1.1119


Training Progress:   5%|▌         | 8209/150000 [05:47<4:06:21,  9.59it/s]

step 8200: train loss 1.1098, val loss 1.1107


Training Progress:   6%|▌         | 8311/150000 [05:51<4:06:35,  9.58it/s]

step 8300: train loss 1.1099, val loss 1.1092


Training Progress:   6%|▌         | 8407/150000 [05:55<4:07:16,  9.54it/s]

step 8400: train loss 1.1089, val loss 1.1101


Training Progress:   6%|▌         | 8509/150000 [06:00<4:06:32,  9.57it/s]

step 8500: train loss 1.1099, val loss 1.1094


Training Progress:   6%|▌         | 8611/150000 [06:04<4:07:17,  9.53it/s]

step 8600: train loss 1.1098, val loss 1.1099


Training Progress:   6%|▌         | 8707/150000 [06:08<4:07:28,  9.52it/s]

step 8700: train loss 1.1092, val loss 1.1099


Training Progress:   6%|▌         | 8809/150000 [06:13<4:06:48,  9.53it/s]

step 8800: train loss 1.1096, val loss 1.1098


Training Progress:   6%|▌         | 8905/150000 [06:17<5:35:52,  7.00it/s]

step 8900: train loss 1.1094, val loss 1.1095


Training Progress:   6%|▌         | 9007/150000 [06:21<4:07:21,  9.50it/s]

step 9000: train loss 1.1092, val loss 1.1092


Training Progress:   6%|▌         | 9109/150000 [06:26<4:07:30,  9.49it/s]

step 9100: train loss 1.1093, val loss 1.1090


Training Progress:   6%|▌         | 9205/150000 [06:30<5:33:39,  7.03it/s]

step 9200: train loss 1.1091, val loss 1.1097


Training Progress:   6%|▌         | 9307/150000 [06:34<4:06:42,  9.50it/s]

step 9300: train loss 1.1089, val loss 1.1091


Training Progress:   6%|▋         | 9409/150000 [06:39<4:06:15,  9.52it/s]

step 9400: train loss 1.1085, val loss 1.1091


Training Progress:   6%|▋         | 9505/150000 [06:43<5:32:06,  7.05it/s]

step 9500: train loss 1.1086, val loss 1.1091


Training Progress:   6%|▋         | 9607/150000 [06:47<4:05:14,  9.54it/s]

step 9600: train loss 1.1094, val loss 1.1081


Training Progress:   6%|▋         | 9709/150000 [06:52<4:05:20,  9.53it/s]

step 9700: train loss 1.1089, val loss 1.1090


Training Progress:   7%|▋         | 9805/150000 [06:56<5:29:57,  7.08it/s]

step 9800: train loss 1.1081, val loss 1.1086


Training Progress:   7%|▋         | 9907/150000 [07:00<4:04:28,  9.55it/s]

step 9900: train loss 1.1093, val loss 1.1089


Training Progress:   7%|▋         | 10009/150000 [07:04<4:04:11,  9.56it/s]

step 10000: train loss 1.1089, val loss 1.1091


Training Progress:   7%|▋         | 10105/150000 [07:09<5:29:12,  7.08it/s]

step 10100: train loss 1.1700, val loss 1.1694


Training Progress:   7%|▋         | 10207/150000 [07:13<4:03:33,  9.57it/s]

step 10200: train loss 1.1691, val loss 1.1689


Training Progress:   7%|▋         | 10309/150000 [07:17<4:03:31,  9.56it/s]

step 10300: train loss 1.1682, val loss 1.1689


Training Progress:   7%|▋         | 10411/150000 [07:22<4:03:09,  9.57it/s]

step 10400: train loss 1.1688, val loss 1.1686


Training Progress:   7%|▋         | 10507/150000 [07:26<4:03:49,  9.54it/s]

step 10500: train loss 1.1686, val loss 1.1690


Training Progress:   7%|▋         | 10609/150000 [07:30<4:02:57,  9.56it/s]

step 10600: train loss 1.1685, val loss 1.1687


Training Progress:   7%|▋         | 10705/150000 [07:35<5:28:40,  7.06it/s]

step 10700: train loss 1.1677, val loss 1.1687


Training Progress:   7%|▋         | 10807/150000 [07:39<4:04:01,  9.51it/s]

step 10800: train loss 1.1684, val loss 1.1684


Training Progress:   7%|▋         | 10909/150000 [07:43<4:02:42,  9.55it/s]

step 10900: train loss 1.1684, val loss 1.1683


Training Progress:   7%|▋         | 11005/150000 [07:48<5:28:27,  7.05it/s]

step 11000: train loss 1.1684, val loss 1.1687


Training Progress:   7%|▋         | 11107/150000 [07:52<4:02:35,  9.54it/s]

step 11100: train loss 1.1678, val loss 1.1682


Training Progress:   7%|▋         | 11209/150000 [07:56<4:02:21,  9.54it/s]

step 11200: train loss 1.1686, val loss 1.1682


Training Progress:   8%|▊         | 11305/150000 [08:00<5:28:27,  7.04it/s]

step 11300: train loss 1.1681, val loss 1.1676


Training Progress:   8%|▊         | 11407/150000 [08:05<4:02:16,  9.53it/s]

step 11400: train loss 1.1684, val loss 1.1684


Training Progress:   8%|▊         | 11509/150000 [08:09<4:02:28,  9.52it/s]

step 11500: train loss 1.1682, val loss 1.1685


Training Progress:   8%|▊         | 11605/150000 [08:13<5:29:14,  7.01it/s]

step 11600: train loss 1.1678, val loss 1.1687


Training Progress:   8%|▊         | 11707/150000 [08:18<4:01:43,  9.54it/s]

step 11700: train loss 1.1687, val loss 1.1682


Training Progress:   8%|▊         | 11809/150000 [08:22<4:01:37,  9.53it/s]

step 11800: train loss 1.1686, val loss 1.1686


Training Progress:   8%|▊         | 11905/150000 [08:26<5:25:40,  7.07it/s]

step 11900: train loss 1.1685, val loss 1.1672


Training Progress:   8%|▊         | 12007/150000 [08:31<4:00:58,  9.54it/s]

step 12000: train loss 1.1680, val loss 1.1677


Training Progress:   8%|▊         | 12109/150000 [08:35<4:05:38,  9.36it/s]

step 12100: train loss 1.2460, val loss 1.2472


Training Progress:   8%|▊         | 12205/150000 [08:39<5:31:05,  6.94it/s]

step 12200: train loss 1.2409, val loss 1.2410


Training Progress:   8%|▊         | 12307/150000 [08:44<4:04:30,  9.39it/s]

step 12300: train loss 1.2376, val loss 1.2386


Training Progress:   8%|▊         | 12409/150000 [08:48<4:06:08,  9.32it/s]

step 12400: train loss 1.2366, val loss 1.2372


Training Progress:   8%|▊         | 12511/150000 [08:53<4:04:41,  9.36it/s]

step 12500: train loss 1.2360, val loss 1.2365


Training Progress:   8%|▊         | 12607/150000 [08:57<4:03:42,  9.40it/s]

step 12600: train loss 1.2362, val loss 1.2376


Training Progress:   8%|▊         | 12709/150000 [09:01<4:04:22,  9.36it/s]

step 12700: train loss 1.2374, val loss 1.2382


Training Progress:   9%|▊         | 12811/150000 [09:06<4:03:16,  9.40it/s]

step 12800: train loss 1.2364, val loss 1.2362


Training Progress:   9%|▊         | 12907/150000 [09:10<4:03:39,  9.38it/s]

step 12900: train loss 1.2352, val loss 1.2358


Training Progress:   9%|▊         | 13009/150000 [09:14<4:03:47,  9.37it/s]

step 13000: train loss 1.2339, val loss 1.2366


Training Progress:   9%|▊         | 13105/150000 [09:19<5:29:07,  6.93it/s]

step 13100: train loss 1.2367, val loss 1.2366


Training Progress:   9%|▉         | 13207/150000 [09:23<4:04:12,  9.34it/s]

step 13200: train loss 1.2377, val loss 1.2382


Training Progress:   9%|▉         | 13309/150000 [09:28<4:02:54,  9.38it/s]

step 13300: train loss 1.2388, val loss 1.2395


Training Progress:   9%|▉         | 13405/150000 [09:32<5:28:16,  6.94it/s]

step 13400: train loss 1.2514, val loss 1.2512


Training Progress:   9%|▉         | 13507/150000 [09:36<4:03:07,  9.36it/s]

step 13500: train loss 1.2384, val loss 1.2380


Training Progress:   9%|▉         | 13609/150000 [09:41<4:01:42,  9.40it/s]

step 13600: train loss 1.2360, val loss 1.2369


Training Progress:   9%|▉         | 13705/150000 [09:45<5:28:54,  6.91it/s]

step 13700: train loss 1.2362, val loss 1.2350


Training Progress:   9%|▉         | 13807/150000 [09:49<4:02:49,  9.35it/s]

step 13800: train loss 1.2362, val loss 1.2357


Training Progress:   9%|▉         | 13909/150000 [09:54<4:01:13,  9.40it/s]

step 13900: train loss 1.2357, val loss 1.2360


Training Progress:   9%|▉         | 14005/150000 [09:58<5:29:40,  6.88it/s]

step 14000: train loss 1.2352, val loss 1.2362


Training Progress:   9%|▉         | 14107/150000 [10:02<4:01:18,  9.39it/s]

step 14100: train loss 1.2364, val loss 1.2362


Training Progress:   9%|▉         | 14209/150000 [10:07<4:00:24,  9.41it/s]

step 14200: train loss 1.2360, val loss 1.2353


Training Progress:  10%|▉         | 14311/150000 [10:11<4:02:45,  9.32it/s]

step 14300: train loss 1.2357, val loss 1.2365


Training Progress:  10%|▉         | 14407/150000 [10:16<4:00:16,  9.41it/s]

step 14400: train loss 1.2361, val loss 1.2365


Training Progress:  10%|▉         | 14509/150000 [10:20<4:01:30,  9.35it/s]

step 14500: train loss 1.2377, val loss 1.2396


Training Progress:  10%|▉         | 14605/150000 [10:24<5:25:17,  6.94it/s]

step 14600: train loss 1.2414, val loss 1.2401


Training Progress:  10%|▉         | 14707/150000 [10:29<3:59:41,  9.41it/s]

step 14700: train loss 1.2369, val loss 1.2382


Training Progress:  10%|▉         | 14809/150000 [10:33<4:01:10,  9.34it/s]

step 14800: train loss 1.2384, val loss 1.2365


Training Progress:  10%|▉         | 14911/150000 [10:37<3:59:17,  9.41it/s]

step 14900: train loss 1.2375, val loss 1.2365


Training Progress:  10%|█         | 15007/150000 [10:42<3:59:58,  9.38it/s]

step 15000: train loss 1.2353, val loss 1.2353


Training Progress:  10%|█         | 15109/150000 [10:46<4:01:19,  9.32it/s]

step 15100: train loss 1.2366, val loss 1.2348


Training Progress:  10%|█         | 15205/150000 [10:50<5:25:13,  6.91it/s]

step 15200: train loss 1.2362, val loss 1.2367


Training Progress:  10%|█         | 15307/150000 [10:55<3:59:36,  9.37it/s]

step 15300: train loss 1.2356, val loss 1.2370


Training Progress:  10%|█         | 15409/150000 [10:59<3:58:35,  9.40it/s]

step 15400: train loss 1.2362, val loss 1.2360


Training Progress:  10%|█         | 15505/150000 [11:03<5:23:30,  6.93it/s]

step 15500: train loss 1.2362, val loss 1.2357


Training Progress:  10%|█         | 15607/150000 [11:08<3:59:56,  9.33it/s]

step 15600: train loss 1.2373, val loss 1.2376


Training Progress:  10%|█         | 15709/150000 [11:12<3:59:16,  9.35it/s]

step 15700: train loss 1.2359, val loss 1.2362


Training Progress:  11%|█         | 15805/150000 [11:17<5:22:54,  6.93it/s]

step 15800: train loss 1.2386, val loss 1.2398


Training Progress:  11%|█         | 15907/150000 [11:21<3:58:45,  9.36it/s]

step 15900: train loss 1.2377, val loss 1.2372


Training Progress:  11%|█         | 16009/150000 [11:25<3:57:41,  9.40it/s]

step 16000: train loss 1.2380, val loss 1.2391


Training Progress:  11%|█         | 16105/150000 [11:30<5:22:00,  6.93it/s]

step 16100: train loss 1.2371, val loss 1.2372


Training Progress:  11%|█         | 16207/150000 [11:34<3:57:19,  9.40it/s]

step 16200: train loss 1.2361, val loss 1.2363


Training Progress:  11%|█         | 16309/150000 [11:38<3:56:49,  9.41it/s]

step 16300: train loss 1.2361, val loss 1.2356


Training Progress:  11%|█         | 16405/150000 [11:43<5:23:04,  6.89it/s]

step 16400: train loss 1.2357, val loss 1.2368


Training Progress:  11%|█         | 16507/150000 [11:47<3:57:58,  9.35it/s]

step 16500: train loss 1.2348, val loss 1.2364


Training Progress:  11%|█         | 16609/150000 [11:52<3:57:11,  9.37it/s]

step 16600: train loss 1.2363, val loss 1.2355


Training Progress:  11%|█         | 16705/150000 [11:56<5:24:32,  6.85it/s]

step 16700: train loss 1.2348, val loss 1.2350


Training Progress:  11%|█         | 16807/150000 [12:00<3:58:37,  9.30it/s]

step 16800: train loss 1.2339, val loss 1.2365


Training Progress:  11%|█▏        | 16909/150000 [12:05<3:58:12,  9.31it/s]

step 16900: train loss 1.2353, val loss 1.2346


Training Progress:  11%|█▏        | 17005/150000 [12:09<5:22:18,  6.88it/s]

step 17000: train loss 1.2360, val loss 1.2357


Training Progress:  11%|█▏        | 17107/150000 [12:14<3:57:34,  9.32it/s]

step 17100: train loss 1.2364, val loss 1.2359


Training Progress:  11%|█▏        | 17209/150000 [12:18<4:00:15,  9.21it/s]

step 17200: train loss 1.2353, val loss 1.2346


Training Progress:  12%|█▏        | 17305/150000 [12:22<5:20:14,  6.91it/s]

step 17300: train loss 1.2353, val loss 1.2357


Training Progress:  12%|█▏        | 17407/150000 [12:27<3:56:04,  9.36it/s]

step 17400: train loss 1.2352, val loss 1.2361


Training Progress:  12%|█▏        | 17509/150000 [12:31<3:56:32,  9.33it/s]

step 17500: train loss 1.2369, val loss 1.2348


Training Progress:  12%|█▏        | 17605/150000 [12:35<5:19:22,  6.91it/s]

step 17600: train loss 1.2449, val loss 1.2440


Training Progress:  12%|█▏        | 17707/150000 [12:40<3:55:26,  9.36it/s]

step 17700: train loss 1.2529, val loss 1.2514


Training Progress:  12%|█▏        | 17809/150000 [12:44<3:54:57,  9.38it/s]

step 17800: train loss 1.2364, val loss 1.2365


Training Progress:  12%|█▏        | 17905/150000 [12:49<5:17:23,  6.94it/s]

step 17900: train loss 1.2353, val loss 1.2365


Training Progress:  12%|█▏        | 18007/150000 [12:53<3:54:51,  9.37it/s]

step 18000: train loss 1.2351, val loss 1.2358


Training Progress:  12%|█▏        | 18109/150000 [12:57<3:54:18,  9.38it/s]

step 18100: train loss 1.2359, val loss 1.2361


Training Progress:  12%|█▏        | 18205/150000 [13:02<5:16:20,  6.94it/s]

step 18200: train loss 1.2344, val loss 1.2346


Training Progress:  12%|█▏        | 18307/150000 [13:06<3:55:28,  9.32it/s]

step 18300: train loss 1.2356, val loss 1.2343


Training Progress:  12%|█▏        | 18409/150000 [13:10<3:53:24,  9.40it/s]

step 18400: train loss 1.2355, val loss 1.2359


Training Progress:  12%|█▏        | 18505/150000 [13:15<5:15:51,  6.94it/s]

step 18500: train loss 1.2357, val loss 1.2362


Training Progress:  12%|█▏        | 18607/150000 [13:19<3:54:09,  9.35it/s]

step 18600: train loss 1.2371, val loss 1.2361


Training Progress:  12%|█▏        | 18709/150000 [13:24<3:53:19,  9.38it/s]

step 18700: train loss 1.2359, val loss 1.2362


Training Progress:  13%|█▎        | 18805/150000 [13:28<5:16:02,  6.92it/s]

step 18800: train loss 1.2368, val loss 1.2366


Training Progress:  13%|█▎        | 18907/150000 [13:32<3:52:54,  9.38it/s]

step 18900: train loss 1.2361, val loss 1.2357


Training Progress:  13%|█▎        | 19009/150000 [13:37<3:52:45,  9.38it/s]

step 19000: train loss 1.2373, val loss 1.2374


Training Progress:  13%|█▎        | 19111/150000 [13:41<3:53:35,  9.34it/s]

step 19100: train loss 1.2365, val loss 1.2354


Training Progress:  13%|█▎        | 19207/150000 [13:45<3:52:28,  9.38it/s]

step 19200: train loss 1.2346, val loss 1.2361


Training Progress:  13%|█▎        | 19309/150000 [13:50<3:52:31,  9.37it/s]

step 19300: train loss 1.2355, val loss 1.2367


Training Progress:  13%|█▎        | 19405/150000 [13:54<5:15:08,  6.91it/s]

step 19400: train loss 1.2453, val loss 1.2467


Training Progress:  13%|█▎        | 19507/150000 [13:58<3:52:52,  9.34it/s]

step 19500: train loss 1.2373, val loss 1.2369


Training Progress:  13%|█▎        | 19609/150000 [14:03<3:52:38,  9.34it/s]

step 19600: train loss 1.2370, val loss 1.2373


Training Progress:  13%|█▎        | 19705/150000 [14:07<5:12:49,  6.94it/s]

step 19700: train loss 1.2357, val loss 1.2363


Training Progress:  13%|█▎        | 19807/150000 [14:12<3:52:30,  9.33it/s]

step 19800: train loss 1.2353, val loss 1.2358


Training Progress:  13%|█▎        | 19909/150000 [14:16<3:52:05,  9.34it/s]

step 19900: train loss 1.2365, val loss 1.2359


Training Progress:  13%|█▎        | 20011/150000 [14:20<3:50:53,  9.38it/s]

step 20000: train loss 1.2356, val loss 1.2360


Training Progress:  13%|█▎        | 20107/150000 [14:25<3:50:27,  9.39it/s]

step 20100: train loss 1.2358, val loss 1.2364


Training Progress:  13%|█▎        | 20209/150000 [14:29<3:50:43,  9.38it/s]

step 20200: train loss 1.2356, val loss 1.2364


Training Progress:  14%|█▎        | 20311/150000 [14:33<3:50:00,  9.40it/s]

step 20300: train loss 1.2350, val loss 1.2346


Training Progress:  14%|█▎        | 20407/150000 [14:38<3:50:05,  9.39it/s]

step 20400: train loss 1.2358, val loss 1.2358


Training Progress:  14%|█▎        | 20509/150000 [14:42<3:49:19,  9.41it/s]

step 20500: train loss 1.2356, val loss 1.2357


Training Progress:  14%|█▎        | 20611/150000 [14:47<3:49:44,  9.39it/s]

step 20600: train loss 1.2364, val loss 1.2360


Training Progress:  14%|█▍        | 20707/150000 [14:51<3:50:12,  9.36it/s]

step 20700: train loss 1.2380, val loss 1.2370


Training Progress:  14%|█▍        | 20809/150000 [14:55<3:49:49,  9.37it/s]

step 20800: train loss 1.2385, val loss 1.2371


Training Progress:  14%|█▍        | 20905/150000 [14:59<5:09:32,  6.95it/s]

step 20900: train loss 1.2390, val loss 1.2379


Training Progress:  14%|█▍        | 21007/150000 [15:04<3:49:19,  9.38it/s]

step 21000: train loss 1.2371, val loss 1.2373


Training Progress:  14%|█▍        | 21109/150000 [15:08<3:47:50,  9.43it/s]

step 21100: train loss 1.2349, val loss 1.2368


Training Progress:  14%|█▍        | 21211/150000 [15:13<3:48:20,  9.40it/s]

step 21200: train loss 1.2354, val loss 1.2349


Training Progress:  14%|█▍        | 21307/150000 [15:17<3:48:56,  9.37it/s]

step 21300: train loss 1.2361, val loss 1.2355


Training Progress:  14%|█▍        | 21409/150000 [15:21<3:48:47,  9.37it/s]

step 21400: train loss 1.2360, val loss 1.2356


Training Progress:  14%|█▍        | 21505/150000 [15:26<5:09:55,  6.91it/s]

step 21500: train loss 1.2345, val loss 1.2364


Training Progress:  14%|█▍        | 21607/150000 [15:30<3:48:08,  9.38it/s]

step 21600: train loss 1.2350, val loss 1.2357


Training Progress:  14%|█▍        | 21709/150000 [15:34<3:47:43,  9.39it/s]

step 21700: train loss 1.2365, val loss 1.2369


Training Progress:  15%|█▍        | 21805/150000 [15:39<5:07:45,  6.94it/s]

step 21800: train loss 1.2436, val loss 1.2420


Training Progress:  15%|█▍        | 21907/150000 [15:43<3:47:17,  9.39it/s]

step 21900: train loss 1.2369, val loss 1.2380


Training Progress:  15%|█▍        | 22009/150000 [15:47<3:47:42,  9.37it/s]

step 22000: train loss 1.2380, val loss 1.2383


Training Progress:  15%|█▍        | 22111/150000 [15:52<3:47:27,  9.37it/s]

step 22100: train loss 1.2352, val loss 1.2353


Training Progress:  15%|█▍        | 22207/150000 [15:56<3:47:24,  9.37it/s]

step 22200: train loss 1.2363, val loss 1.2361


Training Progress:  15%|█▍        | 22309/150000 [16:01<3:48:45,  9.30it/s]

step 22300: train loss 1.2362, val loss 1.2359


Training Progress:  15%|█▍        | 22405/150000 [16:05<5:07:10,  6.92it/s]

step 22400: train loss 1.2359, val loss 1.2360


Training Progress:  15%|█▌        | 22507/150000 [16:09<3:46:58,  9.36it/s]

step 22500: train loss 1.2363, val loss 1.2356


Training Progress:  15%|█▌        | 22609/150000 [16:14<3:46:59,  9.35it/s]

step 22600: train loss 1.2339, val loss 1.2357


Training Progress:  15%|█▌        | 22705/150000 [16:18<5:06:36,  6.92it/s]

step 22700: train loss 1.2375, val loss 1.2369


Training Progress:  15%|█▌        | 22807/150000 [16:22<3:46:50,  9.35it/s]

step 22800: train loss 1.2351, val loss 1.2362


Training Progress:  15%|█▌        | 22909/150000 [16:27<3:45:37,  9.39it/s]

step 22900: train loss 1.2370, val loss 1.2383


Training Progress:  15%|█▌        | 23011/150000 [16:31<3:44:55,  9.41it/s]

step 23000: train loss 1.2393, val loss 1.2412


Training Progress:  15%|█▌        | 23107/150000 [16:36<3:46:29,  9.34it/s]

step 23100: train loss 1.2370, val loss 1.2364


Training Progress:  15%|█▌        | 23209/150000 [16:40<3:44:55,  9.40it/s]

step 23200: train loss 1.2355, val loss 1.2348


Training Progress:  16%|█▌        | 23300/150000 [16:44<1:31:02, 23.19it/s]

step 23300: train loss 1.2363, val loss 1.2360
Early Stopping at iteration 23300





In [94]:
wandb.finish()

0,1
Loss,▁▁▁▁▁▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇█████████████████

0,1
Loss,1.23479


In [None]:
fig = go.Figure()

iterations = list(range(58))

# Add lines
fig.add_trace(go.Scatter(x=iterations, y=val_loss_list, mode='lines+markers', name='Validation Loss'))
fig.add_trace(go.Scatter(x=iterations, y=acc_test_list, mode='lines+markers', name='Test Accuracy'))
fig.add_trace(go.Scatter(x=iterations, y=acc_all_list, mode='lines+markers', name='All Accuracy'))

# Set title label
fig.update_layout(
    title="Validation Loss & Accuracy over Iterations",
    xaxis_title="Iterations",
    yaxis_title="Metrics",
    legend_title="Metrics",
    width = 900,
    height = 500
)

fig.show()

wandb.init(project="transformer", name="val_loss and accuracy plot")
wandb.log({"Interactive Chart": wandb.Html(fig.to_html())})
wandb.finish()

In [75]:
def accuracy_print(model, num_digits, need_print=False):
        correct = 0

        for j in range(100):
            a = np.random.choice(np.arange(10**num_digits), 1)
            b = np.random.choice(np.arange(10**num_digits), 1)
            c = a + b
            reversed_c = np.array([str(x)[::-1] for x in c])
            input = f"{a.item()}+{b.item()}="
            context = torch.tensor(encode(input), dtype=torch.long, device=device)
            output = generate(model, context, 100, 1)
            if need_print:
                print(f"Input: {input}")
                print(f"Output: {output}")
                print(f"Expected: {a.item()}+{b.item()}={reversed_c.item()}")
                print("-----------")
            if output == f"{a.item()}+{b.item()}={reversed_c.item()}":
                correct += 1
        acc = correct / 100
        print(f"Accuracy for {num_digits} digits addition: {acc} ")
        return acc

In [76]:
accuracy_print(model, 1, need_print=True)

Input: 6+2=
Output: 6+2=8
Expected: 6+2=8
-----------
Input: 9+3=
Output: 9+3=21
Expected: 9+3=21
-----------
Input: 5+3=
Output: 5+3=8
Expected: 5+3=8
-----------
Input: 7+8=
Output: 7+8=51
Expected: 7+8=51
-----------
Input: 0+6=
Output: 0+6=6
Expected: 0+6=6
-----------
Input: 5+8=
Output: 5+8=31
Expected: 5+8=31
-----------
Input: 0+1=
Output: 0+1=1
Expected: 0+1=1
-----------
Input: 6+4=
Output: 6+4=01
Expected: 6+4=01
-----------
Input: 2+1=
Output: 2+1=3
Expected: 2+1=3
-----------
Input: 5+6=
Output: 5+6=11
Expected: 5+6=11
-----------
Input: 7+9=
Output: 7+9=61
Expected: 7+9=61
-----------
Input: 5+9=
Output: 5+9=41
Expected: 5+9=41
-----------
Input: 4+9=
Output: 4+9=31
Expected: 4+9=31
-----------
Input: 5+6=
Output: 5+6=11
Expected: 5+6=11
-----------
Input: 0+4=
Output: 0+4=4
Expected: 0+4=4
-----------
Input: 5+5=
Output: 5+5=01
Expected: 5+5=01
-----------
Input: 0+3=
Output: 0+3=3
Expected: 0+3=3
-----------
Input: 5+4=
Output: 5+4=9
Expected: 5+4=9
-----------
Input: 0

0.95

In [77]:
accuracy_print(model, 2)

Accuracy for 2 digits addition: 1.0 


1.0

In [78]:
accuracy_print(model, 3)

Accuracy for 3 digits addition: 1.0 


1.0

In [79]:
accuracy_print(model, 4)

Accuracy for 4 digits addition: 0.99 


0.99

In [80]:
accuracy_print(model, 5)

Accuracy for 5 digits addition: 0.94 


0.94

In [81]:
accuracy_print(model, 6)

Accuracy for 6 digits addition: 0.8 


0.8

In [82]:
accuracy_print(model, 7)

Accuracy for 7 digits addition: 0.01 


0.01

In [83]:
accuracy_print(model, 8)

Accuracy for 8 digits addition: 0.0 


0.0

In [84]:
def get_avg_performance(model):
    dict_acc = {}
    for num_dig in range(1, 9):
        dict_acc[num_dig] = accuracy_print(model, num_dig, need_print=False)
    return dict_acc

In [95]:
avg_performance = get_avg_performance(model)

Accuracy for 1 digits addition: 0.97 
Accuracy for 2 digits addition: 1.0 
Accuracy for 3 digits addition: 1.0 
Accuracy for 4 digits addition: 0.98 
Accuracy for 5 digits addition: 1.0 
Accuracy for 6 digits addition: 1.0 
Accuracy for 7 digits addition: 0.02 
Accuracy for 8 digits addition: 0.0 


In [96]:
x_values = list(avg_performance.keys())
y_values = list(avg_performance.values())


fig = go.Figure(go.Bar(x=x_values, y=y_values, marker_color='MediumPurple'))


fig.update_layout(
    title="Accuracy for different digits addition Reverse Applied",
    xaxis_title="Num Digits",
    yaxis_title="Accuracy",
    template="plotly_white",
    width=800,
    height= 500
)


fig.show()

wandb.init(project="transformer", name="Accuracy for different digits addition plot Reverse Applied")
wandb.log({"Interactive Chart": wandb.Html(fig.to_html())})
wandb.finish()

In [None]:
import subprocess

os.system('git config --global user.email "zifeibai@umich.edu"')
os.system('git config --global user.name "ZifeiBai"')

# 2️⃣ **Use Google Drive to store GitHub Token**
GITHUB_TOKEN_PATH = "/content/drive/MyDrive/URPS/github_token.txt"
if os.path.exists(GITHUB_TOKEN_PATH):
    with open(GITHUB_TOKEN_PATH, "r") as f:
        os.environ["GITHUB_TOKEN"] = f.read().strip()
else:
    print("❌ GitHub Token")
    exit(1)

# 3️⃣ **Set up GitHub remote repo**
GIT_PATH = "/content/drive/MyDrive/URPS/Git"
REPO_URL = f"https://{os.environ['GITHUB_TOKEN']}@github.com/ZifeiBai/URPS.git"

if not os.path.exists(GIT_PATH):
    print(f"📁 Creating directory: {GIT_PATH}")
    os.makedirs(GIT_PATH)

# 4️⃣ **If .git/ does not exsit， need to clone**
if not os.path.exists(os.path.join(GIT_PATH, ".git")):
    print("❌ Git repository not found. Cloning...")
    subprocess.run(f"rm -rf {GIT_PATH}", shell=True, check=True)
    subprocess.run(f"git clone {REPO_URL} {GIT_PATH}", shell=True, check=True)

# 5️⃣ **Enter Git repo**
os.chdir(GIT_PATH)
print("📂 Changed working directory to:", os.getcwd())


# 6️⃣ **Check Git status**
status_output = subprocess.run("git status", shell=True, capture_output=True, text=True)
print(status_output.stdout)

#  **Push to Git**
print("🚀 Adding files to Git...")
subprocess.run("git add .", shell=True, check=True)

print("📝 Committing changes...")
commit_output = subprocess.run('git commit -m "Auto update from Google Colab 2.6"', shell=True, capture_output=True, text=True)
print(commit_output.stdout)



print("📤 Pushing to GitHub...")
push_output = subprocess.run("git push origin main", shell=True, capture_output=True, text=True)
if "fatal" in push_output.stderr or "error:" in push_output.stderr:
    print("❌ Real Git Push Error:", push_output.stderr)
else:
    print("✅ Git Push Success!")

📂 Changed working directory to: /content/drive/MyDrive/URPS/Git
On branch main
Your branch is up to date with 'origin/main'.

Changes not staged for commit:
  (use "git add <file>..." to update what will be committed)
  (use "git restore <file>..." to discard changes in working directory)
	modified:   transformer.ipynb

no changes added to commit (use "git add" and/or "git commit -a")

🚀 Adding files to Git...
📝 Committing changes...
[main 60bbc33] Auto update from Google Colab 2.6
 1 file changed, 1 insertion(+), 1 deletion(-)
 rewrite transformer.ipynb (95%)

📤 Pushing to GitHub...
✅ Git Push Success!
