In [34]:
import json
from io import open

def preprocess_data(file_path):
    train_data = {}

    with open(file_path, 'r') as file:
        i = 0
        for line in file:
            # Parse the line as JSON
            train_data[i] = json.loads(line)
            i+=1

    inputs, targets = [], []

    for key, instance in train_data.items():
        if isinstance(instance, dict) and 'messages' in instance:
            messages = instance['messages']
            game_scores = instance['game_score']

        for idx, (message, score) in enumerate(zip(messages, game_scores)):
            start_index = max(0, idx-5)
            end_index = min(len(messages)-1, idx)
            context = " ".join(messages[start_index:end_index])
            inputs.append(f"Given the following context, generate an appropriate repsonse: {context}")
            targets.append(message)
    return inputs, targets

training_inputs, training_targets = preprocess_data("2020_acl_diplomacy/data/train.jsonl")
validation_inputs, validation_targets = preprocess_data("2020_acl_diplomacy/data/validation.jsonl")
test_inputs, test_targets = preprocess_data("2020_acl_diplomacy/data/test.jsonl")

print(len(training_inputs)) 
print(len(training_targets))
print(training_inputs[3])
print(training_targets[3])

13132
13132
Given the following context, generate an appropriate repsonse: Germany!

Just the person I want to speak with. I have a somewhat crazy idea that I’ve always wanted to try with I/G, but I’ve never actually convinced the other guy to try it. And, what’s worse, it might make you suspicious of me. 

So...do I suggest it?

I’m thinking that this is a low stakes game, not a tournament or anything, and an interesting and unusual move set might make it more fun? That’s my hope anyway.

What is your appetite like for unusual and crazy? You've whet my appetite, Italy. What's the suggestion? 👍
It seems like there are a lot of ways that could go wrong...I don't see why France would see you approaching/taking Munich--while I do nothing about it--and not immediately feel skittish


In [35]:
%pip install transformers
%pip install torch

from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset, DataLoader

class DiplomacyDataset(Dataset):
    def __init__(self, data, targets, tokenizer, max_len=512):
        self.data = data
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        tokenized_input = self.tokenizer.encode(self.data[index], max_length=self.max_len, truncation=True, padding='max_length', return_tensors='pt')
        tokenized_target = self.tokenizer.encode(self.targets[index], max_length=self.max_len, truncation=True, padding='max_length', return_tensors='pt')
        item = {
            'input_ids': tokenized_input.flatten(),
            # 'decoder_input_ids': tokenized_target.flatten()[:-1], # remove the last token of the target
            'labels': tokenized_target.flatten()[1:],
        }
        return item

tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')

train_dataset = DiplomacyDataset(training_inputs, training_targets, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

validation_dataset = DiplomacyDataset(validation_inputs, validation_targets, tokenizer)
validation_loader = DataLoader(validation_dataset, batch_size=4, shuffle=True)


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [36]:
%pip install accelerate -U
import accelerate


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-4)

for epoch in range(3):
    model.train()
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch: {epoch}, Loss: {loss.item()}")
        


# training_args = TrainingArguments(
#     output_dir='./results',
#     evaluation_strategy='epoch',
#     num_train_epochs=3,
#     learning_rate=3e-4,
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     warmup_steps=500,
#     weight_decay=0.01,
#     logging_dir='./logs',
#     logging_steps=10
# )

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_dataset,
#     eval_dataset=validation_dataset
# )



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
Epoch: 0, Loss: 2.521228790283203
Epoch: 0, Loss: 2.0196661949157715
Epoch: 0, Loss: 2.0782434940338135
Epoch: 0, Loss: 1.8016599416732788
Epoch: 0, Loss: 1.8211919069290161
Epoch: 0, Loss: 1.907079815864563
Epoch: 0, Loss: 1.5171607732772827
Epoch: 0, Loss: 1.6586593389511108
Epoch: 0, Loss: 1.2573260068893433
Epoch: 0, Loss: 1.2407628297805786
Epoch: 0, Loss: 0.9981366395950317
Epoch: 0, Loss: 0.9253311157226562
Epoch: 0, Loss: 0.6820459365844727
Epoch: 0, Loss: 0.6176517009735107
Epoch: 0, Loss: 1.1860613822937012
Epoch: 0, Loss: 0.5721552968025208
Epoch: 0, Loss: 0.3700699210166931
Epoch: 0, Loss: 0.6948256492614746
Epoch: 0, Loss: 0.31539386510849
Epoch: 0, Loss: 0.2

KeyboardInterrupt: 

In [None]:
# trainer.train()

  0%|          | 0/3283 [36:15<?, ?it/s]
  0%|          | 0/3283 [25:26<?, ?it/s]
  0%|          | 10/3283 [00:18<1:05:46,  1.21s/it]

{'loss': 3.1059, 'learning_rate': 4.0000000000000003e-07, 'epoch': 0.0}


  1%|          | 20/3283 [00:29<59:55,  1.10s/it]  

{'loss': 2.4386, 'learning_rate': 8.000000000000001e-07, 'epoch': 0.01}


  1%|          | 30/3283 [00:40<1:00:07,  1.11s/it]

{'loss': 2.9373, 'learning_rate': 1.2000000000000002e-06, 'epoch': 0.01}


  1%|          | 40/3283 [00:51<1:00:39,  1.12s/it]

{'loss': 3.7704, 'learning_rate': 1.6000000000000001e-06, 'epoch': 0.01}


  2%|▏         | 50/3283 [01:02<59:33,  1.11s/it]  

{'loss': 3.4283, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.02}


  2%|▏         | 60/3283 [01:14<1:00:46,  1.13s/it]

{'loss': 3.3524, 'learning_rate': 2.4000000000000003e-06, 'epoch': 0.02}


  2%|▏         | 70/3283 [01:25<1:02:05,  1.16s/it]

{'loss': 2.5103, 'learning_rate': 2.8000000000000003e-06, 'epoch': 0.02}


  2%|▏         | 80/3283 [01:37<1:02:05,  1.16s/it]

{'loss': 2.6416, 'learning_rate': 3.2000000000000003e-06, 'epoch': 0.02}


  3%|▎         | 90/3283 [01:48<1:01:37,  1.16s/it]

{'loss': 2.6455, 'learning_rate': 3.6000000000000003e-06, 'epoch': 0.03}


  3%|▎         | 100/3283 [02:00<1:00:25,  1.14s/it]

{'loss': 3.2336, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.03}


  3%|▎         | 110/3283 [02:11<1:01:28,  1.16s/it]

{'loss': 2.6843, 'learning_rate': 4.4e-06, 'epoch': 0.03}


  4%|▎         | 120/3283 [02:23<1:01:33,  1.17s/it]

{'loss': 2.9246, 'learning_rate': 4.800000000000001e-06, 'epoch': 0.04}


  4%|▍         | 130/3283 [02:35<1:00:02,  1.14s/it]

{'loss': 2.818, 'learning_rate': 5.2e-06, 'epoch': 0.04}


  4%|▍         | 140/3283 [02:46<59:34,  1.14s/it]  

{'loss': 2.853, 'learning_rate': 5.600000000000001e-06, 'epoch': 0.04}


  5%|▍         | 150/3283 [02:57<59:41,  1.14s/it]  

{'loss': 3.067, 'learning_rate': 6e-06, 'epoch': 0.05}


  5%|▍         | 160/3283 [03:09<59:23,  1.14s/it]  

{'loss': 2.1538, 'learning_rate': 6.4000000000000006e-06, 'epoch': 0.05}


  5%|▌         | 170/3283 [03:20<58:11,  1.12s/it]

{'loss': 2.2841, 'learning_rate': 6.800000000000001e-06, 'epoch': 0.05}


  5%|▌         | 180/3283 [03:32<59:10,  1.14s/it]

{'loss': 2.037, 'learning_rate': 7.2000000000000005e-06, 'epoch': 0.05}


  6%|▌         | 190/3283 [03:43<1:00:05,  1.17s/it]

{'loss': 1.6792, 'learning_rate': 7.600000000000001e-06, 'epoch': 0.06}


  6%|▌         | 200/3283 [03:55<58:47,  1.14s/it]  

{'loss': 2.1718, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.06}


  6%|▋         | 210/3283 [04:06<58:55,  1.15s/it]

{'loss': 1.6793, 'learning_rate': 8.400000000000001e-06, 'epoch': 0.06}


  7%|▋         | 220/3283 [04:18<59:06,  1.16s/it]

{'loss': 1.4591, 'learning_rate': 8.8e-06, 'epoch': 0.07}


  7%|▋         | 230/3283 [04:29<57:58,  1.14s/it]

