Model implementation from
http://nlp.seas.harvard.edu/annotated-transformer/#training

In [None]:
import tensorflow as tf

In [None]:
!pip install -q torchdata==0.3.0 spacy==3.2 altair GPUtil
!python -m spacy download de_core_news_sm
!python -m spacy download en_core_web_sm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting de-core-news-sm==3.2.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.2.0/de_core_news_sm-3.2.0-py3-none-any.whl (19.1 MB)
[K     |████████████████████████████████| 19.1 MB 1.2 MB/s 
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('de_core_news_sm')
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting en-core-web-sm==3.2.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.2.0/en_core_web_sm-3.2.0-py3-none-any.whl (13.9 MB)
[K     |████████████████████████████████| 13.9 MB 4.0 MB/s 
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


In [None]:
!pip install -U torchtext==0.12.0
!pip install torch==1.11.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import os
from os.path import exists
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, pad
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
import altair as alt
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
import torchtext.datasets as datasets
import spacy
import GPUtil
import warnings
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

warnings.filterwarnings("ignore")
RUN_EXAMPLES = True

In [None]:
def is_interactive_notebook():
  return __name__ == "__main__"

def show_example(fn, args=[]):
  if __name__ == "__main__" and RUN_EXAMPLES:
    return fn(*args)

class DummyOptimizer(torch.optim.Optimizer):
  def __init__(self):
    self.param_groups = [{"lr":0}]
    None

  def step(self):
    None
  
  def zero_grad(self, set_to_none=False):
    None
  

class DummyScheduler:
  def step(self):
    None

In [None]:
class EncoderDecoder(nn.Module):
  def __init__(self,encoder,decoder,src_embed,tgt_embed,generator):
    super(EncoderDecoder, self).__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_embed = src_embed
    self.tgt_embed = tgt_embed
    self.generator = generator

  def forward(self, src, tgt, src_mask, tgt_mask):
    return self.decode(self.encode(src,src_mask),src_mask,tgt,tgt_mask)
  
  def encode(self, src, src_mask):
    return self.encoder(self.src_embed(src),src_mask)
  
  def decode(self, memory, src_mask, tgt, tgt_mask):
    return self.decoder(self.tgt_embed(tgt),memory,src_mask,tgt_mask)

In [None]:
class Generator(nn.Module):
  def __init__(self, d_model, vocab):
    super(Generator, self).__init__()
    self.proj = nn.Linear(d_model, vocab)
  
  def forward(self,x):
    return log_softmax(self.proj(x),dim=-1)

In [None]:
def clones(module, N):
  return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
  def __init__(self,layer,N):
    super(Encoder, self).__init__()
    self.layers = clones(layer,N)
    self.norm = LayerNorm(layer.size)
  
  def forward(self,x,mask):
    for layer in self.layers:
      x = layer(x,mask)
    return self.norm(x)

In [None]:
class LayerNorm(nn.Module):
  def __init__(self,features,eps=1e-6):
    super(LayerNorm,self).__init__()
    self.a_2 = nn.Parameter(torch.ones(features))
    self.b_2 = nn.Parameter(torch.zeros(features))
    self.eps = eps

  def forward(self,x):
    mean = x.mean(-1,keepdim=True)
    std = x.std(-1,keepdim=True)
    return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [None]:
class SublayerConnection(nn.Module):
  def __init__(self,size,dropout):
    super(SublayerConnection,self).__init__()
    self.norm = LayerNorm(size)
    self.dropout = nn.Dropout(dropout)
  
  def forward(self,x,sublayer):
    return x + self.dropout(sublayer(self.norm(x)))

In [None]:
class EncoderLayer(nn.Module):
  def __init__(self,size,self_attn,feed_forward,dropout):
    super(EncoderLayer,self).__init__()
    self.self_attn = self_attn
    self.feed_forward = feed_forward
    self.sublayer = clones(SublayerConnection(size,dropout),2)
    self.size = size
  
  def forward(self,x,mask):
    x = self.sublayer[0](x,lambda x: self.self_attn(x,x,x,mask))
    return self.sublayer[1](x,self.feed_forward)

In [None]:
class Decoder(nn.Module):
  def __init__(self,layer,N):
    super(Decoder,self).__init__()
    self.layers = clones(layer,N)
    self.norm = LayerNorm(layer.size)

  def forward(self,x,memory,src_mask,tgt_mask):
    for layer in self.layers:
      x = layer(x,memory,src_mask,tgt_mask)
    return self.norm(x)

In [None]:
class DecoderLayer(nn.Module):
  def __init__(self,size,self_attn,src_attn,feed_forward,dropout):
    super(DecoderLayer,self).__init__()
    self.size = size
    self.self_attn = self_attn
    self.src_attn = src_attn
    self.feed_forward = feed_forward
    self.sublayer = clones(SublayerConnection(size,dropout),3)

  def forward(self,x,memory,src_mask,tgt_mask):
    m = memory
    x = self.sublayer[0](x,lambda x: self.self_attn(x,x,x,tgt_mask))
    x = self.sublayer[1](x,lambda x: self.src_attn(x,m,m,src_mask))
    return self.sublayer[2](x,self.feed_forward)

In [None]:
def subsequent_mask(size):
  attn_shape = (1,size,size)
  subsequent_mask = torch.triu(torch.ones(attn_shape),diagonal=1).type(torch.uint8)
  return subsequent_mask == 0

In [None]:
def attention(query,key,value,mask=None,dropout=None):
  d_k = query.size(-1)
  scores = torch.matmul(query,key.transpose(-2,-1)) / math.sqrt(d_k)
  if mask is not None:
    scores = scores.masked_fill(mask == 0, -1e9)
  p_attn = scores.softmax(dim=-1)
  if dropout is not None:
    p_attn = dropout(p_attn)
  return torch.matmul(p_attn,value),p_attn

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self,h,d_model,dropout=0.1):
    super(MultiHeadAttention,self).__init__()
    assert d_model % h == 0
    #we assume d_v always equals d_k
    self.d_k = d_model // h
    self.h = h
    self.linears = clones(nn.Linear(d_model,d_model),4)
    self.attn = None
    self.dropout = nn.Dropout(p=dropout)

  def forward(self,query,key,value,mask=None):
    if mask is not None:
      mask = mask.unsqueeze(1)
    nbatches = query.size(0)
    
    query,key,value = [
        lin(x).view(nbatches,-1,self.h,self.d_k).transpose(1,2)
        for lin,x in zip(self.linears,(query,key,value))
    ]
  
    x, self.attn = attention(
        query,key,value,mask=mask,dropout=self.dropout
    )

    x =  (
        x.transpose(1,2).contiguous().view(nbatches,-1,self.h * self.d_k)
    )

    del query
    del key
    del value
    return self.linears[-1](x)

In [None]:
class PositionWiseFeedForward(nn.Module):
  def __init__(self,d_model,d_ff,dropout=0.1):
    super(PositionWiseFeedForward,self).__init__()
    self.w_1 = nn.Linear(d_model, d_ff)
    self.w_2 = nn.Linear(d_ff,d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x):
    return self.w_2(self.dropout(self.w_1(x).relu()))

In [None]:
class Embeddings(nn.Module):
  def __init__(self, d_model,vocab):
    super(Embeddings,self).__init__()
    self.lut = nn.Embedding(vocab,d_model)
    self.d_model = d_model

  def forward(self,x):
    return self.lut(x) * math.sqrt(self.d_model)

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self,d_model,dropout,max_len=5000):
    super(PositionalEncoding,self).__init__()
    self.dropout = nn.Dropout(p=dropout)

    pe = torch.zeros(max_len,d_model)
    position = torch.arange(0,max_len).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0,d_model,2) * (-math.log(10000.0) / d_model)
    )
    pe[:,0::2] = torch.sin(position * div_term)
    pe[:,1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)
    self.register_buffer("pe",pe)
  
  def forward(self,x):
    x = x + self.pe[:, :x.size(1)].requires_grad_(False)
    return self.dropout(x)

