In [1]:
import os
import torch
from torch import nn
import math
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, DataCollatorWithPadding, TrainingArguments
from dataclasses import dataclass
from transformers.modeling_outputs import ModelOutput
from typing import Optional, Tuple

# Enable CUDA debugging
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_files = {
    "train": "train-open.json",
    "validation": "val-open.json",
    "test": "test-open.json"
}
dataset = load_dataset("json", data_files=data_files)

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained("aubmindlab/aragpt2-base")
tokenizer.pad_token = tokenizer.eos_token

# Preprocess Dataset
def preprocess_function(examples):
    inputs = tokenizer(
        examples["question"], max_length=128, truncation=True, padding="max_length"
    )
    labels = tokenizer(
        examples["answer"], max_length=128, truncation=True, padding="max_length"
    )

    inputs["labels"] = labels["input_ids"]
    return inputs

# Tokenize Dataset
tokenized_datasets = dataset.map(preprocess_function, batched=True)

Map: 100%|██████████| 12715/12715 [00:10<00:00, 1240.00 examples/s]


In [3]:
# Define DifferentialAttention Class
class DifferentialAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.qkv_proj = nn.Linear(d_model, d_model * 3)
        self.lambda_param = nn.Parameter(torch.tensor(0.8))

    def forward(self, x):
        qkv = self.qkv_proj(x).chunk(3, dim=-1)
        q, k, v = qkv
        attention_scores = torch.softmax((q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1)), dim=-1)
        return attention_scores @ v

# Define Custom Output Class
@dataclass
class CustomCausalLMOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

# Define Custom Model Class
class CustomModel(nn.Module):
    def __init__(self, base_model, d_model, num_heads):
        super().__init__()
        self.base_model = base_model
        self.differential_attention = DifferentialAttention(d_model, num_heads)
        self.vocab_projection = nn.Linear(d_model, base_model.config.vocab_size)

    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.base_model(input_ids, attention_mask=attention_mask, labels=labels, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]  # Last hidden state
        attention_output = self.differential_attention(hidden_states)
        attention_output = self.vocab_projection(attention_output)

        if labels is not None:
            print(f"attention_output shape: {attention_output.shape}")
            print(f"labels shape: {labels.shape}")

            shift_logits = attention_output[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            print(f"shift_logits shape: {shift_logits.shape}")
            print(f"shift_labels shape: {shift_labels.shape}")

            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            return CustomCausalLMOutput(loss=loss, logits=attention_output)

        return CustomCausalLMOutput(logits=attention_output)


In [5]:
# Load Pre-trained Model
base_model = AutoModelForCausalLM.from_pretrained("aubmindlab/aragpt2-base")
base_model.config.output_hidden_states = True  # Enable hidden states

# Define Custom Model
d_model = base_model.config.hidden_size
num_heads = base_model.config.num_attention_heads
model = CustomModel(base_model, d_model, num_heads)

# Define Training Arguments
training_args = TrainingArguments(
    output_dir="./QAMODELLLLLL",  # Save directory
    eval_strategy="epoch",           # Evaluate at the end of each epoch
    learning_rate=3e-5,              # Adjust learning rate
    per_device_train_batch_size=2,   # Reduce batch size
    num_train_epochs=3,              # Number of epochs
    weight_decay=0.01,
    save_steps=500,
    save_total_limit=2,
    logging_dir="./logs",            # Directory for logs
    logging_steps=10,
    warmup_steps=100,                # Warmup steps
    gradient_accumulation_steps=4,   # Gradient accumulation
    fp16=False,                      # Disable mixed precision
    no_cuda=False,                   # Enable GPU
    run_name="ffffffinetuning" # Experiment name
)

In [6]:
# Define Data Collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"].select(range(1000)),  # Subset for training
    eval_dataset=tokenized_datasets["validation"].select(range(200)),  # Subset for evaluation
    data_collator=data_collator
)