{'loss': 1.4309, 'learning_rate': 9.200000000000002e-06, 'epoch': 0.07}


  7%|▋         | 240/3283 [04:41<58:54,  1.16s/it]

{'loss': 1.6299, 'learning_rate': 9.600000000000001e-06, 'epoch': 0.07}


  8%|▊         | 250/3283 [04:52<57:48,  1.14s/it]

{'loss': 0.8464, 'learning_rate': 1e-05, 'epoch': 0.08}


  8%|▊         | 260/3283 [05:03<57:15,  1.14s/it]

{'loss': 1.3337, 'learning_rate': 1.04e-05, 'epoch': 0.08}


  8%|▊         | 270/3283 [05:15<57:21,  1.14s/it]

{'loss': 0.5891, 'learning_rate': 1.0800000000000002e-05, 'epoch': 0.08}


  9%|▊         | 280/3283 [05:27<57:23,  1.15s/it]

{'loss': 0.6414, 'learning_rate': 1.1200000000000001e-05, 'epoch': 0.09}


  9%|▉         | 290/3283 [05:38<57:37,  1.16s/it]

{'loss': 0.6666, 'learning_rate': 1.16e-05, 'epoch': 0.09}


  9%|▉         | 300/3283 [05:50<57:32,  1.16s/it]

{'loss': 0.3842, 'learning_rate': 1.2e-05, 'epoch': 0.09}


  9%|▉         | 310/3283 [06:02<57:17,  1.16s/it]

{'loss': 0.6727, 'learning_rate': 1.2400000000000002e-05, 'epoch': 0.09}


 10%|▉         | 320/3283 [06:13<57:47,  1.17s/it]

{'loss': 0.5515, 'learning_rate': 1.2800000000000001e-05, 'epoch': 0.1}


 10%|█         | 330/3283 [06:25<56:44,  1.15s/it]

{'loss': 0.5084, 'learning_rate': 1.3200000000000002e-05, 'epoch': 0.1}


 10%|█         | 340/3283 [06:36<54:49,  1.12s/it]

{'loss': 0.4077, 'learning_rate': 1.3600000000000002e-05, 'epoch': 0.1}


 11%|█         | 350/3283 [06:48<56:31,  1.16s/it]

{'loss': 0.4453, 'learning_rate': 1.4e-05, 'epoch': 0.11}


 11%|█         | 360/3283 [06:59<55:44,  1.14s/it]

{'loss': 0.3302, 'learning_rate': 1.4400000000000001e-05, 'epoch': 0.11}


 11%|█▏        | 370/3283 [07:10<54:15,  1.12s/it]

{'loss': 0.5269, 'learning_rate': 1.48e-05, 'epoch': 0.11}


 12%|█▏        | 380/3283 [07:22<57:26,  1.19s/it]

{'loss': 0.3889, 'learning_rate': 1.5200000000000002e-05, 'epoch': 0.12}


 12%|█▏        | 390/3283 [07:34<55:08,  1.14s/it]

{'loss': 0.3839, 'learning_rate': 1.5600000000000003e-05, 'epoch': 0.12}


 12%|█▏        | 400/3283 [07:45<53:38,  1.12s/it]

{'loss': 0.3285, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.12}


 12%|█▏        | 410/3283 [07:57<55:39,  1.16s/it]

{'loss': 0.3101, 'learning_rate': 1.64e-05, 'epoch': 0.12}


 13%|█▎        | 420/3283 [08:08<54:09,  1.14s/it]

{'loss': 0.3796, 'learning_rate': 1.6800000000000002e-05, 'epoch': 0.13}


 13%|█▎        | 430/3283 [08:19<54:25,  1.14s/it]

{'loss': 0.236, 'learning_rate': 1.72e-05, 'epoch': 0.13}


 13%|█▎        | 440/3283 [08:31<53:53,  1.14s/it]

{'loss': 0.2318, 'learning_rate': 1.76e-05, 'epoch': 0.13}


 14%|█▎        | 450/3283 [08:42<54:49,  1.16s/it]

{'loss': 0.2312, 'learning_rate': 1.8e-05, 'epoch': 0.14}


 14%|█▍        | 460/3283 [08:54<54:08,  1.15s/it]

{'loss': 0.2845, 'learning_rate': 1.8400000000000003e-05, 'epoch': 0.14}


 14%|█▍        | 470/3283 [09:06<54:07,  1.15s/it]

{'loss': 0.2443, 'learning_rate': 1.88e-05, 'epoch': 0.14}


 15%|█▍        | 480/3283 [09:17<54:15,  1.16s/it]

{'loss': 0.4362, 'learning_rate': 1.9200000000000003e-05, 'epoch': 0.15}


 15%|█▍        | 490/3283 [09:29<53:53,  1.16s/it]

{'loss': 0.2693, 'learning_rate': 1.9600000000000002e-05, 'epoch': 0.15}


 15%|█▌        | 500/3283 [09:40<52:47,  1.14s/it]

{'loss': 0.2067, 'learning_rate': 2e-05, 'epoch': 0.15}


 16%|█▌        | 510/3283 [09:53<53:19,  1.15s/it]  

{'loss': 0.3408, 'learning_rate': 1.992813510600072e-05, 'epoch': 0.16}


 16%|█▌        | 520/3283 [10:04<52:35,  1.14s/it]

{'loss': 0.2502, 'learning_rate': 1.985627021200144e-05, 'epoch': 0.16}


 16%|█▌        | 530/3283 [10:16<53:01,  1.16s/it]

{'loss': 0.341, 'learning_rate': 1.978440531800216e-05, 'epoch': 0.16}


 16%|█▋        | 540/3283 [10:27<52:21,  1.15s/it]

{'loss': 0.2189, 'learning_rate': 1.9712540424002878e-05, 'epoch': 0.16}


 17%|█▋        | 550/3283 [10:38<51:06,  1.12s/it]

{'loss': 0.2287, 'learning_rate': 1.9640675530003594e-05, 'epoch': 0.17}


 17%|█▋        | 560/3283 [10:50<51:09,  1.13s/it]

{'loss': 0.252, 'learning_rate': 1.9568810636004313e-05, 'epoch': 0.17}


 17%|█▋        | 570/3283 [11:01<51:47,  1.15s/it]

{'loss': 0.2669, 'learning_rate': 1.9496945742005032e-05, 'epoch': 0.17}


 18%|█▊        | 580/3283 [11:13<51:53,  1.15s/it]

{'loss': 0.2708, 'learning_rate': 1.942508084800575e-05, 'epoch': 0.18}


 18%|█▊        | 590/3283 [11:24<52:42,  1.17s/it]

{'loss': 0.2122, 'learning_rate': 1.935321595400647e-05, 'epoch': 0.18}


 18%|█▊        | 600/3283 [11:36<51:12,  1.15s/it]

{'loss': 0.2947, 'learning_rate': 1.9281351060007186e-05, 'epoch': 0.18}


 19%|█▊        | 610/3283 [11:47<49:42,  1.12s/it]

{'loss': 0.1728, 'learning_rate': 1.920948616600791e-05, 'epoch': 0.19}


 19%|█▉        | 620/3283 [11:58<50:13,  1.13s/it]

{'loss': 0.2831, 'learning_rate': 1.9137621272008627e-05, 'epoch': 0.19}


 19%|█▉        | 630/3283 [12:09<49:43,  1.12s/it]

{'loss': 0.3148, 'learning_rate': 1.9065756378009343e-05, 'epoch': 0.19}


 19%|█▉        | 640/3283 [12:21<49:54,  1.13s/it]

{'loss': 0.2555, 'learning_rate': 1.8993891484010062e-05, 'epoch': 0.19}


 20%|█▉        | 650/3283 [12:32<50:06,  1.14s/it]

{'loss': 0.2558, 'learning_rate': 1.892202659001078e-05, 'epoch': 0.2}


 20%|██        | 660/3283 [12:44<49:27,  1.13s/it]

{'loss': 0.3003, 'learning_rate': 1.88501616960115e-05, 'epoch': 0.2}


 20%|██        | 670/3283 [12:55<50:24,  1.16s/it]

{'loss': 0.2244, 'learning_rate': 1.877829680201222e-05, 'epoch': 0.2}


 21%|██        | 680/3283 [13:06<50:13,  1.16s/it]

{'loss': 0.3553, 'learning_rate': 1.8706431908012935e-05, 'epoch': 0.21}


 21%|██        | 690/3283 [13:18<49:27,  1.14s/it]

{'loss': 0.2542, 'learning_rate': 1.8634567014013654e-05, 'epoch': 0.21}


 21%|██▏       | 700/3283 [13:29<48:45,  1.13s/it]