In [None]:
def make_model(
    src_vocab, tgt_vocab, N=6, d_model=512,d_ff=2048,h=8,dropout=0.1
):
  c = copy.deepcopy
  attn = MultiHeadAttention(h,d_model)
  ff = PositionWiseFeedForward(d_model,d_ff,dropout)
  position = PositionalEncoding(d_model,dropout)
  model = EncoderDecoder(
      Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),N),
      Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),N),
      nn.Sequential(Embeddings(d_model,src_vocab),c(position)),
      nn.Sequential(Embeddings(d_model,tgt_vocab),c(position)),
      Generator(d_model,tgt_vocab),
  )

  #Weights initialization
  for p in model.parameters():
    if p.dim() > 1:
      nn.init.xavier_uniform_(p)

  return model

In [None]:
def inference_test():
  test_model = make_model(11,11,2)
  test_model.eval()
  src = torch.LongTensor([[1,2,3,4,5,6,7,8,9,10]])
  src_mask = torch.ones(1,1,10)

  memory = test_model.encode(src,src_mask)
  ys = torch.zeros(1,1).type_as(src)

  for i in range(10):
    out = test_model.decode(
        memory,src_mask,ys,subsequent_mask(ys.size(1)).type_as(src.data)
    )
    prob = test_model.generator(out[:,-1])
    _,next_word = torch.max(prob,dim=1)
    next_word = next_word.data[0]
    ys = torch.cat(
        [ys,torch.empty(1,1).type_as(src.data).fill_(next_word)],dim=1
    )
  
  print("Example Untrained Model Prediction: ",ys)


