### Quora Question Pairs
是一个二分类问题，旨在判断一对问题是否具有相同的含义。这项任务的目标是识别出在意义上相似或者相同的问题，即判断它们是否可以被视为同一个问题的不同表述。
#### 任务特点
1. 输入：每个样本包含两个问题（例如，"What is the best way to learn Python?" 和 "How can I effectively learn Python?"）。
2. 输出：一个二进制标签，指示这对问题是否重复或相似（1 表示相似，0 表示不相似）。
3. 应用场景：该任务在问答系统、搜索引擎优化和用户生成内容的去重中具有广泛应用。

### Transformers完成Quora Question Pairs任务的步骤
1. 数据准备
* 输入格式：将每对问题组合成输入，通常格式为 "[CLS] question1 [SEP] question2 [SEP]"。[CLS] 是分类标记，[SEP] 用于分隔两个问题。
* 标签：为每对问题分配标签（1 表示相似，0 表示不相似）。
2. Tokenization
* 使用预训练的 tokenizer（如 BERT 的 tokenizer）将文本转换为模型可以理解的输入格式。这包括：
    * 将文本转换为词 ID。
    * 创建注意力掩码（attention mask），指示模型哪些部分是实际输入，哪些部分是填充（padding）。
3. 嵌入层
* 输入的词 ID 通过嵌入层（embedding layer）转换为密集向量，这些向量包含了词的语义信息。
4. Transformer编码器
* 输入向量通过多个 Transformer 编码器层，这些层包括自注意力机制（self-attention）和前馈神经网络（feedforward neural network）。
    * 自注意力机制：允许模型在处理一个问题时，同时考虑输入中所有词的关系，从而捕捉上下文信息。
    * 层归一化和残差连接：提高模型的稳定性和收敛速度。
5. 池化
* 在最后一层输出中，使用 [CLS] 标记的输出向量作为整个问题对的表示。这通常是进行分类的基础。
6. 分类层
* 将 [CLS] 的输出通过一个全连接层（fully connected layer）进行分类，输出一个二分类结果（相似或不相似）。
7. 训练
* 使用交叉熵损失（cross-entropy loss）来评估模型的输出与实际标签之间的差距。通过反向传播（backpropagation）更新模型参数，逐步提高准确性。
8. 评估与推理
* 在验证集上评估模型性能，使用指标如准确率（accuracy）、F1 分数等来衡量模型效果。

### Tokenization的底层实现
1. 文本预处理
* 在开始分词之前，通常需要对文本进行预处理。这包括去除多余的空格、转换为小写（对于某些模型）、去除特殊字符等。
2. 分词算法
* WordPiece（用于 BERT）：将单词分解为更小的子词单元，允许模型处理未见过的词。
* Byte Pair Encoding (BPE)（用于 GPT）：通过合并频率最高的字符对来创建词汇表，能够适应多种语言。
* SentencePiece：一种无监督的分词算法，可以处理空格不明显的语言，如中文。
3. 构建词汇表
* 在训练过程中，Tokenization 通常会构建一个词汇表，记录每个 token 及其对应的 ID。这一步骤在使用预训练模型时通常已经完成。
4. 分词与编码
* 将文本输入 Tokenizer，使用分词算法将文本拆分为 token。
* 为每个 token 分配一个 ID，生成 input_ids。
* 例如，在 WordPiece 中，"unhappiness" 可能被拆分为 ["un", "happiness"]，并分别分配 ID。
5. 生成Attention Mask
* 对于填充的 token，生成一个 attention mask，指示哪些 tokens 是有效的（1）和无效的（0）。这对于处理可变长度的输入很重要。
6. 支持特殊token
* Tokenizer 通常支持添加一些特殊 token，比如 [CLS]（分类 token）和 [SEP]（分隔 token），这在多句子输入时尤为重要。

### Embedding Layer的底层实现
主要涉及将离散的词或标识符（ID）映射到连续的向量空间。这一过程通常包括以下几个步骤：
1. 词汇表建立
* 在训练之前，首先要构建一个词汇表，记录每个词或 token 及其对应的唯一 ID。这可以是从训练数据中统计得出的。
* Embedding Layer的词汇表通常是从Tokenization过程中生成的词汇表。
2. 初始化嵌入矩阵
* 嵌入层通常用一个二维矩阵表示，行数为词汇表大小，列数为嵌入维度。每一行对应一个词的嵌入向量。
* 这个嵌入矩阵通常是随机初始化的，或者使用预训练的嵌入（如 Word2Vec、GloVe）。
3. 查找嵌入向量
* 当输入一个序列的 token ID 时，嵌入层会使用这些 ID 在嵌入矩阵中查找对应的行，得到相应的嵌入向量。
* 例如，对于输入的 ID 序列 [1, 2, 3]，嵌入层将返回对应的嵌入向量矩阵，其中每一行是 ID 对应的嵌入。
4. 前向传播
* 在神经网络中，嵌入层的输出将作为输入传递到后续的网络层。这个过程是模型的前向传播的一部分。
5. 反向传播
* 在训练过程中，嵌入层的权重（即嵌入矩阵）会随着损失的反向传播而更新。通过梯度下降等优化算法，嵌入向量会逐渐学习到更适合下游任务的表示。