{'loss': 0.2608, 'learning_rate': 1.8562702120014377e-05, 'epoch': 0.21}


 22%|██▏       | 710/3283 [13:41<47:44,  1.11s/it]

{'loss': 0.2649, 'learning_rate': 1.8490837226015093e-05, 'epoch': 0.22}


 22%|██▏       | 720/3283 [13:52<48:12,  1.13s/it]

{'loss': 0.2771, 'learning_rate': 1.8418972332015812e-05, 'epoch': 0.22}


 22%|██▏       | 730/3283 [14:03<48:33,  1.14s/it]

{'loss': 0.2248, 'learning_rate': 1.834710743801653e-05, 'epoch': 0.22}


 23%|██▎       | 740/3283 [14:15<49:19,  1.16s/it]

{'loss': 0.3157, 'learning_rate': 1.827524254401725e-05, 'epoch': 0.23}


 23%|██▎       | 750/3283 [14:26<47:31,  1.13s/it]

{'loss': 0.3193, 'learning_rate': 1.820337765001797e-05, 'epoch': 0.23}


 23%|██▎       | 760/3283 [14:37<48:21,  1.15s/it]

{'loss': 0.2078, 'learning_rate': 1.8131512756018685e-05, 'epoch': 0.23}


 23%|██▎       | 770/3283 [14:49<47:10,  1.13s/it]

{'loss': 0.2115, 'learning_rate': 1.8059647862019404e-05, 'epoch': 0.23}


 24%|██▍       | 780/3283 [15:00<49:01,  1.18s/it]

{'loss': 0.2383, 'learning_rate': 1.7987782968020123e-05, 'epoch': 0.24}


 24%|██▍       | 790/3283 [17:23<8:06:07, 11.70s/it] 

{'loss': 0.2824, 'learning_rate': 1.7915918074020842e-05, 'epoch': 0.24}


 24%|██▍       | 800/3283 [17:35<58:57,  1.42s/it]  

{'loss': 0.2615, 'learning_rate': 1.784405318002156e-05, 'epoch': 0.24}


 25%|██▍       | 810/3283 [19:41<7:29:43, 10.91s/it] 

{'loss': 0.38, 'learning_rate': 1.777218828602228e-05, 'epoch': 0.25}


 25%|██▍       | 820/3283 [20:54<9:51:12, 14.40s/it] 

{'loss': 0.2765, 'learning_rate': 1.7700323392023e-05, 'epoch': 0.25}


 25%|██▌       | 830/3283 [21:05<58:29,  1.43s/it]  

{'loss': 0.2247, 'learning_rate': 1.762845849802372e-05, 'epoch': 0.25}


 26%|██▌       | 840/3283 [25:06<47:49:27, 70.47s/it]

{'loss': 0.2467, 'learning_rate': 1.7556593604024434e-05, 'epoch': 0.26}


 26%|██▌       | 850/3283 [25:17<2:02:56,  3.03s/it] 

{'loss': 0.2125, 'learning_rate': 1.7484728710025153e-05, 'epoch': 0.26}


 26%|██▌       | 860/3283 [33:15<4:33:44,  6.78s/it]  

{'loss': 0.2495, 'learning_rate': 1.7412863816025872e-05, 'epoch': 0.26}


 27%|██▋       | 870/3283 [39:12<4:57:42,  7.40s/it]  

{'loss': 0.2233, 'learning_rate': 1.734099892202659e-05, 'epoch': 0.27}


 27%|██▋       | 880/3283 [39:27<1:22:16,  2.05s/it]

{'loss': 0.2078, 'learning_rate': 1.726913402802731e-05, 'epoch': 0.27}


 27%|██▋       | 890/3283 [39:39<45:55,  1.15s/it]  

{'loss': 0.1765, 'learning_rate': 1.719726913402803e-05, 'epoch': 0.27}


 27%|██▋       | 900/3283 [39:50<45:24,  1.14s/it]

{'loss': 0.2546, 'learning_rate': 1.7125404240028745e-05, 'epoch': 0.27}


 28%|██▊       | 910/3283 [40:01<45:33,  1.15s/it]

{'loss': 0.2531, 'learning_rate': 1.7053539346029468e-05, 'epoch': 0.28}


 28%|██▊       | 920/3283 [40:13<44:29,  1.13s/it]

{'loss': 0.2245, 'learning_rate': 1.6981674452030187e-05, 'epoch': 0.28}


 28%|██▊       | 930/3283 [40:24<44:13,  1.13s/it]

{'loss': 0.3177, 'learning_rate': 1.6909809558030903e-05, 'epoch': 0.28}


 29%|██▊       | 940/3283 [40:36<46:06,  1.18s/it]

{'loss': 0.1984, 'learning_rate': 1.6837944664031622e-05, 'epoch': 0.29}


 29%|██▉       | 950/3283 [40:47<44:56,  1.16s/it]

{'loss': 0.239, 'learning_rate': 1.676607977003234e-05, 'epoch': 0.29}


 29%|██▉       | 960/3283 [40:59<43:39,  1.13s/it]

{'loss': 0.2977, 'learning_rate': 1.669421487603306e-05, 'epoch': 0.29}


 30%|██▉       | 970/3283 [41:10<43:22,  1.13s/it]

{'loss': 0.3031, 'learning_rate': 1.662234998203378e-05, 'epoch': 0.3}


 30%|██▉       | 980/3283 [41:21<42:14,  1.10s/it]

{'loss': 0.3205, 'learning_rate': 1.6550485088034495e-05, 'epoch': 0.3}


 30%|███       | 990/3283 [41:32<43:13,  1.13s/it]

{'loss': 0.279, 'learning_rate': 1.6478620194035214e-05, 'epoch': 0.3}


 30%|███       | 1000/3283 [41:43<41:26,  1.09s/it]

{'loss': 0.4871, 'learning_rate': 1.6406755300035933e-05, 'epoch': 0.3}


 31%|███       | 1010/3283 [41:56<42:38,  1.13s/it]

{'loss': 0.2668, 'learning_rate': 1.6334890406036652e-05, 'epoch': 0.31}


 31%|███       | 1020/3283 [42:06<40:02,  1.06s/it]

{'loss': 0.3376, 'learning_rate': 1.626302551203737e-05, 'epoch': 0.31}


 31%|███▏      | 1030/3283 [42:17<40:15,  1.07s/it]

{'loss': 0.2283, 'learning_rate': 1.6191160618038087e-05, 'epoch': 0.31}


 32%|███▏      | 1040/3283 [42:28<39:57,  1.07s/it]

{'loss': 0.2251, 'learning_rate': 1.611929572403881e-05, 'epoch': 0.32}


 32%|███▏      | 1050/3283 [42:39<40:49,  1.10s/it]

{'loss': 0.2307, 'learning_rate': 1.604743083003953e-05, 'epoch': 0.32}


 32%|███▏      | 1060/3283 [42:50<41:59,  1.13s/it]

{'loss': 0.2253, 'learning_rate': 1.5975565936040244e-05, 'epoch': 0.32}


 33%|███▎      | 1070/3283 [43:01<41:44,  1.13s/it]

{'loss': 0.2611, 'learning_rate': 1.5903701042040963e-05, 'epoch': 0.33}


 33%|███▎      | 1080/3283 [43:12<39:47,  1.08s/it]

{'loss': 0.1615, 'learning_rate': 1.5831836148041683e-05, 'epoch': 0.33}


 33%|███▎      | 1090/3283 [43:23<40:49,  1.12s/it]

{'loss': 0.1962, 'learning_rate': 1.57599712540424e-05, 'epoch': 0.33}


 34%|███▎      | 1100/3283 [43:34<40:07,  1.10s/it]

{'loss': 0.2191, 'learning_rate': 1.568810636004312e-05, 'epoch': 0.34}


 34%|███▍      | 1110/3283 [43:45<38:39,  1.07s/it]

{'loss': 0.2284, 'learning_rate': 1.561624146604384e-05, 'epoch': 0.34}


 34%|███▍      | 1120/3283 [43:56<40:03,  1.11s/it]

{'loss': 0.2794, 'learning_rate': 1.5544376572044556e-05, 'epoch': 0.34}


 34%|███▍      | 1130/3283 [44:07<38:47,  1.08s/it]

{'loss': 0.3038, 'learning_rate': 1.5472511678045278e-05, 'epoch': 0.34}


 35%|███▍      | 1140/3283 [44:18<39:01,  1.09s/it]