def run_tests():
  for _ in range(10):
    inference_test()


show_example(run_tests)

Example Untrained Model Prediction:  tensor([[ 0,  3,  8,  4, 10, 10, 10, 10, 10, 10, 10]])
Example Untrained Model Prediction:  tensor([[0, 8, 7, 9, 7, 9, 7, 9, 7, 9, 7]])
Example Untrained Model Prediction:  tensor([[ 0,  2, 10,  1,  1,  1,  1,  1,  1,  1,  1]])
Example Untrained Model Prediction:  tensor([[0, 3, 2, 1, 3, 2, 3, 2, 3, 2, 3]])
Example Untrained Model Prediction:  tensor([[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]])
Example Untrained Model Prediction:  tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
Example Untrained Model Prediction:  tensor([[0, 1, 7, 9, 0, 1, 1, 1, 1, 1, 7]])
Example Untrained Model Prediction:  tensor([[0, 7, 1, 7, 1, 7, 1, 1, 1, 1, 1]])
Example Untrained Model Prediction:  tensor([[ 0, 10,  9,  9,  2,  7,  4,  4,  9, 10,  5]])
Example Untrained Model Prediction:  tensor([[0, 3, 0, 2, 5, 7, 3, 0, 2, 5, 4]])


In [None]:
class Batch:
  def __init__(self,src,tgt=None,pad=2):
    self.src = src
    self.src_mask = (src != pad).unsqueeze(-2)
    if tgt is not None:
      self.tgt = tgt[:,:-1]
      self.tgt_y = tgt[:,1:]
      self.tgt_mask = self.make_std_mask(self.tgt, pad)
      self.ntokens = (self.tgt_y != pad).data.sum()

  @staticmethod
  def make_std_mask(tgt,pad):
    tgt_mask = (tgt != pad).unsqueeze(-2)
    tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)
    return tgt_mask

In [None]:
class TrainState:
  step: int = 0
  accum_step : int = 0
  samples : int = 0
  tokens : int = 0

In [None]:
def run_epoch(
    data_iter,model,loss_compute,optimizer,
    scheduler,mode="train",accum_iter=1,train_state=TrainState(),
):
  start = time.time()
  total_tokens = 0
  total_loss = 0
  tokens = 0
  n_accum = 0
  for i,batch in enumerate(data_iter):
    out = model.forward(batch.src,batch.tgt,batch.src_mask,batch.tgt_mask)
    loss,loss_node = loss_compute(out,batch.tgt_y,batch.ntokens)
    if mode == "train" or mode == "train+log":
      loss_node.backward()
      train_state.step += 1
      train_state.samples += batch.src.shape[0]
      train_state.tokens += batch.ntokens
      if i % accum_iter == 0:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        n_accum += 1
        train_state.accum_step += 1
      scheduler.step()

    total_loss += loss
    total_tokens += batch.ntokens
    tokens += batch.ntokens
    if i % 40 == 1 and (mode == "train" or mode == "train+log"):
      lr = optimizer.param_groups[0]["lr"]
      elapsed = time.time() - start
      print(
          ("Epoch Step: %6d | Accumulation Step: %3d | Loss: %6.2f " +
          "|Tokens / Sec: %7.1f | Learning Rate: %6.1e"
          ) % (i, n_accum, loss / batch.ntokens, tokens / elapsed, lr)
      )
      start = time.time()
      tokens = 0
    del loss
    del loss_node
  return total_loss / total_tokens, train_state