### Transformers和Bert的关系
1. Transformers
* 定义：Transformers 是一种深度学习模型架构，首次在论文《Attention is All You Need》中提出。它们主要用于处理序列数据，特别是在自然语言处理（NLP）任务中。
* 组成部分：Transformers 由编码器（encoder）和解码器（decoder）组成，但在许多应用中（如文本分类和问答），通常只使用编码器部分。
* 自注意力机制：Transformers 的核心特性是自注意力机制（self-attention），使模型能够在处理输入序列时考虑到序列中所有单词的关系，从而捕捉上下文信息。
2. Bert
* 定义：BERT（Bidirectional Encoder Representations from Transformers）是基于 Transformer 编码器架构的预训练模型，专门为 NLP 任务设计。
* 双向性：BERT 的一个重要特点是其双向性，它在训练时考虑上下文中所有词的关系，能够同时使用左侧和右侧的上下文信息。
* 预训练和微调：BERT 先进行大规模文本的无监督预训练，然后通过微调（fine-tuning）来适应特定任务（如文本分类、问答等）。<br>

总结：BERT 是一种基于 Transformers 架构的特定模型，利用自注意力机制和双向上下文理解能力来处理各种自然语言处理任务。可以认为 BERT 是 Transformers 的一个具体实现，专注于增强语言表示的能力。

### Bert在Quora Question Pairs任务中希望的输入
1. 输入对
* 两个问题（question1 和 question2）需要作为输入
2. 输入格式
* input_ids：对应两个问题的 token ID，使用 BERT 的 tokenizer 进行编码。
* attention_mask：指示哪些 tokens 是实际输入（1）和填充（0）。
* token_type_ids：用于区分两个句子的标识，question1 的 token_type_ids 为 0，question2 的 token_type_ids 为 1。
3. 标签
* 一个标签（labels），指示这两个问题是否重复（0 或 1）。

输入示例：<br>
{<br>
    "input_ids": [101, 2054, 2003, ...],  // Token IDs<br>
    "attention_mask": [1, 1, ...],       // Attention mask<br>
    "token_type_ids": [0, 0, ..., 1, 1, ...],  // Token type IDs<br>
    "labels": 0  // 0: 不重复, 1: 重复<br>
}<br>


### 1. 导包

In [1]:
import pandas as pd
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import torch

  from .autonotebook import tqdm as notebook_tqdm


### 2. 加载数据集

In [20]:
# 加载数据集
dataset = load_dataset('quora', split='train', trust_remote_code=True)

In [5]:
dataset

Dataset({
    features: ['questions', 'is_duplicate'],
    num_rows: 404290
})

In [6]:
dataset_df=pd.DataFrame(dataset)

In [7]:
dataset_df.head()

Unnamed: 0,questions,is_duplicate
0,"{'id': [1, 2], 'text': ['What is the step by s...",False
1,"{'id': [3, 4], 'text': ['What is the story of ...",False
2,"{'id': [5, 6], 'text': ['How can I increase th...",False
3,"{'id': [7, 8], 'text': ['Why am I mentally ver...",False
4,"{'id': [9, 10], 'text': ['Which one dissolve i...",False


In [10]:
dataset_df['questions'][0]

{'id': [1, 2],
 'text': ['What is the step by step guide to invest in share market in india?',
  'What is the step by step guide to invest in share market?']}

### 3. 数据预处理

In [21]:
# 提取问题对和标签
def extract_questions_and_labels(examples):
    questions = examples['questions']
    return {
        'question1': questions['text'][0],  # 获取第一个问题文本
        'question2': questions['text'][1],  # 获取第二个问题文本
        'is_duplicate': examples['is_duplicate']  # 复制标签
    }

In [22]:
# 使用 map 函数转换数据集
dataset = dataset.map(extract_questions_and_labels)

In [23]:
dataset

Dataset({
    features: ['questions', 'is_duplicate', 'question1', 'question2'],
    num_rows: 404290
})

In [24]:
dataset_df=pd.DataFrame(dataset)

In [25]:
dataset_df.head()

Unnamed: 0,questions,is_duplicate,question1,question2
0,"{'id': [1, 2], 'text': ['What is the step by s...",False,What is the step by step guide to invest in sh...,What is the step by step guide to invest in sh...
1,"{'id': [3, 4], 'text': ['What is the story of ...",False,What is the story of Kohinoor (Koh-i-Noor) Dia...,What would happen if the Indian government sto...
2,"{'id': [5, 6], 'text': ['How can I increase th...",False,How can I increase the speed of my internet co...,How can Internet speed be increased by hacking...
3,"{'id': [7, 8], 'text': ['Why am I mentally ver...",False,Why am I mentally very lonely? How can I solve...,Find the remainder when [math]23^{24}[/math] i...
4,"{'id': [9, 10], 'text': ['Which one dissolve i...",False,"Which one dissolve in water quikly sugar, salt...",Which fish would survive in salt water?


In [14]:
# 初始化 tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [36]:
# 数据预处理，tokenization 和添加 labels
def preprocess_function(examples):
    # 对问题对进行编码
    encoding = tokenizer(
        examples['question1'],
        examples['question2'],
        truncation=True,
        padding='max_length',
        max_length=128
    )
    # 返回编码结果和标签
    return {
        'input_ids': encoding['input_ids'],
        'attention_mask': encoding['attention_mask'],
        'labels': [1 if label else 0 for label in examples['is_duplicate']] # 添加 labels
    }