{'loss': 0.1632, 'learning_rate': 1.5400646784045994e-05, 'epoch': 0.35}


 35%|███▌      | 1150/3283 [44:29<38:07,  1.07s/it]

{'loss': 0.1969, 'learning_rate': 1.5328781890046713e-05, 'epoch': 0.35}


 35%|███▌      | 1160/3283 [44:40<38:00,  1.07s/it]

{'loss': 0.3509, 'learning_rate': 1.5256916996047434e-05, 'epoch': 0.35}


 36%|███▌      | 1170/3283 [44:51<39:49,  1.13s/it]

{'loss': 0.217, 'learning_rate': 1.518505210204815e-05, 'epoch': 0.36}


 36%|███▌      | 1180/3283 [45:02<38:50,  1.11s/it]

{'loss': 0.2125, 'learning_rate': 1.511318720804887e-05, 'epoch': 0.36}


 36%|███▌      | 1190/3283 [45:13<37:16,  1.07s/it]

{'loss': 0.1701, 'learning_rate': 1.504132231404959e-05, 'epoch': 0.36}


 37%|███▋      | 1200/3283 [45:23<39:07,  1.13s/it]

{'loss': 0.1812, 'learning_rate': 1.4969457420050307e-05, 'epoch': 0.37}


 37%|███▋      | 1210/3283 [45:34<37:01,  1.07s/it]

{'loss': 0.2777, 'learning_rate': 1.4897592526051026e-05, 'epoch': 0.37}


 37%|███▋      | 1220/3283 [45:45<37:54,  1.10s/it]

{'loss': 0.221, 'learning_rate': 1.4825727632051743e-05, 'epoch': 0.37}


 37%|███▋      | 1230/3283 [45:56<36:53,  1.08s/it]

{'loss': 0.2032, 'learning_rate': 1.4753862738052462e-05, 'epoch': 0.37}


 38%|███▊      | 1240/3283 [46:07<37:34,  1.10s/it]

{'loss': 0.2662, 'learning_rate': 1.4681997844053181e-05, 'epoch': 0.38}


 38%|███▊      | 1250/3283 [46:18<36:12,  1.07s/it]

{'loss': 0.2841, 'learning_rate': 1.4610132950053899e-05, 'epoch': 0.38}


 38%|███▊      | 1260/3283 [46:29<35:57,  1.07s/it]

{'loss': 0.2108, 'learning_rate': 1.4538268056054618e-05, 'epoch': 0.38}


 39%|███▊      | 1270/3283 [46:40<36:57,  1.10s/it]

{'loss': 0.2237, 'learning_rate': 1.4466403162055339e-05, 'epoch': 0.39}


 39%|███▉      | 1280/3283 [46:51<37:17,  1.12s/it]

{'loss': 0.1891, 'learning_rate': 1.4394538268056054e-05, 'epoch': 0.39}


 39%|███▉      | 1290/3283 [47:02<36:03,  1.09s/it]

{'loss': 0.2481, 'learning_rate': 1.4322673374056775e-05, 'epoch': 0.39}


 40%|███▉      | 1300/3283 [47:13<36:54,  1.12s/it]

{'loss': 0.1881, 'learning_rate': 1.4250808480057494e-05, 'epoch': 0.4}


 40%|███▉      | 1310/3283 [47:24<35:52,  1.09s/it]

{'loss': 0.3101, 'learning_rate': 1.4178943586058212e-05, 'epoch': 0.4}


 40%|████      | 1320/3283 [47:35<35:31,  1.09s/it]

{'loss': 0.3364, 'learning_rate': 1.410707869205893e-05, 'epoch': 0.4}


 41%|████      | 1330/3283 [47:45<34:57,  1.07s/it]

{'loss': 0.2345, 'learning_rate': 1.4035213798059648e-05, 'epoch': 0.41}


 41%|████      | 1340/3283 [47:57<35:40,  1.10s/it]

{'loss': 0.2097, 'learning_rate': 1.3963348904060367e-05, 'epoch': 0.41}


 41%|████      | 1350/3283 [48:08<35:03,  1.09s/it]

{'loss': 0.1791, 'learning_rate': 1.3891484010061086e-05, 'epoch': 0.41}


 41%|████▏     | 1360/3283 [48:19<34:48,  1.09s/it]

{'loss': 0.2274, 'learning_rate': 1.3819619116061804e-05, 'epoch': 0.41}


 42%|████▏     | 1370/3283 [48:30<34:28,  1.08s/it]

{'loss': 0.2618, 'learning_rate': 1.3747754222062523e-05, 'epoch': 0.42}


 42%|████▏     | 1380/3283 [48:40<34:45,  1.10s/it]

{'loss': 0.3009, 'learning_rate': 1.3675889328063244e-05, 'epoch': 0.42}


 42%|████▏     | 1390/3283 [48:52<34:52,  1.11s/it]

{'loss': 0.2715, 'learning_rate': 1.3604024434063961e-05, 'epoch': 0.42}


 43%|████▎     | 1400/3283 [49:02<33:55,  1.08s/it]

{'loss': 0.2056, 'learning_rate': 1.353215954006468e-05, 'epoch': 0.43}


 43%|████▎     | 1410/3283 [49:13<33:35,  1.08s/it]

{'loss': 0.2534, 'learning_rate': 1.3460294646065398e-05, 'epoch': 0.43}


 43%|████▎     | 1420/3283 [49:24<33:15,  1.07s/it]

{'loss': 0.2134, 'learning_rate': 1.3388429752066117e-05, 'epoch': 0.43}


 44%|████▎     | 1430/3283 [49:35<34:04,  1.10s/it]

{'loss': 0.253, 'learning_rate': 1.3316564858066836e-05, 'epoch': 0.44}


 44%|████▍     | 1440/3283 [49:46<33:29,  1.09s/it]

{'loss': 0.2722, 'learning_rate': 1.3244699964067553e-05, 'epoch': 0.44}


 44%|████▍     | 1450/3283 [49:57<33:37,  1.10s/it]

{'loss': 0.2464, 'learning_rate': 1.3172835070068272e-05, 'epoch': 0.44}


 44%|████▍     | 1460/3283 [50:09<33:46,  1.11s/it]

{'loss': 0.2786, 'learning_rate': 1.3100970176068992e-05, 'epoch': 0.44}


 45%|████▍     | 1470/3283 [50:19<32:02,  1.06s/it]

{'loss': 0.2376, 'learning_rate': 1.3029105282069709e-05, 'epoch': 0.45}


 45%|████▌     | 1480/3283 [50:31<32:55,  1.10s/it]

{'loss': 0.1352, 'learning_rate': 1.295724038807043e-05, 'epoch': 0.45}


 45%|████▌     | 1490/3283 [50:41<32:11,  1.08s/it]

{'loss': 0.2583, 'learning_rate': 1.2885375494071149e-05, 'epoch': 0.45}


 46%|████▌     | 1500/3283 [50:52<31:45,  1.07s/it]

{'loss': 0.1849, 'learning_rate': 1.2813510600071866e-05, 'epoch': 0.46}


 46%|████▌     | 1510/3283 [51:05<32:53,  1.11s/it]

{'loss': 0.2766, 'learning_rate': 1.2741645706072585e-05, 'epoch': 0.46}


 46%|████▋     | 1520/3283 [51:16<31:26,  1.07s/it]

{'loss': 0.2374, 'learning_rate': 1.2669780812073303e-05, 'epoch': 0.46}


 47%|████▋     | 1530/3283 [51:27<32:51,  1.12s/it]

{'loss': 0.1948, 'learning_rate': 1.2597915918074022e-05, 'epoch': 0.47}


 47%|████▋     | 1540/3283 [51:38<31:10,  1.07s/it]

{'loss': 0.2281, 'learning_rate': 1.2526051024074741e-05, 'epoch': 0.47}


 47%|████▋     | 1550/3283 [51:48<30:22,  1.05s/it]

{'loss': 0.2302, 'learning_rate': 1.2454186130075458e-05, 'epoch': 0.47}


 48%|████▊     | 1560/3283 [51:59<31:13,  1.09s/it]

{'loss': 0.3148, 'learning_rate': 1.2382321236076177e-05, 'epoch': 0.48}


 48%|████▊     | 1570/3283 [52:10<30:53,  1.08s/it]

{'loss': 0.2066, 'learning_rate': 1.2310456342076898e-05, 'epoch': 0.48}


 48%|████▊     | 1580/3283 [52:21<30:44,  1.08s/it]

{'loss': 0.2871, 'learning_rate': 1.2238591448077614e-05, 'epoch': 0.48}


 48%|████▊     | 1590/3283 [52:32<30:37,  1.09s/it]