# Train the Model
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmohammedberrhazi003[0m ([33mmohammedberrhazi003-universit-internationale-de-rabat[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/375 [00:00<?, ?it/s]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  0%|          | 1/375 [00:25<2:37:31, 25.27s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  1%|          | 2/375 [00:59<3:09:35, 30.50s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  1%|          | 3/375 [01:30<3:09:49, 30.62s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  1%|          | 4/375 [02:01<3:11:32, 30.98s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  1%|▏         | 5/375 [02:33<3:12:05, 31.15s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  2%|▏         | 6/375 [03:04<3:12:47, 31.35s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  2%|▏         | 7/375 [03:38<3:16:23, 32.02s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  2%|▏         | 8/375 [04:06<3:08:26, 30.81s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  2%|▏         | 9/375 [04:33<3:01:28, 29.75s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  3%|▎         | 10/375 [05:05<3:03:40, 30.19s/it]

{'loss': 44.2766, 'grad_norm': 66.46235656738281, 'learning_rate': 3e-06, 'epoch': 0.08}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  3%|▎         | 11/375 [05:45<3:21:28, 33.21s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  3%|▎         | 12/375 [06:17<3:18:28, 32.81s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  3%|▎         | 13/375 [06:48<3:15:41, 32.43s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  4%|▎         | 14/375 [07:20<3:13:15, 32.12s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  4%|▍         | 15/375 [07:51<3:11:49, 31.97s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  4%|▍         | 16/375 [08:23<3:10:48, 31.89s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  5%|▍         | 17/375 [08:54<3:09:43, 31.80s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  5%|▍         | 18/375 [09:26<3:07:58, 31.59s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  5%|▌         | 19/375 [09:57<3:07:34, 31.62s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  5%|▌         | 20/375 [10:28<3:06:00, 31.44s/it]

{'loss': 42.0561, 'grad_norm': 117.70447540283203, 'learning_rate': 6e-06, 'epoch': 0.16}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  6%|▌         | 21/375 [11:00<3:06:09, 31.55s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  6%|▌         | 22/375 [11:32<3:05:42, 31.57s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  6%|▌         | 23/375 [12:03<3:05:05, 31.55s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  6%|▋         | 24/375 [12:35<3:05:08, 31.65s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  7%|▋         | 25/375 [13:07<3:05:00, 31.72s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  7%|▋         | 26/375 [13:39<3:04:28, 31.71s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  7%|▋         | 27/375 [14:10<3:03:36, 31.66s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  7%|▋         | 28/375 [14:42<3:02:51, 31.62s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  8%|▊         | 29/375 [15:13<3:02:06, 31.58s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  8%|▊         | 30/375 [15:44<3:01:05, 31.49s/it]

{'loss': 35.6621, 'grad_norm': 184.2217254638672, 'learning_rate': 9e-06, 'epoch': 0.24}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  8%|▊         | 31/375 [16:16<3:01:13, 31.61s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  9%|▊         | 32/375 [16:48<3:00:24, 31.56s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  9%|▉         | 33/375 [17:20<3:00:47, 31.72s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  9%|▉         | 34/375 [17:52<3:00:09, 31.70s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


  9%|▉         | 35/375 [18:23<2:59:23, 31.66s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 10%|▉         | 36/375 [18:55<2:59:46, 31.82s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 10%|▉         | 37/375 [19:27<2:59:36, 31.88s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 10%|█         | 38/375 [19:59<2:59:03, 31.88s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 10%|█         | 39/375 [20:31<2:58:41, 31.91s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 11%|█         | 40/375 [21:03<2:58:08, 31.90s/it]

{'loss': 24.4391, 'grad_norm': 246.3253631591797, 'learning_rate': 1.2e-05, 'epoch': 0.32}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 11%|█         | 41/375 [21:34<2:55:47, 31.58s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 11%|█         | 42/375 [22:05<2:55:02, 31.54s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 11%|█▏        | 43/375 [22:37<2:54:05, 31.46s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 12%|█▏        | 44/375 [23:08<2:53:37, 31.47s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 12%|█▏        | 45/375 [23:40<2:53:26, 31.54s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 12%|█▏        | 46/375 [24:12<2:53:27, 31.63s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 13%|█▎        | 47/375 [24:43<2:53:09, 31.68s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 13%|█▎        | 48/375 [25:17<2:56:28, 32.38s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 13%|█▎        | 49/375 [25:49<2:54:16, 32.07s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 13%|█▎        | 50/375 [26:21<2:53:26, 32.02s/it]

{'loss': 7.0266, 'grad_norm': 17.277692794799805, 'learning_rate': 1.5e-05, 'epoch': 0.4}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 14%|█▎        | 51/375 [26:52<2:52:21, 31.92s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 14%|█▍        | 52/375 [27:24<2:51:19, 31.82s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 14%|█▍        | 53/375 [27:56<2:50:57, 31.85s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 14%|█▍        | 54/375 [28:28<2:50:18, 31.83s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 15%|█▍        | 55/375 [28:59<2:49:32, 31.79s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 15%|█▍        | 56/375 [29:31<2:48:39, 31.72s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 15%|█▌        | 57/375 [30:03<2:48:29, 31.79s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 15%|█▌        | 58/375 [30:35<2:48:26, 31.88s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 16%|█▌        | 59/375 [31:07<2:48:24, 31.98s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 16%|█▌        | 60/375 [31:39<2:47:50, 31.97s/it]

{'loss': 5.1298, 'grad_norm': 85.18212890625, 'learning_rate': 1.8e-05, 'epoch': 0.48}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 16%|█▋        | 61/375 [32:10<2:45:39, 31.65s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 17%|█▋        | 62/375 [32:42<2:45:35, 31.74s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 17%|█▋        | 63/375 [33:13<2:44:32, 31.64s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 17%|█▋        | 64/375 [33:45<2:43:33, 31.55s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 17%|█▋        | 65/375 [34:16<2:43:05, 31.57s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 18%|█▊        | 66/375 [34:48<2:42:46, 31.61s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 18%|█▊        | 67/375 [35:20<2:42:07, 31.58s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 18%|█▊        | 68/375 [35:51<2:41:50, 31.63s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 18%|█▊        | 69/375 [36:23<2:41:45, 31.72s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 19%|█▊        | 70/375 [36:55<2:41:23, 31.75s/it]

{'loss': 5.192, 'grad_norm': 24.943004608154297, 'learning_rate': 2.1e-05, 'epoch': 0.56}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 19%|█▉        | 71/375 [37:27<2:41:48, 31.94s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 19%|█▉        | 72/375 [37:59<2:41:08, 31.91s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 19%|█▉        | 73/375 [38:31<2:40:45, 31.94s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 20%|█▉        | 74/375 [39:03<2:40:17, 31.95s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 20%|██        | 75/375 [39:35<2:39:44, 31.95s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 20%|██        | 76/375 [40:07<2:39:24, 31.99s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 21%|██        | 77/375 [40:39<2:38:51, 31.98s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 21%|██        | 78/375 [41:11<2:38:35, 32.04s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 21%|██        | 79/375 [41:44<2:38:35, 32.15s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 21%|██▏       | 80/375 [42:16<2:37:25, 32.02s/it]

{'loss': 4.2397, 'grad_norm': 91.32967376708984, 'learning_rate': 2.4e-05, 'epoch': 0.64}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 22%|██▏       | 81/375 [42:48<2:36:48, 32.00s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 22%|██▏       | 82/375 [43:19<2:35:54, 31.93s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 22%|██▏       | 83/375 [43:51<2:35:34, 31.97s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 22%|██▏       | 84/375 [44:23<2:35:05, 31.98s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 23%|██▎       | 85/375 [44:55<2:34:14, 31.91s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 23%|██▎       | 86/375 [45:27<2:33:57, 31.96s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 23%|██▎       | 87/375 [45:59<2:33:18, 31.94s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 23%|██▎       | 88/375 [46:31<2:32:56, 31.97s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 24%|██▎       | 89/375 [47:03<2:32:12, 31.93s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 24%|██▍       | 90/375 [47:35<2:31:52, 31.97s/it]

{'loss': 5.1076, 'grad_norm': 8.728018760681152, 'learning_rate': 2.7000000000000002e-05, 'epoch': 0.72}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 24%|██▍       | 91/375 [48:07<2:31:10, 31.94s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 25%|██▍       | 92/375 [48:39<2:30:53, 31.99s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 25%|██▍       | 93/375 [49:11<2:30:24, 32.00s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 25%|██▌       | 94/375 [49:43<2:29:59, 32.03s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 25%|██▌       | 95/375 [50:30<2:50:52, 36.62s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 26%|██▌       | 96/375 [51:03<2:44:23, 35.35s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 26%|██▌       | 97/375 [51:35<2:39:02, 34.32s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 26%|██▌       | 98/375 [52:00<2:26:12, 31.67s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 26%|██▋       | 99/375 [52:18<2:05:48, 27.35s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 27%|██▋       | 100/375 [52:34<1:50:30, 24.11s/it]

{'loss': 4.6885, 'grad_norm': 10.714316368103027, 'learning_rate': 3e-05, 'epoch': 0.8}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 27%|██▋       | 101/375 [52:51<1:40:45, 22.06s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 27%|██▋       | 102/375 [53:08<1:32:31, 20.33s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 27%|██▋       | 103/375 [53:24<1:26:46, 19.14s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 28%|██▊       | 104/375 [53:40<1:22:30, 18.27s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 28%|██▊       | 105/375 [53:57<1:19:32, 17.68s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 28%|██▊       | 106/375 [54:13<1:17:35, 17.31s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 29%|██▊       | 107/375 [54:29<1:15:51, 16.98s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 29%|██▉       | 108/375 [54:45<1:14:34, 16.76s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 29%|██▉       | 109/375 [55:02<1:13:36, 16.61s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 29%|██▉       | 110/375 [55:20<1:16:01, 17.21s/it]

{'loss': 4.3794, 'grad_norm': 9.70386028289795, 'learning_rate': 2.890909090909091e-05, 'epoch': 0.88}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 30%|██▉       | 111/375 [56:09<1:57:39, 26.74s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 30%|██▉       | 112/375 [56:46<2:09:57, 29.65s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 30%|███       | 113/375 [57:11<2:03:17, 28.24s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 30%|███       | 114/375 [57:34<1:55:49, 26.63s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 31%|███       | 115/375 [57:56<1:49:57, 25.37s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 31%|███       | 116/375 [58:19<1:46:04, 24.57s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 31%|███       | 117/375 [58:41<1:42:53, 23.93s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 31%|███▏      | 118/375 [59:04<1:40:52, 23.55s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 32%|███▏      | 119/375 [59:26<1:39:08, 23.24s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 32%|███▏      | 120/375 [59:49<1:38:17, 23.13s/it]

{'loss': 4.6676, 'grad_norm': 22.408517837524414, 'learning_rate': 2.7818181818181818e-05, 'epoch': 0.96}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 32%|███▏      | 121/375 [1:00:14<1:40:39, 23.78s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 33%|███▎      | 122/375 [1:00:54<1:59:55, 28.44s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 33%|███▎      | 123/375 [1:01:31<2:10:33, 31.09s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 33%|███▎      | 124/375 [1:02:10<2:19:59, 33.46s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 33%|███▎      | 125/375 [1:02:51<2:28:12, 35.57s/it]

attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])


                                                     
 33%|███▎      | 125/375 [1:07:38<2:28:12, 35.57s/it]

{'eval_loss': 1.206662654876709, 'eval_runtime': 287.3032, 'eval_samples_per_second': 0.696, 'eval_steps_per_second': 0.087, 'epoch': 1.0}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 34%|███▎      | 126/375 [1:08:15<8:27:19, 122.25s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 34%|███▍      | 127/375 [1:08:50<6:37:38, 96.20s/it] 

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 34%|███▍      | 128/375 [1:09:23<5:17:50, 77.21s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 34%|███▍      | 129/375 [1:09:55<4:20:00, 63.42s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 35%|███▍      | 130/375 [1:10:26<3:39:27, 53.75s/it]

{'loss': 4.5792, 'grad_norm': 8.751739501953125, 'learning_rate': 2.6727272727272728e-05, 'epoch': 1.04}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 35%|███▍      | 131/375 [1:10:57<3:11:23, 47.06s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 35%|███▌      | 132/375 [1:11:35<2:59:51, 44.41s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 35%|███▌      | 133/375 [1:12:23<3:02:39, 45.29s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 36%|███▌      | 134/375 [1:12:57<2:48:11, 41.87s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 36%|███▌      | 135/375 [1:13:27<2:34:10, 38.54s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 36%|███▋      | 136/375 [1:13:59<2:24:41, 36.33s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 37%|███▋      | 137/375 [1:14:29<2:17:07, 34.57s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 37%|███▋      | 138/375 [1:15:00<2:12:32, 33.56s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 37%|███▋      | 139/375 [1:15:32<2:09:26, 32.91s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 37%|███▋      | 140/375 [1:16:03<2:07:18, 32.50s/it]

{'loss': 4.0578, 'grad_norm': 35.22407913208008, 'learning_rate': 2.5636363636363635e-05, 'epoch': 1.12}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 38%|███▊      | 141/375 [1:16:35<2:06:26, 32.42s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 38%|███▊      | 142/375 [1:17:07<2:05:22, 32.28s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 38%|███▊      | 143/375 [1:17:40<2:04:44, 32.26s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 38%|███▊      | 144/375 [1:18:11<2:03:24, 32.06s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 39%|███▊      | 145/375 [1:18:43<2:03:03, 32.10s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 39%|███▉      | 146/375 [1:19:16<2:02:49, 32.18s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 39%|███▉      | 147/375 [1:19:48<2:02:06, 32.13s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 39%|███▉      | 148/375 [1:20:20<2:01:28, 32.11s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 40%|███▉      | 149/375 [1:20:52<2:00:25, 31.97s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 40%|████      | 150/375 [1:21:28<2:04:41, 33.25s/it]

{'loss': 3.8899, 'grad_norm': 40.78175354003906, 'learning_rate': 2.454545454545455e-05, 'epoch': 1.2}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 40%|████      | 151/375 [1:22:00<2:02:30, 32.82s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 41%|████      | 152/375 [1:22:31<2:00:42, 32.48s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 41%|████      | 153/375 [1:23:03<1:59:39, 32.34s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 41%|████      | 154/375 [1:23:35<1:58:55, 32.29s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 41%|████▏     | 155/375 [1:24:07<1:57:42, 32.10s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 42%|████▏     | 156/375 [1:24:39<1:57:04, 32.08s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 42%|████▏     | 157/375 [1:25:13<1:58:59, 32.75s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 42%|████▏     | 158/375 [1:25:44<1:56:22, 32.18s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 42%|████▏     | 159/375 [1:26:16<1:55:27, 32.07s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 43%|████▎     | 160/375 [1:26:51<1:57:53, 32.90s/it]

{'loss': 3.1657, 'grad_norm': 14.020478248596191, 'learning_rate': 2.3454545454545456e-05, 'epoch': 1.28}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 43%|████▎     | 161/375 [1:27:23<1:56:13, 32.58s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 43%|████▎     | 162/375 [1:27:55<1:54:52, 32.36s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 43%|████▎     | 163/375 [1:28:27<1:54:22, 32.37s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 44%|████▎     | 164/375 [1:28:59<1:53:48, 32.36s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 44%|████▍     | 165/375 [1:29:31<1:52:54, 32.26s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 44%|████▍     | 166/375 [1:30:03<1:52:07, 32.19s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 45%|████▍     | 167/375 [1:30:35<1:51:17, 32.10s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 45%|████▍     | 168/375 [1:31:07<1:50:24, 32.00s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 45%|████▌     | 169/375 [1:31:39<1:49:52, 32.00s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 45%|████▌     | 170/375 [1:32:12<1:49:48, 32.14s/it]

{'loss': 3.9063, 'grad_norm': 29.318450927734375, 'learning_rate': 2.2363636363636366e-05, 'epoch': 1.36}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 46%|████▌     | 171/375 [1:32:43<1:48:23, 31.88s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 46%|████▌     | 172/375 [1:33:10<1:43:30, 30.59s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 46%|████▌     | 173/375 [1:33:27<1:28:27, 26.27s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 46%|████▋     | 174/375 [1:33:43<1:17:54, 23.26s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 47%|████▋     | 175/375 [1:33:59<1:10:26, 21.13s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 47%|████▋     | 176/375 [1:34:15<1:05:17, 19.69s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 47%|████▋     | 177/375 [1:34:32<1:01:41, 18.69s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 47%|████▋     | 178/375 [1:34:48<58:52, 17.93s/it]  

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 48%|████▊     | 179/375 [1:35:04<57:00, 17.45s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 48%|████▊     | 180/375 [1:35:20<55:25, 17.05s/it]

{'loss': 3.4499, 'grad_norm': 8.450716018676758, 'learning_rate': 2.1272727272727273e-05, 'epoch': 1.44}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 48%|████▊     | 181/375 [1:35:37<54:19, 16.80s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 49%|████▊     | 182/375 [1:35:53<53:17, 16.57s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 49%|████▉     | 183/375 [1:36:09<52:37, 16.44s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 49%|████▉     | 184/375 [1:36:25<51:59, 16.33s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 49%|████▉     | 185/375 [1:36:41<51:25, 16.24s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 50%|████▉     | 186/375 [1:36:57<50:59, 16.19s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 50%|████▉     | 187/375 [1:37:13<50:34, 16.14s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 50%|█████     | 188/375 [1:37:29<50:08, 16.09s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 50%|█████     | 189/375 [1:37:45<49:50, 16.08s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 51%|█████     | 190/375 [1:38:01<49:32, 16.07s/it]

{'loss': 3.4219, 'grad_norm': 22.729333877563477, 'learning_rate': 2.0181818181818183e-05, 'epoch': 1.52}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 51%|█████     | 191/375 [1:38:17<49:18, 16.08s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 51%|█████     | 192/375 [1:38:33<49:01, 16.07s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 51%|█████▏    | 193/375 [1:38:49<48:48, 16.09s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 52%|█████▏    | 194/375 [1:39:05<48:27, 16.07s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 52%|█████▏    | 195/375 [1:39:21<48:10, 16.06s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 52%|█████▏    | 196/375 [1:39:37<47:55, 16.06s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 53%|█████▎    | 197/375 [1:39:53<47:36, 16.05s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 53%|█████▎    | 198/375 [1:40:09<47:20, 16.05s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 53%|█████▎    | 199/375 [1:40:25<47:02, 16.04s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 53%|█████▎    | 200/375 [1:40:42<46:57, 16.10s/it]

{'loss': 4.5336, 'grad_norm': 8.08409309387207, 'learning_rate': 1.909090909090909e-05, 'epoch': 1.6}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 54%|█████▎    | 201/375 [1:40:58<46:39, 16.09s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 54%|█████▍    | 202/375 [1:41:14<46:20, 16.07s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 54%|█████▍    | 203/375 [1:41:30<46:08, 16.10s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 54%|█████▍    | 204/375 [1:41:46<45:51, 16.09s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 55%|█████▍    | 205/375 [1:42:02<45:36, 16.10s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 55%|█████▍    | 206/375 [1:42:18<45:18, 16.09s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 55%|█████▌    | 207/375 [1:42:34<45:00, 16.07s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 55%|█████▌    | 208/375 [1:42:50<44:45, 16.08s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 56%|█████▌    | 209/375 [1:43:06<44:29, 16.08s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 56%|█████▌    | 210/375 [1:43:22<44:09, 16.06s/it]

{'loss': 3.3206, 'grad_norm': 13.57955265045166, 'learning_rate': 1.8e-05, 'epoch': 1.68}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 56%|█████▋    | 211/375 [1:43:38<43:50, 16.04s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 57%|█████▋    | 212/375 [1:43:54<43:34, 16.04s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 57%|█████▋    | 213/375 [1:44:10<43:16, 16.03s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 57%|█████▋    | 214/375 [1:44:26<42:58, 16.02s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 57%|█████▋    | 215/375 [1:44:42<42:42, 16.02s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 58%|█████▊    | 216/375 [1:44:58<42:27, 16.02s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 58%|█████▊    | 217/375 [1:45:15<42:16, 16.06s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 58%|█████▊    | 218/375 [1:45:31<41:59, 16.05s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 58%|█████▊    | 219/375 [1:45:47<41:41, 16.03s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 59%|█████▊    | 220/375 [1:46:03<41:23, 16.02s/it]

{'loss': 4.3297, 'grad_norm': 12.540483474731445, 'learning_rate': 1.6909090909090907e-05, 'epoch': 1.76}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 59%|█████▉    | 221/375 [1:46:19<41:17, 16.09s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 59%|█████▉    | 222/375 [1:46:35<41:11, 16.16s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 59%|█████▉    | 223/375 [1:46:51<40:57, 16.17s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 60%|█████▉    | 224/375 [1:47:08<40:53, 16.25s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 60%|██████    | 225/375 [1:47:26<41:54, 16.76s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 60%|██████    | 226/375 [1:47:42<41:11, 16.58s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 61%|██████    | 227/375 [1:47:58<40:32, 16.43s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 61%|██████    | 228/375 [1:48:30<52:00, 21.23s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 61%|██████    | 229/375 [1:48:55<54:08, 22.25s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 61%|██████▏   | 230/375 [1:49:20<55:22, 22.92s/it]

{'loss': 3.401, 'grad_norm': 25.823230743408203, 'learning_rate': 1.5818181818181818e-05, 'epoch': 1.84}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 62%|██████▏   | 231/375 [1:49:44<56:07, 23.39s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 62%|██████▏   | 232/375 [1:50:08<56:22, 23.66s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 62%|██████▏   | 233/375 [1:50:33<56:28, 23.86s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 62%|██████▏   | 234/375 [1:50:57<56:16, 23.95s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 63%|██████▎   | 235/375 [1:51:21<56:07, 24.06s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 63%|██████▎   | 236/375 [1:51:46<55:57, 24.15s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 63%|██████▎   | 237/375 [1:52:10<55:40, 24.21s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 63%|██████▎   | 238/375 [1:52:34<55:22, 24.25s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 64%|██████▎   | 239/375 [1:52:58<54:46, 24.16s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 64%|██████▍   | 240/375 [1:53:23<54:28, 24.21s/it]

{'loss': 4.7401, 'grad_norm': 19.70237922668457, 'learning_rate': 1.4727272727272728e-05, 'epoch': 1.92}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 64%|██████▍   | 241/375 [1:53:47<53:57, 24.16s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 65%|██████▍   | 242/375 [1:54:11<53:44, 24.24s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 65%|██████▍   | 243/375 [1:54:35<53:19, 24.24s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 65%|██████▌   | 244/375 [1:55:00<53:00, 24.27s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 65%|██████▌   | 245/375 [1:55:24<52:35, 24.27s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 66%|██████▌   | 246/375 [1:55:48<52:04, 24.22s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 66%|██████▌   | 247/375 [1:56:12<51:36, 24.19s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 66%|██████▌   | 248/375 [1:56:37<51:25, 24.30s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 66%|██████▋   | 249/375 [1:57:01<50:59, 24.28s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 67%|██████▋   | 250/375 [1:57:25<50:32, 24.26s/it]

{'loss': 3.5078, 'grad_norm': 12.72775936126709, 'learning_rate': 1.3636363636363637e-05, 'epoch': 2.0}
attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])


                                                   
 67%|██████▋   | 250/375 [2:00:46<50:32, 24.26s/it]

{'eval_loss': 1.168724775314331, 'eval_runtime': 200.5982, 'eval_samples_per_second': 0.997, 'eval_steps_per_second': 0.125, 'epoch': 2.0}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 67%|██████▋   | 251/375 [2:01:10<2:54:27, 84.41s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 67%|██████▋   | 252/375 [2:01:34<2:16:12, 66.45s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 67%|██████▋   | 253/375 [2:01:59<1:49:31, 53.86s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 68%|██████▊   | 254/375 [2:02:23<1:30:48, 45.03s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 68%|██████▊   | 255/375 [2:02:47<1:17:33, 38.78s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 68%|██████▊   | 256/375 [2:03:12<1:08:18, 34.44s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 69%|██████▊   | 257/375 [2:03:36<1:01:39, 31.35s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 69%|██████▉   | 258/375 [2:04:01<57:11, 29.33s/it]  

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 69%|██████▉   | 259/375 [2:04:25<53:45, 27.81s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 69%|██████▉   | 260/375 [2:04:49<51:14, 26.73s/it]

{'loss': 3.6071, 'grad_norm': 27.04330062866211, 'learning_rate': 1.2545454545454545e-05, 'epoch': 2.08}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 70%|██████▉   | 261/375 [2:05:13<49:25, 26.01s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 70%|██████▉   | 262/375 [2:05:37<47:56, 25.45s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 70%|███████   | 263/375 [2:06:01<46:40, 25.00s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 70%|███████   | 264/375 [2:06:25<45:43, 24.72s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 71%|███████   | 265/375 [2:06:50<44:58, 24.53s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 71%|███████   | 266/375 [2:07:14<44:23, 24.43s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 71%|███████   | 267/375 [2:07:38<43:49, 24.35s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 71%|███████▏  | 268/375 [2:08:03<43:44, 24.53s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 72%|███████▏  | 269/375 [2:08:27<43:13, 24.47s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 72%|███████▏  | 270/375 [2:08:52<42:49, 24.47s/it]

{'loss': 3.4073, 'grad_norm': 6.49222993850708, 'learning_rate': 1.1454545454545455e-05, 'epoch': 2.16}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 72%|███████▏  | 271/375 [2:09:16<42:22, 24.45s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 73%|███████▎  | 272/375 [2:09:40<41:49, 24.37s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 73%|███████▎  | 273/375 [2:10:04<41:19, 24.31s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 73%|███████▎  | 274/375 [2:10:29<40:55, 24.31s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 73%|███████▎  | 275/375 [2:10:53<40:33, 24.34s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 74%|███████▎  | 276/375 [2:11:18<40:10, 24.35s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 74%|███████▍  | 277/375 [2:11:42<39:42, 24.31s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 74%|███████▍  | 278/375 [2:12:06<39:14, 24.28s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 74%|███████▍  | 279/375 [2:12:30<38:41, 24.18s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 75%|███████▍  | 280/375 [2:12:54<38:20, 24.22s/it]

{'loss': 3.257, 'grad_norm': 7.759101390838623, 'learning_rate': 1.0363636363636364e-05, 'epoch': 2.24}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 75%|███████▍  | 281/375 [2:13:18<37:56, 24.21s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 75%|███████▌  | 282/375 [2:13:43<37:37, 24.28s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 75%|███████▌  | 283/375 [2:14:07<37:12, 24.27s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 76%|███████▌  | 284/375 [2:14:31<36:52, 24.31s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 76%|███████▌  | 285/375 [2:14:56<36:25, 24.28s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 76%|███████▋  | 286/375 [2:15:20<35:58, 24.25s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 77%|███████▋  | 287/375 [2:15:44<35:32, 24.23s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 77%|███████▋  | 288/375 [2:16:08<35:05, 24.20s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 77%|███████▋  | 289/375 [2:16:32<34:43, 24.23s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 77%|███████▋  | 290/375 [2:16:57<34:19, 24.23s/it]

{'loss': 3.3806, 'grad_norm': 16.02899742126465, 'learning_rate': 9.272727272727273e-06, 'epoch': 2.32}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 78%|███████▊  | 291/375 [2:17:21<33:54, 24.22s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 78%|███████▊  | 292/375 [2:17:45<33:30, 24.22s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 78%|███████▊  | 293/375 [2:18:09<33:07, 24.24s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 78%|███████▊  | 294/375 [2:18:33<32:38, 24.18s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 79%|███████▊  | 295/375 [2:18:58<32:13, 24.17s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 79%|███████▉  | 296/375 [2:19:22<31:58, 24.29s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 79%|███████▉  | 297/375 [2:19:47<31:39, 24.35s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 79%|███████▉  | 298/375 [2:20:11<31:22, 24.44s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 80%|███████▉  | 299/375 [2:20:36<30:52, 24.38s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 80%|████████  | 300/375 [2:21:00<30:24, 24.33s/it]

{'loss': 4.3299, 'grad_norm': 18.335433959960938, 'learning_rate': 8.181818181818181e-06, 'epoch': 2.4}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 80%|████████  | 301/375 [2:21:24<30:00, 24.33s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 81%|████████  | 302/375 [2:21:48<29:32, 24.28s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 81%|████████  | 303/375 [2:22:13<29:11, 24.33s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 81%|████████  | 304/375 [2:22:37<28:47, 24.33s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 81%|████████▏ | 305/375 [2:23:02<28:26, 24.38s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 82%|████████▏ | 306/375 [2:23:26<28:02, 24.38s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 82%|████████▏ | 307/375 [2:23:50<27:39, 24.40s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 82%|████████▏ | 308/375 [2:24:15<27:18, 24.45s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 82%|████████▏ | 309/375 [2:24:39<26:50, 24.40s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 83%|████████▎ | 310/375 [2:25:03<26:21, 24.33s/it]

{'loss': 3.8661, 'grad_norm': 12.510926246643066, 'learning_rate': 7.090909090909091e-06, 'epoch': 2.48}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 83%|████████▎ | 311/375 [2:25:30<26:45, 25.08s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 83%|████████▎ | 312/375 [2:25:55<26:10, 24.93s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 83%|████████▎ | 313/375 [2:26:19<25:34, 24.74s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 84%|████████▎ | 314/375 [2:26:44<25:12, 24.80s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 84%|████████▍ | 315/375 [2:27:09<24:42, 24.70s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 84%|████████▍ | 316/375 [2:27:33<24:09, 24.56s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 85%|████████▍ | 317/375 [2:27:57<23:39, 24.47s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 85%|████████▍ | 318/375 [2:28:21<23:06, 24.33s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 85%|████████▌ | 319/375 [2:28:45<22:38, 24.26s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 85%|████████▌ | 320/375 [2:29:09<22:14, 24.26s/it]

{'loss': 3.6093, 'grad_norm': 7.178447246551514, 'learning_rate': 6e-06, 'epoch': 2.56}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 86%|████████▌ | 321/375 [2:29:33<21:43, 24.13s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 86%|████████▌ | 322/375 [2:29:58<21:23, 24.22s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 86%|████████▌ | 323/375 [2:30:22<20:55, 24.15s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 86%|████████▋ | 324/375 [2:30:46<20:35, 24.23s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 87%|████████▋ | 325/375 [2:31:10<20:10, 24.21s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 87%|████████▋ | 326/375 [2:31:35<19:47, 24.24s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 87%|████████▋ | 327/375 [2:31:59<19:26, 24.31s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 87%|████████▋ | 328/375 [2:32:23<19:04, 24.35s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 88%|████████▊ | 329/375 [2:32:48<18:39, 24.34s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 88%|████████▊ | 330/375 [2:33:12<18:16, 24.36s/it]

{'loss': 3.1931, 'grad_norm': 6.0260138511657715, 'learning_rate': 4.90909090909091e-06, 'epoch': 2.64}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 88%|████████▊ | 331/375 [2:33:37<17:53, 24.40s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 89%|████████▊ | 332/375 [2:34:01<17:29, 24.42s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 89%|████████▉ | 333/375 [2:34:25<17:01, 24.32s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 89%|████████▉ | 334/375 [2:34:49<16:32, 24.20s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 89%|████████▉ | 335/375 [2:35:13<16:08, 24.21s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 90%|████████▉ | 336/375 [2:35:38<15:45, 24.25s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 90%|████████▉ | 337/375 [2:36:02<15:21, 24.26s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 90%|█████████ | 338/375 [2:36:26<14:56, 24.22s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 90%|█████████ | 339/375 [2:36:50<14:32, 24.23s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 91%|█████████ | 340/375 [2:37:15<14:10, 24.30s/it]

{'loss': 3.8851, 'grad_norm': 14.371610641479492, 'learning_rate': 3.818181818181818e-06, 'epoch': 2.72}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 91%|█████████ | 341/375 [2:37:39<13:45, 24.27s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 91%|█████████ | 342/375 [2:38:03<13:22, 24.30s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 91%|█████████▏| 343/375 [2:38:28<13:00, 24.40s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 92%|█████████▏| 344/375 [2:38:52<12:34, 24.33s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 92%|█████████▏| 345/375 [2:39:17<12:09, 24.33s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 92%|█████████▏| 346/375 [2:39:41<11:46, 24.36s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 93%|█████████▎| 347/375 [2:40:05<11:22, 24.39s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 93%|█████████▎| 348/375 [2:40:30<10:56, 24.33s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 93%|█████████▎| 349/375 [2:40:54<10:31, 24.29s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 93%|█████████▎| 350/375 [2:41:18<10:05, 24.20s/it]

{'loss': 3.8063, 'grad_norm': 13.26820182800293, 'learning_rate': 2.7272727272727272e-06, 'epoch': 2.8}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 94%|█████████▎| 351/375 [2:41:42<09:41, 24.25s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 94%|█████████▍| 352/375 [2:42:07<09:18, 24.29s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 94%|█████████▍| 353/375 [2:42:31<08:55, 24.33s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 94%|█████████▍| 354/375 [2:42:55<08:30, 24.30s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 95%|█████████▍| 355/375 [2:43:19<08:05, 24.29s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 95%|█████████▍| 356/375 [2:43:44<07:41, 24.27s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 95%|█████████▌| 357/375 [2:44:08<07:14, 24.15s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 95%|█████████▌| 358/375 [2:44:32<06:50, 24.15s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 96%|█████████▌| 359/375 [2:44:56<06:26, 24.17s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 96%|█████████▌| 360/375 [2:45:20<06:03, 24.26s/it]

{'loss': 3.3282, 'grad_norm': 14.216032981872559, 'learning_rate': 1.6363636363636363e-06, 'epoch': 2.88}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 96%|█████████▋| 361/375 [2:45:45<05:41, 24.37s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 97%|█████████▋| 362/375 [2:46:09<05:15, 24.27s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 97%|█████████▋| 363/375 [2:46:33<04:50, 24.20s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 97%|█████████▋| 364/375 [2:46:57<04:26, 24.21s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 97%|█████████▋| 365/375 [2:47:26<04:14, 25.50s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 98%|█████████▊| 366/375 [2:48:11<04:41, 31.26s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 98%|█████████▊| 367/375 [2:48:44<04:14, 31.87s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 98%|█████████▊| 368/375 [2:49:06<03:22, 28.99s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 98%|█████████▊| 369/375 [2:49:28<02:41, 26.95s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 99%|█████████▊| 370/375 [2:49:50<02:07, 25.44s/it]

{'loss': 3.3133, 'grad_norm': 4.606822490692139, 'learning_rate': 5.454545454545455e-07, 'epoch': 2.96}
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 99%|█████████▉| 371/375 [2:50:12<01:37, 24.33s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 99%|█████████▉| 372/375 [2:50:34<01:10, 23.62s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


 99%|█████████▉| 373/375 [2:50:56<00:46, 23.11s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


100%|█████████▉| 374/375 [2:51:18<00:22, 22.79s/it]

attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])
attention_output shape: torch.Size([2, 128, 64000])
labels shape: torch.Size([2, 128])
shift_logits shape: torch.Size([2, 127, 64000])
shift_labels shape: torch.Size([2, 127])


100%|██████████| 375/375 [2:51:40<00:00, 22.51s/it]

RuntimeError: 
            Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'base_model.lm_head.weight', 'base_model.transformer.wte.weight'}].
            A potential way to correctly save your model is to use `save_model`.
            More information at https://huggingface.co/docs/safetensors/torch_shared_tensors
            

In [7]:
# Evaluate the Model
test_results = trainer.evaluate(eval_dataset=tokenized_datasets["test"].select(range(200)))
print(test_results)


attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])




attention_output shape: torch.Size([8, 128, 64000])
labels shape: torch.Size([8, 128])
shift_logits shape: torch.Size([8, 127, 64000])
shift_labels shape: torch.Size([8, 127])


                                                   
100%|██████████| 375/375 [3:00:41<00:00, 22.51s/it]

{'eval_loss': 1.1760284900665283, 'eval_runtime': 301.8594, 'eval_samples_per_second': 0.663, 'eval_steps_per_second': 0.083, 'epoch': 3.0}
{'eval_loss': 1.1760284900665283, 'eval_runtime': 301.8594, 'eval_samples_per_second': 0.663, 'eval_steps_per_second': 0.083, 'epoch': 3.0}


In [8]:
# Save the Model and Components
model.base_model.save_pretrained("./arabicaqa_model-base22")
tokenizer.save_pretrained("./arabicaqa_model-base22")
torch.save(model.differential_attention.state_dict(), "./arabicaqa_model-base/differential_attention.pth")

In [12]:
# Test the Model
def generate_answer(question):
    # Tokenize the input question
    inputs = tokenizer(question, return_tensors="pt", max_length=128, truncation=True, padding="max_length").to(model.base_model.device)
    
    # Generate a response using max_new_tokens
    outputs = model.base_model.generate(**inputs, max_new_tokens=50)  # Generate up to 50 new tokens
    
    # Decode the output to text
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Example usage
question ="ما هو عدد سكان إستونيا؟"
answer = generate_answer(question)
print(f"Question: {question}")
print(f"Answer: {answer}")

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Question: ما هو عدد سكان إستونيا؟
Answer: ما هو عدد سكان إستونيا؟وداتينوبيينوبودنا هما من بين أكبر دول حوض النيل ، هما من بين أكبر دول حوض النيل ، هما من بين أكبر دول حوض النيل ، هما من بين أكبر دول حوض النيل ، هما من بين أكبر دول حوض النيل ، هما من بين