In [37]:
# 应用预处理
tokenized_dataset = dataset.map(preprocess_function, batched=True)


[A
[ABe aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.

[ABe aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.

[A
[ABe aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.

[A
[A
[ABe aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.

[A
[A
[A
[ABe aware, overflowing tokens are not returned fo

In [28]:
tokenized_dataset

Dataset({
    features: ['questions', 'is_duplicate', 'question1', 'question2', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 404290
})

In [38]:
# 剔除questions和is_duplicate
# 剔除不需要的字段
tokenized_dataset = tokenized_dataset.remove_columns(['questions', 'is_duplicate'])

# 检查最终的数据集
print(tokenized_dataset.column_names)  # 确认保留的字段


['question1', 'question2', 'input_ids', 'attention_mask', 'labels']


In [39]:
tokenized_dataset

Dataset({
    features: ['question1', 'question2', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 404290
})

In [40]:
print(tokenized_dataset[0])  # 打印第一个样本以确认字段


{'question1': 'What is the step by step guide to invest in share market in india?', 'question2': 'What is the step by step guide to invest in share market?', 'input_ids': [101, 2054, 2003, 1996, 3357, 2011, 3357, 5009, 2000, 15697, 1999, 3745, 3006, 1999, 2634, 1029, 102, 2054, 2003, 1996, 3357, 2011, 3357, 5009, 2000, 15697, 1999, 3745, 3006, 1029, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

#### 最后得到的tokenized_dataset
1. input_ids
* 描述：每个问题对被编码为模型输入的ID列表
* 示例：[101, 2054, 2003, 1996, ...]（数字对应于 BERT 词汇表中的词）。
2. attention_mask
* 描述：指示哪些 tokens 是实际输入（1）以及哪些是填充（0）。
* 示例：[1, 1, 1, 1, 0, 0, ...]（有效 token 用 1 表示，填充用 0 表示）。
3. labels
* 描述：二分类标签，表示两个问题是否重复（0 或 1）。
* 示例：0（表示这两个问题不重复）或 1（表示它们重复）。
4. 其他可选字典
* 如果你在数据预处理过程中保留了问题文本，可以包含 question1 和 question2 字段用于进一步分析。

### 4. 设置模型和训练参数

In [41]:
# 设置模型
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# 设置训练参数
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    logging_dir='./logs',
    logging_steps=10,
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### 5. 模型训练

跑的时间比较久所以先不跑了

In [42]:
# 使用 Trainer 进行训练
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset  # 可以替换为验证集
)

# 开始训练
trainer.train()

  0%|          | 0/75807 [04:50<?, ?it/s]
  0%|          | 10/75807 [00:05<10:25:30,  2.02it/s]

{'loss': 0.664, 'grad_norm': 4.907032489776611, 'learning_rate': 4.9993404303032704e-05, 'epoch': 0.0}


  0%|          | 20/75807 [00:10<10:08:38,  2.08it/s]

{'loss': 0.717, 'grad_norm': 2.8049395084381104, 'learning_rate': 4.9986808606065405e-05, 'epoch': 0.0}


  0%|          | 30/75807 [00:14<10:08:53,  2.07it/s]

{'loss': 0.6858, 'grad_norm': 2.0637381076812744, 'learning_rate': 4.9980212909098107e-05, 'epoch': 0.0}


  0%|          | 40/75807 [00:19<10:06:10,  2.08it/s]

{'loss': 0.6356, 'grad_norm': 6.703273773193359, 'learning_rate': 4.997361721213081e-05, 'epoch': 0.0}


  0%|          | 50/75807 [00:24<10:00:02,  2.10it/s]

{'loss': 0.6218, 'grad_norm': 3.67154598236084, 'learning_rate': 4.996702151516351e-05, 'epoch': 0.0}


  0%|          | 60/75807 [00:29<10:05:16,  2.09it/s]

{'loss': 0.6571, 'grad_norm': 13.532525062561035, 'learning_rate': 4.996042581819621e-05, 'epoch': 0.0}


  0%|          | 70/75807 [00:34<10:03:44,  2.09it/s]

{'loss': 0.5325, 'grad_norm': 5.410545349121094, 'learning_rate': 4.995383012122891e-05, 'epoch': 0.0}


  0%|          | 80/75807 [00:38<10:02:08,  2.10it/s]

{'loss': 0.5407, 'grad_norm': 3.6037492752075195, 'learning_rate': 4.994723442426161e-05, 'epoch': 0.0}


  0%|          | 90/75807 [00:43<10:05:03,  2.09it/s]

{'loss': 0.5539, 'grad_norm': 9.445372581481934, 'learning_rate': 4.9940638727294315e-05, 'epoch': 0.0}


  0%|          | 100/75807 [00:48<9:59:07,  2.11it/s]

{'loss': 0.5479, 'grad_norm': 4.551067352294922, 'learning_rate': 4.9934043030327016e-05, 'epoch': 0.0}


  0%|          | 110/75807 [00:53<10:04:28,  2.09it/s]

{'loss': 0.5365, 'grad_norm': 3.160682439804077, 'learning_rate': 4.992744733335972e-05, 'epoch': 0.0}


  0%|          | 120/75807 [00:57<9:55:54,  2.12it/s] 

{'loss': 0.4984, 'grad_norm': 7.243453025817871, 'learning_rate': 4.992085163639242e-05, 'epoch': 0.0}


  0%|          | 130/75807 [01:02<9:59:14,  2.10it/s] 

{'loss': 0.4441, 'grad_norm': 9.628811836242676, 'learning_rate': 4.991425593942512e-05, 'epoch': 0.01}


  0%|          | 140/75807 [01:07<9:56:23,  2.11it/s] 

{'loss': 0.5529, 'grad_norm': 7.7236409187316895, 'learning_rate': 4.990766024245782e-05, 'epoch': 0.01}


  0%|          | 150/75807 [01:12<10:04:19,  2.09it/s]

{'loss': 0.5506, 'grad_norm': 4.287801742553711, 'learning_rate': 4.990106454549052e-05, 'epoch': 0.01}


  0%|          | 160/75807 [01:16<9:57:21,  2.11it/s] 

{'loss': 0.4624, 'grad_norm': 17.113445281982422, 'learning_rate': 4.9894468848523224e-05, 'epoch': 0.01}


  0%|          | 170/75807 [01:21<9:59:12,  2.10it/s]

{'loss': 0.5325, 'grad_norm': 12.137215614318848, 'learning_rate': 4.9887873151555926e-05, 'epoch': 0.01}


  0%|          | 180/75807 [01:26<9:56:53,  2.11it/s] 

{'loss': 0.5648, 'grad_norm': 3.0160844326019287, 'learning_rate': 4.988127745458863e-05, 'epoch': 0.01}


  0%|          | 190/75807 [01:31<9:56:48,  2.11it/s] 

{'loss': 0.507, 'grad_norm': 7.2587761878967285, 'learning_rate': 4.987468175762133e-05, 'epoch': 0.01}


  0%|          | 200/75807 [01:35<9:58:21,  2.11it/s] 

{'loss': 0.4819, 'grad_norm': 6.128285884857178, 'learning_rate': 4.986808606065403e-05, 'epoch': 0.01}


  0%|          | 210/75807 [01:40<9:54:31,  2.12it/s] 

{'loss': 0.4881, 'grad_norm': 5.179528713226318, 'learning_rate': 4.986149036368674e-05, 'epoch': 0.01}


  0%|          | 220/75807 [01:45<9:54:00,  2.12it/s] 

{'loss': 0.4845, 'grad_norm': 5.795255184173584, 'learning_rate': 4.985489466671944e-05, 'epoch': 0.01}


  0%|          | 230/75807 [01:50<9:58:26,  2.10it/s]

{'loss': 0.5056, 'grad_norm': 8.283775329589844, 'learning_rate': 4.984829896975214e-05, 'epoch': 0.01}


  0%|          | 240/75807 [01:54<9:57:26,  2.11it/s] 

{'loss': 0.4354, 'grad_norm': 6.36803674697876, 'learning_rate': 4.984170327278484e-05, 'epoch': 0.01}


  0%|          | 250/75807 [01:59<9:51:51,  2.13it/s]

{'loss': 0.5147, 'grad_norm': 3.0770459175109863, 'learning_rate': 4.9835107575817544e-05, 'epoch': 0.01}


  0%|          | 260/75807 [02:04<9:50:02,  2.13it/s]

{'loss': 0.4452, 'grad_norm': 10.688276290893555, 'learning_rate': 4.9828511878850245e-05, 'epoch': 0.01}


  0%|          | 270/75807 [02:09<9:58:28,  2.10it/s] 

{'loss': 0.4453, 'grad_norm': 11.232171058654785, 'learning_rate': 4.9821916181882946e-05, 'epoch': 0.01}


  0%|          | 280/75807 [02:13<9:55:15,  2.11it/s]

{'loss': 0.4356, 'grad_norm': 5.009365081787109, 'learning_rate': 4.981532048491565e-05, 'epoch': 0.01}


  0%|          | 290/75807 [02:18<9:58:37,  2.10it/s] 

{'loss': 0.4268, 'grad_norm': 11.043815612792969, 'learning_rate': 4.980872478794835e-05, 'epoch': 0.01}


  0%|          | 300/75807 [02:23<9:57:21,  2.11it/s] 

{'loss': 0.4687, 'grad_norm': 13.077932357788086, 'learning_rate': 4.980212909098105e-05, 'epoch': 0.01}


  0%|          | 310/75807 [02:28<10:05:12,  2.08it/s]

{'loss': 0.5004, 'grad_norm': 4.946128845214844, 'learning_rate': 4.979553339401375e-05, 'epoch': 0.01}


  0%|          | 320/75807 [02:32<9:56:52,  2.11it/s] 

{'loss': 0.4777, 'grad_norm': 10.528044700622559, 'learning_rate': 4.978893769704645e-05, 'epoch': 0.01}


  0%|          | 330/75807 [02:37<10:14:46,  2.05it/s]

{'loss': 0.412, 'grad_norm': 5.848959922790527, 'learning_rate': 4.9782342000079154e-05, 'epoch': 0.01}


  0%|          | 340/75807 [02:42<9:56:59,  2.11it/s] 

{'loss': 0.3578, 'grad_norm': 3.1740450859069824, 'learning_rate': 4.9775746303111856e-05, 'epoch': 0.01}


  0%|          | 350/75807 [02:47<9:56:05,  2.11it/s] 

{'loss': 0.526, 'grad_norm': 5.705559253692627, 'learning_rate': 4.976915060614456e-05, 'epoch': 0.01}


  0%|          | 360/75807 [02:52<10:10:18,  2.06it/s]

{'loss': 0.4932, 'grad_norm': 3.675070285797119, 'learning_rate': 4.976255490917726e-05, 'epoch': 0.01}


  0%|          | 370/75807 [02:57<10:12:48,  2.05it/s]

{'loss': 0.3689, 'grad_norm': 4.85699987411499, 'learning_rate': 4.975595921220996e-05, 'epoch': 0.01}


  1%|          | 380/75807 [03:01<9:54:15,  2.12it/s] 

{'loss': 0.5255, 'grad_norm': 8.221631050109863, 'learning_rate': 4.974936351524266e-05, 'epoch': 0.02}


  1%|          | 390/75807 [03:06<10:00:21,  2.09it/s]

{'loss': 0.446, 'grad_norm': 5.07841157913208, 'learning_rate': 4.974276781827536e-05, 'epoch': 0.02}


  1%|          | 400/75807 [03:11<9:53:18,  2.12it/s] 

{'loss': 0.4286, 'grad_norm': 8.51626968383789, 'learning_rate': 4.9736172121308064e-05, 'epoch': 0.02}


  1%|          | 410/75807 [03:15<9:48:01,  2.14it/s]

{'loss': 0.3934, 'grad_norm': 11.602399826049805, 'learning_rate': 4.9729576424340765e-05, 'epoch': 0.02}


  1%|          | 420/75807 [03:20<9:50:02,  2.13it/s]

{'loss': 0.442, 'grad_norm': 3.1047515869140625, 'learning_rate': 4.972298072737347e-05, 'epoch': 0.02}


  1%|          | 430/75807 [03:25<9:54:08,  2.11it/s]

{'loss': 0.4262, 'grad_norm': 4.609258651733398, 'learning_rate': 4.971638503040617e-05, 'epoch': 0.02}


  1%|          | 440/75807 [03:30<9:59:39,  2.09it/s] 

{'loss': 0.4784, 'grad_norm': 2.9943106174468994, 'learning_rate': 4.970978933343887e-05, 'epoch': 0.02}


  1%|          | 450/75807 [03:34<9:57:56,  2.10it/s] 

{'loss': 0.4214, 'grad_norm': 19.045448303222656, 'learning_rate': 4.970319363647157e-05, 'epoch': 0.02}


  1%|          | 460/75807 [03:39<9:51:58,  2.12it/s]

{'loss': 0.4932, 'grad_norm': 3.7924039363861084, 'learning_rate': 4.969659793950427e-05, 'epoch': 0.02}


  1%|          | 470/75807 [03:44<9:52:51,  2.12it/s]

{'loss': 0.3581, 'grad_norm': 6.4877028465271, 'learning_rate': 4.9690002242536974e-05, 'epoch': 0.02}


  1%|          | 480/75807 [03:49<9:51:50,  2.12it/s]

{'loss': 0.4822, 'grad_norm': 5.331875324249268, 'learning_rate': 4.9683406545569675e-05, 'epoch': 0.02}


  1%|          | 490/75807 [03:53<9:50:12,  2.13it/s]

{'loss': 0.3452, 'grad_norm': 9.672165870666504, 'learning_rate': 4.9676810848602376e-05, 'epoch': 0.02}


  1%|          | 500/75807 [03:58<10:03:26,  2.08it/s]

{'loss': 0.4881, 'grad_norm': 8.223822593688965, 'learning_rate': 4.967021515163507e-05, 'epoch': 0.02}


  1%|          | 510/75807 [04:05<10:37:40,  1.97it/s]

{'loss': 0.3774, 'grad_norm': 5.164790630340576, 'learning_rate': 4.966361945466777e-05, 'epoch': 0.02}


  1%|          | 520/75807 [04:10<10:04:25,  2.08it/s]

{'loss': 0.4081, 'grad_norm': 4.251534938812256, 'learning_rate': 4.9657023757700474e-05, 'epoch': 0.02}


  1%|          | 530/75807 [04:15<9:56:49,  2.10it/s] 

{'loss': 0.3741, 'grad_norm': 16.012235641479492, 'learning_rate': 4.9650428060733175e-05, 'epoch': 0.02}


  1%|          | 540/75807 [04:20<10:11:50,  2.05it/s]

{'loss': 0.4466, 'grad_norm': 6.623591899871826, 'learning_rate': 4.9643832363765877e-05, 'epoch': 0.02}


  1%|          | 550/75807 [04:24<9:52:11,  2.12it/s] 

{'loss': 0.4039, 'grad_norm': 2.9993674755096436, 'learning_rate': 4.963723666679858e-05, 'epoch': 0.02}


  1%|          | 560/75807 [04:29<9:58:03,  2.10it/s] 

{'loss': 0.4482, 'grad_norm': 4.071503639221191, 'learning_rate': 4.963064096983128e-05, 'epoch': 0.02}


  1%|          | 570/75807 [04:34<9:54:30,  2.11it/s] 

{'loss': 0.4746, 'grad_norm': 7.2324957847595215, 'learning_rate': 4.962404527286398e-05, 'epoch': 0.02}


  1%|          | 580/75807 [04:39<9:54:17,  2.11it/s]

{'loss': 0.5568, 'grad_norm': 4.483479022979736, 'learning_rate': 4.961744957589668e-05, 'epoch': 0.02}


  1%|          | 590/75807 [04:43<9:55:06,  2.11it/s]

{'loss': 0.4312, 'grad_norm': 9.702492713928223, 'learning_rate': 4.9610853878929383e-05, 'epoch': 0.02}


  1%|          | 600/75807 [04:48<9:57:27,  2.10it/s] 

{'loss': 0.5021, 'grad_norm': 14.631596565246582, 'learning_rate': 4.9604258181962085e-05, 'epoch': 0.02}


  1%|          | 610/75807 [04:53<9:54:30,  2.11it/s] 

{'loss': 0.4017, 'grad_norm': 5.687025547027588, 'learning_rate': 4.9597662484994786e-05, 'epoch': 0.02}


  1%|          | 620/75807 [04:58<9:50:59,  2.12it/s]

{'loss': 0.4886, 'grad_norm': 3.957763671875, 'learning_rate': 4.9591066788027494e-05, 'epoch': 0.02}


  1%|          | 630/75807 [05:02<9:56:19,  2.10it/s]

{'loss': 0.4768, 'grad_norm': 4.408023834228516, 'learning_rate': 4.9584471091060196e-05, 'epoch': 0.02}


  1%|          | 640/75807 [05:07<9:51:14,  2.12it/s]

{'loss': 0.4263, 'grad_norm': 6.2244768142700195, 'learning_rate': 4.95778753940929e-05, 'epoch': 0.03}


  1%|          | 650/75807 [05:12<9:54:42,  2.11it/s]

{'loss': 0.5824, 'grad_norm': 2.626708745956421, 'learning_rate': 4.95712796971256e-05, 'epoch': 0.03}


  1%|          | 660/75807 [05:17<9:57:42,  2.10it/s] 

{'loss': 0.4275, 'grad_norm': 2.999260425567627, 'learning_rate': 4.95646840001583e-05, 'epoch': 0.03}


  1%|          | 670/75807 [05:21<9:53:38,  2.11it/s]

{'loss': 0.4145, 'grad_norm': 7.7390055656433105, 'learning_rate': 4.9558088303191e-05, 'epoch': 0.03}


  1%|          | 680/75807 [05:26<9:54:19,  2.11it/s]

{'loss': 0.4861, 'grad_norm': 4.195108413696289, 'learning_rate': 4.95514926062237e-05, 'epoch': 0.03}


  1%|          | 690/75807 [05:31<9:56:08,  2.10it/s]

{'loss': 0.3695, 'grad_norm': 3.8023364543914795, 'learning_rate': 4.9544896909256404e-05, 'epoch': 0.03}


  1%|          | 700/75807 [05:36<9:50:46,  2.12it/s]

{'loss': 0.5209, 'grad_norm': 9.44452953338623, 'learning_rate': 4.9538301212289105e-05, 'epoch': 0.03}


  1%|          | 710/75807 [05:40<9:56:15,  2.10it/s] 

{'loss': 0.386, 'grad_norm': 6.348776340484619, 'learning_rate': 4.953170551532181e-05, 'epoch': 0.03}


  1%|          | 720/75807 [05:45<9:58:01,  2.09it/s] 

{'loss': 0.4473, 'grad_norm': 4.120762348175049, 'learning_rate': 4.952510981835451e-05, 'epoch': 0.03}


  1%|          | 730/75807 [05:50<9:57:30,  2.09it/s] 

{'loss': 0.4586, 'grad_norm': 5.786841869354248, 'learning_rate': 4.951851412138721e-05, 'epoch': 0.03}


  1%|          | 740/75807 [05:55<10:03:43,  2.07it/s]

{'loss': 0.41, 'grad_norm': 2.9008710384368896, 'learning_rate': 4.951191842441991e-05, 'epoch': 0.03}


  1%|          | 750/75807 [05:59<9:48:51,  2.12it/s] 

{'loss': 0.415, 'grad_norm': 6.219066619873047, 'learning_rate': 4.950532272745261e-05, 'epoch': 0.03}


  1%|          | 760/75807 [06:04<9:47:27,  2.13it/s]

{'loss': 0.3948, 'grad_norm': 4.372376918792725, 'learning_rate': 4.9498727030485314e-05, 'epoch': 0.03}


  1%|          | 770/75807 [06:09<9:53:16,  2.11it/s]

{'loss': 0.4393, 'grad_norm': 10.332174301147461, 'learning_rate': 4.9492131333518015e-05, 'epoch': 0.03}


  1%|          | 780/75807 [06:14<9:49:16,  2.12it/s]

{'loss': 0.4308, 'grad_norm': 3.896419048309326, 'learning_rate': 4.9485535636550716e-05, 'epoch': 0.03}


  1%|          | 790/75807 [06:18<9:54:17,  2.10it/s]

{'loss': 0.3163, 'grad_norm': 6.662471294403076, 'learning_rate': 4.947893993958342e-05, 'epoch': 0.03}


  1%|          | 800/75807 [06:23<9:48:24,  2.12it/s]

{'loss': 0.481, 'grad_norm': 7.32781982421875, 'learning_rate': 4.947234424261612e-05, 'epoch': 0.03}


  1%|          | 810/75807 [06:28<9:53:29,  2.11it/s]

{'loss': 0.3093, 'grad_norm': 7.137081623077393, 'learning_rate': 4.946574854564882e-05, 'epoch': 0.03}


  1%|          | 820/75807 [06:33<9:50:13,  2.12it/s]

{'loss': 0.4208, 'grad_norm': 4.2583088874816895, 'learning_rate': 4.945915284868152e-05, 'epoch': 0.03}


  1%|          | 830/75807 [06:37<10:04:16,  2.07it/s]

{'loss': 0.529, 'grad_norm': 7.997827053070068, 'learning_rate': 4.945255715171422e-05, 'epoch': 0.03}


  1%|          | 840/75807 [06:42<10:01:44,  2.08it/s]

{'loss': 0.419, 'grad_norm': 4.38381290435791, 'learning_rate': 4.9445961454746925e-05, 'epoch': 0.03}


  1%|          | 850/75807 [06:47<9:48:51,  2.12it/s] 

{'loss': 0.4546, 'grad_norm': 10.95750904083252, 'learning_rate': 4.9439365757779626e-05, 'epoch': 0.03}


  1%|          | 860/75807 [06:52<9:51:43,  2.11it/s]

{'loss': 0.4884, 'grad_norm': 5.179226398468018, 'learning_rate': 4.943277006081233e-05, 'epoch': 0.03}


  1%|          | 870/75807 [06:56<9:49:52,  2.12it/s]

{'loss': 0.401, 'grad_norm': 4.9905877113342285, 'learning_rate': 4.942617436384503e-05, 'epoch': 0.03}


  1%|          | 880/75807 [07:01<9:47:25,  2.13it/s] 

{'loss': 0.3986, 'grad_norm': 3.997645139694214, 'learning_rate': 4.941957866687773e-05, 'epoch': 0.03}


  1%|          | 890/75807 [07:06<9:51:31,  2.11it/s] 

{'loss': 0.3394, 'grad_norm': 15.663834571838379, 'learning_rate': 4.941298296991043e-05, 'epoch': 0.04}


  1%|          | 900/75807 [07:11<9:47:52,  2.12it/s]

{'loss': 0.3866, 'grad_norm': 4.251307487487793, 'learning_rate': 4.940638727294313e-05, 'epoch': 0.04}


  1%|          | 910/75807 [07:15<9:50:17,  2.11it/s]

{'loss': 0.56, 'grad_norm': 4.424576759338379, 'learning_rate': 4.9399791575975834e-05, 'epoch': 0.04}


  1%|          | 920/75807 [07:20<9:42:25,  2.14it/s]

{'loss': 0.4015, 'grad_norm': 2.969726085662842, 'learning_rate': 4.9393195879008536e-05, 'epoch': 0.04}


  1%|          | 930/75807 [07:25<9:50:26,  2.11it/s]

{'loss': 0.3669, 'grad_norm': 9.54753303527832, 'learning_rate': 4.938660018204124e-05, 'epoch': 0.04}


  1%|          | 940/75807 [07:29<9:54:07,  2.10it/s] 

{'loss': 0.4699, 'grad_norm': 5.0186309814453125, 'learning_rate': 4.938000448507394e-05, 'epoch': 0.04}


  1%|▏         | 950/75807 [07:34<9:51:21,  2.11it/s]

{'loss': 0.3476, 'grad_norm': 8.797258377075195, 'learning_rate': 4.937340878810664e-05, 'epoch': 0.04}


  1%|▏         | 960/75807 [07:39<9:51:38,  2.11it/s]

{'loss': 0.4279, 'grad_norm': 8.412954330444336, 'learning_rate': 4.936681309113934e-05, 'epoch': 0.04}


  1%|▏         | 970/75807 [07:44<9:52:41,  2.10it/s] 

{'loss': 0.5034, 'grad_norm': 5.756533145904541, 'learning_rate': 4.936021739417204e-05, 'epoch': 0.04}


  1%|▏         | 980/75807 [07:48<9:55:45,  2.09it/s]

{'loss': 0.336, 'grad_norm': 3.007363796234131, 'learning_rate': 4.9353621697204744e-05, 'epoch': 0.04}


  1%|▏         | 990/75807 [07:53<9:51:11,  2.11it/s]

{'loss': 0.4114, 'grad_norm': 9.133577346801758, 'learning_rate': 4.9347026000237445e-05, 'epoch': 0.04}


  1%|▏         | 1000/75807 [07:58<9:50:19,  2.11it/s]

{'loss': 0.2813, 'grad_norm': 3.4141125679016113, 'learning_rate': 4.9340430303270147e-05, 'epoch': 0.04}


  1%|▏         | 1010/75807 [08:05<10:21:52,  2.00it/s]

{'loss': 0.4421, 'grad_norm': 10.622652053833008, 'learning_rate': 4.933383460630285e-05, 'epoch': 0.04}


  1%|▏         | 1020/75807 [08:10<9:48:44,  2.12it/s] 

{'loss': 0.4726, 'grad_norm': 11.064596176147461, 'learning_rate': 4.932723890933555e-05, 'epoch': 0.04}


  1%|▏         | 1030/75807 [08:15<9:46:37,  2.12it/s]

{'loss': 0.4475, 'grad_norm': 10.778032302856445, 'learning_rate': 4.932064321236826e-05, 'epoch': 0.04}


  1%|▏         | 1040/75807 [08:19<9:53:58,  2.10it/s]

{'loss': 0.4435, 'grad_norm': 4.035192012786865, 'learning_rate': 4.931404751540096e-05, 'epoch': 0.04}


  1%|▏         | 1050/75807 [08:24<9:49:17,  2.11it/s]

{'loss': 0.4759, 'grad_norm': 14.861085891723633, 'learning_rate': 4.930745181843366e-05, 'epoch': 0.04}


  1%|▏         | 1060/75807 [08:29<9:45:40,  2.13it/s]

{'loss': 0.4348, 'grad_norm': 4.7510480880737305, 'learning_rate': 4.930085612146636e-05, 'epoch': 0.04}


  1%|▏         | 1070/75807 [08:34<9:58:09,  2.08it/s] 

{'loss': 0.4484, 'grad_norm': 3.5921289920806885, 'learning_rate': 4.929426042449906e-05, 'epoch': 0.04}


  1%|▏         | 1080/75807 [08:38<9:45:47,  2.13it/s]

{'loss': 0.4052, 'grad_norm': 14.381134986877441, 'learning_rate': 4.9287664727531764e-05, 'epoch': 0.04}


  1%|▏         | 1090/75807 [08:43<9:47:43,  2.12it/s]

{'loss': 0.4454, 'grad_norm': 3.9669134616851807, 'learning_rate': 4.9281069030564466e-05, 'epoch': 0.04}


  1%|▏         | 1100/75807 [08:48<9:47:12,  2.12it/s]

{'loss': 0.4363, 'grad_norm': 5.404374599456787, 'learning_rate': 4.927447333359717e-05, 'epoch': 0.04}


  1%|▏         | 1110/75807 [08:53<9:54:14,  2.10it/s]

{'loss': 0.4668, 'grad_norm': 4.229330539703369, 'learning_rate': 4.926787763662987e-05, 'epoch': 0.04}


  1%|▏         | 1120/75807 [08:57<9:46:10,  2.12it/s]

{'loss': 0.4545, 'grad_norm': 8.021880149841309, 'learning_rate': 4.926128193966257e-05, 'epoch': 0.04}


  1%|▏         | 1130/75807 [09:02<9:40:57,  2.14it/s]

{'loss': 0.3871, 'grad_norm': 7.944655418395996, 'learning_rate': 4.925468624269527e-05, 'epoch': 0.04}


  2%|▏         | 1140/75807 [09:07<9:40:32,  2.14it/s]

{'loss': 0.39, 'grad_norm': 4.467078685760498, 'learning_rate': 4.924809054572797e-05, 'epoch': 0.05}


  2%|▏         | 1150/75807 [09:11<9:43:33,  2.13it/s]

{'loss': 0.407, 'grad_norm': 3.1791136264801025, 'learning_rate': 4.9241494848760674e-05, 'epoch': 0.05}


  2%|▏         | 1160/75807 [09:16<9:41:47,  2.14it/s]

{'loss': 0.4126, 'grad_norm': 7.0518412590026855, 'learning_rate': 4.9234899151793375e-05, 'epoch': 0.05}


  2%|▏         | 1170/75807 [09:21<9:41:45,  2.14it/s]

{'loss': 0.5006, 'grad_norm': 20.635051727294922, 'learning_rate': 4.922830345482608e-05, 'epoch': 0.05}


  2%|▏         | 1180/75807 [09:25<9:43:25,  2.13it/s]

{'loss': 0.4703, 'grad_norm': 4.450615406036377, 'learning_rate': 4.922170775785878e-05, 'epoch': 0.05}


  2%|▏         | 1190/75807 [09:30<10:00:16,  2.07it/s]

{'loss': 0.4454, 'grad_norm': 3.233123302459717, 'learning_rate': 4.921511206089148e-05, 'epoch': 0.05}


  2%|▏         | 1200/75807 [09:35<9:47:32,  2.12it/s] 

{'loss': 0.4746, 'grad_norm': 23.41254997253418, 'learning_rate': 4.920851636392418e-05, 'epoch': 0.05}


  2%|▏         | 1210/75807 [09:40<9:53:41,  2.09it/s]

{'loss': 0.4492, 'grad_norm': 18.489803314208984, 'learning_rate': 4.920192066695688e-05, 'epoch': 0.05}


  2%|▏         | 1220/75807 [09:44<9:48:30,  2.11it/s]

{'loss': 0.5383, 'grad_norm': 12.814064979553223, 'learning_rate': 4.9195324969989584e-05, 'epoch': 0.05}


  2%|▏         | 1230/75807 [09:49<9:50:57,  2.10it/s]

{'loss': 0.3363, 'grad_norm': 6.220242500305176, 'learning_rate': 4.9188729273022285e-05, 'epoch': 0.05}


  2%|▏         | 1236/75807 [09:52<9:44:59,  2.12it/s]

KeyboardInterrupt: 