{'loss': 0.2562, 'learning_rate': 1.2166726554078335e-05, 'epoch': 0.48}


 49%|████▊     | 1600/3283 [52:43<30:49,  1.10s/it]

{'loss': 0.1827, 'learning_rate': 1.2094861660079052e-05, 'epoch': 0.49}


 49%|████▉     | 1610/3283 [52:54<30:47,  1.10s/it]

{'loss': 0.1982, 'learning_rate': 1.2022996766079771e-05, 'epoch': 0.49}


 49%|████▉     | 1620/3283 [53:05<30:28,  1.10s/it]

{'loss': 0.217, 'learning_rate': 1.195113187208049e-05, 'epoch': 0.49}


 50%|████▉     | 1630/3283 [53:16<30:12,  1.10s/it]

{'loss': 0.2779, 'learning_rate': 1.1879266978081208e-05, 'epoch': 0.5}


 50%|████▉     | 1640/3283 [53:27<29:31,  1.08s/it]

{'loss': 0.2272, 'learning_rate': 1.1807402084081927e-05, 'epoch': 0.5}


 50%|█████     | 1650/3283 [53:38<28:52,  1.06s/it]

{'loss': 0.2945, 'learning_rate': 1.1735537190082646e-05, 'epoch': 0.5}


 51%|█████     | 1660/3283 [53:49<29:48,  1.10s/it]

{'loss': 0.1745, 'learning_rate': 1.1663672296083363e-05, 'epoch': 0.51}


 51%|█████     | 1670/3283 [54:00<29:05,  1.08s/it]

{'loss': 0.3049, 'learning_rate': 1.1591807402084083e-05, 'epoch': 0.51}


 51%|█████     | 1680/3283 [54:11<29:51,  1.12s/it]

{'loss': 0.2904, 'learning_rate': 1.15199425080848e-05, 'epoch': 0.51}


 51%|█████▏    | 1690/3283 [54:22<30:24,  1.15s/it]

{'loss': 0.2282, 'learning_rate': 1.144807761408552e-05, 'epoch': 0.51}


 52%|█████▏    | 1700/3283 [54:33<29:32,  1.12s/it]

{'loss': 0.271, 'learning_rate': 1.137621272008624e-05, 'epoch': 0.52}


 52%|█████▏    | 1710/3283 [54:44<28:07,  1.07s/it]

{'loss': 0.3145, 'learning_rate': 1.1304347826086957e-05, 'epoch': 0.52}


 52%|█████▏    | 1720/3283 [54:55<28:56,  1.11s/it]

{'loss': 0.2657, 'learning_rate': 1.1232482932087676e-05, 'epoch': 0.52}


 53%|█████▎    | 1730/3283 [55:06<27:25,  1.06s/it]

{'loss': 0.2176, 'learning_rate': 1.1160618038088395e-05, 'epoch': 0.53}


 53%|█████▎    | 1740/3283 [55:16<27:35,  1.07s/it]

{'loss': 0.296, 'learning_rate': 1.1088753144089113e-05, 'epoch': 0.53}


 53%|█████▎    | 1750/3283 [55:27<27:49,  1.09s/it]

{'loss': 0.1676, 'learning_rate': 1.1016888250089832e-05, 'epoch': 0.53}


 54%|█████▎    | 1760/3283 [55:38<27:49,  1.10s/it]

{'loss': 0.22, 'learning_rate': 1.0945023356090551e-05, 'epoch': 0.54}


 54%|█████▍    | 1770/3283 [55:49<26:24,  1.05s/it]

{'loss': 0.2811, 'learning_rate': 1.0873158462091268e-05, 'epoch': 0.54}


 54%|█████▍    | 1780/3283 [56:00<26:42,  1.07s/it]

{'loss': 0.2523, 'learning_rate': 1.0801293568091988e-05, 'epoch': 0.54}


 55%|█████▍    | 1790/3283 [56:11<26:26,  1.06s/it]

{'loss': 0.2199, 'learning_rate': 1.0729428674092705e-05, 'epoch': 0.55}


 55%|█████▍    | 1800/3283 [56:22<27:14,  1.10s/it]

{'loss': 0.2509, 'learning_rate': 1.0657563780093426e-05, 'epoch': 0.55}


 55%|█████▌    | 1810/3283 [56:33<26:44,  1.09s/it]

{'loss': 0.2527, 'learning_rate': 1.0585698886094145e-05, 'epoch': 0.55}


 55%|█████▌    | 1820/3283 [56:44<26:32,  1.09s/it]

{'loss': 0.2402, 'learning_rate': 1.0513833992094862e-05, 'epoch': 0.55}


 56%|█████▌    | 1830/3283 [56:55<25:57,  1.07s/it]

{'loss': 0.2402, 'learning_rate': 1.0441969098095581e-05, 'epoch': 0.56}


 56%|█████▌    | 1840/3283 [57:06<25:26,  1.06s/it]

{'loss': 0.223, 'learning_rate': 1.03701042040963e-05, 'epoch': 0.56}


 56%|█████▋    | 1850/3283 [57:17<25:48,  1.08s/it]

{'loss': 0.2819, 'learning_rate': 1.0298239310097018e-05, 'epoch': 0.56}


 57%|█████▋    | 1860/3283 [57:28<25:06,  1.06s/it]

{'loss': 0.3094, 'learning_rate': 1.0226374416097737e-05, 'epoch': 0.57}


 57%|█████▋    | 1870/3283 [57:39<26:02,  1.11s/it]

{'loss': 0.2338, 'learning_rate': 1.0154509522098454e-05, 'epoch': 0.57}


 57%|█████▋    | 1880/3283 [57:50<25:11,  1.08s/it]

{'loss': 0.1931, 'learning_rate': 1.0082644628099174e-05, 'epoch': 0.57}


 58%|█████▊    | 1890/3283 [58:01<25:41,  1.11s/it]

{'loss': 0.2716, 'learning_rate': 1.0010779734099894e-05, 'epoch': 0.58}


 58%|█████▊    | 1900/3283 [58:12<25:26,  1.10s/it]

{'loss': 0.1741, 'learning_rate': 9.938914840100612e-06, 'epoch': 0.58}


 58%|█████▊    | 1910/3283 [58:23<24:43,  1.08s/it]

{'loss': 0.2678, 'learning_rate': 9.86704994610133e-06, 'epoch': 0.58}


 58%|█████▊    | 1920/3283 [58:34<24:46,  1.09s/it]

{'loss': 0.3107, 'learning_rate': 9.795185052102048e-06, 'epoch': 0.58}


 59%|█████▉    | 1930/3283 [58:45<24:14,  1.08s/it]

{'loss': 0.2763, 'learning_rate': 9.723320158102767e-06, 'epoch': 0.59}


 59%|█████▉    | 1940/3283 [58:56<24:27,  1.09s/it]

{'loss': 0.2595, 'learning_rate': 9.651455264103486e-06, 'epoch': 0.59}


 59%|█████▉    | 1950/3283 [59:07<24:05,  1.08s/it]

{'loss': 0.3177, 'learning_rate': 9.579590370104206e-06, 'epoch': 0.59}


 60%|█████▉    | 1960/3283 [59:18<24:13,  1.10s/it]

{'loss': 0.2181, 'learning_rate': 9.507725476104923e-06, 'epoch': 0.6}


 60%|██████    | 1970/3283 [59:29<23:34,  1.08s/it]

{'loss': 0.1695, 'learning_rate': 9.435860582105642e-06, 'epoch': 0.6}


 60%|██████    | 1980/3283 [59:40<23:30,  1.08s/it]

{'loss': 0.2232, 'learning_rate': 9.363995688106361e-06, 'epoch': 0.6}


 61%|██████    | 1990/3283 [59:51<23:29,  1.09s/it]

{'loss': 0.2177, 'learning_rate': 9.292130794107079e-06, 'epoch': 0.61}


 61%|██████    | 2000/3283 [1:00:02<23:08,  1.08s/it]

{'loss': 0.2264, 'learning_rate': 9.220265900107798e-06, 'epoch': 0.61}


 61%|██████    | 2010/3283 [1:00:15<23:45,  1.12s/it]

{'loss': 0.2289, 'learning_rate': 9.148401006108517e-06, 'epoch': 0.61}


 62%|██████▏   | 2020/3283 [1:00:26<22:58,  1.09s/it]

{'loss': 0.1975, 'learning_rate': 9.076536112109236e-06, 'epoch': 0.62}


 62%|██████▏   | 2030/3283 [1:00:37<23:16,  1.11s/it]