In [None]:
def rate(step,model_size,factor,warmup):
  if step == 0:
    step = 1
  return factor * (
      model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))
  )

In [None]:
class LabelSmoothing(nn.Module):
  def __init__(self,size,padding_idx,smoothing=0.0):
    super(LabelSmoothing,self).__init__()
    self.criterion = nn.KLDivLoss(reduction="sum")
    self.padding_idx = padding_idx
    self.smoothing = smoothing
    self.confidence = 1.0 - smoothing
    self.size = size
    self.true_dist = None
  
  def forward(self,x,target):
    assert x.size(1) == self.size
    true_dist = x.data.clone()
    true_dist.fill_(self.smoothing / (self.size - 2))
    true_dist.scatter_(1,target.data.unsqueeze(1),self.confidence)
    true_dist[:,self.padding_idx] = 0
    mask = torch.nonzero(target.data == self.padding_idx)
    if mask.dim() > 0:
      true_dist.index_fill_(0,mask.squeeze(),0.0)
    self.true_dist = true_dist
    return self.criterion(x,true_dist.clone().detach()) 

In [None]:
class SimpleLossCompute:
  def __init__(self,generator,criterion):
    self.generator = generator
    self.criterion = criterion

  def __call__(self,x,y,norm):
    x = self.generator(x)
    sloss = self.criterion(x.contiguous().view(-1,x.size(-1)),y.contiguous().view(-1)) / norm

    return sloss.data * norm, sloss

In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.zeros(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len - 1):
        out = model.decode(
            memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)
        )
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat(
            [ys, torch.zeros(1, 1).type_as(src.data).fill_(next_word)], dim=1
        )
    return ys

In [None]:
def data_gen(V, batch_size, nbatches):
    "Generate random data for a src-tgt copy task."
    for i in range(nbatches):
        data = torch.randint(1, V, size=(batch_size, 10))
        data[:, 0] = 1
        src = data.requires_grad_(False).clone().detach()
        tgt = data.requires_grad_(False).clone().detach()
        yield Batch(src, tgt, 0)

In [None]:
def execute_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        fn(*args)

In [None]:
from tqdm import tqdm

In [None]:
def example_simple_model():
    V = 11
    criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)
    model = make_model(V, V, N=2)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9
    )
    lr_scheduler = LambdaLR(
        optimizer=optimizer,
        lr_lambda=lambda step: rate(
            step, model_size=model.src_embed[0].d_model, factor=1.0, warmup=400
        ),
    )

    batch_size = 80
    for epoch in tqdm(range(120)):
        model.train()
        run_epoch(
            data_gen(V, batch_size, 20),
            model,
            SimpleLossCompute(model.generator, criterion),
            optimizer,
            lr_scheduler,
            mode="train",
        )
        model.eval()
        run_epoch(
            data_gen(V, batch_size, 5),
            model,
            SimpleLossCompute(model.generator, criterion),
            DummyOptimizer(),
            DummyScheduler(),
            mode="eval",
        )[0]

    model.eval()
    src = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
    max_len = src.shape[1]
    src_mask = torch.ones(1, 1, max_len)
    print(greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=0))


