# Milestone 2

In [1]:
%cd ..
# imports 
import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scripts_m2 import *


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


c:\Users\001\OneDrive\Desktop\GUC\semester 10\nlp\NLP_PROJECT_111\QA Task milestone


### 1. Exploring the Dataset

In [2]:
# Loading the training data
import json 

with open('data/m2_train.json', 'r') as f:
    squad_data = json.load(f)

print(squad_data['data'][0]['paragraphs'][0].keys())

print(json.dumps(squad_data['data'][0], indent=2))

dict_keys(['qas', 'context'])
{
  "title": "Beyonc\u00e9",
  "paragraphs": [
    {
      "qas": [
        {
          "question": "When did Beyonce start becoming popular?",
          "id": "56be85543aeaaa14008c9063",
          "answers": [
            {
              "text": "in the late 1990s",
              "answer_start": 269
            }
          ],
          "is_impossible": false
        },
        {
          "question": "What areas did Beyonce compete in when she was growing up?",
          "id": "56be85543aeaaa14008c9065",
          "answers": [
            {
              "text": "singing and dancing",
              "answer_start": 207
            }
          ],
          "is_impossible": false
        },
        {
          "question": "When did Beyonce leave Destiny's Child and become a solo singer?",
          "id": "56be85543aeaaa14008c9066",
          "answers": [
            {
              "text": "2003",
              "answer_start": 526
            }
          ],


In [3]:
# Counting the number of answerable and unanswerable questions
num_answerable = 0
num_unanswerable = 0

for article in squad_data["data"]:
    for paragraph in article["paragraphs"]:
        for qa in paragraph["qas"]:
            if qa["is_impossible"]:
                num_unanswerable += 1
            else:
                num_answerable += 1

print("Number of answerable questions: ", num_answerable)
print("Number of unanswerable questions: ", num_unanswerable)
print("Number of QA pairs: ", num_answerable + num_unanswerable)

Number of answerable questions:  86821
Number of unanswerable questions:  43498
Number of QA pairs:  130319


Answerable vs. Unanswerable Questions in SQuAD 2.0

- **Answerable Questions:** The correct answer exists within the given passage, and the model must extract the exact span.
- **Unanswerable Questions:** The passage does not contain the answer, and the model should predict "No Answer."

SQuAD 2.0 introduces unanswerable questions to make the task more challenging, requiring models to distinguish between when an answer is present and when it is not.

In this notebook, we will use answerable questions only in the subset we will be training our neural networks on

### 2. Preprocessing the dataset

In [4]:
# Clean text function
import re

def clean_text(text: str, is_question: bool=False):
    """
    Cleans text by removing extra spaces, newline characters, and special symbols.
    """
    if is_question:
        text = text.replace("\n", " ").replace("\t", " ")  # Remove newlines & tabs
        text = re.sub(r"\s+", " ", text)  # Remove extra spaces
        text = text.strip()  # Trim leading/trailing spaces
    text = text.lower()  # Convert to lowercase
    return text

# Testing the clean_text function on a sample text
sample_text = "   This is a \n\n\t sample    text.    "
print("Original text:", sample_text)
print("Cleaned text:", clean_text(sample_text))

Original text:    This is a 

	 sample    text.    
Cleaned text:    this is a 

	 sample    text.    


In [5]:
# Load and process the SQuAD dataset using the clean_text function
import json

def load_and_process_squad(filepath, max_samples=20000):
    """
    Loads, cleans, and extracts answerable questions from the SQuAD dataset.

    Args:
        filepath (str): Path to the SQuAD JSON file.
        max_samples (int): Maximum number of answerable questions to load.

    Returns:
        List[dict]: A list of cleaned question-answer pairs.
    """
    with open(filepath, "r") as f:
        squad_data = json.load(f)

    data = []
    for article in squad_data["data"]:
        for paragraph in article["paragraphs"]:
            context = clean_text(paragraph["context"]) 
            for qa in paragraph["qas"]:
                if qa["is_impossible"]:  
                    continue
                
                question = clean_text(qa["question"], is_question=True) 
                answer_text = clean_text(qa["answers"][0]["text"])  
                answer_start = qa["answers"][0]["answer_start"]
                
                data.append({
                    "context": context, 
                    "question": question, 
                    "answer": answer_text, 
                    "answer_start": answer_start, 
                    "answer_end": answer_start + len(answer_text)})


    data.sort(key=lambda x: (len(x["context"]), len(x["question"])))
    # Limit the number of samples if specified
    if max_samples > 0:
        data = data[:max_samples]
    return data

# Load and process the SQuAD dataset
train_dataset = load_and_process_squad("data/m2_train.json", max_samples=-1)
dev_samples = load_and_process_squad("data/m2_dev.json", max_samples=2000)

print("Number of QA pairs in train data:", len(train_dataset))
print("Training sample:", train_dataset[0].keys())
print("\n")
print("Number of QA pairs in test data:", len(dev_samples))
print("Dev sample:", dev_samples[0].keys())


Number of QA pairs in train data: 86821
Training sample: dict_keys(['context', 'question', 'answer', 'answer_start', 'answer_end'])


Number of QA pairs in test data: 2000
Dev sample: dict_keys(['context', 'question', 'answer', 'answer_start', 'answer_end'])


### 3. Tokenization

In [6]:
unique_words = set()

for sample in train_dataset:
    context = sample["context"]
    question = sample["question"]
    answer = sample["answer"]

    unique_words.update(set(re.findall(r"\w+", context)))
    unique_words.update(set(re.findall(r"\w+", question)))
    unique_words.update(set(re.findall(r"\w+", answer)))

                
# Print the number of unique words
print("Number of unique words in the training set:", len(unique_words))

Number of unique words in the training set: 82597


In [7]:
combined_text = []
for sample in train_dataset:
    combined_text.append(sample["context"])
    combined_text.append(sample["question"])
    combined_text.append(sample["answer"])

# Join all text samples into one corpus (you can also use '\n'.join for a more distinct separation)
combined_text = "\n".join(combined_text)


In [8]:
# Training the tokenizer
from scripts_m2 import *

# Creating / Loading the tokenizer
tokenizer = BPETokenizer()

# Training the tokenizer if not already trained
tokenizer.train(combined_text=combined_text.split("/n"), vocab_size=10000)


Loading tokenizer from ./tokenizers/tokenizer.json...
Tokenizer already exists at ./tokenizers/tokenizer.json. Skipping training.


In [9]:
# Encoding a Question
example_question = train_dataset[4]["question"]
tokenized_output, attention_mask = tokenizer.encode(example_question)
print("Tokenized Question:", tokenized_output)
print("Length of Tokenized Question:", len(tokenized_output))
print("Attention Mask:", attention_mask)

# Decoding the tokenized output
decoded_output = tokenizer.decode(tokenized_output)
print("Decoded Question:", decoded_output)

Tokenized Question: [3, 2441, 1997, 1966, 4721, 2776, 1974, 3539, 1978, 7873, 3120, 1390, 4]
Length of Tokenized Question: 13
Attention Mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Decoded Question: where is the oklahoma school of science and mathematics located ?


In [10]:
# Changing the max tokenizer length
tokenizer.set_max_length(25)

In [11]:
# Encoding an answer
example_answer = train_dataset[4]["answer"]
tokenized_answer, attention_mask = tokenizer.encode(example_answer)
print("Tokenized Answer:", tokenized_answer)
print("Length of Tokenized Answer:", len(tokenized_answer))
print("Attention Mask:", attention_mask)

# Decoding the tokenized answer
decoded_answer = tokenizer.decode(tokenized_answer)
print("Decoded Answer:", decoded_answer)

Tokenized Answer: [3, 4721, 2251, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Length of Tokenized Answer: 25
Attention Mask: [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Decoded Answer: oklahoma city


### 4. Dataset & Data Loaders

In [12]:
# Getting the dataset
train_dataset = load_and_process_squad("data/m2_train.json", max_samples=20000)
dev_dataset = load_and_process_squad("data/m2_dev.json", max_samples=2000)
print("Number of training samples:", len(train_dataset))
print("Number of dev samples:", len(dev_dataset))

Number of training samples: 20000
Number of dev samples: 2000


In [13]:
# Create the tokenizer
tokenizer = BPETokenizer()
# Train the tokenizer
tokenizer.train(combined_text=combined_text.split("\n"), vocab_size=10000)

Loading tokenizer from ./tokenizers/tokenizer.json...
Tokenizer already exists at ./tokenizers/tokenizer.json. Skipping training.


In [14]:
# View the training data
random_idx = np.random.randint(0, len(train_dataset))

# Print the random sample
print(train_dataset[random_idx].keys())
print("Training data sample:")
print(train_dataset[random_idx]["context"])
print(train_dataset[random_idx]["question"])
print(train_dataset[random_idx]["answer"])
print("Answer start:", train_dataset[random_idx]["answer_start"])
print("Answer end:", train_dataset[random_idx]["answer_end"])
print("Length of the training examples:", len(train_dataset))

dict_keys(['context', 'question', 'answer', 'answer_start', 'answer_end'])
Training data sample:
in 1913, his father was elevated to the nobility for his service to the austro-hungarian empire by emperor franz joseph. the neumann family thus acquired the hereditary appellation margittai, meaning of marghita. the family had no connection with the town; the appellation was chosen in reference to margaret, as was those chosen coat of arms depicting three marguerites. neumann jános became margittai neumann jános (john neumann of marghita), which he later changed to the german johann von neumann.
what town did von neumann's family become associated when elevated to nobility?
marghita
Answer start: 203
Answer end: 211
Length of the training examples: 20000


In [15]:
q = tokenizer.encode(train_dataset[4]["question"])[0]
tokenizer.decode(q)

'where is the oklahoma school of science and mathematics located ?'

In [16]:
# Viewing the length of the longest question, context, and answer
max_context_length = max(len(tokenizer.encode(sample["context"])[0]) for sample in train_dataset)
max_question_length = max(len(tokenizer.encode(sample["question"])[0]) for sample in train_dataset)
max_answer_length = max(len(tokenizer.encode(sample["answer"])[0]) for sample in train_dataset)

print("Max context length:", max_context_length)
print("Max question length:", max_question_length)
print("Max answer length:", max_answer_length)

Max context length: 205
Max question length: 64
Max answer length: 75


So for the transformer based models we can use the max length of the question and answer as follows:
1. `Answers Max Length`: 25 tokens
2. `Question Max Length`: 30 tokens
3. `Context Max Length`: 160 tokens
4. `Context + Question Length`: 160+30+1 (+1 for [SEP]) We could have used 175 as it is the 99th percentile but we will use 191 instead


This would give us a good tradeoff between the length of the sequence without using a lot of padding.

In [47]:
# Creating dataset
tokenizer = BPETokenizer()
train_data = load_and_process_squad("data/m2_train.json", max_samples=20000)
dev_data = load_and_process_squad("data/m2_dev.json", max_samples=2000)

context_max_length = 160
question_max_length = 30
answer_max_length = 25
train_dataset = QADataset(train_data, tokenizer, context_max_length=context_max_length, question_max_length=question_max_length, answer_max_length=answer_max_length, include_context=True, encode_two_texts_sep=True, context_question_swap=True)
dev_dataset = QADataset(dev_data, tokenizer, context_max_length=context_max_length, question_max_length=question_max_length, answer_max_length=answer_max_length, include_context=True, encode_two_texts_sep=True, context_question_swap=True)
print("Number of training samples:", len(train_dataset))
print("Number of dev samples:", len(dev_dataset))

# View a sample from the dataset
random_idx = np.random.randint(0, len(train_dataset))
print("Sample from the dataset:")
print("Question encoded:", train_dataset[random_idx]["question"])
print("Question decoded: ", tokenizer.decode(train_dataset[random_idx]["question"].tolist()))
print("Answer encoded:", train_dataset[random_idx]["answer"])
print("Answer decoded: ", tokenizer.decode(train_dataset[random_idx]["answer"].tolist()))
print("Context encoded:", train_dataset[random_idx]["context"])
print("Context decoded: ", tokenizer.decode(train_dataset[random_idx]["context"].tolist()))
print("Context Question encoded:", train_dataset[random_idx]["context_question"])
print("Context Question decoded: ", tokenizer.decode(train_dataset[random_idx]["context_question"].tolist()))
print("Answer Start:", train_dataset[random_idx]["answer_start"])
print("Answer End:", train_dataset[random_idx]["answer_end"])
print("Answer Start: End decoded:", tokenizer.decode(train_dataset[random_idx]["context"][train_dataset[random_idx]["answer_start"]:train_dataset[random_idx]["answer_end"]+1].tolist()))

print("\n\n")

# View a sample from the dev dataset
random_idx = np.random.randint(0, len(dev_dataset))
print("Sample from the dev dataset:")
print("Context encoded:", dev_dataset[random_idx]["context"])
print("Context decoded: ", tokenizer.decode(dev_dataset[random_idx]["context"].tolist()))
print("Question encoded:", dev_dataset[random_idx]["question"])
print("Question decoded: ", tokenizer.decode(dev_dataset[random_idx]["question"].tolist()))
print("Answer encoded:", dev_dataset[random_idx]["answer"])
print("Answer decoded: ", tokenizer.decode(dev_dataset[random_idx]["answer"].tolist()))
print("Context Question encoded:", dev_dataset[random_idx]["context_question"])
print("Context Question decoded: ", tokenizer.decode(dev_dataset[random_idx]["context_question"].tolist()))
print("Answer Start:", dev_dataset[random_idx]["answer_start"])
print("Answer End:", dev_dataset[random_idx]["answer_end"])
print("Answer Start: End decoded:", tokenizer.decode(dev_dataset[random_idx]["context"][dev_dataset[random_idx]["answer_start"]:dev_dataset[random_idx]["answer_end"]+1].tolist()))


Loading tokenizer from ./tokenizers/tokenizer.json...
Filtered dataset size: 19993 out of original 20000
Filtered dataset size: 1996 out of original 2000
Number of training samples: 19993
Number of dev samples: 1996
Sample from the dataset:
Question encoded: tensor([   3, 2087, 1997, 3564, 1967, 4828, 1390,    4,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1])
Question decoded:  what is okanye ?
Answer encoded: tensor([   3, 5982, 2071, 2041, 2014, 2748, 6615, 1315, 1282, 1315, 9702,    4,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1])
Answer decoded:  pronounce unstressed / o / clearly
Context encoded: tensor([   3, 1966, 3368, 3618, 5685, 1978, 2881, 5213, 3000, 1966, 2670, 5730,
        3267, 4162, 5982, 2071, 2041, 2014, 2748, 6615, 1315, 1282, 1315, 9702,
        1528, 1966, 8675, 1980, 2566, 3564, 1967, 4828, 1315,  303,  299,  289,


In [48]:
# Creating Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True) # add this `num_workers=0` if you want to see print in __getitem__ in dataset class
dev_dataloader = DataLoader(dev_dataset, batch_size=32, shuffle=False)

for batch in train_dataloader:
    print(batch.keys())
    print(batch['context_question'].shape)
    print(batch['question'].shape)
    print(batch['answer'].shape)
    print(batch['attention_mask_question'].shape)
    print(batch['attention_mask_answer'].shape)
    print(batch['attention_mask_context_question'].shape)
    print(batch['answer_start'], batch['answer_end'])
    print(batch['context_question_type_mask'].shape)
    print("-" * 50)
    input_ids = batch['context_question']
    attention_mask = batch['attention_mask_context_question']
    token_type_ids = batch['context_question_type_mask']
    break

dict_keys(['question', 'attention_mask_question', 'answer', 'attention_mask_answer', 'context', 'attention_mask_context', 'context_question', 'attention_mask_context_question', 'answer_start', 'answer_end', 'context_question_type_mask'])
torch.Size([64, 189])
torch.Size([64, 30])
torch.Size([64, 25])
torch.Size([64, 30])
torch.Size([64, 25])
torch.Size([64, 189])
tensor([ 81,  19,   6,  76,  41, 123,  65,   0,  73,  40,  85,  78,  52,  71,
         39,  34,   5,  25,  24,  87,  22,   3,  14, 111, 120,  47,  49,  27,
        112,  27,  65,  12,  65, 101,  46,  23,   6,  68,  61,  58,  69,  19,
         60,  81,   0,  80,  93,  37,  12,  10,  23,  29,  49, 121,   9,  31,
         48,  28,  84,  15,   9,  45,  60,  56]) tensor([ 86,  25,  10,  76,  45, 123,  73,  13,  73,  40,  88,  81,  55,  71,
         44,  38,   7,  27,  28, 115,  23,   6,  15, 111, 121,  50,  51,  30,
        121,  51,  68,  17,  66, 102,  50,  23,   6,  87,  62,  62,  72,  21,
         67,  88,   2,  81,  99,  43,  

In [49]:
context_max_length + question_max_length

190

### 5. Phase 1: Model Training

In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerQAModel(nn.Module):
    """
    Transformer Encoder QA model for SQuAD-style span prediction.
    Input:  [batch_size, seq_len] token IDs in the form
            <start> question <sep> context
    Output: start_logits, end_logits of shape [batch_size, seq_len]
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 300,
        nhead: int = 6,
        dim_feedforward: int = 2048,
        num_layers: int = 4,
        max_len: int = 512,
        dropout: float = 0.01,
        pad_idx: int = 0,
    ):
        super().__init__()
        # word + positional embeddings
        self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos_emb   = nn.Embedding(max_len, d_model)

        # transformer encoder stack
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation="relu",
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # two linear heads for start/end prediction
        self.start_classifier = nn.Linear(d_model, 1)
        self.end_classifier   = nn.Linear(d_model, 1)

        self.dropout = nn.Dropout(dropout)
        self.max_len = max_len

    def forward(
        self,
        context_question: torch.LongTensor,               # [B, L]
        attention_mask_context_question: torch.BoolTensor = None,    # [B, L] (True = keep, False = pad)
    ):
        B, L = context_question.size()
        # create position IDs
        pos_ids = torch.arange(L, device=context_question.device).unsqueeze(0).expand(B, -1)

        # embed tokens + positions
        x = self.token_emb(context_question) + self.pos_emb(pos_ids)
        x = self.dropout(x)             # [B, L, D]

        # TransformerEncoder expects [L, B, D]
        x = x.transpose(0, 1)           # [L, B, D]

        # key_padding_mask: True for positions that should be ignored (pads)
        if attention_mask is not None:
            key_padding_mask = ~attention_mask_context_question.bool()  # invert: True=pad
        else:
            key_padding_mask = None

        # encode
        enc = self.encoder(x, src_key_padding_mask=key_padding_mask)
        enc = enc.transpose(0, 1)       # [B, L, D]

        # compute logits and squeeze
        start_logits = self.start_classifier(enc).squeeze(-1)  # [B, L]
        end_logits   = self.end_classifier(enc).squeeze(-1)    # [B, L]

        return start_logits, end_logits




In [None]:

class TransformerQAModel3(nn.Module):
    def __init__(self,
                 vocab_size: int,
                 d_model: int,
                 num_layers: int,
                 num_heads: int,
                 dim_feedforward: int,
                 max_question_len: int,
                 max_context_len: int,
                 dropout: float = 0.1):
        """
        Args:
          vocab_size: size of your tokenizer vocabulary
          d_model: transformer hidden size
          num_layers: number of encoder blocks
          num_heads: number of attention heads
          dim_feedforward: inner dim of the transformer FFN
          max_question_len: maximum length of question (m)
          max_context_len: maximum length of context (n)
        """
        super().__init__()
        self.d_model = d_model
        self.max_context_len = max_context_len
        self.seq_len = max_question_len + max_context_len - 1

        # --- embedding + positional encoding ---
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Embedding(self.seq_len, d_model)

        # --- transformer encoder ---
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation="relu"
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # --- the “Concat & FeedForward” from Fig 2(a) ---
        #  input_dim = (context_len+1) * d_model, hidden = (context_len+1)
        self.fc1 = nn.LazyLinear(self.max_context_len)
        self.relu = nn.ReLU()

        # --- two heads for start / end ---
        self.start_ff = nn.LazyLinear(self.max_context_len)
        self.end_ff   = nn.LazyLinear(self.max_context_len)

    def forward(self,
                context_question: torch.LongTensor,
                attention_mask_context_question: torch.BoolTensor,
               ) -> torch.Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
          context_question: LongTensor of shape (batch, seq_len)
          attention_mask_context_question: BoolTensor of shape (batch, seq_len),
                          True for real tokens, False for padding.
        Returns:
          start_probs, end_probs: both FloatTensors of shape (batch, context_len+1)
        """
        bsz, seq_len = context_question.size()
        assert seq_len == self.seq_len, \
            f"expected seq_len={self.seq_len}, got {seq_len}"

        # --- embed tokens + positions ---
        pos = torch.arange(seq_len, device=context_question.device).unsqueeze(0)  # (1, seq_len)
        x = self.token_emb(context_question) * math.sqrt(self.d_model)
        x = x + self.pos_emb(pos)         # (batch, seq_len, d_model)

        # --- transformer wants (seq_len, batch, d_model) ---
        x = x.transpose(0,1)

        # build src_key_padding_mask for transformer
        # (batch, seq_len), True = mask out; so invert attention_mask
        kp_mask = None
        if attention_mask_context_question is not None:
            kp_mask = attention_mask_context_question == 0  # (batch, seq_len)

        # --- pass through N transformer encoder layers ---
        x = self.encoder(x, src_key_padding_mask=kp_mask)  # (seq_len, batch, d_model)
        x = x.transpose(0,1)                               # (batch, seq_len, d_model)

        # --- slice off only the context portion (including [OOV]) ---
        # context lives at the *end* of the sequence:
        context_repr = x[:, -self.max_context_len:, :]   # (batch, C+1, d_model)

        # --- flatten and feed through the “Concat & FeedForward” ---
        flat = context_repr.contiguous().view(bsz, -1)     # (batch, (C+1)*d_model)
        h = self.relu(self.fc1(flat))                   # (batch, C+1)

        # --- two linear heads  ---
        start_logits = self.start_ff(h)                    # (batch, C+1)
        end_logits = self.end_ff(h)                      # (batch, C+1)
        # print("Start logits shape:", start_logits.shape)
        # print("End logits shape:", end_logits.shape)

        # attention_mask_context = attention_mask_context_question[:, -self.max_context_len:]
        # # print("Attention mask context shape:", attention_mask_context.shape)
        # ctx_pad = attention_mask_context == 0 
        # start_logits = start_logits.masked_fill(ctx_pad, -1e9) # (batch, C+1)
        # end_logits = end_logits.masked_fill(ctx_pad, -1e9) # (batch, C+1)

        print("argmax start logits:", start_logits.argmax(dim=-1))
        print("argmax end logits:", end_logits.argmax(dim=-1))

        return start_logits, end_logits

In [59]:
import torch
import torch.nn as nn
import math

class TransformerQAModel3(nn.Module):
    def __init__(self,
                 vocab_size: int,
                 d_model: int,
                 num_layers: int,
                 num_heads: int,
                 dim_feedforward: int,
                 max_question_len: int,
                 max_context_len: int,
                 dropout: float = 0.1):
        """
        Args:
          vocab_size: size of your tokenizer vocabulary
          d_model: transformer hidden size
          num_layers: number of encoder blocks
          num_heads: number of attention heads
          dim_feedforward: inner dim of the transformer FFN
          max_question_len: maximum length of question (m)
          max_context_len: maximum length of context (n)
        """
        super().__init__()
        self.d_model = d_model
        self.max_context_len = max_context_len
        self.seq_len = max_question_len + max_context_len - 1

        # --- Embedding + Positional encoding ---
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Embedding(self.seq_len, d_model)

        # --- Transformer encoder ---
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation="relu"
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # --- Dropout layer ---
        self.dropout = nn.Dropout(dropout)

        # --- Two linear heads for start / end (per token) ---
        self.start_ff = nn.Linear(d_model, 1)
        self.end_ff   = nn.Linear(d_model, 1)

    def forward(self,
                context_question: torch.LongTensor,
                attention_mask_context_question: torch.BoolTensor
               ) -> torch.Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
          context_question: LongTensor of shape (batch, seq_len)
          attention_mask_context_question: BoolTensor of shape (batch, seq_len)
        Returns:
          start_logits, end_logits: FloatTensors of shape (batch, context_len+1)
        """
        bsz, seq_len = context_question.size()
        assert seq_len == self.seq_len, \
            f"expected seq_len={self.seq_len}, got {seq_len}"

        # --- Embedding ---
        pos = torch.arange(seq_len, device=context_question.device).unsqueeze(0)  # (1, seq_len)
        x = self.token_emb(context_question) * math.sqrt(self.d_model)
        x = x + self.pos_emb(pos)         # (batch, seq_len, d_model)

        x = self.dropout(x)
        x = x.transpose(0, 1)             # (seq_len, batch, d_model)

        # Invert attention mask for transformer: True = padding
        kp_mask = attention_mask_context_question == 0

        # --- Pass through transformer encoder ---
        x = self.encoder(x, src_key_padding_mask=kp_mask)  # (seq_len, batch, d_model)
        x = x.transpose(0, 1)                              # (batch, seq_len, d_model)

        # --- Extract context representations (including [OOV]) ---
        context_repr = x[:, -self.max_context_len:, :]     # (batch, C+1, d_model)
        context_repr = self.dropout(context_repr)

        # --- Token-level start/end logits ---
        start_logits = self.start_ff(context_repr).squeeze(-1)  # (batch, C+1)
        end_logits   = self.end_ff(context_repr).squeeze(-1)    # (batch, C+1)

        print("argmax start logits shape:", start_logits.argmax(dim=-1))
        print("argmax end logits shape:", end_logits.argmax(dim=-1))

        return start_logits, end_logits

In [60]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pad_idx=1

model = model = TransformerQAModel3(
    vocab_size=10_000,
    d_model=256,
    num_layers=6,
    num_heads=8,
    dim_feedforward=512,
    max_question_len=30,
    max_context_len=160,
    dropout=0
)
model = model.to(device)

# Setup an optimizer (e.g., Adam)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

train_qa_context_model_boilerplate(
    model=model, 
    train_dataloader=train_dataloader, 
    val_dataloader=dev_dataloader, 
    optimizer=optimizer, 
    criterion=criterion, 
    num_epochs=10, 
    device=device, 
    inputs = ["context_question", "attention_mask_context_question"],
    evaluate_val_dataset=True,
)


Epoch 1/10:   0%|          | 0/313 [00:00<?, ?it/s]

argmax start logits shape: 

Epoch 1/10:   0%|          | 1/313 [00:00<02:05,  2.49it/s, loss=9.27]

tensor([  6,  54,  16,  30,  91,  95,  17,  68,  71,   2,  55, 107, 123,  78,
         25,  66,  19,  97,  71,  56,  34,  27,  89,  55, 100,  20,  65,   7,
         42,  94,   9,  33,  40,  18,  35,  53,  37,  33,  48,  60,  75,  87,
         36,  43,  11,  41,  77,  90,  45,  31,   6,  78,  34,  24,  69,  41,
         73,  51,  35, 114, 105,  23,  43,  28], device='cuda:0')
argmax end logits shape: tensor([ 69, 110,  78, 123, 139, 112,  12,   8, 137,  33,  98, 120, 129,  84,
        115,  82,  36,  78,  56, 134,  38, 109, 115,  90, 121,  89, 123,  66,
        103,  94,  60,  10,  74,  54,  93,  89,  75,  17,  95, 115, 112,  94,
         99, 145,  96, 109,  55, 159, 110, 126,  84,  42,  78, 104, 108, 109,
         76,  52,   8, 127,  69, 110, 137,  30], device='cuda:0')
argmax start logits shape: tensor([ 20,  42,  28,  99,  30,  71,  58,  36,   2,  16,  81,  81,  15,  24,
         92,   3,   9,  79,  65,  79,  38,  96,  11,  62,  43,  14,  30,  47,
         98,  12,  37, 116,  27,  28

Epoch 1/10:   2%|▏         | 5/313 [00:00<00:34,  9.05it/s, loss=8.91]

argmax start logits shape: tensor([ 32,   5,  72,   2,  52,  99,  84, 110,  40,  47, 107,   0,  32,  20,
          9,  75,  30,  81,  71,  42,   0,  93,  80, 122, 100,  72,  30, 134,
         63,  85,  65,   5,  77,   0,  38,  75,  71,  32,  59,  55,  49,  31,
         36,  14,  62,  43, 128,  94,  62,  20,  21, 118,  66,  35,  50,  12,
         98,  36,   0,  56,  44,  22,  16,  64], device='cuda:0')
argmax end logits shape: tensor([ 42,  20,  76,  40,  78,  19,  84,  70,  36, 107,   1, 112,  40,  17,
         81,  87,  18,  38,  54,  56,  13,  42,  89,   1,  91, 130,  34,  72,
         63,  98,  91,   2,  76,  37,  97,  75, 125,  48,   4,  69,  34,  88,
         10,  27,  27,  66,  30,  60,  44,  13,  21,  30,  62,  17,  59,  85,
          4,  34,  44,  43, 124,  11,  40,  48], device='cuda:0')
argmax start logits shape: tensor([  0,  50,   0,   0,  96,   0,   0,   0,  31,   0,   0,  12,   0,   0,
          0,   0,   0,  18,   4,   0,   0,   0,  22,   0,  69,  80,  42,   0,
         

Epoch 1/10:   3%|▎         | 9/313 [00:00<00:24, 12.17it/s, loss=8.87]

argmax start logits shape: tensor([ 0, 88,  0,  0,  0,  0,  0,  0, 62,  0,  0,  0,  0, 95,  0,  0, 47,  0,
         0,  0,  0,  0,  0,  0, 52,  0,  0,  0,  0, 43,  0,  0,  0,  0,  0,  6,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0, 25,  0,  0,  0, 40,  0], device='cuda:0')
argmax end logits shape: tensor([ 11,  34,  62,  68,  72,  38, 104,  17,  51, 143,  60,  58, 102,  43,
         28,  36,  27, 135,  76,  82,   1,  40,   2,  62,  32,  67,  81,  36,
         36,  58,  17,  29,  32, 115,  38,  45,  60,   2,  13,  52,  76,  39,
         80,   3,  21,  47,  80,  32,  13,  62,  35,  34,  17, 117,  95,  52,
         11,  28,  83,  45,  46,  96,  12, 104], device='cuda:0')
argmax start logits shape: tensor([ 81,  23,  36,  12,  17,  27,   0,  35,  19,   0,   5,   9,   4,  47,
          1,  26,  23,   5,  22, 110,   0,  19,  63,  32,   0,  47, 102, 109,
         35,   0,  48,   0,   0,   0,  34,   0,  88,   0,   0,  83,  44,  50,
   

Epoch 1/10:   4%|▍         | 13/313 [00:01<00:21, 13.70it/s, loss=8.98]

argmax start logits shape: tensor([ 51,  87,  54,  87,  71,  47,  83,  79,  58,  72,  71,  40,  49,  53,
         56,  99,  45,  67,   4,  16,  79,  62,  89,  83,  68,  33,   3,  52,
         58, 102, 107,  69,   4,  68,  52,  36, 101,  87,  33,  84,  86,  25,
         57,  97,  37,   1,  41,  52,  85,  30, 109,  72,  14,  23,   2,   2,
         20,  79,  97,   7,  20, 118,  62,  88], device='cuda:0')
argmax end logits shape: tensor([  9,  96,  18,  66,  62,  21,   1,   4,  18,  53, 111,  49,  72,  53,
         85,  40,  33,  97,  32,  43,  43,  46,   6,  42,  56,  85,  35,  52,
        105,  70, 107,  44,   4,  97,  52,  27, 101, 111,  10, 101,  69,   2,
         24, 104,  35, 129,   3,  12,  79,   8,  50,  10,   9,   9,  12, 119,
         40,  19,  97,   7,  36, 115,  66,  13], device='cuda:0')
argmax start logits shape: tensor([ 49,  68,  94,  43, 120,  65, 145,  58,  53,  64,  56,  46,  80,  80,
         52,  10, 102, 139,  24, 115,   3,  62,  57,   8,  37,  61,   3,  66,
         

Epoch 1/10:   5%|▍         | 15/313 [00:01<00:21, 13.85it/s, loss=8.74]

argmax start logits shape: tensor([ 52,   8,   7,  13,  73,  52, 153,  16,  49,  57,  53,  48,  57,  33,
          4,  33, 119,  82,   4,   2,   7,  30,  90,  50,  24,  79,  27,  45,
         88,  31,  72,  10,  62,  75,  15,  18,   2,  62,  59,  69,  33,  34,
         44,  38,  68,  47,  18,  39,  53, 144,  37,  36,  79,  57,  68,  32,
         20,  94,  40,  61,  19,  98,  21,  46], device='cuda:0')
argmax end logits shape: tensor([  2, 102,  64,  13,  73,  13, 116,  40,  50,  24,  60,  33,  44,  58,
         66,  55,  91,  38, 131,  29,  58,  34,  29,  11,  27,  78,   7,  25,
         64,  69,  14,  60,  42,  25,  15,  33,  15,  61,  14,  20,  25,  46,
          3,   9,  10,  47,  14,  57, 135,  88,  41,  78,  53,  43,  40,  16,
         48,  32,  46,  38,  37,  41,  25,  32], device='cuda:0')
argmax start logits shape: tensor([ 12,  67, 110,  81,  45,  20,  38,  41,   7,  33,  14,  78,  47,  44,
         18,  31,  10,  68,  27, 117,  70,  50,  47,  30,  25,  37,  27, 100,
         

Epoch 1/10:   6%|▌         | 19/313 [00:01<00:20, 14.39it/s, loss=8.89]

argmax start logits shape: tensor([ 15,   7,  37,  79,  67,  59,  10,  10,  62,  77,  36,  20,  82, 106,
         23, 111,  73,  40,   2,  56,  89,  31,  83,  53,   2,   2,  87, 107,
         82,  89,  10,  81,  17,  15,  75,  70,   6,  84,  19,  41, 110,   6,
          5,  25,  79,  61,  15,  10,  12,  26,  17,  19,  77,  42,  36,  42,
         42, 100,  33, 115,  68,  83,  98,  54], device='cuda:0')
argmax end logits shape: tensor([ 15,  99,  32,  79,  62,  18,  34,   2,  62,  70,  17,  15,  74,  40,
         23, 100,  47,  49,   2,  31,  37,  14,  19,  92,   4,  53, 118,  18,
          6,  22,  52,  18,  63,  15,  76,  59,   7,   9,  62,  48, 120,  21,
         83,   9, 115,  62,  46,  34,  12,  68,  15,  69,  55,  40,  33,  72,
         57,  59,  43,  91,  29,  23,  68,  54], device='cuda:0')
argmax start logits shape: tensor([ 20, 105, 117,  56,  15,  33,  20,  93,  62,  77,  45,  47,  52,  67,
         63,  93,  85,  34,  94,  32, 113,  21,   4,  40,   2,  60,   7,  23,
         

Epoch 1/10:   7%|▋         | 21/313 [00:01<00:20, 14.53it/s, loss=8.67]

tensor([ 19,  93,  31,  94,  76,  38,   0,   5,   1,  86, 100,  96,  36,   0,
          0,  33,   1,  44,  49,  38,  10,  37,  48, 108,  43,  84,  22,  69,
         14,  41,  11,  47, 101,  59,   0,  45,   0,   0,  86,  31,   0,  62,
          0,  89,   3,  31, 100,  10,  74,  96,   1,  34,   0,  65,  13,   0,
         16,  21,  52,   0,   0,  74,   0,  10], device='cuda:0')
argmax end logits shape: tensor([ 60,  32,  64,  15,   7,  32,  41,  16,  22,  69,  74,  84,  62,  50,
        125,  57,  77,  13, 118,  72,  20,  88,  85,  52,  52,   9,  36,  66,
         17,   4,  86,  64, 103,  41,  63,  56,  55, 104,   4,  37,  66,  15,
          3,   5,  81,  15,  39,  45,  70,  18,  14,  66,  80,  64,   3,  89,
         31,  36,  75,  63,  94,   3,  39,  33], device='cuda:0')
argmax start logits shape: tensor([  0,   0,  87,   0,  79,   1,   0,   0, 100,  20,  33,   0,  72,  88,
        122,  33, 118,   0,  72, 122,   0,  39,  35,   0,  16,   0,  37,   0,
        135,  50,   0,   8, 108,  86

Epoch 1/10:   8%|▊         | 25/313 [00:02<00:19, 14.71it/s, loss=8.46]

argmax start logits shape: tensor([  0,   0,   0,   0,   0,   0,  93,   0,  15,   0,   0,   0,   0, 112,
          0,   0,   0,   0,   0,  95,   2,  35,   0,   0,  20,   0,  65,   0,
         80,  91,  22,   0,   0,   0,   0,   0,   0,  49,  87,   0,   3,   0,
          0,   0,   0,   0,  40,  83,   7,   0,  53,   0, 105,   0,   0,   0,
          0,  29,   0,   0,  86,   0,   0,  31], device='cuda:0')
argmax end logits shape: tensor([ 20, 135,  41,  38,  34, 111,  47,  21,  23,  41,  32,  17,  44,  15,
         73,  49,  60,  53,  13,  38,  23,  19, 108,  81,  94,  51, 100,  51,
        112,  40,  71, 106,  47,  68,  73,   2,  80,  18,  57,  26,  81, 116,
         36,  71,  29,  67,  90,  45,  90, 101,  38,  34,  50,  28,  63,  86,
         66,  38,   7,  21,  53,  35,  63,  59], device='cuda:0')
argmax start logits shape: tensor([  0,  90,   0,  27,   0,   0,   0,   0,   5,   0,  79,   0,   0,   0,
          0,  70,  43,   7,   0,   3,   0,   0,   0,   0,   0,   0,   0,   0,
         

Epoch 1/10:   9%|▊         | 27/313 [00:02<00:19, 14.76it/s, loss=8.72]

argmax start logits shape: tensor([  0,  72,  27,   0,  94,   0,   0,   0,   0,  10,   0,   0,  86,   0,
          0,   0,   0,  34,   0,   0,   0,   0,   4, 115,  10,   0,   0,   0,
         94,   0,   0,  41,   0,   0,   0,  36,   0,  30,  37,  39,  36,  41,
          0,   0,  50,  79,   0,  54,   0,   0,   0,   0,  54,  15,   0,   0,
          0,   0,  55,  16,   0,   0,   0,   0], device='cuda:0')
argmax end logits shape: tensor([ 57,  58,   4,  14, 110,  70,  59,  14,  54,  98,  27,  67,  83, 108,
         12,  85,  63,   2,  49,  84, 126,  47,  40,  91,  47,  99, 107,  80,
        117,  23,  95, 121,  13,  56,  87,  55,  74,  69,  57,  60,  32,  50,
        108, 107, 139,  28, 134,  51,   9,  32,  21, 102,  65,  17,   4,  46,
         45,  22,  58,  73,  47, 106,  20,  44], device='cuda:0')
argmax start logits shape: tensor([  0,   0,  28,   0,   0,   0,   0,   0,   0,   0,   0,   0,  88,  14,
          0,  71,   0,   0,   0,   5,   0,   0,   0,  46,   0,  25,   0,   0,
         

Epoch 1/10:   9%|▉         | 29/313 [00:02<00:19, 14.70it/s, loss=8.65]

argmax start logits shape: tensor([76, 16,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0, 80,  0,  0,  0,  0,  0,  0,  0, 99,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 36,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0], device='cuda:0')
argmax end logits shape: tensor([ 16,  62,  49,  10,  13,  29,  35,  49,  51,  20,  47,  21,  45,  27,
         70,  42, 109,  53,  35,  56, 111,  20,  27,  16,  29,   8,  23,  24,
         91,   2,   3,  23,  11,  52,  99,  53,   7,   6,  30,  18,  80,  71,
         29,  62,  39,  91,  15,  93,  55,  98,  23,  75,  60,  58,  35,  31,
         63,  79,  33,  60,  33,  32,  31,  16], device='cuda:0')
argmax start logits shape: tensor([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  28,
          0,   0,   0,   0,   0, 107,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
   

Epoch 1/10:  11%|█         | 33/313 [00:02<00:19, 14.54it/s, loss=8.76]

argmax start logits shape: tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0, 10,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0], device='cuda:0')
argmax end logits shape: tensor([ 92,  12,  39, 133,  13,  75,  23,  24, 117,  73,   5,  36,  81,  44,
        118,  14,  51,  67,  24,  66,  88,  26,  59,  85,  65,  49,  28,  31,
          7,  16,  59,  83,  52,   5,  53,  21,  24,  62,  60,  29,  77,  26,
         43,  51,  40,  88,  70,   2,  59,  46,  24,  41,  11,  13,  35,  24,
         10,  46,  16, 119,  38,  81,  33,  16], device='cuda:0')
argmax start logits shape: tensor([ 0,  0, 11,  0, 99,  0,  0,  0,  0,  0,  0,  0,  0,  0, 27,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 27,  0,  0,  0,  0,  0,  0,  0,
         0,  8,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  

Epoch 1/10:  11%|█         | 35/313 [00:02<00:19, 14.54it/s, loss=8.6] 

argmax start logits shape: tensor([  0,   0,  10,  35,   0,   0,   0,   0, 132,   0,  12,   0,   0,   0,
          0,   0,   0,   0,  63,   0,  40,   0,  99,   0,   0,   0,  43,   0,
         32,   0,   0,   0,  56,   0,   0,   0,   0,   0,   0,   3,  88,   0,
          0,  39,   0, 112,   0,  77,   0,  28,  81,   0,   0,  34,   0,  56,
          0, 104,   4,   0,   0,  31,   0,  27], device='cuda:0')
argmax end logits shape: tensor([ 53,  82,  14,  45,  23,  23,  40, 115,  47,  83,  33,  87,  34, 107,
         89,  11,  42,  44,  23,  58,  13,  86,  54,   7,  84,   4,  54,  64,
         50, 111,  53,   2,  81,  73,  73,  90,  46,  19,  68,  45,  43,  85,
         70,   2,  23, 100,  38,  65,  31, 124,  16,  69,  95,  58,  49,   6,
         83,  34,   2,  29,  20,  85,  12, 138], device='cuda:0')
argmax start logits shape: tensor([ 14,   0, 109,   0,   0,   0,   0,   3,   7,   0,   0,   0,  87,   0,
          0,   0,   0,  31,   0,  34,  83,  63,   0,   0,   0,  28,  69,   0,
         

Epoch 1/10:  12%|█▏        | 39/313 [00:03<00:18, 14.78it/s, loss=8.66]

argmax start logits shape: tensor([  0,  70,   0,  52,  40, 105,   0,  36,   0, 108,   0,   0,  47,  73,
          0,   0,   0,  20,   0, 121,   0,   0, 100,  66,   0,   0,   6,   0,
         66,  37,  94,   0,  54,  29,   0,   0,   6,  91,   0,  71,   0,  39,
          0,   0,   0, 140,  34,  91,  81,  74,  66,  91,   0,   0,   0,   0,
        151,   0,   0,   0,  92,  34, 100,  38], device='cuda:0')
argmax end logits shape: tensor([ 54,   8,  21,  89,  58,  42,  46,  50,   2,  42,  78,  65, 111,  55,
         23,  18,  57,  33,  52,   9, 100,  99,  92, 105,   8,  19,  61,  84,
         75,  38, 104,  58,   7,  45,   9,  51,  63,  70,  35,  36,  15,  99,
         28,  48,  75,  29,  58, 102,  84,  84,   2,  53,  53,  16,  24,  24,
        134,  12,  52,  47,  60,  15,  25, 111], device='cuda:0')
argmax start logits shape: tensor([  0,  44,  83,   0,   0,   0,  83,   0,   0,   0,   0,   0,   0,  22,
          0,   3,  78,   0,  95,  93,   0,   0,  77, 113,   0,  33,   7,   0,
         

Epoch 1/10:  14%|█▎        | 43/313 [00:03<00:18, 14.88it/s, loss=8.71]

argmax start logits shape: tensor([  0, 117,   0,   0,   0,  31,  27,  52,   0,  23,   2,   0,   0,   0,
          0,  12,  28,  46,  15,   0,  56,   0,   0, 101,   0,   0,  22,   0,
          0,  35,  45,  21,  75,   0,   0,  59,   0,   0,  35,  10,   0,  79,
         26,  40, 108,   0,   0,  52,   0,   0,   0,   0,   0,  41,   0,   0,
          0,  43,   0,   0,  62,  72,  99,  40], device='cuda:0')
argmax end logits shape: tensor([  4, 130,  34,  27,  94,  66,  42,  45,   7,  81,  90,  47,  62,  51,
         46,  71,  44,  31,  60, 129,  39,  31,  63,  52,  32,  29,  83,   5,
         55, 117,  16,   3,  56,  20, 106,  88,  12,  52,  31,  53, 102,  47,
         35,  46,  17,  66,  18,  89,  64,   6,   4,  67,  34,  32,   8,  50,
         95,  18,  45,  44,  48,  94, 121,  37], device='cuda:0')
argmax start logits shape: tensor([ 58,   0,  29,   0,  25,   0,  19,   0,  37,   0,  53,   0,  66,   0,
         51,   0,  76,  40,   0,   0,   0,   0,   0,   0,   0,   0,  23,   0,
         

Epoch 1/10:  14%|█▍        | 45/313 [00:03<00:18, 14.67it/s, loss=8.72]

argmax start logits shape: tensor([ 24,   0,  85,  68,  39,   0,   0,  15,   0,  32,  32, 102,   0,   0,
          0,  51,   0,   0,  52,  15,   0,  91,   0,   0,   0,  50,   0,   0,
         23,  32,   0, 107,   0,  53,   0,  67,  95,  79,  15,   0,  30,   3,
          0,  17,  54,   0,   0,   3,   0,  33,   0,   0,   0,  44,  52,  90,
          0,  58,  91,   0,  26,   0,  63,   0], device='cuda:0')
argmax end logits shape: tensor([ 52,  33,   3,  99, 102,  66,  27,  83,  32,  12, 101,  44,   6,  63,
        143,  90,  83,  54,  66,  50,  79,  10,  63,  72,   6,  68,  40,  43,
         31,  13,  31,  36,  56,  25,  22, 100,  48,  25, 132,  13, 108, 108,
         81,  50,  79,  53,  98,  55,  15,  36,  25,  56,  59,  26,  36,  47,
         41,   6,  41,   9,  17,  93,  25, 122], device='cuda:0')
argmax start logits shape: tensor([ 0,  5,  0, 67,  0,  0,  8,  0,  0,  0, 68,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0, 28,  8,  0,  0,  0,  0, 14,  0,  0,  0,  0,  0,  0,  0, 54,
     

Epoch 1/10:  16%|█▌        | 49/313 [00:03<00:17, 14.88it/s, loss=8.52]

argmax start logits shape: tensor([ 21,   0,   0,  66,   0,   0,   0,  55,   0,  56,   0,   0,   0, 110,
          0,   0,   0,   0,  59,   0,   0,   0,   0,   0,  85,  22,  96, 119,
          0,  70,   0,   0,   0,   0,   0,   0,  27,   0,  49,  68,   0,   0,
          0,   0,   0,   0,  78,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,  24,  36,   0], device='cuda:0')
argmax end logits shape: tensor([ 38,  30,  47,  10,  81,  37,  60,  71,  39,  49,  95,  45,  80, 113,
         57,  80,  46,  74,   7,  37,  98,  36, 106,  85,   3,  14,  64,  46,
         78,  32,  52,  91,  65,  35,  38,   7,  41,  42,  38,  58,   2,  97,
         23,  45,   1,  85,  37,  85,  31,  65,  56,  15,  33,  21,  30, 111,
          6,  14,  44,  27,  45,   4,  31,   8], device='cuda:0')
argmax start logits shape: tensor([  0,   0,   0,   0,   0,   0,   7,   0,  26,   0,   0,  57,   0,   0,
         13,   6,   0,   0,   0,  50,   0,   0,   0,  20,   0,  68,   0,   0,
        1

Epoch 1/10:  16%|█▋        | 51/313 [00:03<00:17, 14.80it/s, loss=8.5] 

argmax start logits shape: tensor([ 73,   0,  20,   0,  26,   0,   0,  25,  20,   0,  41,  19,   6,  33,
         24,   3,  15, 101,   0,   0,  61,   0,   0,   0,  34,   0,  42,  56,
          0,   0, 108,  52,   0,   0,   0,   0,   0,   9,   0,   0,   0,   0,
        125,   0,  74,  38,  46,   0,  28,  15,   0,  15,   0,  52,   0,  21,
         69,  33,   0,  41, 149,   0,   0, 116], device='cuda:0')
argmax end logits shape: tensor([ 15,  41,  19, 100,  90,  59,  89,  87, 108,  36,  17,   8,  25,  63,
         14,  79,  16,   5,  63,  84,   2,  28,  57,  76,  49,  25,  53, 103,
         53,   3,  82,  60,  65,  34,  16,  70,   3, 104,  16,  11,  72,   5,
         47, 133,   2,  77,  93,  88,   6,   8,  96,  81,  14,  18,  87,  18,
         32,  63,   5,  94, 140,  51, 103,  34], device='cuda:0')
argmax start logits shape: tensor([  0, 101,   0,  71,  26,   0, 105,   0,  29,   0,  36,   7,   0,   0,
         93,  83,   0,   1,   0,   0,  75,   0,   0,   0,   0,   0,  87,   0,
         

Epoch 1/10:  18%|█▊        | 55/313 [00:04<00:17, 14.80it/s, loss=8.64]

argmax start logits shape: tensor([  0,   0, 135,   0,  23,   0,   0,   0,   0,  81,   0,   0,  10,  29,
         56,  27,   0,   0,  79,   0,  94,  80,   0,  50,  81, 122,   0,  20,
          0,  11,  61,   0,  11,   0,   0,  58,   0,   0,  41,   0,   0, 127,
          0,  84,  14,   0,   0, 105,   0,   0, 120,   0, 121,  58,  10,   0,
          0,   0,   0,  55,  27,  17,   0,   0], device='cuda:0')
argmax end logits shape: tensor([100,  13,  79,   7,  81,  12,  37,  19,  42,  63,  41,  41,  75,  65,
         48,  58,  47,  27,  10,  95,  12,  12,  10,  38,  17,  42,  21,  50,
          3,  12,  60,  54,  58,  76,  76,  68,  22,  11,  21,  30, 113,  35,
         72,  50,  71,  74,  52,  15, 106,  87,  18,  20,   2,  61, 117,  22,
         86,   2,  42,  48,   9,  45,  17,   8], device='cuda:0')
argmax start logits shape: tensor([ 81,  39,  65,  25,   0,   0,  49,   0,   0,  15,  62, 115,  79,  37,
          0,  45,   0,  59,   0,   0,   0,  50,  92,  68,   0,   0,   0,   0,
         

Epoch 1/10:  18%|█▊        | 57/313 [00:04<00:17, 14.70it/s, loss=8.45]

argmax start logits shape: tensor([ 13,  78,  48,  40,   4,  28,  52,  35,  58,  71, 103,  41,  54,  27,
         95,  11,   8,  94, 118,  25,   0,  49,  49,  54,  25, 121,  80, 104,
         88,   2,  11,  69,   3,  27,  63,  58,  97, 102, 114,  44,   2,  34,
          0,  41,  40,  62,  42,  91,  66,  69,  47,  21,   0,   4,  29,  39,
        125,  26,  49,   0,  58,   0,  64,  27], device='cuda:0')
argmax end logits shape: tensor([ 72,  71,  26,  59,  21,  83,  60, 113,  30,  78,  10,  30,  36,  58,
         72,  48,  40,  84,  93,  37,  41,  55,  41,  82,  11, 105,   8,  77,
         46,   3,  47, 117,  41,  33,  13,  30,  81,  97,  95,  47,  38,  75,
         74,  50,  51, 110,   2,  81, 111,  79,  31,  20,  27,   5, 102,   4,
        114, 110,  27,  71,  28,  51, 100,  80], device='cuda:0')
argmax start logits shape: tensor([  5,  29, 103,   0, 111,  58, 125, 102, 108, 103,  86,  82,   4,  62,
         75,  19,  85,  43,   0,  41,  55,   6,  60,  57,   0, 100,   0,  77,
         

Epoch 1/10:  19%|█▉        | 61/313 [00:04<00:17, 14.82it/s, loss=8.53]

argmax start logits shape: tensor([ 27,  17,   0,   0,  58,   0,  54,  79,  22,   0,   0,  14, 113,   0,
          0,   0,   0,   0,  41,  19, 131,   0,  82,  72,  71,  33,   0,   0,
         15,  24,   0,   0,   0,  90,   0,  68,   0,   0,   0,  72,   0,   0,
          0,  75,   1,  81,  24,   0,   0,  33,  56,  65,   0,  25,   6,  26,
          0,  35,   0,   0,   0,  20,  52,   0], device='cuda:0')
argmax end logits shape: tensor([  3, 135,  80,  23,  60,  50, 119,  33,  43,  27,   2,  15,   2,  35,
         86,  86,  20,  65,  99,  10,  83,  29, 105,  17,  27,   5,  15,  37,
         20,  36,  48,  45,   6,  63,  34, 112, 114,  31,  60,  78,  26,  25,
        103,  27,  73, 114,   8,  27,  61,   8,  48,  39,  37,  12,  10, 110,
        113,  81,  46, 143,  54,  88,  55,  84], device='cuda:0')
argmax start logits shape: tensor([ 85,   0,  63,   9,  59,  27,  18,  17,   0,  69,   0,  94,  25,  51,
         97,   4,   0,   0,   0,   0,   0, 107,  32,  24,   0,   0,   0,  50,
         

Epoch 1/10:  21%|██        | 65/313 [00:04<00:16, 14.98it/s, loss=8.52]

argmax start logits shape: tensor([  0,   0,  14,  48,  30,   9, 105,   0,  51,   0,   0,   0,  65,  85,
         13,  28,   8,  21,   0,  81,   0,   8,   0,   0,  63,  82,  11,  68,
         50,   2,  46,   8,   0,  22,  42,   0,  41,  35,  75,  22, 142, 100,
         17,  14,  60,  18,   0,  83,   0,  14,  70,   8,   0,  61, 116,  60,
        122,  70, 123,  41,  18, 112,  30,   9], device='cuda:0')
argmax end logits shape: tensor([ 56,  27,  33,  51,  59,  47,  15,   8,  10,  65,  37,  52,  98, 116,
         15,  15,  28,  59,  27,  11,  21, 136,  10,  38,  90, 117,  97,  50,
        110,  36,  94,  52,  32,  22,   9,  17, 124,  27,  21,  18,  16,  31,
         33, 121,  37,  83,  39,  37,  56,  37,  89, 101, 116,  62,  48,  25,
         47, 104,  99,  66,  58,  25,  70,  43], device='cuda:0')
argmax start logits shape: tensor([ 26,  52,  53,  70,   0,  33,   0,  52,   6,  20,   2,  26,   0,   0,
          2,   0,  31,   0,  25,  50,  31,  35,  18,  10,  45,  67,  65,  28,
         

Epoch 1/10:  21%|██▏       | 67/313 [00:04<00:16, 14.97it/s, loss=8.58]

argmax start logits shape: tensor([ 53,   0,   0,  48,   0, 134,   0,   7, 105, 107,   0,  16,   0,  24,
        106,  36,  17,  19,  20,  48,  32,   0,   0,  19,   6,   7,   0,  15,
          0,  35,  33,   0,  75,   0,  86,   0,   0,  43,   0,  58,   0,   0,
          0,   0,  18,  23,  28,  11,  14,   0,   0,   0,   0,   0,   7,  49,
          0,   0,   0,   0,   0,  29,   0,   0], device='cuda:0')
argmax end logits shape: tensor([118,  14,   9,  44,  11,  81,  64,  82,  41, 110,  51,  72,  80,  25,
         81,  51,  30, 102, 125,  16,  66,  12,  14,  18,  77,  51,  82,   2,
         77,   6,  63,  64,  53,   5,  36,  73,  91,  11,  96,  69,  42, 129,
         95,   2,  39,  35, 111,  43,  15,  47,  80,  75,  44,  36,  67,  14,
          2, 110,  64,  91,  47,   7, 112,  96], device='cuda:0')
argmax start logits shape: tensor([102,   0,  44,   0,  50,   0,  62,   0,  49,   0,  33,   0,   0,  21,
          0,   0,   0, 107,   0,  88,   0,  20,   0,  19,   0,  15,   0,   0,
         

Epoch 1/10:  22%|██▏       | 69/313 [00:05<00:16, 14.93it/s, loss=8.4] 

argmax start logits shape: tensor([ 35,  45,  17, 111,   0, 113,  35,   5,   7,  15,  57,  75,  56,   4,
         38,  54,  18,   0,  43,  52,   0,  22,  23,  69,  52,  20,   0,  87,
         78,  58,  25,  13,  90,  32,  33,   0,  31,  42,   0,  13,  27,  34,
         69,  40,  34,  40,  47,  52,  24,  39,  33,  52,   0,   0,   2,  28,
         16,   4,  86,  34,  50,  30,  14,  34], device='cuda:0')
argmax end logits shape: tensor([ 45,  60,  15,  78,  46,  77,  28,  75,  73,   3,  94,  83,  42,  41,
          7,  32, 109,  21,  13,  17,  15,  56,  55,  12,   2,  50,  22,  47,
         12,  96,  61,  43,  31,  76,  34,  42,  76,  29,  55,  37,  38,  67,
         52,  32, 112,  58, 100,  30,  80,  83,   8,  50,  27,  35,   4,  38,
         69,  94, 113,  51,  71,  34,  57,  49], device='cuda:0')
argmax start logits shape: tensor([ 45,   6,   4,  88,  88,   9,  49,  48,  28,   8,  49,  62,  47,   0,
         53,  67,  24,  27,  24, 109,  12, 109,  29,   0, 113,  68,  92,  69,
         

Epoch 1/10:  23%|██▎       | 73/313 [00:05<00:15, 15.01it/s, loss=8.83]

argmax start logits shape: tensor([ 62, 119, 100,  19,  15,  40, 135,  13,  98,   2,  30,  62,  86,  44,
         54,  69,  10,  51,  49,  99,  64,  24,   6,  26, 111,  49,  31,  33,
         49,  90,  27,  13,  12,  65,   3, 115,  49,  28,   2,  58,  34, 117,
         65,  19,  50,  68,  86,  84, 113,  52,  51,  34,  32,  48,   6, 135,
        116,   5,  53,  83,  15,  88,  55,  77], device='cuda:0')
argmax end logits shape: tensor([  3,  86, 118,  36,  81,  52,  58,  22,  36,  29,  59,  18,  64,  56,
         42,   3,  27, 146,  14,  52,  34,  70,  64,  65,  20,  11,  76,   3,
         72,  76,  42,   7,  80,  93,  49, 129,  59,  19,  31,  72,  14, 121,
         51,  40,  21,  42,  48,  80,  47, 106,  47,  44,   5,  16,  56, 109,
         54, 140,  98,  83,  18,  75,  22, 113], device='cuda:0')
argmax start logits shape: tensor([ 14,  28,  92, 106,  55,  14,  39,  29,  97,  16,  23,  65,  93,  87,
         69,  23,  60,  42,  75,   2,  50,  17,  86,  35,   5,   5,  38,  22,
         

Epoch 1/10:  24%|██▍       | 75/313 [00:05<00:15, 14.93it/s, loss=8.53]

argmax start logits shape: tensor([  0,   0,  16,   0,  65,   0, 132,  90,   0,  85,   0,  42,   0,  71,
          0,  85,  63,   0,   0,  39,  85,   0,   4,   0,   0,  35,   0,  42,
          0,  79,   8,  60,   0,   0, 100,   0,   0,   0,  81,  16,  42,   0,
          0,   0, 128,   5,  11,   0,   0,   0,   0,   0,   0,  77,  13, 100,
          0,  37,  32,  33,   2,  10,   0,   0], device='cuda:0')
argmax end logits shape: tensor([ 54,   8,   5,  32,   3,  60,  30, 119,   7,  71,  34,  53,  36,  62,
         31,  86,  89,  37,  43,  20, 102,  10,  99,  32, 114,  30,  24, 107,
         12,  63, 134,   5,  51,  19,  91,  64,  12,  11, 107,  47,  49,  49,
         88,  34, 107,  89,  47,  36,  36,  47,  50,  34,  57,  46,   5,   6,
         47, 108,  86,  30,  63,  81,  69,   6], device='cuda:0')
argmax start logits shape: tensor([ 11,   0,  33, 130,  71,   0,   0,   0,   0,  23,   0,   0,   0,  67,
          0,  53,   9,  75,  65,  30,  75,   0,  96,  25,  96,   0,   0,  65,
         

Epoch 1/10:  25%|██▌       | 79/313 [00:05<00:15, 14.95it/s, loss=8.72]

argmax start logits shape: tensor([  0,   0,   0,  56,  62,   0,   0,   0,   0,  46,   0,   0,   0,  53,
          0, 108,   0,   0,   0,   0,   0,  91,   0,   0,   0,   0,   0,   0,
         31,   0,   0,   0,   0,  50,   0,  33,   0,  36,   0,  64,   0,   0,
          0,   0,  30,   0,  54,   0,   0,   0,   0,  44,  25,   0,  29,   0,
          0,   0,   0,   0,  10,  16,   0,   0], device='cuda:0')
argmax end logits shape: tensor([ 42,  37,  67,  81,   9,  96,  23, 136,  84,  37,   9,   3,  45,  57,
         60,  15,  30,  81,  11,  24,  33,  78,  66,  66,  81,  98,   8, 103,
        117,  57,  20,  58,  50, 100,  43,  37,  20,  41,  72, 125,  97,   3,
          7,  37,  49,  36,  59,  36,  68,  26,  14,  81,  80,  79,  83,  35,
         51,   5,  39,  28,  94,  55,  72,  31], device='cuda:0')
argmax start logits shape: tensor([ 31,   0,  81,   0,  98, 100,  21,  48,   0,  28,  15,  11, 125,  25,
          0,  94,   0,  38,   0,   0,  35,   0,   0,  68,   0,   0,  49,   0,
         

Epoch 1/10:  26%|██▌       | 81/313 [00:05<00:15, 14.80it/s, loss=8.41]

argmax start logits shape: tensor([120,  53,  71,   2, 102,   0,  56,  24,  14,  30,  31,  50,  75,   5,
         15,  35,  16,   0,  74, 103,  43,  34,  91,  40,  97,  49,  85,  10,
         26,  27,  43,  82,  89,   5,  30,   1,  23,  96,  13, 101,  36, 142,
         56,  92, 118,  20,  54, 107,  77,   2,   4,   4,  94,   0,  49,  58,
         38,   0, 126,  77, 123, 102,   8,  61], device='cuda:0')
argmax end logits shape: tensor([ 48,  55,  76,  60,  31,  12, 105,  45,  32,  26,  87,  10,  63,  54,
         50,  39,  92,  24,  66,  74, 106,  52,  44,  62,  35,  42,  58,  32,
         81,  14,  33, 116,  51,  66,  74,  84,  77, 104,  96,  67,  52,  28,
         74,  83,  56,   6,  59,   2,  34,  33,  32,  44,  71,  16,   4,  56,
         67,  23, 111,  19,  50,  87,  15,  22], device='cuda:0')
argmax start logits shape: tensor([ 65,  28,  10,  16,  43,  45,  46,  63,   6,  31,  58,   8,  89,  81,
         88,  73,  58,  35,   6,  49,  24,  64,   0,  76,  20,  63,  32,  35,
         

Epoch 1/10:  27%|██▋       | 85/313 [00:06<00:15, 14.82it/s, loss=8.3] 

argmax start logits shape: tensor([ 65,  38,  60,  40,  33,  34,   0,  85,  28,  29,  56,  97,  96,  99,
         14,  86,  54,  50,   2,  31,  95,  12,   0,  36,  66,  96,  37, 111,
          2,   7,  15,  92, 115,  52,  78,   3,  65,  70,  10,  36,  49,  85,
          0,  23,  33,  34,  90,  21,  72,  64,  56,   4,  15,  53,  85,   2,
         33,  26,  61,  19,  28,  17, 108,  42], device='cuda:0')
argmax end logits shape: tensor([118,  60, 104,  44,  53,  48,  56,  71,  47,  48,  20,  44,  54,  12,
         15,   3, 133,  49,  27,   7,  48,  57, 110,  16,  15,  18,   3,  89,
         27,  64,  45,   9,  91,  43,  81,  21, 105,  58,   3,  87,  54,  71,
         22,  24,  20,   3, 115,  96,  46, 100,  42,  26,  76, 101,  47,  30,
         36,  58,  18,  59, 115,  14, 104,  43], device='cuda:0')
argmax start logits shape: tensor([ 29,   0,  57,  27,   9,  49,  28,  22,  62,   2,   0,   0,  81,  53,
         19,  68,  66, 112,   0,   7,  19,  63,  25, 151,  98, 105,  62,  19,
         

Epoch 1/10:  28%|██▊       | 89/313 [00:06<00:15, 14.90it/s, loss=8.7] 

argmax start logits shape: tensor([ 88,   0,  33,   0,  79,  17,   0,  67,  66,   0,   0, 112,  23,  81,
          0,   9,   0,   0,  20,   0,  15,   0, 116,   0,  22,  80,  85, 105,
         16,  36,  53,  37,   0,   0,   0,  52,   0,  42,  13,  51,  80,   0,
         29,   2,  17,   0,   0,  17,   0,   0,   0,   0,   0,   0,  14, 123,
         54,  54,   0,   0,  92,   0,  83,   0], device='cuda:0')
argmax end logits shape: tensor([ 23,   8,  31,   4,  42,   7,  51, 106,  64,   6,  21,  81,  86,  89,
         87,   3,   4,  79,   9,  48,  46, 109, 117,  41,  12,  66,  52,  50,
          5,   6,  21,  35,  40,  26,  64,  89,  55,  16,  32,  78,  44,  30,
         83, 103,  23,  88,  46,  14,   5,  14,  27,   7,  92,  29,  47,  27,
         59,  55,  65,  43, 123,  61,  68,  15], device='cuda:0')
argmax start logits shape: tensor([ 23,  25,  16, 105,  46,  85,  21,   0,  99,   0,  56, 108,  91,   0,
          0,   4,   0,  54,  34,   0,  81,   0,  62,  70,  90,   0,   0,   0,
         

Epoch 1/10:  30%|██▉       | 93/313 [00:06<00:14, 15.03it/s, loss=8.46]

argmax start logits shape: tensor([ 11,  53,   0,  63, 127,   0,  23,   0,   0, 137,   0,  12,   0,   0,
          0,   0,  42,  73,  25,   0,   0,   0,   0, 112,  78,  44,  15,   1,
          0,  13,   0,  22,   0,   0,   0,   0,   0,   0,  88,   0,  15,  77,
         12,   0,   0,   0,  91,  67,   0,   0,   0,  42,   0,  71,  19,   0,
          0,  23,  70,  49,   0,   9,  27,   0], device='cuda:0')
argmax end logits shape: tensor([ 10,  35,  74,  18, 128,  45,   8,  59,  31, 117,   4,  36,  17,  80,
         23,  91, 148,   7,  54,  25,  40,  79,  41,  61,  49,  43,  90,  32,
         86,  14,   2, 107, 100,  36,  32,   3,  50,  27,   3,  87,  65,  58,
         44,  22,  23,  67,  16,  45,  32,  21,  78,  16,  88,  51,  41,  52,
         64,  36,  19,  94,  28,  59,  30,  85], device='cuda:0')
argmax start logits shape: tensor([100,   0,  85, 125,   0,  60,   0,   0,   0,   0,   0,  18,   6, 142,
         17,   0,  11,  48,  47,   0,  76,   0,  85,   0,   0, 127,  85,   0,
         

Epoch 1/10:  30%|███       | 95/313 [00:06<00:14, 14.96it/s, loss=8.57]

argmax start logits shape: tensor([ 21,  38,  41,  50,  34,   0,  58,   0,   0,   0,   0,   0,   9,   0,
          8,   0,   0,  71,  20,   0,   4,  18,   0,   0,   1,  56,   0,   0,
         67,   0,  67,   0,   0,   0,   0,   0,  42,   6,   0,   0,   0,   0,
        123,   0,   0,  27,   0,  19,   0,   0,  78,   0,   0,   8,   0, 106,
          0,  51, 103,   0,   0,  27,  20,  49], device='cuda:0')
argmax end logits shape: tensor([ 23,  57,  27,  69,  23,  32,  38,  75, 111,  46,   3,  52,  34,  39,
         78, 143,  30,  85, 109,  24,  72,  77,  16, 100,  10,  49,   3,  32,
         86,  65,  71,  93,   9,  17,  67,  87, 117, 110,  67,  31,   4,  63,
        117,  60,  45,  41,  56,  55,  67, 101,  65,   2,  49,  71,  83,  83,
         66,  38,  25,  33,  38,  33,  68,  37], device='cuda:0')
argmax start logits shape: tensor([  0, 105,   9,   0,   0,  36,  77,  48,  92,  23,   0,   0,  22,   0,
         11,  20,   0,  10,  50,  39,   0,   0,   1,  42,  46,  44,   4,   0,
         

Epoch 1/10:  31%|███       | 97/313 [00:06<00:14, 14.88it/s, loss=8.4] 

argmax start logits shape: tensor([ 25,  65,  58,  85,   0,  31,  77,   0,   0,  99,  65,  63,  96,  92,
         70,  24, 109,  40,  57,  51,  32,  34,  20,   0,  93,  98,  18,   0,
         93,  54, 110,  34,   9,  94,  17,  72,  36,  41,   0,   0,  36,  53,
         73,  56,   0,  45,  50, 132,  79,   0,  40,   0,   0,  75,   3,  50,
        107,   0,   0,  46,  36, 107,  78,  37], device='cuda:0')
argmax end logits shape: tensor([ 92,  54,  51,  32,  52,   2,  74,  36,  34,  16, 138,  60,  59,  16,
         75, 106,  12,  39,  38,  67,  51,  66, 127,  82,  92,  26,  67, 116,
         60,  82, 120,  45,  87,  66,  73,  46,  27,  55,  32,  24,  32,  97,
         17,  57,   9,  28,  48,  14,  10,  17,   2,  60,  58,  88,  22,  59,
         77,  66,   9,   1,  59,  26,  97,   8], device='cuda:0')
argmax start logits shape: tensor([ 26,  27,  44,  82,  61,  99,  79,  19,   7,  69, 101,  95,  80,  61,
         65, 100,  50,   0,  22,  92,   8, 110,   8,  45,  54, 121,  42,  10,
         

Epoch 1/10:  32%|███▏      | 101/313 [00:07<00:14, 14.67it/s, loss=8.35]

argmax start logits shape: tensor([ 69,  46,  82, 152,  21,  57,  17,  69, 112,  49,  57, 109,  98,  13,
         65,   0,  85,  14,  79,  49,  72,  20,  70,  44,  96,  34,  23,  75,
         35,  10,  33,  25,  10,  27,  20,  53,  36,  30,  33,  23,  31,  90,
         23,  14,  13,  83,  72,  19,  84,  90,  20,   6,  54,  45,   5,   9,
         75,  47,  37,  27,  30,  26,   2,  44], device='cuda:0')
argmax end logits shape: tensor([ 80,   9,  85,  63,  41,  71,   7,  95,   4,  76,  83, 113,  49,   5,
         85,   5,  80,  78,  28, 103,  11,  81,  43,  17,  27,  10,  10,  90,
         13,  67,  44, 113,  25,   6,  33,  45,  35,  17,  81,  36, 104,  69,
         56,  15,   4,  40,  72, 118,  93,  17,  17,  16,  16,  62,  83,   2,
         22,  56,  65, 119,  17,  11, 137,  34], device='cuda:0')
argmax start logits shape: tensor([ 47,   9,  39,  62,  71,   2,  62, 107,  54,  37,  36,  90,  23,   2,
          7,  50,  99,  45,  71, 143,  67,  24,  52,  45,  14,  52, 107,  17,
         

Epoch 1/10:  34%|███▎      | 105/313 [00:07<00:14, 14.85it/s, loss=8.48]

argmax start logits shape: tensor([ 53,  50,  53, 112,  15,  20,  56,  58,  49,  61,  24,  46,  85,  42,
         75,  43,   4,  22,  85,  35,  65,  10, 142, 117,   9,  76,  13, 104,
         66, 107,  13,  50, 107,   3,  13,  91,  91,  43,  23,  55,  85,  86,
         54,  93, 124,  49, 120,  30,  76,  92,  32,  77,  35, 116,   4,  85,
          8,  32, 106, 103,  39,  44,  85,  49], device='cuda:0')
argmax end logits shape: tensor([ 79,  64, 101, 113,  31,  64, 107,  47,  30,  52,  38,   9,  20,  71,
         33,  34,  21,  74,  83,  61,  81,  51,  33,  55,   5,  25,  35,  83,
         54,  28,  37, 113, 123,  72,  24,  37, 138,  75, 112,  10, 113,  45,
         26,  27,  40,  19, 136,  22, 103, 107,  61,   9, 104,  37,  15,  63,
         80,   6,  74,  31,  75,  38,  87,  75], device='cuda:0')
argmax start logits shape: tensor([107,   5,  23,  58,   2,  67,  81,  73,  44,  62,  23,  31,  33,  77,
         50,  41,  30,  67,  51,  33,  87,  78,  36,  19,  66,  78,  59,  31,
         

Epoch 1/10:  34%|███▍      | 107/313 [00:07<00:13, 14.84it/s, loss=8.46]

argmax start logits shape: tensor([ 74, 127,  24,  53,  11,  51,  56,  69,   7,  50,  62, 129,  88, 100,
         10, 113,  67,  46,  49,   2,  17,  24,  52,  33,  23,  17,  14,  58,
         24,  15, 100,  27,  79,  45, 118,  99,  80,  77,  34,  93,  13,   6,
         87,  71,  79,  99,  60,  49,   8,  24,  24,  19,  14,  15,  62,  53,
         32,  10,  14,  32,  15,  17,  12,  40], device='cuda:0')
argmax end logits shape: tensor([ 88,  88,  20,  92,  58, 108,  15,  31,  26,   4,   6, 117,  89,  17,
         25, 126,  68, 122,  42,  38,  39,  41,  65,  39,  51,  21,  35,  15,
          2,  86, 103,  33,  79,  34,   3,  47, 128,  50,  85,  34,  45,   7,
         53,  11,  47,  13,  67,  30,  67,  42,  26,  30,  42,  23,  99,   2,
        114,   7,  14,  64,  17,  17,  88,  56], device='cuda:0')
argmax start logits shape: tensor([ 77,  55, 110,   2,  55,  53,  78,  56,  81,  69,  62,  34,  26,  35,
          2,  99,   3,  24,  77,  40,  16,  70, 100,  63, 117,  31, 137,   7,
         

Epoch 1/10:  35%|███▌      | 111/313 [00:07<00:13, 15.05it/s, loss=8.16]

argmax start logits shape: tensor([  4,  18,  33,  34,  15, 107,  23,   0,  77,  42, 107,  56, 100,  35,
        118, 128,  29,  56,  44, 123,  61,   1,  53,  14,  54,  53, 105,   7,
         14,  13,  37,  53, 120,  54,   2,  45,  66,  58,  45, 107,   4,  71,
          6,  26, 115,  81,  85,   4,  89,   2,  74, 100,  11,  57, 110,  57,
         14,   2,   5,  60,  81,  26,  24,  26], device='cuda:0')
argmax end logits shape: tensor([  4,  18,  27,  50,  27, 110,   6,  46,  77,   7,  42, 110,  20,   5,
         19, 128,  49,  56,   3,  18,  50,  40,  30,  61,   6,  77,  41,   7,
         31,  17,   2,  61,  76,  27,  35,  16, 125,   8,  33, 107,  23,  24,
         21,  37, 104, 107,  83, 117, 106,   2,   9, 105,  17,  87, 110,  66,
         68,   2,   5,  77, 117,  26,  14,  26], device='cuda:0')
argmax start logits shape: tensor([ 26,  73,  25,  48,  39,  37,   8,   2,  89,  51,  16,   0,  72,   3,
          2,  45,  26,  84,  55,  23,  52, 140,  56,  79,  98,   6,  22,  21,
         

Epoch 1/10:  36%|███▌      | 113/313 [00:08<00:13, 14.97it/s, loss=8.53]

argmax start logits shape: tensor([ 23,   0,   8,  12,   0,  10,   6,  27,  34,  85,  81,  83,  91,  77,
          0,  81,  88, 104,  71,  85,   0,  51,  78,   0,   0, 105,   9,  82,
         90,   0,  15, 115,  62,  38,  39,  26,  63,   0,  60,  36,   0,  50,
         95,  65, 119,  76,   2,  52,  98, 105,   6,  25,   0,  15,  23,  99,
          0,  88,  44,  85,   0,   8, 127,  26], device='cuda:0')
argmax end logits shape: tensor([ 23,   1,  78,  13,  23,   8,  21,  16,  83,  45,  47,  78,  53,  77,
         11,  41, 109,   3,  76,  52,  84,  80,  80,  73,  23, 105,  63,  28,
         63,  59,   3,  43,  52, 103,  72,  31,  75, 113,  78,  23,  64,  39,
         19, 100, 134,  92,  27,  52,  89,  31,  19,  47,  71,  20,  55, 112,
         47,  41,  71,  94,  44,  97,   9,  16], device='cuda:0')
argmax start logits shape: tensor([ 79,  93,  25,  30, 146, 113,  43,  28,  40,   2,  27, 107,  53,  73,
         45,  59,  47,  23,   0,   0,  24,   0,  44,  65,  22,   0,   0,  60,
         

Epoch 1/10:  37%|███▋      | 115/313 [00:08<00:13, 15.00it/s, loss=8.39]

argmax start logits shape: tensor([ 32, 119,  71,   0,   9,  94,  96,   5,  63,   0,   1, 107, 126,  68,
         21,  91,   9,   0,   4,  71,  47,  76,   0,  70,  32,  54,  43,  54,
         18,  17,  16,  25,  98, 105,  88,   8,   7,  46,   2,  74,  26,   0,
         29,  86,  70, 103,   3,  40, 136,   6,  21,  26,  56,  69,   3,  16,
         21,  25,  19,   5,  53, 134, 102,  58], device='cuda:0')
argmax end logits shape: tensor([ 25, 128,  70,  18,  65,  81,  21,  18,  73,  88,  32,  39, 135,  45,
         18,  99,  19,  13,  28,  56,  10,  78,  10,  83,  47,  54,  77,  27,
         49,  81,  32,  52,  39,  25,  90,  22,   5,  50,  70,  26,  75,  58,
        129,   8,  62,  13,  81,  58,  15,  19,  12, 110,  44,  69,  21,  27,
         28,  80,  13,  39,   9, 137,  36,  61], device='cuda:0')
argmax start logits shape: tensor([ 79, 123,   5,   5,  39, 106,  11,  29,  36,  22,  88,  62,  79, 104,
          3,  62,  50, 130, 110,  81,  88,  96,  33,  10,  69,  33,  22, 103,
         

Epoch 1/10:  38%|███▊      | 119/313 [00:08<00:12, 15.21it/s, loss=8.33]

argmax start logits shape: tensor([  6,  62,  64,  44,  93, 119, 108,  83,  85,  25,  47,  35,  50,  14,
         85,  43, 110, 101,   8,  19,  33,  11,  10,  32, 123,  55,  33,  54,
         71,  54,  55,  65,  69, 112,  78,  41,  13,  40,  22,  19,  13,  22,
         35,  13,   6, 101,   4,   7,  20,  47,  36,  84,  54,   6,   8,  22,
         42,  31,  11,  92,  94,  60,   3,  45], device='cuda:0')
argmax end logits shape: tensor([ 60,  23, 101,  44,  85,  52, 124,  75,  56,  15,  33,  24, 115,  32,
         56,  41,  80,  73,  50,  19, 100,  41,  48,  31,  24,  74,  21,  27,
         38,  26,  51,  42, 119, 112,  57,  26,  17,  12,  16,  52,   6,  33,
         88,  26,  67,  45, 106,   7,   4,  18,  24,  42,  27,  19,  11,   5,
         17,   3,  43,  40,  87,  12,  27,  81], device='cuda:0')
argmax start logits shape: tensor([ 71,  19,  12,  12,  56,  18, 108,  61,  60, 105,  46,  93,  65,  22,
         35,  90,  60,  81,  45,  66, 128,  51,  75,  10,  21,  43,  60,   8,
         

Epoch 1/10:  39%|███▊      | 121/313 [00:08<00:12, 15.21it/s, loss=8.48]

argmax start logits shape: tensor([ 97,  52,   0,  34,  23,   0,  60,   0,   0,  36,   0,   0,  10,  47,
         38,  83,  67,   0,  20,  81,  30, 110,  70,  52,  57,  22,   6,  94,
        121,   0,   8,  57,  20,   0,   0,  38, 128, 100,  46,  31,   0,  23,
        123,  60,   9,  24,  69,  41, 107,   0,  55,   0,  70,  14,  85,  50,
         24, 106,  60,  14,  79,   0,  97,   0], device='cuda:0')
argmax end logits shape: tensor([ 18,  24,  59,  39,  80,  64,  87,  26, 124,  11, 110,   9,  48,  52,
         68,  49,  86,  29, 115,  37,  29, 109,   7,  77,  37,  41,  68,  15,
         76,  58,  81,  56,  94,  50,  30,  18,  82,  52,   9,  23,  53,  88,
        100,  35, 102,   5,  95, 104,  93,  16,   6,  79,  45,  15,  74,  65,
         74,  52,  87,  53,  62,  21,  26,  20], device='cuda:0')
argmax start logits shape: tensor([ 15,  30,  95,  17,  31,   0,  38,  72, 101,  66,  45,   0,  63,  63,
         14,  21,   0,  81,   0,   6,  57,  51,   0,   1,  15,  99,  23,  37,
         

Epoch 1/10:  40%|███▉      | 125/313 [00:08<00:12, 15.00it/s, loss=8.62]

argmax start logits shape: tensor([ 27,  45,  32,  48,  71,  78,  68,  30,  47,  35,   3,  96,  99,  26,
         88,  81,  20,  66,  25,   7,  35, 114,  45,  19,  66,  25,  35,   5,
         79,  45,   7,  72,  54, 107,  61,  39,  22, 111,  30,  12,  49,  73,
         80,  37,  65,  72,  31,  69,  45,   1,  65,  41,  15,  95,  56,  24,
         10,  47,  40, 126,  35,  12,  15,  71], device='cuda:0')
argmax end logits shape: tensor([ 15,  23,  27,  54,  76,  57, 102,   3,  63,   5,  37,  58,   5,  32,
         87,  39,  20,  10,  24,  69,  12, 121,  36,  29,  52,  30,  24,  70,
         62,  14,  91, 122,  61,  13,  52,  39,   6,  70,  30,  40, 115,  15,
        105,  46, 124,  16,  40,  95, 109,  65,  85,  71, 140,  70,  36,  95,
         15,  36,  51,  63,   3,  11,  35,  95], device='cuda:0')
argmax start logits shape: tensor([ 37,  39, 104,  70,   6,  25, 128,  91,  28,  49,  17,  49,  34,  84,
         88,  80,  97,  78, 132,  48,   3,  11,   2,  49,  89,  37, 127,  21,
         

Epoch 1/10:  41%|████      | 129/313 [00:09<00:12, 15.02it/s, loss=8.37]

argmax start logits shape: tensor([ 56,  49,  43,   0,  11,  15, 108,   9,   0,  23,   0,  98,  47,   9,
         52,  20,   3,  65,  22,  77, 142,  98,  91, 114,  12,   0,  20,  72,
          0,   0,   9,   0,   0,   0,  34,  46,  40,   0,   0,  87,  95,   4,
          1,   0,   0,   0,  22,   0,  53,  67,   0,   7,  82,  67,  31,   7,
         70,   0,  18,  70,  91,  24,  24,  13], device='cuda:0')
argmax end logits shape: tensor([ 79,  64,  49,   5,  28,  60,  38,  44,  21, 102,  42,  98,  25,  29,
         47, 105, 102, 105,  56, 107,  98,  13,  74,  36, 111,  35,  17,  52,
         21,  84,   5,  16,  65,  49,  82,  21,  14,  25,  85,  80,  84, 122,
         87,  78,  41,   5,  26, 109,  80,  55,  70,   5,  41,  33,  81,  89,
         35,  73,  78,  78,  67,  58,  49, 110], device='cuda:0')
argmax start logits shape: tensor([ 11,  74,  28,   0,  24,  62,   0,   0,   0,  59,  57,  79,   0,  11,
          0,  56,   0, 102,  46,   0, 121,  49,  19,   0,   0, 106,   0,  19,
         

Epoch 1/10:  42%|████▏     | 133/313 [00:09<00:11, 15.14it/s, loss=8.35]

argmax start logits shape: tensor([ 60,  22,   7,   9, 121,   0,   0,   0,   0,  33,  88,   0,   0,   0,
         76,  38,   0,  41,   0,   0,   0,  13,  13,   0,   0,   9,   0,   6,
          0,   0,  25,  87,  92,   3,   0,   0,   0,   0,   0,   0,  13,   0,
         13,   0, 118,   0,   0,   9,   9,   0,   2,  31,   0,   0,   0,   0,
         86,   0,  66,   0,  37, 107,  94,   0], device='cuda:0')
argmax end logits shape: tensor([ 10,   2,   5,  51,  50,  18,  57,  34,   3,   2,  63,   9,   2,  20,
         80,  76,  80, 117,  21,  43,   1,  27, 110,   2,  59,  68,   7,   8,
         14, 119,  10,  42,  40,   9,  18,  95,  67,  66,  36,  67,  27,  52,
         47,  52,  70,  77,  24, 102, 102,  23, 109,  51, 119,   7,  53,  90,
        142,  46,  57,  62,  16,  96, 133,  70], device='cuda:0')
argmax start logits shape: tensor([  0,  25,   0,   0,   0,  90,   0,   0,   0,  25,   0,  39,  29,  69,
         17,   0,   0, 107,   0,   0,   0,  53,   0,  56,   0,   0,   0,   0,
         

Epoch 1/10:  43%|████▎     | 135/313 [00:09<00:11, 15.07it/s, loss=8.44]

tensor([110,  79, 111,   1,  69,  13,  79,   0,  24,  88,  83,  74,  26,   7,
         90,  71, 119,  96,  39,  69,  27,  54,  24,  17,  36, 105,  53, 103,
         39,  69, 108,  32, 131, 102,  49,   8,  10, 113,  29,  38,   2,  39,
         85,  25,  93,  11,  52,   2,   0,  70,  20, 122,  39,  43, 101, 117,
         92,  28,  86,  15,  22,   0,  28, 104], device='cuda:0')
argmax end logits shape: tensor([ 23,   5, 134,  45,  99, 111,  60,  47, 113,  41,  59,  36,  73,  69,
         17,  20, 114,  84,  74,  37,  10,   6,  47,  43,  37,  20,  28,  50,
         88,  76,   4,  20,  32, 108,   3,  55,  46,  12,  91,   7,  18,  52,
         79,   3,  47,  16,  27,  34,  84,  78,  67, 109,  86,   9,  13,  39,
         15,  75,  20,  48,  71,  55,  22,  58], device='cuda:0')
argmax start logits shape: tensor([ 45,  19,  77,  47,  33,  81,  20, 109,  19,  46,  39,  66,   7,  69,
         32,   7,  55,  16,  77,  22,  14,  78,  73,  28,  59,  31,  63,  74,
         82, 137,  30,   5,  21,  33

Epoch 1/10:  44%|████▍     | 137/313 [00:09<00:11, 14.87it/s, loss=8.34]

argmax start logits shape: tensor([ 11,  72,   2,  19,  39,  25,  27,  80,  27,  87,  51, 116,  81,   7,
        139,  22, 100,   2,  44,  98,  53,  83,  53,   7,   3,  67,  31, 140,
        107,  80,  47,  94,  60,   3,  91,  10,   1,  58,  47,  27,  32, 112,
         86,  34,  48,  61,  50,  45,  87,  58,  49,  32,  74,  40,  43, 104,
        101,  87,  52,  79,  48,  11,  72,  72], device='cuda:0')
argmax end logits shape: tensor([ 93,   3,  19,   7,  29,  76,  34,  33,  96,  85,  59,  15,   9,  63,
         26,  19,  48,  54,  25,  98,  92,  11,  42, 109,   5,   2,  80,  30,
         57,  24,  48,  50,  63,  61,  26,  30, 107,  39,  43,  21,  62,  48,
         10,  15,  12,  85,  57,  33, 107,  32,  30,   6,  43,  37,  45,  62,
         66,  33,  54,  57,  49,   9,  49,  52], device='cuda:0')
argmax start logits shape: tensor([ 20,  70,   4,  54, 140,  42,   2,  84,  86,  12,  41,  14,  12,  83,
         41, 117,   5,  18,  39, 115,  53,  90,  34,  50,  12,  79, 127,  86,
         

Epoch 1/10:  45%|████▌     | 141/313 [00:09<00:11, 14.94it/s, loss=8.66]

argmax start logits shape: tensor([ 25,  89,  48,   0,   5,  57,  81,  55,   0,  30,  64,  69,  72,  24,
         18,   9,  34,  99,  18,   0,  12,  93,   0,   7,   0,  91,   0,   0,
        103,  54,  11,  70,  42,  18,  35,  37,   0,  87,  23,   0,  47,   5,
         63,  35,   0, 116,  91,   2,  42,  60,  53,   0,   0,  14,  20,   9,
          0,  20,  47,  93,  92,  96,  63,   0], device='cuda:0')
argmax end logits shape: tensor([ 81,  26,  41,  52,  24,  47,  19,  59, 107,  52, 107,  67,  81,  56,
          8,  24,  58,  71,  56, 103,  44,  63,  49,   8,  25,  32,  13,  22,
         26,  41,  14,  99,   4,  17,  93,  33,  30,  53,  10,  36,  43,   9,
         43,   6,  96,   5,  45,  64,  77,  80, 137,  50,  21,  30,  42, 112,
         36,  14,  81,  46,  69,  87,  51,  65], device='cuda:0')
argmax start logits shape: tensor([ 27,  21,  16,  19,   0,   0,   0,   0,   2,  65,   0,  27,  62,  31,
          0,  20,   0,   0,  36,   2,  50,   0,  20,   0,  24,  92, 121, 113,
         

Epoch 1/10:  46%|████▌     | 143/313 [00:10<00:11, 14.93it/s, loss=8.35]

argmax start logits shape: tensor([110, 119,  32,   0,   0,   0,  47,   5,   0,  30,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0, 105,  53,   0,   0,   0,  95,  63,  64,
          0,   0,   0,   5,   0,  82,  19,   0,   0,   0,   0,   0,  19,   0,
          0,   0,  60,  36,  63,  10,  59,   0,   0,  75,   0,   0,   0,   0,
         56,  25,   0,   0,   0,   0,  42,  28], device='cuda:0')
argmax end logits shape: tensor([ 91, 124,  38,   6, 130,  68, 107,  57,  42, 126,  85,  71,  17,  41,
         28,   1, 102,   4,  71,   7,  20,  50,  11,  38,   5,  26,  76,  98,
         18,  64,  48,  49,  93, 105, 106,  20,  30,  35,  72,  39, 106,  29,
         32,  40,  57, 118,   3,  87,   8,   5,  42, 118,  45,  58,  56,  38,
        101,  40,  45,  26,  42, 101,  99,  34], device='cuda:0')
argmax start logits shape: tensor([104,   0,   0,   0,  87,   0,  80,  39,  25,   0,  81,  56, 102,   0,
         88,   0,  36,   5,  23,  20,  73, 102,   0,   0,   0,   0,  71,   0,
         

Epoch 1/10:  47%|████▋     | 147/313 [00:10<00:11, 15.01it/s, loss=8.55]

argmax start logits shape: tensor([ 53,  64,   0,   0,  67,  84,   2,   6,  93,   3,   0,   0,  44,   0,
          0,   0, 116,   0,  25,   0,   0,   0,   0,  29,  40,   0,   0,   0,
         46,   0,  41,  82,   0,  93,   0,   0, 107,   0,   0,   0,  23,  20,
          0,   0,  23,   0,  12,   0,   0,   0,  29,  35,   0,   0,   0,  11,
          0,  91,  84,  20,   0,   0,   0,   0], device='cuda:0')
argmax end logits shape: tensor([ 79,  83,  38,  15,  45,  25,  21,   5,  75,  37, 130,  10, 103,  92,
         10,  88, 113,  58,  41,  22,  50,   3,  43,   3,  43,  59,   8,   2,
         64,  38,  36,  56,  82,  80,  31,  60,  10,  47,  98,  30,  36,  89,
         93,  42,  31,  38,  11,  13,   7,   5,  20,   4,  42,  84, 105,   4,
         14,  36,  17,  70,  36,  41,  11,  51], device='cuda:0')
argmax start logits shape: tensor([  0,   0,  39,  27,  43, 117,   0,  28,   0,  42,  36,  45,   0,   6,
          8,   0,  66,  14,   9, 107,   0, 117,  46,   0,  76,   0,   0,   0,
         

Epoch 1/10:  48%|████▊     | 149/313 [00:10<00:10, 15.01it/s, loss=8.19]

argmax start logits shape: tensor([ 28,   0,  31,   0,  48,  51,  97,   0,   0,   7,   0,   0,   0,  15,
          0,   0,  26,  88,   0,   4,   0,   7,   0,   9,  98,   0,   3,  86,
        108,   6,  23,   7,  15, 123,   0,  48,   0,   0,   0,  45,  18,  52,
          0,   0,   0,  68,   0,  21,   0,  29,  52,   0,   0,   0,  11, 113,
          0,  12,  92,   0,  36,   0,  67,   0], device='cuda:0')
argmax end logits shape: tensor([ 19,  42,   5,  35,  64, 117,  41,  84,  40,  62,  80,  23,   4,  54,
         64,  32,  66,  19,  78, 109,  81,   5,  65,  30,  51,   5,  83,  85,
         30,  24,  15,  56,  45, 113,   3,  93,   7,  37,  77,  76,  82,  37,
         25,   5,  27,  72,  31,  92,   3,  87,  90,  99,  68,  53, 128,  97,
         23,  48,  40,  54,  76,  45, 108,  56], device='cuda:0')
argmax start logits shape: tensor([  0,  58,   0,   0,  94,   0, 110,   0,   0,  18,   0,   8,  21,  31,
          0,  38,  12,  98, 104,  26,   0,  42, 104,  30,  51,   0,   0,   0,
        1

Epoch 1/10:  49%|████▉     | 153/313 [00:10<00:11, 14.53it/s, loss=8.76]

argmax start logits shape: tensor([  0,   0,  30,   2,  22,   0,  19,   0,  31,  38,   0,   0,   0,  98,
         54,   0,  81,  87,   0,   0,  48,  94,   0,   0,  93, 131,   0,  45,
        128,  99,  41,   0,   0,   0,   0,  62,  56,   0,   0,  55,  71, 114,
         75,   9,   0,   0,   8,  76, 108,  18,  46,   0,  19,   0,  58,  62,
          0,  56,  29,  21,  46,   9,   0,  26], device='cuda:0')
argmax end logits shape: tensor([ 48,   7,  89,   3,  71,  30, 118,  13,  21, 110,  82,  89,  22,  42,
         35,  10,  32,  46,  62,   4,  17,  13,   9,  21,  33,  28,  76, 129,
         27,  68,  39,  90,  56,  19,  68,  77,  45,  67,   4,  85,  48,  37,
         17, 111,  11,  32,  36,  75,  91,   3,  15,  64,  50,  30,   4, 116,
         32,  49,  82,  21,  11,   7,  40,  16], device='cuda:0')
argmax start logits shape: tensor([112,   0,  85,  20,  89,  92,  27,  58,  12,  67,   0,   0,  21,  84,
        127,  12,  17,  17,  41,   6,  53,   0, 108,  23,   0,  11,  60,  72,
         

Epoch 1/10:  50%|████▉     | 155/313 [00:10<00:10, 14.66it/s, loss=8.14]

argmax start logits shape: tensor([  0,  13,   0,  66,   0,  12, 100,   0,  83,  91,  38,   0,  84,   0,
          0, 105,   0,   0,  36,  99,  83,  34,   0,   0,   0,  11,  58,  64,
          0,  46,  92,   0,  90,   0,   0,  41,   0,  15,   0,   0,   2,   0,
         70,   0,   0, 100,   0,   0,  20,   0,   0,   0,   0,   0,  58,  62,
          3,   0, 119,  42,  21,  21,  10,   2], device='cuda:0')
argmax end logits shape: tensor([ 60,  45,   3,   8,  13,  36,  24,  24,  30,  83,  21,  31,  35,  27,
         42,  31,  33,  55,  65,  38,  45,   3,  19,  58,  35,  24,  64,  65,
         84,   7,  30,  82, 104,  58,   6,  38,   5,  33,   3,  16,  36,  27,
         76,  45,  21,  27,  17,   9,  48,  49,  24,  80,  78,  43,  23,  79,
        112,  37,  47, 118,   4,  35,  41,   2], device='cuda:0')
argmax start logits shape: tensor([  0,   0,   0,  29,  50,   0,  88,  31,  25,  34,  20,   0,   0,  32,
          0,  63,  45,   0,  32,   0,   0,   0,  10,  33,  26,   0,   0,   0,
        1

Epoch 1/10:  51%|█████     | 159/313 [00:11<00:10, 14.82it/s, loss=8.27]

argmax start logits shape: tensor([100,  34, 109,   0,   0,   0,   3,  18,  97,  25,   0,  64,   0,  12,
          0,  12,   0,  21,   0,  41,   0, 133,   0, 102,   0,  40,   0,  33,
        104,  30,   0,  49,  20, 102,   0,   0,  63,   0,   2,   6,  30,  47,
         84,   0,  24,  64,   0,   0,   0,   0,   0,   0,   0,   0,   0,   5,
        110,  67,  15,   0,   0,  46,   0,  79], device='cuda:0')
argmax end logits shape: tensor([ 75,  97,  33,   5,  11,  34, 112,  30,  61,  56,  25,  31,  12,  43,
         37,  60, 116,   3,  11,  54,  46,  85,  15,  60,  22,   2,  58,  33,
        106,  75,  36,  36, 104,   5,  63,  31,  48,  82,  21, 115,   2,  92,
         44,  24,  21,  42, 110,  94,  43, 114,  64,  44,  28,  66,  82,  53,
         63,  66,  32,  29,  53,  23,  24,  87], device='cuda:0')
argmax start logits shape: tensor([  2,  41,  41, 122,  14,   0,   0,   8,   0,   0,  42,   0,  86,   0,
          0,  10,  37,  25,  30,  66,  33,   0,   2,   0,   0,  80,  85,  25,
         

Epoch 1/10:  52%|█████▏    | 163/313 [00:11<00:10, 14.95it/s, loss=8.16]

argmax start logits shape: tensor([ 63,  33,   0,   8,  12,  14,  43,  55,  39,  90,  17, 121,  10,  38,
         37, 100,   9,  21,  93,  60,   2,  38,  40,  45,  86,  62,  71,  67,
        123,  20, 105,  71,  56,  72,   7,   2,  55,  25,  67,  64,  97,  70,
         77,  45,  59,  78,  26,  96,  90,  81,   9,   2, 113,  85, 105,  51,
        100, 110,  74,  37,  44,  28,  55,  17], device='cuda:0')
argmax end logits shape: tensor([ 53,   5,  57,  68,  36,  53,  31,  55,  15,  42, 124, 102, 115, 122,
         13,  70,  36,   5,  12,  16,  23,  28,  17,  18, 126,  62,  71,  99,
        130,  21,  50,  72,  42, 114,  18,  61,   5,  31,  32,  28, 133,  55,
          2,  55,  65,  87,  18,   5,  63, 122,   7, 111,  59,  61,  29,  51,
          3, 129,  44,  45,  60,  27,  59,  12], device='cuda:0')
argmax start logits shape: tensor([ 20,  52,  15,  63,  24,  46,  37,  47,  16,  14,  15,  24,  58,  71,
         67,  43,  44,  44,  20,  70,   2,  59,   7,  10,  14,  84,  56,  54,
         

Epoch 1/10:  53%|█████▎    | 167/313 [00:11<00:09, 15.13it/s, loss=8.32]

argmax start logits shape: tensor([ 13,  64,   3,   0,  78,   5,   7, 105,  20,  29,   0,  87,  62,   0,
         11,  65,   5,  12,  24,   2,  52,  21,  59,  23, 117,  61, 105,  71,
         12,   2,  12,   3,  23,  58,  84, 117,   6,  31, 116, 121,  46,   0,
        108,  18,  19,  42,  16,  15, 121,  51,  77,  10,  78,  66,   0,  56,
         47, 132,  10,  98,  41,  25, 113,  81], device='cuda:0')
argmax end logits shape: tensor([ 50,  23,  74,  21,  80,   5,  30,  50,  26,  31,  83,  88,  70,  31,
         65,  43,  16,  12,   6,  61,  17,  13, 101,  33, 119,   3,  23, 106,
         69,  62,   1, 116,  60,  59,   2,  93, 120,   2,  27,   7,  11,  18,
         90,  12,  49,  67,  77,  34,   2,  42,  38,  11,  79,  76,  65,  48,
         57, 142,  29,  52,  65,  87,  95,  98], device='cuda:0')
argmax start logits shape: tensor([ 33,  88,  33,  22,  85,  28,   2,  90,   6,  19,  72,  23,  35,  40,
         38,  15, 103,  40,  60, 103,  21,  38,  50,  99,  89,  89,   9,  34,
         

Epoch 1/10:  55%|█████▍    | 171/313 [00:11<00:09, 15.02it/s, loss=8.28]

argmax start logits shape: tensor([ 55,   3,  12, 112,  11,  11,  46,  31, 105, 101,  31,  41,  20,  38,
          8,  50,   2,  60,  56, 100, 107,   2,  21,  57,  71,   4,  26,  22,
         81,  49,  69,  57,   5,   8,  56,  33,  71,   6,  81,  69,  20,  21,
         70,  21,  30,  99,   9,  26,  34,  87,  15,  18,  93,   5,  56,  54,
         49,  90,   8,  39,  32,  22,   2,  74], device='cuda:0')
argmax end logits shape: tensor([ 80,   3,  26, 102,  96,  21,  30,  54, 110,  86,  55,  81,  18,  89,
         64,  81,   3,   3,  79,  14,  14,  77,  48,  34,  84,  79,   4,  41,
         37,  25,  93,  13,  95,  37,  42, 115,  10,  45,  52, 117,  18,  11,
          3,  30,  45,  30,  29,  92,  39, 118,  78,  18,  29,  27,  31,   7,
         82,  53,  14,  39,  13,  12,   8,  74], device='cuda:0')
argmax start logits shape: tensor([ 39, 111,  63,  69,  59,  10,  68,  62,  56,  56,  19, 119,  99,   2,
         21,  90,  54,  45,  27,  87,  39,  69,  78,  71,   2,  33,  69,  38,
         

Epoch 1/10:  56%|█████▌    | 175/313 [00:12<00:09, 15.03it/s, loss=8.31]

argmax start logits shape: tensor([ 10,  80,  97,  86,  81,  33,   0,  20, 105,   8,  19,   2,   7,  40,
          7,  11,  79,  74,  38,   6,  50,  98,  32,  36,  12,  27,  54,   7,
          3,  27,  24,  94,   0, 106,  36,   8,  52,   0,  27,  83, 156,  20,
         43,  74,  42,  32,  35,  12,  12,  69,  57,  99,  42,   5,   2,  45,
         95,  23,  24,  59,  11,  61, 107,  37], device='cuda:0')
argmax end logits shape: tensor([ 46,  56,  41,  58,  33,  33,   9,  92,  20,  10,   3, 106,   4,  28,
         11,  45,  64,  24,  17,  68,  72,  69,  32,   9,  18,  32,   5,   7,
         25,  42, 100,  62,  53,   8,  39,  59, 106,  57,  34,  77,  83,  30,
         30,  64,  88,  84,  38,  12,  22,  65,  48,  10,  60,  31,  19,  61,
        112,   3,  42,  86,   5,   5, 107, 103], device='cuda:0')
argmax start logits shape: tensor([  0,   1,  44,   0,   4,   0,  55, 103,  34,   0,  61,  40,  57,  59,
        127,  27,  28, 137,  15,   0,  68,   0,  80, 101,  60,   4,  45,  35,
         

Epoch 1/10:  57%|█████▋    | 179/313 [00:12<00:08, 14.95it/s, loss=8.5] 

argmax start logits shape: tensor([  0,   0,   0, 138,   0,   0,   0,   0,   0,  45,  64,   0,   0,   0,
          0,   0,   0,   0,  55,   0,   0,  49,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   9,   0,   0,   0,   0,   5,
          0,   0,   0,  41,   0,   0,  61,  21,   0,  50,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,  84], device='cuda:0')
argmax end logits shape: tensor([ 56,  56,  45,  49,  72,  62,   6,  54,  42,  37,  49,  18,   6, 101,
         22,  61, 117,  34,  88,   4,  16,  24,  98,  94, 140,  27,   5,  76,
         86,  91,  15,  21, 119,   5,   6,  47,  88,  15,  39,  60,  73,  21,
         81,  94,  35,   5,  25,  53,  20, 118,  54,  71,  68,   3,  80,  38,
         99, 111,  15,   4,  24,  52,  83,  86], device='cuda:0')
argmax start logits shape: tensor([  0,   0,  91,   0,  56,   0,   0,  29,   0,   0,   0,  55,   0,   0,
          0, 104,  62,  43,   0,   0,   0,   0,   0,   0, 142,   0,   0,   0,
         

Epoch 1/10:  58%|█████▊    | 183/313 [00:12<00:08, 14.92it/s, loss=8.06]

argmax start logits shape: tensor([  0,  56,   0,   0,   0,   0,  92,   0,   0,   0,  63,  47,   0,  34,
        106,   0,   0,   0,  40,   0,  79,   0,  58,   0,   0,   0,  17,  24,
          0,   0,   0,  63,   0,   0,  47,   0,   0,   0,  69,   0,  56,   0,
         62,   0,   6,   0,   0,   0,   0,   0,   0,  49,  66,   0,  45,   6,
         46,  30,   0,   0, 104,  96,   9,   0], device='cuda:0')
argmax end logits shape: tensor([ 20,  76,  52,  26,  31,  80,  39, 110,  23,  54,  83,   5,   8,  97,
        129,  34,   6,  59, 118, 123,  60,   5,  68, 138,  67,  57,  54,  27,
         46,  27,  55,  35,  47,  51,   5, 100,  97, 123,  45,  71,  68,  11,
         35,  35,  80,  84,   6,  28,  58,  36,  70,  62,  22,   4,  32,   2,
         23,  12,  44, 114,  79,  42,  67,  35], device='cuda:0')
argmax start logits shape: tensor([  0,  19,   0, 103, 113,   0,   0,  38,   2,   0,  44,   0,   0,   0,
         35,  94,  48,   0, 127,  79,  97,   0,   0,  53,  63,  35,   0, 109,
         

Epoch 1/10:  59%|█████▉    | 185/313 [00:12<00:08, 14.95it/s, loss=8.16]

argmax start logits shape: tensor([ 83,   5,  92,  76,  77,   0,  35, 134,   0,  32,  77,   7, 103,   0,
         12,  35,   0,  78,  76,  70,  26,  67,  11,  25,  72,  30,  83,  17,
         92,   0,  34,  18,   9,   0,   6,  42, 138,  24,  14,   8,  16,  96,
          3,  14,  60,  16,  92,   0,  34,  35,  46,  63,  71,   0,  49,  25,
         25,  40,  28,   8,  65,  29,   7,  67], device='cuda:0')
argmax end logits shape: tensor([116,  21,  84,  88,   6,  46,  37, 140,  97,  91,  79,  43,  51,  27,
         64,  24,  99,  74,  81,  50,  67, 125,   5,  36,  25,  43,  23,  44,
        116,  11,   9,  95,   6,  62,  72,  31,   3,  41,  37,  16,  89,   4,
          3,  28,  22,  55, 108,  20,  89, 107,  71,  99,  22,  10,  99,  82,
          8,  56,  20,  10,  26,  23,  58,  44], device='cuda:0')
argmax start logits shape: tensor([  7,  30,  56, 110,  52,  63,  74,  14,   0,  22,  40,  13,  21,  63,
          0,   6,  20,   6,  80,  42,  31, 100,  27,   0,   0,  37,  87,   0,
         

Epoch 1/10:  60%|█████▉    | 187/313 [00:13<00:08, 15.05it/s, loss=8.37]

argmax start logits shape: tensor([  6,  12,  32,   2,   0,  57,   0,   6,  89,  36,  90,  27,  79,  60,
         15,   0, 115,  50,  76,  67, 103,  78,  33,   0, 101,  75,  52,  23,
          2,   0,  34,  83,  28,   0,  30,   0,  36,  39,  26,   0,   0,  84,
          2,   0,   1,  49,  28,  96,  21,   0,   1,  55,  38,  46,  43,  24,
         29,   0,   9,   0,   0,  27,  49,  65], device='cuda:0')
argmax end logits shape: tensor([ 15,  51,  80,  72,  30,   7,  80,  13,  28,   6,  99,  20,  89,  58,
         43,  30, 147,  73,  79,  99,  89,  52,  66,   5,  32, 115,  37,  61,
         32,  54, 120, 113,   7,  44, 104, 114,  65,  84,  43,   3,  40,  85,
         35,  68,  21, 129,  20,  80,  48,  23,  31,  58, 109,  61,  10,  41,
         17,  41,  14, 106,  11,  90,  14,  39], device='cuda:0')
argmax start logits shape: tensor([  3,   7,  57,  65,  15,   0,  10,  79,  77,  24,  98, 130,  52,  98,
         13,  25,   7,  10,  83,  53,   2,  89,   0,  15,  66,  98,  43,   1,
         

Epoch 1/10:  61%|██████    | 191/313 [00:13<00:07, 15.30it/s, loss=7.86]

argmax start logits shape: tensor([ 87,  15,  18,  22,  16,   0,  37,  38,  14,  57,   7,  20,  69,  54,
          2,  14, 100, 106,  95,  49,  27,  27,  82, 102,   2,  47,  17,  54,
         96,  68,  43,  11,  96,  86,  62,  25,  54,  87,  25, 105,  43,  67,
         78,  66, 110,  24,  19,  34,  57,  98,   1,  42,   0,  27,  18,   1,
         12,  49, 113,  84,   4,  53,   9,  60], device='cuda:0')
argmax end logits shape: tensor([ 28,  54,  63,  24, 142,  10,  41,  25,  26,  35,  10,  25,  98,   2,
         15,  37,  60, 104, 110,  70,  25,  29,  93,  37,  13,  43,  20,   6,
         87,   4,  44,  50,  87,  25,  63, 101,  37,  26,  83,  75,   3,  65,
         80,  61, 113,  20,   8,  10,  10,  54,  21,  64,  54,  34,  80,  90,
        108,  40,  82,  30,  12,  46,  56,  43], device='cuda:0')
argmax start logits shape: tensor([ 46,   2,   0,  87,  30,  76,  34,  15,  80,  66,  79,  47,  60,   7,
         59,  12,  54,  44,  79,  56,  43,  46,  25,   2,  71,   8,  15,  88,
         

Epoch 1/10:  62%|██████▏   | 195/313 [00:13<00:07, 15.24it/s, loss=8.08]

argmax start logits shape: tensor([121, 101,  12,  90,  61,   0,  56,  25,  65,  95,  96,  77,  49,  94,
         32,   3,  65,  98,  29,  90,  25, 101,   5,  50,  46,   8,  19,  70,
          2,  53,   7,  55,  54,  52, 106,   8,  15,   4,  33,  44,  33,  44,
         41,  38,  48,  41,   2,  11,  35,   2,  61,  27,  84,  38,  90,  43,
         49,  40,  75, 100,  59,  34,  94,  92], device='cuda:0')
argmax end logits shape: tensor([  5,  34,  47,  61,  20,  46,  68,  26,  10,  18,  99,  24,  25, 118,
          4,  44,  67,  17,  19,  53, 144,  96,  93,  23,  10,  37,   7,  88,
         35,  78, 126,  34,   6,   5,  36, 124,  65,  21,  10, 101,  63,  94,
         50,  81, 108,  44,   5,   4,  55,   3,  65,  62,  50,  21,  92,  32,
         59,  80,  23, 105,  64,  29, 109,  23], device='cuda:0')
argmax start logits shape: tensor([ 52,   0,  72,  63,  40,  93,   0,  10,  88,  19,   4, 119, 112,  24,
         34,  55,  43,  24,  63,   0,  84,   2,  59,  31,   2,  88, 127,  49,
         

Epoch 1/10:  63%|██████▎   | 197/313 [00:13<00:07, 15.02it/s, loss=8.06]

argmax start logits shape: tensor([ 38,  30,  11,  63,   4,  16,  43,  77,   0, 100,  75,  67,  24,  35,
         16,   0, 131,  84,  53,  91,   0,  75,  11,  38,  68,   1,  57,   2,
         25,   0, 100, 100,  96,   0,  11,  18,  53,   0,  30,  49,  89,  23,
         54,   9,  85,  28,  55,  34, 111,  20,   0,  50,   0, 102,   0,  84,
         69,   0,  76,  13,  34,  40,  11,  84], device='cuda:0')
argmax end logits shape: tensor([113,  15,  46, 110,  24,  28, 102, 119, 114,  43,  24,   8,  76,   2,
         59,  40,  52,  84,  84,  35,   9,  56,  75,  14,  67,  87,  85,  65,
        102,   2,  83, 114,  38,  59,  99,  23, 138,  51,   5,  20,  59,  79,
         39,  77,  20,  88,   2,  33,  44,  26,  19,   3,  51,  13,  14, 122,
         63,  13, 121, 112,  25,  40,  65,  64], device='cuda:0')
argmax start logits shape: tensor([116,  33,   0,  46,   0,  69,  48, 112,  10, 100,  39,   0,  42,  39,
          0,  75,  23,  88, 113,  59,  10,  75, 106,   0,   0,  42,  38,  14,
        1

Epoch 1/10:  64%|██████▎   | 199/313 [00:13<00:07, 14.88it/s, loss=8.33]

argmax start logits shape: tensor([  0,  22,   0,  47,   7,  13,  20, 104,  42,   0,   0,  49,  75,   7,
         17,   0,   0,  74,  28, 126,   7,  22,  49,   0,   0,  39,   0,  49,
         36,  70,   0,   0, 104,  54, 113,  26,   0,  29, 135,  25,  24,  43,
          0,  75, 105,   0,   0,   0,  22,   0,  67,  35,  79,  65,   0,   0,
         98,   0,   2,  48,   0,  59,  79,  20], device='cuda:0')
argmax end logits shape: tensor([ 95,  22,  48,  32,  35,  43, 105, 118,  28,  28,  17,  52,  56,  82,
         64,  32,  34,  58,  83,  88,  57,  78,  80,  36,  63,  79,   7,  10,
         39,  77,  30,  32,   9,   2,  67,  35,  69,  29,  11,  32,  54,  18,
         62,  56,  75,  66,  21,  51,  67, 126, 115,  65,  11, 105,  90,  70,
        103, 108,   2,  42,  36,   4,   3,  14], device='cuda:0')
argmax start logits shape: tensor([ 79,   0,   0,  53,  84,  85,   0,  20,  28,  21,   0, 143,   0, 114,
         28,   0,  30,   0,  30,  78,  78,   0,   0,   8,   0,   2,  88,   0,
        1

Epoch 1/10:  65%|██████▍   | 203/313 [00:14<00:07, 15.11it/s, loss=8.38]

argmax start logits shape: tensor([  0,  61,  45,  69,   0,  22,   0,  68,  77,  22,   0,  69,  69,  32,
          0, 123,  86,  62, 120,  23,   0,  78,  94,   0,   0,  99,   0,   0,
          0,  38,  26,   0,  60, 108,  75,  22,  11,  31,  66,  28,  57,   0,
         83,  14,   0,  37,  20,   2,  59,  54,  39,   1,  15,  96,  65,   0,
         22,  60,  19,  45,   0,  57,  28,   2], device='cuda:0')
argmax end logits shape: tensor([  5,  61,  15, 105,  23,  49,  43,  13,  10,  24,  28, 100,  63,  18,
         85, 124,  86,  19,  66,   9,  29,  80, 105,  75,  46, 140,  34, 118,
         94,  73,  82,  38, 128,  42, 100,  22,  23,  17,  19,  65,  48,  42,
         48,  43,  41,  37,   3,  57,  84,  13,  86,  63,  53,  65,  76,  66,
         59,  98,  67,   1,  80,  26,  75,  38], device='cuda:0')
argmax start logits shape: tensor([  0,   0,  82,  42,  66,  11,  26, 106,  22, 107,  40,  79,  96,  11,
          2,  40,   0,   0,  94, 104, 126,   8, 109,   4,   0,   8,  11,   0,
         

Epoch 1/10:  66%|██████▌   | 207/313 [00:14<00:07, 15.12it/s, loss=8.35]

argmax start logits shape: tensor([  0,  61, 107,   0,   0,  28,   2,   8,   0,  33,   3, 123,  25,  10,
          0,  70,   0,   0,  49,  45,  79,   0,   6,   0, 121,  73,  33,   0,
        112,  76,  11,  23,   2,  63,  43,  59,  19,  84,  78,  53,  59,   0,
          0,  77,  32,   0,   0,   0,   0,   0,  46,  34,   0,   8,   0,  24,
        112,   0,   0,  37,   5,  42,   0,   8], device='cuda:0')
argmax end logits shape: tensor([  3,   3,  25,  22,  31, 102,  52,  53,  95,  14,  84, 123,  67,  41,
         70,  50,  63,  32,  34,  78,   4,  59,  50,   4, 100,  73,  33,   8,
          3,  92,   8,  23,   2, 106,  22,  55,  45,  96,  76,  53,   1,  24,
         23,  43,  58,  34,  80, 108,  48,  22,  61,   3,  57,  16,  70,   4,
        112,  29,  31,  46,   5,   6,  46,  23], device='cuda:0')
argmax start logits shape: tensor([ 34,   0,   0,   0,   0, 117,   0,   0,   0,  36,   0, 138, 101,  54,
          0,  48,   0,   0,   0,  97,   0,   0,   0,  41,  25,   0,   0,  16,
         

Epoch 1/10:  67%|██████▋   | 209/313 [00:14<00:06, 15.10it/s, loss=8.37]

argmax start logits shape: tensor([  0,  34,  64,   0,  52,  52,  10,  15,  38,   0, 113,   3,  41,   0,
         59, 100,   0,   0,  19,   0,   0,   0,   0,   0,   0, 114,   0,   0,
          3,  46,  22,  54,  25,  59,  55,  72,   0,  26, 119,   0,   0,   0,
         55,  50,  47,   0,   0,  47,  10,  76,   3,   0,   0,  15,   0,   0,
        107,  84,   0,  27,  65, 126, 116,   0], device='cuda:0')
argmax end logits shape: tensor([ 35,  52,  47,  62,  52,  11,   4,  41,  18,  12, 129,   3,  67,  13,
         20,  22,  50,  84,  84,  39,  67,  27, 118,  65, 150,  20,  93,  39,
        128,  53,  32,  53,  81,   6, 100,  80,  97,  29,   3,  28,  28,  49,
         45,   6, 111,  77,  26,  65,  34,  80,  20,  45,  82,  80,  70,  27,
         92,  44,   5,  16,  33,  81,  90,  68], device='cuda:0')
argmax start logits shape: tensor([ 79,  53,  58,   0,   0,  14,  48,  70, 102,   9,  57,  68,  43,  96,
         17,  59,  61,  17,  89,  78,  55,  50,  42, 101,  27,  38,  17,  59,
         

Epoch 1/10:  68%|██████▊   | 213/313 [00:14<00:06, 14.89it/s, loss=8.36]

argmax start logits shape: tensor([  0,  15,  74,  60,  21,   9,  59,  88,  32,   0,  42,  75,  85,  67,
         47,  42,   0,  68, 109,  18,  43,  63,   8,  11,  11,   1,  53,  89,
         85,  53,  73,  60,  29,  15,  23,  28,   7,  76,   2, 131,   0, 103,
        112,   9,  72,  80,  88,  45,  11,  19,   8,  15,  78,  77,  94,  77,
          1,  21,  36,  42,  47,  21,   3,  13], device='cuda:0')
argmax end logits shape: tensor([ 10, 121,  58,  23,  11,  49,  60,  11,  64,  26,  78,  27,  25,  55,
          7,  58,  32,  55,  82,  68,  16,  85,  69,  49,  45,   2,  41,  76,
         40,   5,  16,  76,  30,  12, 109,  28,  22,  74, 110,  24,  30,  20,
         76,   4,  80,  60,  89,   5,  48,  43,  20,   9,  63,  74,  75,  77,
          4,  22,  65,  40,  39, 106, 108,  59], device='cuda:0')
argmax start logits shape: tensor([104,  17,   2,   7,  15,  55,  58,  92,   0,   8,  15,  55,  43,  10,
         10,   0,  52,  56,  35,  24,  50,  44,   0,  99,   6,  64,  11,  20,
         

Epoch 1/10:  69%|██████▊   | 215/313 [00:14<00:06, 15.00it/s, loss=8.2] 

argmax start logits shape: tensor([ 72,   0,   0,   0,   0,   0,  47,  86,  79,   0,   0,   0,  54,   0,
         28,  23,   0,  55,   0,   0,  19,  65,   0, 125,   0,   0,  59,   0,
        155,  16,   0,   0,  10,  90,  92,  22,  90,  88,   0,   0,  24, 104,
         58,   3,   0,  42,  37,   0,  33,  75,  27,  53, 105,   0,   0,   0,
         63,   0,   0,   1,  84,  85,  11,  52], device='cuda:0')
argmax end logits shape: tensor([ 53,  81,  20,  70,  28,  33,   3,  26, 121,  14,  93,  16,  14,   4,
          4,  56,  56,  18,  42,  49,  20,  33,  16, 111,  41,  50,   4,  57,
        119,  13,  55,  48,  88, 102,  65,  30, 134,   3,  64,  46,  12, 106,
         24,  27,  50,  53,  18,  32,  98,  59,  57,   8,  45,   2,  12, 113,
         88,  65,  47,   5,  78,  45,  58,  37], device='cuda:0')
argmax start logits shape: tensor([ 13,  77,  24, 101,  25,   5,   0,  35,   0,   0,   0,   0,  90, 127,
          0,  10,   0,   0,   0,  34,   0,   0,  69,   0,   8,   0,   0,  47,
         

Epoch 1/10:  70%|██████▉   | 219/313 [00:15<00:06, 15.02it/s, loss=8.05]

argmax start logits shape: tensor([  0,  26,  77,   0,  20,  60,  85,  84,   0,   0,  30,  92, 105,  61,
          8,  11,   0,   0,   0,   0,  77,  92, 104,  72,  53,  91,  99,  63,
          0,   0,   0,   0,   0,  14,  31,   0,   0,   0, 110,   0,   0,   0,
          0,  51,  12,   0,  11,   0,   0,   0,   0,  45,  25,  52,  42,   0,
          0,  15,  43,  12,   0,  22,   0,   0], device='cuda:0')
argmax end logits shape: tensor([120,   5,  14,  14,  20,  61,  84, 100,  52,  65,  36,  35,  37,  61,
         47,  11,  67,  10,  26,  93,   3,  98,  79,  62,  25,  53,  12,  77,
         86,   4,  11,  23,  52,  20,  81,   4,  31,  77,  82,  35,  75,  62,
         37,  30,  22,  53,  17,  53,  70,   2,  61,  54,  55,  77,  53, 105,
         69,  48,  24,  36,  14,  12,   3,  42], device='cuda:0')
argmax start logits shape: tensor([  0,   0,   0,   3,  30,  15,   0, 114, 112,   0,   0,  36,   0, 108,
          0,   0,   0,   0,  10,  70,  78,  21,   0,  59,   0,  17,  40,   0,
         

Epoch 1/10:  71%|███████   | 221/313 [00:15<00:06, 14.96it/s, loss=8.29]

argmax start logits shape: tensor([ 30,  71,   8,  11,  67,   2,  79, 100,  41,  57,  30,  19,   0,  10,
          0,   0,  44,  40,  66,  47,  90, 105,  70,   0,   6,  34,   6,   5,
          7,  48,  55,  75,   0,  25, 106,   0,  54,   8,  30,  62,  30, 123,
        113,   2,  29,  46,  32,   6, 105,  40,  42,  84,  52,  11,  62,  62,
          0,  51,  81,  39,  11, 119,   0,  38], device='cuda:0')
argmax end logits shape: tensor([ 44,  84,  23,  26,  46,  79,  82,  12,  27,   7,  30,  19,  87,  28,
         72,  58,  81,  67,  45,  75,  76,  45,  14, 114,  67,  52,   6,  81,
         15,  22,  16,  78,  70,  30,  29,   3,  10,  21,   5,  39,  16,  66,
         62,  16,   2,  55,  58, 111,  35,  44,  46,  69,  37,  67,  62,  39,
         16,  97,  35,  42,  30,  65,  41,  57], device='cuda:0')
argmax start logits shape: tensor([ 27,   0,  18,  54,  34,  21,  33,  33,  57,  97,   5,  83,  78,  58,
         45,  22,  54,  14,  45,  22,  85, 105,  27,  51,  85,   6,   9,  68,
         

Epoch 1/10:  72%|███████▏  | 225/313 [00:15<00:05, 14.97it/s, loss=8.12]

argmax start logits shape: tensor([69, 12,  0,  0, 17, 94, 98, 57,  1,  3, 34,  2,  0, 20, 40, 88, 41, 46,
         0, 61,  2, 69, 21, 20, 26, 60, 52, 60, 17, 59, 61, 40, 49, 77, 80, 34,
         0,  0,  0,  0,  0,  7,  0,  0, 51, 88,  0, 86,  8,  2, 14, 24, 84, 66,
        88,  7, 86, 21, 77, 83, 47, 27, 90,  2], device='cuda:0')
argmax end logits shape: tensor([ 44,  21,  34,  40,  20,   8,  31,   8,  72,  51,  74, 101,  31,  56,
         86,  40,  71,  46,  16,  20,  52,  34,  79,  81,  83,  52,  47,  51,
        128,   4,  61,  36,  11,  37, 100,  30,  61,   4,  43,  18,  69,  17,
          7,  10,  11,  92,  48,  28,  16,  27,  18,  20,  45,  47,  75, 116,
         75, 112, 107,  37,  84,   3, 106,   6], device='cuda:0')
argmax start logits shape: tensor([ 50,   0, 122,   0, 111,  71,   0,   6,   0,   8,  19, 114,  72, 104,
          0,   0,  55,  34,  20,   0,   0,  41,  15, 117,  41,   0,  25,  36,
         99,  10,   0, 110,  18,   0,   0,   0,   0,   0,   0,   0,   0,  76,
   

Epoch 1/10:  73%|███████▎  | 229/313 [00:15<00:05, 14.94it/s, loss=8.27]

argmax start logits shape: tensor([  0,  70,  71,  76,   0,  19,  77,   0,  84,   0,   0,   0,  75,   0,
          0,   0,   0,   0,  43,  87,   0,   0,   0, 110,   0,  51,  52,   0,
         24,  48,  94,   0,  79,   0,   0,  82,   0, 131,  20,   0,   0,   0,
          0,   0,   0,   0,  31,   2,   0,  46,  53,  72,  35,  53,  42,  94,
         11,  68,   0,   0,   0,   0,   0,  14], device='cuda:0')
argmax end logits shape: tensor([ 90,  92,  34,  56,  34,   2,  65,  19,  63,  71,  17,  94,  84,  17,
          7,  61,  41,  21,  41,  19,  36,   7,  71,  80,  26,   2,  24,  18,
        100,  99,  30,  14,  17, 127,  86,  21,  21,   1,   9,  95,  44,   2,
         91,  42,  22,  65,  71,   2,  10,  28,  30,  38,  93, 103,  24,  99,
         75,   6, 106,  79,  51,  67,  27,   3], device='cuda:0')
argmax start logits shape: tensor([ 50,   0,   0,   0,  44,   0,  41, 103,   0,   0,  57,  74,  93,   0,
          0,   0,   0,   0,   0,  11,   0,   0,  25,   0,  85,   0,  86, 101,
         

Epoch 1/10:  74%|███████▍  | 233/313 [00:15<00:05, 15.03it/s, loss=8.18]

argmax start logits shape: tensor([113,  10,  53,  19,  55,   0,   0,  14, 116, 124,  39,   0,  13,  39,
         54,  85,   2, 152,   3,   9,  43,  69,  18,   6,  56,   0,  21,   2,
        100,   0,  53, 100,  22,  63,  41,   0,  45,  43,  11,   1,   2,  10,
          0,  92,   3,  33,   6,   0,  56,   0,  65,  43,   0,  12,   0,  31,
        108,  33,  90,  40,   0,  10,   2,   0], device='cuda:0')
argmax end logits shape: tensor([ 77,  23,  59,  19,  26,  61,   3,  31,  37,  50, 125,  81,  75,  75,
         47, 128,   2,   8,  55,   3,  32,  90,   8, 131,  56,  19,  22,   5,
        101,  66,  52,  71,  22,  65,  38,  34,  16,  29,  39,   1,   3,  20,
         27,  92,  26,  74,   2,  29,   3,  57,  78, 105,  62,  66,   1,  32,
         86,  63,  94,   8,  42,   3,  72,  59], device='cuda:0')
argmax start logits shape: tensor([101,  75,  40,  46,   0,  76,  34, 119,  83,  21, 120,  10,  33,   0,
          8,  19,  54,  71,  20,  41,  56,  36,   9,  65,  71,  64,  86,   0,
         

Epoch 1/10:  76%|███████▌  | 237/313 [00:16<00:05, 14.91it/s, loss=7.95]

argmax start logits shape: tensor([ 54,  78,  85,  74,  15,  35,   5,  98,  32,  39,  77,  79,   9,  50,
         33,  33,  58,  48, 116,  37,  67,  23,  34,  61, 103,  36,  90,  56,
         65,  43,  12,  55,  87,  26,  48,  81,  29,  61,  34,  18,   4,  88,
          7,  20,   0,  14, 105,  55,  78,   0, 111,  47,  55,  44,  26,  55,
         98,  24,  65,  77,   2, 107,  39,   2], device='cuda:0')
argmax end logits shape: tensor([ 73,  47,  41, 112,  35,  22,   5,   6,  59,  75,  37,  46,  80,  42,
        112,  50,   9,  26,  38,  74,  58,  47,  43,  54,  63,  32,  48,   4,
         16,  45,  69,  11,  48,   3,  11,  70,  24,   3,  13,  75, 147,  89,
         91,  92,  72,  20, 106,  63,  80,  18,  68,  21,  25,  21,  10,  43,
         94,  27,  34,  79,   2,  20,   3,  55], device='cuda:0')
argmax start logits shape: tensor([ 50,  59,  40, 103,   3,   0,  51,  10,  65,  11,  50,  72, 107,  81,
         68,   3,   5,  53,  60,  57,  82,   4,   0,  65,  84, 118,  75,  46,
         

Epoch 1/10:  76%|███████▋  | 239/313 [00:16<00:05, 14.76it/s, loss=7.92]

tensor([ 61,  57,  21,  36,  16,   3,  10,   0,  12,  21,  54, 100,  19,   2,
          0,  15,  30,  11,  43,  59,  50, 124,  75,  63,   0,   0,  70,   0,
          0,   8, 110, 127,   9,  23,  47,   0,   0,   3,   0,  45,   6,  43,
          0, 101,  50,  29,  19,  19,  22,  90,   0,   0,   0,   0,  73,  35,
         86,   0,  41,   0,   0,   0,  55,  64], device='cuda:0')
argmax end logits shape: tensor([ 61,  36,  47, 119,  22, 104, 127,  70,  39,   8,  34,  16,  19,  40,
        104,  97,  34,  34,  10,  42,  50,  81,  87,  39,  14,  16,  45,  58,
         47,  11, 130,   3,  24,  24,  79, 114,  46,  12,  34,  26,  86,   5,
         56,   5,  51,  17,  73,  20,  17,  78,  45, 112,  65,  37,  22,  36,
         76,   5,  41,   4,  49,  25,  71,  38], device='cuda:0')
argmax start logits shape: tensor([ 45,   0,  76,  52,  26,   0,  45,   0,  54,  47,  17,   0,   0,   0,
         11,   7,   9,   0,   0,  78,   5,   8,   0,   0,  24,   0,   0,   8,
          3, 130, 114,   0,   0,  33

Epoch 1/10:  77%|███████▋  | 241/313 [00:16<00:04, 14.65it/s, loss=8.09]

argmax start logits shape: tensor([ 82,   0,   3,  57,  22,   0,   0,   0,   0,   0,   0,  78, 120,  58,
         74, 111,   0,   5, 104,   0,  29,  33,  41,  88,   0,  56,  84,  33,
         79,  62,  57,  41,   0,  42,   2,  92,  73,  34,   0,   0,   0,  58,
          0,  70,   0,  96,   0,   0,  88,  36,   0,   0,  29,  73,  47,   0,
         43,  39,   0,  78,   0,   0, 134,  49], device='cuda:0')
argmax end logits shape: tensor([ 15,  25,  56,  25,  25,  51, 118, 138,  78,  43,  28,  78, 116,  59,
         59,  32,  17,  80,  65,  14,  25,   5,  83,  80,  46,  19,  71,  34,
         69,  19,  25,  41,  67,  45,  10,  21,   3,  60,  33, 110,  33, 116,
        105,  70,   3, 100, 113,  50,  62,  47,  46,  46,  73,  87,  38, 102,
         83,  41,  14,  78,  21,  23,  32,  30], device='cuda:0')
argmax start logits shape: tensor([ 62,   0,  27,   0,   0, 140,  30,  14,  19,   0,   0,  27,   0,   4,
         20,   0,  31,   0,   0,   0,   0, 136,   0,  90,  12,   0,   0,   0,
         

Epoch 1/10:  78%|███████▊  | 245/313 [00:16<00:04, 14.81it/s, loss=8.21]

argmax start logits shape: tensor([ 0,  0,  1,  0,  0,  2,  0,  0,  0,  0, 11, 32,  0, 34,  0, 59, 29,  0,
         0, 66, 31,  0, 24, 66,  0, 53,  0,  0,  0,  0,  3,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0, 49,  0,  0,  0,  0, 31,  0, 66, 30, 47,
        53,  1,  0,  0,  0,  0, 57, 53,  0, 98], device='cuda:0')
argmax end logits shape: tensor([ 43,  52,  36, 115,  66,  36,  13,  13, 124,  44, 121,  62,  14,  82,
         90,  44,   2,  58,  22, 107,  99,   6,  96,   2,  17,  30,  53,  70,
         54,  60,  13,  11,  66,  36,  52, 147,  23,  72,  31,  67,  26,  62,
         22,   2,  23,  19,  41, 127,  93, 131,  59,  81, 110,  90,  21,  47,
         59,  50,  42,   8,   9,  53,  81,  63], device='cuda:0')
argmax start logits shape: tensor([ 29,   0,   0,   0,   0,   0,   8,  85,  51,   0,   0,   0,   0, 100,
          0,   0,   0,  28,   0,  52,   0,  42, 110,   0,   0,  98,   0,   2,
          0,  26,   0,   0,  48,   0,   0,   0,   0,   0,   0,   0,   0,  89,
   

Epoch 1/10:  80%|███████▉  | 249/313 [00:17<00:04, 14.97it/s, loss=8.15]

argmax start logits shape: tensor([  0,   0,  73,   0,  80,  49,   0,   6,  55,  48,  69,   0,   0,  63,
         14,   1,  53,   0,  72,   0,   0,   0,   0,   0,  62,   0,  66,  14,
         20,   0,  23,   0,   0,   0, 118,  10,  59,  80,  82,  51,   0,  56,
          0,   3,   0,   0,  81,  60,   5,   0,  38,  61,   0,  41,  11,  75,
          0,  81,  24,   0,  59,  77,  60,   0], device='cuda:0')
argmax end logits shape: tensor([ 48,  41, 108,  31,  73,  83,  11,  80,  47,  66,  31, 113,   6,  14,
         58, 108,  77,  22, 134,   3,  25,  64,  50, 115,  89,  41,   3,  40,
        121,   3,  25, 100,  78,  40,  52,  58,   3,  63,  68,   3,  43, 102,
        129,  35,  12, 100,  65,  74,  29,  94,  61,  77,  59,  50,  46,  71,
         12,  37,  96, 100,   7,  67,  61,   4], device='cuda:0')
argmax start logits shape: tensor([  0,   0,  20,   0,  46,  92,   0,  29,   0,   0,  24,  45,   0,  12,
         59,  79,   0,   0,   0,  34,   0,  26,  19,   0,  63,   0,   0,   0,
         

Epoch 1/10:  81%|████████  | 253/313 [00:17<00:03, 15.05it/s, loss=8.08]

argmax start logits shape: tensor([ 51,  15,   0,   0,  30,  44,  12,  51,   0,   0,  28,   0,  31,   0,
         54,   0,  35,   0, 123,  27,   2,   0,  79,   0,  54, 140,   0,   0,
          0,   0,   0,   0,   0,   0,   8,   0,  81,  17,   0,  99,  68,   0,
         28,  24,   0,   0,  52,  24,   0,  64,   0,   0,   4,   0,   0,   0,
         50,  66,  24,   0,   0,   0,   3,   0], device='cuda:0')
argmax end logits shape: tensor([ 85,  76,  20,  45,   5,  22, 131,  66,  18,  21,  80, 107,  12,  47,
        101,  37,  57,   8, 124,  36,   2, 114, 107,  72, 122,  45,  52,  32,
          5,  37,  16,  39,   2,  21,  16,   2,  26,  50,   5, 101,  35,  98,
         26,  96,  79,  71, 101,  25,   9,  90,  35,  36,  79,  97, 123,  44,
          4,  61,  13,  35,  91,  36,  27,  35], device='cuda:0')
argmax start logits shape: tensor([ 64, 137,  66,  41,   0,   0,  80,   0,   2,  56,  79,  80,  24,   9,
         45,  53,   5,   2,   0,  10, 101,   0,  42,  85,  27,   0,  59,  42,
         

Epoch 1/10:  82%|████████▏ | 257/313 [00:17<00:03, 15.22it/s, loss=8.31]

argmax start logits shape: tensor([ 87,  66,   0,  27,   0,  51,   0,  46,   0,  52, 111,  20,  41,  71,
         12,  30,  20,  51,   0,   6,  11,  22,   0,  41,  25,   0,  47,  14,
          0,   0,   0,  83,   0,  73,  33,  69,   0,  80,  55,  67,  85,   5,
         18,  22,  89, 110,   0,  84,  16,  79,  19,  42,  67,  49,  59,  96,
          0,   6,  93,  75,  60,  44,  26,  25], device='cuda:0')
argmax end logits shape: tensor([ 25, 112,  37,  11,  55,  67,   3,  38,  41,  24,  78,  13,   4,  72,
         69,  25,  17,  48, 107,  72,  48,  17,  37,  21,  67,  61,  11,  12,
         58,  66,  63, 100,  91, 108,  73,  19,  96, 110,  55,  22,  99,  49,
         41,  23, 121,  85,  97,  74,  94,  16,  65,   4, 105,  76,  59,  96,
         17,  41,  11,   6, 114,  62,  21,  26], device='cuda:0')
argmax start logits shape: tensor([110,  19,   0,  54,   0,  89,   0,  85,  75,   3,   0,   0,  86,  86,
          0,   0,  38, 101,  71,  38,   0, 119,   0,   0,   0,  17,   2,  10,
        1

Epoch 1/10:  83%|████████▎ | 259/313 [00:17<00:03, 14.86it/s, loss=8.12]

argmax start logits shape: tensor([ 35,  24,  16,  19, 102,  72,   0,   0,  48, 112,   0,  33,   0,   0,
          0,  99,   0, 110,   2,  59,  73,  61,  32,   0,  25,  62,   0,   1,
         66,  68,   0,  54,  56,  32,   0,   0,   0,   0,   0,   0,   0,  16,
         44,  53,   0,   0, 114,  11,  80, 112,   0,   0,  60,  19,  15,   0,
          0,  83,  34,   0,  67,   0,   0,   0], device='cuda:0')
argmax end logits shape: tensor([ 35,  76,  83,   4,  51,  24,   3,  60,   3,  92,  89,  40,  77,  22,
         37,  99,  30, 110,  57,  14,  73,  44,  18,  20,  48,  60,  97, 122,
         46,  72,   4,  23,   3,  46,  54,   5,  22,  31,  20, 111,  16,  73,
         29,  53,  23,  20,  83,  11,  40,   7,   3,  37,  60,  19,  20,  66,
         66,  23,  99,  46,  32,  11,  52,  43], device='cuda:0')
argmax start logits shape: tensor([  0,  51,  16, 105,  35,  72,   0,  61,   0,  33,   4,  51,  77,   0,
         11,  72,  34,   9,   6,  88,  53,   0,  75,   0,  86,  21,  40,  79,
         

Epoch 1/10:  83%|████████▎ | 261/313 [00:17<00:03, 14.94it/s, loss=8.19]

argmax start logits shape: tensor([ 55,  44,  25, 112,  44,  30,  42,  53,  39,  92,  36,  25, 107, 110,
        105,  20,  11, 110,  81,  51,  96,  42,  42,  18,  26, 108,   3,  84,
         93,  89, 103,  46,  12,  34,  66, 107, 112,  18,  19, 113,   0,  51,
         37,  60,  92,  90,  45,  85, 116,  90,  79,  10,  73,  26,  21,  11,
         54,   3,   8,   3,  36,  63,  30,  88], device='cuda:0')
argmax end logits shape: tensor([ 55,  48,  70,  18,  92,  28,  17,  53,  40,  62,  36,  28, 102,  83,
        100,   8,  31,  54,  81,  47,  96,  42,  50,  18,  26,  78,   3,  58,
         95,  89,  90,  93, 115,  34, 106,  57,  15,   2,  39, 121,  65,  90,
         76,  56,  84, 106,  20,  89,  39,   8,  74,  15,  73,  45,  79,  63,
        104, 151,  21,  41,  36,  49,  46,  16], device='cuda:0')
argmax start logits shape: tensor([ 38,  38,  15,  55,  92,  16,  56,  23, 104,   5, 110,  10,  10,  36,
         25,  45,  47,  68, 112,  42, 105,  94,  82,  31,  44,  44,  48,  97,
        1

Epoch 1/10:  85%|████████▍ | 265/313 [00:18<00:03, 15.00it/s, loss=8.27]

argmax start logits shape: tensor([  0,  89, 122,  77,  37,  43,  21,  31,   0,   8,  86,   5,  30,  26,
          0,  77, 102,   2,  28,  97,   0,  15,  32,  70,  11,   8,  35,  37,
         29,   0,   0,  23,   3,  23, 106,   0,  19,  99,  28,  48,   8, 138,
          0,   0,   0,  23,   0,  19,   0,   0, 112, 118,   0,  10,   0,   8,
          2,   2,  19,  27,  52,  74,   2,  77], device='cuda:0')
argmax end logits shape: tensor([ 36, 106, 124,  69,  32,  23,   1,  49,   2,  60,  21,  25,  42,  31,
         33,  78,  80,   5,  35,  89,  77,  56,  47,  46,   6,   9,  36,   2,
         82,  23,  71, 102, 111,  64, 133,  53,  36,  14,  10,  48,  80,  28,
          2,   2,  90,  26,  65,  60,  44, 107,  17,  52,  87,  12,   7,  54,
          5,  41,   3,   9,  65,  64, 112,  44], device='cuda:0')
argmax start logits shape: tensor([ 81,   0,   0,   0,  23,   0,   0,  26,  65,   0,  65,  54,  92, 115,
          7,  32,   9, 110,   0,   0, 123,  72,  53,  51,  89,  72,  33,  97,
        1

Epoch 1/10:  86%|████████▌ | 269/313 [00:18<00:02, 15.17it/s, loss=7.82]

argmax start logits shape: tensor([ 68,  57,   3,  14,  76,  57,  18,  75, 118,  86,  39,   8,   2,  13,
         37,  27,  75,   1,  92,  15,  58,  49,  87, 111,   8,  44,  44,  25,
          9,  14,  91,  33,  44, 105,   3,  93,  28,  16,  77,   3,  18,  97,
         59,  85,  41,  75,  49,  44,  78,  98,  41,   3,  62,  51,  37,   2,
         63,  15,   8,  28,   5,  36,  54,   2], device='cuda:0')
argmax end logits shape: tensor([ 24,  78,  71,  82,  56,  30,  50,  55,  71,  28,  20,  15,   3,  99,
         28,  30,  24,  62,   3,  15,  55, 126,  83,  58,  53,  60,  33,  31,
         10,  94,  91,  46, 115,  36,  48,  24,  48,  16,  95,  32,  22,  48,
         48,  12,  70, 103,  62,  44,  80,  86, 101,  55,  15,  80,  37,  46,
         83,  74,  50,  18,  89,  37,   6,   9], device='cuda:0')
argmax start logits shape: tensor([ 85,  78,  28,   9,  51,  38,  12,  51, 113,  97, 109,   2,  32,  75,
         89,  27,  70,  83,  68,  30,  67,  68,  95,  87,  71,   0,  79,  18,
         

Epoch 1/10:  87%|████████▋ | 271/313 [00:18<00:02, 15.22it/s, loss=8.12]

argmax start logits shape: tensor([ 46,  43,  49, 117,  68,   6,  50,  80,   2,  23,  96,  30, 102,  57,
         46,  28,   9,   2, 105,  50,  11,  16,  11,   3,  48,  66,  71,  18,
          7,  73,  70,  59,  20,  71,   8,   0,  40,  56,  28,  68,  87,  21,
         75,  13,   0,  93,  22,  52, 106,  78,   0, 106,  60,  53,  10,   0,
         17,  45, 104,   5,  62,  78,  10,  17], device='cuda:0')
argmax end logits shape: tensor([ 53,   5,  59,  31,  24,   9,  26,  67,   2,  25,  10,  32, 102,  89,
         69,  32, 102, 106, 112,  24,   5,  45,  34,   3,  93,   7,  97,   6,
         63,  25,  83,  30,  46, 113,  34,   3,  22,  56,   2, 123,  22,  15,
         94,  16,   5,  93,  67,  20, 108,  35, 103,  44,  98,  53,   8,   5,
         17,  54, 104,   5,  65, 117,  10,  42], device='cuda:0')
argmax start logits shape: tensor([ 15,  64,  38,   0,  59,  21,  88, 108,  43,  10,  75,  23,  91, 127,
          7,   0,   0,   6,   0,  31,  80,   2,   0,   0,  83,  23,  44,   2,
         

Epoch 1/10:  88%|████████▊ | 275/313 [00:18<00:02, 15.05it/s, loss=7.98]

argmax start logits shape: tensor([ 35,  31,  40, 113,  25,  69,   0,   2,  51,  98,  64,  63,  72, 106,
          0,  13, 132,  97,  78, 114,  77,   7,   6, 100,  42,  22,  57,  58,
          2,   2,  48,  17, 110,   0,  47,  28,  24,  63,  65,   0,  24,  18,
         70,  73,  67,  55,   7,  26,   0, 104,  55,  15,   0,  23,  13,   5,
         29,  61, 102,   3,   0,   2,  78,  23], device='cuda:0')
argmax end logits shape: tensor([  5,  23,  47,   8, 121,  22,  17, 140,  32,  90,  55,  37,  84, 124,
         19,  14, 116,  57,  19, 115,  79,   7,  10,  71,  59,  66,  59,  95,
         56,   3,  49,  24,  47,  13,  58, 128, 104,  52,  68,  43,  25,  19,
         34,  83,  68,  17,  71,  27,  23,  88,  56,  70,  64,  89,  42,   7,
         25,  64,  21,   8, 110,  14,   5,  64], device='cuda:0')
argmax start logits shape: tensor([114,  50,  61,  13,  80,   0,  44,   0,  94,   2,  88,  71,  78,  32,
         94,  44,  36,   0,  99,  85,  45,  51,  95,   7,  15,  90,  91,  47,
        1

Epoch 1/10:  89%|████████▉ | 279/313 [00:19<00:02, 14.88it/s, loss=7.79]

argmax start logits shape: tensor([ 65,  64,  21,  54,  34,  87,  23,  93,  22,   3,   0,  26,  81,  89,
         59,   0,  24,  20,   3,  53,   0,  21,  11,  82,   0,  61,  31,  55,
          8,   0,   5,  37,  69,  86,   0,   0,   3,  54,   2,  42,  78,  38,
         93,  36, 116,  57,  56, 106,  53,  99,  86,  80,   2,   0,   0,  18,
         43, 111,   0,  53,  35,  44,  87,   1], device='cuda:0')
argmax end logits shape: tensor([105,  23,  83,  11,  35, 131,   3,  73,  42,  56,  12,  22,   5,  55,
         84,  33,  42,  73,  10,  60,   3,  37, 101,  13,  19,  85,  83,  56,
         20,   9,   2,  58,  37,  82,  13,   6,  86,   7,  64,  53,  77,  90,
         22,  57,  65,  19,  27,  50,  93, 104,  76,  16,   3,  41,  56,  71,
        111,  61,  11,  60, 101,  93,  41,   3], device='cuda:0')
argmax start logits shape: tensor([  8,  63,   6,  11,  31,  17, 114,  41,  22,  80,  47,  51,  70,  82,
         67,  21,  78, 108,   5, 113,  59,   3,  26,  81,   7,  17,   6,  42,
         

Epoch 1/10:  90%|████████▉ | 281/313 [00:19<00:02, 14.97it/s, loss=8.23]

argmax start logits shape: tensor([108,  16,  70,  54, 106,  77,  54,   4,  97, 117,  95,  57,  66,   8,
         58,   2,   3,  13,  35,  66,  51,  15,   2,  46,  47,  58,  60, 124,
         24,  43,  24,  53,  11,  40,  22,  96,  16,  62, 111,  54,  69,   9,
         29,  24,  37, 136,  27,   3,   2,  25,  42, 111, 100, 133,  26, 115,
         74,  33,  50, 100,  87,  42,  23,   7], device='cuda:0')
argmax end logits shape: tensor([ 70,  29, 136,  72,  48,  17,  92,   4,   2, 117,  39,  59,   1, 105,
         12,  36,  20,  23,  16,  66,  51,  95,  61,  46,  64,  31,  62,  23,
         51,  63,  80,  53,  45,   5,  32, 100,  16,  23, 131,   8,  31, 109,
         85, 116,  27, 134,  14,  53,  66,  10,  58, 104,  75,  85,  43, 108,
         43,  24,  65,  15,  10,  10,  34,  74], device='cuda:0')
argmax start logits shape: tensor([ 13,  24,   8,  74,  31,  56,  25,   3,  90,  37,   2,  35,  33,  86,
         34,  95,  65,  19,  74, 102,  12,  65,  17,  26,  86,  59,  43,   8,
        1

Epoch 1/10:  91%|█████████ | 285/313 [00:19<00:01, 15.13it/s, loss=8.44]

argmax start logits shape: tensor([ 68,  51,  70,  79,  48,  19,  17, 113,  28, 109,  21,  66,  69,   9,
         50,  23,   7,   3,   6,  95,   2,  22, 100,  96,   4,  56,  24,  90,
         42, 101,   2,  25,  81,  83,  36,  31,  93, 126,   7,  90,  66,  23,
         69,  21,   3,  41,  58,  92,  31,  44, 104,  48,   3,  95,  10, 108,
         26,  44,   5,  57,  22,  15,   6,  40], device='cuda:0')
argmax end logits shape: tensor([ 42,   2,  60,   3,  40,  22, 100,   5,  31,  74,   2,  66, 110,   7,
         46,  61,  10,   6,  34,  95,   3,  65,   3, 121, 120,  56, 106,   6,
         61,   5,  45,  79,  17, 100,  37,  40, 109,  67,   7,  48,  82,  43,
         59,  26,   3,  42,  59,  92,  13,  44,  13,  45,   6, 119,  23, 102,
          1,  31,  22, 133,  37,  15,  79,  57], device='cuda:0')
argmax start logits shape: tensor([ 71,  93,  78,  28,   2,  26,  26, 102,  26,  58,  49,  41,  74,   5,
         74,  47,   8, 126,  91,  54,  45,  82,  40,  30,  93,  62, 112,  20,
         

Epoch 1/10:  92%|█████████▏| 287/313 [00:19<00:01, 15.17it/s, loss=8.14]

argmax start logits shape: tensor([ 86,  31,  43,   0,  17,  21,  48,   0,   8,  41,  72,  25,  15,  57,
         23,  33,  41,   0,   0, 128,   6,  51,  16,  11,  88,  28,  36,  45,
        102,  61,  28,   8,  65,  40,  57,   2,   0,  92,  52,  17,  39,  69,
         20,  90,  40,  96,  15,  69,  81,   0,  21,  41,   8,   7,  34,  22,
         57,  22,  63,   0,   8, 127,  42,  34], device='cuda:0')
argmax end logits shape: tensor([ 95,  42,  91,  14,  25,  33,  34,  26,  87,  45,  81,  50,  80,  61,
          3,   3,  45,  26,  21,  99,  36,  29,  11,  48,  74,  31,   9,  20,
        121,  10,  46,  21,   6,  36,  58,   2,  73,  16,   4,  17,  42,  31,
         42,  91,  52,  77,  23,  32,  62,  35,  15,  73,   2,  94,  38,  65,
         63,  45,  50,  32,  11,  42,  63,  73], device='cuda:0')
argmax start logits shape: tensor([ 40,  67, 108,   0,  22,  99,   0,  20,  26,  28,  38,  75,   0,  41,
          8,   0,  37,  28,   0,  52,  90,  29,   0,  26,  36,  64,   0,  92,
         

Epoch 1/10:  93%|█████████▎| 291/313 [00:19<00:01, 15.14it/s, loss=7.82]

argmax start logits shape: tensor([  0, 107,  47,   0,   0,   9,   0,   0,  48,   0,  40,   0, 114,  57,
          0,  81,  62,  30,  54,  63,  91,  52,   0,  86,   0,  56,  64,  66,
          0,   6,   0,  93,  88,  41,   0,  44,  53,  75,  33,  92,  34,   0,
         54, 109,  97, 128,   0,   3,   0,   0,   0,  51,  10,  83,  86,  66,
         47,  44,   0,  13,  70,  60,   2,  39], device='cuda:0')
argmax end logits shape: tensor([ 53,  57,  14,   5,  54,  60, 109,  64,  48,   3,  62,  70, 116,  50,
         17,  17,  47, 109,  72,  70,  12,  17,  77,  26,  80,  30,  23,  11,
         26,  88, 113,  30,  92,   2,  53, 116,   6,  80,  49,  16,  27,  14,
         17,  61, 102,  71,  87,  43,  54,  81,  23,  42, 105,  64,   6,  11,
         19,  95,  14,  15,  75,  12,   2,  52], device='cuda:0')
argmax start logits shape: tensor([  0,  58,  84,  41,   0,  24,   0,  39,   0,  31,  55,   0,   0,  71,
         22,  34,  57,   0,  30,   0,  75,   0,  14,  88,   2,   0,  26,  47,
        1

Epoch 1/10:  94%|█████████▎| 293/313 [00:20<00:01, 15.22it/s, loss=8]   

argmax start logits shape: tensor([ 91,  26,  82,  41,   0,   0,  26,  45,   2,   0,  42,  67,  50,   0,
          1,   0,  76, 108, 115,  34,   0,   6,  58,   0, 123,  24,  70,   0,
         78, 101,  50,   6,  45,  50,   8,  35,   0,  76,   0,   0,  41,   0,
        107,  78, 102,   0,  68,  44,  95, 110,  68,  95,  76,   2,  90, 114,
         36,   6,  37,   3, 125,  35,  37,  27], device='cuda:0')
argmax end logits shape: tensor([ 59,  40,  32,  32,   5,  48,  63,  30, 103,  65,  45,  69,  63,  26,
         34,  46,   3,  90,  35,  36,  66,  79,  58,   2, 130, 103,  81,  68,
         79, 117,  51,  28,  66,  52,   5,  52,   7,   5,  60,  50,  15,  83,
        123,  53, 108,   9,   2,  17,  66,  36, 104,  14,   9,   3,  80, 115,
          8,   6,  74,  15, 121,  61,   6,  27], device='cuda:0')
argmax start logits shape: tensor([  0,  16,  46,   0,  59,  40,  19,   9,  52,  30,  50,   0,  50,   6,
         33,  12,  85,  65,  64,   0,  62,   0,  75,  63,  68,  40,  80,  13,
         

Epoch 1/10:  95%|█████████▍| 297/313 [00:20<00:01, 14.67it/s, loss=8.06]

argmax start logits shape: tensor([110,   8,  17, 123,   8, 130,   5,   0,  26,  42,  74,  72,  88,  38,
         16, 114,  22,   3,  36,  39,  48,  33,   6,  75,   6,  50,  63,  19,
          7,  89,  67,  31,  10,  42,   3,  80, 101,  72,  45,  15,  20,   8,
        114,  39,  66,  64,  61,  18, 111,  56,  67,  70,  38,   6,  56,   5,
         31,  36,  22,  22,  47,  16,  43,  12], device='cuda:0')
argmax end logits shape: tensor([ 43,  41,  27, 123,  19,  47,  67,  12,  19,  53,  23,  44,  88,   7,
         65,  87,  13,   3,   6,  39,  68,  90,   6,  19,  21,   3,  63,  19,
         34,  44,  56,  69, 111,  94,  59, 112,   3,  46,  12,   5,  85,   8,
         91,  43,  34,  91,  56,  20, 113,  42,  95,  95,  15,  17,  51,  33,
         31,  39,  44,  22,   1,  16,  94, 100], device='cuda:0')
argmax start logits shape: tensor([ 92,  91,  99,  11,   0,   2,  64,  15,  85,  77,   0,  31,  79,   2,
         91,  55,  68,  18,  46,   3,  39,   8,  99,  43, 104,  58,  70,  25,
         

Epoch 1/10:  96%|█████████▌| 299/313 [00:20<00:00, 14.84it/s, loss=8.03]

argmax start logits shape: tensor([100,   0,   0,  43,   9,  40,  51,   8,  81,  17,  61,   0,  71,  14,
        101,   0,   0,  43,  49,  94,   0,  64,   0,  52,  51,  60,   2,  38,
         20,   0,   0,  79,   0,   0,  15,  19,  50,  76,  12,  49,   6,  23,
          0,   2,  25,  46,  41,  17,  21,   0,   0,  76,  67,  63,  14,  80,
          2,   0, 118,  34,  57,   4,  46,   0], device='cuda:0')
argmax end logits shape: tensor([127,  58,  33,  46,  11,  11, 106,  14,  50,   9, 100,  84,  17,  99,
         50,  81,  66,  16,  47,  54,  44,  64, 101,  42,  51,  29,   2,   9,
          4, 102,  82,  79,  29,  35,   3,  83,  45,  73,  55,  51, 131,  24,
         14,  67,  26,  57,   5,  17,  24,   4,  45,  92,  67,  70,  72,  53,
          8, 104, 120,  15,  58,  36,  57,  95], device='cuda:0')
argmax start logits shape: tensor([ 51,  70,  93,   5,  84, 113,  82,  59,   0,  11,  23,   0,   0,  90,
          0,  41,  23,   0,  92,  34,  34,  13,   0,  67,  40,   0,   0,   0,
         

Epoch 1/10:  97%|█████████▋| 303/313 [00:20<00:00, 15.10it/s, loss=7.92]

argmax start logits shape: tensor([ 62,   0,  70,  12,  25,   0,  12,   0,   0,   0,   0,   0, 113,  43,
          2,   0,  36,   0,   0,  48,   0,   0,   0,   2,  84,  68,   0,  27,
          0,   0,   0,  11,   0, 107,  48,   0,  57,   0,   0,  16,  45,   0,
          0,   0,   0,   0,  46,  36,  69,   0,   0,   0,  63,   0,   0,   0,
         46,   0,  18,   0,  67,   0,   0,  11], device='cuda:0')
argmax end logits shape: tensor([79, 21, 17,  2,  3,  1, 14, 34, 26, 81, 38, 81, 90, 58, 42, 55, 95, 15,
        56, 25, 24, 74, 97,  5, 75, 64, 42, 37, 64, 20, 34, 14,  5, 24, 62, 63,
         5, 51,  3, 67, 85, 14, 63, 18, 43, 17, 57,  2,  5, 10, 15,  4,  4, 27,
        25,  3, 62, 49,  5, 65, 67, 29,  3, 11], device='cuda:0')
argmax start logits shape: tensor([  0,   0,   0,  10,   0,   0,   0,  83,   0,  52,  58,  99,  24,   0,
          0,   0, 118,   0,  87,   0,   0,   7, 116,   0,   0, 113,   0,   0,
        114,  19,  13,   0,  45,   0,   0,  62,   6,   0,   0,  51,   0,   0,
   

Epoch 1/10:  97%|█████████▋| 305/313 [00:20<00:00, 14.99it/s, loss=8.08]

argmax start logits shape: tensor([  0,   0, 101,   0,   2, 111,  75,  31,  25,   0,  36,  56,   0,  59,
         89,   0,  83,   0,   0,   1,  53,  17,  54,   0,  19,   0,   0,   0,
         18,  72,   2, 110,   0,  15,  30,   0,  18,  35,   0,  21,  54,   0,
         10,  54,  49,  66,  87,   0,   0, 102,   0,  19,   8, 119,  90,   0,
          2,   0,  16,   0,  55,   0,  29,  40], device='cuda:0')
argmax end logits shape: tensor([ 67,  57, 101,  76,   8,  25,  10,  45,  13,  96,  49, 128,  14, 100,
         28,  80,  19,  24,  33,  77,  97,  67,  23,  50,  65,  81,  82,  11,
         61,  19,  37,  78,  17,  18,  83,   2,  12,   8,  97,  21,   5,  87,
          4,   2,  65,  80, 102,  42,  55,  84,  56,   3,  80,  56,  99,  48,
         41,   7,  16, 109,  88,  35,  45,  73], device='cuda:0')
argmax start logits shape: tensor([ 83,   0,   0,   0, 106, 102,  51,   0,  16,   0,   0,  21,  49,   9,
          0,  19,  27,   0,   0,  63,  87,  91,   8,  81,  17,  27,  84,   0,
         

Epoch 1/10:  99%|█████████▊| 309/313 [00:21<00:00, 14.88it/s, loss=8.05]

argmax start logits shape: tensor([ 84,  22,  42,  49, 122,   0,  34, 108,  19,  23,   0,  65,   5,   8,
         25,  46,  28, 101, 110,  65,   0,   3,  13,  25,  45,   0, 100,   0,
         99,  61,  75,   6,  37,  22,  13,  81,  88,  89,  44,   0,  46,   0,
          2,   0,   0,  61,  94,   0,  85,   3,   0,   0,  72,  70,   0,  55,
         98,  46,   0,  71,  27,   0,  12,   0], device='cuda:0')
argmax end logits shape: tensor([ 56,  36,  29,  19, 102,  19,  44,  18,  20,  23,  63,  67,  17,   6,
         70,  75,  53,  99, 111,  58,  29,  41,  30,  76,  41,  45, 101,  43,
         11,  20,  26,  99,  38,  23,  53,  26, 114,  75,  38,  36,  73,   8,
         99,  53,   8,  10,  77,  46,  85,  14,  11,  37, 105,  16,  69,  55,
         48,  46,  75,  71,  88,  71,  13,  81], device='cuda:0')
argmax start logits shape: tensor([ 32,  36,   0,   0,  14,  29,   0,   0,  34,  56, 104,  75,  19,  63,
         46,  43,  99,  31,  43,  64,  69,  31,   0,  33,   0,  18,  37,  13,
         

Epoch 1/10: 100%|██████████| 313/313 [00:21<00:00, 14.73it/s, loss=7.83]


argmax start logits shape: tensor([ 43,  51,  34, 115,   2,   0, 120,   9,  94,  36,   0,  54,  69,  24,
          0,  75,   0,   0,   0,  23,  75,   6, 103,   0,  56, 100,  41,  17,
         71, 114,  98,  90,  24,   0,  24,  16,   0,   0,  23,   0,  55,   0,
          0,   7,  31,   0,  61,   0,   7,  83,   0,   0,  48,   0,  11,  38,
          0, 109,  13,  13,   0,  25,  67, 133], device='cuda:0')
argmax end logits shape: tensor([ 47,  50, 116, 115,  15,   7, 118,  27,  58,  75,  16,  54,  46,  36,
          3,  78, 121,  48,  41,   6,  47,  36, 133,  15,  77, 138,  51,  16,
         49,  87,  72,  91,  20,  43,  27,  22,  12,  27,  28,   6,  61,  51,
         66,  20,   5,   5,  36,  53,  10,  35,  56,  22,  20,  63,  76, 114,
          2,  92,  23,  13,  43, 108,   4,  49], device='cuda:0')
argmax start logits shape: tensor([ 72,  51,  55,   2,  89,  48,  12,  31, 107,  89, 142,   2,  79,  44,
         96, 105,  40,  59,  45,  62,  47,   0,  25,   0,   0,  68,  18,  59,
         

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

argmax start logits shape: tensor([ 56,  76,  76,  27,  49,  49,  15,  41,   8,  81,  81,  48,  26,  16,
         69,  21, 142, 103,  79,   1,   1,   8,  10, 116,  83,   0,  46,  27,
         10,   9, 105,   0,  47,  51,  11,  88,  56,   2,  61,  50,  34,  34,
         27,  73,   6,  77,  25,   0,   5,   0,  36,  28,  88,  15,  51, 118,
          5,  18,  85,   2,  52, 119, 105,  89], device='cuda:0')
argmax end logits shape: tensor([ 22,  25,  36,  42,  81,  70,  56,  41,  10,  78, 155, 109,  26,  85,
          5,  69, 105,  96,  19,  33,  27,  52,  10, 117,  33,  67,  46,  53,
         16,   9,  31,   9,  45,  47,  56,   2,  38,  16,  62,  51, 106,  86,
          6,   4,  56,  69,  21,   2,  20,  37,  50,  25,  90,   9,  47,  68,
         30,  19,  15,  53,  12,  34,  52, 108], device='cuda:0')
argmax start logits shape: tensor([110,  50,  67,  23, 107,  25,  55,  61,  76,  17,  22,  84,  75,  82,
         25,  49,  43,  92,  39,  29,  80, 106,  25,  11,  78,  89,   8,  83,
         

Evaluating:   1%|▏         | 4/313 [00:00<00:08, 34.59it/s]

argmax start logits shape: tensor([  0,  29, 124,  83,  32, 101,  27,  26,  12,  76,  24,  97,  65,   1,
         60,  82,  78,   3,  23,  22,   2,  53,  16,  44,  86,  31,  95,  42,
         40,  34,  48,  36,  45,  86, 137,  83,  40,  67,  91,   0,   8,  14,
         56,  54,  92, 124,  23,   5,  32,  24,   2,  15,  32,  20,  52, 109,
         93,  35,  14,  51,  72,   2,  42,  78], device='cuda:0')
argmax end logits shape: tensor([ 10,  30, 125,  83,   8,  94,  65,   3,  10,  25,  25,  98,   7,  60,
         72,   7,  39,   3,   5,  56,  15,  70,  27,  41,  37,  70,  43,  70,
          7,   6,   2,  80,  86,  93, 132,  14,  19,  67, 128,   2,   8,  14,
         57,  58,  41, 145,   5,   5,  24,  58,   3, 102,  24,  61,  44, 104,
         50,  88,  49,  36,  74,  29,  21,  84], device='cuda:0')
argmax start logits shape: tensor([ 86,  93,  49,  95,  53,  33,   3,  35,  88,  55,   3,   3,   4,  11,
          2,  41,  64,  11,  55,  74,   5,  90, 104,  10, 107,  82, 107,  52,
         

Evaluating:   3%|▎         | 8/313 [00:00<00:08, 34.84it/s]

tensor([ 75,  82,  34,  55,  28,  46,  66,   3,  85,  27,  34, 118,  63,   3,
         60,  12,  87,  44,  75,  49,  66,  54, 121,  85, 108,   2,  17,  47,
         33,  70,   0, 142,  32,   2,  66, 109,  67,  50,   0,  60, 102,  80,
         84, 116,  76,  48,  93,  71,   0,  45,   6, 137,   2,  49,  25,  84,
          3,  71,  73,  43,  12,  53,  11,  98], device='cuda:0')
argmax end logits shape: tensor([ 93,  41,  22, 111,  38,  12,  20,   3,  15,  31,  76,  60, 102,  53,
         72,  14,  71,  44,  75,  23,  59,  83,  57,  80, 109,   3,  26,   9,
         64,  30, 110, 105,  15,   2,   2,  42,  41,   4,  48,  12,  33,  11,
          9, 116,  76,  50,  57,  73,  50,  89,  33, 139,  75,  32,  73,  81,
          4,  71,  15,  24,  56,  46,  41,  78], device='cuda:0')
argmax start logits shape: tensor([ 79,  54,   4,  31,  35, 101, 133,  49,   2,  15, 134,  45,  23,  12,
         20,  70,   7,  36,  33,  47,  19,  17,  15,  25,  14,  74,  70,   0,
         37,  24,   0, 114,   5,  66

Evaluating:   4%|▍         | 12/313 [00:00<00:08, 33.57it/s]

tensor([ 14,   2,  19,  36,   2,  37,  17,  76,  33,  27,  68, 122,  62,   9,
         81,  38, 100,  72,  24,  13,  75,  79, 106,   6,  75,  45,  71,  32,
         15,  95,  31,  53,   0,  81,  87,  26,  78,  85,  89,   0,  37,   2,
         76,  33,  26,  29,  13,  59,  46,  95, 109,  25,  14,  55,  41,  68,
         14,  77, 132,  24, 123,  26,  79,  35], device='cuda:0')
argmax end logits shape: tensor([ 15,   3,  67, 107,   2,  43,  31,  48, 117,   7,  32,  31,  62,  29,
         80, 114,  75,  67,  20, 106,  35,  79,  52,  92, 113,  29,  71,   3,
         29, 110,  79,  23,  10,  78,   9,  19, 101, 106,  21,  20,  16,  18,
         82,  11,  28, 106,  93,  59, 116,  39,  90,  40, 101,  29, 130,   3,
         61,  79, 107,  48,  84,  17, 107,  17], device='cuda:0')
argmax start logits shape: tensor([ 16,  30,  14,   8,   9,  10,  25,  86,  76,  57,  89,  78,  24,  69,
          7,   2,  40,  58, 113,  69,  90,  57,  71,  15,  54, 114,  45, 132,
         34,  29, 127,  30,  51,  17

Evaluating:   5%|▌         | 16/313 [00:00<00:08, 33.93it/s]

tensor([ 14,  89,  19,  76,  11,  38,   1,  40,  23,  12,  58,  67,   6,  24,
        108,   6,   0,  31,  79,  45,  10,   2,  34,  59,  60,   7,  41,  28,
         36,  24,  49,   8,  23,  26,  67,  88,  88,  12,  27,  63,  21,  14,
         21, 128,  53,  99,  62,  39,  97,   0,  22,  32,  42,  68,  24,  32,
         49,  69,  78, 131,   8,  46,  48,  51], device='cuda:0')
argmax end logits shape: tensor([ 13,  42,  28, 107,  26,   9,  23,  19,  46,  62,  72,  67,   4,  79,
         24,   6,  86,   2,  79,   5,  32,   3,  36, 104,  72,  39,  20,   7,
          2,  10,  49,  42,  96,   2,  49, 110,  10,  79,   2,  63,  29,  72,
         10, 128,   9,  78,  14,  64,  30,  36,  45,  87,   3,   3,  89,  24,
         32,  84,  13,   8,  26, 102,   2,  80], device='cuda:0')
argmax start logits shape: tensor([ 11, 128,   3, 126,  71,  76, 118,  41,  83,  23,  75, 156,  56,   0,
        116,  61,  25,  90, 108,  19, 121,  43,  15,  25,   8,  14,   7,   0,
         81,  72,   5,  46, 106,   8

Evaluating:   6%|▋         | 20/313 [00:00<00:08, 33.91it/s]

argmax start logits shape: tensor([ 53,  29,  79,  33,  57,  21,   1,  38,  39,   2,  30,  23,   8,  34,
        113,  71, 106,  66,  68,  33,  78,  95,  80,   0,  28,  11,  26,  42,
         88,  77,  36,  81,  31,  65,   2, 110,  55,   2,   0,  18,   0,  86,
         89,   2,  82,  39,  44,   2,   3,  28,  20, 102,  14,  39,  90,  88,
         36,  22,  14,  91,  23,  70,  94,   8], device='cuda:0')
argmax end logits shape: tensor([ 53,  33,   4,  98,  22, 109,  62,  38,   5,  49,  66,  42,  92,  34,
          5, 100,   3, 127,  50,  18,  22,  35,  99,  16,  26,  26,  84,  23,
        117,  12,  35,  55,  29,  38,  45, 108,  16,   3,  62,  11,  58,  90,
         51,   3, 102,  58,  20,  60,  10,  59,   7, 104, 110,  39,  90, 106,
          9,  23,  45, 117, 129,  41,  94,  64], device='cuda:0')
argmax start logits shape: tensor([  8,  61,  20, 123,  16,  47,  12,  30,   0,  56,  41,  21,  39,  72,
         12,  18,  31,  13, 123,  53,   2,  41,  44,  18,  47,  78,  16,  31,
         

Evaluating:   8%|▊         | 24/313 [00:00<00:08, 33.29it/s]

tensor([ 57, 101,  41,  67,  41,  26,  89,  16,  13,  59,  11,  21,  56,  30,
          2,  74, 114,  35,  80, 100,  99,   8,  20,  40,  30,   0,  52,  28,
         37,  39,   2,  30,  72,  25, 100,  54,  39,  31,   3,   7,  35,  28,
         75,  41,  95,  55,  60,   0, 124,   0,  17,  42,   2,  78,  37,  68,
        107,   5,  44,  55,  13,   8,   3,  36], device='cuda:0')
argmax end logits shape: tensor([ 44,  28,  31,  70,  58,  26,  47,  34,  58,  24,  17,   2, 121,   7,
         14,  30, 103,  42,  53, 100,  18, 146,  21,  44,  66,  13,  57,  71,
         37,   9,  76,  30,  74,  28, 100,  81,   2,  16,   3,  43,  17,   7,
        114,  29,  86,  45,   7,  27, 125,  34,  54,  63,   4,  53,  37, 105,
         41,  54,  50,  55, 101,   5,  55,  81], device='cuda:0')
argmax start logits shape: tensor([ 42,  64,  61,  85,   1,   9,  28,   4, 110,  11,   2,  43, 105,  49,
         52,  11,  61,  45,  88,  90,  44,  87,   2,  60,   0,  49,  69,  52,
         15,  41,  86,  73,  72, 112

Evaluating:   9%|▉         | 28/313 [00:00<00:08, 33.50it/s]

tensor([ 48,  26,  27, 110,  75,  47,   3, 117,  23,  58,  48,  21,   2, 130,
        134,  94,  57,  84,  50,  21,  71,  28, 119,  63,  42,  38,  65,  59,
          1,  78,  76,  45, 102,  38, 100,  69,  23,  50,  25,  10,  84,  88,
         59, 110,  29,  16,  28,  12,  22,  87,  23,  56,  44, 101,  33,   2,
         43,  25,   8, 114,  40,  50,  22,  22], device='cuda:0')
argmax end logits shape: tensor([ 48,  36,  27,  10,  75,  52,   3,  66,  64, 105,  48, 121,   5,  47,
        134,   5,  23,  18,  56,  23,  73,  54, 119,  86,  17,  40,  41,  89,
         54,  29, 149,  47,  59,  24,  88,  49,  35,  67,  60,   5,  85,  78,
        111,  58,  44,   1,  31,  79,  46,  56,  12,  60,   5, 105,  33,  15,
        112,  25,  27,  45,  49,  56,  46,  54], device='cuda:0')
argmax start logits shape: tensor([ 22,   2,  50,  49,   2,   8,  65,  23,  70,  47,   3,  39,   3,  24,
         46,   8,  15,  64,  30,  19,  27,   5,  16,  17,   4,  48,  17,   2,
         19,   6,  54,   2,  75, 150

Evaluating:  10%|█         | 32/313 [00:00<00:08, 33.88it/s]

argmax start logits shape: tensor([ 15,  44,  75,  15,  27,   3,  23,  34,   9,  62,  54,  25,   3,  75,
         25,   3,  42, 101,  88,  17,   2,  44,  53,  20,   2,  37,  49,  80,
         21,  81,  76,  92,  22,   0,  86,  20,  30,  52, 105,  41,   0,  11,
         45,   1,  49, 105,  43,  46,  47, 101,   8,  76,  75,  52, 115,  21,
         15,  16,  14,  53,  11, 106,  20,   7], device='cuda:0')
argmax end logits shape: tensor([ 43,  28,  76,  41,  85,   3,  73, 106,  48,  29,  64,  35,  44,  35,
         94,  55, 124,  68,  35, 123,  24, 115,  21,  26,   1,  20,  62,  77,
         43,  23,  44,  24,  58,  37,  10,  46,  23,  52, 117,  54,  54, 113,
        106,  29, 114,  95,  88,  39,  58,  82,  21, 149,  61,  66,  71,  62,
         24, 110,  27,  93,  14,  20,   4,  36], device='cuda:0')
argmax start logits shape: tensor([ 72,  44,  49, 123,  61,  99,  25,  27,   0, 114, 107, 113,  65,  50,
         61,  53,  46,  34,  30,  66, 119,  86,  33,  56,  11,  29, 102,  42,
         

Evaluating:  12%|█▏        | 36/313 [00:01<00:08, 33.57it/s]

argmax start logits shape: tensor([ 46,  23,  26,  31,  56,  40,  54,  21,  86,  37,  47,  10,  10,  88,
         55, 124, 112,  12, 106,  71,  32,  15,  25,   0,  19, 101,  92,  52,
         35,  37,  62,  12,  50,  88,  19,  61,  48, 121,  86,  25,  39,  28,
         52,  78,  17,  54,  60,  83,  50,  30,  15,  79,  50,  16,  15,  45,
         17,  27,  49,  18,  40,  90,   8,  15], device='cuda:0')
argmax end logits shape: tensor([ 35,  46,  26,  88,  42,  92,  83,  15,  26,  16,  93,  69,  14,  88,
         38,  49,  48,  14, 106, 109,  55,  75,  86,  10,  67,  82,   6,   7,
          2,  30,  70,  13,  39,  74,  35,  30,  90, 122,  42,  38,  39,   7,
         52,  65,  84,  54,  49,  35,   3,  30,  28,  12,  48,  43,  37,  53,
         14,  14,  18, 143,  78,  93,  74,  74], device='cuda:0')
argmax start logits shape: tensor([ 48,   3,  94,  64,  91,  67,  34,  72,  19,  41,  47,  36,  62,  23,
         46, 123,   3,  42,   7,  56,  83,  64,  90,  16, 116,  49,   2,   3,
        1

Evaluating:  13%|█▎        | 40/313 [00:01<00:08, 33.72it/s]

argmax start logits shape: tensor([ 20,  82,  76,   2,   0,  87, 106,  40,  86,  49,  56,  41,  22,   8,
          2,  49,  85,   5,  53,  41,  38,  20,  43,  76,  70,  55,   7,  81,
         13,  63,  66,   2,   0,  33,  36,   0,  38,  15, 109,  56,   3,   0,
        109,  81,  63,  75,  70,  54,  41,  79,  50,  24,  68,  10,  31, 147,
         99,  27, 109,  31,  48, 109,  67,  33], device='cuda:0')
argmax end logits shape: tensor([127,  80,  26,   2,   8,  17,  20,  40, 135,   8,  36,  41,  54,  63,
          3,  14,  19,  38,  53,  42, 114,  21,  11,  60,   2, 100,  71,  82,
         54,  84,   2, 146,  12,  27,  11,   6,  96,  61, 136,  56,  53,  27,
         80,  78,  86,  55,  17,  26,  21,   4,  60,  24, 105,  69,  31,  34,
         97,  34, 113,   3,   7, 110,  14,  75], device='cuda:0')
argmax start logits shape: tensor([ 51,  86, 148,  44,  10,  81,  64,  22,   2,  90,  26,  37,  61,  80,
         62,  54,   8,  74,   5,  71, 104, 113,  38,   3,  46,  35,  68,  54,
         

Evaluating:  15%|█▌        | 48/313 [00:01<00:07, 33.89it/s]

argmax start logits shape: tensor([107,  20,  76,   2,   0,  10,  32, 106,  84,  34,  62,  61,  57,  72,
         15,   0,  87,  11,  69,  35,  30,  56,  35,  56,  50,  23,  17,  53,
        102,  81,  29,  18,  32,  57,  19,  26,  61,  46,  89,   0,  41,   2,
         52,   6,  66,  30,  14,  71,  81,  59,  46,  35,  88,  28,  24,  83,
         10,  56,  60,  51,  60,  88,   0,  67], device='cuda:0')
argmax end logits shape: tensor([107,  22,  48,  37,  46,  56,  30,  52,  92,  45,  83,   7,   7,   8,
         58,  70, 131,  11,  70,  35,  81,  10,  73, 121,  47,  32,  51,  53,
        104,  61, 106,  44,  65,   7,  19,  81,   2,  77,  77,  49,   2,  81,
         37,  33, 127,  31,  69, 109, 117,  59,  74,  74,  27,  29,  36,  35,
         79,  38,  61,  17,  45,  88,  75,  41], device='cuda:0')
argmax start logits shape: tensor([ 31,  70,   0,  50,  81,  40,  41,  33,   0,  43,  52, 127,  69,   2,
        109,  61,   4,   0,  71,  23,   0,  55,  75,  10,  33,   3,  93,   2,
        1

Evaluating:  17%|█▋        | 52/313 [00:01<00:07, 33.38it/s]

tensor([  8,  21,   8,  83,  31,  38,  55,  34,  19,  59,  50,  53,  36, 150,
         71,   7,   2, 109,  35,  19,  48, 107,  60,  83,  30,  22,  84,   6,
         16,  30,  32, 120,  77,   2,  72,   0, 102,  50,  10,  34,  34, 133,
         56, 129,  49,  92,  47,   0,  45,  21,  29,  91,  55,  34, 104,  61,
         33,   8,  85,  57,  30, 110,  96,  14], device='cuda:0')
argmax end logits shape: tensor([ 70,  43,  53,  21,  33,  36,  12,   5,  75,  55,  43,  30,  48,  89,
        112,  50,  59, 102,  20,  47,  49,  86,  63, 115,  32,  14,  85,   8,
         12,  10,   9, 120,  85,  50,  30,  52,  73,  37,  56,  72,  22, 127,
         10,  17,  30,  24,  55,  80,  29,  56,  29,  13, 116,   6,  93,  43,
         70,  17,  95,  93,  83,  34, 126,  40], device='cuda:0')
argmax start logits shape: tensor([ 21,  36,  30,  25,  59,   3,  11,  25,   8,  98,  69,  26,  21,  15,
          2,  90,  45,  89,  62,  81, 118, 128,  70,  22,  97,  38, 128,  22,
         34,  24,  43,  20,  80,   7

Evaluating:  18%|█▊        | 56/313 [00:01<00:07, 33.33it/s]

tensor([ 52,  76,  20,  10, 124,  13,  86,  49,  13,  46,  70, 113,  25,  54,
         71,  22,  25,  45,  24,  13,  15,  19,  92,   9,  67,  28, 111, 108,
        105,   3,  21,  75, 107, 111,  36,  13,  67,  50, 121,  35,  44,  92,
         62,  67,   6,  85,  32,   2,  45,  23, 103,   1, 110,  99,  33,  80,
         56,  95,  18,  89,  10, 124,   0,  42], device='cuda:0')
argmax end logits shape: tensor([ 64,  23,  25, 110,  29,  62,   5,   7, 116,   3,  12,   5,  20,  55,
         48,  20,  35,  22,  55,  51,  74,  19,  92,  25,  59,   7,  90,  52,
        106,  39,  62,   4,  36, 107,  61,  38,  72,   3,  99,  38,  50,  93,
         59,  48,  65,  46,  32,  24,  12,   3,  71,  67, 110,  69,  33,   7,
         15,  30,   8,  10,  25,  21,  25,  42], device='cuda:0')
argmax start logits shape: tensor([ 44,  12,  94,  31,  12,  69,  95,  34,  24,  55,  55,  60,   2,  45,
         56,  53,   0,  29,  26,  34, 102,   2,  40,  13,   6,   0, 106,  79,
         76,  85,  79,  40,  14,   2

Evaluating:  19%|█▉        | 60/313 [00:01<00:07, 33.70it/s]

argmax start logits shape: tensor([ 89,  15, 107,   8,  14, 100,  36, 105,  79,  15,  36,   0,   2,  42,
        104,  95,  81,   0,  77,  45,  79,   3, 102, 107,  26,  24, 141,  52,
         48,  70,  13,   2,  79, 121,   0, 100,   1,   2,  85,   0,  73,  50,
         63,  31,  87,   8,  26,  19,  72,  85,  21,   5,  50,  48,  83,  89,
         95,  61,  38,  62,  96,  22,  90,  87], device='cuda:0')
argmax end logits shape: tensor([105,  56, 135,  10,  11,  75,  80,  66,  33,  37,  32,  26, 105,  17,
         92,  51,  73,  93,  23,   5,  47,   8,  80, 135,  19,  72, 141,  48,
          5,  70,  63, 104,  47,  42,  36,  88,   9,  80,  51,  25,  49,  16,
         30,   2,  78,   7,  26,  32,  46,  52,  88,  34,  55,  10,  21,  47,
         64,   3,  30,  62,  47,  67,  27,  87], device='cuda:0')
argmax start logits shape: tensor([ 61,  55,  74,  21, 102,  52,   6,  87,  10,  85,  31,  86,  11,  47,
         97,  99,  57, 112,  48,   0,   2, 124,  31,  17,   0,  10,  99,   1,
         

Evaluating:  22%|██▏       | 68/313 [00:02<00:07, 33.86it/s]

tensor([  1,  24,  17,  82,  50,  36,  85,  51,  85,  36,  51,  34,  49,  20,
         25,  20,  51, 109,  22,   2,  49, 121,  28,  85,  52,   3,  60,  47,
          3,   1,  38, 105,   8,  22,  14,   9,  23,   2,   2,  34,  28,  53,
         29,  46,  26,  22,   3,   6,  89,   5,   0,  36,  13,  95,  43,   2,
         33,  22,  67,  74,  71,  82,  83,  95], device='cuda:0')
argmax end logits shape: tensor([ 71,  42,  17,  53,  45,  25, 106,  85,  51,  49,  26,  95,  50,  76,
          2, 102,  24,  55,  64,   2,  35, 104,  64,  27,  90,   3,   8,  76,
          3, 117,  27,  47, 112,  30,  27, 104,  15,  24,   2,  51,  28,  50,
         44,  46,  10,   9,  52,   6,  47,  98,  14,  34,  75,  92,  43, 130,
         33,  25,  50,  56,  67,  72, 115,  95], device='cuda:0')
argmax start logits shape: tensor([ 89,  22,   2,  42,  41,  79,  38, 107,  92,  21,   9,  15,  12,  79,
         21,  62,  35,  94,  33,  55,  63,  43,  24,  85,  46,  88,   0,  78,
         48,   0,  61,  45,  76,  28

Evaluating:  23%|██▎       | 72/313 [00:02<00:07, 33.42it/s]

argmax start logits shape: tensor([ 84, 101,  54,  75,  33,   0,  78,  19,  46,  70,  45,   2,  38,  46,
         43,  76,  65,   0,   3,  10,  23,  72,   2,  57, 107,  20,  43, 108,
         72, 104,   0,  49,  44,  29,  66,  40,  37,  74, 121,  65,  96,  90,
         37,  29,  22,  16, 109,  16,  33,  47,   2,  55,   3,  71,   5,  44,
         18,  43,  99,  90,  52,   3,  57,  31], device='cuda:0')
argmax end logits shape: tensor([111,  46,   3,  28,  37,  63,  20,  17,  46,   3,  53,   2,  60,  43,
          3, 109,  45,  70,  56,  21,  68,  72,   3,  30,  18,  42,  13, 107,
         57,  35,  58,  52,  48,   2, 109,  37,  15,  23, 120,   7,  77,  71,
         14,  29,  59,  30,  93,  19,  13,  28,  55,  68,  10,  71,  39,  58,
         31,  61,  60,  91,  63,  55,  57, 108], device='cuda:0')
argmax start logits shape: tensor([ 10,   6,  45,  13,   2,  83,  24,  40,  51,   4,  90,  30,  16,   3,
         34,  63,  50,  42,   7,  90,  28, 110,  23, 109,  26,   3,  49,  40,
         

Evaluating:  24%|██▍       | 76/313 [00:02<00:07, 32.92it/s]

tensor([  2,  21, 112,   0,  15,  14,  55,  47,   3,  33,   5,  90,  10,  12,
         69,  94, 121, 115,  85,  16,   6,  72, 121,  13,  85,  23, 107,  86,
          2,  47,  27,  17,  40,  20, 104, 121, 111,  58,  24,  49,  50,   5,
         68,  61,  34,  48,  66,  25,  68,  97,  67,  21,  80,  40,  74, 107,
         39,  35,  25,   0,   2,  57,   1,  15], device='cuda:0')
argmax end logits shape: tensor([  2,  21,  79,  36,  47,  71,  44,  54,  27, 115,  11,  90,  12,  24,
         58,  56, 127,  42,  81,  24,   8,  83, 121,  24,  37,  99,  80,  48,
        121,  51,  40,  47,  20,  34,  38,  43, 107,  72,  11,  14,  52,  65,
         54,  86,  26,  17,  80,  46,  65, 107,  70,  30,  14,  19,  52, 135,
         24,  64,  13,  20,   3,  26,   9,   3], device='cuda:0')
argmax start logits shape: tensor([ 87,  42,   0,  53,  40,  24,  47,   6,  43,  76,  46,  33,  69,   0,
         83,   9,  12,   7,  42,  67,  10,  71,  26,  89,  22,   0,  14,  50,
          0, 100,  76,  49, 110,  25

Evaluating:  26%|██▌       | 80/313 [00:02<00:07, 32.77it/s]

tensor([ 15,  66,  11,  97,  43,  13,  65,  41,  52,  64,  64,  76,  38,  32,
         43,   0,  49,  29,  14,  46,  57,  44,   8,  53,  45,  32,  38,  27,
         57,  33,  11,   5,  25,  31,  23,  96,   2,  16,   0,   4,  24,  14,
         73,  19,  67,  43,  36,  49,   8,  32,   0, 105,  91,  15,  42,  98,
          2,  52,  43, 110,   4, 106,  21,  19], device='cuda:0')
argmax end logits shape: tensor([  3,  28,  95,  57,  70,  59,   6,  81,  57,  22,  65, 118,  29,  70,
         95,  79,  50,  68,  48,  73,  45,  83,   8,  40, 107,  35,  26, 104,
         66,   5,  32,  52,  70,  17,   5, 102, 120,  68,  81,  18,  72,  34,
         74,  30,  41,  44,  28,  81,  20,   8,  19,  47,  94,  74,  44, 108,
         63,  64,  25, 110,  59,  97,  33,  33], device='cuda:0')
argmax start logits shape: tensor([ 26,  40,  42,  50,  36,   0,  14,   6,  73,   0,  61,   3,  26,  58,
         98,  41, 120, 126,  81,  68,  21,  94,   0,  49,  30,  33,  12,  59,
         21,  67,  44,   0,   6,  50

Evaluating:  27%|██▋       | 84/313 [00:02<00:06, 32.98it/s]

tensor([  5,  32,   7,  84,  18,  16,  57,  18,  37,  74,  21,  81, 115, 107,
          6,  93,  45, 122,  48,   2,  28,  10,  24,  18,  41,  89,  42,  97,
         71,   5,   5,  48,   3,  77,  20,  45,   4,  25,  22, 103,  42,  81,
         35,  20,  42,  35,  53,  15,  23,  27,  28,  44,  92,  18,  20,  64,
         85,  48,  36,  61,  36, 137,  21,  31], device='cuda:0')
argmax start logits shape: tensor([ 28,  81,  90,  25,  20,   2,  18,   9,  63,  35,  79,   7,  33,  49,
         56,  55,  62,   7,  75,   4,  65,  41,   3,  55, 102,   2,  50,  76,
         22,  14,  49,  50,   8,  50,  42,  74,  98,  38, 114, 105,  38,  56,
         52,  42,  43,  22,   0, 115,   0, 113,  63,  65,  33,  89,  75,  66,
         84,  88,  20,  71,  52,   7,  94,   6], device='cuda:0')
argmax end logits shape: tensor([ 31,  37,   3,  21,  48,  16,  19,  71,  63,  39,  42,  36,  14,  49,
         57,  63,  66,   7,  23,  38,  65,  47,  31,  28,  94,   2,  37,  36,
         25,  61,  43,  27,  19,  93

Evaluating:  28%|██▊       | 88/313 [00:02<00:06, 33.23it/s]

argmax start logits shape: tensor([ 80, 118,  30,  99,  86,  87,  72,  97,  20,  61, 127, 133,   0,  49,
         90,  63,  12,  12,  63,  52,  73,  96,  33,   4, 107,  57,  33,  99,
          7,   7,  66,   0,  65,  22,  24,  88,   3,   0,  72,  46,  47, 104,
         20,  71,   8,  45,   0,  34,   3,  42,   6,  36,   7,  94,  68,   2,
          4,  25,  70,  70,  45,   2, 106,  76], device='cuda:0')
argmax end logits shape: tensor([ 82,  63,  24,  18,  74,  26,  30,  80,  34,  68,  56,  49,  89,  10,
         53,  84,  24,  12,  63,  68,  71,  97,  33,  63,  19,  26,  65, 101,
         62,  43,   4,  55, 105,  14,  17,  52,  52,  72,   2,  42,  28,  45,
         73, 131,   3,  14,  41,  30,   3,   6,   8,  35,  34,  94,   7,  63,
          4,  32,  97,  21,  17,  37,  50,  76], device='cuda:0')
argmax start logits shape: tensor([ 41,  28,  45, 112,   0,  31,  40,  61,  43,   2,  74,  53,  61,  23,
        120,  92,  30,  56,  70,  78,  46,  13,  13,  26,   0,  88,  38,  72,
         

Evaluating:  31%|███       | 96/313 [00:02<00:06, 33.25it/s]

argmax start logits shape: tensor([ 56, 106,  15,   2,  75,   2,   1,  68,  90,   3,  50,  33, 121,  38,
         87,  86,  50,  18, 102,  74,  64,   7,  18,  25,  81,  21,  45, 107,
         13,  56,  56,  20,  61,  42,  81,  47,   0,   6,  38,  14,  36,  47,
          2,  57,   4,  65,  68,  21,  18,  31, 101,  37,  41,  19,  84,   0,
         71,  80,  51,  19,  83,  98,  76,  29], device='cuda:0')
argmax end logits shape: tensor([ 57,   3, 121,  58,  76,   3,  36,  58, 118,   3,  93,  35,  43,  89,
          9,  70,  50,  19, 102,  14,  39,  16,  10,  32,  59,  33, 107,  73,
         13,  37,  50,  13,  35,   6,  93,  76,  37,  50,  18,  50, 132,  47,
        109,  58,  40,   7,  55,  45,   2,  18, 101,  44,  42,  32,  95,  32,
         71,  51,  51,  19,  83, 108,  77,  29], device='cuda:0')
argmax start logits shape: tensor([ 45,  17,  13,  78,  84,  35,   7,   7,  10,   4,  69,  61,  76,  28,
        101,  76,   8,   0,  20,  29,  71,  22,  83,  42,  94,   8,  86,  11,
         

Evaluating:  32%|███▏      | 100/313 [00:02<00:06, 33.72it/s]

argmax start logits shape: tensor([  2,  99,   6,  17,   8,   0,  21,  11,   0,  21,  55,  23,  61,  57,
         44,  28,  20,   0,  19,  15,  47,   3,   4,  57,  22,   5,  31,  69,
         27,  38,   3,  71,  42,  85,  22,  80,  41, 111,   9,  41,  86,  68,
        116, 117,  65,   0,  60,  68,   9,  65,   2,  52,  74,  19,  93, 119,
         83, 130,  10,  61,  53,  54, 105,  21], device='cuda:0')
argmax end logits shape: tensor([ 56,  18, 109,  21,   8, 110, 121,  31,  32,  62,  55,  23,   2,  20,
        115,  56,  24,   5,  15,   9,  28,   3,  42,  57,   9,   5,  43,  19,
          3,  80,  40,  71,  66,   2,  24,  60,  42,  93, 147,  21,  18,  20,
        117,  10,  48, 104,  61,   7,  19,  12,  51,  55,  85,  55,  26,  41,
         63,  47,  48,  85,  57,  83,  18,  78], device='cuda:0')
argmax start logits shape: tensor([  8,   2,  68,  73,  38,  89,  56,  88,  81,  50,   0,  59,  48,  82,
          2, 119,   0,   2,  32,  84,  97,  65,   9,  38,  84,   0,  11,  17,
         

Evaluating:  33%|███▎      | 104/313 [00:03<00:06, 33.18it/s]

argmax start logits shape: tensor([107,  57,  94,   0,  50,   3,  94,  13,  90,  35,  47,  22,  34,  42,
         62,  85,   2,  24,  33,  21, 124,  90, 106,  71,  79,  43,   1,  36,
         25,  81, 119,  55, 105,  93,  70,   6,  99,  24,  55,  10,  38,  76,
          6,  60,  40,  90,  99,  52,   0,  38,   4,  78,  81,   3,  23,  90,
         78,  76,  64,  32,  42,   9,  31,   6], device='cuda:0')
argmax end logits shape: tensor([ 18,  20,  92,  28,  55,   3,  53,  27, 133,  78,  88,  22,  11,  32,
         12,  45,   2,  61,  38,  33, 124,  91,  43, 100, 108,  61,  34,  36,
         25,   4,  34,  30, 106,  79, 110,  33,   3,  45,  57, 100,  28,  73,
         29,  49,  78,  70, 101,  90,  19,  18,  79,  13,  81,  50, 107,  66,
         39,  36,  65,  43,   5,  55,  49,  50], device='cuda:0')
argmax start logits shape: tensor([ 89,  99,  52, 107,   8,  50,  30,  36,  77,  15,  29, 124,  75,  53,
          2,  57,   2,  64,  45,  84,  23,  24,   7,  31,  24,  15,  47,  47,
         

Evaluating:  35%|███▍      | 108/313 [00:03<00:06, 32.43it/s]

argmax start logits shape: tensor([ 65,  76,   5,   6,  64,  90,  56,  37,  12,  67,   1,  83,  29,  71,
         30,  81,  43,  39,  43, 103,  14,  24,  49,  47,  48,   9,  88,  52,
          0,   0,  88,   9,   0, 122,  45,   0,   6,  15,  22, 102, 124,   7,
          5,   7,  82,  20,  25,   0,  22,  42,  11,  44,  80,  42,  36,  91,
         12,  22,  80,  62,  66,  20,  71,  47], device='cuda:0')
argmax end logits shape: tensor([ 24,   2, 122,  11,  74,  37,  55, 117,  16,  14,   1,  10,   2, 109,
         23, 117, 112,   7,  35, 103,   3,  81,  47,  51,  90,  60,   7,  36,
         45,  66, 110,  38,  61,  78,  37,  81,  88,  28, 108,  52,  49,  65,
         39,  57,  92,  20,  48,  24,  74,  12,  67,  20,   1,  42,  23,  44,
         30,   4,  83,  20,  61,  20,  62,  47], device='cuda:0')
argmax start logits shape: tensor([ 19, 116,  64,  75,  15,  18,  60,  59,  25, 107,  25,  72,  89,  39,
         41,  83,   7,  13, 103,  20,  24,  20,   7,   5,   2,  78,  88,  56,
        1

Evaluating:  36%|███▌      | 112/313 [00:03<00:06, 32.92it/s]

tensor([ 23,  30,  18,  11,  85,  99,  61,  35,   0,  58,  18,   2,  17,   7,
         54,  33,  20,   2,  42, 128,   5, 121,  59,   4,  36,  28,   8,  23,
         21,  29,  40,  21,  19,  38,   0,  98, 115,  10,  28,  89,  27,  79,
         19,  79,  38,  96,  45,  72,  22,  25,  21,  68,   3,  64,  47,  38,
         46,  60,  13,   6,  99,  68,  50,  71], device='cuda:0')
argmax end logits shape: tensor([ 23,  59, 117,  45,  74,  69,  40,  19,   2,  14,   9,   3,   2,  23,
          8,  14,  57,  37,  64, 138,  11,  66, 111,  40, 107,  80,  58,  23,
         23,  45,  90,  62,  80,  10,  28,  22,  21,   7,  28,  51,  36,  79,
         19,  54,  41,  75,  38,   2,  40,  51,  60,  59,  33,  98,  19,  59,
         12,  95,  30,  49,  85,   6,  27,  64], device='cuda:0')
argmax start logits shape: tensor([ 19,  31,   2, 107,  15,  52,  13,  45, 113,  72,  23,  79,  49,  72,
         31,  70,  13,  18,  27, 101,  44,  31,   8,   0,  72,   8,  12,  36,
         42,   0, 109,  70,  35, 116

Evaluating:  37%|███▋      | 116/313 [00:03<00:05, 32.90it/s]

argmax start logits shape: tensor([ 75,  86,   2,  88, 100,  10, 114,  48,   2,  27,   2,  83,  59,  44,
         72,   2,  48,   0,  33,  10,  24,  47, 106,  90, 117,  11,  10,  25,
         45,  83,  33,  68,  13,   4,  93,  62,   0,  19,  14,  49,  62,  14,
        110,  57,  19,  26, 104,  33,   8, 128,   2,  57,  37,  70,  12,  26,
         20,  42,  45,   0,  23,  13,   5,   9], device='cuda:0')
argmax end logits shape: tensor([ 61,  86,   2,  52, 100,  48,  59,  17,  12,   4,  16,  30,  24,  71,
         53, 102,  17,  91,  19,  48, 100,  93,  75,  10, 114,  12,  11,  46,
         28,  24,  20,   4,  28,   5,  32,  26,  80,  52,  49,  47,  86,  14,
         41,  15,  80,  87, 117,  37,  33, 128,   5,  20,  37,  66,  92,  81,
          8,  86,  75,  48, 107,  59,   5,  40], device='cuda:0')
argmax start logits shape: tensor([ 24,   7,  56,  38,  19,  51, 100,  25,  14,  49,  95, 114,  24,  31,
          3,  54, 115,   4,  44,  10,  13,  56,  70,  69,  63, 123,  12,  11,
         

Evaluating:  38%|███▊      | 120/313 [00:03<00:05, 32.92it/s]

tensor([104,  11,  40,   2,  50,   8,  39,  15,  50,  75,  94,  75, 127, 111,
         14,  53,  55,  92,  75,   2,   1,  36,  51,  36,  62,  68,  81,  92,
         83,  61,  12,  49,  42,  58,  53,  40,  12,  50,  46,  26, 130,  20,
         28,  59,  11,   0,  61,  87,  44,  33,  23,  65, 110, 107,  55,   2,
         36,  48,  52,   8,  19,  66,   3,   4], device='cuda:0')
argmax end logits shape: tensor([ 94,  13,   7,  18,  42,  22,  52,   6,  10,  60,  93,  79, 114, 113,
         23,  56,  71,  62,  35,  58,   1,  80,  36,  36, 122, 115,  34, 111,
         61,   5,  79,  18,  44,  34,  53,  35,  30,  56,  46,  24, 134,  20,
         30,  59,  33,  44,  62,  78, 115,  14,   9,  67,  10,  81,  16,  13,
         26,  37,  17,  10,  33,  66,   3,   2], device='cuda:0')
argmax start logits shape: tensor([ 51, 108,  81,  67,  19,  14, 124,  68,  99, 100, 139,   7,  60,  16,
         30,  13,   0,  33, 107,   0,  23,  39,  75,  52,  20,  23,  71,   8,
        110, 101,  60,   0,  31,   0

Evaluating:  40%|███▉      | 124/313 [00:03<00:05, 33.23it/s]

argmax start logits shape: tensor([ 34,  29,  36,  41,  40, 115,  22,  26,  34,   7,  92,  95,   5,  47,
         48,  17,  59,  63,   0,  26,   2,   2,  64,   0,   0, 107,  34,  74,
         70,  56,  47,   0, 112,  35, 107,  33,  25,  99,  75,  99,  38, 107,
         77,  28,  21,  61,  13,  49,  80,  47,  44,   6,   0,  10,  61,  76,
        120,  31,  17,  33,  23,  43,  63,  87], device='cuda:0')
argmax end logits shape: tensor([ 35,  30,   5, 104,  40, 117,  14,  84,  22,   8,  30,   5,   8,  47,
         19,  73,  90,  34,  61,  40,  29,  56,  64,  55,  41,  45,  62,  56,
         82,  57,  28,  33,  15,  18,  10,  36,  25, 123,  62,  49,  89, 107,
         95,  29,  65,  53,  20,  23,  40,  89,  44,   6,   3,  31,   8,  25,
        112,  94,  17,  67,  49,  44,  34,  10], device='cuda:0')
argmax start logits shape: tensor([ 11,   8,  16,   9,  50,  77,  94,  23,  18,   6,  62,  52,   0,   1,
          7,  76,  56,   8,  41,  17,  67,  33,  30,  18,  36,  82,  13,  71,
         

Evaluating:  41%|████      | 128/313 [00:03<00:05, 33.11it/s]

argmax start logits shape: tensor([ 50,  51,  50,   2,  35, 111,  45,  20,  34,  15,  35,  64,  62,  20,
         77,  53,  53,  18,  14,  26,  30,  57, 107,  85,  14,   4,  31,  45,
         30,   3,  96,  41,  11,  58,  55,  47,  41,  36,  23, 132,  43,  76,
        106,  30,  70,  34, 116, 133,  80,   0, 127,  81,  47,  94,   8,  43,
         13,  81,   2,  45,   0,  48,  54,   5], device='cuda:0')
argmax end logits shape: tensor([ 31,  99,  89, 102,  19,  36,  65,  20,  20,  48,  78,   9,  45,  20,
         83,  83,  50,  19,  34,  25,  28,  21, 107,   7,  14,  47,   5,  54,
         31,   3, 103,  26,  72,  18,  20,  50, 123,  79,  99, 107,  67,   4,
        109,  33,  17,  34,  30, 133,  86,   6,  56,  81,  43,  94,  23,  95,
         14,  78,   4,  45,   3,  49,   9,  98], device='cuda:0')
argmax start logits shape: tensor([ 57,  20,   0,   7,  11,  12,  30,  10,  84,   1,  48,  98,  13,  65,
         16,  52,   2,  86, 119,  72,  55,  64,  22,  27,  11, 102,  17,  63,
         

Evaluating:  42%|████▏     | 132/313 [00:03<00:05, 33.35it/s]

argmax start logits shape: tensor([ 77,  61,   2,  81,  14,   6,  24,  36,   2,  71,  17, 105,   4,  23,
          6,  18, 105,  72,   2,   6,  31,  24,  31,  67,   0,  33,   0,  33,
          2,  35,  13,  57,  17,  10,  23,   2,  15,  34,   4,   3,  56,  56,
         15,  84,  28,  37,  61,  25,  39,  41,  18,  10,  45,  94,  24,  70,
         33,  53,  40,   2,  55,  42, 120,  86], device='cuda:0')
argmax end logits shape: tensor([ 31,  40,   5,   3,  99,  88,  84,  36,  81,  78,  36,  56,  32,  75,
         13,  40,  86,  55,  21,  60,   5,   3,  27,  52, 108,  28,  13,  37,
         22,  30,  13,  52,  85,  15,  23,   3,  50,   5,  77,  55,  57, 121,
         37,   9,  29,  58,  11,  94,  23,  41,  34,  10,  75,  94,  45,  24,
         23,  99,  47,  16,  55,  45,  48,  23], device='cuda:0')
argmax start logits shape: tensor([ 18,  75, 125,  49,   0,  99,  48,  46,  59,  79,  49,  60,  40,  10,
         98,   0,  11,   9,  71,  21,  13,  14,  83,  47,  44,  36,  34,  19,
         

Evaluating:  43%|████▎     | 136/313 [00:04<00:05, 33.57it/s]

argmax start logits shape: tensor([ 82,  58,  25, 100,  35,  55,  20,  25,  61,  48,  32,  36,  45,  59,
         59,  43,   3,  52,  67,  21,  26,  85,  96,   0,   2,  76,  13,  18,
          3,  38,  20,  13,  68,  84,  35,  42,  66, 123,  44,  55,  10,   3,
         76,  34, 109,  13,  39,  31,   2,  47,   3, 115,   0,  44,  25,  45,
          7,  64,   4, 130,  21, 113,  53,  15], device='cuda:0')
argmax end logits shape: tensor([ 71,   9,  61, 100,  78,  63,  18,  73,   7,  22,  21,  60,  89,  75,
         90,  95,   3,  88,  21,  57,   8,  41, 105,   2,   2,  25, 140,  12,
          3,  91,  19,  18,  14,  70,  26,   4,  33,  15,  50,  62,  56,  45,
         90,  26,  42,  14,  32,  22,  65,  95,  27, 115,  81,  48,  98,  59,
         53,  45,  45, 134, 121, 115,  15,  36], device='cuda:0')
argmax start logits shape: tensor([ 87,  18,  61, 133,  40,  82,  18,  18,  55,  24,  54,  37,  29, 109,
          5,  48,  17,  88,  15,  68,  11,  43,   8,  88,  56,  43, 116,  11,
         

Evaluating:  46%|████▌     | 144/313 [00:04<00:04, 33.84it/s]

tensor([ 17,  44,  49,  57,  33,   9,  83,  19,  80,  53,  52,  27,  19,  22,
        107,  90, 105,  35,  45,  11,   0,  42,  79, 100,  75, 110,  32,  66,
         90,  76,   8,  74,  71,  53,  84,  25,  55,  62,  11,  61,  29,  11,
         32,  38, 118,  85,  42,  93, 103,  36,  19,  43,  20, 150, 114,  99,
         76,  62,  49,   3,  22,  62,  99,  46], device='cuda:0')
argmax end logits shape: tensor([ 21,  50,  65,  45,  35,  85,  77,  35,  31, 118,   9,  17,  63,  42,
         50, 118, 110,  40,  54,  92,  24,  51,  28,  25,  99,  24,  32,  65,
         39,   3,   8,  37,  17,  55,  76,  28,  76,  62,  13,  63,  78,  12,
          5,  57, 118,  41,   6, 129,  12,  50,  19,  68,   4,  59, 104,  54,
         75,  14,  55,  20,  22,  38,  29, 103], device='cuda:0')
argmax start logits shape: tensor([ 30, 112,  92,   0,  39,  12,  12,  42,  33,   9,   2,  11,  65,  48,
         12, 112,  15,  90,   0,  17,  81,  76,  70,  27,  38,  64,   6,  66,
         28,  37,   8,  21,   0,  55

Evaluating:  49%|████▊     | 152/313 [00:04<00:04, 34.03it/s]

argmax start logits shape: tensor([  3,   3,  12,  62,  12,  53,  34,  49,   2,  10,   8, 150,  14,   0,
        108,  11, 137,  19,  39,   0, 131, 105,  58,  40,  22,  37,  46,   2,
         76,   3,  86,  49,  52,  49,  18,  30,  35, 119,  21,  25, 101,  83,
         32,  60,  57, 121,  17,  98,   2, 113,  95, 102,  72, 110,  56,  52,
         51,  24,  42,   0,  45,   3,  17, 137], device='cuda:0')
argmax end logits shape: tensor([ 37,   3,  61,  62,  36, 141,  24,  59,  11,  20,   2,  89,  31,   2,
        107,  11, 137, 117,  58,  93,   7,  41,   8,  40,  73, 103,  10,   3,
        104,   5,   7,  76,  20,  35,  11,  61,  38,  79,  32,  47,   8,  14,
         43,  63,  45,  15,  17, 129, 126,   3,  65, 102, 101,   1,  85,  20,
         54,  40, 104,  93, 114,   3,  21, 138], device='cuda:0')
argmax start logits shape: tensor([ 54,  23,  84,  92,  60,  41,  92,  34,  41,  50,  66, 102,  49,  56,
         12,  23,  43,  58,   7, 142,  36,  38,   0,  93,  48,   9,  56,  46,
         

Evaluating:  50%|████▉     | 156/313 [00:04<00:04, 33.11it/s]

argmax start logits shape: tensor([ 62,  55,  38,  85,  46,  46,  85, 105,  85,  33,  55,  40,   2, 113,
         17,   0,  92,  78,  18,   0,  95, 106,  71,  87,  68,  20,  49,   5,
         16,  78,  26,  29,  96,  99,   8,  24,  67,  24,   7,  56, 152,  29,
          2,  45,  50,  20,  99,  69,  95,   1, 127,   3,  23,   2,  26,  75,
         68, 111,  19,  63,   1,  26,  70,  15], device='cuda:0')
argmax end logits shape: tensor([ 13,  74,  59,  90,  46,  95,  85,  14,  56,  33,  55,  19,   3,  75,
         19, 110, 111,  43,   9,  81, 109,  52,  73,  87,  39,  41,  38,  34,
         43,  78,  28,  55,  87,  78,   2,  24,  58,  72,   7,  62,   5,  68,
          3,  17, 121,  41,  97,  32,  10,  64,  24, 135,   7,  18,  81,   1,
         80,  89,  26,  62,  20,   8,  84,  16], device='cuda:0')
argmax start logits shape: tensor([ 39,  71,  53,  25,  81,  73,  11,  51,  83,  36, 116,   0,  98,  39,
         39,  24,  25,  67,  37,  99,  30,  89,  63,  24,  26,  86, 112,  39,
         

Evaluating:  51%|█████     | 160/313 [00:04<00:04, 33.28it/s]

tensor([ 32,  74,   8,   1,  30,  46,  38, 101,  62,  54,  11,  34,  33,  94,
         27,  36,   5,  59, 115,  91,   2,  10,  86,  82,  45,  36,  46,  40,
         55,  50,  20,   0,  21,  43,  22,  33,  52,  11,  24,  45,  38,  75,
         48,  17,   8,  11,  98,  43,  15, 100,  68,  10,  13,  39,   5,   6,
          3,  14,  15, 105,  85,  74,  32,   3], device='cuda:0')
argmax end logits shape: tensor([ 30,  77,  26,   1,  67,  79,  44,  62,  59,   7,  75,  49,  82,  53,
         58,  36,  30,  42,  15,  44,   3,  15,  83,  80,  53,   5, 146,  40,
         27,  79,  22, 110,   2,  17,  18,  98,  34,  95,  29,  22,  23,  75,
         49,  17,  40,   3,  64,  13,  47,   7,  31,  34,  53,   2,  87,  30,
          3,  23,  66,  22,  19,  74,   7,  27], device='cuda:0')
argmax start logits shape: tensor([121,  68,  59,  71,  79,   8,  21,  74,   0,  44,   2,  25,  44,  17,
         10,  53,  94,   5,  31,  26,  26,  21,  34,  62,  45,  23,  76,  46,
         36,  38,  14,  68,  45,   6

Evaluating:  52%|█████▏    | 164/313 [00:04<00:04, 33.53it/s]

argmax start logits shape: tensor([ 27,  30,  33,   0, 112,  64,  85,  88,  36,  11,  72,  61,   0,  19,
         91,  81,  86,  83,   7,  98,  68, 101,  21,  48,  86, 120, 137,  38,
         79,  68,  46,   8,  85,  39, 108,  34,  86,  22,   8,  44,  71, 114,
         87,  21, 106, 114, 106,  41,  33,  15,  52,  69,   7,  53, 107,  23,
         61,   2,  46,  25,  12,  54,   6, 101], device='cuda:0')
argmax end logits shape: tensor([ 75,   5,  13,   8, 120,  54,  85,  47,  45,  11,  19,  82,  54,  11,
        128,  78,  74,  83,   7,  14,  50, 101,  88,  24,  42,  84,  62,   9,
         65,  79,  49,  64,   8,  42,   3,  10,  94,  37,  52,  37,  64, 108,
         23,  60,  77,  90,  46,   7, 117,  30,  58,  29,  27, 106, 122,  24,
         50,  55,  35,  45,  36,  54,   6,  35], device='cuda:0')
argmax start logits shape: tensor([ 34,   8,  24, 101,  31,  13,  85,  66,  31,  87,  88,  11,  21,  27,
         76,  17,  49,   2,   0,  13,  49,  59,  55,  13,  44,  64,  72,  24,
         

Evaluating:  55%|█████▍    | 172/313 [00:05<00:04, 34.04it/s]

tensor([  7,  12,  90,   2,  70,  56,  70,  37,   0,  10,  90,  24,  33,   0,
         14,  74,  11,  42,  15,   6,  36,  32,  23, 137,  20,  21,  43,  51,
          4,   5,  49,   2,  54,  54, 104,  41,  31, 106,   6,  38,  14,  38,
          0,  75,  61,  99,  19,  32,  83,  77,  38,  13,  62,  54,  21,  42,
         42,  45,   2,  36,  85,  18,  44,  85], device='cuda:0')
argmax end logits shape: tensor([  3,  42,  52,   5,  27,  35,  56,  15, 104,   7,  18,  22,  36,  48,
          5, 138,  50,   6,  42,   8,  34,  46,  99, 111,  20,  93,  59,  86,
         47,   5,  49,   3,  65, 137,  15,  29,  70,  50,   6,  17,  69,   7,
         16,  76,   3,  19,  51,  25, 117,  26,  32,  14,  86,  65,  93,  87,
         81,  14,   3,  50,  85,  10,  44,  44], device='cuda:0')
argmax start logits shape: tensor([  3,   0,  23,   2,  78,  63, 102,   4,   8,  15,  35,  15, 119,  67,
        108,  34,  72,   8,   6,  43,  68,  30,  62,  66,  62,  12,   0,   0,
         30,  30,   5,  94,  27,  24

Evaluating:  58%|█████▊    | 180/313 [00:05<00:03, 34.22it/s]

argmax start logits shape: tensor([102,  24,   6,  90, 108,  34,  45,   2,  55,  69,  24,   7,  49,   2,
          0,   6,  65,   0,  92,  45,  63, 113,  24,  41,  13,  61,  99, 113,
         58,   4,  44,  48,  40,  18,  28,  63,  39,   7,  31, 104,  14,  26,
         20,  59,  18,  99,  83,  44,  86,  32,  97,  44,  99,  66, 115,  83,
         10,  54,   0, 106,  49,  12,  21,  67], device='cuda:0')
argmax end logits shape: tensor([  4,  22,  18,  90,  34,   9,  48,  18,   3,  25,   2, 118,  14,  15,
         44,   2,  66,  13,  92,  22,  47,  36,  36,  29,  13,  70,  69,  55,
         59,   4,  53,  46,  19,  77,  46,  63,  39,   2,  31, 100,  81,  19,
         20,  59,  67,  19,   7,  33,  86, 115,  76,  44,  78,  56,  43,  53,
         15,  26,  33,   4,  25,  14,  44,  67], device='cuda:0')
argmax start logits shape: tensor([ 64,  48,  50,  38, 101, 114, 104,  55, 116,  24,  20,  78,  71,  81,
         42, 108,  72,   8,  51,   7,  81,  47,   4,  50,  62,  79,  17,   8,
         

Evaluating:  59%|█████▉    | 184/313 [00:05<00:04, 31.82it/s]

tensor([ 18,  54,  52,  35,  38,  81,  57,  14,  27,  34,  10,  47,  44,  69,
         16,  80,  63,  20, 118,  23,  75,  99,  49,  16, 100, 102,  15,  50,
         77,  56,  33,  20,  20,  34,  97,  90,  57,  66,  90,  64, 102, 133,
         37,  69,  83,   2,  23,  47,  30,  53,  47,  76,  30,  49,  64,  79,
         92,  48,  81,  34,  56,  45,   0,   7], device='cuda:0')
argmax end logits shape: tensor([ 18,  11,  68,  35,  15,  81,  23,   3,  36,  95,  19,  37,  59,  71,
         44,  20,  80,  23,  87,   4,  38,  90,  81,  10, 100,  16,  50,   8,
         86, 128,  42,  23, 114,   9,  41,  91,  15,  23, 100,  81, 113,  49,
          3,   2,   7,   2,  49,  75,  32,  53,  38,  76, 101,  35,  65,  79,
        107,  62,  82,  35,  17,  59,  31,  51], device='cuda:0')
argmax start logits shape: tensor([ 43,   7,  79, 106, 129,  31,  18,  48,  55,  65,  66,  14,  92,  43,
         86,  42,   2,   1,   2,  33,  69, 113, 123,   0,  42,   3,  87,  50,
         31,  17,  34,  45,   3,  12

Evaluating:  60%|██████    | 188/313 [00:05<00:03, 32.55it/s]

argmax start logits shape: tensor([101,  55,  50,  17,  98,  88,  55,  40,  32,  54,  44,  96,  11,  75,
         11,  59, 121,  28,  12,  59,  36,  11,  29,  79,  26,  10,  49,  23,
        126,  65,  69,  13,  92,   8,  70,  82,  10,  43,   6,  54, 138,   2,
         62,  31, 109,  71,  37,  22,  26,  98,  38,  46,  67,   6,  75,  14,
         24,  73,  75,  14,  70,  29,  43, 114], device='cuda:0')
argmax end logits shape: tensor([101,   3, 109,   8,  34, 116,  20,  90,  24,  40,  20, 126,  13,  49,
          5,  59,  42,  35,  10,  32, 103, 107,  29,  64,  28,  51,  62,  24,
         49,  70,  29,  13,   2,  66,  82,  12,  31,  59,   9,  79,  16,   3,
         97,  42,  42,  71,  37,  58,  26, 108,  30,  97,  67,   8,  17,  61,
        117,   2,  79,  14,   8,  22, 109,  87], device='cuda:0')
argmax start logits shape: tensor([ 95, 110,  96,  28,  15,  38, 104, 102, 113,  38,  97,  74,   0,   5,
         74,  39,  76,   2,  12,  13,  55, 137,  29,  38,  72,  84,   2,  84,
         

Evaluating:  61%|██████▏   | 192/313 [00:05<00:03, 33.39it/s]

argmax start logits shape: tensor([ 26,  40,  29,  94,  74,  32,  30,  39, 102,  70,  29,   6,  49,   8,
         61,  32,  23,   2,  56,   2,  10,  38, 100,  49,  49,  47, 153, 114,
         31,  26,  54, 121,  15,  67,  18,  58,  99,  87,   0,  43,  94,  38,
         76,  22,   4,  18,  22, 114,  39,  10, 143,  20,  47,  12,  16,  50,
          2,  33,  93,   7,  21,  14,  33,  24], device='cuda:0')
argmax end logits shape: tensor([ 26,  71,  33,  97,  23,  94,  56,  46,  22,  84,   2,  33,  59,  98,
         30,   5,  40,  12,  37,  12,  24,   5, 140,  62,  59,  50,  75,   4,
         23,  10,  11, 121, 122,  62,  21,  33, 123,  12,  36,  61,  94,   5,
         40,  58,  59,  48,  13,  76,   9,  10, 107,   2,  14,  27,  33,  50,
         43,  33,  32,  16, 121,   2,  59,  34], device='cuda:0')
argmax start logits shape: tensor([ 31,   7,  52,  20,  33,  97,   7,  59,   4,   0,  67,  45,  30, 121,
         73,  16,  86,  11,  52,  15,  46,  16,   5,  47,  94,  80,  28,   2,
         

Evaluating:  64%|██████▍   | 200/313 [00:05<00:03, 33.66it/s]

tensor([ 84,  79,  12, 104,  75, 107, 114,  26, 106,  37,  20,  81,  88,  53,
         20,  62,  93,  27,  32,  66,  35,  36,  67,  86,  62,  90,  31,  58,
         75,   5,  22,  61,  40,  49,  59,  80,   2,  46,  17,  52,  81,  63,
         87,  34,  62,  10,  81,   5,  67,   3,  69,  13,  27,  70, 113,  20,
         63,  74,  33,  53,   6,  63,  49,  11], device='cuda:0')
argmax start logits shape: tensor([  0,  32,  53,  22,   6,  37,  80, 116,  21, 101,  59,  13,  70,  53,
         91,  21,   4, 156,  30,  49,   2,  37,  20,  19,  41,  85,  63, 103,
          0,  20,  96,  56,   0,  58,  33,  10,  46,  55,  17,  17,  67,  22,
         24,  34,  53,  24,  24,   0,  73,  33,  26,  55,  31,  34,  11,   3,
          6,  48,  82,  47,  36,  75,  75,  47], device='cuda:0')
argmax end logits shape: tensor([ 26,   4,  25,  45, 115,   3,  88,   4,  83, 101,  57,   6,  77,  50,
          4,  15,  59,  69,  59,  45,   2,  34,   7,  17,  29,  81,  32,  96,
         31,  56,  26, 129,   5,  58

Evaluating:  65%|██████▌   | 204/313 [00:06<00:03, 34.01it/s]

tensor([ 12,  29,  29, 114,   1,  67,  59,  20,   3,  55, 120,  66,  29,  46,
         38,  21,   0,  38, 102,  89,  17,  30,  36,  72,  64,  18,  58,  32,
         22,   2,  78,  52,   2,  63,   2,  24,  10,  21,  50,   2,  72,  59,
         67,  81,  83,  33,  50,   5,  87, 128,   0,  60, 109, 129,  52,  85,
         35,   3,   8,   0,  96,   5,   2,  84], device='cuda:0')
argmax end logits shape: tensor([ 28,  30,  29,  26,  58,  81,  40,  98,  27,  11,  27,  73,  56,  73,
         60,  32,  21,  29,  84,  89,  47,  47,   5,  74,  15,   4,  58,  70,
          1,  56,  61,  75, 149,  34,  91,  80,  54,  62,  42,  52,  72,   8,
         42,  23,  61,  27,  51,   5,  56,  57,  74, 117, 106,  30,  88,  85,
          8,  27,  45,  15, 118,   5,  24,  77], device='cuda:0')
argmax start logits shape: tensor([ 87,  86,  84, 102,  19,  22,  21, 101,   2,  55,   2,  29,   7,  66,
         48,  84,  67,  88,  14,  56,  92,  56,  21, 133, 125,  10,  68,  62,
          4,  31,  34,  18,  13,  76

Evaluating:  66%|██████▋   | 208/313 [00:06<00:03, 33.49it/s]

tensor([ 33,  18,   9,  47,  56,   4,  55,  53,  87,  69,  15,  26,  25,  13,
         85,  16,  20,   0,  88,  31,  39,  18,  81,  14,  79,   2,  63,  13,
         88,  19,  13,  10,   0,  95,  55,  88,  76,  45,  72,  36,  39,  68,
         88,  56,  25, 142,  24,  42,  12,  18,  20,  39,  27,  39,   0,  75,
         42,  89,  27,  45,  31,  75,  61,  81], device='cuda:0')
argmax end logits shape: tensor([ 13,  45,  94,  77,  24,  38,  63,  53,  33,  19,  47,   4,  45,  30,
          2,  33, 114,  31,   4,  70,  10,  49, 114,  14,  17,  14,  32,  34,
        111,  75,  59,   3,  91,   3,  38, 124,  89,  86,  46,  53,  40, 112,
         89,  36,   2, 142,  33,  42,  13,  40,  73,   9,  59,  80,  84,  44,
         81,   7,  40, 132,   8,  67,  61,   4], device='cuda:0')
argmax start logits shape: tensor([ 31,  75,  77,  49,   8,  50,  35,  39,   0,  44,  66,  47,  40,  15,
          5,  30,  17,  27,  95,  78,  83,   2,  22,  81,  42,  22,  34,   0,
         78,  46,  61,  16,   0, 103

Evaluating:  68%|██████▊   | 212/313 [00:06<00:03, 33.30it/s]

tensor([ 49,  47,  10,  24,  21,  55,  62,  51,  78, 100,  32,  91,  47,  31,
          0,   0,  44,  30,  73,  42,  88,  76,   8,  52,  30,  85,   2,   2,
         21,  43,  37,  25,  17,  57,  34,  39,  57,  33,  18, 121,  28,   7,
          6,  92, 156,  97,  72,  44,  35,  80,  42,  13,   2,  10,   0,  20,
         59,  92,  33,  40,  90, 105,  46,  17], device='cuda:0')
argmax end logits shape: tensor([ 92,  93,  23,  34,  37,   3,  97,  54,   3,  54,   3,  82,  75,  29,
         35,  66,  91,  39,  70,  45,  53, 121,   6,  75,  70,  21,  21,   3,
         60,  99,   7, 110,  47,   7,  62,  40,   3,  95,  49,  99,  22,  22,
         48,  92, 112, 140,  19,  38,  37, 132,  18,  23,   2,  37,  25,  20,
         43,  95,  33,  38, 114,  56,   4,  54], device='cuda:0')
argmax start logits shape: tensor([ 92,   2,   2,  29,  76,  31,  28,   6,  81,   0,   0,  42,  12,   0,
         11,  10,  58,   5,  81,  58,  58,  42,  22, 113,  38, 110,  31,  24,
         21, 127, 102,  22,  79,  72

Evaluating:  69%|██████▉   | 216/313 [00:06<00:02, 33.37it/s]

argmax start logits shape: tensor([ 22,  90,   7,  99,  45,   8,  46,  32, 130,  38,  18,  70,  31,  21,
         10,  23,  94, 115, 104,  35,  15,  53,  81,  55,  17,   6,  49, 136,
         62,  39,  13,   1,   0,  24,  50,  30, 102,  27,  14,  35,  71,   2,
          7,   3,   4,  43,  33,  34,  98,  31,  70,  86,  32,  72,  18, 148,
         40,  35,  25, 112,  61,  41,  17,  27], device='cuda:0')
argmax end logits shape: tensor([ 40,  91,  57, 100,  89,  26,  33,  30,  23, 126,  31,  70,  31,  21,
         69,  73,  95, 113, 100,  74,  89,  53,  63,  64,  31,  49,  49,  24,
         35,   2,  34,   6,  61, 136,   4,  87,   7,  31,  95,  81,  75,  52,
         36,  14,   4,  66,  31,  20,  99,  16, 120,  10,  94,   8,   3,  42,
         89,  61,   2,   7,  79,  41,  70,  98], device='cuda:0')
argmax start logits shape: tensor([117,  40,  69,  14,  41,  47,  21,   7, 123,  50, 121,  16,   0,  36,
         24,  31,  20,  50,  56,  68,   2, 123,  46,   6,   5,  62,  15,   3,
         

Evaluating:  72%|███████▏  | 224/313 [00:06<00:03, 26.19it/s]

argmax start logits shape: tensor([ 64,   8,  38,  66,  23,  72,  53,  59,  17,  82,  42,   3,  86, 113,
         23,  30,  73,  11,  35,   8,  64,  35,  61,  55,  32,  86, 128,   5,
        118,  99,  75,  47,  69,  94,  69,   3,  15, 108,  76, 123,  11,  12,
         84,  71,   2,  33,  87,  97,  10,  31,   7,  13,  15,  47,  88,  20,
         18,  27,  23,  87,  68,  45,   8,   9], device='cuda:0')
argmax end logits shape: tensor([ 64,   1,  29,  28,  24,  84, 118,  89,  47,  53,  59,   5,  10, 115,
         30,  43,  73,  11, 138,  52,  88,  18,  63,  30,  21,  86,  34,   8,
         68,  20, 114,  48,  20,  95,  25,  14,  29, 108,  15,  45,  53,  76,
         84,   5,  68, 105,  11,   5,  10,  29,  63,  13,  84,  38,  56,  53,
         95,  36,  55, 107,  55,  59,  90,  48], device='cuda:0')
argmax start logits shape: tensor([  0,  84,  45,  18,  88,  26,  54,  30,   0, 126,  32,  26,  53,  33,
         10,  90,  14,  84,  56, 120,  65,  21, 101,  17,   7,  16,   0,  19,
         

Evaluating:  73%|███████▎  | 228/313 [00:06<00:03, 28.08it/s]

argmax start logits shape: tensor([ 53,  53,  49,  67,  45,  94,  40,  30,  85,   0,  25,  89,   0, 111,
         33,  54,   1,  33,  31,  64,  99,  14,  21,  99,  88, 113,   8, 105,
          7,  38, 130,  28,  21, 103,  36,  72,  77, 136,  89,  13,   2, 110,
        133,  71,  43,  57,  11,  63,   3,   2,  86,  33,   7,  25,  86,   4,
         59,   0,   0,  88,  49,  23,  10,  22], device='cuda:0')
argmax end logits shape: tensor([  9,  54, 111,  28,  51,  83,   4,  24,   2,  23,  98,  89,  78, 111,
        125,  54,  76,  80,   5,  64,  78,   4,  41,   3,   2,  38,   8, 117,
         28,   5,  42,  18,  21,  12,  60, 101,  61,  24, 112,  14,  46,  43,
         49,  81,  70,  30,  24,  52,  99,   3,  48,  98,  88,   9,  23,  79,
         86,  11,   3,  88,  27,   7, 114,  36], device='cuda:0')
argmax start logits shape: tensor([ 37,  35,  32,  87,  53,  21,  15,  58,  62,   5,  87,   6,   4,  84,
          3,  97,  76,  62,  54,  80,   9,  57,  82, 106,  54,  76,  16,  25,
         

Evaluating:  75%|███████▌  | 236/313 [00:07<00:02, 30.83it/s]

argmax start logits shape: tensor([ 76,  44,  79,  40,   5,  43,  47,  29,  21,  76,  73,  43,  12, 108,
         37, 113,  44,  16,  35,  75,   8,   0,  54,  40,   0,  82,  47,  38,
        128,  22,  69,  68, 140,  65,  50,  10,  26,  69,   5,  11,   3,   7,
        118,  88,  18,  54,  56,  46,   9,  27, 114,  68,  76,  43,  72,   2,
         99,   3,  45,   9,  71,  30,  16,   4], device='cuda:0')
argmax end logits shape: tensor([ 86,  31,  79,  40,  45,  17,  52,   3,  37,  31,  73,   5,   2,   2,
         89,   5,  38,  11,  89,  75,   9,  65,  78,  19, 110,  72,  80,  46,
          6,   4, 100,   5, 121,  24,  31,  42,  11,  49,  39,  39,   2,  25,
        118,  14,  30,  33,  59,  71,  55,  31,  66,  55, 107, 105,  74,  45,
         85,  32,  41,  22,  10,  14,  13,   5], device='cuda:0')
argmax start logits shape: tensor([ 42,  72,  47,   3,  33,  41,  80,  53,  80,  37,   9,  48,  65,  29,
         10, 103,  25,   2,  53,  30,  65,   0,   8,  97,  56,  41,  75,  29,
         

Evaluating:  78%|███████▊  | 244/313 [00:07<00:02, 32.83it/s]

tensor([ 82,  70,  76, 109,  83, 106,  80,   2,   0,  76,  28,  69,  28,  47,
         75,   0,  19,  68,  71,   6,  14,  25,  87,   2,  19,  65,  61,  87,
         77,   2, 113,  30,  59,  33,  69,  42,  67,  49,  90,  86,  38,  88,
         26,  86,   2,   1,  59,   0,  28,  10,   7,   5, 107,  83,  51,   4,
         49, 106,  22,  55,  15,  14,  34,   3], device='cuda:0')
argmax end logits shape: tensor([ 41, 126,  25, 106,  21,  84,  47,   9,  24,   6,   4,  71,  32,  35,
        102,  87,  27,  59,  71,  14,   2,  20,  23, 121,  50,   6,  61,  87,
         95,  35, 120,  21,  20,  33,  66,  23,  67,  62,  90,  86,  53, 107,
         21, 105,  12,  34,  40, 104,   7,  72,   7,  45,  10,  83,  53,  25,
         34,  52,  53,   9,  49,   9,  11,   3], device='cuda:0')
argmax start logits shape: tensor([ 42, 108,  11,  63,   2,  79,  23,  44,  34,  85,  65,  79,   3,  10,
         60,  22,  79,  48,   8,  74,  86,  52,  55,  10,  42,  29,  92,  35,
         68,  50,  55,   3,  49,  94

Evaluating:  81%|████████  | 252/313 [00:07<00:01, 33.36it/s]

argmax start logits shape: tensor([  3, 109,  40,   0,  99,  15,  39,  33,  86,  75,  58,  47,  90,  11,
         47,  12,   7,  39,  45,  46,  80,  80,  97,  82, 107, 128,  40, 100,
         29,  11,  95,   2,  74,  30,  46, 105,  14,  22,  38, 106,  15,  25,
        123,   0, 127,   3,  36,  36,   1,  69,  51,  66,  47,  40,  34,  80,
         25,  16,   7,  15,  25,  23,   8,  11], device='cuda:0')
argmax end logits shape: tensor([ 55, 122,  80,  31,   4,  37,  39,  31,  74,  78,  62,  76,  90,  37,
         38,  73,   7,  39,   8,  92,   7,  47, 100,   3,  87,  34,  40,  85,
         15,  51,  30, 103,  74,  81,  36, 117,   5,  73,  87,  50,  96,  25,
        123,  48,  28,   6,  36,  49,  21,  25,   7,  51,  95,  42,  29,  53,
          2,  16,  14,  21,  67, 119,  40,  16], device='cuda:0')
argmax start logits shape: tensor([  4,   1,   0,   0,   2, 120,  40,  25,  15,  42, 115,  78,   0,  10,
         73,  36,   6,  60,  67, 128,  44,  37,   0,  30,  31,   0,  74,  43,
         

Evaluating:  83%|████████▎ | 260/313 [00:07<00:01, 33.96it/s]

tensor([ 77,  20,  85,  14,  56,  96,  15,  53,  54,  95,  81,  30,   4,  19,
         10, 103,  96,   2,  64,  54,   0,  48,  62,   8,  34,  27,  54,   5,
         75,  28,  92,  98,   5,  18,  61,  16,  81,  68, 107,  62,  69,  17,
         85,  47,  25,  81,  13,  25,  67,  79, 100,  17,  74,  50,  99,  57,
          0,  11,   2,  59,  41,  40,  76,  15], device='cuda:0')
argmax end logits shape: tensor([ 78,  42,  37,  37, 110, 105,  72,  88, 120,  95, 114,  44,  12,  63,
         74, 136,  47,  56,   8,  54,  18, 100,  86,  11,   9,  60,  48,  99,
         36,   6,  92,  14,   2,  22,  63,  13,  83, 112, 100,  60, 111,  17,
         44,   2,  16,  93,  58,  94,  34, 108, 100,  59,  37,  42,  20,  57,
         27,  75,  43,  59,  21,  26,  46,  74], device='cuda:0')
argmax start logits shape: tensor([ 11,   2,  36,  19, 114,  28,  16,  22,  94,  10,  26,  35,  23,  55,
         49,  33,  88,  26,  18,   5,  21,  21, 100,  77,   4,  24,  98,  43,
         49,  34,  44,   0,  97,  14

Evaluating:  84%|████████▍ | 264/313 [00:08<00:01, 34.18it/s]

argmax start logits shape: tensor([ 25,  65,   0, 116,  43, 149,   0,  51,  32,   7,   6,  20,  60,  23,
         56,  44,  55,  75,  25,  49,  69,   2,   2,  61,  98,  38,  21,  65,
          8, 100,  27,  91,   3,  13,  62,  39,  24,   3,  98,  45,  34, 139,
         88,  44,  59,  99,  33,   0,  90,  79,  56,  20,  68,  36,  37,  60,
         74,  33,  77,  20,  20,  26,  59,  69], device='cuda:0')
argmax end logits shape: tensor([ 55,  66,  28,   4,  70,  32,  13,  51,   2,  23,   4,  61,  61,  23,
         47,  41,  30,   7,  35,  81,  69,  27,  75,  81,  92, 114,  53, 109,
         98, 100,  42, 104,  16,  14,  14,  10,   6,  52, 105,  53,  57, 142,
         21,  44,  59,  90,  32,   9,  52,  76,  11, 112,  71,  18,  37,  27,
         23,  13,  31,  35,  18,  28, 111,  77], device='cuda:0')
argmax start logits shape: tensor([ 43, 108, 124,  71,  21,  41,  15,   4,   2,  57,  77,  26,  92, 129,
          5,  68, 117,  45,  59, 138, 131,   0,  69,  53, 114, 131,  57,   0,
         

Evaluating:  87%|████████▋ | 272/313 [00:08<00:01, 33.98it/s]

argmax start logits shape: tensor([ 42,  30, 114,   2,  25,  53,  38,  11,   7,  38,  49,  75,  76,  74,
         18,  19,  13,  99, 113, 107,  76,  88,   6,   3,  99,  10,  27,  31,
         86,  43,  47,  44,   0,  31,  65,  57,  49, 119,  99, 117,  39,  33,
         80,  50, 116,   7,  13,  18,  38,   1,  10,  24,  11, 130,  56,  32,
         11,  83,  55,  41,   0,  17,  84,  28], device='cuda:0')
argmax end logits shape: tensor([104,  54,  90,  92,  30,  93,  51,  66,   9,  38, 102,  85,  41,  77,
         23,  15,  13,   8,   3, 100,  77,  24,  29,  58,  29,  10,  14,  75,
        102,  67,  19,  33, 123,  41,  45,  22,  30,   8,   8,  27,  24,  74,
         82,  50, 113, 107,  38,   2,  17,  45,  56,  57,  11, 130,  64,  27,
         15,  46,  55,  37,   7,   8,  85,  79], device='cuda:0')
argmax start logits shape: tensor([ 56,  64,  88,  84,  80, 108,  69,  13,  13,  12,  64,   9,  56,  38,
         54,  63,  98,  26, 106,   7,   0,  41,  85,  12,   0,   4,  78,  23,
         

Evaluating:  89%|████████▉ | 280/313 [00:08<00:00, 33.36it/s]

tensor([ 26,  89,  40,  18,  91,  26,   3,  64,  10, 128,  86,  94, 114,  46,
        101,  14,  91,  84,  81,  65,  38,   8,  70,  22,  88,  28,  37,  29,
         31,  25,  67,  98,  99,  59,  33,   9,  85,  14,  33, 106,   7,  62,
         71,  53,  71,  12,  53,  67,  93,   5, 109, 114,   3,  34,  35,  70,
         37,  36, 127,  78,  54,   7, 112,  82], device='cuda:0')
argmax start logits shape: tensor([ 13,  56,  66,   0,  17,  75,  80,  46,   5,  19,  16,  84,  35,  30,
          2,  23,  31,  32,  55,  56,  76, 112,  20,   0,  83,  21,  50,  49,
         11,  10,  36,  17,   7,  12,  57,  17,  63,  65,  20,  52,  37,  63,
         60,  61,  24,  21,  54, 100,  25,   7,  17,  92,  40,   6,  90,  40,
         78,   0,  93,   2,  10,   7,   8,  43], device='cuda:0')
argmax end logits shape: tensor([ 70,  23,  86,  82,  20,  42,  55,  75,  86,  28,  52,  44,  21,  33,
         56,  65,  29,  75,   7, 121,  60,  65,   7,  38,  61,  36,  39,  62,
         75,  31,  38,  54,   9,  30

Evaluating:  92%|█████████▏| 288/313 [00:08<00:00, 31.69it/s]

tensor([ 67,  59,  44, 132,  91,  84,   3,  53,  20,  27,  68,  93,   2,  89,
         43,  11,  20,  59, 112,   5,  20,   9,  93,  55,  62,  76,  56,  51,
         23,  15,  65,  59, 100,  88,  58,  98, 114,  19, 102,  17,   0,  47,
         99,   2,  17,  83,   0,  34,   1,  33,  75, 117,  43, 130, 106,  75,
         75, 108,  10,  41,  33,  55,  46,  13], device='cuda:0')
argmax end logits shape: tensor([ 68,  59,  38, 135,  10,  75,   3, 118,   5,  65,  67,  79,  46,   7,
         65,  75,  18,  35,   7,  45, 114,  15,  24,  34,  63,  93,  62,  54,
         35,  41,  24,  10,  75,  36,  58,  47,  66,   4,  47,  59,  75,  10,
        123,  36,  70,  61,  57,  58,  45,  36,  64, 114,  67,  40,  84,  65,
         63,  20,  27,   6,  35,  20,  16,  34], device='cuda:0')
argmax start logits shape: tensor([  8,  57,  34,  42, 108,  86,  29,  51,  33,  33,  84,  51,  38,  27,
        109,  88,  98,  85,  76,  76,  19,  70,  11,  82,  69,  11,  89,  12,
         16,  77, 104,  49,   8,  76

Evaluating:  93%|█████████▎| 292/313 [00:08<00:00, 32.65it/s]

argmax start logits shape: tensor([ 23,  94,  19,  59,   5,   3, 128,  20,  30,  80,   2,   0,  41,  18,
         36, 108,  42,  69,  21,  14,  46,  40,  86,  40,  48,  48,   0,  81,
         65,   0,  72, 156,  13,  22,  22,  20,  56,  31,  84, 110,  19,   2,
          0,  71,   3, 137,   0,  18,  34,  23,  68,  30,   3,  23,  63,  43,
          0,  28,  25,  79,   8,  62,  97,  46], device='cuda:0')
argmax end logits shape: tensor([  3, 108,  94,  41,  65,  53,  85,   8,  28,  40,  15,  40,  20,  38,
         50,  42,  44,  69,  84,  16,  94,  49,   7,  60,  50,  19,  28,  23,
         65, 118,   5, 112,  10,   9,  66, 105,  55,  50,  46,  58,  35,   2,
         46,  35,  99, 111,  33,  18,   5,  65,  54,  55,  36,  23,  86,  40,
         16,  80,  25,  42,  66,  29,  99,  84], device='cuda:0')
argmax start logits shape: tensor([ 20,  38,  39,   0,  16,  16,   7,  54,  95,  15,  15,  87,   3, 105,
         24,  32,  34,  60,  25,  95,  19,  14,  44, 114,  68,  54,  58,  13,
         

Evaluating:  96%|█████████▌| 300/313 [00:09<00:00, 32.80it/s]

tensor([  5,  43,  23,  55,  19,  23,  40,  89,  38,  54,  23,  79, 121,  16,
         94,  58,  97,  95,   6,  72, 109,   3,  20,  15,  68,  92,  26,  59,
         69,  70, 101,  61,  40,  48,  42,  82, 142,  46,  30,  42, 100,  54,
         83,  94,  50,  34, 125,  76,   2, 137,  41,  27,  13,  95,  40,  33,
         69,  98,  25,  15,  26, 116,  18,  47], device='cuda:0')
argmax end logits shape: tensor([ 73,  15,  80,  64,  58,  23,  26,  75,  95, 102,  23,  12, 121,  18,
         29,  72,  11,  21,   9, 109,  42, 121, 105, 121,  70,  92,  79,  59,
         66, 126, 117,  61,  44,  18,  81,  79, 105,  47,  29,  28, 100,  17,
         84, 120,  69,  51,  55,  22,  69, 137,  22,   5,   2, 100,  48, 123,
          2,  47,  60,  68,  87, 121,   3, 105], device='cuda:0')
argmax start logits shape: tensor([ 43,   6,   4,  36,  90,  27,  59,   2,  63,  40,  55,   0,  11,  46,
         45,  35,   0,  32,  33, 110,  90,  56,  89,   4,  10,   3,  47,  34,
         52,  34,  46,  57,  93, 140

Evaluating:  98%|█████████▊| 308/313 [00:09<00:00, 33.39it/s]

tensor([ 30,  23,  35,  45,  40,  20,  27,  99,  30,   7,  41,  17,  45,  15,
          2,  40,  59,  17, 102,  94,   0,  91,  50,   0,  61,  41,  25,   0,
          2,  20,   1,   0, 109,  25,  31,  32,  15,  19,  87, 113,  49,  14,
         36,  34,  20,  92,  24,  47,  56,   4,  95,  12,  26, 119,  34,  26,
          4,  56,  29,  50, 111,  16,  23,  47], device='cuda:0')
argmax end logits shape: tensor([ 66,  49,  13,  93,  19,  31,  34,  83,  81,  39,  15,  70,  37,  47,
          3,  44,  21,  73, 116, 101,  32,  37,  87,  83,  36,  29,  30,  24,
        119,  33,   2,  81,  20,   4,  31,  25,  99,  95, 102,  55,  21,  85,
         67,  86,  21,   2,  36,  60,  41,  23,  95,  71,  87,  34,  10,  26,
         45,  57, 133,  96,  82,  56,  27,  76], device='cuda:0')
argmax start logits shape: tensor([ 83,   6, 101,  11,  45,  14,  21,  25, 100,  85,   8,  72, 114,   0,
        149,  76,  10, 116,  88,  27,  78,  14,   0,  99, 113,  34,  50,  56,
         41,  39,  76,   6,   3,  40

Evaluating: 100%|██████████| 313/313 [00:09<00:00, 32.97it/s]


argmax start logits shape: tensor([ 15,  86,  76,  85,  99,  32,  61,   2,  41,  14,  36,  12,   7,  18,
         22,  81,  33,  44,  35,  47,  50,   2,  23,  32,  22,  29,  18,  30,
         63, 123,  77,  17,  35,  99,  54,  11,  34,  71,  62,  13,  27,  61,
         77,  68, 100,  28,  81,  42,  55,   2,   0,  69, 108,  44,  64,  15,
          8,  94,  85, 129,  40,  19,  44,   3], device='cuda:0')
argmax end logits shape: tensor([ 66,  62,  80,   8,  19,  34,  62,  31,  49,  14,   4,  14,   9,  19,
         64,  18,  37,  58,  49, 124,  60,  63, 107,  28,  16,   2,  11,  30,
         15,  35,  95,   2,  36, 103,  55,  40,  35,  17,  35,  75,  97,   8,
        115,  14, 100,   7,  28,  42,  55,  83,  37,  36, 113,  83,  39,  72,
         93,  83,  87,  30,   2,  15,  29,  94], device='cuda:0')
argmax start logits shape: tensor([ 50,  42,   0,  30,  50, 105,  51,  43,  33,  49,  12,  45,  17,  78,
         89,   0,  94,   0,  75,  10,  37,  54,  24,  20, 102,  72, 103, 128,
         

Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]

argmax start logits shape: tensor([15, 15, 15,  0,  0, 20,  0, 13,  0, 13, 25, 25, 25, 25, 25, 33, 33, 33,
        33,  8,  8,  8,  2,  2,  2, 37, 37, 37, 37, 37, 34, 34],
       device='cuda:0')
argmax end logits shape: tensor([30, 30, 30,  5,  5,  6, 37, 23, 37, 37, 17, 17, 17, 17, 25, 45, 45, 45,
        45, 23, 23, 23,  6, 23, 23, 31, 31, 31, 31, 31, 30, 30],
       device='cuda:0')
argmax start logits shape: tensor([34, 28, 28, 28, 42, 42, 42, 13, 13, 15, 15, 13,  5,  5,  5,  5,  5,  1,
         1,  1,  1,  9,  0,  9,  0,  0,  9,  9,  9, 12, 12, 12],
       device='cuda:0')
argmax end logits shape: tensor([30, 58, 58, 58, 33, 33, 40, 15, 15, 15, 15, 15, 73,  2, 73, 73, 73,  1,
         1,  1,  1, 14, 19, 14, 19, 19, 14, 14, 14, 35, 35, 35],
       device='cuda:0')
argmax start logits shape: tensor([12,  2,  8,  2,  2,  2, 56, 59, 59, 56, 21, 21, 21, 20, 20, 83, 20, 42,
        69, 38, 69, 69, 69, 69, 47, 47, 47,  0,  0, 72, 18,  0],
       device='cuda:0')
argmax end logits shape:

Evaluating:  13%|█▎        | 8/63 [00:00<00:00, 70.35it/s]

argmax start logits shape: tensor([18, 72, 18, 72,  0,  0,  0,  0, 18, 45, 46, 46, 46, 46, 48, 45, 48, 46,
        48, 87, 87, 87, 57, 65, 65, 65, 57, 47, 99, 54, 61, 33],
       device='cuda:0')
argmax end logits shape: tensor([69, 69, 69, 69, 24, 24, 55, 55, 30, 65, 93, 93, 93, 93, 97, 97, 65, 93,
        65, 73, 73, 73, 17, 17, 17, 17, 67, 87, 54, 16, 87,  9],
       device='cuda:0')
argmax start logits shape: tensor([ 33,  93,  33,  83,  83,  83,  83,  51,  56,  56,  11,  57,  57,  57,
         57,  57,  61,  61,   0,  61,  61,   0,  61,   0,   0,   0,  84, 111,
        111, 111, 111, 111], device='cuda:0')
argmax end logits shape: tensor([  9,  87,  49,  50,  71,  50,  50,  50,  43,  43,   3,  29,  29,  29,
         29,  29,  43,  43,  53,  36,  36,  53,  43,  53,   6,   6,   6, 113,
        113,  12, 113, 113], device='cuda:0')
argmax start logits shape: tensor([ 41,  41,  41,  41,  41, 127,  70,  70, 105, 105, 111, 126,  93,  93,
         90, 105,  90,  90,  93,  90,   0,  67,  

Evaluating:  25%|██▌       | 16/63 [00:00<00:00, 69.12it/s]

argmax start logits shape: tensor([ 56,  61, 103,  56,  56,  56,  56,   8,  79,  79,  76, 107,  79,  76,
        107,  69,  79, 130, 130,  30,  20,  69,  45,  69,  48,   0,  69,  45,
        130,  30, 130,  69], device='cuda:0')
argmax end logits shape: tensor([113,  34,  57,  19,  19,  19,  19, 103,  79,  79,  76,  76,  79,  76,
         76, 108,  79,   3, 131,  91,  79, 118,  58, 118,  91,  93, 118,  58,
        131,  91,   3, 118], device='cuda:0')
argmax start logits shape: tensor([  0,  48,   0,  45, 130,   0,  16,  16,  16,  45,   0, 124,   9, 124,
         50,   9,  50, 124,  42,  50,  42, 124,  37, 124,   9,  42,  42,  42,
         42,  31, 130,   5], device='cuda:0')
argmax end logits shape: tensor([ 93,  91,  93,  58,   3,  93,  64,  64,  64,  58,  93,  97,  89,  97,
        139,  89, 139,  97,  89, 112,  89,  97,  89,  97,  89,  89,  89,  89,
         89,  27, 104,  83], device='cuda:0')
argmax start logits shape: tensor([ 63,  41,  33,  31, 130,  33,  68, 130,  63, 130,  77

Evaluating:  37%|███▋      | 23/63 [00:00<00:00, 67.67it/s]

tensor([ 41, 105, 105,  45,  45,  45,  45,   0,   0,   0,  45,   0,   0,   3,
          3,  24,   3,  39,  39,  65,  24,  65,  24,  51,  65,  65,  51,  51,
         43,  65,  24,  39], device='cuda:0')
argmax end logits shape: tensor([103, 103, 103,   3,   3,   3,  62,  79,  79,  79,   3,  79,  79, 140,
        140, 115, 140,  61,  61,  67, 115,  67, 115,  50,  67,  67,  50,  50,
        115,  67, 115,  61], device='cuda:0')
argmax start logits shape: tensor([ 3, 51, 39, 92, 92, 92, 92, 92, 65, 71, 96, 96, 71, 65, 71, 71, 65,  0,
        14,  0, 36,  0, 13, 49, 49, 72, 51, 51, 98, 98,  5,  5],
       device='cuda:0')
argmax end logits shape: tensor([140,  50,  61,  99,  99,  50,  50,  50,  66,  71,  66,  66,  71,  66,
         71,  71,  66,  85,  51,  85,  85,  85,  85,   8,   8,  99,  56,  56,
         98,  99,   8,   8], device='cuda:0')
argmax start logits shape: tensor([ 57,  12, 105,   5,  21, 121,  21,   5,   5,  72,  21,  57,  51,  57,
         57,  12,  98,  21,  57,  51,  51, 

Evaluating:  48%|████▊     | 30/63 [00:00<00:00, 67.07it/s]

tensor([90, 52, 90, 40, 40, 90, 90, 42, 85, 39, 39, 16, 39, 16, 42, 39, 39, 39,
        39, 39, 39, 16, 42, 42, 16, 16, 39, 92, 92, 92, 99, 92],
       device='cuda:0')
argmax start logits shape: tensor([ 74,  96,  27,  96,  17,  17,  17,  96,  17,  69, 112, 112,  64, 112,
        112, 112,  69,  69,  64,  83,  69,  64, 112,  64,  64,  83, 112,  83,
        112,  83,   4,   4], device='cuda:0')
argmax end logits shape: tensor([ 92,  99,  42,  99,  42,  42,  42,  99,  42,  60,  99,  99, 129,  99,
         99,  99,  60,  60, 129,  48,  60, 129,  99, 129, 129,  48,  99,  48,
         99,  48,   5,   5], device='cuda:0')
argmax start logits shape: tensor([  4,  80,  80,  63,  63,  76,  63, 134, 134, 134, 111, 111, 134, 134,
        134, 111, 134, 134, 111, 134, 111, 134, 103, 107, 103, 103, 124,  62,
        102, 124, 102, 124], device='cuda:0')
argmax end logits shape: tensor([  5,   5,   5, 129, 129, 129, 129, 100, 100, 100,  23,  33, 100, 100,
        100,  23, 100, 100,  23, 100,  23, 

Evaluating:  59%|█████▊    | 37/63 [00:00<00:00, 65.00it/s]

argmax start logits shape: tensor([ 71,   6,  71, 125,  12,  12,  12,  12,  12,  11,  75,  18,  11,  18,
          0, 118, 120, 107,  69, 107, 107, 107,  94,  94,  61,   6,   6,  55,
         55,  55,  55,  32], device='cuda:0')
argmax end logits shape: tensor([ 33,  76,  33,  45,  12,  12,  12,  12,  12,  76,  76,  19, 100,  19,
         30,  30,  30,  88,  38,  88,  88,  88,  11,  95, 124,  66,   6, 100,
        100, 100, 100, 100], device='cuda:0')
argmax start logits shape: tensor([ 31,  32,  32,  31,  31,   6,  31,   6,  31, 107,   6,   6,   6,  99,
        115,  25,  25,  99,  25,  99,  76,   6,  99,  76,  99,  93,  93,  25,
         84,  12,  99,  76], device='cuda:0')
argmax end logits shape: tensor([  8, 100, 100,   8,   8,  66,   8,   6,   8,   6,   7,  18,  18,  12,
        110,  20,  20,  26,  12,  26,   2,  18,  26,   2,  26,  86,  86,  20,
        138, 138,  26,   2], device='cuda:0')
argmax start logits shape: tensor([ 93,  99,  76,  84,  12,  76,  93,  99,  93,  76,  76

Evaluating:  70%|██████▉   | 44/63 [00:00<00:00, 65.37it/s]

tensor([ 99,   6,  99, 130, 130, 130, 130,  45,  45,   2,   2,  28,  28,  45,
        130,  28,  28,  27,   2, 106,  28,  45,   2, 106,  53,   2,  27, 106,
        106, 106,   8,  83], device='cuda:0')
argmax end logits shape: tensor([  5,   5, 113,  89,  89,  89,  89,  38,  38,  94,  94,  98,  98,  38,
         89,  98,  98,   2,  94, 130,  98,  38,  94,  64,   2, 129,   2, 130,
         56,  56,  43,  69], device='cuda:0')
argmax start logits shape: tensor([  8,   8,   8,  83, 103,  83,   8,  83, 103, 103, 103, 103, 154, 154,
         73,  93,  76, 154,  76,  73,  93,  73, 154, 154,  76,  76, 154, 154,
        154,  63,  73,  93], device='cuda:0')
argmax end logits shape: tensor([113,  21,  21,  84, 130,  84,  21,  84, 130, 130, 130, 130, 154, 154,
        115,  63,   1, 154,  16, 115,  63, 115, 154, 154,  32,   1,  51, 154,
        154,  46, 115,  63], device='cuda:0')
argmax start logits shape: tensor([ 93,  93, 154,  63, 154,  63,  63,  76,  60,  32,  60,  60,  60,  60,
         6

Evaluating:  81%|████████  | 51/63 [00:00<00:00, 63.89it/s]

argmax start logits shape: tensor([ 56,  15, 117,  15,  15, 117,   6,  25,  55,  55, 117,   6,   6,  55,
         25,  25,   6,  25, 117,   6,   6,  25,  55,   6,   6,   6,   6,  55,
         55,  55,  55,  55], device='cuda:0')
argmax end logits shape: tensor([ 57,  32,  83,  32,  32,  83, 104,  14,  33,  33,  83, 104, 104,  33,
         14,  14, 104,  83,  15, 104, 104,  14,  33, 104, 104, 104, 104, 109,
        109, 109, 109, 109], device='cuda:0')
argmax start logits shape: tensor([ 64,  42,  64,  42,  64,  42,  64,  42,  42, 145, 145, 145, 145,  72,
         24,  44, 145, 145, 145, 145, 145, 145,  72,  70,  25,  25,  72,  25,
          8,  94,   8,  94], device='cuda:0')
argmax end logits shape: tensor([121,  76, 121,  76, 121,  76, 121,  76,  76, 145, 145, 145, 145,  81,
         33,  33, 145, 145, 145, 145, 145, 145,  81,  81,  33,  33,  81,  33,
        115,  94, 115,  94], device='cuda:0')
argmax start logits shape: tensor([  8,  78,  78,  78,  78,  78,  77,  77,  77,   2, 110

Evaluating:  92%|█████████▏| 58/63 [00:00<00:00, 63.53it/s]

argmax start logits shape: tensor([ 56,  64,  64,  33,   2, 104, 117,  54,  33,  74,  74,  74,  74,   0,
          0,   0,   0,  74,   0,  98,  78,  78,  78,  78,  37,  32,  11, 109,
        109,  92,  92,  37], device='cuda:0')
argmax end logits shape: tensor([ 56, 104, 104,  34,  56,  91, 119,  98,  34,  74, 137,  74,  74,  45,
         45,  62,  45, 137, 106,  32,  67,  67,  67,  67,  26,  37,  32,  32,
         32,   7,   7, 128], device='cuda:0')
argmax start logits shape: tensor([ 78,  92, 123,  92, 140, 140,  70, 140,  21,  21,   0,  70,   0,  70,
         21,  70,  70, 148,   0,   0,  56,  56, 140,  21,  21,  56,   0,  29,
         29,  29,  73,  43], device='cuda:0')
argmax end logits shape: tensor([ 67,   7,   7,   7, 120, 120, 111, 148,  15,  15,  53, 111,  53, 111,
         15, 111, 111, 148,  53,  53,   2,   2, 120,  76,  15,   2,  53, 120,
        120, 120,  80,  89], device='cuda:0')
argmax start logits shape: tensor([43, 43, 43, 13, 29, 72, 13, 43, 73, 29, 72, 72, 13, 5

Evaluating: 100%|██████████| 63/63 [00:00<00:00, 64.89it/s]


argmax start logits shape: tensor([133, 133,  28,  28,  45,  45,  55,  55,  45,  45,  67,  55,  67,  67,
         67,   0,  75,  78,  78,  78,  75,   8, 119,  79, 119,  79, 119,  79,
        119,  79,   8,   8], device='cuda:0')
argmax end logits shape: tensor([ 40,  72,  80,  80,  47,  47,  80,  80,  47,  47,   7,  80,   7,   7,
          7,  40,  80,  80,  80,  80,  80,  20,  50,  59,  50,  59,  50,  59,
         50,  59, 113, 113], device='cuda:0')
argmax start logits shape: tensor([  8,   8,   8,  79,   8,   8,   8,   8, 115, 115, 115, 116, 116, 125,
        115, 125, 125, 140,  25, 116, 116, 140, 125, 125, 140,  25,  25,  34,
         25, 140,  34,  34], device='cuda:0')
argmax end logits shape: tensor([113, 113, 113,  59,  20,  20,  20,  20,   7,   7,   7,  28,  28, 125,
          7, 125, 125, 140, 110,  28,  28, 140, 100, 125, 140,  25,   9,  22,
          9, 140,   3,   3], device='cuda:0')
argmax start logits shape: tensor([140,  34,  25, 108, 108, 108, 108, 108,  79,  83, 145

Epoch 2/10:   0%|          | 0/313 [00:00<?, ?it/s, loss=7.56]

argmax start logits shape: tensor([110,  41,  18,  12,  18,  67,  17,  18,  26,  47,  97,  95,  84,  40,
         50,  49,   1,  19,  45,  60,  75,  15,  25,   4,   4,  18,   3,  11,
         60,  61,  21,  80,  26,   6,  13, 115,  25,   5,   6,   4, 109,  59,
        127,  88, 150,  47,  34,  87,   0,  41,  72,  73,  38,  52,  38,  79,
        122,  38,  63,  75,  46,  13,  37,  88], device='cuda:0')
argmax end logits shape: tensor([110,  42,  24,  61,  61,  50,  54,  67,  26,  37,  18,  56,  85,  27,
         69,  62,   1,  17,   8,   3,  60,  44,  71,  26,  47,  48,   3,  45,
         63, 108,  21,  81,  67,   6,   6,  71,  16,   8,  92,  50,  65,  59,
         95, 119,  69,  38,  39,  12,  14,  55,  74, 120,  39,  13,  38,  69,
        108,  17,  18,  71,   4, 124,  37,  74], device='cuda:0')
argmax start logits shape: tensor([ 54,  43,  36,  10,  30,  43,  55,  13,  82,  43,  20,   2,  41,  69,
         43,  40,   7,  10,   5, 110,  95,  44,  43, 100,  48,  65,  34,  33,
         

Epoch 2/10:   1%|          | 2/313 [00:00<00:19, 15.65it/s, loss=7.72]

argmax start logits shape: tensor([ 19,  45, 111,  56, 103,  84,   8,  78,  18,  11,  58,  75,  40,  29,
         36,  40, 123,  49,  44,  51,  43,  41,  54,  85,  40,  34,  37,  77,
         51,  52,  17,  81,  63,  37,  57,  81,  40,   2,  15,  94, 105,  25,
         20,  56,  54,  62,  27,  29,  52,  27,  86,  41,  27, 114,  43,  38,
         86,  51,  36,  15,  73,  17, 104,   5], device='cuda:0')
argmax end logits shape: tensor([ 32,  46,  95,  15,  36,  86,  40,  89,   8,  56,  46,  22,  85, 106,
        107,  88,  45,   3,  50,  19,  43,  82,  35,  85,  44,  14,  77,  53,
         21,  58,  84, 114, 104,  44,  61,  53,   4,  47,  43,  95,  45,   2,
          5,  70,  81,  14,  42,  31,   3,  69,  48,  71,  87, 103,   6,  17,
         86,  71,  38,  31,  24,  18,  94,  17], device='cuda:0')


Epoch 2/10:   1%|          | 2/313 [00:00<00:19, 15.65it/s, loss=7.79]

argmax start logits shape: 

Epoch 2/10:   1%|▏         | 4/313 [00:00<00:19, 15.47it/s, loss=8.02]

tensor([ 53,  76,  15,  46,  91,  18,  47, 102,   3,  86,  34,  49,  61,  18,
         70,  19,  95,  81,  50,  38,  47,  75,  70,   3,   2,  45,  42,  31,
          7,  64,  22,  91,  73,  21, 143,  68,  46,  93,  41,   3,  17,  36,
         12,  95,  13, 109,  46,  11, 130,  63,  53, 114,  76,  27,  66,  95,
         32,  86,  15, 138,   2, 101,   0,  34], device='cuda:0')
argmax end logits shape: tensor([ 23,   3,  14,  29,  12,  78,  58, 109,   3,  47,  34,  70, 100,  70,
        120, 108,  21,  59, 121,  38,  50,  79, 120,  15,  23, 136,  23,  35,
          7,  11,  20,  94,  69,  15, 107,  55,   4,  82,  41,  78,  57,  50,
        129,   7,  14,  57,  30,  67, 134,  37,  34,  40,  36, 114,  18,  39,
         56,   3,  58, 140,  38, 117,  13,  13], device='cuda:0')
argmax start logits shape: tensor([ 41,  26,  39, 130,  27,   0,  63, 128,  81,  46,  10,  10,  11,  59,
         10,   0,  37,   0,   9,   0, 106,  38,   2,  62,  37,  91,  87,  72,
         39,  10,  67,  10,  87,   5

Epoch 2/10:   1%|▏         | 4/313 [00:00<00:19, 15.47it/s, loss=7.94]

argmax start logits shape: tensor([ 17,   0,  62,  37,  56,   1,  55,  62,  91,  56,  34,   5, 101,   0,
         63,  54,  32,  87, 139,  85, 107,   9, 111, 104,  48,  69,  35,  46,
          0,   2,  72,  37,  15,  64,  40,   6,   0,  10,  53,  59,   0,   8,
         34,  91,  90,  51,   3,   0,   2,  23,  59,   3,  55,  68,  22, 119,
          0,   0,  69,  30,   0,   8,  10,   3], device='cuda:0')
argmax end logits shape: tensor([ 15,  17,  50,  36,  94,  43,  49,  27, 115,  38,  29,  37, 101,  88,
         47,  73,  71,  78,  29, 106,  45,  90,   5,  86,   2,   9,  56,   3,
         78,   3,  68,  38,  96,  70,  41,   9,  29,   7,  92,  75,  65,  43,
         78,  80,  34,  19,  15,  45,  37,  75,  59,  58,  22,  30,  22,  34,
         25,  50,  69, 107,  41,  20,   4,  65], device='cuda:0')


Epoch 2/10:   2%|▏         | 6/313 [00:00<00:20, 15.10it/s, loss=7.94]

argmax start logits shape: 

Epoch 2/10:   2%|▏         | 6/313 [00:00<00:20, 15.10it/s, loss=7.85]

tensor([ 15,  67,  72,  83,  41,  57,   3,  75,  98,   0,  60, 130,  75,  99,
         60,  69,   6,  84,  60,  11,  43,   0,  26,   3,  26,   6,  38,  53,
         94,  53,   2,  97,  28,  14,  11,  34,  16,  91,   6,  84,   4,  23,
         45,  24,   9,  61,  21,   2,   0,  62,  92,  25,  88,   0, 105, 108,
         92,  52,  94,  23,  63,  24,   9,  77], device='cuda:0')
argmax end logits shape: tensor([  6,  69,   5, 107,  41,  57,  33,   6, 119,   4, 121, 100,  11, 100,
         96,  20,  50,  30,  32,  28,  42,  34,   8,   4,  26,  57,  38,  56,
          4,  53,   3, 100,  26,  58,  30, 113,  17,  12,  25,  91,  45,  26,
         62,  35,  55,  43,  59,  45,   8, 105,  13,   2,  97,  33,  34,  26,
         92,  63,  71,  71,  53,  14,  75,  65], device='cuda:0')
argmax start logits shape: tensor([ 18, 102,  68, 108,  95,  81,  60,  68,   4,   2,  61,  10,  13,  18,
         55,  69,  53,  49,  97,  24,  11,   6,  44,  10,  74,  53,  16,   0,
          0,  48,  18,   8,  85,  19

Epoch 2/10:   3%|▎         | 8/313 [00:00<00:20, 15.19it/s, loss=7.67]

argmax start logits shape: tensor([ 78, 125,  87,  86,  34,  19,  93,  29,  85,  71,   2,  29,  78,   3,
         62,  88,   0,  48,  65,   6,  87, 133,  58,   0,   0,   3, 114, 123,
         85,  79,  18,  47,  24,  71,  38,  48,  23, 133,  15,  73,  28,  85,
         75,  62,  12,  56,  63,   2,  79,  69,  14,  69,  38,  52,   6,   7,
         44,   4,  28,   2,  48,  95,  36,  59], device='cuda:0')
argmax end logits shape: tensor([ 61, 125,  87,  10,  77,  99,  85,  91, 141,   6,  38,  20,  16,  43,
        100,  53,   7,  85,  71,  72,  78,  82,   9,  99,  78,  14, 115, 123,
          3,  63,  80,  58,  31,  42,  18,  99,  34, 111,   2,  75,  33,  86,
         21,   2,  45,  54,  64,  29,  54,  28,   3,  53,   5,  27,   9,   9,
        115,   7,  55,  38,  55,  85,  80,  30], device='cuda:0')


Epoch 2/10:   3%|▎         | 10/313 [00:00<00:19, 15.22it/s, loss=7.79]

argmax start logits shape: tensor([ 41,   2,  33,   2,   7, 116,  72,  25,   7,  26,  14,  14,  11,  42,
        110, 111, 157,  57,   3, 137,  35,  54,  81,  98,  57,  49,  31,   2,
         65,   7,   6,  50,  25,  23,  50,   2,  89,   0,  29,  59,  85,   2,
         30, 102,  77, 136,   7,  16,   3,  20,  81,  39,  37,   8,  72,   2,
         79,  57,   6,  41,   7,  48,  38,  29], device='cuda:0')
argmax end logits shape: tensor([  2,   5,  48,  49,  10,   3, 126,   4,  57,  18, 107,  20,  72,  49,
         45,  46,  74,  43,  92, 132,  65,  52,   7, 100,  58, 124,  99,   2,
          3,  43,  57,  10,  25,  80,  31,  42,   6,   6,  94,  70,  85,  53,
         13,  61, 108, 117,  43,  35,  92,  78,  15,   6,  80,  61,  60,  79,
         85,  81,  12,  24,  58,  42,  80,  31], device='cuda:0')
argmax start logits shape: tensor([ 81,  44,   9,  98,  35,  37,  49,  18,  24,  24,   3,  52,  78,  45,
         14,  21,  43,  48,  56,  79,  16, 110, 101, 111,   2, 116,  89,  10,
         

Epoch 2/10:   3%|▎         | 10/313 [00:00<00:19, 15.22it/s, loss=8.04]

argmax start logits shape: tensor([ 19,  13, 107,  98,  14,  47,  16,  99,  58,   7,  60,  24,  24,  22,
         54,  40,  17,  77,  33,  82,  31,   6,   6,  72, 107, 133,   5,   7,
         67,  92,  11,  68,   1,  55,  56,  35,  49,  35,  52,  28,  46,  57,
         34,  98,  30,   2,  49,  31,  99,  40,  32,  80,  98,  11,   6,  50,
          6,  27,  28,  71,   8,  81,  56,  56], device='cuda:0')
argmax end logits shape: tensor([ 14,  24,  64,  98, 109,  76,  81,   7,  75,  31,  37,  42,  72,  50,
         62,  46,  44, 107,  12,  97,  25,   3,  93,  48,  57,  36,  99,  57,
         66,  93,  21,  54,  21,  27,  57,  18,  31,  86,  53,  45,  78,  32,
        138,  78,  38,  59,  34,  44,  55,  99,  45, 106,  59,   9,   4,  51,
         14,  24,   4,  27,  78,  48,  42,  59], device='cuda:0')


Epoch 2/10:   4%|▍         | 14/313 [00:00<00:19, 15.20it/s, loss=7.75]

argmax start logits shape: tensor([ 12,  54,  72,   0,  44,  21,  31,  65,  13,  76,  21, 141,  53, 106,
         66,  18,   6,  38,  57,  43,   4,  14,  20,  49,  38,   0,  11,   6,
         35,  92,  45,  38, 105,  31,  58,  39,  14,  94,  30, 105,  49,  84,
         31, 106,  38,  79, 140,  36,   3,  38,  45, 115,   8,  60,  25,  58,
         16,  92,  95,  27,   0,  54,  85,  53], device='cuda:0')
argmax end logits shape: tensor([ 71,  41,   3,   9,  46,  21,  29,  22,  37,  80,  15, 148,   9,  42,
          5,  42,  70,  38,  26,  17,  85,  14,  50,  41,  18,  83,  21,   2,
         38,  23,  12,  67, 112,  59,  34,   6,  84,  16,  43,   3,  48,  85,
         33, 107,  67,  79,  54,  74,   2,  28,  60,  16,  33,  36,  15,  37,
          7,   8,  15,  36,  62,  81,  18,  10], device='cuda:0')
argmax start logits shape: tensor([ 49,  59,  18,   0,  40,  90,  23,  79,  12,  25,   2,  21,   0,  72,
         74,   1,  39,  78,  28,  34,  17,  43,  64,  85,  50,   0,  85, 107,
         

Epoch 2/10:   5%|▌         | 16/313 [00:01<00:19, 15.25it/s, loss=7.84]

argmax start logits shape: tensor([ 84,  65,   6,  87, 121,  48,  40,  21,  56,   0,   0,  13,  44,   0,
          0,  54,  54,   6,  47,  31,   0,  90,  92,  80,   0,  68, 110,  90,
          0,  41,   2,  13,   0,  78,  64, 107, 114,  49,   2,   5,   0,  91,
         10,  34,   2,  71,   1,  18,  96,   3,  85,   0,  32,  10,   0,  34,
          0,  23,   3,  50,   1,  75,   3,  51], device='cuda:0')
argmax end logits shape: tensor([ 52,  65,  14,   2,  76,  50,  37,  22, 101,  98,   7,  18,  33, 112,
         50,   8,  54, 115,  35,  32, 112,   9,   9,  62,  17,  28,   7, 100,
         49,  52,   6,  13,  11,  20,  64,   8,  74,   2,  26,   8, 113, 115,
         26,  57,  47,  92,   3,  18,  75,   3,  81,  17,  56,  77, 103,   8,
        141,  65,  84,  24,  18,  38,   4,  50], device='cuda:0')
argmax start logits shape: tensor([ 10,  52,  33,   0,  63,  14, 109,  14,  27,  63,   1,   0, 135,  77,
         44,   0,   0,  55,  80,  47,  31,  44,  16, 102,  20,  51,  88,   8,
         

Epoch 2/10:   6%|▋         | 20/313 [00:01<00:19, 15.32it/s, loss=7.9] 

argmax start logits shape: tensor([  7,  52,  53,  12,  40,   1,  22,   2,  75,  25,  61,  91,  22,  92,
         73,   3, 107,  57,  13,  18,  17,  23,   2,   2,  41,  19,  41,  39,
         76,  59, 110,  72,  15,   7,  38,  48,  55,  18,   2,  74,  26,  74,
         53,  77,  43,  94,  90,  23,  85,   0,  28,  47,  16, 138,   6,   6,
        104,  82,   2, 104,   4,  21,  16,   0], device='cuda:0')
argmax end logits shape: tensor([ 66,  23,  84,   3,  13,  38,  77,  21, 111,  47,  32, 117,  16,  46,
         78,  80,  60,  57,  13,  42, 102,   2,   3, 119,  79,  34,   5,  17,
         79,  65,  12,  45,  76,  19,  52,  22,  55, 108,   5,  25,   1,  47,
         25,  49,  35,  27,  40,  11,   6,  30,  56,  40,  25, 123,  53,  30,
         99,   1, 104,  99,  38,  65,  16,  44], device='cuda:0')
argmax start logits shape: tensor([103,  50,  47,  30,   2,  93,  32,  86,  50,  59,  71,  62,  18, 107,
        117,  60,  50,  75,  11,  74,  52,  81,   1,  13,  27,  93,  24,  90,
         

Epoch 2/10:   7%|▋         | 22/313 [00:01<00:19, 15.29it/s, loss=7.69]

argmax start logits shape: tensor([ 30,   1,  34,  85,  50,  46,  94,  20,  65,  75,  53,  25,  88,  22,
          5,  79,  33,  67,   3,  90,  33,  30,   5,  11,  21,  20, 128, 110,
         46,  26,  29,  99,   3,  21,   6,   9,   9,  93,  54,  88,  87,  95,
         15,  83,  40,  27,   2,  12,   6,  45,  83,  59,  39,  78,  13,  57,
         43,  46,  42,   2,  54,  67, 107,  98], device='cuda:0')
argmax end logits shape: tensor([ 65,  59,  55, 104,   8,   3,  15,  77,  67,  27,  72,  76,  50,  23,
         31,  65,  17,  69,   3,  19,  29,  15, 130,  25,  42,  20,  12,  20,
          4,  14,  56,  27, 104,  74,  32,  38,  68,  96, 102, 122,  65,  23,
         50,  24,   5, 101, 100,   3,   9,  21,  60, 111,  53,  62,  26,  50,
         57,  15,  94,  69,  61,  20,  12,  45], device='cuda:0')
argmax start logits shape: tensor([ 79,  10,  40,  72, 109,  42, 138,  79, 137, 100,  79,  47, 121,  32,
         55,  22,  37,  63,  20,  41,  13, 101,  17,  35,   0,  23,  27, 100,
         

Epoch 2/10:   8%|▊         | 26/313 [00:01<00:18, 15.12it/s, loss=7.75]

argmax start logits shape: tensor([  0,  59,   0,  27,  67,   0, 115,  14,  21,  85,  39,  31,   2,  62,
          0,   0, 102,  76,   0,   0,  58,  45,  27,   8,  40,  11,  79,   0,
          0,  47, 102,  45,  52,  84,  87,  59,  61,  96, 133,  56,  21,  86,
         84,  85,   0,  91,  29,  12,  39, 117,  19,  81,   0,   0,  71,  52,
         96,  15,   9,  34, 122, 133, 114,  88], device='cuda:0')
argmax end logits shape: tensor([ 57,  24,  39,  55,   2,  45,   2,  29,  47, 113,  68,  53,  35,   6,
         78,  45,   4,  40,  20,  50,   8,  47,  23,  35,  28,  41,  74,  34,
         38,  93, 104,   9,  62,  75,  59,  63,  70,  75,  36,  72,  33,  87,
         27, 120,  49,  14,  32, 109,  14,   5,  44, 100,  73,  48,  35,  33,
         67,   1,  11,  33,  76, 110,  72,  57], device='cuda:0')
argmax start logits shape: tensor([ 39,   2,  95,  10,  30,   0,  84,  33,   3, 102, 101,  18,  27,  92,
          0,  74,  41,   6,   0,   0,  21,   0,  51,  57,   5,  25,   0,   2,
         

Epoch 2/10:  10%|▉         | 30/313 [00:01<00:18, 15.12it/s, loss=7.79]

argmax start logits shape: tensor([ 37,  57,   3,  79,   0,  77,  46,  41,  30,   0,   9,  80,  66,  20,
        127,  71,   8,  53,  26,  83,   5,  32,  74,  96,  24,   5,   6,  81,
         10,   0,  23,  97,  96,  37,  10,  13,  84,   2, 115,  34,  27,  65,
         61,  46,  63,  49,  37,  45, 143,   7,  83,  17,  14,  42,  80,   0,
         92,  51, 116, 106, 116,   0, 122,  11], device='cuda:0')
argmax end logits shape: tensor([  7,  35,   3,  92,  20,  77,  37,  45,  26,   5,  75,  73,  51,  50,
        128,  35,  74,  15,   2, 114,  50,  68,  74,  96,  22,   7,  79,  23,
          4,   3,  82,  97,  50,  37,  88,  16,  65,  76,   2,  20,  34,  12,
         44,  46,  68,  21,  38, 107,   4, 103,   2,  15,  15,  43,  81,  26,
         93,   4,  43,  79,  22,  12, 123,  73], device='cuda:0')
argmax start logits shape: tensor([116,  34, 116,  81,  29,  37,  76, 100,  16,   5,  29,   0,  80,  45,
         39,  37,  43, 137, 100,  79,  44,  27,  83,  32,   7,  40, 131,  56,
         

Epoch 2/10:  10%|█         | 32/313 [00:02<00:18, 15.22it/s, loss=8.13]

argmax start logits shape: tensor([ 73,  13,  48,  43,  55,  41,   2,  34,  32,  54,  52,  29,  19, 117,
         11,  43,  41,  17,  68,   3,  75,  72,  44,  35, 123,  35,   3,  42,
        132,  44,  79,  59,  40,  69,  15,  14,  47,  43,   4,   2,  19, 109,
         50,  18,  79, 112,  77,  22,  41,  34,  68,   9,  18,  79,  85,  16,
         13,  66,  55,  22,  37,   4,  72,  20], device='cuda:0')
argmax end logits shape: tensor([106,  25,  76,  46,   3,  30,   4,  56,  40,  54,  91,  37,   8,  76,
          3,  89,  16,  18,  32,   3,  18, 116,  90,  65, 123,  92,  81,  55,
        136,   5,  79,  98,  61,  57,  93,  14,  68,  55,   4,   2,  19, 108,
         52,   9,  46, 112, 111,  36,  55,  94,  69,   9,  14,  22,  77,  34,
         37,   3,  59,  22,  18,  38,  48,  33], device='cuda:0')
argmax start logits shape: tensor([ 62,  96, 139,   3,  35,  38,  20,  65,  50, 137, 127,  49,  33,  41,
          8,  19,  92,  20,  75,  35,  15,   2,  25,  39,  71,  15,  16,  28,
         

Epoch 2/10:  12%|█▏        | 36/313 [00:02<00:18, 15.19it/s, loss=7.58]

argmax start logits shape: tensor([ 69, 112,   6,  65,  51, 103,   1,  73,  61,  42,  21,  21,  72,  10,
          2,  73, 128,  10,  27,  42,  91,   6,  99,  44,  40,  44,  71,  10,
         42,  39, 113, 104,  58, 103,  33,  54,  92,  10,  61,  26,  35,  77,
         94,  38,  28,  18,  60,  10,  88,  26,   7, 118,   0,   0,   8,   6,
        114,   9,   0,  39,  32,  44,   0,  67], device='cuda:0')
argmax end logits shape: tensor([ 71, 104,  50,  78,  51, 104,  31,  73,  67, 109,  65,  40,  18,  15,
         61,  58, 110,  20,  49,   5,  91,  80,  80, 109,  21,  20,  87,  11,
         60,  29,  44,  53,  57, 104,  17,  38, 131,  87,  48,  43,  65,  17,
         80,  14,  30,  20,  29,  46,  40,  12,  73,  57,  47,  88,  30,  80,
         87, 116,  34,  50, 110,  28,  12,  48], device='cuda:0')
argmax start logits shape: tensor([ 79,   6,  18,   9,  53,  73,  45,  40,  50,  18,  17,  80,  97,  66,
         34,  51,  80,  24,  39,  12,  45,  89,  24,  55,  93,  41,   2,  28,
         

Epoch 2/10:  12%|█▏        | 38/313 [00:02<00:17, 15.28it/s, loss=8.16]

argmax start logits shape: tensor([  7,  19,  72,  68,  34,  86,  42,  40,  91,   3,  42,  84,   3,  14,
         46,  42,  64,  91,  45,   2,  36,   2,   4,  33,  14, 119,  41,   0,
         76,  80,   2,  41, 118,   4,  77,  17, 102,  73,  98,  80,  20,  65,
         63,  47, 106,  34,  13,  41,  98,  33, 114,  40,  56, 104,  22,  13,
         21, 119,  79,  15,   9,  60,  60,  20], device='cuda:0')
argmax end logits shape: tensor([ 72,  24, 120, 106,  16, 106,  44,  41,  94,  35,  95,  83,   3,  85,
         31, 111,  37,  65,  13,  46,  17,  77,  40,  37,   8, 109,   3,  61,
        109,  81,   3,  35,  83,   4,   5,  70,  61, 115,  86,  37,  12,   2,
         33, 111,  94,  49,   4,  33,  44,  62,  83,  55,  23,  55, 118,  32,
         44,  15,   6,   3,  30,  58,  44,  84], device='cuda:0')
argmax start logits shape: tensor([  0,  85,  85,  71,  52,  17,  10,   0,  24,  47,  16,  18,   2,  33,
         46,  76,  59, 128,  47,  40,   0,   0,   2,  17,  77,   0,  64,   2,
         

Epoch 2/10:  13%|█▎        | 42/313 [00:02<00:17, 15.32it/s, loss=7.7] 

argmax start logits shape: tensor([ 53,  71,  60,   0,   4,   0,  65,   0,  76,  49,  30,   0,  16,  28,
          3,  17,  29,  11,   0,  41,  81,   0,   0,   0,  22,  74,  20,  33,
         87,   0,  66,  50,  10,  79,   0,  30,   0,   0,   0, 101,  70,  23,
         45,   0,  42,   0,   0,   2,  36,  10,  80, 114,   0, 121,   0,  75,
          0,  15,  88,  15,  45,  54,   0,   0], device='cuda:0')
argmax end logits shape: tensor([ 38,  11,  61,  63,  45, 131, 117,  44,   5,  69,  48,  28,   5,  47,
        114,  84,  24,  73,  51,  14,  71,  33,  16,   9,  81, 138,  26,   4,
         69,   4,  73,   2,   2,  15,  25,  65,  71,  40, 121,  93,  62,  24,
         75,  57,  90,  35,  77,   8,   4,   3,  26,  41,  29,   5,  18,   4,
         39,  63,   2,  14,  11,  23,  98, 121], device='cuda:0')
argmax start logits shape: tensor([108,  17,  36,  31,  49,  24,  67,  77,  14,   2,   0,  39,   9,  66,
         93,  79,  41,  60,  85,   0, 105,  21,  54,  28,  33,  95,  15,  39,
         

Epoch 2/10:  15%|█▍        | 46/313 [00:03<00:17, 15.39it/s, loss=8.13]

argmax start logits shape: tensor([ 39,  93,  25,  49,  26,  32,  66,  18,   2,  80,  51,  80,  53,  75,
         35,  60,  77,   2,  89,  81,   3,   7,  61,   9, 108,  21,  70,   9,
         96,   1, 123,  41,   8,  22,  56,   5,  25,  63,  11,  11,  39,  60,
          4,  25,  90,  16,  27,  18,   2, 106,  28,  35,  40, 121,  11,  85,
         71,  25,  17,  87,   7,  59,  30,  12], device='cuda:0')
argmax end logits shape: tensor([ 29,  93,  34,  33,  15,  18,  82,  18,   2,  27,  69,  98,  84,  28,
          9, 132,  91,  35,  89,  32,   3,   7,  63,  37,  99,  31,  19,   9,
         96,   1,  41, 105,  52,  43, 109,  33,  25,  61, 118,  11, 102,   2,
          4,  60,  91,  35, 121,  52,   2,  27,  53,   8,   7, 121, 120,  92,
          3,  18,  30,  13, 134,  14,  48, 110], device='cuda:0')
argmax start logits shape: tensor([ 14,  16,  20,   9,  77,  49,   0,  93, 133,   3,  51,  34,  60,  21,
         54,  63,  10,  73,  13,  50,  80,   4,  63,  54,  13, 107,   3,  79,
         

Epoch 2/10:  16%|█▌        | 50/313 [00:03<00:17, 15.37it/s, loss=7.84]

argmax start logits shape: tensor([ 65, 119,  44,  22,   2, 103,   0,  16,  86, 102,  40,  36,  21,   0,
          0,  62,  37,  58,  17,   0,  79,  90,  76,   0,   0, 128,  44,  86,
          0,  37,   0,  80,  68,  13,  71,   0, 113,  22,   0,  33,  58,  50,
         34,   2,  51,   2,  92,  24,  18,  91,   0,  74,  94,  56,   0,  27,
          0, 119,  96,   0,   1,  52, 110,   0], device='cuda:0')
argmax end logits shape: tensor([ 65, 124,  44,   5,   2, 112, 115,  70,  86, 112,   8,  53,  77,  46,
         40,  46,  35,  98,  17,  23, 125,  79,  76,  40,  57,  89,  92,  82,
         82,  31,  98,  11,  68, 121, 125,  49, 110,  61,  12,  67,   4,  51,
         21,   3,  97,   2,  59,  24,  63,  97,  43,  76, 108,  90,  65,  88,
         36,   3,  96,  21,  52,  10,   2,  93], device='cuda:0')
argmax start logits shape: tensor([ 50,  54,  70, 101,  16,   8,  38,   1,  49,  45,   0,   2,  81,  42,
          0,  69,  44,  34,  47,  20,  19,   4, 108,  72,  90,  11,  65,  14,
         

Epoch 2/10:  17%|█▋        | 52/313 [00:03<00:16, 15.38it/s, loss=7.81]

argmax start logits shape: tensor([  9,   0,   1,   3,  34, 111,  22,  87,  99,   3,  59,   0,  75,  95,
         78,  16,   2, 107,  19, 118,  81,  88,  87,   2,  33,  54,  53,  51,
         18,  76,  17,  57, 107,  68,   2,  26,  69,  27,  20,  87, 123,  51,
         11,  81, 108,  16,  22,  41,  57,   0,  66,  85,  65,  39,  59,   2,
         79,  70,   2,  76,  12,  47,  20,  94], device='cuda:0')
argmax end logits shape: tensor([ 28,  11,   5,  12,  55, 126,  43,  25,  51,  27,   4,  29,  49,  96,
         24, 121, 120,  48,  16,  33,  81,  86,   9,  17,  60,  33,  67,   3,
         92,  48,   6,  59,  20, 111,  12,  19,  58,   3,  35, 105,  75,  62,
         11, 106,  72,  59,  14,  41,  67,  50,  49, 121,  71, 120,   5,   2,
         79,  14, 123,  29,  78,  94,  84,  25], device='cuda:0')
argmax start logits shape: tensor([ 12,  98,  34,  17,   5, 106, 127,  15,  60,  88,  64,  44,   3,  55,
         11, 115, 122,  32,  71,  63,  44,   2,  67,  19,  64,  66,  56,  37,
         

Epoch 2/10:  18%|█▊        | 56/313 [00:03<00:16, 15.32it/s, loss=7.7] 

argmax start logits shape: tensor([ 66,  39,  94,  19,  65,  69,   2, 118,  46,   0,  35,  51,  30,  98,
         11,  60,  63,   0,  14,  89,  52,  12,  55,  41,  98,  45,   0,  17,
         56,  35,  39, 100,  46,   5,  88,  59,   0,  33,  41,  30,  47,  47,
         34,   4, 109,   0,  90,  54,  34,  59,   1,  64,   0,  75,  29, 100,
         61,  88, 101,  43,   0, 105,  12,  60], device='cuda:0')
argmax end logits shape: tensor([ 31,  94, 108, 106,  78,  88,  49,  12,  86,  25,  61,  10,  31,  59,
         11,  50,   2,  83,  14,  76,  12,  57,  16,  36,  99,  33,  15,   6,
         31,   3,   9,  18,  16,   3,  97,  80,   4,  34,   3,  30,  40,  39,
        111,  34,  83,  41,  56,  29,  35,  20, 101,  64,  26, 115,  32,  76,
         21,  43, 117,   3,  20,  23,  50,   4], device='cuda:0')
argmax start logits shape: tensor([ 86,  40,  91,  38,  31,  44, 119,   0,  11, 120,  91,   3,  56,   0,
          0,  48,  83, 135,  36, 100,   0,   0,  18,  55,  66,  23,  25,  33,
         

Epoch 2/10:  19%|█▊        | 58/313 [00:03<00:16, 15.29it/s, loss=7.54]

argmax start logits shape: tensor([  2,  58,  56,  31,   0,  45,  99,   0,  85,  86, 101,  47,   2,   0,
         14,   0,   8,   0,  32,  59, 104,   0,   0,  48,  80,  33,   0,   0,
         32,  91,  15,  81,   0,  39,  71,  61,   0,   0,   0,   3,   0,  19,
         25,   0,  42,   0,   0,  50,   0,  47,  17,   0,  68,   0,  83,  87,
          0,   0,  27,   0,  56, 117,  37,  24], device='cuda:0')
argmax end logits shape: tensor([ 99,  89,  27,  47,  46,   3,  38,   7,  18,  65,   9,  76,  53,  15,
         25,   6,  53,  13,  25,  88,  53,  10,   3,  45,  77,  74,  24,   3,
         33, 113,  25,  68,   7,  78, 117,  21,  92,  13,  98,   4,  49,  59,
          5,  27,   7,  35,  57,  98,  98,  38,  67, 100,   3,   8,  13, 136,
         44,  42,  27,  79,  59,  22,  50,  24], device='cuda:0')
argmax start logits shape: tensor([ 19,   5,  16,   0,  35,  31,   0,  62,  73,  14,  35,  20, 125,  84,
         19, 104,   0,  19,   0,  21,   2,  37,   0,  75,   0,  99,   0,   0,
         

Epoch 2/10:  20%|█▉        | 62/313 [00:04<00:16, 15.24it/s, loss=7.65]

argmax start logits shape: tensor([129,   0,   0,   2,  22, 100, 113,  13,   0,   7,   0, 121,   0,  33,
         19,  52,  45,  69,   7,  41,  67,  43, 101,  79,   0,  83,  48,  63,
         37, 107,  43,   8,  31,  34,   0,  56, 127,  19,   2,  35,   0,  46,
         43,  50,  51,  44,  83,  54,  38,  32,  26,   7,  51,  30,  39,  11,
         22,  60,   0,  44, 118,   2,   0,   7], device='cuda:0')
argmax end logits shape: tensor([ 33,  38,   5,   2,  23, 118, 108,  24,  35,   7,  69,  41,  86,  33,
         76,  52,  95,  81,   7, 153,  11,   4,  76,  37,   5,  11,  17,  67,
         15,   8,  44,  28,   1,  64,  96,  57,  28,  16,  22,  41,  39,  45,
         26,  81,  51,  78,  32,  68,   3,  68,  88,   7,  60,   3,   6,  11,
         65,  11,  81,   3,  33,  21,   9,  27], device='cuda:0')
argmax start logits shape: tensor([ 87,  63,  51,  31,   8, 100,   0,  39,  98,  48,   7,   0,   4,   0,
         44,  44,   0,  82,  16,  94,  35,  65,  33, 120,   0,   2,  70,  34,
         

Epoch 2/10:  20%|██        | 64/313 [00:04<00:16, 15.31it/s, loss=7.91]

argmax start logits shape: tensor([  0,   0,   2,  42, 102,  55,  65,   3,  41,   0,  52,   2,  45,   5,
         77, 146,   0,  57,  61,  94,  20,  85,   7,  78,  29,  24,   0,  33,
         44,  16,  65,  59,   0,   0,   2,  36,  20,  42,  54,  21,  29,  29,
        109,   0,  10,   0,  11, 119, 107,  45,  54,  44,  62,  34, 109,  90,
         11,   2,   4,  54,  24, 100,  20,   0], device='cuda:0')
argmax end logits shape: tensor([ 77,   2,  52,  16, 102,  55,  35,   4,  40, 119,  26,  32,  86,  77,
         77,   5,  49,  16,  30,  45,  11,  13,  53,  12,  29,  24, 100,  13,
         44,  16,  76,   1,  75,  53,   2,  20,  20,  45,  54, 103,  64,  39,
        140,   2,   4,  65,  63,  58, 107,  25,  49,  78,  74,  37,  44, 120,
          3,  42,  11,  68,  88,  22,  21,  12], device='cuda:0')
argmax start logits shape: tensor([ 11,  61,  60,  70,  97, 106,  63,  15,  30,  94,  69,  35,  43,   9,
        105,  25,  53,  41,  88,  21,  60,  39,  62,  32,  88,  15,  54,  14,
         

Epoch 2/10:  22%|██▏       | 68/313 [00:04<00:15, 15.35it/s, loss=7.93]

argmax start logits shape: tensor([114,  52,  85,  36,  52,  54,  16,  12,  13,  96,  81,  50,  34,   6,
         34,  44,  32,  92,  27,  40,  26,  53,  11,  31,   5,  75,  43,  47,
         35, 109,   7,  17,  33,  11,  93, 100,  44, 122,  78,  68,  22,  47,
         92,  17,   6,  15,  43,  55,  51, 104,  90,  16,  77,  73,  56,  39,
         90,  53,  60,  85,  10,  63, 108,   0], device='cuda:0')
argmax end logits shape: tensor([112,  35,  50,  25,  39,  33,  67,  44,  88,  31,  59,   4,  10,  84,
         10,  27,  48,  24,   6,  56,  39,  53,  71,  12,  31,  29,   4,  37,
         11,  77,   9,  70,  41,  42,  72,  52,  38,  63,  43,  67,  11,  10,
          8,  17,   2,  36,  15,  15,  36,  27,  61,  20,  52,  89,  59,   2,
         87,  82,  49,  81,  11,  63,  18,  64], device='cuda:0')
argmax start logits shape: tensor([119, 125, 128,  87, 112,  89,  87,  86,  50,  72,   2,  80,   0,  35,
          0,  85,  90,  70,  14,  33,  57,  75,  20,  62,   8,  61,  37,  31,
         

Epoch 2/10:  23%|██▎       | 72/313 [00:04<00:15, 15.07it/s, loss=7.71]

argmax start logits shape: tensor([ 31,  19,   0,  45,  74, 119,  77,  78,  27,  47,  47,  40,   2,  52,
         68,  58,  98,  16,  79,  13,   7,  26,  86,   7,   0,   0,   0,   0,
          0,  60,  53,  18,  30,  11, 149,   6,   2,  79,  20,  27,  85,  11,
         16,   0,   8,  51,  70,  50,  54,  37,  78,  64,  27,  29,  15,  20,
         53,  47,  42, 103,  89,  30,  14,  19], device='cuda:0')
argmax end logits shape: tensor([ 44,  89,  77,  54, 107,  62,  38, 106,  25,  98,  52,  48, 104,  11,
         61,  36,   3,  36,  67,  13,   6,  16,  43,  50,  51,  85,  37,  90,
         33,  92,  53,  50,  65,  30, 141,   2,  15,  11,  26,  26,  77,  40,
         52,  35,  36,  22,  35,  52,  39,  59,  77,   7,  26,  58,  28,  41,
         43,  51,  22,  67, 101,  30,  17,  41], device='cuda:0')
argmax start logits shape: tensor([ 27,   0,  24,  78,   0,  66,  46,   2,  40,   3,  21,  43, 107,  58,
        109,  66,  63,  97,  50,  31,  13,   0, 120, 105,  57,  75,  75,   2,
         

Epoch 2/10:  24%|██▎       | 74/313 [00:04<00:15, 15.06it/s, loss=7.77]

argmax start logits shape: tensor([109,  59,  42,  12,   0,  16,  17,  98,  63,  89,  87,  27,   0,  36,
         53, 105,   3, 146,  63,  79,  71,  68,  55,  56,  31,  18, 100,   4,
         50,   3,  47,  96,  97,   8,   0,  76,  41,   3,  81,  70, 115,  95,
         36,   2,  78,   8,   0,  15,  67,  12,  85,  81,  96,  46,  20,   2,
         23,   0,  78,   0,   0,  64,  60,  10], device='cuda:0')
argmax end logits shape: tensor([112,  17,  99,  13,  65,  85,  17,   7,  86,  89,  15,  15,  90,  26,
         67,   3,  11, 147,  18,  79,  58,   7,  56,  56,  31,   2,  22,  93,
         62,   3,  76, 109,  99,  92,  11,  10,  82,  65,  82,  91,  90,   3,
         57,  29,  59,   6,  37,  65,  36,   3,  56,  16,   8,  55,  14,  52,
          3,  51,  87,  38,  76,  47,  39,  77], device='cuda:0')
argmax start logits shape: tensor([109,   3,  56,  53,   0,  27,  19,   0,  82,   0,   3,  56,  92,  17,
         48,  11,  24,  39,  57,   5,  65,   0,   0,  27,  43,  83,   0,  24,
         

Epoch 2/10:  24%|██▍       | 76/313 [00:05<00:15, 15.17it/s, loss=7.64]

argmax start logits shape: tensor([ 19,  51,  90,  66,   4,  20,  70,  71,  22,   0,   0,  32,  93,   3,
         33,  93,  23,   0,   0,  30,  29,   4,  32,   0,  86,  21,  64,  25,
         34,   0,   0,  37,  89,   0,  16,   0,  54,  35,   0,  23,  71,  11,
          0,  46,  14,  52, 104,  77,  54,  20, 106,   8,  79,   0,  23,  91,
         71,  61,   0,  43,  71,  46,   0,   0], device='cuda:0')
argmax end logits shape: tensor([ 78,  88,  20,  78,   3,  42,  34,  90,  53,  11,  90,  38,  89,  87,
         60,  57,  14,  48,  22,  57,   3,  17,  33,   5,  13,  42,  64,  25,
        127,  75,  14,  40,  76,  33,  15,  31,  54,  39,   2,  81,  46,  80,
         29,  48,  33,  52,  67,  26,  68, 102,   4,  75,  40,  63,  30,  91,
        104,   7,  52,  18,  30,  29,  57,  42], device='cuda:0')
argmax start logits shape: tensor([ 75,  71,   2,  76,   9, 107,  39, 117,   0,   0,  55,  35, 108,  70,
         23,   0,   3,  75,  60,  92, 137,   3,  21,  88,  50,   7, 142,   0,
         

Epoch 2/10:  26%|██▌       | 80/313 [00:05<00:15, 15.16it/s, loss=7.99]

argmax start logits shape: tensor([ 78,  80,  46,  82,  84,   2,  90,   0,  90,  69,  27,  70,  26,  10,
          3,   0,   1,  23,  18,   8,  53,  84,  39,  25,  46,  60,  36,  37,
         44,  26,  10,  75,   0,  45,  31,  51,   5,  13,   0,  64,  87,  94,
        133,  14, 100,  74,  76,  52,   0,  12, 112, 105,  49,  15,  72,  97,
          8,   0,  69,  37, 110,  24,  20,   0], device='cuda:0')
argmax end logits shape: tensor([ 62,  15,  42,  47,  94,  20,  93,  73,   6,  24,  15,  30,  22,  10,
         23,  68,  97,  43, 102,   7,   9,  41,  42,  73,  46,  22,  73,  20,
         58,   8,  10,  70,  50,   4,  34,  76, 130,  14,  10,  58,  38,  97,
         13,  14, 112,  74,  35,  14,  68,  14,  49, 105,  49,  83, 109,  97,
         10,  86,   5,  37,   5,  29,  12,  67], device='cuda:0')
argmax start logits shape: tensor([ 51,  97, 106,  47,  87,  69,  77,  55,  16,   0,  65,   5,   8,  15,
         88,  55,  92,  47,  73,  24,   4,  25,  41,   8,  65,  49,   2,  87,
         

Epoch 2/10:  27%|██▋       | 84/313 [00:05<00:15, 15.05it/s, loss=7.64]

argmax start logits shape: tensor([ 14,  26,  37,  75,  25,   8,  16,  30, 104,  19,  89,  54,  42,  29,
         24,  15,   7,   5,  66, 110, 110,   3,  59,  16,  11,  21,  13,  42,
         33,  18,   4,  25,   0,  82,  17,  78,  33,  92,  75, 118,  55,  25,
         51,  43, 109,  10,  10,  30,  78, 113,  61,  25,  61, 138,  16,  72,
         67,  41, 107,  63,  12,  70,   2,   0], device='cuda:0')
argmax end logits shape: tensor([ 23,  64,  50,  94,  13,   5,  15,  81,  73,  62,  44,  43,  39,  42,
         11,  53,  75,   5,  65,  62,  62,   3,  73,  16,  39,  23, 123,   5,
          3,  18,  17,  81,  57,  34,   2,  78,  28,  77,  35,  52, 109,  12,
         12,  23, 112,  98,  12,  30,   8, 114,  43,  16,  31, 138,  13, 121,
          2,  67,  49,  62,  33,  16,   4,  62], device='cuda:0')
argmax start logits shape: tensor([ 17,  15,  78,  78,  44,  99,  47,  37,  97,   4, 100,  40,  87,  45,
         49,   9,  20,  93,  71,  54,  71, 109,  14,  62,  17,  85,  40,  29,
         

Epoch 2/10:  28%|██▊       | 88/313 [00:05<00:14, 15.27it/s, loss=7.54]

argmax start logits shape: tensor([ 78,  28,  39,   8,  28,  53,  10,  41,  30,   6,  38,   3,  73,  11,
         50,  17,  41,  43,  73,  13, 143,  53,  68,  91,  32,  25,  44,  65,
         90,  67,  81,  41,  10, 114,  14,  36, 109,   9,  65,  30,  29,  72,
         86,  11,  10,  55,   4, 111,  60,  52,  31,  81,  18,  38, 103,  72,
         50,  52,  38,  43,  52,   6,  45,  39], device='cuda:0')
argmax end logits shape: tensor([ 41,  22,  27,  15,  28,  22, 130,  72,  52,  10, 110,   3,  95,  83,
          9,  21,   8,  10,  68, 124, 124,  90,  75,  80,  35,  10,  74,  92,
         61,  33,  39,   4,  10,  71,  14,  16, 113,  69,  51,  24,   5,  15,
         28,  72,  73,  42,  75, 113, 128,  39,  31,  72,   3,  44,  13,  57,
          2,  77,  73,  72,  51,  64,  40,  18], device='cuda:0')
argmax start logits shape: tensor([140,  60,  54,   2, 101,   3,  34,  80, 108, 108,  19,  66,  24,  57,
         92,  60,  87,   2, 150, 108,  20,  96,  86,  69,   8,  46,  17,   3,
         

Epoch 2/10:  29%|██▉       | 90/313 [00:06<00:14, 15.25it/s, loss=7.88]

argmax start logits shape: tensor([  0,  69,   6,  35,  29,  33,  54,  91,   6,  77,  15,  15,  65,  91,
          6,  87,   1, 103,   8,  32, 109,  52,  90,  43,  58,   7,  16,  35,
         24,  36,  33,  27,   0,  11,  52,  25,  41,  42,  26,  59,  13,  28,
         46,   9,  32, 107,  39,  37,  43,  14,  10,  65,  48,   1,  90,  43,
         84,  17, 121,  36,  68, 102, 151,  26], device='cuda:0')
argmax end logits shape: tensor([ 48,   5,  95,   5,  62,  33,  41,  93,  14,  95,   2,  16,  98, 115,
          3,  67,  15,  81, 111,  36,  17,  22,  29, 136,   4,  10,  77,   8,
         15,   6, 115,  29,  70, 114,  72,  85,  38,  71,   4,  55, 110,  28,
         55, 120,  93,  63,  66,   3,  13,  20,  70,  50,  80,  63,  29,  50,
         68,  20,  96,  56,  25, 114,  47,  23], device='cuda:0')
argmax start logits shape: tensor([ 14,   9,  13,  38,   9,   0,  27,   0,  34,  84,  79,  63, 124,  68,
         54,  81,  20,  24,  92,   2,  23,  35,  24,  62,  24,   3,  44,  18,
         

Epoch 2/10:  30%|███       | 94/313 [00:06<00:14, 15.03it/s, loss=7.61]

argmax start logits shape: tensor([ 25,  67, 107,  38,   8,  34,  71,  51,  46, 103, 102,  17,  47, 107,
         99,  27, 117,  64,  18,  59,   9, 137, 105,   6,  68,  37,  75,  84,
          0, 106,  45,   0,  76,  84,  89,  61, 114,  62,   0, 133,  23,  54,
        100,  56,  21, 121,  34,  88,  64,   8,  75,  75,  54,   8,  17,  50,
         61,  27,  99,  33,  54, 114,  39,   9], device='cuda:0')
argmax end logits shape: tensor([ 40,  18,  14,  20,  22,  62, 126,  51, 121, 105,  17,  59,  38, 130,
         84,  68, 134, 108,  37, 112,  36, 137,  67,   9,  94,   7,  91,  49,
          7, 107,  99,  76,  32,  31,  77,  58, 116,  90,   6, 119,  43,  66,
        139,  83,  77,  80,   2,  56,   3,  49,   5,  91,  54,  44,  33,  50,
         47,   4,  97,  33,  20,  29,   3,   5], device='cuda:0')
argmax start logits shape: tensor([ 84,  65,  10,  55,  91,  51,  45,   7,  46,  61,  29,  45,  24,  94,
         54,  30,  84,  96,  36,   0,  36,   2,  51,   2,  39,  39,  46,  65,
         

Epoch 2/10:  31%|███▏      | 98/313 [00:06<00:14, 15.14it/s, loss=7.49]

argmax start logits shape: tensor([ 35,  25,  18,  48,   0,  30, 113,  85,  12,   0,  63,  51,  93,   0,
          0,   0,  65,   6,   0,  87,   0,  36,  60,   0,  25,  50,  23,   2,
         65,  75, 102,  17,  84,  92,  25,  43,  89,   0,  44,  66,   0,   2,
          0,  51,   0,  60,  52,   0,  25,  22,   0,  46,  59,   0, 121,  64,
          0,  33,  18,  34,  41,  89,   0,  20], device='cuda:0')
argmax end logits shape: tensor([ 77,  25,   6,  51,  53,  68, 105,  87,  20,  11,  77,  51,  93,   7,
         92,  16,  65,   6,  18,  56,  61, 118,  74,  86,  25,  27,  23,  30,
         65, 146,  99,  17,  21,  92,  25,  53,  42,  46,  96,  16,  59,  29,
         52,  22,  26,  45,  52,  63,  52,  36,  37,  42,  30,  20,  44,  65,
         63,  97,  18,  66,  41,  39,   1,  58], device='cuda:0')
argmax start logits shape: tensor([  0,   0,   0,   0,   0,  13,   0,  22,  76,  55, 113,   0,   0,   0,
          0,   0,   0,  16,   0,   0,   0,  15, 123,   0,  60,  21,   0,  50,
         

Epoch 2/10:  33%|███▎      | 102/313 [00:06<00:13, 15.14it/s, loss=7.85]

argmax start logits shape: tensor([  0,   0,  44,  79,   0,   0,  75,   0,  86,   0,   0,   0,   0,   0,
          0,   0,   0,  53,   0, 106, 100,  43,   0,  91,   0,  33,   0,   0,
          0,  88,  10,   0,  43,   0,   0,  30,  80,   0,   0,   0,   0,  61,
          4,   0,   0,   0,  30, 100, 107,   6,  29,  69,   0,   0,   0,   0,
        102,   0,  15,   6,  21,   0,   0,   0], device='cuda:0')
argmax end logits shape: tensor([116,  33,  19,  45,  65, 113, 146,  82,  63,  53,  14,   7,  50,  13,
        106,  31,  14,  25,  64,  52,  74,  75,   5,  23,  44,  33, 126,  16,
         34,  78,  30,  64,  50,  34,  83,  32,  41, 129,  17,  43,  43,  53,
         69, 149,  30,  82,  21, 100,  24,  42,  29,  56, 116,  86,  13,  74,
         59,  60,  36,  12,  28,   5, 114, 102], device='cuda:0')
argmax start logits shape: tensor([ 65,  28,   9,   0,   0,   0,  13,  59,   0,  54,   0,   0,   0,   0,
         13,   0, 100,   0,   0,  19,  32,   0,  67,   0,   0,   0,  19,  17,
         

Epoch 2/10:  34%|███▍      | 106/313 [00:06<00:13, 15.20it/s, loss=7.56]

argmax start logits shape: tensor([  0,   0,  57,   0,   0,   0,  47,   0,   0, 126,   0,   0,   0,  83,
         53,   2,   0,   0,   8,   0,   0, 104,   0,   0,  25,  10,   3,   0,
          0,   0,   0,   0,   0,   0,  90,   0,   0,  53,   0,   0,   0,  18,
         71,   0,  89,   0,   0,   0,   0, 100,   0,   0,   0,   0,  52,   0,
         86,  11,  56,  43,   0,   0,   2,  80], device='cuda:0')
argmax end logits shape: tensor([ 17,  81,   8,  12,  45,  56,  82,  25,  95,  11,  72,  38,   5,  50,
          2,  42,  22,  44,  43,  50,  16,  87,  20,  26,  25,  76,  69,  23,
         44,  18,  30,  61,  60, 118,  86,  99,   6,  78,  41,  12,  52,   3,
         22,  64, 115,  10,  63,  19,  27,  12,  34,  47,  78,  24,  37,  80,
         44,   8, 141,  18,  53, 112,  44,  41], device='cuda:0')
argmax start logits shape: tensor([  0,   0,   0,   0, 100, 132,  84,   0,   6,   0,   0,  91,  71,  13,
         42,  28,  50,   0,   0,  93,   0,   0,  63,  41,   0,   3,  77,  59,
         

Epoch 2/10:  35%|███▍      | 108/313 [00:07<00:13, 15.28it/s, loss=7.88]

argmax start logits shape: tensor([ 72, 111,  39,  10,   0,  13,  23,   0,  41, 100,  87,  65,  13,  85,
         33,  71,   0,  14,  11,  45,  33,  13,  19,   0,  16,  76, 100,   0,
         83,   0,   0, 127,  59,  16,  19,   0,  76,  71,  77,  12,   0,  40,
         52,  71,  66,  43,  14,  83,   0,  71,  33, 117,  24, 102,  36,   0,
         48,   3,  74,  51,  52,   0,  29,  25], device='cuda:0')
argmax end logits shape: tensor([114, 111,  42, 117,  73,  14,  81,  57, 104,  28,  76, 100,  14,  50,
         74,  64,  15,  22,  67,   7,  95,  26,  20,  55,  17,  78,  20,  34,
          2,  22,  70, 105,   4,  17,  32,  66,  78,  49,   8,  50,   6,  41,
        110,  45,  37,  27, 127, 105,  44,   5,  63,  68, 111,  38,   5,   3,
         85,  58,  56,  51,  77,  38,   2,  51], device='cuda:0')
argmax start logits shape: tensor([ 28,  48,  11,   6,  26,   3,  75,   2,  48,  52,  69,  47,  20,  32,
         19,  26,  92, 133,  90, 104,   0,  62,   3,  91,  77, 140,  26,  51,
         

Epoch 2/10:  36%|███▌      | 112/313 [00:07<00:13, 15.20it/s, loss=7.9] 

argmax start logits shape: tensor([  0,   2,  56,  29,   0,   0,   2,  83,  55,  40, 103,  81,   0,  10,
         14,   0,   0,  28,  85,  31,   9,  57,  11,  74,  21,  40,  28,   0,
         63,  37,  49,   0,   3, 120,  63,  50,  41,  58,  17,  48,   0,   0,
          0,  40,  55,  47,  38,  90,  43,   0,   8,   2,   0,  85,   0,  98,
         85,  95,  25, 109,   0,  30,  25,  18], device='cuda:0')
argmax end logits shape: tensor([ 17,   2,  56,  29,  27,  18,   3, 123,  59,  17, 102,  40,  88,  10,
         14,  27,  15,  21,  22,  31,   9,  57,  18,  76,  65,  51,  30,   5,
        129,  42,  50,  47,  41, 147,  48,  61,  46,  46,  23,  20,   3,  72,
         72,   7,  55,  59,  38,   2,  43,  25,  46,   3,  49,  85,  43,   5,
         86,  95,  26, 118,  57,  55,   5,  18], device='cuda:0')
argmax start logits shape: tensor([  0,   0,  49,  20,  20,   0,   0,  17,   0,   0,   0,   0,  24,  77,
         61,   0,  69,   0,  51,   0,   3,   0,   0,  35,  75,  46,   0,  51,
        1

Epoch 2/10:  37%|███▋      | 116/313 [00:07<00:12, 15.28it/s, loss=7.61]

argmax start logits shape: tensor([ 55, 105,  26,  97,  45, 114,   7,   9, 113,  90,   0, 138,  80,  41,
         37, 115,  11,  85,   2,   0,   0,  67,   3,  60,  14,  53,  15,  16,
         23,   0,  22,  25,  19, 101,  37,  11,  45,  29,  32,  55,  69,  33,
         48,  46,  21,  56,   0,  28,  52,  59,  35, 128,  38,  30,   0,   2,
         14,   0,  26,  59,  16,  82,  48,   2], device='cuda:0')
argmax end logits shape: tensor([ 53,  37, 106,  41,  81, 116,   8,  24,  72,  91,  21,  71,   6,   7,
         57,  47,  11,  56,  26,  41,  44,  62,  27,  27,  17,  24,  19,  80,
         55,  33,  61,  25, 101, 117,  62,  37, 116,   3,  32,  44,  98,  45,
         80,  10,  32,  75,  36,  28,  87,  37,  35,   4,  48,  68,  67,   3,
         53,  43,  73,   2,  13,  59,   3,   5], device='cuda:0')
argmax start logits shape: tensor([ 46,  95,  97,  32,  45,   8,   5, 105,  83,  96,   7,  99,  97,  13,
         17,   3,  25,  58,  30,  88,  66,  26,  42,  10,  17,  25,   5,   3,
        1

Epoch 2/10:  38%|███▊      | 120/313 [00:07<00:12, 15.20it/s, loss=8.27]

argmax start logits shape: tensor([107,  54,  67,  47,  54,  14,  52,  20,  61,  12,  40,  61,  22,   2,
         17,  75, 114,   2, 102, 101,  40,   7,  13,  61,  68,  35,  59, 102,
         16,  45,  25,  88,   8,  32,  35,  36,  45,  85,  46,  18,  95,   0,
         99,  12,  17,  20,  51,  19,  33,  47,  42,  47,   0,  23,  32,  25,
          2, 109,  23, 104,  24,   0,  20,  36], device='cuda:0')
argmax end logits shape: tensor([ 25, 101,   6,  35,  92,  74,  52,  81,  90,  44,  44,  73,  24,   3,
         23,  74,   3,  42,  61,   6,  43,  10,  17,   8,   7,  57,  60,  74,
         64,  45, 119,  51,   8,  32,  37,  40,  45,  80,  90,  19,  92,  38,
         73,  23,  20,  24,  55,  35,  13,  47,   3,  82,  28,   9,  83,  26,
          2, 107,  12,  44,  30,  12,  92,  37], device='cuda:0')
argmax start logits shape: tensor([ 49,   0,   0, 101,  98,  46,   5,   0,  25,  68,  29,   0,  44,  27,
          0,   2,   2,   0,  70,   0, 113, 114,   0,  44,   2,  10,  18,  83,
         

Epoch 2/10:  39%|███▉      | 122/313 [00:08<00:12, 15.15it/s, loss=7.67]

argmax start logits shape: tensor([  8, 106, 114,   0,  21,  27,   0,   0,  58,   0,  55,  29,   0,  28,
          0,   0,   0,  45,  67,  28,  51,  29,   0,   0,  77,   2,  37,  74,
          0,  54,   0,  46,  98,   0,  51,   0,   0,   0,   0,   0,   2, 121,
          0,   0,   0,  85,  18,   0,  10,   0,   0,  46,   0,   0,   0,   0,
        101,   0,  32,   0,  33,   0,   0,   0], device='cuda:0')
argmax end logits shape: tensor([  8, 106,  25,  35,  45, 115,  60,  91,  37,  41,  57,  16,  56,  29,
         88,  32,  38,  42,  68,  48,  51,  29,  50,  36,  81,  96,  54,  74,
         38,  63,  33,  46,  69,  61,  54,  62,  68,  97,  12,  63,  73,  15,
         77,   5,  25,  85,  42,  53,  49,   4,  41,  48,  56,  52,  18,  45,
         74,  59, 116, 104,   5,  63,   1,  10], device='cuda:0')
argmax start logits shape: tensor([  0,   0,  82, 101,  73,  93,  32,  49,  87, 118,   0,   0,   0,   0,
          1,   0,   0,  79,   0,  65,  41,   0, 100,   2,   5,   0,   2,   0,
         

Epoch 2/10:  40%|████      | 126/313 [00:08<00:12, 15.24it/s, loss=7.56]

argmax start logits shape: tensor([ 72,   9,  40,  26,  41,  65,   4,   3,   0,  32,  83,   0,  61, 112,
         54,  60,   0,  49,  85,   0,  24,  50,  56, 116,  22,  26,   8,  71,
          0,   6, 105,  30,  49,  25,  37,   2,   7,  79,  14,   2, 111,  11,
          2,  43,  24,  98,  48,  63,  63,  79,  75,   0,  85,  80,  67,  20,
         50,  35,  15, 138,  11,  64,  84,   0], device='cuda:0')
argmax end logits shape: tensor([ 80,  70,  30,   3,  73,  31,   4,   3, 124,  15,   2,  10,  20,  89,
         14,  76,  40,  33,  69,  36,  26,   9,  12,  90,  28,   6,  35,  20,
         76,   7,  96,   3,  78,   4,  55,  44,  72,  75,  35,   2,  37,  64,
         25,  46,  50,  98,  66,  22,  75,  79,  63,  36,  25,  78,  27,  34,
         58,  35,  17, 108,  11,  60,  39,  73], device='cuda:0')
argmax start logits shape: tensor([110,  91,  66,  27,  76,  23,  49,  98, 109,   9,  60,  25,  21,  30,
         22,  20,  33,  47,  49,  27,  13,   2,  76,  74,  79,   3,   3,  36,
         

Epoch 2/10:  41%|████      | 128/313 [00:08<00:12, 15.24it/s, loss=7.7] 

argmax start logits shape: tensor([110,  47,  47,  59,  59,   0,  15,  58,   8,  60,  11,  63,  75,   9,
          2,  51,  11,  33,  29,  95,   1,  18,  96,  46, 113,  56,  75,   2,
         61, 118,  18,  42,  18,   2,  22,  17,  45,  24,  39,   1,  54,  52,
         15,   0,  62, 135, 109,  41,   7,  53,  41,  69,  32,  73,  82,  10,
         19,  79,  39,  30,   6, 113,  94,  76], device='cuda:0')
argmax end logits shape: tensor([ 24, 102,  47,   9,  30,   9,  82, 118,   2,  63,  46,  64, 101,  34,
        121,  43,   9,  28,   8,  73,  12,  35,   3,   5,  31,  56, 119,  85,
         24, 108,  95,  67,  47,  95,   8,  99,  29,   4,   9,  36, 100,  57,
         25,  28,   9,  92,  65,  44,   7,  98,  34,  50,  18,  42,  59,  59,
         15,  28,  61,  43,  33,   7,  18,  18], device='cuda:0')
argmax start logits shape: tensor([ 13,  16,   0,   5, 111,  84,  87,   0,   0,  96,  22,   0,  21,  46,
         10,   3,  41,   0,   0,  30,   0,   0,  46,  88,  93,  51,   7,  31,
         

Epoch 2/10:  42%|████▏     | 132/313 [00:08<00:11, 15.25it/s, loss=7.42]

argmax start logits shape: tensor([25, 53,  0, 78,  0, 46, 74,  0,  0,  0,  0,  0, 60, 77,  0,  0,  0,  0,
        56,  0,  0,  0,  0,  0, 18, 57, 60,  0, 36, 31, 60,  0,  0, 44,  0, 93,
        85, 27,  0,  0, 40,  0,  0, 88, 12, 19,  0, 88,  0,  0,  0,  0,  0,  0,
         0,  0,  0, 51, 50,  0, 55, 95,  0,  0], device='cuda:0')
argmax end logits shape: tensor([ 25,  73, 110, 123,  18,   2,  28,   2,  19,  81,  76,  59, 101,  61,
         15,  25,  44,  32,  68,  15, 117,   8, 103,  91,  69,  17,  12,  24,
         11,  34,  60,   4,  67, 116,  38,  21, 100,   9, 114,  99,   7,   2,
         22,  90,  69,  68,  94,  79,  88,   5,  13, 104,  98,  95,  21,   8,
         66,  95,  29,  25,  55,  39,  42,  49], device='cuda:0')
argmax start logits shape: tensor([  2,   0,   0,   0,  25,  21,   0,   0,   5,   0,   2,   0,   0,   0,
          0, 129,  19,   0,   0,   0, 103,   0,   0,   0,  97,   0, 126,   0,
          0,  36,   0,   0,   0,   0,   0,   3,   0,   0,   0,   0,   0,  23,
   

Epoch 2/10:  43%|████▎     | 136/313 [00:08<00:11, 15.05it/s, loss=8.13]

argmax start logits shape: tensor([ 67,   0,   0,  94,   0,  54,   0,   0,   2,  91,  22,   0,   0,  68,
          0,   0,   0,   0,  77,   0, 107, 145,   0,   0,   0,   0,  99,   0,
         35, 121,  34,   0,   0,   0,   0,  51,   0,   0, 128,  40,   0,   8,
         18,  20,   0,   0,  31,  72,   0,   2,  30, 129,   0,   0,   0,  70,
          0,  25,   0,   0,  51,  28,  30,   0], device='cuda:0')
argmax end logits shape: tensor([ 88,  64,  65,  94,  44,  30,  31,  40,  43,  52,  43,  44,  55,  15,
         20, 106,  12,  29,  87,   3,  50, 147,  63,  75,  78,  21, 100,  10,
         40,  39,  25,  72,  17,  21,  14,  56,  16, 141, 125,  49,  40,  61,
        116,  24,   3,  46,  49,   6,  53,   6,  99,  30, 109,   7,  96, 113,
          4,  28,  43,  29,  21,  23,  30,  32], device='cuda:0')
argmax start logits shape: tensor([ 43,  28,   0,  55,  93,  92,   0,  70,  45,   0,   2,   0,   2,   0,
         75, 110,  57,  13, 105,  21,   0,   0,  52,   0,   0,   0, 112,  52,
         

Epoch 2/10:  45%|████▍     | 140/313 [00:09<00:11, 15.16it/s, loss=7.76]

argmax start logits shape: tensor([ 41,  19,  43,  41,  84,  30,  75,  11,  47,  81,  92,   8,  40,  85,
         76,  40,  24,  25,  10,  42, 108,  29,  99,  61,  80,  39,  36,   7,
         64,  75,   3,  28,  17,  51,  79,  49,  85,  65,  31,  65,  28,  87,
        103,   2,   2,  96,  20,  67,  15,   9,  57,  51,  66,  76,  60,  76,
         31,  42,   2,  86,  62,  86,  94,  38], device='cuda:0')
argmax end logits shape: tensor([  5, 141,  43,  95,   2, 119,  74,  56,  21,   3,  35,   9,  53, 113,
         33,   2,  84,  73,  10,  62,  25,  30,  33,   9,  80,  32,  22, 101,
         31,  60,  27,   8,  13,  56,  79,  20,  57,   9,  35,  58,  28,  92,
          3,   2,  30,  65,   1,  67,  65,  62,  27,  92, 104,  10,  76,  76,
        136,  42,  49,  57,  69,  81,  30,  56], device='cuda:0')
argmax start logits shape: tensor([101, 108, 109,  95,  33,  83,   3,  30,  53,  55,  10,   2,  33,  34,
         59, 137,   7,  35,  71,  33,  72,  77,  20,  10,  96,   0,  57,   2,
         

Epoch 2/10:  45%|████▌     | 142/313 [00:09<00:11, 15.15it/s, loss=7.69]

argmax start logits shape: tensor([ 17,  14,   0,   0,  11,   0,  65,   0,  16,  47,  69,  17,   0,  37,
         16,  92,  77,  17,  45,  69,   8,  39,  25,  60, 123,   6,  48,  10,
          0,  42,   0,  47,   3,   0,   0,  37, 102,  56,  32, 100,  19,  37,
         59,  17,   0,  29,  51, 107,   0,  52,  67, 116,   0, 132,  23,  41,
         52,  48,  49,  18,  73,  90,   0,  20], device='cuda:0')
argmax end logits shape: tensor([122,  15,  61,  36,  28,  44,  65,  48,  33,  32, 115,  24,  30,  46,
         71,  61,  63,  21,  33,  69,  48, 110,  67,   3, 124,   3, 104,  21,
          5,  88,  14,  82,  30,  20,  75,  24,  85,  78,  38,  86,  64,  37,
          5, 108,  54, 113,  51,  81,  19,  60,  45,  51,  42, 113,  76,  51,
         33,  11,  60,  89,  60,  55,  49,  12], device='cuda:0')
argmax start logits shape: tensor([  0,   0, 104,  26,   0,  33,   0,  46,  70,   0,  32, 113, 101,   0,
        101,   0,  82,  24,  57,  24,  64,   0,   7,   0,   0,  13,  48,  19,
         

Epoch 2/10:  47%|████▋     | 146/313 [00:09<00:10, 15.33it/s, loss=7.53]

argmax start logits shape: tensor([  0,  27,  20,  14,   0,  24,   0,  87,  75,  94,  25,  91,  18,   0,
         61,  54,   0,  78,   0,  59,  78,  59,  65,   0,  50, 125,  99, 110,
          3,  14,   0,  78,  63, 103,  18,  30,   0,   2,   0,  50, 101,   0,
          3,  53,   6,  31,   0,  41,  80,   0,  46,  47,  20,   0,   2,  40,
        117,  57,   0,  22,  81,  68,  38,   5], device='cuda:0')
argmax end logits shape: tensor([ 14,  58,  70,  35,  53,  37,  80,  61, 108,  17,  21, 115,   5,  62,
         53,  18, 121,   3,  26,  16,  51,   3,  67,  34,  50,  18, 101,  38,
          7,  47,  59,  78,  92,  16,  39,  69, 105,  60,  79,  78,  40,  23,
         55,  97,   9,  17,  30,  33,  11,  82,  41,  38,  19, 103,  67,  22,
         44,  17,  63,  60,  81,  32,  21,  38], device='cuda:0')
argmax start logits shape: tensor([ 22,  31,  17,   0,  39,  20,  82,   8,  17,   0,  45,  36,   8,   0,
          2,  89,   4,  64,  37,   2,   0,  90,  11,   0,   7, 111,  97,  13,
         

Epoch 2/10:  48%|████▊     | 150/313 [00:09<00:10, 15.22it/s, loss=7.54]

argmax start logits shape: tensor([ 33,  98,   7,  61,   0,  48,  68,   0, 116,  41,  79,  59,  91,  10,
         37,  67,   1,  85,  25,  36,  82,  41,  82,  64,  17,  61, 104,  69,
          0,  18,  18,  42,  87, 110,   8,  89,  23, 113,  81,   6,   6,  12,
          3,  91, 101, 109,  31,  45,  66,  51,  41,   3,   3,   0,   2,  79,
         39, 127,  82,  12,  85,  14,  58,  30], device='cuda:0')
argmax end logits shape: tensor([ 34,  98,  41,   3,  19,  11, 117,  98, 116,  81,  79,  72,  62,   7,
         20,  67,  71,  19,  11,  58,   8,  63, 140,  42,  17,  11, 104,  44,
         33,  74,  27, 141,  69,  51,  19,  39,  79, 108, 130,  31,  28,  13,
         15,   8,  65, 104,  28,  47,  40, 104,  33,  23,  24,  35,  37,  35,
         43,  45, 104,  66,   5,  23, 107,  35], device='cuda:0')
argmax start logits shape: tensor([109,  72,   2,  62,   8,  93,  53,  59,   3,  69,  88,  38,  25,  11,
         37,  42,  63,   2,  22,  74,  35,  39,  64,  53,  32,  89,   2,  28,
         

Epoch 2/10:  49%|████▉     | 154/313 [00:10<00:10, 15.21it/s, loss=7.86]

argmax start logits shape: tensor([  0,   0,   2,   0,  53,  52,   0,   0,  10,   0,  35,   0,  45,  90,
         51,  21, 110,   0,   5,   0,   0,  85,  92,   0,   6,   0,  41,   0,
          0,  13,  89,   0,  52,  35,   2,  70,  87,  90,   0,  66,   0,   2,
          9,   8,   0, 108,   0,   0,   0,  76,   0,  59,   0,  22,  69,  68,
         42,  68,  84,   3,  85,   0,   0,   0], device='cuda:0')
argmax end logits shape: tensor([ 93, 122,   3, 107,  53,  71,  52,  29,  77,  61,  17,  46,  46,   3,
         60,  52,  88,  68,   5,  40,  17,  34,  87,  56,  28,   3, 110,  10,
         47,  13,  75,  35,  71,  35,   5,  76,  60,   7,  38,   3, 128,  79,
        104,  47,  43,  44,  31,  79,  28,  80,   5, 114,  94,  43,  69,  37,
        134,  66,  35,   6,   4,  36,  59,  96], device='cuda:0')
argmax start logits shape: tensor([  6,   0, 100,  92,  27,   0,  27,  17,  66,  54,  22,   0,   6,  36,
          0,   0,  62,   0,   0,  83,  36,   0,   0,  57,  18,   0,   0,   5,
         

Epoch 2/10:  50%|████▉     | 156/313 [00:10<00:10, 15.15it/s, loss=7.52]

argmax start logits shape: tensor([117,  10,  78,   0,  46,  34,  81,   0,  88,  46,  13,  57,  18,  78,
         55,  36,  11,   0, 100, 109,  40,  90, 108,  87,  27,  60,  29,  78,
          0,  29,  12,  56,  47,  73,  33,  41, 101, 127, 108,   0,  61,   0,
         45,  60,  18,  18,  46,  87, 112,  59,  41,  69,   0,   0,  45,   2,
         45,  49,  49,  48,  75,  13,  58,  78], device='cuda:0')
argmax end logits shape: tensor([122,  21,  78,  10,   9,  60,  97,  24,  63,  60, 113,  94,  75,  43,
         89,  47,   5, 123, 100, 110,  34,  29, 109,  86,   3,  56,  44,  78,
         34,  29, 138,  59,  32,  32,  63,  76,  22,  52,  46,  65,   9,  30,
         73,  13,  41,   3,  30,  67,  40,  35,  92,  94,  15,  58,  17,   5,
         15, 103, 143,  64,  69,  14,  68,  97], device='cuda:0')
argmax start logits shape: tensor([ 31, 116,  27,  85,  67,  46,  85, 106,  49,   0,  43,  50,  17,  23,
         52,  96,  37,  57,  56,  43,  53,  33,   7,  85, 101,   4,  13,  51,
         

Epoch 2/10:  51%|█████     | 160/313 [00:10<00:10, 15.17it/s, loss=7.59]

argmax start logits shape: tensor([ 60,   2,  61,  76,  41,  52,  14,  15,  13,  63,   0,  46,  16,  37,
          0,   0,  78, 113,   0,   5,   8,   4,  60,  38,   8,  55,  79,  30,
         44,  25,  39,   0,   0,  16,   0,  54,   0,   0,   0,  81, 120,  18,
        106,   6,  61,  21,  94,  89, 101,   6,  17,   2,  84,   0,  59, 105,
         38,  61,   8,  55,  62,   8,  22,  58], device='cuda:0')
argmax end logits shape: tensor([ 60,   3,  91, 109,  17, 104,  36,  44,   5,  86,  36,   7,  83,  64,
         14,  45,  79, 114,  26,  78,  19,  15,  63,  38,  31,  55, 124,  28,
         65,  51,  74,  59,  31,   1, 103, 116,  53,  23,  18,  44,  98, 105,
         40,  14,  61,  65,  60,  89, 101,  40,  38,  88,   3, 102,   9, 131,
         41,  10,  17,  49,  65,   6,   5,  59], device='cuda:0')
argmax start logits shape: tensor([  0,   0, 130,  55,  23,  44,   0,  73,  52,   0,  44,  40,  42,  21,
         18,   0,   3,  49,  33,  26,  28,  11,   3,   0, 136, 114,  14,  33,
         

Epoch 2/10:  52%|█████▏    | 164/313 [00:10<00:09, 15.25it/s, loss=7.69]

argmax start logits shape: tensor([  0, 114,  12,  40,  45,   2,   0,   0,  48,  13,  34,   2,  88,   0,
          0,  68,   0,  86,  40,  71,   2,  89, 107,  16,   0,  67,   0,  33,
          0,  68,   0, 104,   0,  80,   0, 108,  64,   0,   0,  39,   0,   0,
        102,   6,   0,  61,   8,   0,   0,   0,  24,   0,   0,   0, 100,  20,
         92,  15,   0,  37,   0,   0,   0,   0], device='cuda:0')
argmax end logits shape: tensor([ 23,   7,  44,   1,  54,   2,   3,   8,  34,   3,  52,   2,  67, 103,
         89,  71, 129,  89,  70,  98,  52,  48,  45,  70,  16,  13,  78, 108,
          3,  71,  30,  51,  56,   2,  61,  46,  64,  10,   3,  39,  58,   3,
         19,  33, 101,  65, 113, 121,  17,  34,  24, 130,  37,  13,  34,  58,
         92,  81,  96, 120,  53,  35,  63,  35], device='cuda:0')
argmax start logits shape: tensor([ 25,  85,  83,   0,  80,   0,   0,   5,  21,  45,   2,  82,  52,  29,
          0, 137,   0,   0,  92, 115,   0,   0,   0,   0,  43, 102,  83,   3,
         

Epoch 2/10:  53%|█████▎    | 166/313 [00:10<00:09, 15.15it/s, loss=7.97]

argmax start logits shape: tensor([ 12,  28,  38,   3,  27,  90,  66,  95,  75, 121,   3,  52,   3,   5,
         82,   5,   3,  22,  48,  64,  13,   5,  41,  14,  81,  61,   2,  41,
         98,   7,   2,  38,  24,  51,  21, 129, 116,  83,  53,  22,  17, 104,
          8,  50,  34,  95,  81,  17,  88,  18,   0,  27,  68,   1,  12,  42,
         13,  15,   0,  16,  34,  60,  80,  62], device='cuda:0')
argmax end logits shape: tensor([ 49,  28,  67,   7,  27,  77,  52,  59,  86,  18,  46,  55,  27,  82,
        138,  61,  33,  22, 114,  96,  31,  35,  41,  22,  80,   3, 102,  41,
         67,  25,  33,  38,  24,  58,  53,  95,  10,  17, 103,   5,  32,  76,
         38,  29,  56,  73,  41,  13,   8,  58, 109,  38,  55,  22,   4,  24,
         55,  34,  82,  29,  70,  25,  31,  53], device='cuda:0')
argmax start logits shape: tensor([ 88,  63,  48,  26,  24,  70,  67,  15, 105,   8,  60, 106,  61,  41,
         29,  12,  42,  33, 124,  50,  29,  30,  67,   2,  42,  85,  67,  95,
         

Epoch 2/10:  54%|█████▎    | 168/313 [00:11<00:09, 15.16it/s, loss=7.93]

argmax start logits shape: tensor([129, 103,  52,  15,   9,   2,  27,  50,   4,  57,  23,  83,  57,  45,
         12,  17,   3, 103,  75,  65,  26,  74,  59,  30,  99,  70,  56,   9,
         18,   3,  19,  53,  47,   3,  19,  35,  84,  19,  95,  89,  16, 118,
         36,  67, 119,  28,  59,  48,  98,  79,  41,  72,  24, 113,  52,  46,
         98,  96,   2,  94,  29,  93,  87, 101], device='cuda:0')
argmax end logits shape: tensor([ 17,  36,  49,  15,  23,   3,  80,  10,  13,  35,  23,  23,  40,  47,
         12, 107,  46, 130,  16,  61,  38,  33,  59,  32,  92,  11,  92, 147,
          6,  11,  22,  60,  44,  95,  25, 134,  86,  19,  95,  42,  48,  67,
         63,  57, 105,  26,  59,  74,  21,  10,  14,  18,  93, 122,  48,  88,
         75,  87,   2,  26,  61,  36,  80,  50], device='cuda:0')
argmax start logits shape: tensor([ 66,  39,  95,   6,   7,  13,  42,  12,  25,   9,  90,  31,  26, 131,
         40,  54,   2,  24, 110,  68,  79,   2,  52, 133,   0,   3,  27,  49,
         

Epoch 2/10:  55%|█████▍    | 172/313 [00:11<00:09, 15.02it/s, loss=7.78]

argmax start logits shape: tensor([ 25,  21, 110,   0,   0,  69,   0,  68,  44,  40,   0,  43,  50,  72,
         57,   0,  16,   0, 126,   0,   0,  18, 103,   0,  76,  16,   2,   0,
         32, 100,   8,  10,   8,   6,  62,  84,   0, 115,  28,  86,   0,  21,
         66, 100,   0,   0,   0,   2,  94,   0,   0,  81,   0,  54,   0,  54,
         46,   0,  99,  44,  25,  12,   0,   0], device='cuda:0')
argmax end logits shape: tensor([ 55, 130,  91,  46,  14,  20,  37,   4,  38,  39, 117,  12,   8,  75,
         50,   6,  98,  20,  88,  18,  33,  24,  16, 111,  26,  85, 111,  68,
         25,  99,  88,  10,   9,   6, 135,  64,  66,  91,  21,  14,  24,  49,
        112,  81,  49,  38,  41,  12,  83, 107,  58,   4,  33,  54,   5,  37,
         36,   3,  89,  20,  78,  61,  83,  61], device='cuda:0')
argmax start logits shape: tensor([ 71,   0,   0,  28,  68,   0,   9,   5,   0,  42,   2,   0,  37,   0,
         47,   0,  70,  10,  29,   0,  85,   0,  25,   0,   6,  18,   0,   0,
         

Epoch 2/10:  56%|█████▌    | 176/313 [00:11<00:09, 15.18it/s, loss=7.82]

argmax start logits shape: tensor([ 62,  29,  70,   0,  11,  39, 136,   0,   0,  96,  21,  42,  14,  78,
         89,  53,  13, 116,   9,  33,  86,  94,  50,  80,  11, 119,  36,  81,
         37, 102,   0,   0,  51,  16,  86,  69,  84,   7,  51,   5,   0,   0,
         84,  14,  34,  62, 102,  87,   0, 113,  51,  88,  90,   0,  37,  33,
         60,  36,  47,   0,  90,  14,  48,   0], device='cuda:0')
argmax end logits shape: tensor([ 29,  20,  37, 103,  36,  52, 101,  13,  50,  54,  80,  43,  55,  79,
        101,  53,  62,  63,  14,  30, 105, 100,  61,  65,  12,  46,  59, 107,
         26,  57,  13,  20,  51,  70,  17,  89,  90,  90,  65,  32,  10,  92,
         84,  58,   4,  66,  54,  59,  24,  48,  32,  66,   3,  35,  26,  13,
         40,  41,  97,   4,  57,  40,  12,  51], device='cuda:0')
argmax start logits shape: tensor([ 48,  58,  83,  24,  45,   2,  54,  47,  39,  37,   5,   0,  87,   5,
          0,  99,  31,   0,  62,   8,   5,  70,  58, 104,   2, 102,  46,  19,
        1

Epoch 2/10:  57%|█████▋    | 178/313 [00:11<00:08, 15.18it/s, loss=7.7] 

argmax start logits shape: tensor([ 99,  67, 102,   6, 138,   9,  20,  29,  11,  37,  70,   2,  69,  44,
         56,  25,  26,   0,  88,  72,   0,   0, 101,  53,  59,   7,   0,   0,
         86,  57,  73,   3,   0,  23,   5, 102,  61,   2,  36,  32,  88,  54,
          6,  36,  68,  53,  44,   3,  37,   5,  27,  46,   0,  20,  69,  63,
          0,  38,  51,  56,  29, 110,   1,  21], device='cuda:0')
argmax end logits shape: tensor([101,  30,  15,  26,  84,  24,  32,   2,   7,  37,  50,  32,  33,  45,
         64,  51,  61,  12,  79,  76,  33,  69, 101,  14,  59,  59,  74,  44,
         87,  84,  97,  46,  33,  66,   3,  43,   2,   2,  69,  32,  61,  66,
         72,   9,  68,  91,  44,   3,  46,  93,  55,  90,   3,  62,  93,  64,
         53, 114,   9,  28,  70,  90,  48,  33], device='cuda:0')
argmax start logits shape: tensor([100,  51,  14,  25, 104,  47,  62,  17,  80,  66,  63,   6,   0,  15,
         56,  95,   8,   2,  45, 100,  16,  72,   2, 119,  20,  37,  43,  87,
         

Epoch 2/10:  58%|█████▊    | 182/313 [00:12<00:08, 15.31it/s, loss=7.65]

argmax start logits shape: tensor([  0,  56,  15,  15,   0,  28, 113,   0,  61,  95,  16,   0,  24,   5,
         59,  66, 105,  71,   0,  19,  10,  55,  69,   0, 107,   8,   1,   2,
         17,  57,   0,  21,  14,   0,  54,  54,  69,   0,   7,   9,   0,  36,
        128,   9,   2,  30,   0,   3,  37,  23,  13,  23,  75,   2,  17,  33,
          0,   9,  14,  11,  29,   0,  68,  60], device='cuda:0')
argmax end logits shape: tensor([ 66,  97,  91,  24,  74,  30, 113,  26,  38,  58,  12,   1,  18,   4,
         58,  42, 117,  71,  35,   5,  61,  61,  81,  45,  45,   8,  53,  11,
         17,  56,  72,  35,  14,  23,  11,  73,  96,  20,   7,   9,  72,  36,
        128,  98,  65,  31,  21,  53,  50,  23,   2,   5,  75, 112,  41,  33,
         15,  20,  72,  13,  34,  95,  68,  45], device='cuda:0')
argmax start logits shape: tensor([ 50, 112, 111,   0,  63,  80,   0,  34,  97,  77,  99,  93,  42,  54,
         93,   0,  89,  60,  81,   1,  15,   0,  33,  44,   0,  52,  98,   7,
        1

Epoch 2/10:  59%|█████▉    | 184/313 [00:12<00:08, 15.29it/s, loss=7.26]

argmax start logits shape: tensor([ 53,   0, 106,  60,   3,  25,   9,  21,  77,  10, 106,  22,   0,   9,
         71,  32,  34,  66,  46,   0,  62,  90,  19,  42,   0,  80,   0,  46,
         45,  12,  36,   0,   8,   0,  93,  48,   8,  12,  52,  26,  64,  59,
          7,   0,   0,  66,  38,  38,   4,  35,  90,  61, 101,  34,   2,  97,
        121,  43,  91,   0,  23,  52, 108,  19], device='cuda:0')
argmax end logits shape: tensor([ 46,  27,  45,  39,  25,  47, 147,  24,  99,  44,  18,  87,   5,  23,
         98,  12,  22,  64,  25,  17,   4,  40,  68,  77,   5,  87,   9,  54,
        120, 100,  79,  28,  50,  22,  86,  12,   5,  15,  90, 105,  64,  16,
         58,   7,  33,  67,  50,  65,  31,  55,  61,  32,   2,  29,   7,  85,
          7,  57, 108,  41,  25,  54, 103,  25], device='cuda:0')
argmax start logits shape: tensor([  9,  11,  73,  97,  48,  59,   0,  92,  14,  70,  90,  60,   0, 101,
          8,  16,   2, 114,  27,  94,  14, 106,  24,  65,  11,  29,  24,  16,
         

Epoch 2/10:  60%|██████    | 188/313 [00:12<00:08, 15.24it/s, loss=7.67]

argmax start logits shape: tensor([ 53,   0,  21,  29,  32,  65,   2,   0,   8,   4,  19,  14,  87,  29,
         15,  67,  34,  14,   1, 109, 116,  11,   8,  65,  90,  17,  11, 104,
          3,  53,   1, 132,  14,  18,  46,  76,  53,  24,  38,  40,  91,  88,
         50,  73,   0,  23,  48, 127,   4,   2, 123,   3, 113,  79, 107,  35,
          6,  25,  15, 129, 107,  11, 113,  29], device='cuda:0')
argmax end logits shape: tensor([ 79,  27,   5,   9,  52,  65,  61,  41,  18,   4,  14,   6,  59, 101,
         54,   7,  73,  76,  12,  17, 121,  11,  36,  65,  87,  20,  53,  92,
         25,  52,   8,  21,  21,  33,  62,  70,  79,   8,  46,  41,  85,  16,
         25,  36,   6,  13,  10,  39,  40,  29,  83, 128,  50,  37,  14,  99,
         36,  92,  98, 129,  14,  66,  97,  93], device='cuda:0')
argmax start logits shape: tensor([ 14,  46,  29,   8,  25, 108,  54,  25,  45,  53,  29,  61,  32,   2,
         14,  44,  27,   2,  54,  79,  25,  40,  19,  20,  45,  47, 102,  43,
         

Epoch 2/10:  61%|██████▏   | 192/313 [00:12<00:07, 15.37it/s, loss=7.9] 

argmax start logits shape: tensor([ 53,  18,  24, 105,  23,  38,   0,  42,  66,  40,  19,   0,   0,  17,
          4, 122,  17,  64,   7,  70, 117,  98,  32,  50,  31,  48,  70,  10,
         61,  16,  95,  71,  60,   0,  27,  82, 106,   0,   0,  26,  88,  62,
         10,  69,  46,   0,  54,  46,   0,   9, 121,  75,  19, 132,  36,  12,
          0,  28,  90,  91,  61,  84,  29,  60], device='cuda:0')
argmax end logits shape: tensor([ 76,  54,  38,  32,  36,  57,  82,   3,  56,  74,  37,   2,  39,  77,
         11, 109,  14,  87,   5,  10,  69,  98,  76, 108,  66,  10,  69,  11,
         20,  79,  46,  70,  19,  44, 101,  86, 109,  11,  73,  25,  97,  55,
         63, 103,  43,   6,  89,  64,  21,  13, 113, 105,  26, 105,  40,  83,
          8,  16,  47,  21,   8,  26,  84,   3], device='cuda:0')
argmax start logits shape: tensor([  0,  58,   8,   7,   0,  85,  23,  35,   0, 119,   0,   0,  62,   0,
         84, 118,  39,  46,   0,  16,   0,   0,  19,   0,  89,  45,   6,   0,
         

Epoch 2/10:  62%|██████▏   | 194/313 [00:12<00:07, 15.16it/s, loss=7.56]

argmax start logits shape: tensor([ 66,   0,  85,   0,  40,   0,   0,  26,  53,  81,   0,  16,  68,  45,
        100,   0,   0,  60,  79,  55,  34,  45,  15,   0,  24, 127,   0,  65,
          0,  81, 100,   0,   0,   0,  44,   0,  56,   5,   0,   4,  66, 134,
         68,  83,   0,   0,   0, 144,   0,   0,  71,   0,   0,   0,   0, 134,
          0,   0,   0,  10,   8,   0,  17,  47], device='cuda:0')
argmax end logits shape: tensor([ 75,  14,  18,  24,  13,  16,   4,  44,  57,  56,  15,  23,  20,  11,
        104,  37,  23,  18,  85,  37, 106,  74,  38, 102,  11, 128,  38,  37,
          2,  64, 104,  33,  40,  85,  56,  21, 114,  63, 104,   7, 143,  24,
         68,  92, 121,  11,  85,  32,   7,  20,  41,  38,  63,  44,  37, 138,
        119, 127,  70,  61, 122,  83, 126,  33], device='cuda:0')
argmax start logits shape: tensor([ 56,  40,   8,  19,  39,  68,   0,  15,  20,   0,  61,   0,  34,  79,
          0,   0,  81,   0,   0,   0, 109,  22,   0,  28,   2,  30,   2, 133,
         

Epoch 2/10:  63%|██████▎   | 198/313 [00:13<00:07, 15.00it/s, loss=7.65]

argmax start logits shape: tensor([ 92,   9,  16,  25,  22,  52,  92,   0,  47,  62,  92,  51,  55, 120,
         87,  85, 109,  21,  40,  94,  49,   0,  75, 109, 112,   8,  93,  86,
         22,  42,  13,   2,   0,   2,  34,  43,  61,  32,  20, 130,  63,  48,
         44,  26,  31,   0,  93,  40,  96,  73, 156,  37,  70,  48,  99,  58,
         97,  72,  20,  54,   2,  53, 108,  48], device='cuda:0')
argmax end logits shape: tensor([ 94,  68,  34,  25,  76,  53,  17, 119,  49,  62,  49,  55,  28,  61,
          7, 106,  61,  10,   2,  63,  49,  16,  39,  20, 112,  58,  90,  89,
         90,   2,  37,  13,  31,   3,  76,  43,  61,  43,  63,  22,  63,  44,
         20, 112,   5,  37,  47,  75,  96,   9, 116,  30,  64,  48,  47,  38,
         97, 101,  21,  29,  98,  12,  24,  41], device='cuda:0')
argmax start logits shape: tensor([ 37,  12,  83,  89,  42, 118,  38, 120,  34,  80,   4,  56,   5,  39,
         28,   1,  44,  32,  36,  28,  22,   7,  31,   7,   9,   2,  36,  40,
         

Epoch 2/10:  65%|██████▍   | 202/313 [00:13<00:07, 15.31it/s, loss=7.52]

argmax start logits shape: tensor([  6,  97,  56,  75,  70, 100,  25,   2,  97,  82,  36,   6,   3,  29,
         98,  17,  33,  11,   3, 131,  69,  67,  43,  34,  64,  20,   3,  24,
         17,  25,  50,   1,  44,  13,  35,  17, 135,  49,   0,   2,   3,  24,
         34,  21,  42,   2,  65,   2,  53,  59,  32,  15,  34,  68,  25,  10,
         99, 115,  15,  76,  56, 105,  44,  11], device='cuda:0')
argmax end logits shape: tensor([  2,  16, 113,  95,  95,  92,  23,  80,  97, 106,  21,  54,  33,   3,
         91,  83,  61,  16,  98,  43,  99,  67,  17,  31, 123,  43,  79,  95,
         48,  36,  10, 122,  21,  20,  27,  32, 137,  49,   6,   4,  49,  23,
         23,  68,  68,  43,  55,   2, 101, 118,  49,  80,  40,  21,  65,   4,
        100,  73,  66,  29,  94, 106,  21,   9], device='cuda:0')
argmax start logits shape: tensor([104,  36,  68,  41,  68,  50,  29,  71,  14,  66,  48,   3,  24,  89,
         71,  12,  18,  13,   8,  59, 112,  11,   5, 106,  70,  91,  90,  45,
         

Epoch 2/10:  65%|██████▌   | 204/313 [00:13<00:07, 15.27it/s, loss=7.58]

argmax start logits shape: tensor([ 44,  17,   7,  21,   0,  19,   0,  85,  74,  13,  41,  23,  30, 107,
         13,  17,  49,   2,  89,  50,  24,  29,  53,  85,  39, 134,  50,  95,
         42,  38,  72,  35,   5,  28,  84,  61,  15,  44,   2,   2,  84,  18,
         16,  47,  67,  27,  21,  20,  78,   0,  73,  82,   2,  91,  39,  11,
         29,  51,  55,  56,  14,  15,  18,  20], device='cuda:0')
argmax end logits shape: tensor([ 14,  24,  34,   1,   1,  29,  70,  29,  39,   2,  78,  54,  32,  77,
         66, 107,  76,  27,  91,  28,   2,  30,   7,  75,   7,  77,  26,  74,
         95,  23, 107,  54,   6,   2,   3,  52,  18,  68, 110,   3,  87,  20,
         31,  10,  67,  28,  77,  13,  66,  33,  48, 105,   3,  68,  19,  43,
         35,  20,  20,  80,  16,  32,   5,  29], device='cuda:0')
argmax start logits shape: tensor([  2,  31,  78,  51,  81, 111,   8,  75, 121,  79,  16,  33,  84,  23,
        106,  42, 105,  82,  10,  49,  35, 123,   5, 107,  60,   2,  48,  25,
         

Epoch 2/10:  66%|██████▋   | 208/313 [00:13<00:06, 15.33it/s, loss=7.8] 

argmax start logits shape: tensor([ 73,  66,  27,  34,  77,   0,  75, 110,  17,  29, 122,   2,   0, 112,
         27,  25,  70,   4,  11, 101, 117,   0,   0,  70,  66,  49,  87,  17,
         82,  44,   7,  66, 121,  43,  75,   5,  25, 110,  40, 101,  87,  45,
         44,  62,  62,  88,  94,  23,   0,  57,  39,  72,  72,  13,  13,  63,
         39,  85,   0,  96,   0,  16,  34,  64], device='cuda:0')
argmax end logits shape: tensor([ 76,  65,  33,  39,  10,  24,  76,  76,  18,  70,  16,  75,  57, 113,
          3,  14,  99,  33,  32,   1,  66,  28,  86, 103,  18,  67,  88,  67,
         23,  46,  36,   4,  39,  44,   3,   8,  25,   3,  63, 133,  62,   6,
         46,  90,  76,  35,  56,  47,  55, 117,  41,   4,  85,  20,  27,  26,
         78,  11,  45,  97,  12, 105,  25,  91], device='cuda:0')
argmax start logits shape: tensor([  0,  11,  13,  69,   8,  94,   0,  89,  11,  60,   0,  47,  77,   0,
         39,   3,  90,   6,   5,  72,   0,  23,  10, 151,  77,   0,   0,  42,
         

Epoch 2/10:  67%|██████▋   | 210/313 [00:13<00:06, 15.35it/s, loss=8.01]

argmax start logits shape: tensor([ 85,   7,  28,  47,  45,   4,  39,  84,   2,  38,  46,  70,  64,   9,
         20,  39,   1,  47,  80,  28,  40,   2,   3,  15, 117, 113,  90,   0,
         33, 113,  70,  53,  15,  25,  43,  23,   3, 107,  63,   0,  60,  28,
          2,  38,  11,  22,  17, 116,  34,  70,   9,  19, 101,  45,   2,  81,
         42,  86,  79,  28,  78,  80,   2,  56], device='cuda:0')
argmax end logits shape: tensor([ 66,  29,  42,  19,  44,  19,  25,   8,   2,  19,  75,  55,  91,  61,
         10,  37,  69,  47,  21,  19,  40,  63,   3,   2,  59,  21,  51,  18,
         47,  72,  45,  15,  24,  98,  25,  86,  89,   6,  47,  18,   5,  29,
         67,  19,  51,  20,  18,  39,  18,  41,  15,  19, 137,  43,  22, 127,
         32, 105,  45,  27,  78,  35,   3,  29], device='cuda:0')
argmax start logits shape: tensor([ 36,  41,  87,  16,   2, 122,   3,  25,  36,  29,  57,  71,   2,  34,
         15,  77,  27,  25,  38,  64,  17,   3,  30, 108,  21,  63,  21,  26,
         

Epoch 2/10:  68%|██████▊   | 214/313 [00:14<00:06, 15.22it/s, loss=7.73]

argmax start logits shape: tensor([  3,  51,  27,  17, 101,  54,  16,  97,  63,  97,  45,  42,  19,  22,
        117,  20,  31,  80,  24,  89,  75, 106, 106,  34, 105,  25,  81, 117,
         67,  52,  80, 110,   3,  17,  80,  53,  76,  25, 105,   3,  18,  80,
         53,  28,   9,  95,  38,  90,  23,  36,  89,  62,  51,  66,  44,   8,
         68,  76,  58,  59,   7,  45,  26,  31], device='cuda:0')
argmax end logits shape: tensor([ 17,  13,  88,   7,  75,  55,  70,  35,  35,  80,  94,  42,  15,  23,
        120,  21,  88,  59,  17,  86,  34,  94, 129,  34,  22,  25, 143,  37,
         67,  12,   6,  49,  21,  98,  80,  20, 122,  25,  12,  38,   3,  53,
         86,  83,  67,  41,  45,  65,  23, 111,  89,  14,  90,  94,  44,  40,
          7,  22,  57,  21,  91,  32,   5,  31], device='cuda:0')
argmax start logits shape: tensor([ 35,  66,   2, 122, 115,   2,  25,  43,  19,  55,  47,  96,  44,  11,
        101,  25,  44,  69,  34, 112,  61,  75,  21, 105,  55,  59,  23,  92,
         

Epoch 2/10:  70%|██████▉   | 218/313 [00:14<00:06, 15.33it/s, loss=7.77]

argmax start logits shape: tensor([ 52,  36,  66,  77,  46,  88,  77,  78,  13,  99,  40,  74,  23, 137,
         33,  18,  69,  37,   8,  77,   0, 104,   0, 117,   2,  82,  32,  69,
         68,  43,  41,  13,  36,   6,  35,  65,  60,  10,   7,  57,  24,  19,
          0, 111, 108,   2,  56,   0,   0,  88, 121,   0,  66,   2, 127,  11,
         23,  10,  17,  30,  34,  39,  83,   8], device='cuda:0')
argmax end logits shape: tensor([ 52,  36,  83,  80,  33,  57,  77,  30,  98,  23,  43,  74,  34,  55,
         84,  72,  66,  82,  42, 113, 105, 121,  31,  42,  12,  78,  35, 134,
         34,  94,  76,  39,   2,  68,  35,  89,  34,  13,   3,  15,  24,  19,
         53, 113,  54,   8,  42,  27,  44,  11,  18,  33,  67,  62, 126,  78,
         77,  13,  15, 122,  86,  72,  66,  54], device='cuda:0')
argmax start logits shape: tensor([ 63,   0,  60,  17,  41,  39,  22, 108, 124,  15,  89,  11,   2,  57,
          0,  12, 104,  66,  21,   2,   2,  79,  54,  57,   0,   2,  10,   9,
         

Epoch 2/10:  70%|███████   | 220/313 [00:14<00:06, 15.29it/s, loss=7.6] 

argmax start logits shape: tensor([ 75,  27,  63,  40,  80,   2,  77,   1,  49,   7,  49,  53,   2,  11,
         42,   6,  13,  19,  17,   5,  52,   0,  24,  57,   8,  71, 103,  76,
         23, 104,  57,  35,  30,  57,  53,  28,  25,  51,  92,  11,  25,  11,
          9,  21,   1,  37,  98,  23,  80,  45,  77,  57,  38,  91,  14,  22,
         55,  48,  66,  27,  49,  79,   4,  77], device='cuda:0')
argmax end logits shape: tensor([ 76,  52,  63,  10,  93,  40,  55,   3,  33,  61,  11,  17, 120, 112,
         22,  63,  12,  74,  36,  40,  33,  43,  39,  91,  84, 131,  60,  88,
         39, 100,  38,  10,  57,  24,  53,  16,  14,  10,  23,  11,  13,  30,
         49,  22,   3,  38,  88,  77,   8,   5,   9,  31,  19,  22,  24,  22,
         42,  61,  79,  47, 111,  70,   4,  35], device='cuda:0')
argmax start logits shape: tensor([  0,   6,   2,  35,   0,  30,  58,  81,  46,   0,  19,  64,  22,   4,
         79,  75,  76,  25,   0,  64,   0,  23,  37,  46,  56,  21, 104,  24,
         

Epoch 2/10:  72%|███████▏  | 224/313 [00:14<00:05, 15.11it/s, loss=8.03]

argmax start logits shape: tensor([ 40,  93, 116,   0,  26,   5,  76,  54,  23,   0,  21,   1,   0,  46,
         54,  79,  40,  20,  91,  15,  99, 110,  82,  73,  18,   7,   3,  13,
          3,  80,  41,   0, 119,  68,  56,   0,  72,   0,   0,   5,  55,   0,
         65,   0,   5, 118,  59, 101,  22,   0,   0,  28,  22,   0,  96,   0,
         15,   0,   0,   2, 111,  75,  68,   8], device='cuda:0')
argmax end logits shape: tensor([122,  68, 101,  23,  73,   5,  31,  54,   7,  51,  26,   1,   5,  43,
         74,  79,   3, 112,  45,  95,   5,  24,  14,  67,  35,  44, 128,  39,
         16,  77,  65, 117, 122,  69,  61,  68,  73,  32,  45,  71,  13, 106,
         67,   4,   5, 124,  84,  90,  71,  38,  22,  19,  17,  90, 152,  77,
         76,  50,  45,  59, 111,  61,  68,  84], device='cuda:0')
argmax start logits shape: tensor([  0,   0,   0,   0,  76,  55,  27,  21,  63,   0,  80,   0,  85,   2,
         18,  37,  42,   2,  35,   0,  41,  43,  93, 102,  80,   0,  69,   8,
         

Epoch 2/10:  73%|███████▎  | 228/313 [00:14<00:05, 15.19it/s, loss=8.02]

argmax start logits shape: tensor([ 88, 100,   0,  65,  60,   8,  90,  16,  61,  21,  53,  32,   0,  40,
          5,   0,   0,  19,  11,  39,  37,   0,  13,   7,   8,  24,  52,  30,
         29,  41,   0,  37,   6,  23,  61,  86, 156,  41,   0,  52,   0,  17,
          0,  57,  83,  36,  66,  54,  49,  27,  12,   0,  21,   0,  48, 143,
         56,   6,   0,   0,   7,  11,  26,   0], device='cuda:0')
argmax end logits shape: tensor([ 96, 113,  37,  44, 107,   8,  24,  49,  94,   5,  53,  28,  94,  40,
          6,  13, 128,  19,  67, 136,  53, 106,  31,  17,  47,  77,  40,  42,
         24,  54,  24,  34,  44,   4,  35,  65,   5,  54,  49,  21,  72,  17,
         76,  57,  62,  11,  20,  66,  52,  24,  55,  36,  27,  92, 100,  61,
         56, 107,  29,   3,  50,  14,  81,  54], device='cuda:0')
argmax start logits shape: tensor([ 41,   0,  51,   8,  68,   0,  42, 112,  48,   9,  26,  70,  19,   0,
         30,  92,  56,   6,   8,  12,   0,  33,  23,  19,  42,  38,  27,  80,
         

Epoch 2/10:  73%|███████▎  | 230/313 [00:15<00:05, 15.04it/s, loss=8.19]

argmax start logits shape: tensor([ 64,  75,  39, 112,   0,   2,  96,  47,  41,  10,  41,  32,   0, 119,
         83,  59,   0,  53,  85,   0,   8,   0,  23,  11, 127,  48,  54,  44,
         85,   3,  26,   9,  61,  91,  51,  54,  83,  71,   3,   7,  91,  41,
         18,  86,  38,  45,  29,  69,  41,  94,  12,  14,  19,   0,  47,  14,
         72,  64,  24, 118,   8,   0,  17,  18], device='cuda:0')
argmax end logits shape: tensor([ 69,  77,  45, 112,   7,  37,  21,  10,  43,  39,  97,  36,  39,  67,
        154,  91,  21,  30,  14,   4,   2,  10,  58,  18, 120,  21,   2,  44,
          5,   3,  26,  12,  81,  13,  67,   5,  84,  56,   6, 124,  91,  40,
          8,  37,  39,  30,  29,  89,  41,  17, 105,  15,  18,  90,  30,  14,
         41,  30,  17, 119,  70,  40,  68,  91], device='cuda:0')
argmax start logits shape: tensor([ 82,  90,  39,  82,  26,  30,   0,  44,  79,   0,  24,   2,   0,  17,
          0,  82,  38,  25,  75, 116,  65, 107,  37,  65,  24,  95,  84,  71,
         

Epoch 2/10:  75%|███████▍  | 234/313 [00:15<00:05, 15.04it/s, loss=7.54]

argmax start logits shape: tensor([ 37,  54,  12,  25,  54,  32,  31,  73,  87,  55, 114,  84,  15,  56,
         67,   1,  22,  34,  16,  75,  53,  65,  40, 104, 120,  69,  42,  39,
         31,  36,   6, 132,  89,   3, 121,  24,  56,   3,  38,  65,  59, 126,
         24,  52,   0,  55,  28,   2,  43,  85,  81, 111, 107,  44, 109, 125,
        107,  16,   7,  17,   2,  14,  78,  19], device='cuda:0')
argmax end logits shape: tensor([105,  46,  42,  11,  73, 117,  26,  26,  30,  58, 115,  29,  30,  53,
         72,  90,  55,   9,  61,  80,  78,  32,  17,  69,  43,  10,  65,  54,
         14,   7,  38, 109,  27,   3,  55,  20,  12,  99, 114,  65,  65,   3,
         27,  59,  30,  56, 135,  13,  42,  85,  43,  65,  76,  20,  85,  49,
         67,  16,  58,  82,   2, 101,  78,  75], device='cuda:0')
argmax start logits shape: tensor([116,  32,  73,   8,   2,  99,  89,  36,  60,  71,  70,   4,  26, 102,
         21, 143,  26,  83,  71,  88,  96, 101, 124,  42,  96, 100,  68,  41,
         

Epoch 2/10:  75%|███████▌  | 236/313 [00:15<00:05, 15.16it/s, loss=7.65]

argmax start logits shape: tensor([  5,   0, 142,  82,  48, 127,   4, 111,  70,  59,  31,   9,  91, 105,
          2, 140,   1,  30,  40,   2,  28,  13,   9,  48,  34,  47,  34,  90,
          8,   7,  26,  50,  69,   3,  24,  25, 108,  42, 112,   4,  13,  48,
         58,  17,  15,   7,  13,  29,  72,  52,  36,  20,  37, 110,  31,  61,
         33, 100,   8,  55,   2,  29,   3,  40], device='cuda:0')
argmax end logits shape: tensor([ 28,  23, 142,  62,   2,  49,  21,  47,  64,  89,  67,  12,  18,  44,
         14, 103,   5,   7,  61,   3,  18,   2,  11,  91,  86,  77,  63,  93,
         22,  88,  27,  94,  27,   4,  56,  49,   2,  67, 112,  13,  23,  49,
         44,   2,  37,   7, 102,  66,  47,  16,  68,  39,  37,  28, 127,  18,
         33,  59,  50,  44,   3,   2,  10,  66], device='cuda:0')
argmax start logits shape: tensor([  0,  65,  30,   0,  26,  13,  39,  49,   2,  53,  14,  60, 101,  59,
          0,   0,  39,  27,  12,   7,  61,   0,  78,  11,   0,   0,   0,  39,
        1

Epoch 2/10:  77%|███████▋  | 240/313 [00:15<00:04, 15.17it/s, loss=7.68]

argmax start logits shape: tensor([ 54,   0,  41,  72,  79,   0,  60, 115, 112, 107,  74,  79,   3,   7,
         31, 110,   0,   5,   0,  80,  45,  40, 107,  26,   0,  39,   0,  51,
          0,   0,   1,  21,  10,   0,  62,  85,  28,   0,   0,  42,  32, 132,
         92,   0,  46,   0, 113,  17,  89,  48,   0,   0, 104,  91,  61,   0,
         70,   0,  15,   0,  13,   0,   0,   0], device='cuda:0')
argmax end logits shape: tensor([  9,  21,  77,  72,  47,  49,  33,   9, 112, 107,  45,  88,  51,   3,
        109, 111,  12,  35,  68,  80,  74,  24,  75,  77,  41, 124,   8,  43,
         29,  25,  61,  18,  46,  44,   2, 103,   5,  30,  81,  35,   3,  84,
         92,  38,  26,  66,  30,  18,  24,  17,  60, 112, 106,  14,  81,  39,
         35,  87,  15,  40,  52, 103,  57,  59], device='cuda:0')
argmax start logits shape: tensor([  0,  86,  80,  12,   5,  85,  63,   7,   7,  68,   0,  81,   0,   2,
         96,  42,  71,   0,  26,  20,  68,  52,   2,  39,  62,   0, 111,  35,
         

Epoch 2/10:  78%|███████▊  | 244/313 [00:16<00:04, 15.10it/s, loss=8.16]

argmax start logits shape: tensor([ 45,   7,   6,  21,  89,  82,  44, 107,  41,   0,  57,  23,   0,  24,
         27,  16,  25,  47, 150,  27,  65, 100,  19,  29,  29,  78,  70,   2,
         65,  11,  18,  84,   2,  47,  26,  50,  91,  40, 148,  51,  40,  54,
         80,  45,  28,  93,  23,  92,   2,  45,  34,  74,  43,  23,  34,  51,
         28,  17,  30,   5,   2,  17,  42,  91], device='cuda:0')
argmax end logits shape: tensor([ 24,  28,   2,  86, 120,  45,  52, 123,  85,  45,  94,  25,  48,  25,
         27,  96,  58,  47,  11,  17,  63, 117,  11,   7,  10,  80,  62,  96,
        126,  12,  19,  74,   3,  28,  75,  16,  17,  40, 141,  44,  33,  42,
         90,  68,   7,  71,  61,  35,  59, 105,  34,  74,  85,   5,  75,  43,
         47,  52,  30,   5,   2,  37, 111,  28], device='cuda:0')
argmax start logits shape: tensor([ 32,  29,  72,   9,  87, 108,  34,  25,  49,  34,  37,  21,  51,  51,
         28,  24,  46,  26,  87,  74, 145,  13,  63,  50, 109,   2,  35,  60,
         

Epoch 2/10:  79%|███████▉  | 248/313 [00:16<00:04, 15.17it/s, loss=7.56]

argmax start logits shape: tensor([ 33,  78,  64,  78,   0,   7,   3,  80,   0,   0,  67,  30,  66,  13,
        131, 150,  63,   6,   0,  16,  30,   0,   0, 118,  12,   2,  80,   2,
         10,   0,  78,   0,  59,   6, 108,   9,   0,  13,  52,   0,  12,   6,
          0, 118,   2, 112,   7,  13,   0,   0,  19,  45,  73,  71,   0, 136,
          0,  95,  38,  25, 100,  92,  11,   0], device='cuda:0')
argmax end logits shape: tensor([ 80,  63,  48,  78,  23,  17,   3,  73, 108,  27,  67,  30,  30,  69,
         15,  50,  45,  16,  81,  49,  30,  64, 112,  58,  52,  15,  23,  50,
         80,  53,  31,  59,  59,   2, 111,  44,  43,  69,  85,  65,  13,  40,
          4,  23,  31,  66,  31, 102,  32,  88,  44,  94,  73,  81,  71, 109,
         11,  88,  32,  89, 104,  35,  44,  53], device='cuda:0')
argmax start logits shape: tensor([ 88, 101,   0,  74,   0,  14,   0,   7,   0,  11,   2,  24,   0,  86,
          0,  78,   0,  81,   0,   0,   0,  70, 104,  58,  16,  17,   0, 101,
         

Epoch 2/10:  80%|███████▉  | 250/313 [00:16<00:04, 15.27it/s, loss=7.09]

argmax start logits shape: tensor([ 23,  82,   2,  21,  22,   0,   0,   3,  71,  47,  22,  81,   1,   0,
         62,   0,   6,  27,  45,   0,   0, 110,   0,  46,  12,  72,   0,  51,
          0, 104,  54,   0, 102,  42,  58,   0,  51,  20,  82,  72, 121,   0,
         25,  79, 115,  67,  75,  71, 117,  63,  86,  64,  82,  14, 101,  79,
          0,  13,  10,  61,   6,   0,   0,   0], device='cuda:0')
argmax end logits shape: tensor([ 23,   3,  42,   7,  22,   8,  41,   4,  93,  47,  19,  81, 138, 120,
         20,  51,   3,  27,  59,  69,   3,  12,  35,  46,  14,  81,  33,  95,
         83,  17,  54,  14,  44,  39,  31,  76,   4,  46,  62,  67,  35, 102,
         59,  17,   3,  34, 117,  12,  21,  64,  56,  65,  25,  23,  73,  34,
         34,  31,  16,  34,   7,  71,  20,  32], device='cuda:0')
argmax start logits shape: tensor([ 63,  89,  78,  20,  38,  99,  85,  40,  92,  33,   0,  37,   0,  55,
          5,  18,  42,  90,  26,  14,  17,  47,  15,  36,   0,   0, 101,  78,
         

Epoch 2/10:  81%|████████  | 254/313 [00:16<00:03, 15.25it/s, loss=7.46]

argmax start logits shape: tensor([ 16,  66,   9,  88,   0,  90,  38,  53,  60,  56,  33,  45, 116,  62,
         25,  52,  59,  25,  90,  14,  40,  13,  26,  55,  84, 117,  57,  20,
         44,  18,  63,  17,  24,  28,   9,   2,  21,  61,   7,  11,  87,  29,
         12, 101,  88,  70, 151, 117,  40, 117,  51,  99,  13,   0,   2,  45,
         11,  72,  65,  10,   0,  17,  62,  75], device='cuda:0')
argmax end logits shape: tensor([ 40,  93,  67,  81,  21,  43,  19,  97,  30,  88,  59,  20,  24,  48,
         75,  11,  96,  25,  69,  31,  36,  41,  71,  11,  85,  58, 115,  22,
         52,  40,  63,  38,  24,  39,   2,  47,  19,  61,   7,  12, 103,  64,
          6,  62,  50,  27,  51,  51,  59,  42,  53,  47,  48,  36,   2,   5,
         21,  87,   6,  41,  17,  47,  74,   6], device='cuda:0')
argmax start logits shape: tensor([ 63,   1,  28,  37,  29,  41,  60,  16, 101,  64,   2,   4, 102, 107,
         72,  29,   2,   8,  44,  27,   3,  48,  31,  15,  66, 112,  55,  34,
        1

Epoch 2/10:  82%|████████▏ | 258/313 [00:16<00:03, 15.23it/s, loss=7.49]

argmax start logits shape: tensor([ 31,  17,  20,  43,  49,  19,  23,  21,  15,  12,  65,  44,  36,  22,
          3,  29,  60,  59,  78,  64,  74,  40,  47,  40,  38,  39,  31,  53,
         41,  51,  85,  49, 119,  46,  82,   0,   3,  89,  12,   0,  70,  74,
         11,  70,  80,  13, 104,   9,   0,  92,   8, 117,   0,  34,  63,  69,
        105, 115,  66,  11,   7,   5,  55,   2], device='cuda:0')
argmax end logits shape: tensor([ 42,  11,  43,  29,  20,  78,  24,  19,  18,   9,  48,   8, 125,  64,
         92,  87,  32,  18,  78,  42,  60,  41,  47,  26,  88,  63,  86, 101,
        106,  50,  59, 115,  15,  90, 138,   6,  83,  12,  24,  67,  83,  77,
        118,  33,  97,  47,  35,  37,  47,  57,   3,  92,  62, 124,  64,  33,
          6, 126,  41,  96,  39,  33,  43,   5], device='cuda:0')
argmax start logits shape: tensor([120,  53,   0, 123,  83,  63,  39,  14,  68,  15,  25, 105, 109,   0,
         63, 100,  35,   0,  14,   2,   2,   2,  71,  40,  55,  32, 115, 108,
         

Epoch 2/10:  83%|████████▎ | 260/313 [00:17<00:03, 15.06it/s, loss=7.59]

argmax start logits shape: tensor([127, 113,  30,  50, 115, 107,  52,   7,  66, 110,  97,  68,  80,   4,
          9,   3,   7,   0,  24,  74,  28,  52,  83,  14,  11,  37,  45,  99,
         20, 121,   2,  10, 101,  13,   3,  14, 110,  42,  20,   9,   7,   5,
          2,  42,  49, 100,  25,  60,  23,  99,  28,  76,  63,  24,  55, 118,
         26,  58,  38,   0,  84, 108,  31,  65], device='cuda:0')
argmax end logits shape: tensor([ 18,  32,  61,  49, 110,  98,  49,  37,  84, 110,  98, 103,  23,  21,
         49,  58,   5,  20, 107,  88,  30,  45, 101,  87,   8,  82,   6,  75,
        134, 118,  33,  10, 129,  75,  44,  35, 114,  10,  86,  74,  22,  73,
         94,  23, 113, 109,   7,  94,  26,  41,  25, 116,  85,  88,  40,  22,
         29,  38,  38,  18,  94, 107,  10,  83], device='cuda:0')
argmax start logits shape: tensor([ 93,  77,   1,  18,  78,   2,  63,   0,  24,  76,  47,   0,  53, 112,
         86,  51,  17,  34,  58,   3,  51,  50, 112, 114,  21,  19,  19,  46,
         

Epoch 2/10:  84%|████████▎ | 262/313 [00:17<00:03, 15.07it/s, loss=7.59]

argmax start logits shape: tensor([ 52,   0,  16,  46,  11,  22,  21,  65, 108,   3,  53,  54, 120,   0,
         91,  26,  75,  24,  30,  35, 102,  11, 111,  29,  13,   4, 103,  10,
         39,   8,  17,  47,  11,  98,  79,  27,  30,  83,  21,  94,  17,  40,
        100,  13,  30,  40,   5,  13,  61,  54,  52,  11,   1,  24,  39,  76,
         24,   9,  38, 105,  38,  87,  32,   6], device='cuda:0')
argmax end logits shape: tensor([ 83,  10,  17,  26,  14,  40,  23,  23,  37,   3,  53,  56,  22,  49,
         67,  28,  77,  47,  30,  39, 104,  37,  49,  12, 102,   3,  94,  13,
        118,  77,  78,  27,  18,  76, 127,   4,  28, 108,  22, 114,  95,  76,
        122,  70,  14,  41,   6,  15,  25,  28, 102,  13,  90,  19,  34,   5,
          2,  75,  59,  62,  29,  13,  78,  36], device='cuda:0')
argmax start logits shape: tensor([113,  44,  44,  80,  15,   2,   9,   2,  39,  27,   9,  44,  12,  23,
         20, 113,  52,  54,  97,  25,  72,   6,  67,  71,  57,  71,  91,   9,
         

Epoch 2/10:  85%|████████▍ | 266/313 [00:17<00:03, 15.16it/s, loss=8.02]

argmax start logits shape: tensor([ 84,  22,   3,  63,  84,  42,  14,  79,  19, 120,   6,   3,  69,   4,
         84,  41,  57,  43,  93,  72, 109, 110,  50,  37,   5,  55,  45,  15,
         54, 108,  55,  52,  54,  27,  19,   3,  17,  31,  80,  81,  79,   7,
         43,  11,  98,  53,  44,   3,  25,  70,  42, 133, 100,   8,  85,  26,
         58,  46,  94,  23,  27,  79,  47,  59], device='cuda:0')
argmax end logits shape: tensor([ 94,  91,   3,  68,  88, 123,  13,  36,  36,  43,  32,  86,  18,   4,
         96,  30,   9,  45,  93,  79,  82,   3,  51,  38,  11,  72,  66,   9,
        105,  68,  12,  45,  46,  62,  46, 127,  41, 115,  80,  34,  46,   7,
         57,  31,   8,   5,  27,   3,  13,  37,  70,  25,  16,   8,  83,  62,
         37,  48,  94,  65,  79,  31,  28,  61], device='cuda:0')
argmax start logits shape: tensor([ 59, 109,   1,  63,  30,  30,  18, 105, 118,  83,  33,   8,  21,   2,
         49,  91,  29,  34,  24,  89,  53,  88,  24,  74,  40,  63,  80,  40,
         

Epoch 2/10:  86%|████████▌ | 268/313 [00:17<00:02, 15.17it/s, loss=7.96]

argmax start logits shape: tensor([ 87, 128,  24, 113,  16,  40,  42,  11,  91,  73,  62,  59,  12,  73,
         21,  26,  86,  22,  38,  35,  82,  78,  23,  51,  17, 104,   6,  96,
         74,  18,  56, 134,   5,  17,   6,  74, 111,  62,  97,  78,  13,  16,
         79,  47,  71,   8,   2,   2,  26,  73, 107, 111,  37,   2,  16,  16,
        156,  63,   5,   0,   3,  40,   2,   4], device='cuda:0')
argmax end logits shape: tensor([ 11,  28,  71,   4,  22,  62,  87,  23,  92,  89, 142,  59,  36,  76,
         23,  81,  81,  72,   2, 104,  62,  84,  81,   6,  18,  76,  27, 101,
          8,  56,  18, 128,   5,  26,  15,  74,  45,  62,  33,  32,  21,  17,
         81,  90,   4, 100,   2,  54,  69,  64, 107,  34,  28, 113,  34,   5,
         97,  11,   5,  42,  44,   6,   2,   4], device='cuda:0')
argmax start logits shape: tensor([ 73,  48,  43,  37,  77,  88,  15,  40,  18,  17, 117, 104,  52,   2,
         63,  12,  20,  41,  54,   4,  40,  41,  24,  54,  63, 119,  95,   5,
         

Epoch 2/10:  87%|████████▋ | 272/313 [00:17<00:02, 15.36it/s, loss=7.3] 

argmax start logits shape: tensor([  5, 103,  10,   2,  65,  33,  32,  60,   8,  21,  46,  22,  34,  31,
          9,  32,  22,  47,  16,  25,  87,  63,  17,  98, 114,  70,  32,  23,
         13,   9,   5,   5,  25,  90,  22,  22,  93,  85,  36,  49,  83,   7,
         90,  50,  30,  88,  97,  18,   9,  62,  35,  83, 137,  66,   7,  85,
         98,   2,   2,  11,  15,  65,  42,  14], device='cuda:0')
argmax end logits shape: tensor([ 40, 113,  16,  91,  57,  44,  32, 124,  23, 123,  46,  23,   3,  11,
         98,  58,  24,  48,  49,  40,  74,   4,  17,  26,  97,  75,  47,   2,
         99,   9,  77,  27,  38, 102, 130,  16,  93, 101,  69,   5, 135,  20,
         61,  60,  80,  43,  45,  21,  89,  66,  75,  22,   5,  59,  98,  79,
          3,  38,   2,  55,  49,  24, 123,  14], device='cuda:0')
argmax start logits shape: tensor([ 81,  42,  16,  89,  39,  91,   3,  46, 107,  47, 113,  44,  59,   2,
         37, 101,   1,  49,  80,  56,  17,  27,  51,  53,   3,  18,  15,  79,
         

Epoch 2/10:  88%|████████▊ | 276/313 [00:18<00:02, 15.31it/s, loss=7.34]

argmax start logits shape: tensor([ 57, 126, 101,  15, 121,  55,  13,  12,  65,  70,   3,  52, 107,  23,
         31,  87,  59,  71,  67,   8,   2,  45,  28,  80,   5,  54,  86,  24,
         22,  76,  35,  56,  54,  64,  38,  54,  47,  96,  56,  45,   0,  35,
         84,  87,  30,  32,  39,  41,  78,   7,  23,  33,  38, 104,  72, 145,
         43,   2,   6,  58,  83,  22,  52,  61], device='cuda:0')
argmax end logits shape: tensor([ 20, 136, 101, 102, 121,  76, 103,  35,  59,  83,   4,  11, 109,  17,
         33,  38,  21,  22, 106,  60,  20,  30,  92,  42,  76,  46,  59, 126,
         52,  67,   3,  22,  14,  83,  10,  54,  47, 126,  37, 116,  18,  27,
         45,  74, 123,  31,   4,  24,  78,   8,  23,  61,  34,  58,  74, 146,
         53,  99,  32,  58,   8,  82,  17,  62], device='cuda:0')
argmax start logits shape: tensor([  2,  58,  56,  34,  59,   7,  43, 100,  60,  61,  18,  28,  72,  68,
         44, 123,  74,  94,  39,  64,  27,  67,  17,  68,  23,  49,  82,  54,
         

Epoch 2/10:  89%|████████▉ | 280/313 [00:18<00:02, 15.30it/s, loss=7.6] 

argmax start logits shape: tensor([  1,  91,   0,  28,  44,  50,  40,   8,  62,   3,   8,  54,  72,  80,
          0,  22,   0,  53,  51,   3,  96,  46,  39,  12,  60, 137,  67,  52,
         31,  89,  30,  44,  13,   0,  68,  18,  24,  30,  25,  48,  39,  52,
         83,  39,  15,  68,   0,  15,   0,   8,  77, 107,  76,   0,  91,  36,
          6,  46,  49,  60,  14,  49,   0,  22], device='cuda:0')
argmax end logits shape: tensor([  3,  67,  64,  17,  44,  17,  57,  23,  34,   3,  47, 126,   6,  52,
         41,  23,  72,  25,  22,  10,  45,  53,  42,  28,  24, 135,  12,  84,
         79,  95,  37,  62,   4,  14,  26,  17,  26,  12,  53,  51,  69, 119,
         53,  10,  13, 113,  26,  16,  36,  83,  10,  89,  77,  44,  91,  66,
         22,  76, 106,  28,  22,  25,  12,  23], device='cuda:0')
argmax start logits shape: tensor([ 36,   7,  98,  65,   4,  18,  67,  11,   0,  26, 141,   9,  56,  13,
         95,   6,   6,   2,  63,   2,   0, 115,  54,   2, 116,  67,  12,  22,
         

Epoch 2/10:  90%|█████████ | 282/313 [00:18<00:02, 15.29it/s, loss=7.89]

argmax start logits shape: tensor([  6,  55,  95,  13,  45,  29, 108,  20,  88, 106,   0,  95,  20,  85,
         84,  15,  64, 113,   2,  59,  57,  61,  24,   8,  28,   2,  95,   8,
          0,  25,  61,   0,  85,  77,  68,  11,  71,  49,  23,  57,  54,  41,
         16,  21,  46, 111,  22, 117, 114,  51, 128,  60,  99,  31,  24,  92,
         62,  54,  34,  76,   0, 100,   0, 150], device='cuda:0')
argmax end logits shape: tensor([ 76,  70, 114,  39,  28,  17,  20,  39,  41,  83,  32,  45,  39,  88,
         78,  37,  79, 125,  48,  54,  29,  62,  33,  91,  77,  54,  61, 126,
         34,  52,  29,  42,  18,  43,  34,  28,   4, 108,  54,  72,  50,   5,
         13,   3,  10, 110, 105,   5,  91,  48, 130,  55,  98,  28,  18,  84,
          5,  40,  98,  57,  34, 117,  48, 123], device='cuda:0')
argmax start logits shape: tensor([ 19,   0,  15, 109,  66,  44,   2,   6,  69,  50,  22,  99,  93,  21,
         10,  25,  20,  84,  76,  46,  99,  59,   0,  63,  20,  67,  17, 108,
         

Epoch 2/10:  91%|█████████▏| 286/313 [00:18<00:01, 15.18it/s, loss=7.31]

argmax start logits shape: tensor([109,  52, 120,  33, 124,   9,  88,  37,  78,  15,   3,   2,  79,  33,
         88,  49, 120,  94,  11,  93,  90,  17,  45,  72,  21,  11, 148,  27,
         45,  15,  71,   0,  78,   2, 106,  10,  79,  27,  77,  27,  86,  36,
          0,   8, 107,  59,  96,  34,  49,  15,  11,  31,   0,  57,  51, 105,
         50,  61,  23,  44, 104,   3,  20,  19], device='cuda:0')
argmax end logits shape: tensor([ 83,  27, 126,  99,  18, 126,   5,  86, 104,   3,  15,   2,   3,  20,
         63,   3,  43,  36,   5,  96,  45,  85,  40,  58, 111,  10,  78,  49,
         47,  26,  38,  39,   2, 105, 109,  33,  87,  45,  15,  27,  16, 124,
         16,  51, 123,  66, 112,  15,  26,  13,   5,  52,  33,  61,  51,  20,
         51,   6,  65,  28, 114,  97,   4,  78], device='cuda:0')
argmax start logits shape: tensor([ 52,  90, 103,  46,  56,  33,  52,  72, 118,  40,  67,  42,  50,  17,
         59,  20,   0,  55,  45, 103,   0,  43,  13,  54,   7,  30,  36,  45,
         

Epoch 2/10:  92%|█████████▏| 288/313 [00:19<00:01, 15.14it/s, loss=7.78]

argmax start logits shape: tensor([ 37,  10,  40, 107,  77,  90,  76,   6,  70,  32,  97,  92,  35,  71,
         53, 111,  34,  45,  57,  10, 101,  82,   0,  34,  40,   7, 112, 106,
         12, 107,  99,  74, 141,  29,   2,  51,  81,  93,  72,  23,  58,  35,
         85,  34,   7,  44,  28,  39, 108,  87,  19,  83,  11,  65,  61,  31,
         26,  19,  59,  42,  70,  11,   0,  15], device='cuda:0')
argmax end logits shape: tensor([  3,  52,   4, 109,  10, 102,  26,  27,  47,  30,  35,  93,  90,  67,
         30,  32,  35,  75,  56,  23,  64, 116,  12,  50, 122,  23,  55,  79,
         93, 135,  72,  88,  64,  26,  29,  10,   3, 110,  75,  23,  26,  35,
         52, 128,  22,  24,  16,  39,  61,  90,  68,  84,   5,   3,  18,  31,
         31,  38,  42,  27,  89,   8,  39,  37], device='cuda:0')
argmax start logits shape: tensor([  7,  91,  71,   2,  71,  34,  35,  63,  10,  20,  49,  58,  36,  76,
         37, 104,  10,  70,  67,   2,  45,   9,  64,  29,  73, 100,  90,  89,
         

Epoch 2/10:  93%|█████████▎| 292/313 [00:19<00:01, 15.05it/s, loss=7.55]

argmax start logits shape: tensor([ 82,  33, 100,  62,  66,  66,  39,   0,   0,  62, 109,  57,  89,  74,
         94,   0, 104,   0,   2,  52,  88, 105,  56, 129,  20,  53,   0,  95,
        119,  10,   7,   5,  54,  63, 137,  70,  55,  80, 118,   0,  23,  70,
         20,  86,  28,  40,  27,  51,  79,   0,   0,  27,  98,  29,  90, 101,
         70,   0,  67,  59,  36,  12,   3,  53], device='cuda:0')
argmax end logits shape: tensor([116,  33, 101,  41, 112,   3,  20,  56,  66,  23,  24,  20,  25,  35,
         25,  65,  23,  45,  14,  36,  11, 106,   3, 118,  42,  45, 108,  41,
         37,  21,  78,  30,  32,  14,  81,  93,  55,  59,  13,  51,  16,  53,
         51,  86,  18, 105,  80,  44,  59,  42,  88,  52, 129,  62,  78,  73,
         36,  26,  27,  21,  58,  50,  23,  45], device='cuda:0')
argmax start logits shape: tensor([  0,   0,  83,  25,  70,   2,   3,   2,  73,   0,   0,  77,   0,   9,
         19,  58,  43,  76,   0,   0,  97,  15,  92,  42,  71,  18,  28,  41,
         

Epoch 2/10:  94%|█████████▍| 294/313 [00:19<00:01, 15.19it/s, loss=8.2] 

argmax start logits shape: tensor([  0,  59,  22,  34,  86,   0,  27,  65,  80,  18,   1,   2,  49,   0,
          0,   0,   0,  64,  40,  58,   0,   0,   0,   0, 146,  13,   0,  18,
          0,   0,   0,   0,  11,   0,  30,   0,  86,  53,   0,   0,   0,   2,
          0,   0,   0,   0,   0, 108,  49, 118,   0,   2,   0,  95,  49,   0,
         76,   0,  65,  17,  75,   0,   0,  88], device='cuda:0')
argmax end logits shape: tensor([ 87,  63,  74,  34,  89,  61,  17,  91, 107,  78,  88,  76, 109,  20,
         98,  12,   3,  65,  41,  21,  11,   3,  53,  83,   3,  37,  75,  19,
         59,  77,  84, 105,  67,  26,  30, 118,  92,  32,  15,  76, 125,  13,
         31,  37,  21,  13,  24,  27,  18,  23,  70,  27,  18, 113,  18, 110,
         91,  81,   3,  40,  23,  82,  74, 111], device='cuda:0')
argmax start logits shape: tensor([  0,   0,   0,   0,   7,  71,  77,   0,   0,  47,   0,   0,  45,   0,
          0,  28,   0,   0,   0,  24,   0,   0,  74,   0,   0,   6,   0,   0,
         

Epoch 2/10:  95%|█████████▌| 298/313 [00:19<00:00, 15.14it/s, loss=8.01]

argmax start logits shape: tensor([  0,   7,   0,   0,   3,   0,   0,   0,  52,  34,   9, 105,   0,  11,
         60,   3,   0,  49,  18, 131,  17,  43,   0,  19,   0,   0,  33,  85,
          0, 107, 105,  61,   0, 121,  58,   0,   0,  43,  13,  11,   0,  91,
          0, 114, 107, 105,  30,  71,  72, 104,  95,   0,  22,  31,   0,   2,
         85,  59,  19,  15,   3,   0,  86,   0], device='cuda:0')
argmax end logits shape: tensor([ 28,   3,  63,  10,  46,  20,   9,  37,  22,   5,  19,  47,  35,  73,
         38,   3,  21,  49,  45,   5,  54,  10,  68,  98,  58,  35,  28,  39,
          3, 106, 127,  31,  81,  76,  58,  23, 119,  21,  93,  30,  30,  53,
         22, 115,  48,  93,  30,  27,  27, 106,  30,  38,   8,  50,  15,  81,
         85,  81,  75,  59,   5,  35,  89,  85], device='cuda:0')
argmax start logits shape: tensor([ 32,   0,  30,   7,   7,  47,   0,  61,   0, 103,  51,   0,   0, 128,
         37,   0,  11,   0,  33, 115,   0,   0,   0,   0,  64,  98,   7,  29,
         

Epoch 2/10:  96%|█████████▋| 302/313 [00:19<00:00, 15.19it/s, loss=7.4] 

argmax start logits shape: tensor([ 28,  10,  13,  27,  12,   2,  11,  56,  16,  91,  76,  58,  24,  88,
         62,  37,   1, 117,  44,  65,  35,  27,  72,   7,  58,  41,  41,   0,
         15,  15,  76,  21, 108,  20,   3,  30,  22,  82,  17,  20,  88,   3,
         12,   2, 114,  39,  19, 104,  11,  92,  40,  76,  29,  14,   5,  46,
         83,   3,   3,  50,  23,  24,   0,  60], device='cuda:0')
argmax end logits shape: tensor([ 87,  24,  87,  25,  28,  11,  12,   5,   5,  49,  60,  93, 118, 111,
         27,  56,   3,  81,  33,  81,  33,  24,  72,  16,  92,  41,  42,  50,
         99,   8,  75,  88,  42,  20,   5,  31,  20,  58,  84,  22,  43, 106,
         39,  29,  83,  24,  84,  83, 104,  92,  63,  22,  32,  15,  79, 104,
         11, 102,  92,  80,   4,  24,  41,  21], device='cuda:0')
argmax start logits shape: tensor([ 28,  86,  54,  79,  51,  89,  33, 114,  27,   1,  21,  95,  17,   6,
          1,  62,  95,  45,  23,   5,  74,  32,  59,  71,  48,  79,  21,  69,
         

Epoch 2/10:  98%|█████████▊| 306/313 [00:20<00:00, 15.10it/s, loss=7.51]

argmax start logits shape: tensor([114,   7,  15,  16,   0,  12,   0,   0,  17,   0,   0,  18,   2,  25,
          0,   0,  61,  56,   0,  92,  12,  43,  71,  65,  67,  74,   0,  13,
        119,   2,   0,   5,  37,   0,  31,  27, 121,  26,  30,   0,  20,   0,
         20,  75,  32,   0,  67,  37,   0,  21, 136,  62,  82,   0,   0,   0,
         28,  83, 111,   2,  42,  17,  71,   0], device='cuda:0')
argmax end logits shape: tensor([ 67,  30,  11,  75,  42,  12,  86,   7,  70,  80,  17,  23,   2,  10,
         88,   4,  94,  37,  58,  57,   5,  10,  65,  66,  56, 116,  65,  14,
         51,  91,  33,  18, 123,  39,  77,  32,  76,  73,   3, 107,   5,  33,
         50,  32,  57,  61,  57,  51,  36,  13, 113,  38,  46,  87,  69,  45,
         37,  62,  54,   5,  62,  32,  69,  56], device='cuda:0')
argmax start logits shape: tensor([ 14,   0,   9,  52,   0,   0,   0,  83,   0,   0, 119,   0,   0,  10,
          0,  68,  84,  18,  58,   0,  28,  45,   0,  53,  12,   0,   0, 121,
         

Epoch 2/10:  98%|█████████▊| 308/313 [00:20<00:00, 15.17it/s, loss=7.55]

argmax start logits shape: tensor([  0,   0,   0,   0,  61,   0,  59,  36,  18,   0,  81,  34, 110,   0,
          0,  65,   0,   0,  78,  37,  54,  90,   0,  77,  48,   9,   0,   0,
          0,  78,   0,   2,   0,  40,  97,  14,  68,   0,   0,  87,  73,  22,
          3,  65,   0,  27,  83,  49,  14,  44, 120,   0,   8,  37,   0,   0,
          0,  72,   0,   0,   0,  21,   2, 120], device='cuda:0')
argmax end logits shape: tensor([  3,  97,  64,   8,  45,  31,  35,  63,  36,  54,  17,  92,  43,  28,
         20,  17,   4,  47,  76,   9,  96,  41,  45,  97,  49,   2,  86,   9,
          6, 131,  95,   3,  62,  18,  69,  21,  18,  55,  74,  87,  82,  15,
         17,  65,  76,  15,  53,  52,   4,  17,  30, 118,  94,  19,  42,  10,
         31,  16,  50,  38, 123, 104,   3,  88], device='cuda:0')
argmax start logits shape: tensor([ 27, 123,  60,  75,  90,  51,  45,  31,  15,   0,  45, 104,  88,  80,
          0,   0,  64,  22,  30, 123,   0,   0,   0,  55,  47,   8,   8,   3,
         

Epoch 2/10:  99%|█████████▉| 310/313 [00:20<00:00, 15.16it/s, loss=7.51]

argmax start logits shape: tensor([  0,  22,  66,  61,   0,   7,  86,   0,  88,   0,  44, 112, 118,   0,
         75,  33,   0,  16,  54,  29,  17,  76,   0,  53,  94,   2,   0,   6,
         43,  48, 126,   0,  29,  16, 104,  85,  11,   0,   0,  71, 126,   0,
         20,  88,  35,   0,   0, 100,  68,  83,  53,  37,  33,  88,  76,   2,
          0,   0,  48,  79, 104,  82,  71,  45], device='cuda:0')
argmax end logits shape: tensor([  3,  68,  88,  59,  24,  26,  87,  66,  39, 113,  64,  56,  77,  34,
         44,  33,  49,  29, 100,  34,  50,  60,  39,   2, 104,   3,  62,   4,
         46,  47, 117, 104,   6,  32,  35, 103,  18,  23,  99,  45, 143,  68,
         56,  39,  25,  67,  23,  33,   7,  67,  42,  37,  62, 121,  87, 111,
         46,  56,  50,  88, 138,  83,  12,  32], device='cuda:0')
argmax start logits shape: tensor([ 65, 106,  91,   7,  46,  11,  53,  40,   0,   0,  71,   2,   0,  42,
          7,  26,  68,  24,  39,  32,  41,  25,   9,  61,  49,  86,  68,   0,
        1

Epoch 2/10: 100%|██████████| 313/313 [00:20<00:00, 15.23it/s, loss=7.71]


argmax start logits shape: tensor([ 13,  61,  21,  38,  33,  51,  50,   0,  22,   3,  54,  37,  86,  34,
         26,  17,   0, 101,  55,  71,  15,  99,  86,  26,  50],
       device='cuda:0')
argmax end logits shape: tensor([112,   3,  21,  60,  63,  68,  67,   1,  35, 116,  78,   2,   5,  80,
         81,  50,  10, 117,  66,  14,  16,  99, 102, 126,  97],
       device='cuda:0')
Epoch 2 Loss: 7.7176


Evaluating:   1%|▏         | 4/313 [00:00<00:08, 35.75it/s]

argmax start logits shape: tensor([ 39,  18,   0,  48,   0,   0,  15,  31,  93,  62,   0,  65,  11,  28,
         82,  23,  54,  54,  76,  83,   0,  79,  64,  49,  61, 103,   0,  26,
         49,  21,   0,  54,  53,   7,   0,  85,  13,  34,   0,   0,   9,  13,
         55,  79,  53,  74,  21,  56,  83,  52,   0, 115,  14,   0,  20,   2,
         37,  15,   2,   6,  13,  54,  79,  24], device='cuda:0')
argmax end logits shape: tensor([ 39,  10,  77,  73,  20,  46,   3,   5,  19,  41,   6,  55,  42,  29,
        103,  45,  55,  54,  78,   3,  55,  95,  83, 101,  21,  14,  55,  11,
         11,  29,  31,   8,  40,  42,  87,  87,  17,   8,  66,  14,   9,  13,
         11,   4,   5,  81,  33,  73,  39,  16,  37, 106,  98,  33, 104,  16,
        125,   6,  83,  37,  22,  69,   4,  23], device='cuda:0')
argmax start logits shape: tensor([ 55,  88,   2,   2,  86,  85,  10,  52,  35,  34,  71, 103,   2,  46,
         17,  33,  13,   0,  23,  51,   2,  44,  85,  18,  17,  10,  17,  33,
         

Evaluating:   4%|▍         | 12/313 [00:00<00:08, 34.04it/s]

tensor([  3,  22,  80,   0,  57, 133,   1,   0,  34,  30,   0,  92,   8, 132,
          3,  61, 108,  58,  14,  49,  33,  53,   0,  52,  12,  22,   2,   4,
         41,   4,  27,  11,  66,   0,  11,  24,   0,  57, 121,  46,  83,   3,
         25,   0,   0,  66,  73,  79, 105,  86,  65,  23, 117,  83,  63,  95,
          0,  74,  77,  64,  63,  10, 113, 147], device='cuda:0')
argmax end logits shape: tensor([  3,  94,  99,  14,  25, 129,   2,  66, 111,   9,   5,  46,  65, 132,
         79,  24,  12,  68,  14,  44,  16,  53,  46,  20,  44,  38,  36,   4,
         38, 105,  47,  52,  31,   3,  96,  88,  64,  57,   6,  16,  83,   3,
         65,   4,  81, 105,  49, 103,  89,  89, 126,  59, 112,  81, 104,  98,
          9,  57,  64,  65,  31,  15, 118,  77], device='cuda:0')
argmax start logits shape: tensor([ 40,  99,  12,  36,  50,   0, 102,  53,  54,  51,  34,   0,  34,   3,
          0,   0,  29,   0,   0,  98, 100,  59,  61,  47,  37,  66,  48,   2,
          0,  16,  35,  26,  34,  90

Evaluating:   6%|▋         | 20/313 [00:00<00:08, 33.98it/s]

tensor([ 54,  38,  98,   0,  90,  58, 110,   0,  39,  24,  37,   0,  65,  49,
         70,  84,  32,  17,  77,  62,  96,   3,  96,  13,  72,  44,  65, 102,
         72,   8,  25,  55,   2,  44,  31,  20, 101, 100,  14,  98,  86,   0,
         93,   0,   0,  98,   0,  14,  60,  84,   2,  40,  40, 113,  17,   0,
         63,  57,  41,  44,  12,  28,   0,  18], device='cuda:0')
argmax end logits shape: tensor([ 43,  91, 129,  21,  46,  58,  68,  52,  22,   2,   9,   3,  66,  88,
         51,  64,  78,  28,  77,  86, 101,  56,  83,  94,  71,  81,  88,  77,
         20,  23,  50,  16,  93,  87,  80,   4,  68,   4,  15,  62, 102,  55,
         92,  84,  39, 129,  51,  29,  62,  66,   5,  41,  41,  57,  42,   8,
         82,  22,  36,  31,  43,  17,  51,   8], device='cuda:0')
argmax start logits shape: tensor([ 72,  46,  36,  25,   0,  53,  62,  60,  60,   0,   8, 133,  18,  15,
         34,  56,  89,  64,  63, 101,   1,  55,  49,  50,   0,  81,   0,  21,
         13,  41,  55,  47,   7,  45

Evaluating:   8%|▊         | 24/313 [00:00<00:08, 34.04it/s]

argmax start logits shape: tensor([ 77,  58, 100,  71,  51,  10,  77,  66,  49,  55,  25, 104,  10,   5,
        108,  96,  30,  30,  19,   0,  19,  34,  25,   0,  33,   0,  23,   0,
         54,  71,   0,  45,  35,   0,  47,  57,  76,  61,  64,  50,   0,  74,
          2,  89,  31,  14,   0,  41,  12,   0,   9,  35,  39, 110,  24,  90,
        100,   0,  75,  60,  27,  33, 121,   4], device='cuda:0')
argmax end logits shape: tensor([ 64,   8, 136,  40,  54,   2,  55,  66,  10,  63,  34, 104,  66,  27,
         26,   9,  57,  31,  18,  95,  16,  55,  80,  82,   2,  50, 131,  50,
         54,  73,  50,  90,  35,   6,  28,  29,  76,  61,  64,  66,  70,  74,
         93,  68,  10,  14,  33,  74,  39,  24,  12,  54,  42,  17,  24,  63,
          6,  20,  76,   4,  27,  26,  57,  39], device='cuda:0')
argmax start logits shape: tensor([ 49,  10,  68,  75, 103,  28,  65,   2,  30,  41,  29,  65,   0,   0,
        100,  12,  48, 121, 141,  12,  83,   5, 100,   0,  21, 111,  50,   0,
         

Evaluating:  10%|█         | 32/313 [00:00<00:08, 33.39it/s]

argmax start logits shape: tensor([ 70,   0,  30,   0,  14,  54,  66,  99,  15,  29, 110,   2,   0, 107,
         58,  29,  50,  35,  49,  31,  91,  26,  64,  63,  62,  86,  35,  23,
         61,  50,  26,  42,  49,  53,  40,  46,   0,  67,  63,  41,   0,  27,
         64,  55,  15,   0,  15,  42,   9,  51,  54,   0,  42,  37,   0,  14,
         29,  80,  93,  19, 123,  51,  16,  88], device='cuda:0')
argmax end logits shape: tensor([ 47,  38,  30,  48,  32,  14,  14,  29,  37,  59,  43,  53,  38,  21,
         93,  56,  57,  43,  72,   8,  91,  26,  64,   2,  64,  89,   8, 127,
         59,  34,  31,  17,   5,  68,  39,   5,  17,  13,  59,  12,  93,  37,
         64,  55,  12,  22,  32,  42,  19,  19,  34,  23, 108,  18,  34,  32,
          2,  81,  93,  51, 123,  51,  12,  82], device='cuda:0')
argmax start logits shape: tensor([  7,  52,  16,   9,  58,   0,  12,   0,   0,  24,   0,  21,   0,   0,
         10,  13, 121,  44,  42,   8,  12,  34,  30,  13,  14,  33,   5,  16,
         

Evaluating:  13%|█▎        | 40/313 [00:01<00:08, 34.06it/s]

argmax start logits shape: tensor([ 34,  46,  38,  40,  63,  10,  90, 110,  72,  24,  51,   0,  93,  70,
         39,  58, 105,  77,  75,  35,  12,   3,   0,   0,  85,  85,  64,  44,
         12,   0,   0,  32,  18,  37,  29,  33,  67,  65,   0, 100,  30,   0,
         99,   0,  72,   0,  85,  30,   0,  10,  30,   0,  80,  33,  78,  44,
         52,  78,   0,  64,   0,  71,  51,  62], device='cuda:0')
argmax end logits shape: tensor([ 23,  50,  44,  89, 115,  67, 100, 110,  20,   8,  85,  26, 110,  53,
         32,  93, 105,  29,   6,  19,  36,   6,  98,  73,  85,  85,  80,  81,
         36,  18,   5,   7,  18,  25,  43,  67,  73,  55,  56,   8, 109,  15,
         34,  52,  71,   6,  87,  77,  48,  69,  57,  52,  13,  55,  76,   8,
         42,  18, 111,  92,  29,  99,  20,  62], device='cuda:0')
argmax start logits shape: tensor([ 53, 103,  20,  81, 126,   2,   0,   8,  50,  92,   0,   8,  46,  39,
         84,   3,   3,  20, 106,  69,   0,  84,  34,  30,  23,   1,  87,  77,
         

Evaluating:  15%|█▌        | 48/313 [00:01<00:07, 34.62it/s]

tensor([ 98,   0,  17,  22,  82,  52,   0,   3,  83, 132,  18,  86,  63,  88,
          0,  93,   0,   2,   0,   0,  40,  47,  65, 136,  11,  98,   0,   0,
         27,   0,  31,   0,  25,  52,  47,  83,  40,  90,   0,  11,  12,  86,
         13,  13,  44,  94,  56,   0,   7, 120,  63,   0,   6, 131,  31,  84,
         57,  86,  60,  39,  41,  47,  58,   0], device='cuda:0')
argmax end logits shape: tensor([108,   4,  18,  42,   7,  52,  36, 104, 124, 103,  19, 113,  62,  87,
         50,  93,  31,  17,  80,  86, 124,  47,  92, 121,  75,  50,  97,   9,
         98,  77,  49,  30,  88,  91, 100,  43, 114, 134,  73,  23,  11,   8,
         47,  14,   6,  94,  66,  28,   7,  88,  58,  21,  56,  15,  95,  53,
         18,  94, 128,  30,  36,  38,  46,  38], device='cuda:0')
argmax start logits shape: tensor([139,   0,  97,  33,  46,  80,   0,   0,   0,  27, 119,  84,  53,  11,
         67,  11,  60,  59,  30,  52,  29,  61,   8,  20,  92,  33,  70,  17,
         61,  30,  53,  46,  80,  39

Evaluating:  18%|█▊        | 56/313 [00:01<00:07, 33.90it/s]

argmax start logits shape: tensor([ 53,  98,  17,  45,  18,   7,  23,  85,  13,  52,  86, 111,  49, 103,
          2,   5,  90,  35,   0,  41,  41,  23, 109,  17,  31,  34,  97,  45,
         15,  13,  93, 104,  26,   0,   3,  33,  73,  47,  44,  38, 100,  50,
         53, 120, 110,   2,  96,  14,  39,   0,   0,   2,  55,  55,  93,   0,
        107,   2,  82,  47,  99,  32,   7,  38], device='cuda:0')
argmax end logits shape: tensor([101,  62,  18,  20,  50,  55,  69,  85,  31,  85, 120,  80,  47,  13,
          3,   7,  91,  17,  18, 120,  44,  57,   6,  17,  66,  41,  24,  48,
         65,   3, 103,  81,  18,  16,  88,   6,  80,  11,  30,  70,  71,  45,
         40,  58,  91,   2,  25,  56,  27,  10,  21,  39,  78,  71,  20,  23,
         50,  28, 110,  72,  11,  25,  91,  80], device='cuda:0')
argmax start logits shape: tensor([117, 130,   0,  58,   7,   0,  22,   0,  63,  44,  58,  52, 111,  33,
         90,  13,   2,   0,  65,   0,  76,  27,  49,   0,   6,  55,  51,  32,
         

Evaluating:  19%|█▉        | 60/313 [00:01<00:07, 33.91it/s]

argmax start logits shape: tensor([ 67,  38,  89,  14,  22,   0,  90,   0, 130,  53,   0,  14,  32, 114,
          1,   0,  10, 105,   2,   0,  40,   2,  30,   2,  25,   4,   0,  65,
         62,  88,  47,   8,  52,   0,  86,   3, 114,   0,   8,   7,   0,   9,
          9, 112, 113,  20,  99,  34,  13,  81, 101,  61,  44,   0,  24, 119,
          0,  21,  80,   8,  43,  75,   6,  45], device='cuda:0')
argmax end logits shape: tensor([ 42,  26, 102,  31,  61,  12,  36,  30, 130,  33,  94,  77,  52, 130,
         63,  38,  20,  13,  75,  28,  43,  67,  30,  55,  87,  52,  33,  13,
         55,  25,  35,   3,  46,  87,  82,   5, 115,  77,   7,  69, 119, 126,
          9,  69, 113,  20,  66,  24,  31,  66, 101,  24,  96,   7,  20, 109,
         17,  30,  71,  70,  35,  75,  10,  37], device='cuda:0')
argmax start logits shape: tensor([  0,  62,  10,  12,   2,  64, 127,  16,   0,   5, 110,  19,  23,  34,
         39,  26,  19,   0,   8,  70,   0,  15,  86,  91,  30,  39,   0,  83,
         

Evaluating:  22%|██▏       | 68/313 [00:01<00:07, 34.14it/s]

argmax start logits shape: tensor([110,  23,   0,   0,  56,   0,   0,  23,   6, 123,   3,  49,   2,  52,
         25,  37,  43,   0,  44, 105,  74,  65,  15,  49,  20,  11,  38,  66,
         81,   8,  70,  52,   0,  13,   0,   3,  54,  21,   0,  38,  25,   0,
         22,  30,  23,   0,   0,  48,  85,   9,   6,   0,  11,  83,  89, 112,
          2,  49,  26,  61,   0,  36,   0,  47], device='cuda:0')
argmax end logits shape: tensor([117, 109,  17,   6,   7,   8,  21,   7,  13,  15,   3,  69,   2,  20,
         25,  39,  44,   5,  75,  26,   3,  66,  37,  16, 116, 117,  90,  76,
         59,  29,  70,  92,  84,  13,  45,   3,  94,  43,  11,  81,  26,  51,
         22,  17,  65,  13,  20, 114,  34,  34,  30,  28,  22,  27,  46, 111,
         13,  41,  27,  50,  53,  62,  21,  14], device='cuda:0')
argmax start logits shape: tensor([ 76,  34,   0,  55,   0,  58,  91,   0,  59,   2,  13, 105,  33,  52,
         30,  64,  19,  34,   0,  63,  47,   0,   7,  80,  21,  94,  20,  41,
         

Evaluating:  24%|██▍       | 76/313 [00:02<00:07, 33.26it/s]

tensor([ 28,  56, 104,  64,  79,   4,   4,  72,  97, 119, 100,  81,  48,  85,
         52, 113, 113,  90,  66,  36, 114,  18,  65,   3,  75,   2,  40,  41,
         30,  81,  89,  73,  99,  53,   5,  22,  45, 106, 105,  42,  85,  82,
         10,  13, 100,  86,  63, 102,  81, 110,  90,  95,   5,  16,  38,  92,
         41, 112, 113,   8,  57,  97,  17,  35], device='cuda:0')
argmax start logits shape: tensor([110,   0,  46,   0,  10,  15,  14,  72,  38,   0,  35,  59,  59,  44,
        103,  31, 107,  30,   8,  55,  11,   0,  99,  20,   0,  21,  59,  14,
         70,   2,   7,  47, 112,  49,  90,   0,  54,  15,  52,   0,  33,   3,
          7,  31,  35,   0,  43,  11,   0, 101,   0, 139,  84,  71, 105,   2,
         70,  98,  16, 108,  77,  37, 108,  88], device='cuda:0')
argmax end logits shape: tensor([111,  12,  30,   2,  10,   5, 107,  77,  47,  90,  23,  66,  91,  81,
         33, 110,  67,  45,  48,  56,  48,  19,  49,  94,  88,  22,  25,  54,
         15, 124,  51,   9, 105,  49

Evaluating:  27%|██▋       | 84/313 [00:02<00:06, 33.23it/s]

argmax start logits shape: tensor([ 51,  93,  99,   2,  90,  88,   0, 105,  36,  32,  22,  51,  85,  90,
         98,  57,  41,  52,  16,  67,  71,  25,  85, 113,  39,   8,  17, 106,
         12,  88,   0,  54,  29,   2, 109,  22,  19,   0,  34,  93, 112,  54,
         69,  99,  71,  10,  63,   8,  97,  78,  64,   0,  73,  52,  54,  22,
         26,  18,   0,   0,  22,   3,  91,  45], device='cuda:0')
argmax end logits shape: tensor([ 56,   3,  12,   6,  35,  88,  30, 114,  79,  86, 126,  79,  17,   9,
        109,   2,  16,  33,  33,  67,  45,  57,  85,   9,  32,   4,  27, 106,
          2,  82,  12,   8,  31,   5,  12,  45,  19,  89,  61,  14,  55,  98,
         31,   3,  79,  67, 120,  48,  35, 106,  13,  19,  82,  90,  19,  28,
         26,   8,  13,  16,  38,  46,  53,  50], device='cuda:0')
argmax start logits shape: tensor([ 31,  90,  37,   0,  13, 108,   2,  36,  43,  18,   2,   2,  24,   2,
          0,  25,   0,  12,  75,  99,  17,   2,  75,  72,   2,  70,  63,   0,
         

Evaluating:  28%|██▊       | 88/313 [00:02<00:06, 33.59it/s]

argmax start logits shape: tensor([117,  93,  85, 133,  56,  88,  49,   1,   1,  17,  54,  49,  17,  61,
        102,  59,   0,  71,  38,  74,  13,  56,  46,  21,  71,  29,   3, 107,
         42,   0,  53,   8,  61,  10,  24,  29,   8,  85,  55,   0,   0,  50,
         47,  86,  47,  23,  60,  45,  29,  32,   3, 140,  59,  26,  34,  93,
         73,  36,  94,  17,  11,   0,  95,  22], device='cuda:0')
argmax end logits shape: tensor([100,  40,  67,   4,   7,  30,  26,   1,   1,  13,  58,  51,  47,  61,
         66,  58,  53,  43,   7,  62,  52,  88,  47,  21,  14,  29,  18,  97,
         38,  12,  34,   9,  29, 101,  20,  10,  17,  21,  62,   5,  30,  44,
         32, 102,  48,  63,  82,  47,   5,  33,  28, 109,  25,  22, 111,  72,
         88,  37,  94,  73,  11,  12,  13,  24], device='cuda:0')
argmax start logits shape: tensor([ 59,  80, 103,  21,   2, 115,   2,  21,  80, 127,  27,  40,  80,  85,
         17,  37,  53,  90,  79,  31,  58,  34,  25,  23,  28,  39,  92,  55,
         

Evaluating:  31%|███       | 96/313 [00:02<00:06, 33.53it/s]

tensor([ 28,  36,  46,  83,  28,   0,   3,   0,  37,  20,  51, 100,   0, 100,
         19, 101,   0,  14,  49,  11,  55,   0,  53,  99,  62,  81,  25,  42,
         18,  11,  42,   0,   7,  34,  80, 106,   2,   0,  56,  96,   0,  79,
          0,  95,  19,  51,   0, 106, 109,   6,   0,  81,  16,   0,  96,  49,
         29,   0,   8, 118,   0,   0,  87,  37], device='cuda:0')
argmax end logits shape: tensor([ 79,   6,   6, 128,  65,  90,  44,  67,  25,  20,  90, 101,  33,  62,
         20,  74,  31,  15,  41,  22,  80,   8,   4,  21,  55,  78,  60,  26,
         49,  28,  38,  77,  38,  19,   8,  59, 135,  18,  69,  62,  72,  11,
         63,   5,  73,   4, 120, 107, 109,  26,  46,  81,  64,  10, 119,  78,
         70,  46,  93,  62, 133,  88,  78,  37], device='cuda:0')
argmax start logits shape: tensor([  0,  45,  15,   5, 133,  75,  32, 100,  28,   5,   0,  63,  43,  12,
         55,  38,  24, 104,   0,   0,   2,   0,  77,  77,  80,  90,  12,   0,
         34,  16,  98,  30,   5,  61

Evaluating:  33%|███▎      | 104/313 [00:03<00:06, 33.26it/s]

tensor([ 23,  97,  18,  74,  70,  67,   0,  33,   0,  70,  79,  53,   0,   0,
         85, 104,  12,   5,  16,  18,  81,  21,   0,  93,  85,  49,  84,  29,
         90,   0,  86,  37,   0, 105,  36,   9,  14,  92,  61, 147,  27,  66,
          0,   0,  48,  26,   2,   2,  70,   7,   0,  25,  28,  55,  22,  76,
          5,   0,   5,  13,  19,   0,   2,  14], device='cuda:0')
argmax end logits shape: tensor([ 59,  41,  23,   3,  57,  34,  35,  42,  34,  83,  79, 116,  64,  99,
          2,   4, 115,   7,  66,  31,  81,  59,  35, 110,   6, 115,  65,  70,
        100,  36,  86,  39,  84,  53,  69,  54,  42,   6,  31,  72,  29,  66,
         36,  25,  56,  26,   3,  98,  93,  32,  14,  64,  28,  16,  20,  22,
         17,  17,  90,  47,  27,   7,   3,  85], device='cuda:0')
argmax start logits shape: tensor([ 10,  44,  57,  18,  48,  81,  75,  17,  85, 113,  50,   2,  71,  81,
         47,   0,  85,   0,   0,  26,   0,  11,  52,  39,  53,  54,  20, 123,
         57,   0,   0,  54,   0,  41

Evaluating:  35%|███▍      | 108/313 [00:03<00:06, 33.29it/s]

argmax start logits shape: tensor([ 21,  31,  55,   0,  49,  93,  40,   4,  59,  53,   0,  41,   0,   3,
         90,  92,  90,  85,  31,  83,   0,  39,   0,   0,  22,  90,  24,  64,
          7,  21,  17,  41, 116,  31,   0,   0,   4,   8,  10,   0,  63,  13,
          0,  25,   0,  66,  33,  42,  96,  90,   0,  14,  28,  40,  41,  19,
          0,  87,  80,  16,  69,   0, 110,  63], device='cuda:0')
argmax end logits shape: tensor([ 16,  11,  10,  31,  21,   4,  54,  84,  59, 102,  21,  41,  55,  86,
          9,  92,  39, 129,  31,   4, 105,  32,  19,   9,  46,  44,  47,   6,
         54, 113,  42,  42, 154,  31,  36, 107,  33,  66,  51,  71,  59,  97,
         18,  50,  67,  66,  80,  42,  33,  97,   9,  94,  28,  57,  41,  58,
         84,  59,  66,  11,  71,  60,  11,  63], device='cuda:0')
argmax start logits shape: tensor([ 15, 101,   0,   0,  75,  33,  54,  79,  19,   0,  98,  32,  12,  38,
         30,  54,  51,  72,  82,   0,   3,  15,   2,  60, 105,  80,   2,  42,
         

Evaluating:  37%|███▋      | 116/313 [00:03<00:06, 32.71it/s]

tensor([  5,  78,  63,  53,  13,  29,  13,   8,  70,  19,  74,   0,  83,  14,
         38,  53,  81,  47,  25,  42,  15,  20,  52,   2,  44,   0,  78,  79,
         29,  22,  65,  30,  77,  61,   0,  18,  42,  59,  91,  18,   0,  75,
         53, 100,   3, 101,  18,   0,  31,   0,   0,  99,  90,  54,  36,  15,
         62,  38,  16,   0, 107,   8,  51,  27], device='cuda:0')
argmax end logits shape: tensor([  6,  79,   8,   6,  81,  30, 120,  32,  55,  90,  13,  43,  83,   3,
          7,  13,   5,  35,  61,  45,  90,  47,  15,  22,  81,  28,  74, 117,
         30,  70,   6,   3,  91,  65,  12,  19,  44,  10,  87,  51,  80,  17,
         11,  51, 124, 122,  74,  31,   5,  71,  67,  51,  36,  45,  11,  34,
         66,  17,  16, 114,  18,  94,   4,  37], device='cuda:0')
argmax start logits shape: tensor([ 86,   9,  27,   6,  23,  90,  77, 138,  99,  83, 107,   0, 115,  23,
        112,  15,   0,  19,  34,  52,  36,  25,  85,  76,  36,  25,  19,  15,
         64,  27,  61,  34,  59,  22

Evaluating:  40%|███▉      | 124/313 [00:03<00:05, 32.99it/s]

argmax start logits shape: tensor([ 63,  49, 106,  62,   9, 106,   8,   2,   0,  65, 103,  52,   0,  30,
         30,  33, 111,   0,  78,  16,  82,   0,  55, 121,   4,  25,  31,   0,
         34,   7,   5, 116,   0,  10,   9,   0,   4,  41,  18,  85,  46,   0,
         92,  71, 109, 111,   0,  84,  98, 121,  24,  16,  19,  75,   0,   0,
         50,  38,  36,   0,   0,   3,   0, 112], device='cuda:0')
argmax end logits shape: tensor([ 94,  43, 105,  55, 107,  15,  50,  41,  32,  62, 104,  93,  34,  30,
         83,  87,  13,  52,  76,  42,  42,  50,  55,  47,  16,  26,  29,  19,
         80,  66,  34,  31,  53,  73,   8,   7,  15,  73,  63,  58, 124,  68,
         46,  35,  50,  39,  50,  14, 115, 116,  15,  82,  31,  74,  76,  18,
         79,  81,  11,  41,  43,  49,  93, 105], device='cuda:0')
argmax start logits shape: tensor([ 17,   0,  98,   0,  46,  30,  57,  14,   0,  14, 108, 107,  17,   7,
         92, 108,  47,   0,  59,  64,  48,  75,  41,   6,  49,   0,   3,  40,
         

Evaluating:  42%|████▏     | 132/313 [00:03<00:05, 34.07it/s]

tensor([ 43,  21,  16, 106,  60,   0, 123,  51,   0,   2,  14, 112,   0,  63,
          0,   5,  47,  22,   4,  17,  73,   0,   4,   0, 110,   4,  19,   0,
          8,  93,  78,  49, 117,  59,  79,  10,  55, 119,  36,  29,  40,   0,
         30, 121,  71,  17,  54,  52,   0,   7,   0,  93,  25,  21,  61,  73,
          0,  95,   0,   0,   0,  13, 116,  24], device='cuda:0')
argmax end logits shape: tensor([102,  19,  34,  48,  36,  52, 136, 109,  68,  64,  59,  56,  69,  58,
         12,  67,  67, 120,  20, 112,  57,  17,  45,  13,  68,  37,  54,  49,
         54,  44, 106,  51, 119,  59,  81, 104,   6,  76,  62,  15,  53,  45,
        104,   6,  79,  34,  54,  92,  96,  75,  63,  86,  26,  80,  61, 147,
         14, 113,  82, 107,  44,  15, 154,  29], device='cuda:0')
argmax start logits shape: tensor([117,  87,  41, 119,   5,   0,   9, 132,  98,   2,  33,  78,   0,  54,
         89,   0,   0,  14,  51,  46,  20,  41,  18,  11,   0,  36,  34,   0,
          0,   0,   0,  62,   0,   0

Evaluating:  43%|████▎     | 136/313 [00:04<00:05, 33.77it/s]

argmax start logits shape: tensor([ 80,  55,  27,   0, 121,   0,  79,  68,  10,  47,  39,   0,   0,  29,
         22,  28,  13,  11,   2,  14, 101, 135,  75,  44,  83,   2,   0,  53,
         13,   0,  70,  24,  12,   1,  76,  27,   4,  77,  14,  84, 107,  73,
         16,  19,   0,  89,  30,  58,  24,   3,  59,   8,  47,  38,   0, 114,
        119,   6,  13,  41,  35,  17,   0,   0], device='cuda:0')
argmax end logits shape: tensor([ 13,  16, 117,  36, 124,  44,  17,  91,  68,  67,  83,   6,  51,  75,
         94,   2,  31,  14,  71,  18,  63,  97,  46,  15,  70,   3,  72,  42,
         14,  65,  62,  28,  52,   1,  76,   5,   2,  29,   1,  21, 121,  14,
         40,   5,  60,  84,  30,  21,  24,   4,  67,  50,  57,  27,  39,  87,
         55,  32,  40,  34,  35,  23,  44,  31], device='cuda:0')
argmax start logits shape: tensor([ 33,  61,   9,  30,  59,   2,  23, 109,  14,  68,  64,  59,  54, 121,
         82,   6,   0, 100,  31,  41,  94,  66,  40,  33,   3,   0,  40,  98,
         

Evaluating:  46%|████▌     | 144/313 [00:04<00:05, 33.30it/s]

argmax start logits shape: tensor([ 99,   0,   3,  31,  75,  47, 107,  89,  17,  26,  12,  22,   2,  33,
          8,   2,  23,  36,  39,   0,  42, 117,  71,  67,   0,  49,   1,  43,
          0,  27,  66,   8,  88, 109, 110,   0,  25,  92,  39,  13,   5,  53,
         35,   0,   0,  95,   0,  40,  53,  33,  20,  29,  42,  33,  72,  83,
         55,  95,   5, 106,   0,  18,  44,  29], device='cuda:0')
argmax end logits shape: tensor([ 66,  67,  64,  31,  46,  56,  35,  13,  90,  27,  40,  61,  60, 117,
        133,   5, 109,   4,  93, 112,  32,  16,  49,  57,  10,  21,  92, 102,
        133,  90,  91,  28,  88, 109, 129,  86,  15, 116,  75,  14,  37, 116,
         19,  81,   8,   3,  41, 143,  47,  34,  22,  12,  44,  17,   1, 100,
         16,  98,   6,   4,  68,  18, 119,  64], device='cuda:0')
argmax start logits shape: tensor([ 70,  31, 112,   9,  68,  76,   3,  39,   8,  99,  49,  71,  13,  31,
         19,  42,  25,  52,  20,   8,   0, 105,   0,  24,   7,  36,   4,   0,
         

Evaluating:  49%|████▊     | 152/313 [00:04<00:04, 33.57it/s]

tensor([146,  17,  10,   4,  35,  69, 109,   0, 120,  36,  31,  10,   0, 119,
         70,  26, 134,  15,  39, 100,  49,  48,  61,  36,  36,  42,   0,  30,
         10,   0,   3,  47,   3,  16,  11,  53,   8, 107,  26, 119,  55,  97,
         26,   7,  65, 101,  99, 108,   0,  45,  76,  23,  27,  33,   0, 116,
         81, 127,  87,  43,   0, 120, 146,  15], device='cuda:0')
argmax end logits shape: tensor([156,  35,  62,   8, 112, 131, 101,  17, 120, 146,  65,  40,  24,  28,
         77,  31, 131,  15,  52, 108,  12,  72,  54,  37,  58,  34, 105,  42,
         10,  48,  27,  47,   3,   8,  45,  46,  10,  67,   3,  70,  59,  97,
         81,  23,  46,  25,  99,  78,  44, 128,  76,  45, 104,  87,   8,  77,
         84,  17,  32,  70,   3,  66,  83,  19], device='cuda:0')
argmax start logits shape: tensor([ 38,   9,  96,   5,  95,  27, 117,  28, 102,   0, 109,   0,   0,  34,
         26,  47,  25,   0,   0,  36,  50,  79,  49,  79,  16,   0,  29,   0,
          2, 112,   2,  54,  27,   3

Evaluating:  51%|█████     | 160/313 [00:04<00:04, 32.75it/s]

argmax start logits shape: tensor([120,  25,  20,  33,   2,  58,  65,   3,  26, 108,   0,  14, 112,  93,
         34,  16,  62,  16,   0,  19,   3,  93,   2,  63,  70,  43,  92,  45,
          2,  83,  42,   2,   0,   0,   0,  69,  28,  87,  36,  81,  44,  80,
        103,  78,   0, 107,  74,   0,  83,   9,  26,  80,  53,  48,  53,   2,
         74,  83,  65,  26,  10,   0,  33,   2], device='cuda:0')
argmax end logits shape: tensor([ 49,  29,  12,  48,  54,  21, 108,  78,  32,  26,  80,  95,  79,  49,
         14,  13,   3,   5,  59,  38,   6,  13,  30,  63,  20,  17,  10,  12,
         40,  43, 130,   2,  79,  23,  23,  72,  56,  78,  32,  93,  63,  15,
        103,  84,  10,  16,  13,  14,  73,  32,  16,  80,  40, 100, 102,   5,
         62,   3,  37,  15, 125,  24,   2,  68], device='cuda:0')
argmax start logits shape: tensor([ 21,  50,   5, 136,  46,  50,   3,  63,   0,  63,  47,   0,  73,  40,
         29,  16,  16,  19,  55,  67,   1,  33,   0,   9, 107,  55,  39,  38,
         

Evaluating:  52%|█████▏    | 162/313 [00:04<00:04, 33.42it/s]

argmax start logits shape: tensor([ 61,   2,  75,  37,   4,  45,   0,   5,  80,  71,  16,  54,   3,  14,
         51,  19,  78,  38, 120,   0,   4,   0,  48,  32, 117,  33,  20,   9,
         26, 127,   0,  88, 101,  18,  52,  69,  58,   0, 127,  25,  72, 100,
         38,   0,  22,  55,  33,  29,  65,  30,  95,   5,  24, 127,  13,  51,
          0,  62,  41,   0,  11,   8,   0,  43], device='cuda:0')
argmax end logits shape: tensor([ 64,  61,  37,  38,  45,  47,  69,  86,  35,  45,   8,  54,  22,  14,
         17,  29, 105,  39,  88,  55,  30,  48,  48,   4,  48,  45, 112,  82,
         51,  17,  12,  88,  21, 109, 141,  41,  58,  37,  39, 112,  77, 101,
         39,  11,  54,  78,  12,  85,  49,   6,  96,  73,   4,  58,  17,  90,
         19,  67,  32,  90,  58,   2,  67,  17], device='cuda:0')
argmax start logits shape: 




KeyboardInterrupt: 

In [40]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pad_idx=1

model = TransformerQAModel(vocab_size=10000)
model = model.to(device)

# Setup an optimizer (e.g., Adam)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

train_qa_context_model_boilerplate(
    model=model, 
    train_dataloader=train_dataloader, 
    val_dataloader=dev_dataloader, 
    optimizer=optimizer, 
    criterion=criterion, 
    num_epochs=10, 
    device=device, 
    inputs = ["context_question", "attention_mask_context_question"],
    evaluate_val_dataset=True,
)


Epoch 1/10: 100%|██████████| 625/625 [00:28<00:00, 21.89it/s, loss=9.7] 


Epoch 1 Loss: 9.6254


Evaluating: 100%|██████████| 625/625 [00:11<00:00, 54.76it/s]


Training Loss: 9.4978
Training Metrics: {'start_accuracy': 0.03336167658680538, 'start_precision': 0.0025512420609037755, 'start_recall': 0.03336167658680538, 'start_f1_score': 0.0041695202896415914, 'end_accuracy': 0.013354674135947582, 'end_precision': 0.002648366725810521, 'end_recall': 0.013354674135947582, 'end_f1_score': 0.0017788785126064577, 'joint_exact_match': 0.004951733106587306, 'span_overlap_f1': 0.04972476770789385}


Evaluating: 100%|██████████| 63/63 [00:01<00:00, 55.28it/s]


Validation Loss: 9.9155
Validation Metrics: {'start_accuracy': 0.03156312625250501, 'start_precision': 0.0015882682796786235, 'start_recall': 0.03156312625250501, 'start_f1_score': 0.00298917917021909, 'end_accuracy': 0.011022044088176353, 'end_precision': 0.0004505255004957793, 'end_recall': 0.011022044088176353, 'end_f1_score': 0.0008158863183095769, 'joint_exact_match': 0.0035070140280561123, 'span_overlap_f1': 0.04250132813072487}
--------------------------------------------------


Epoch 2/10: 100%|██████████| 625/625 [00:28<00:00, 21.75it/s, loss=10.5]


Epoch 2 Loss: 10.0968


Evaluating: 100%|██████████| 625/625 [00:11<00:00, 54.40it/s]


Training Loss: 10.4835
Training Metrics: {'start_accuracy': 0.009153203621267444, 'start_precision': 0.008294903377655191, 'start_recall': 0.009153203621267444, 'start_f1_score': 0.005903952177322749, 'end_accuracy': 0.010303606262191766, 'end_precision': 0.012917179306736604, 'end_recall': 0.010303606262191766, 'end_f1_score': 0.007263170567592487, 'joint_exact_match': 0.001700595208322913, 'span_overlap_f1': 0.01553683800681206}


Evaluating: 100%|██████████| 63/63 [00:01<00:00, 54.35it/s]


Validation Loss: 10.4834
Validation Metrics: {'start_accuracy': 0.008517034068136272, 'start_precision': 0.0033243234269571633, 'start_recall': 0.008517034068136272, 'start_f1_score': 0.003948711226259474, 'end_accuracy': 0.00751503006012024, 'end_precision': 0.008530819843778377, 'end_recall': 0.00751503006012024, 'end_f1_score': 0.005593969100494187, 'joint_exact_match': 0.001503006012024048, 'span_overlap_f1': 0.013045845978577802}
--------------------------------------------------


Epoch 3/10: 100%|██████████| 625/625 [00:29<00:00, 21.52it/s, loss=10.5]


Epoch 3 Loss: 10.4859


Evaluating: 100%|██████████| 625/625 [00:11<00:00, 56.46it/s]


Training Loss: 10.4835
Training Metrics: {'start_accuracy': 0.009153203621267444, 'start_precision': 0.009549010609766312, 'start_recall': 0.009153203621267444, 'start_f1_score': 0.006495914882263201, 'end_accuracy': 0.009603361176411744, 'end_precision': 0.007991165082784963, 'end_recall': 0.009603361176411744, 'end_f1_score': 0.0069187707667760865, 'joint_exact_match': 0.001500525183814335, 'span_overlap_f1': 0.016522912446762596}


Evaluating: 100%|██████████| 63/63 [00:01<00:00, 58.81it/s]


Validation Loss: 10.4835
Validation Metrics: {'start_accuracy': 0.009018036072144289, 'start_precision': 0.0072049700008965534, 'start_recall': 0.009018036072144289, 'start_f1_score': 0.0054157054795576755, 'end_accuracy': 0.008016032064128256, 'end_precision': 0.003922685675360828, 'end_recall': 0.008016032064128256, 'end_f1_score': 0.004407336472507647, 'joint_exact_match': 0.002004008016032064, 'span_overlap_f1': 0.01473049934691032}
--------------------------------------------------


Epoch 4/10: 100%|██████████| 625/625 [00:28<00:00, 22.03it/s, loss=10.5]


Epoch 4 Loss: 10.4859


Evaluating:  70%|███████   | 440/625 [00:08<00:03, 53.93it/s]


KeyboardInterrupt: 

In [64]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class PositionalEncoding(nn.Module):
    """
    Adds sinusoidal positional encodings to token embeddings.
    """
    def __init__(self, d_model, dropout=0.2, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_len, d_model)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


class BiDAFAttention(nn.Module):
    """
    Bi-Directional Attention Flow layer from "BiDAF" (Seo et al., 2017).
    """
    def __init__(self, hidden_dim):
        super(BiDAFAttention, self).__init__()
        # similarity takes [c, q, c ∘ q]
        self.similarity_linear = nn.Linear(3 * hidden_dim, 1, bias=False)

    def forward(self, context, query):
        batch_size, c_len, _ = context.size()
        q_len = query.size(1)

        # pairwise combinations
        c_exp = context.unsqueeze(2).expand(-1, -1, q_len, -1)
        q_exp = query.unsqueeze(1).expand(-1, c_len, -1, -1)
        cq_mul = c_exp * q_exp
        sim_in = torch.cat([c_exp, q_exp, cq_mul], dim=3)

        # similarity matrix
        S = self.similarity_linear(sim_in).squeeze(3)  # (B, c_len, q_len)

        # Context-to-Query
        a = F.softmax(S, dim=2)                       # (B, c_len, q_len)
        c2q = torch.bmm(a, query)                     # (B, c_len, D)

        # Query-to-Context
        b = F.softmax(torch.max(S, dim=2)[0], dim=1)  # (B, c_len)
        b = b.unsqueeze(1)                            # (B, 1, c_len)
        q2c = torch.bmm(b, context)                   # (B, 1, D)
        q2c = q2c.repeat(1, c_len, 1)                  # (B, c_len, D)

        # combine
        G = torch.cat([context, c2q, context * c2q, context * q2c], dim=2)
        return G


class BiDAFOutput(nn.Module):
    """
    Output layer to compute start and end logits for BiDAF/G
    """
    def __init__(self, hidden_size, dropout=0.2):
        super(BiDAFOutput, self).__init__()
        # G dim = 8*H, M dim = 2*H -> concat dim = 10*H
        self.p1_weight = nn.Linear(10 * hidden_size, 1)
        self.p2_weight = nn.Linear(10 * hidden_size, 1)
        self.modelling_LSTM_end = nn.LSTM(
            2 * hidden_size, hidden_size, bidirectional=True, batch_first=True
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, G, M):
        # G: (B, c_len, 8H), M: (B, c_len, 2H)
        GM = torch.cat([G, M], dim=2)  # (B, c_len, 10H)
        start_logits = self.p1_weight(self.dropout(GM)).squeeze(2)

        M2, _ = self.modelling_LSTM_end(M)
        GM2 = torch.cat([G, M2], dim=2)
        end_logits = self.p2_weight(self.dropout(GM2)).squeeze(2)
        return start_logits, end_logits


class BiDAFTransformer(nn.Module):
    """
    BiDAF with Transformers replacing the BiLSTMs.
    """
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_size: int,
        nhead: int = 8,
        num_layers: int = 2,
        dim_feedforward: int = 2048,
        dropout: float = 0.2,
        pretrained_embeddings: torch.Tensor = None
    ):
        super().__init__()
        self.hidden_dim = 2 * hidden_size  # match BiLSTM output

        # embeddings + projection
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(pretrained_embeddings)
        self.embed_proj = nn.Linear(embed_dim, self.hidden_dim)
        self.pos_enc = PositionalEncoding(self.hidden_dim, dropout)

        # transformer encoder for context & question
        enc_layer = TransformerEncoderLayer(
            d_model=self.hidden_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu'
        )
        self.context_encoder = TransformerEncoder(enc_layer, num_layers=num_layers)

        # attention flow
        self.att_flow = BiDAFAttention(self.hidden_dim)

        # modeling transformer
        self.modeling_proj = nn.Linear(4 * self.hidden_dim, self.hidden_dim)
        mod_layer = TransformerEncoderLayer(
            d_model=self.hidden_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu'
        )
        self.modeling_encoder = TransformerEncoder(mod_layer, num_layers=num_layers)

        # output layer
        self.output_layer = BiDAFOutput(hidden_size, dropout)

    def forward(self, context, question, attention_mask_context, attention_mask_question):
        ctx_mask = attention_mask_context == 0
        qry_mask = attention_mask_question == 0
        

        # embed + project + pos encode
        c = self.embedding(context)
        c = self.embed_proj(c)
        c = self.pos_enc(c)
        q = self.embedding(question)
        q = self.embed_proj(q)
        q = self.pos_enc(q)

        # transformer encoders (seq_len, batch, dim)
        c_enc = self.context_encoder(
            c.transpose(0,1), src_key_padding_mask=ctx_mask
        ).transpose(0,1)
        q_enc = self.context_encoder(
            q.transpose(0,1), src_key_padding_mask=qry_mask
        ).transpose(0,1)

        # attention flow
        G = self.att_flow(c_enc, q_enc)

        # modeling
        M_in = self.modeling_proj(G)
        M = self.modeling_encoder(
            M_in.transpose(0,1), src_key_padding_mask=ctx_mask
        ).transpose(0,1)

        # output heads
        start_logits, end_logits = self.output_layer(G, M)
        return start_logits, end_logits


In [65]:
import torch

# 1) Set up model hyperparameters
vocab_size    = 10_000    # e.g. BERT‐style vocab
embed_dim     = 300      # your embedding size
hidden_size   = 128      # LSTM “half”–size; transformer d_model = 2*hidden_size

# 2) Instantiate the model
model = BiDAFTransformer(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    hidden_size=hidden_size,
    nhead=8,
    num_layers=2,
    dim_feedforward=512,
    dropout=0.2,
    pretrained_embeddings=None  # or your Tensor(shape=(vocab_size, embed_dim))
)

model = model.to(device)

# Setup an optimizer (e.g., Adam)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

train_qa_context_model_boilerplate(
    model=model, 
    train_dataloader=train_dataloader, 
    val_dataloader=dev_dataloader, 
    optimizer=optimizer, 
    criterion=criterion, 
    num_epochs=10, 
    device=device, 
    inputs = ["context", "question", "attention_mask_context", "attention_mask_question"],
    evaluate_val_dataset=True,
)

Epoch 1/10: 100%|██████████| 313/313 [00:25<00:00, 12.32it/s, loss=8.61]


Epoch 1 Loss: 8.4773


Evaluating: 100%|██████████| 313/313 [00:10<00:00, 30.60it/s]


Training Loss: 8.4181
Training Metrics: {'start_accuracy': 0.05016755864552593, 'start_precision': 0.050738859724243336, 'start_recall': 0.05016755864552593, 'start_f1_score': 0.04446677830471992, 'end_accuracy': 0.06667333566748362, 'end_precision': 0.0620882369875789, 'end_recall': 0.06667333566748362, 'end_f1_score': 0.061667272291193446, 'joint_exact_match': 0.004251488020807282, 'span_overlap_f1': 0.058614727536063435}


Evaluating: 100%|██████████| 63/63 [00:01<00:00, 58.21it/s]


Validation Loss: 9.0916
Validation Metrics: {'start_accuracy': 0.0405811623246493, 'start_precision': 0.03859321502623705, 'start_recall': 0.0405811623246493, 'start_f1_score': 0.033891583025935534, 'end_accuracy': 0.0531062124248497, 'end_precision': 0.04359191951361048, 'end_recall': 0.0531062124248497, 'end_f1_score': 0.04563846314937105, 'joint_exact_match': 0.001503006012024048, 'span_overlap_f1': 0.05381935205999294}
--------------------------------------------------


Epoch 2/10: 100%|██████████| 313/313 [00:24<00:00, 12.66it/s, loss=9.32]


Epoch 2 Loss: 8.4730


Evaluating: 100%|██████████| 313/313 [00:10<00:00, 30.70it/s]


Training Loss: 8.7012
Training Metrics: {'start_accuracy': 0.04016405742009704, 'start_precision': 0.04308403160567833, 'start_recall': 0.04016405742009704, 'start_f1_score': 0.03959680953288123, 'end_accuracy': 0.06132146251187916, 'end_precision': 0.05539349141488904, 'end_recall': 0.06132146251187916, 'end_f1_score': 0.05100266305111465, 'joint_exact_match': 0.002600910318611514, 'span_overlap_f1': 0.03006356580619024}


Evaluating: 100%|██████████| 63/63 [00:01<00:00, 59.75it/s]


Validation Loss: 9.1678
Validation Metrics: {'start_accuracy': 0.02655310621242485, 'start_precision': 0.0258765825754199, 'start_recall': 0.02655310621242485, 'start_f1_score': 0.021957672115747447, 'end_accuracy': 0.056112224448897796, 'end_precision': 0.04493284545076262, 'end_recall': 0.056112224448897796, 'end_f1_score': 0.04432866791493448, 'joint_exact_match': 0.002004008016032064, 'span_overlap_f1': 0.024396955497541922}
--------------------------------------------------


Epoch 3/10: 100%|██████████| 313/313 [00:24<00:00, 13.04it/s, loss=8.54]


Epoch 3 Loss: 9.0317


Evaluating: 100%|██████████| 313/313 [00:10<00:00, 31.13it/s]


Training Loss: 8.5682
Training Metrics: {'start_accuracy': 0.03716300705246836, 'start_precision': 0.04416158088899367, 'start_recall': 0.03716300705246836, 'start_f1_score': 0.036095490330775803, 'end_accuracy': 0.046516280698244386, 'end_precision': 0.043199644426854894, 'end_recall': 0.046516280698244386, 'end_f1_score': 0.04120339376190402, 'joint_exact_match': 0.009953483719301756, 'span_overlap_f1': 0.03807694992580062}


Evaluating: 100%|██████████| 63/63 [00:01<00:00, 60.60it/s]


Validation Loss: 9.1987
Validation Metrics: {'start_accuracy': 0.021042084168336674, 'start_precision': 0.025745546491983587, 'start_recall': 0.021042084168336674, 'start_f1_score': 0.020309516326853187, 'end_accuracy': 0.028557114228456915, 'end_precision': 0.025524177047627564, 'end_recall': 0.028557114228456915, 'end_f1_score': 0.023827284140570035, 'joint_exact_match': 0.004008016032064128, 'span_overlap_f1': 0.023603550993802604}
--------------------------------------------------


Epoch 4/10: 100%|██████████| 313/313 [00:24<00:00, 12.77it/s, loss=8.36]


Epoch 4 Loss: 8.6291


Evaluating: 100%|██████████| 313/313 [00:10<00:00, 29.57it/s]


Training Loss: 8.3513
Training Metrics: {'start_accuracy': 0.03441204421547542, 'start_precision': 0.032196286546729506, 'start_recall': 0.03441204421547542, 'start_f1_score': 0.032379723499549415, 'end_accuracy': 0.0518181363477217, 'end_precision': 0.05020443743496569, 'end_recall': 0.0518181363477217, 'end_f1_score': 0.04593099062455904, 'joint_exact_match': 0.010153553743810334, 'span_overlap_f1': 0.040095807473011565}


Evaluating: 100%|██████████| 63/63 [00:01<00:00, 57.75it/s]


Validation Loss: 9.1831
Validation Metrics: {'start_accuracy': 0.02404809619238477, 'start_precision': 0.021799416237402978, 'start_recall': 0.02404809619238477, 'start_f1_score': 0.021039873658048613, 'end_accuracy': 0.03356713426853707, 'end_precision': 0.02881324681506143, 'end_recall': 0.03356713426853707, 'end_f1_score': 0.027694368208032633, 'joint_exact_match': 0.006513026052104208, 'span_overlap_f1': 0.027177412114514497}
--------------------------------------------------


Epoch 5/10: 100%|██████████| 313/313 [00:24<00:00, 12.83it/s, loss=8.69]


Epoch 5 Loss: 8.4115


Evaluating: 100%|██████████| 313/313 [00:10<00:00, 30.24it/s]


Training Loss: 8.2150
Training Metrics: {'start_accuracy': 0.02835992597409093, 'start_precision': 0.02683166109747708, 'start_recall': 0.02835992597409093, 'start_f1_score': 0.025229782806959238, 'end_accuracy': 0.05301855649477317, 'end_precision': 0.051697642217742534, 'end_recall': 0.05301855649477317, 'end_f1_score': 0.045166217529387494, 'joint_exact_match': 0.007152503376181663, 'span_overlap_f1': 0.04265346792091874}


Evaluating: 100%|██████████| 63/63 [00:01<00:00, 58.11it/s]


Validation Loss: 9.0689
Validation Metrics: {'start_accuracy': 0.022044088176352707, 'start_precision': 0.023319399169909327, 'start_recall': 0.022044088176352707, 'start_f1_score': 0.01808004063067341, 'end_accuracy': 0.04158316633266533, 'end_precision': 0.030268590400537303, 'end_recall': 0.04158316633266533, 'end_f1_score': 0.031179349052933063, 'joint_exact_match': 0.005511022044088177, 'span_overlap_f1': 0.036085301203402864}
--------------------------------------------------


Epoch 6/10: 100%|██████████| 313/313 [00:24<00:00, 12.87it/s, loss=8.21]


Epoch 6 Loss: 8.3004


Evaluating: 100%|██████████| 313/313 [00:10<00:00, 30.86it/s]


Training Loss: 8.1297
Training Metrics: {'start_accuracy': 0.03651277947281548, 'start_precision': 0.0365086182023549, 'start_recall': 0.03651277947281548, 'start_f1_score': 0.03547276467510793, 'end_accuracy': 0.06022107737708198, 'end_precision': 0.05753267268408212, 'end_recall': 0.06022107737708198, 'end_f1_score': 0.05566619516405114, 'joint_exact_match': 0.014605111789126194, 'span_overlap_f1': 0.05149802027700999}


Evaluating: 100%|██████████| 63/63 [00:01<00:00, 60.03it/s]


Validation Loss: 9.0228
Validation Metrics: {'start_accuracy': 0.022545090180360723, 'start_precision': 0.022433151114667454, 'start_recall': 0.022545090180360723, 'start_f1_score': 0.02056413705847162, 'end_accuracy': 0.036072144288577156, 'end_precision': 0.03463007143764952, 'end_recall': 0.036072144288577156, 'end_f1_score': 0.03156325627422017, 'joint_exact_match': 0.009519038076152305, 'span_overlap_f1': 0.03403111395304026}
--------------------------------------------------


Epoch 7/10: 100%|██████████| 313/313 [00:24<00:00, 12.94it/s, loss=8.1] 


Epoch 7 Loss: 8.2732


Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [None]:
model_path = "models/qa_context_model_trans_cross_att.pkl"

In [33]:
save_model(model, model_path)

Model saved to models/qa_context_model_dqra.pkl


In [34]:
model = load_model(model_path=model_path)

Model loaded from models/qa_context_model_dqra.pkl


In [41]:
# evaluate the model on the dev set
criterion = nn.CrossEntropyLoss(ignore_index=1)
evaluate_qa_context_model_boilerplate(model=model, dataloader=train_dataloader, criterion=criterion, device='cuda', inputs = ["context", "question", "attention_mask_question"])

Evaluating:   0%|          | 0/625 [00:00<?, ?it/s]


TypeError: TransformerQAModel.forward() got an unexpected keyword argument 'context'

In [22]:
# evaluate the model on the dev set
criterion = nn.CrossEntropyLoss(ignore_index=1)
evaluate_qa_context_model_boilerplate(model=model, dataloader=dev_dataloader, criterion=criterion, device='cuda', inputs = ["context", "question", "attention_mask_question"])

Evaluating: 100%|██████████| 63/63 [00:01<00:00, 50.07it/s]

Validation Loss: 10.3172
Validation Metrics: {'start_accuracy': 0.07565130260521043, 'start_precision': 0.09645009422754619, 'start_recall': 0.07565130260521043, 'start_f1_score': 0.07168606723489031, 'end_accuracy': 0.09018036072144289, 'end_precision': 0.13011087755691064, 'end_recall': 0.09018036072144289, 'end_f1_score': 0.08554795701482384, 'joint_exact_match': 0.02404809619238477, 'span_overlap_f1': 0.07026853459209209}





(10.317197020091708,
 {'start_accuracy': 0.07565130260521043,
  'start_precision': 0.09645009422754619,
  'start_recall': 0.07565130260521043,
  'start_f1_score': 0.07168606723489031,
  'end_accuracy': 0.09018036072144289,
  'end_precision': 0.13011087755691064,
  'end_recall': 0.09018036072144289,
  'end_f1_score': 0.08554795701482384,
  'joint_exact_match': 0.02404809619238477,
  'span_overlap_f1': 0.07026853459209209})

In [42]:
preds, true_labels = predict_qa_context_model_boilerplate(model=model, dataloader=dev_dataloader, tokenizer=tokenizer, device='cuda', inputs = ["context_question", "attention_mask_context_question"])

Predicting: 100%|██████████| 63/63 [00:01<00:00, 45.76it/s]


In [43]:
# Compare the predictions with the actual answers
for i in range(100):
    print(f"Predicted Answer: {preds[i]}")
    print(f"True Answer: {true_labels[i]}")
    print("-" * 50)

Predicted Answer: 
True Answer: rugby
--------------------------------------------------
Predicted Answer: ,
True Answer: an official school sport
--------------------------------------------------
Predicted Answer: 
True Answer: high school
--------------------------------------------------
Predicted Answer: the
True Answer: framework
--------------------------------------------------
Predicted Answer: 
True Answer: complexity classes
--------------------------------------------------
Predicted Answer: ,
True Answer: complicated definitions
--------------------------------------------------
Predicted Answer: and
True Answer: palm springs
--------------------------------------------------
Predicted Answer: for
True Answer: southern
--------------------------------------------------
Predicted Answer: be
True Answer: open spaces
--------------------------------------------------
Predicted Answer: its
True Answer: beaches
--------------------------------------------------
Predicted Answer

In [25]:
train_preds, train_true_labels = predict_qa_context_model_boilerplate(model=model, dataloader=train_dataloader, tokenizer=tokenizer, device='cuda', inputs = ["context", "question", "attention_mask_question"])

Predicting: 100%|██████████| 625/625 [00:14<00:00, 44.49it/s]


In [26]:
# Compare the predictions with the actual answers
for i in range(100):
    print(f"Predicted Answer: {train_preds[i]}")
    print(f"True Answer: {train_true_labels[i]}")
    print("-" * 50)

Predicted Answer: portuguese
True Answer: coimbra
--------------------------------------------------
Predicted Answer: fatburger
True Answer: fatburger
--------------------------------------------------
Predicted Answer: 
True Answer: rigid descent rule
--------------------------------------------------
Predicted Answer: upper case ( also capital letters , capitals , caps , large letters , or more formally majuscule , see terminology ) and smaller lower case ( also small letters , or more formally minuscule , see terminology ) in the written representation of certain languages . here is a comparison of the upper and lower case versions of each letter included in the english alphabet
True Answer: larger upper
--------------------------------------------------
Predicted Answer: 
True Answer: statistical update
--------------------------------------------------
Predicted Answer: 118 % from 2012 – 13 when coverage still aired on fox soccer and espn / espn2 ( 220 , 000 viewers ), and nbc sp