{'loss': 0.2274, 'learning_rate': 9.004671218109953e-06, 'epoch': 0.62}


 62%|██████▏   | 2040/3283 [1:00:48<22:18,  1.08s/it]

{'loss': 0.2017, 'learning_rate': 8.932806324110672e-06, 'epoch': 0.62}


 62%|██████▏   | 2050/3283 [1:00:59<22:16,  1.08s/it]

{'loss': 0.2666, 'learning_rate': 8.860941430111391e-06, 'epoch': 0.62}


 63%|██████▎   | 2060/3283 [1:01:09<22:03,  1.08s/it]

{'loss': 0.4474, 'learning_rate': 8.78907653611211e-06, 'epoch': 0.63}


 63%|██████▎   | 2070/3283 [1:01:20<22:02,  1.09s/it]

{'loss': 0.2416, 'learning_rate': 8.717211642112828e-06, 'epoch': 0.63}


 63%|██████▎   | 2080/3283 [1:01:31<21:29,  1.07s/it]

{'loss': 0.275, 'learning_rate': 8.645346748113547e-06, 'epoch': 0.63}


 64%|██████▎   | 2090/3283 [1:01:42<21:47,  1.10s/it]

{'loss': 0.2439, 'learning_rate': 8.573481854114266e-06, 'epoch': 0.64}


 64%|██████▍   | 2100/3283 [1:01:53<21:21,  1.08s/it]

{'loss': 0.2208, 'learning_rate': 8.501616960114985e-06, 'epoch': 0.64}


 64%|██████▍   | 2110/3283 [1:02:04<21:17,  1.09s/it]

{'loss': 0.2314, 'learning_rate': 8.429752066115703e-06, 'epoch': 0.64}


 65%|██████▍   | 2120/3283 [1:02:15<20:47,  1.07s/it]

{'loss': 0.2198, 'learning_rate': 8.357887172116422e-06, 'epoch': 0.65}


 65%|██████▍   | 2130/3283 [1:02:26<20:52,  1.09s/it]

{'loss': 0.2582, 'learning_rate': 8.286022278117141e-06, 'epoch': 0.65}


 65%|██████▌   | 2140/3283 [1:02:37<20:30,  1.08s/it]

{'loss': 0.186, 'learning_rate': 8.214157384117858e-06, 'epoch': 0.65}


 65%|██████▌   | 2150/3283 [1:02:48<20:44,  1.10s/it]

{'loss': 0.2851, 'learning_rate': 8.142292490118577e-06, 'epoch': 0.65}


 66%|██████▌   | 2160/3283 [1:02:59<19:50,  1.06s/it]

{'loss': 0.2467, 'learning_rate': 8.070427596119297e-06, 'epoch': 0.66}


 66%|██████▌   | 2170/3283 [1:03:10<20:09,  1.09s/it]

{'loss': 0.1939, 'learning_rate': 7.998562702120016e-06, 'epoch': 0.66}


 66%|██████▋   | 2180/3283 [1:03:20<19:38,  1.07s/it]

{'loss': 0.2272, 'learning_rate': 7.926697808120733e-06, 'epoch': 0.66}


 67%|██████▋   | 2190/3283 [1:03:31<19:59,  1.10s/it]

{'loss': 0.2233, 'learning_rate': 7.854832914121452e-06, 'epoch': 0.67}


 67%|██████▋   | 2200/3283 [1:03:43<19:49,  1.10s/it]

{'loss': 0.1996, 'learning_rate': 7.782968020122171e-06, 'epoch': 0.67}


 67%|██████▋   | 2210/3283 [1:03:54<19:30,  1.09s/it]

{'loss': 0.2284, 'learning_rate': 7.71110312612289e-06, 'epoch': 0.67}


 68%|██████▊   | 2220/3283 [1:04:05<19:28,  1.10s/it]

{'loss': 0.1691, 'learning_rate': 7.639238232123608e-06, 'epoch': 0.68}


 68%|██████▊   | 2230/3283 [1:04:16<19:05,  1.09s/it]

{'loss': 0.1765, 'learning_rate': 7.567373338124326e-06, 'epoch': 0.68}


 68%|██████▊   | 2240/3283 [1:04:27<18:53,  1.09s/it]

{'loss': 0.2185, 'learning_rate': 7.495508444125046e-06, 'epoch': 0.68}


 69%|██████▊   | 2250/3283 [1:04:37<18:32,  1.08s/it]

{'loss': 0.2524, 'learning_rate': 7.423643550125764e-06, 'epoch': 0.69}


 69%|██████▉   | 2260/3283 [1:04:48<18:36,  1.09s/it]

{'loss': 0.2299, 'learning_rate': 7.3517786561264825e-06, 'epoch': 0.69}


 69%|██████▉   | 2270/3283 [1:04:59<18:19,  1.09s/it]

{'loss': 0.266, 'learning_rate': 7.279913762127201e-06, 'epoch': 0.69}


 69%|██████▉   | 2280/3283 [1:05:10<18:05,  1.08s/it]

{'loss': 0.2605, 'learning_rate': 7.208048868127921e-06, 'epoch': 0.69}


 70%|██████▉   | 2290/3283 [1:05:21<17:53,  1.08s/it]

{'loss': 0.2386, 'learning_rate': 7.136183974128639e-06, 'epoch': 0.7}


 70%|███████   | 2300/3283 [1:05:32<18:00,  1.10s/it]

{'loss': 0.2149, 'learning_rate': 7.064319080129357e-06, 'epoch': 0.7}


 70%|███████   | 2310/3283 [1:05:43<17:51,  1.10s/it]

{'loss': 0.241, 'learning_rate': 6.992454186130076e-06, 'epoch': 0.7}


 71%|███████   | 2320/3283 [1:05:54<17:35,  1.10s/it]

{'loss': 0.2012, 'learning_rate': 6.9205892921307946e-06, 'epoch': 0.71}


 71%|███████   | 2330/3283 [1:06:05<17:37,  1.11s/it]

{'loss': 0.1827, 'learning_rate': 6.848724398131513e-06, 'epoch': 0.71}


 71%|███████▏  | 2340/3283 [1:06:16<16:55,  1.08s/it]

{'loss': 0.2177, 'learning_rate': 6.776859504132232e-06, 'epoch': 0.71}


 72%|███████▏  | 2350/3283 [1:06:27<17:12,  1.11s/it]

{'loss': 0.2208, 'learning_rate': 6.704994610132951e-06, 'epoch': 0.72}


 72%|███████▏  | 2360/3283 [1:06:38<16:40,  1.08s/it]

{'loss': 0.22, 'learning_rate': 6.633129716133669e-06, 'epoch': 0.72}


 72%|███████▏  | 2370/3283 [1:06:49<16:29,  1.08s/it]

{'loss': 0.2503, 'learning_rate': 6.5612648221343875e-06, 'epoch': 0.72}


 72%|███████▏  | 2380/3283 [1:07:00<16:25,  1.09s/it]

{'loss': 0.261, 'learning_rate': 6.489399928135106e-06, 'epoch': 0.72}


 73%|███████▎  | 2390/3283 [1:07:11<16:08,  1.08s/it]

{'loss': 0.2011, 'learning_rate': 6.417535034135826e-06, 'epoch': 0.73}


 73%|███████▎  | 2400/3283 [1:07:22<16:10,  1.10s/it]

{'loss': 0.1666, 'learning_rate': 6.345670140136544e-06, 'epoch': 0.73}


 73%|███████▎  | 2410/3283 [1:07:33<15:55,  1.09s/it]

{'loss': 0.2368, 'learning_rate': 6.273805246137262e-06, 'epoch': 0.73}


 74%|███████▎  | 2420/3283 [1:07:44<15:37,  1.09s/it]

{'loss': 0.1888, 'learning_rate': 6.2019403521379805e-06, 'epoch': 0.74}


 74%|███████▍  | 2430/3283 [1:07:55<15:28,  1.09s/it]

{'loss': 0.1807, 'learning_rate': 6.1300754581387005e-06, 'epoch': 0.74}


 74%|███████▍  | 2440/3283 [1:08:06<15:27,  1.10s/it]

{'loss': 0.2515, 'learning_rate': 6.058210564139419e-06, 'epoch': 0.74}


 75%|███████▍  | 2450/3283 [1:08:17<14:56,  1.08s/it]

{'loss': 0.2477, 'learning_rate': 5.986345670140137e-06, 'epoch': 0.75}


 75%|███████▍  | 2460/3283 [1:08:28<15:00,  1.09s/it]

{'loss': 0.2945, 'learning_rate': 5.914480776140855e-06, 'epoch': 0.75}


 75%|███████▌  | 2470/3283 [1:08:39<14:53,  1.10s/it]