execute_example(example_simple_model)

  0%|          | 0/120 [00:00<?, ?it/s]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   3.10 |Tokens / Sec:   439.2 | Learning Rate: 5.5e-06


  1%|          | 1/120 [00:34<1:08:31, 34.55s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   2.14 |Tokens / Sec:   466.6 | Learning Rate: 6.1e-05


  2%|▏         | 2/120 [01:07<1:05:46, 33.44s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   1.77 |Tokens / Sec:   470.1 | Learning Rate: 1.2e-04


  2%|▎         | 3/120 [01:41<1:06:13, 33.96s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   1.53 |Tokens / Sec:   477.3 | Learning Rate: 1.7e-04


  3%|▎         | 4/120 [02:14<1:04:39, 33.44s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   1.08 |Tokens / Sec:   473.4 | Learning Rate: 2.3e-04


  4%|▍         | 5/120 [02:47<1:03:29, 33.13s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.60 |Tokens / Sec:   472.3 | Learning Rate: 2.8e-04


  5%|▌         | 6/120 [03:21<1:03:47, 33.58s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.37 |Tokens / Sec:   476.9 | Learning Rate: 3.4e-04


  6%|▌         | 7/120 [03:54<1:02:44, 33.32s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.20 |Tokens / Sec:   474.6 | Learning Rate: 3.9e-04


  7%|▋         | 8/120 [04:28<1:02:55, 33.71s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.18 |Tokens / Sec:   473.4 | Learning Rate: 4.5e-04


  8%|▊         | 9/120 [05:01<1:01:42, 33.36s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.12 |Tokens / Sec:   466.1 | Learning Rate: 5.0e-04


  8%|▊         | 10/120 [05:34<1:00:47, 33.16s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.23 |Tokens / Sec:   406.4 | Learning Rate: 5.6e-04


  9%|▉         | 11/120 [06:08<1:00:53, 33.52s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.15 |Tokens / Sec:   479.8 | Learning Rate: 6.1e-04


 10%|█         | 12/120 [06:40<59:48, 33.23s/it]  

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.14 |Tokens / Sec:   478.7 | Learning Rate: 6.7e-04


 11%|█         | 13/120 [07:15<59:50, 33.55s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.14 |Tokens / Sec:   480.0 | Learning Rate: 7.2e-04


 12%|█▏        | 14/120 [07:47<58:38, 33.19s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.09 |Tokens / Sec:   469.2 | Learning Rate: 7.8e-04


 12%|█▎        | 15/120 [08:20<57:47, 33.03s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.05 |Tokens / Sec:   363.5 | Learning Rate: 8.3e-04


 13%|█▎        | 16/120 [08:54<57:50, 33.37s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.08 |Tokens / Sec:   481.0 | Learning Rate: 8.9e-04


 14%|█▍        | 17/120 [09:26<56:50, 33.11s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.11 |Tokens / Sec:   476.5 | Learning Rate: 9.4e-04


 15%|█▌        | 18/120 [10:01<56:56, 33.50s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.11 |Tokens / Sec:   471.5 | Learning Rate: 1.0e-03


 16%|█▌        | 19/120 [10:33<55:52, 33.19s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.23 |Tokens / Sec:   470.3 | Learning Rate: 1.1e-03


 17%|█▋        | 20/120 [11:07<55:23, 33.24s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.13 |Tokens / Sec:   335.5 | Learning Rate: 1.1e-03


 18%|█▊        | 21/120 [11:40<55:01, 33.35s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.13 |Tokens / Sec:   481.5 | Learning Rate: 1.1e-03


 18%|█▊        | 22/120 [12:13<54:07, 33.14s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.14 |Tokens / Sec:   476.3 | Learning Rate: 1.1e-03


 19%|█▉        | 23/120 [12:47<54:13, 33.54s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.12 |Tokens / Sec:   469.3 | Learning Rate: 1.0e-03


 20%|██        | 24/120 [13:20<53:18, 33.31s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.09 |Tokens / Sec:   470.2 | Learning Rate: 1.0e-03


 21%|██        | 25/120 [13:55<53:25, 33.74s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.10 |Tokens / Sec:   471.8 | Learning Rate: 9.9e-04


 22%|██▏       | 26/120 [14:28<52:29, 33.51s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.08 |Tokens / Sec:   470.0 | Learning Rate: 9.7e-04


 22%|██▎       | 27/120 [15:01<51:42, 33.36s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.05 |Tokens / Sec:   461.1 | Learning Rate: 9.5e-04


 23%|██▎       | 28/120 [15:36<51:49, 33.80s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.07 |Tokens / Sec:   465.7 | Learning Rate: 9.3e-04


 24%|██▍       | 29/120 [16:09<51:05, 33.68s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.08 |Tokens / Sec:   456.7 | Learning Rate: 9.2e-04


 25%|██▌       | 30/120 [16:44<51:11, 34.12s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.06 |Tokens / Sec:   454.5 | Learning Rate: 9.0e-04


 26%|██▌       | 31/120 [17:18<50:19, 33.93s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.06 |Tokens / Sec:   457.9 | Learning Rate: 8.9e-04


 27%|██▋       | 32/120 [17:51<49:34, 33.80s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.05 |Tokens / Sec:   420.8 | Learning Rate: 8.7e-04


 28%|██▊       | 33/120 [18:26<49:34, 34.19s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.08 |Tokens / Sec:   464.3 | Learning Rate: 8.6e-04


 28%|██▊       | 34/120 [19:00<48:49, 34.06s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.06 |Tokens / Sec:   451.0 | Learning Rate: 8.5e-04


 29%|██▉       | 35/120 [19:36<48:54, 34.52s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.11 |Tokens / Sec:   452.3 | Learning Rate: 8.3e-04


 30%|███       | 36/120 [20:10<48:07, 34.37s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.04 |Tokens / Sec:   454.8 | Learning Rate: 8.2e-04


 31%|███       | 37/120 [20:46<48:07, 34.79s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.10 |Tokens / Sec:   443.6 | Learning Rate: 8.1e-04


 32%|███▏      | 38/120 [21:20<47:27, 34.73s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.07 |Tokens / Sec:   441.4 | Learning Rate: 8.0e-04


 32%|███▎      | 39/120 [21:55<46:55, 34.76s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.13 |Tokens / Sec:   320.8 | Learning Rate: 7.9e-04


 33%|███▎      | 40/120 [22:31<46:54, 35.18s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.07 |Tokens / Sec:   452.8 | Learning Rate: 7.8e-04


 34%|███▍      | 41/120 [23:05<45:57, 34.90s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   447.7 | Learning Rate: 7.7e-04


 35%|███▌      | 42/120 [23:41<45:51, 35.27s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.08 |Tokens / Sec:   452.0 | Learning Rate: 7.6e-04


 36%|███▌      | 43/120 [24:16<45:06, 35.15s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.04 |Tokens / Sec:   434.8 | Learning Rate: 7.5e-04


 37%|███▋      | 44/120 [24:53<45:12, 35.69s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.05 |Tokens / Sec:   442.6 | Learning Rate: 7.4e-04


 38%|███▊      | 45/120 [25:29<44:34, 35.66s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   431.9 | Learning Rate: 7.4e-04


 38%|███▊      | 46/120 [26:06<44:40, 36.23s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.05 |Tokens / Sec:   432.1 | Learning Rate: 7.3e-04


 39%|███▉      | 47/120 [26:42<43:56, 36.11s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.04 |Tokens / Sec:   426.4 | Learning Rate: 7.2e-04


 40%|████      | 48/120 [27:19<43:36, 36.33s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   332.6 | Learning Rate: 7.1e-04


 41%|████      | 49/120 [27:57<43:25, 36.70s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   421.8 | Learning Rate: 7.1e-04


 42%|████▏     | 50/120 [28:34<42:58, 36.83s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   411.9 | Learning Rate: 7.0e-04


 42%|████▎     | 51/120 [29:13<43:03, 37.44s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.04 |Tokens / Sec:   411.1 | Learning Rate: 6.9e-04


 43%|████▎     | 52/120 [29:49<42:07, 37.17s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   417.1 | Learning Rate: 6.8e-04


 44%|████▍     | 53/120 [30:28<42:03, 37.66s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.05 |Tokens / Sec:   399.2 | Learning Rate: 6.8e-04


 45%|████▌     | 54/120 [31:06<41:37, 37.84s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.06 |Tokens / Sec:   406.2 | Learning Rate: 6.7e-04


 46%|████▌     | 55/120 [31:46<41:45, 38.55s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   394.9 | Learning Rate: 6.7e-04


 47%|████▋     | 56/120 [32:26<41:16, 38.69s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   398.6 | Learning Rate: 6.6e-04


 48%|████▊     | 57/120 [33:06<41:11, 39.23s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   388.9 | Learning Rate: 6.5e-04


 48%|████▊     | 58/120 [33:46<40:43, 39.41s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   384.2 | Learning Rate: 6.5e-04


 49%|████▉     | 59/120 [34:27<40:42, 40.04s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   384.7 | Learning Rate: 6.4e-04


 50%|█████     | 60/120 [35:08<40:16, 40.27s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   375.3 | Learning Rate: 6.4e-04


 51%|█████     | 61/120 [35:50<40:10, 40.86s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   383.1 | Learning Rate: 6.3e-04


 52%|█████▏    | 62/120 [36:31<39:20, 40.70s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   366.8 | Learning Rate: 6.3e-04


 52%|█████▎    | 63/120 [37:13<39:05, 41.15s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   384.7 | Learning Rate: 6.2e-04


 53%|█████▎    | 64/120 [37:53<38:01, 40.73s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.04 |Tokens / Sec:   304.0 | Learning Rate: 6.2e-04


 54%|█████▍    | 65/120 [38:36<37:57, 41.41s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.06 |Tokens / Sec:   368.3 | Learning Rate: 6.1e-04


 55%|█████▌    | 66/120 [39:20<38:02, 42.27s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   349.2 | Learning Rate: 6.1e-04


 56%|█████▌    | 67/120 [40:03<37:35, 42.55s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   349.4 | Learning Rate: 6.0e-04


 57%|█████▋    | 68/120 [40:48<37:26, 43.20s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   351.7 | Learning Rate: 6.0e-04


 57%|█████▊    | 69/120 [41:31<36:49, 43.33s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   350.0 | Learning Rate: 5.9e-04


 58%|█████▊    | 70/120 [42:18<36:55, 44.31s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   329.3 | Learning Rate: 5.9e-04


 59%|█████▉    | 71/120 [43:04<36:31, 44.73s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.09 |Tokens / Sec:   333.4 | Learning Rate: 5.9e-04


 60%|██████    | 72/120 [43:51<36:27, 45.58s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   333.7 | Learning Rate: 5.8e-04


 61%|██████    | 73/120 [44:40<36:29, 46.58s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   304.9 | Learning Rate: 5.8e-04


 62%|██████▏   | 74/120 [45:29<36:14, 47.27s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   311.9 | Learning Rate: 5.7e-04


 62%|██████▎   | 75/120 [46:17<35:39, 47.55s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.05 |Tokens / Sec:   315.6 | Learning Rate: 5.7e-04


 63%|██████▎   | 76/120 [47:06<35:02, 47.78s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   317.8 | Learning Rate: 5.7e-04


 64%|██████▍   | 77/120 [47:57<35:00, 48.84s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   291.3 | Learning Rate: 5.6e-04


 65%|██████▌   | 78/120 [48:51<35:16, 50.39s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.00 |Tokens / Sec:   283.9 | Learning Rate: 5.6e-04


 66%|██████▌   | 79/120 [49:44<34:58, 51.19s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.04 |Tokens / Sec:   286.9 | Learning Rate: 5.6e-04


 67%|██████▋   | 80/120 [50:39<34:48, 52.22s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   280.7 | Learning Rate: 5.5e-04


 68%|██████▊   | 81/120 [51:33<34:25, 52.96s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   289.9 | Learning Rate: 5.5e-04


 68%|██████▊   | 82/120 [52:26<33:25, 52.77s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   277.9 | Learning Rate: 5.5e-04


 69%|██████▉   | 83/120 [53:21<33:04, 53.63s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   280.9 | Learning Rate: 5.4e-04


 70%|███████   | 84/120 [54:17<32:32, 54.23s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   265.5 | Learning Rate: 5.4e-04


 71%|███████   | 85/120 [55:14<32:03, 54.96s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   212.2 | Learning Rate: 5.4e-04


 72%|███████▏  | 86/120 [56:11<31:33, 55.70s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   272.3 | Learning Rate: 5.3e-04


 72%|███████▎  | 87/120 [57:08<30:47, 55.99s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   278.6 | Learning Rate: 5.3e-04


 73%|███████▎  | 88/120 [58:04<29:53, 56.05s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   267.0 | Learning Rate: 5.3e-04


 74%|███████▍  | 89/120 [59:00<28:56, 56.01s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   270.9 | Learning Rate: 5.2e-04


 75%|███████▌  | 90/120 [59:55<27:56, 55.87s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   281.4 | Learning Rate: 5.2e-04


 76%|███████▌  | 91/120 [1:00:51<27:02, 55.95s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   271.0 | Learning Rate: 5.2e-04


 77%|███████▋  | 92/120 [1:01:48<26:08, 56.03s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   274.3 | Learning Rate: 5.1e-04


 78%|███████▊  | 93/120 [1:02:42<24:59, 55.52s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   277.9 | Learning Rate: 5.1e-04


 78%|███████▊  | 94/120 [1:03:39<24:11, 55.84s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   271.7 | Learning Rate: 5.1e-04


 79%|███████▉  | 95/120 [1:04:37<23:33, 56.55s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   259.3 | Learning Rate: 5.1e-04


 80%|████████  | 96/120 [1:05:36<22:52, 57.19s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   236.2 | Learning Rate: 5.0e-04


 81%|████████  | 97/120 [1:06:33<21:57, 57.29s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   260.1 | Learning Rate: 5.0e-04


 82%|████████▏ | 98/120 [1:07:31<21:02, 57.38s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   268.5 | Learning Rate: 5.0e-04


 82%|████████▎ | 99/120 [1:08:29<20:09, 57.60s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   270.8 | Learning Rate: 5.0e-04


 83%|████████▎ | 100/120 [1:09:27<19:18, 57.93s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.00 |Tokens / Sec:   233.8 | Learning Rate: 4.9e-04


 84%|████████▍ | 101/120 [1:10:27<18:28, 58.35s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.00 |Tokens / Sec:   253.4 | Learning Rate: 4.9e-04


 85%|████████▌ | 102/120 [1:11:24<17:22, 57.92s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   270.0 | Learning Rate: 4.9e-04


 86%|████████▌ | 103/120 [1:12:21<16:20, 57.67s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   268.8 | Learning Rate: 4.9e-04


 87%|████████▋ | 104/120 [1:13:20<15:29, 58.11s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.03 |Tokens / Sec:   257.1 | Learning Rate: 4.8e-04


 88%|████████▊ | 105/120 [1:14:19<14:35, 58.38s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   252.5 | Learning Rate: 4.8e-04


 88%|████████▊ | 106/120 [1:15:19<13:45, 58.94s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   252.7 | Learning Rate: 4.8e-04


 89%|████████▉ | 107/120 [1:16:19<12:51, 59.35s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   260.4 | Learning Rate: 4.8e-04


 90%|█████████ | 108/120 [1:17:19<11:52, 59.41s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.00 |Tokens / Sec:   250.9 | Learning Rate: 4.8e-04


 91%|█████████ | 109/120 [1:18:18<10:51, 59.26s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.00 |Tokens / Sec:   250.7 | Learning Rate: 4.7e-04


 92%|█████████▏| 110/120 [1:19:17<09:52, 59.30s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   250.8 | Learning Rate: 4.7e-04


 92%|█████████▎| 111/120 [1:20:16<08:53, 59.25s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   246.7 | Learning Rate: 4.7e-04


 93%|█████████▎| 112/120 [1:21:17<07:56, 59.61s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.00 |Tokens / Sec:   251.1 | Learning Rate: 4.7e-04


 94%|█████████▍| 113/120 [1:22:14<06:52, 59.00s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.00 |Tokens / Sec:   259.7 | Learning Rate: 4.6e-04


 95%|█████████▌| 114/120 [1:23:12<05:50, 58.45s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   270.0 | Learning Rate: 4.6e-04


 96%|█████████▌| 115/120 [1:24:09<04:50, 58.05s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.02 |Tokens / Sec:   256.8 | Learning Rate: 4.6e-04


 97%|█████████▋| 116/120 [1:25:07<03:52, 58.21s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   250.2 | Learning Rate: 4.6e-04


 98%|█████████▊| 117/120 [1:26:06<02:55, 58.47s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.01 |Tokens / Sec:   256.6 | Learning Rate: 4.6e-04


 98%|█████████▊| 118/120 [1:27:05<01:57, 58.61s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.00 |Tokens / Sec:   256.9 | Learning Rate: 4.5e-04


 99%|█████████▉| 119/120 [1:28:04<00:58, 58.58s/it]

Epoch Step:      1 | Accumulation Step:   2 | Loss:   0.00 |Tokens / Sec:   257.2 | Learning Rate: 4.5e-04


100%|██████████| 120/120 [1:29:02<00:00, 44.52s/it]

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])