{'loss': 0.2049, 'learning_rate': 5.842615882141574e-06, 'epoch': 0.75}


 76%|███████▌  | 2480/3283 [1:08:50<15:00,  1.12s/it]

{'loss': 0.2219, 'learning_rate': 5.770750988142293e-06, 'epoch': 0.76}


 76%|███████▌  | 2490/3283 [1:09:01<14:04,  1.07s/it]

{'loss': 0.2438, 'learning_rate': 5.698886094143012e-06, 'epoch': 0.76}


 76%|███████▌  | 2500/3283 [1:09:12<14:05,  1.08s/it]

{'loss': 0.2028, 'learning_rate': 5.627021200143731e-06, 'epoch': 0.76}


 76%|███████▋  | 2510/3283 [1:09:24<14:03,  1.09s/it]

{'loss': 0.1966, 'learning_rate': 5.555156306144449e-06, 'epoch': 0.76}


 77%|███████▋  | 2520/3283 [1:09:35<14:02,  1.10s/it]

{'loss': 0.1982, 'learning_rate': 5.483291412145167e-06, 'epoch': 0.77}


 77%|███████▋  | 2530/3283 [1:09:45<13:31,  1.08s/it]

{'loss': 0.2532, 'learning_rate': 5.4114265181458856e-06, 'epoch': 0.77}


 77%|███████▋  | 2540/3283 [1:09:56<13:36,  1.10s/it]

{'loss': 0.2058, 'learning_rate': 5.3395616241466055e-06, 'epoch': 0.77}


 78%|███████▊  | 2550/3283 [1:10:07<13:11,  1.08s/it]

{'loss': 0.1929, 'learning_rate': 5.267696730147324e-06, 'epoch': 0.78}


 78%|███████▊  | 2560/3283 [1:10:18<12:56,  1.07s/it]

{'loss': 0.2374, 'learning_rate': 5.195831836148042e-06, 'epoch': 0.78}


 78%|███████▊  | 2570/3283 [1:10:29<12:22,  1.04s/it]

{'loss': 0.2331, 'learning_rate': 5.12396694214876e-06, 'epoch': 0.78}


 79%|███████▊  | 2580/3283 [1:10:39<12:08,  1.04s/it]

{'loss': 0.1982, 'learning_rate': 5.052102048149479e-06, 'epoch': 0.79}


 79%|███████▉  | 2590/3283 [1:10:49<11:53,  1.03s/it]

{'loss': 0.2715, 'learning_rate': 4.9802371541501985e-06, 'epoch': 0.79}


 79%|███████▉  | 2600/3283 [1:11:00<11:57,  1.05s/it]

{'loss': 0.2296, 'learning_rate': 4.908372260150917e-06, 'epoch': 0.79}


 80%|███████▉  | 2610/3283 [1:11:10<11:48,  1.05s/it]

{'loss': 0.1962, 'learning_rate': 4.836507366151635e-06, 'epoch': 0.8}


 80%|███████▉  | 2620/3283 [1:11:21<11:26,  1.04s/it]

{'loss': 0.2251, 'learning_rate': 4.764642472152354e-06, 'epoch': 0.8}


 80%|████████  | 2630/3283 [1:11:31<11:21,  1.04s/it]

{'loss': 0.1926, 'learning_rate': 4.692777578153072e-06, 'epoch': 0.8}


 80%|████████  | 2640/3283 [1:11:42<11:12,  1.05s/it]

{'loss': 0.2807, 'learning_rate': 4.620912684153791e-06, 'epoch': 0.8}


 81%|████████  | 2650/3283 [1:11:52<10:58,  1.04s/it]

{'loss': 0.2446, 'learning_rate': 4.54904779015451e-06, 'epoch': 0.81}


 81%|████████  | 2660/3283 [1:12:02<10:48,  1.04s/it]

{'loss': 0.2485, 'learning_rate': 4.477182896155228e-06, 'epoch': 0.81}


 81%|████████▏ | 2670/3283 [1:12:13<10:30,  1.03s/it]

{'loss': 0.2223, 'learning_rate': 4.405318002155947e-06, 'epoch': 0.81}


 82%|████████▏ | 2680/3283 [1:12:23<10:51,  1.08s/it]

{'loss': 0.3454, 'learning_rate': 4.333453108156666e-06, 'epoch': 0.82}


 82%|████████▏ | 2690/3283 [1:12:34<10:11,  1.03s/it]

{'loss': 0.2505, 'learning_rate': 4.2615882141573845e-06, 'epoch': 0.82}


 82%|████████▏ | 2700/3283 [1:12:44<10:04,  1.04s/it]

{'loss': 0.2231, 'learning_rate': 4.1897233201581036e-06, 'epoch': 0.82}


 83%|████████▎ | 2710/3283 [1:12:55<09:59,  1.05s/it]

{'loss': 0.2441, 'learning_rate': 4.117858426158822e-06, 'epoch': 0.83}


 83%|████████▎ | 2720/3283 [1:13:05<09:38,  1.03s/it]

{'loss': 0.2362, 'learning_rate': 4.045993532159541e-06, 'epoch': 0.83}


 83%|████████▎ | 2730/3283 [1:13:15<09:36,  1.04s/it]

{'loss': 0.1982, 'learning_rate': 3.974128638160259e-06, 'epoch': 0.83}


 83%|████████▎ | 2740/3283 [1:13:26<09:22,  1.04s/it]

{'loss': 0.3071, 'learning_rate': 3.902263744160978e-06, 'epoch': 0.83}


 84%|████████▍ | 2750/3283 [1:13:36<09:14,  1.04s/it]

{'loss': 0.222, 'learning_rate': 3.8303988501616965e-06, 'epoch': 0.84}


 84%|████████▍ | 2760/3283 [1:13:46<09:05,  1.04s/it]

{'loss': 0.1991, 'learning_rate': 3.758533956162415e-06, 'epoch': 0.84}


 84%|████████▍ | 2770/3283 [1:13:57<08:50,  1.03s/it]

{'loss': 0.2703, 'learning_rate': 3.6866690621631335e-06, 'epoch': 0.84}


 85%|████████▍ | 2780/3283 [1:14:07<08:33,  1.02s/it]

{'loss': 0.1877, 'learning_rate': 3.6148041681638526e-06, 'epoch': 0.85}


 85%|████████▍ | 2790/3283 [1:14:17<08:29,  1.03s/it]

{'loss': 0.2892, 'learning_rate': 3.542939274164571e-06, 'epoch': 0.85}


 85%|████████▌ | 2800/3283 [1:14:28<08:23,  1.04s/it]

{'loss': 0.2305, 'learning_rate': 3.4710743801652895e-06, 'epoch': 0.85}


 86%|████████▌ | 2810/3283 [1:14:38<08:13,  1.04s/it]

{'loss': 0.3016, 'learning_rate': 3.399209486166008e-06, 'epoch': 0.86}


 86%|████████▌ | 2820/3283 [1:14:49<08:00,  1.04s/it]

{'loss': 0.174, 'learning_rate': 3.327344592166727e-06, 'epoch': 0.86}


 86%|████████▌ | 2830/3283 [1:14:59<07:53,  1.04s/it]

{'loss': 0.2026, 'learning_rate': 3.255479698167445e-06, 'epoch': 0.86}


 87%|████████▋ | 2840/3283 [1:15:09<07:41,  1.04s/it]

{'loss': 0.2765, 'learning_rate': 3.1836148041681642e-06, 'epoch': 0.87}


 87%|████████▋ | 2850/3283 [1:15:20<07:25,  1.03s/it]

{'loss': 0.2346, 'learning_rate': 3.1117499101688825e-06, 'epoch': 0.87}


 87%|████████▋ | 2860/3283 [1:15:30<07:27,  1.06s/it]

{'loss': 0.2576, 'learning_rate': 3.0398850161696016e-06, 'epoch': 0.87}


 87%|████████▋ | 2870/3283 [1:15:40<07:07,  1.04s/it]

{'loss': 0.2329, 'learning_rate': 2.9680201221703203e-06, 'epoch': 0.87}


 88%|████████▊ | 2880/3283 [1:15:51<06:59,  1.04s/it]

{'loss': 0.2304, 'learning_rate': 2.8961552281710385e-06, 'epoch': 0.88}


 88%|████████▊ | 2890/3283 [1:16:01<06:47,  1.04s/it]

{'loss': 0.2032, 'learning_rate': 2.8242903341717576e-06, 'epoch': 0.88}


 88%|████████▊ | 2900/3283 [1:16:11<06:30,  1.02s/it]

{'loss': 0.2661, 'learning_rate': 2.752425440172476e-06, 'epoch': 0.88}


 89%|████████▊ | 2910/3283 [1:16:22<06:27,  1.04s/it]

{'loss': 0.251, 'learning_rate': 2.680560546173195e-06, 'epoch': 0.89}


 89%|████████▉ | 2920/3283 [1:16:32<06:13,  1.03s/it]

{'loss': 0.2037, 'learning_rate': 2.6086956521739132e-06, 'epoch': 0.89}


 89%|████████▉ | 2930/3283 [1:16:42<06:00,  1.02s/it]

{'loss': 0.2132, 'learning_rate': 2.536830758174632e-06, 'epoch': 0.89}


 90%|████████▉ | 2940/3283 [1:16:53<05:47,  1.01s/it]

{'loss': 0.2408, 'learning_rate': 2.4649658641753506e-06, 'epoch': 0.9}


 90%|████████▉ | 2950/3283 [1:17:03<05:38,  1.02s/it]

{'loss': 0.2138, 'learning_rate': 2.3931009701760693e-06, 'epoch': 0.9}


 90%|█████████ | 2960/3283 [1:17:13<05:33,  1.03s/it]

{'loss': 0.2334, 'learning_rate': 2.3212360761767875e-06, 'epoch': 0.9}


 90%|█████████ | 2970/3283 [1:17:24<05:22,  1.03s/it]

{'loss': 0.2152, 'learning_rate': 2.2493711821775066e-06, 'epoch': 0.9}


 91%|█████████ | 2980/3283 [1:17:34<05:09,  1.02s/it]

{'loss': 0.2036, 'learning_rate': 2.1775062881782253e-06, 'epoch': 0.91}


 91%|█████████ | 2990/3283 [1:17:44<05:02,  1.03s/it]

{'loss': 0.2575, 'learning_rate': 2.105641394178944e-06, 'epoch': 0.91}


 91%|█████████▏| 3000/3283 [1:17:55<04:55,  1.04s/it]

{'loss': 0.2451, 'learning_rate': 2.0337765001796627e-06, 'epoch': 0.91}


 92%|█████████▏| 3010/3283 [1:18:06<04:47,  1.05s/it]

{'loss': 0.2112, 'learning_rate': 1.961911606180381e-06, 'epoch': 0.92}


 92%|█████████▏| 3020/3283 [1:18:16<04:31,  1.03s/it]

{'loss': 0.1842, 'learning_rate': 1.8900467121810998e-06, 'epoch': 0.92}


 92%|█████████▏| 3030/3283 [1:18:27<04:20,  1.03s/it]

{'loss': 0.2231, 'learning_rate': 1.8181818181818183e-06, 'epoch': 0.92}


 93%|█████████▎| 3040/3283 [1:18:37<04:07,  1.02s/it]

{'loss': 0.1955, 'learning_rate': 1.746316924182537e-06, 'epoch': 0.93}


 93%|█████████▎| 3050/3283 [1:18:47<04:03,  1.04s/it]

{'loss': 0.1923, 'learning_rate': 1.6744520301832557e-06, 'epoch': 0.93}


 93%|█████████▎| 3060/3283 [1:18:58<03:49,  1.03s/it]

{'loss': 0.2598, 'learning_rate': 1.6025871361839743e-06, 'epoch': 0.93}


 94%|█████████▎| 3070/3283 [1:19:08<03:41,  1.04s/it]

{'loss': 0.2024, 'learning_rate': 1.5307222421846928e-06, 'epoch': 0.94}


 94%|█████████▍| 3080/3283 [1:19:19<03:32,  1.05s/it]

{'loss': 0.1796, 'learning_rate': 1.4588573481854115e-06, 'epoch': 0.94}


 94%|█████████▍| 3090/3283 [1:19:29<03:19,  1.03s/it]

{'loss': 0.2211, 'learning_rate': 1.3869924541861302e-06, 'epoch': 0.94}


 94%|█████████▍| 3100/3283 [1:19:39<03:06,  1.02s/it]

{'loss': 0.2479, 'learning_rate': 1.3151275601868488e-06, 'epoch': 0.94}


 95%|█████████▍| 3110/3283 [1:19:50<02:53,  1.00s/it]

{'loss': 0.2707, 'learning_rate': 1.2432626661875675e-06, 'epoch': 0.95}


 95%|█████████▌| 3120/3283 [1:20:00<02:49,  1.04s/it]

{'loss': 0.2365, 'learning_rate': 1.1713977721882862e-06, 'epoch': 0.95}


 95%|█████████▌| 3130/3283 [1:20:10<02:40,  1.05s/it]

{'loss': 0.1921, 'learning_rate': 1.0995328781890049e-06, 'epoch': 0.95}


 96%|█████████▌| 3140/3283 [1:20:21<02:27,  1.03s/it]

{'loss': 0.2619, 'learning_rate': 1.0276679841897233e-06, 'epoch': 0.96}


 96%|█████████▌| 3150/3283 [1:20:31<02:18,  1.04s/it]

{'loss': 0.2916, 'learning_rate': 9.55803090190442e-07, 'epoch': 0.96}


 96%|█████████▋| 3160/3283 [1:20:42<02:07,  1.04s/it]

{'loss': 0.2166, 'learning_rate': 8.839381961911607e-07, 'epoch': 0.96}


 97%|█████████▋| 3170/3283 [1:20:52<01:58,  1.05s/it]

{'loss': 0.2433, 'learning_rate': 8.120733021918793e-07, 'epoch': 0.97}


 97%|█████████▋| 3180/3283 [1:21:02<01:45,  1.02s/it]

{'loss': 0.2619, 'learning_rate': 7.40208408192598e-07, 'epoch': 0.97}


 97%|█████████▋| 3190/3283 [1:21:13<01:36,  1.04s/it]

{'loss': 0.2088, 'learning_rate': 6.683435141933165e-07, 'epoch': 0.97}


 97%|█████████▋| 3200/3283 [1:21:23<01:27,  1.05s/it]

{'loss': 0.215, 'learning_rate': 5.964786201940353e-07, 'epoch': 0.97}


 98%|█████████▊| 3210/3283 [1:21:34<01:15,  1.03s/it]

{'loss': 0.212, 'learning_rate': 5.246137261947539e-07, 'epoch': 0.98}


 98%|█████████▊| 3220/3283 [1:21:44<01:05,  1.04s/it]

{'loss': 0.1906, 'learning_rate': 4.527488321954725e-07, 'epoch': 0.98}


 98%|█████████▊| 3230/3283 [1:21:54<00:54,  1.02s/it]

{'loss': 0.2262, 'learning_rate': 3.8088393819619115e-07, 'epoch': 0.98}


 99%|█████████▊| 3240/3283 [1:22:05<00:44,  1.05s/it]

{'loss': 0.169, 'learning_rate': 3.090190441969098e-07, 'epoch': 0.99}


 99%|█████████▉| 3250/3283 [1:22:15<00:33,  1.02s/it]

{'loss': 0.2663, 'learning_rate': 2.3715415019762845e-07, 'epoch': 0.99}


 99%|█████████▉| 3260/3283 [1:22:25<00:23,  1.04s/it]

{'loss': 0.2243, 'learning_rate': 1.6528925619834713e-07, 'epoch': 0.99}


100%|█████████▉| 3270/3283 [1:22:36<00:13,  1.05s/it]

{'loss': 0.2448, 'learning_rate': 9.342436219906577e-08, 'epoch': 1.0}


100%|█████████▉| 3280/3283 [1:22:46<00:03,  1.04s/it]

{'loss': 0.1886, 'learning_rate': 2.1559468199784405e-08, 'epoch': 1.0}


                                                     
100%|██████████| 3283/3283 [1:23:41<00:00,  1.53s/it]

{'eval_loss': 0.18469354510307312, 'eval_runtime': 51.2079, 'eval_samples_per_second': 27.652, 'eval_steps_per_second': 6.913, 'epoch': 1.0}
{'train_runtime': 5021.0636, 'train_samples_per_second': 2.615, 'train_steps_per_second': 0.654, 'train_loss': 0.42550153326341883, 'epoch': 1.0}





TrainOutput(global_step=3283, training_loss=0.42550153326341883, metrics={'train_runtime': 5021.0636, 'train_samples_per_second': 2.615, 'train_steps_per_second': 0.654, 'train_loss': 0.42550153326341883, 'epoch': 1.0})

In [None]:
model.eval()
input_ids = validation_dataset[10]['input_ids'].unsqueeze(0).to(device)
outputs = model.generate(input_ids)
print(outputs)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])



