# Descarga de los datos formateados

In [1]:
%%capture
!pip install gdown

In [2]:
import gdown

files_to_download = {
    "formatted_train.jsonl": "13p2I_-pxXTQRjddJDd26bwAERjAgMqbe",
    "formatted_test.jsonl": "10SQP_HikNR0mcgTJtT9BVczSk78aGziH",
    "formatted_validation.jsonl": "1qkfNA5tm8jFv_whMpNWAgVC_C9YOgo-V"
}

for destination, file_id in files_to_download.items():
    gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=13p2I_-pxXTQRjddJDd26bwAERjAgMqbe
From (redirected): https://drive.google.com/uc?id=13p2I_-pxXTQRjddJDd26bwAERjAgMqbe&confirm=t&uuid=df464a8b-0245-47aa-b794-1841934673e1
To: /content/formatted_train.jsonl
100%|██████████| 176M/176M [00:02<00:00, 76.9MB/s]
Downloading...
From: https://drive.google.com/uc?id=10SQP_HikNR0mcgTJtT9BVczSk78aGziH
To: /content/formatted_test.jsonl
100%|██████████| 8.00M/8.00M [00:00<00:00, 35.2MB/s]
Downloading...
From: https://drive.google.com/uc?id=1qkfNA5tm8jFv_whMpNWAgVC_C9YOgo-V
To: /content/formatted_validation.jsonl
100%|██████████| 17.6M/17.6M [00:00<00:00, 154MB/s]


# Implementación del modelo

In [3]:
%%capture
!pip install unsloth "xformers==0.0.28.post2"
# Also get the latest nightly Unsloth!
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install datasets

In [4]:
from datasets import load_dataset, Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq, TextStreamer, AutoTokenizer
from peft import AutoPeftModelForCausalLM
import json

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [None]:
with open("formatted_train.jsonl", "r") as f:
    data = [json.loads(line)["interaction"] for line in f]

dataset = Dataset.from_dict({"conversations": data})

In [5]:
max_seq_length = 2048

def initialize_model_and_tokenizer():
    base_model_name = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=base_model_name,
        max_seq_length=max_seq_length,
        dtype=None,
        load_in_4bit=True,
    )
    model = FastLanguageModel.get_peft_model(
        model,
        r=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_alpha=16,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=3407,
        use_rslora=False,
        loftq_config=None
    )
    tokenizer = get_chat_template(
        tokenizer,
        chat_template="llama-3.1"
    )

    return model, tokenizer

In [6]:
def preprocess_dataset(dataset, tokenizer):
    def formatting_prompts_func(examples):
        convos = examples["conversations"]
        texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
        return {"text": texts}

    formatted_dataset = dataset.map(formatting_prompts_func, batched=True)
    return formatted_dataset

In [7]:
used_seed = 3407

# Entrenamiento

In [None]:
training_configuration = {
    "lr_scheduler_type": "linear",
    "per_device_train_batch_size": 2,
    "gradient_accumulation_steps": 4,
    "max_steps": 200,
    "output_dir": "second_train_output",
    "learning_rate": 2e-4
}

In [None]:
model, tokenizer = initialize_model_and_tokenizer()
formatted_dataset = preprocess_dataset(dataset, tokenizer)

==((====))==  Unsloth 2024.11.10: Fast Llama patching. Transformers:4.46.2.
   \\   /|    GPU: NVIDIA A100-SXM4-40GB. Max memory: 39.564 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=formatted_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    args=TrainingArguments(
        learning_rate=training_configuration["learning_rate"],
        lr_scheduler_type=training_configuration["lr_scheduler_type"],
        per_device_train_batch_size=training_configuration["per_device_train_batch_size"],
        gradient_accumulation_steps=training_configuration["gradient_accumulation_steps"],
        num_train_epochs=1,
        max_steps=training_configuration["max_steps"],
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.001,
        warmup_steps=5,
        output_dir=training_configuration["output_dir"],
        seed=used_seed,
    )
)

Map (num_proc=2):   0%|          | 0/50000 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs


In [None]:
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.564 GB.
5.354 GB of memory reserved.


In [None]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 50,000 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 200
 "-____-"     Number of trainable parameters = 24,313,856


Step,Training Loss
1,2.4391
2,2.4691
3,2.3365
4,2.4939
5,2.4449
6,2.0578
7,1.9611
8,1.863
9,1.739
10,1.7358


In [None]:
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory         /max_memory*100, 3)
lora_percentage = round(used_memory_for_lora/max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

335.5085 seconds used for training.
5.59 minutes used for training.
Peak reserved memory = 5.354 GB.
Peak reserved memory for training = 0.0 GB.
Peak reserved memory % of max memory = 13.533 %.
Peak reserved memory for training % of max memory = 0.0 %.


In [None]:
model.save_pretrained("lora_model_train")
tokenizer.save_pretrained("lora_model_train")

('lora_model_train/tokenizer_config.json',
 'lora_model_train/special_tokens_map.json',
 'lora_model_train/tokenizer.json')

In [None]:
!zip -r lora_model_train.zip lora_model_train/

  adding: lora_model_train/ (stored 0%)
  adding: lora_model_train/tokenizer.json (deflated 85%)
  adding: lora_model_train/adapter_config.json (deflated 53%)
  adding: lora_model_train/README.md (deflated 66%)
  adding: lora_model_train/tokenizer_config.json (deflated 94%)
  adding: lora_model_train/adapter_model.safetensors (deflated 8%)
  adding: lora_model_train/special_tokens_map.json (deflated 71%)


# Testeo primer modelo

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "lora_model_train",
    max_seq_length = max_seq_length,
    dtype=None,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)

==((====))==  Unsloth 2024.11.10: Fast Llama patching. Transformers:4.46.2.
   \\   /|    GPU: NVIDIA A100-SXM4-40GB. Max memory: 39.564 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 3072, padding_idx=128004)
        (layers): ModuleList(
          (0-27): 28 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lor

In [None]:
 tokenizer = get_chat_template(
    tokenizer,
    chat_template="llama-3.1",
)

In [None]:
with open("formatted_test.jsonl", "r") as f:
    data = [json.loads(line)["interaction"] for line in f]

dataset = Dataset.from_dict({"conversations": data})

In [None]:
formatted_test_dataset = preprocess_dataset(dataset, tokenizer)

Map:   0%|          | 0/2277 [00:00<?, ? examples/s]

In [None]:
%%capture
!pip install tqdm

In [None]:
percentage = 0.2
sample_size = int(len(formatted_test_dataset) * percentage)
subset_test_dataset = formatted_test_dataset.shuffle(seed=42).select(range(sample_size))

print(f"Usando {sample_size} ejemplos para la evaluación.")

Usando 455 ejemplos para la evaluación.


In [None]:
from tqdm import tqdm

generation_args = {
    "max_new_tokens": 100,
    "temperature": 0.3,
    "use_cache": True,
    "top_p": 0.9,
    "top_k": 50,
}

predictions = []
references = []

for example in tqdm(subset_test_dataset):
    conversation = example["conversations"]
    system_message = conversation[0]
    user_assistant_turns = conversation[1:-2]

    max_turns = 10
    truncated_turns = user_assistant_turns[-max_turns:]

    messages = [system_message] + truncated_turns
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda")

    print(f"Inputs shape: {inputs.shape}")

    with torch.no_grad():
        outputs = model.generate(input_ids=inputs, **generation_args)

    generated_response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    if "assistant" in generated_response:
        generated_response = generated_response.split("assistant")[-1].strip()

    predictions.append(generated_response)

    references.append(conversation[-2]["content"])


  0%|          | 0/455 [00:00<?, ?it/s]

Inputs shape: torch.Size([1, 606])


  0%|          | 1/455 [00:03<28:18,  3.74s/it]

Inputs shape: torch.Size([1, 639])


  0%|          | 2/455 [00:08<30:34,  4.05s/it]

Inputs shape: torch.Size([1, 500])


  1%|          | 3/455 [00:11<29:05,  3.86s/it]

Inputs shape: torch.Size([1, 504])


  1%|          | 4/455 [00:14<26:34,  3.54s/it]

Inputs shape: torch.Size([1, 578])


  1%|          | 5/455 [00:19<28:51,  3.85s/it]

Inputs shape: torch.Size([1, 469])


  1%|▏         | 6/455 [00:23<29:58,  4.01s/it]

Inputs shape: torch.Size([1, 783])


  2%|▏         | 7/455 [00:28<32:40,  4.38s/it]

Inputs shape: torch.Size([1, 605])


  2%|▏         | 8/455 [00:32<31:09,  4.18s/it]

Inputs shape: torch.Size([1, 568])


  2%|▏         | 9/455 [00:35<28:47,  3.87s/it]

Inputs shape: torch.Size([1, 1156])


  2%|▏         | 10/455 [00:41<32:44,  4.42s/it]

Inputs shape: torch.Size([1, 451])


  2%|▏         | 11/455 [00:45<32:13,  4.35s/it]

Inputs shape: torch.Size([1, 567])


  3%|▎         | 12/455 [00:49<30:58,  4.20s/it]

Inputs shape: torch.Size([1, 578])


  3%|▎         | 13/455 [00:52<29:42,  4.03s/it]

Inputs shape: torch.Size([1, 635])


  3%|▎         | 14/455 [00:56<29:11,  3.97s/it]

Inputs shape: torch.Size([1, 887])


  3%|▎         | 15/455 [00:58<24:13,  3.30s/it]

Inputs shape: torch.Size([1, 432])


  4%|▎         | 16/455 [01:01<24:35,  3.36s/it]

Inputs shape: torch.Size([1, 719])


  4%|▎         | 17/455 [01:06<27:24,  3.75s/it]

Inputs shape: torch.Size([1, 513])


  4%|▍         | 18/455 [01:10<26:44,  3.67s/it]

Inputs shape: torch.Size([1, 578])


  4%|▍         | 19/455 [01:14<29:24,  4.05s/it]

Inputs shape: torch.Size([1, 608])


  4%|▍         | 20/455 [01:20<32:54,  4.54s/it]

Inputs shape: torch.Size([1, 671])


  5%|▍         | 21/455 [01:24<30:20,  4.19s/it]

Inputs shape: torch.Size([1, 824])


  5%|▍         | 22/455 [01:28<31:09,  4.32s/it]

Inputs shape: torch.Size([1, 629])


  5%|▌         | 23/455 [01:32<30:03,  4.17s/it]

Inputs shape: torch.Size([1, 728])


  5%|▌         | 24/455 [01:36<28:54,  4.02s/it]

Inputs shape: torch.Size([1, 679])


  5%|▌         | 25/455 [01:39<28:05,  3.92s/it]

Inputs shape: torch.Size([1, 636])


  6%|▌         | 26/455 [01:43<27:13,  3.81s/it]

Inputs shape: torch.Size([1, 656])


  6%|▌         | 27/455 [01:48<29:17,  4.11s/it]

Inputs shape: torch.Size([1, 598])


  6%|▌         | 28/455 [01:52<29:14,  4.11s/it]

Inputs shape: torch.Size([1, 549])


  6%|▋         | 29/455 [01:55<28:01,  3.95s/it]

Inputs shape: torch.Size([1, 850])


  7%|▋         | 30/455 [01:59<27:50,  3.93s/it]

Inputs shape: torch.Size([1, 639])


  7%|▋         | 31/455 [02:03<27:37,  3.91s/it]

Inputs shape: torch.Size([1, 439])


  7%|▋         | 32/455 [02:06<25:17,  3.59s/it]

Inputs shape: torch.Size([1, 917])


  7%|▋         | 33/455 [02:11<27:17,  3.88s/it]

Inputs shape: torch.Size([1, 761])


  7%|▋         | 34/455 [02:14<25:56,  3.70s/it]

Inputs shape: torch.Size([1, 593])


  8%|▊         | 35/455 [02:17<25:23,  3.63s/it]

Inputs shape: torch.Size([1, 501])


  8%|▊         | 36/455 [02:21<25:31,  3.65s/it]

Inputs shape: torch.Size([1, 642])


  8%|▊         | 37/455 [02:24<23:22,  3.36s/it]

Inputs shape: torch.Size([1, 658])


  8%|▊         | 38/455 [02:29<27:19,  3.93s/it]

Inputs shape: torch.Size([1, 512])


  9%|▊         | 39/455 [02:33<26:49,  3.87s/it]

Inputs shape: torch.Size([1, 705])


  9%|▉         | 40/455 [02:37<27:42,  4.01s/it]

Inputs shape: torch.Size([1, 466])


  9%|▉         | 41/455 [02:40<25:49,  3.74s/it]

Inputs shape: torch.Size([1, 578])


  9%|▉         | 42/455 [02:45<27:42,  4.02s/it]

Inputs shape: torch.Size([1, 607])


  9%|▉         | 43/455 [02:48<26:08,  3.81s/it]

Inputs shape: torch.Size([1, 560])


 10%|▉         | 44/455 [02:51<24:24,  3.56s/it]

Inputs shape: torch.Size([1, 703])


 10%|▉         | 45/455 [02:56<26:34,  3.89s/it]

Inputs shape: torch.Size([1, 526])


 10%|█         | 46/455 [03:00<27:12,  3.99s/it]

Inputs shape: torch.Size([1, 747])


 10%|█         | 47/455 [03:04<26:57,  3.96s/it]

Inputs shape: torch.Size([1, 427])


 11%|█         | 48/455 [03:07<25:30,  3.76s/it]

Inputs shape: torch.Size([1, 709])


 11%|█         | 49/455 [03:11<25:42,  3.80s/it]

Inputs shape: torch.Size([1, 854])


 11%|█         | 50/455 [03:15<26:41,  3.95s/it]

Inputs shape: torch.Size([1, 651])


 11%|█         | 51/455 [03:20<27:05,  4.02s/it]

Inputs shape: torch.Size([1, 560])


 11%|█▏        | 52/455 [03:22<24:34,  3.66s/it]

Inputs shape: torch.Size([1, 713])


 12%|█▏        | 53/455 [03:25<22:49,  3.41s/it]

Inputs shape: torch.Size([1, 458])


 12%|█▏        | 54/455 [03:28<21:50,  3.27s/it]

Inputs shape: torch.Size([1, 771])


 12%|█▏        | 55/455 [03:34<27:07,  4.07s/it]

Inputs shape: torch.Size([1, 547])


 12%|█▏        | 56/455 [03:37<25:15,  3.80s/it]

Inputs shape: torch.Size([1, 441])


 13%|█▎        | 57/455 [03:42<27:11,  4.10s/it]

Inputs shape: torch.Size([1, 527])


 13%|█▎        | 58/455 [03:46<27:17,  4.12s/it]

Inputs shape: torch.Size([1, 525])


 13%|█▎        | 59/455 [03:49<25:29,  3.86s/it]

Inputs shape: torch.Size([1, 531])


 13%|█▎        | 60/455 [03:54<25:53,  3.93s/it]

Inputs shape: torch.Size([1, 705])


 13%|█▎        | 61/455 [03:57<24:06,  3.67s/it]

Inputs shape: torch.Size([1, 609])


 14%|█▎        | 62/455 [04:00<22:56,  3.50s/it]

Inputs shape: torch.Size([1, 503])


 14%|█▍        | 63/455 [04:04<23:32,  3.60s/it]

Inputs shape: torch.Size([1, 534])


 14%|█▍        | 64/455 [04:07<22:38,  3.48s/it]

Inputs shape: torch.Size([1, 971])


 14%|█▍        | 65/455 [04:12<26:19,  4.05s/it]

Inputs shape: torch.Size([1, 498])


 15%|█▍        | 66/455 [04:16<25:43,  3.97s/it]

Inputs shape: torch.Size([1, 499])


 15%|█▍        | 67/455 [04:20<26:31,  4.10s/it]

Inputs shape: torch.Size([1, 710])


 15%|█▍        | 68/455 [04:24<26:29,  4.11s/it]

Inputs shape: torch.Size([1, 576])


 15%|█▌        | 69/455 [04:29<27:12,  4.23s/it]

Inputs shape: torch.Size([1, 572])


 15%|█▌        | 70/455 [04:33<25:59,  4.05s/it]

Inputs shape: torch.Size([1, 618])


 16%|█▌        | 71/455 [04:36<25:09,  3.93s/it]

Inputs shape: torch.Size([1, 422])


 16%|█▌        | 72/455 [04:40<24:09,  3.78s/it]

Inputs shape: torch.Size([1, 535])


 16%|█▌        | 73/455 [04:43<23:20,  3.67s/it]

Inputs shape: torch.Size([1, 625])


 16%|█▋        | 74/455 [04:46<22:04,  3.48s/it]

Inputs shape: torch.Size([1, 679])


 16%|█▋        | 75/455 [04:50<21:57,  3.47s/it]

Inputs shape: torch.Size([1, 509])


 17%|█▋        | 76/455 [04:53<21:17,  3.37s/it]

Inputs shape: torch.Size([1, 713])


 17%|█▋        | 77/455 [04:58<24:17,  3.86s/it]

Inputs shape: torch.Size([1, 455])


 17%|█▋        | 78/455 [05:01<22:43,  3.62s/it]

Inputs shape: torch.Size([1, 454])


 17%|█▋        | 79/455 [05:05<23:31,  3.75s/it]

Inputs shape: torch.Size([1, 694])


 18%|█▊        | 80/455 [05:09<24:03,  3.85s/it]

Inputs shape: torch.Size([1, 540])


 18%|█▊        | 81/455 [05:12<22:53,  3.67s/it]

Inputs shape: torch.Size([1, 609])


 18%|█▊        | 82/455 [05:15<21:36,  3.48s/it]

Inputs shape: torch.Size([1, 556])


 18%|█▊        | 83/455 [05:18<21:18,  3.44s/it]

Inputs shape: torch.Size([1, 489])


 18%|█▊        | 84/455 [05:22<20:35,  3.33s/it]

Inputs shape: torch.Size([1, 520])


 19%|█▊        | 85/455 [05:25<21:01,  3.41s/it]

Inputs shape: torch.Size([1, 599])


 19%|█▉        | 86/455 [05:29<21:25,  3.48s/it]

Inputs shape: torch.Size([1, 629])


 19%|█▉        | 87/455 [05:34<24:01,  3.92s/it]

Inputs shape: torch.Size([1, 412])


 19%|█▉        | 88/455 [05:37<23:04,  3.77s/it]

Inputs shape: torch.Size([1, 655])


 20%|█▉        | 89/455 [05:40<21:49,  3.58s/it]

Inputs shape: torch.Size([1, 581])


 20%|█▉        | 90/455 [05:43<20:43,  3.41s/it]

Inputs shape: torch.Size([1, 557])


 20%|██        | 91/455 [05:47<20:26,  3.37s/it]

Inputs shape: torch.Size([1, 480])


 20%|██        | 92/455 [05:49<19:17,  3.19s/it]

Inputs shape: torch.Size([1, 678])


 20%|██        | 93/455 [05:52<18:21,  3.04s/it]

Inputs shape: torch.Size([1, 591])


 21%|██        | 94/455 [05:57<22:07,  3.68s/it]

Inputs shape: torch.Size([1, 677])


 21%|██        | 95/455 [06:01<23:02,  3.84s/it]

Inputs shape: torch.Size([1, 626])


 21%|██        | 96/455 [06:07<25:21,  4.24s/it]

Inputs shape: torch.Size([1, 445])


 21%|██▏       | 97/455 [06:09<21:37,  3.62s/it]

Inputs shape: torch.Size([1, 684])


 22%|██▏       | 98/455 [06:13<22:51,  3.84s/it]

Inputs shape: torch.Size([1, 677])


 22%|██▏       | 99/455 [06:16<20:59,  3.54s/it]

Inputs shape: torch.Size([1, 719])


 22%|██▏       | 100/455 [06:21<22:40,  3.83s/it]

Inputs shape: torch.Size([1, 669])


 22%|██▏       | 101/455 [06:24<22:04,  3.74s/it]

Inputs shape: torch.Size([1, 408])


 22%|██▏       | 102/455 [06:28<22:04,  3.75s/it]

Inputs shape: torch.Size([1, 664])


 23%|██▎       | 103/455 [06:32<23:20,  3.98s/it]

Inputs shape: torch.Size([1, 741])


 23%|██▎       | 104/455 [06:36<22:24,  3.83s/it]

Inputs shape: torch.Size([1, 628])


 23%|██▎       | 105/455 [06:40<23:17,  3.99s/it]

Inputs shape: torch.Size([1, 603])


 23%|██▎       | 106/455 [06:43<21:30,  3.70s/it]

Inputs shape: torch.Size([1, 715])


 24%|██▎       | 107/455 [06:49<24:16,  4.19s/it]

Inputs shape: torch.Size([1, 698])


 24%|██▎       | 108/455 [06:52<22:37,  3.91s/it]

Inputs shape: torch.Size([1, 776])


 24%|██▍       | 109/455 [06:56<23:53,  4.14s/it]

Inputs shape: torch.Size([1, 606])


 24%|██▍       | 110/455 [07:00<22:42,  3.95s/it]

Inputs shape: torch.Size([1, 606])


 24%|██▍       | 111/455 [07:04<21:55,  3.83s/it]

Inputs shape: torch.Size([1, 439])


 25%|██▍       | 112/455 [07:07<20:33,  3.60s/it]

Inputs shape: torch.Size([1, 559])


 25%|██▍       | 113/455 [07:09<19:16,  3.38s/it]

Inputs shape: torch.Size([1, 630])


 25%|██▌       | 114/455 [07:13<19:42,  3.47s/it]

Inputs shape: torch.Size([1, 625])


 25%|██▌       | 115/455 [07:16<18:46,  3.31s/it]

Inputs shape: torch.Size([1, 576])


 25%|██▌       | 116/455 [07:21<21:23,  3.79s/it]

Inputs shape: torch.Size([1, 798])


 26%|██▌       | 117/455 [07:25<21:47,  3.87s/it]

Inputs shape: torch.Size([1, 579])


 26%|██▌       | 118/455 [07:28<20:58,  3.73s/it]

Inputs shape: torch.Size([1, 610])


 26%|██▌       | 119/455 [07:33<22:05,  3.95s/it]

Inputs shape: torch.Size([1, 553])


 26%|██▋       | 120/455 [07:36<21:25,  3.84s/it]

Inputs shape: torch.Size([1, 654])


 27%|██▋       | 121/455 [07:39<19:29,  3.50s/it]

Inputs shape: torch.Size([1, 555])


 27%|██▋       | 122/455 [07:42<18:21,  3.31s/it]

Inputs shape: torch.Size([1, 500])


 27%|██▋       | 123/455 [07:45<17:17,  3.13s/it]

Inputs shape: torch.Size([1, 583])


 27%|██▋       | 124/455 [07:48<17:50,  3.23s/it]

Inputs shape: torch.Size([1, 633])


 27%|██▋       | 125/455 [07:51<16:38,  3.03s/it]

Inputs shape: torch.Size([1, 598])


 28%|██▊       | 126/455 [07:54<16:21,  2.98s/it]

Inputs shape: torch.Size([1, 399])


 28%|██▊       | 127/455 [07:57<16:55,  3.10s/it]

Inputs shape: torch.Size([1, 536])


 28%|██▊       | 128/455 [08:00<16:48,  3.09s/it]

Inputs shape: torch.Size([1, 537])


 28%|██▊       | 129/455 [08:05<19:28,  3.58s/it]

Inputs shape: torch.Size([1, 701])


 29%|██▊       | 130/455 [08:11<23:29,  4.34s/it]

Inputs shape: torch.Size([1, 474])


 29%|██▉       | 131/455 [08:15<22:32,  4.17s/it]

Inputs shape: torch.Size([1, 579])


 29%|██▉       | 132/455 [08:17<19:23,  3.60s/it]

Inputs shape: torch.Size([1, 634])


 29%|██▉       | 133/455 [08:22<21:08,  3.94s/it]

Inputs shape: torch.Size([1, 694])


 29%|██▉       | 134/455 [08:25<19:51,  3.71s/it]

Inputs shape: torch.Size([1, 602])


 30%|██▉       | 135/455 [08:30<22:06,  4.14s/it]

Inputs shape: torch.Size([1, 508])


 30%|██▉       | 136/455 [08:33<20:23,  3.84s/it]

Inputs shape: torch.Size([1, 635])


 30%|███       | 137/455 [08:37<19:47,  3.74s/it]

Inputs shape: torch.Size([1, 592])


 30%|███       | 138/455 [08:41<21:09,  4.01s/it]

Inputs shape: torch.Size([1, 828])


 31%|███       | 139/455 [08:46<21:46,  4.14s/it]

Inputs shape: torch.Size([1, 616])


 31%|███       | 140/455 [08:50<22:31,  4.29s/it]

Inputs shape: torch.Size([1, 543])


 31%|███       | 141/455 [08:55<22:16,  4.26s/it]

Inputs shape: torch.Size([1, 521])


 31%|███       | 142/455 [08:58<21:38,  4.15s/it]

Inputs shape: torch.Size([1, 547])


 31%|███▏      | 143/455 [09:02<20:48,  4.00s/it]

Inputs shape: torch.Size([1, 606])


 32%|███▏      | 144/455 [09:06<20:32,  3.96s/it]

Inputs shape: torch.Size([1, 513])


 32%|███▏      | 145/455 [09:09<19:16,  3.73s/it]

Inputs shape: torch.Size([1, 573])


 32%|███▏      | 146/455 [09:14<20:58,  4.07s/it]

Inputs shape: torch.Size([1, 633])


 32%|███▏      | 147/455 [09:18<20:38,  4.02s/it]

Inputs shape: torch.Size([1, 585])


 33%|███▎      | 148/455 [09:22<20:51,  4.08s/it]

Inputs shape: torch.Size([1, 505])


 33%|███▎      | 149/455 [09:25<19:04,  3.74s/it]

Inputs shape: torch.Size([1, 619])


 33%|███▎      | 150/455 [09:29<18:58,  3.73s/it]

Inputs shape: torch.Size([1, 684])


 33%|███▎      | 151/455 [09:33<20:20,  4.01s/it]

Inputs shape: torch.Size([1, 755])


 33%|███▎      | 152/455 [09:38<20:27,  4.05s/it]

Inputs shape: torch.Size([1, 797])


 34%|███▎      | 153/455 [09:43<21:45,  4.32s/it]

Inputs shape: torch.Size([1, 403])


 34%|███▍      | 154/455 [09:46<20:31,  4.09s/it]

Inputs shape: torch.Size([1, 591])


 34%|███▍      | 155/455 [09:51<21:57,  4.39s/it]

Inputs shape: torch.Size([1, 613])


 34%|███▍      | 156/455 [09:55<21:16,  4.27s/it]

Inputs shape: torch.Size([1, 644])


 35%|███▍      | 157/455 [09:58<19:21,  3.90s/it]

Inputs shape: torch.Size([1, 712])


 35%|███▍      | 158/455 [10:02<19:42,  3.98s/it]

Inputs shape: torch.Size([1, 564])


 35%|███▍      | 159/455 [10:06<19:34,  3.97s/it]

Inputs shape: torch.Size([1, 903])


 35%|███▌      | 160/455 [10:10<18:57,  3.85s/it]

Inputs shape: torch.Size([1, 863])


 35%|███▌      | 161/455 [10:15<20:56,  4.27s/it]

Inputs shape: torch.Size([1, 562])


 36%|███▌      | 162/455 [10:20<21:17,  4.36s/it]

Inputs shape: torch.Size([1, 792])


 36%|███▌      | 163/455 [10:24<21:02,  4.33s/it]

Inputs shape: torch.Size([1, 671])


 36%|███▌      | 164/455 [10:28<20:23,  4.20s/it]

Inputs shape: torch.Size([1, 502])


 36%|███▋      | 165/455 [10:32<19:55,  4.12s/it]

Inputs shape: torch.Size([1, 631])


 36%|███▋      | 166/455 [10:35<18:01,  3.74s/it]

Inputs shape: torch.Size([1, 598])


 37%|███▋      | 167/455 [10:38<17:08,  3.57s/it]

Inputs shape: torch.Size([1, 725])


 37%|███▋      | 168/455 [10:42<17:40,  3.70s/it]

Inputs shape: torch.Size([1, 616])


 37%|███▋      | 169/455 [10:46<18:41,  3.92s/it]

Inputs shape: torch.Size([1, 535])


 37%|███▋      | 170/455 [10:50<17:56,  3.78s/it]

Inputs shape: torch.Size([1, 772])


 38%|███▊      | 171/455 [10:54<18:47,  3.97s/it]

Inputs shape: torch.Size([1, 565])


 38%|███▊      | 172/455 [10:57<17:29,  3.71s/it]

Inputs shape: torch.Size([1, 746])


 38%|███▊      | 173/455 [11:01<17:08,  3.65s/it]

Inputs shape: torch.Size([1, 620])


 38%|███▊      | 174/455 [11:06<19:53,  4.25s/it]

Inputs shape: torch.Size([1, 593])


 38%|███▊      | 175/455 [11:11<19:55,  4.27s/it]

Inputs shape: torch.Size([1, 482])


 39%|███▊      | 176/455 [11:14<17:45,  3.82s/it]

Inputs shape: torch.Size([1, 599])


 39%|███▉      | 177/455 [11:18<18:19,  3.95s/it]

Inputs shape: torch.Size([1, 606])


 39%|███▉      | 178/455 [11:21<17:23,  3.77s/it]

Inputs shape: torch.Size([1, 712])


 39%|███▉      | 179/455 [11:26<18:49,  4.09s/it]

Inputs shape: torch.Size([1, 592])


 40%|███▉      | 180/455 [11:30<18:05,  3.95s/it]

Inputs shape: torch.Size([1, 579])


 40%|███▉      | 181/455 [11:33<16:41,  3.65s/it]

Inputs shape: torch.Size([1, 799])


 40%|████      | 182/455 [11:36<16:01,  3.52s/it]

Inputs shape: torch.Size([1, 578])


 40%|████      | 183/455 [11:39<15:56,  3.52s/it]

Inputs shape: torch.Size([1, 669])


 40%|████      | 184/455 [11:42<14:55,  3.30s/it]

Inputs shape: torch.Size([1, 586])


 41%|████      | 185/455 [11:46<15:33,  3.46s/it]

Inputs shape: torch.Size([1, 652])


 41%|████      | 186/455 [11:50<15:44,  3.51s/it]

Inputs shape: torch.Size([1, 506])


 41%|████      | 187/455 [11:53<15:05,  3.38s/it]

Inputs shape: torch.Size([1, 606])


 41%|████▏     | 188/455 [11:56<15:07,  3.40s/it]

Inputs shape: torch.Size([1, 734])


 42%|████▏     | 189/455 [12:00<16:13,  3.66s/it]

Inputs shape: torch.Size([1, 499])


 42%|████▏     | 190/455 [12:04<15:34,  3.53s/it]

Inputs shape: torch.Size([1, 848])


 42%|████▏     | 191/455 [12:09<17:47,  4.04s/it]

Inputs shape: torch.Size([1, 488])


 42%|████▏     | 192/455 [12:13<18:02,  4.12s/it]

Inputs shape: torch.Size([1, 422])


 42%|████▏     | 193/455 [12:16<17:04,  3.91s/it]

Inputs shape: torch.Size([1, 509])


 43%|████▎     | 194/455 [12:21<17:12,  3.96s/it]

Inputs shape: torch.Size([1, 618])


 43%|████▎     | 195/455 [12:24<16:34,  3.83s/it]

Inputs shape: torch.Size([1, 595])


 43%|████▎     | 196/455 [12:27<15:56,  3.69s/it]

Inputs shape: torch.Size([1, 700])


 43%|████▎     | 197/455 [12:30<14:57,  3.48s/it]

Inputs shape: torch.Size([1, 451])


 44%|████▎     | 198/455 [12:33<14:18,  3.34s/it]

Inputs shape: torch.Size([1, 585])


 44%|████▎     | 199/455 [12:38<15:20,  3.60s/it]

Inputs shape: torch.Size([1, 636])


 44%|████▍     | 200/455 [12:41<15:05,  3.55s/it]

Inputs shape: torch.Size([1, 713])


 44%|████▍     | 201/455 [12:45<16:01,  3.79s/it]

Inputs shape: torch.Size([1, 541])


 44%|████▍     | 202/455 [12:50<17:24,  4.13s/it]

Inputs shape: torch.Size([1, 683])


 45%|████▍     | 203/455 [12:54<16:10,  3.85s/it]

Inputs shape: torch.Size([1, 954])


 45%|████▍     | 204/455 [12:59<17:40,  4.23s/it]

Inputs shape: torch.Size([1, 551])


 45%|████▌     | 205/455 [13:02<16:10,  3.88s/it]

Inputs shape: torch.Size([1, 674])


 45%|████▌     | 206/455 [13:05<15:14,  3.67s/it]

Inputs shape: torch.Size([1, 476])


 45%|████▌     | 207/455 [13:08<14:29,  3.51s/it]

Inputs shape: torch.Size([1, 529])


 46%|████▌     | 208/455 [13:12<14:28,  3.52s/it]

Inputs shape: torch.Size([1, 741])


 46%|████▌     | 209/455 [13:17<16:37,  4.06s/it]

Inputs shape: torch.Size([1, 567])


 46%|████▌     | 210/455 [13:20<15:57,  3.91s/it]

Inputs shape: torch.Size([1, 550])


 46%|████▋     | 211/455 [13:24<15:35,  3.83s/it]

Inputs shape: torch.Size([1, 570])


 47%|████▋     | 212/455 [13:26<13:43,  3.39s/it]

Inputs shape: torch.Size([1, 676])


 47%|████▋     | 213/455 [13:30<13:56,  3.46s/it]

Inputs shape: torch.Size([1, 576])


 47%|████▋     | 214/455 [13:34<13:56,  3.47s/it]

Inputs shape: torch.Size([1, 606])


 47%|████▋     | 215/455 [13:37<13:28,  3.37s/it]

Inputs shape: torch.Size([1, 662])


 47%|████▋     | 216/455 [13:41<14:48,  3.72s/it]

Inputs shape: torch.Size([1, 618])


 48%|████▊     | 217/455 [13:46<15:25,  3.89s/it]

Inputs shape: torch.Size([1, 669])


 48%|████▊     | 218/455 [13:48<13:58,  3.54s/it]

Inputs shape: torch.Size([1, 624])


 48%|████▊     | 219/455 [13:52<14:02,  3.57s/it]

Inputs shape: torch.Size([1, 757])


 48%|████▊     | 220/455 [13:54<12:20,  3.15s/it]

Inputs shape: torch.Size([1, 695])


 49%|████▊     | 221/455 [13:58<13:16,  3.41s/it]

Inputs shape: torch.Size([1, 564])


 49%|████▉     | 222/455 [14:02<13:19,  3.43s/it]

Inputs shape: torch.Size([1, 258])


 49%|████▉     | 223/455 [14:05<13:25,  3.47s/it]

Inputs shape: torch.Size([1, 838])


 49%|████▉     | 224/455 [14:08<13:03,  3.39s/it]

Inputs shape: torch.Size([1, 871])


 49%|████▉     | 225/455 [14:13<13:58,  3.64s/it]

Inputs shape: torch.Size([1, 381])


 50%|████▉     | 226/455 [14:15<12:57,  3.40s/it]

Inputs shape: torch.Size([1, 585])


 50%|████▉     | 227/455 [14:19<13:17,  3.50s/it]

Inputs shape: torch.Size([1, 755])


 50%|█████     | 228/455 [14:23<13:21,  3.53s/it]

Inputs shape: torch.Size([1, 642])


 50%|█████     | 229/455 [14:26<13:21,  3.54s/it]

Inputs shape: torch.Size([1, 866])


 51%|█████     | 230/455 [14:32<15:21,  4.09s/it]

Inputs shape: torch.Size([1, 565])


 51%|█████     | 231/455 [14:35<14:22,  3.85s/it]

Inputs shape: torch.Size([1, 532])


 51%|█████     | 232/455 [14:38<13:37,  3.67s/it]

Inputs shape: torch.Size([1, 604])


 51%|█████     | 233/455 [14:42<13:14,  3.58s/it]

Inputs shape: torch.Size([1, 667])


 51%|█████▏    | 234/455 [14:45<13:12,  3.59s/it]

Inputs shape: torch.Size([1, 692])


 52%|█████▏    | 235/455 [14:49<13:06,  3.58s/it]

Inputs shape: torch.Size([1, 607])


 52%|█████▏    | 236/455 [14:53<13:18,  3.64s/it]

Inputs shape: torch.Size([1, 651])


 52%|█████▏    | 237/455 [14:56<12:38,  3.48s/it]

Inputs shape: torch.Size([1, 621])


 52%|█████▏    | 238/455 [15:00<13:52,  3.84s/it]

Inputs shape: torch.Size([1, 579])


 53%|█████▎    | 239/455 [15:05<14:35,  4.05s/it]

Inputs shape: torch.Size([1, 612])


 53%|█████▎    | 240/455 [15:09<14:46,  4.12s/it]

Inputs shape: torch.Size([1, 508])


 53%|█████▎    | 241/455 [15:13<14:27,  4.05s/it]

Inputs shape: torch.Size([1, 371])


 53%|█████▎    | 242/455 [15:17<14:28,  4.08s/it]

Inputs shape: torch.Size([1, 509])


 53%|█████▎    | 243/455 [15:21<14:11,  4.01s/it]

Inputs shape: torch.Size([1, 653])


 54%|█████▎    | 244/455 [15:25<14:17,  4.06s/it]

Inputs shape: torch.Size([1, 482])


 54%|█████▍    | 245/455 [15:30<14:28,  4.14s/it]

Inputs shape: torch.Size([1, 721])


 54%|█████▍    | 246/455 [15:35<15:31,  4.46s/it]

Inputs shape: torch.Size([1, 608])


 54%|█████▍    | 247/455 [15:39<15:37,  4.51s/it]

Inputs shape: torch.Size([1, 725])


 55%|█████▍    | 248/455 [15:44<15:33,  4.51s/it]

Inputs shape: torch.Size([1, 736])


 55%|█████▍    | 249/455 [15:47<14:06,  4.11s/it]

Inputs shape: torch.Size([1, 678])


 55%|█████▍    | 250/455 [15:50<12:27,  3.65s/it]

Inputs shape: torch.Size([1, 607])


 55%|█████▌    | 251/455 [15:54<12:43,  3.74s/it]

Inputs shape: torch.Size([1, 462])


 55%|█████▌    | 252/455 [15:58<13:05,  3.87s/it]

Inputs shape: torch.Size([1, 528])


 56%|█████▌    | 253/455 [16:01<12:08,  3.61s/it]

Inputs shape: torch.Size([1, 654])


 56%|█████▌    | 254/455 [16:07<14:34,  4.35s/it]

Inputs shape: torch.Size([1, 486])


 56%|█████▌    | 255/455 [16:10<12:51,  3.86s/it]

Inputs shape: torch.Size([1, 684])


 56%|█████▋    | 256/455 [16:13<12:52,  3.88s/it]

Inputs shape: torch.Size([1, 567])


 56%|█████▋    | 257/455 [16:16<11:24,  3.46s/it]

Inputs shape: torch.Size([1, 590])


 57%|█████▋    | 258/455 [16:19<10:42,  3.26s/it]

Inputs shape: torch.Size([1, 698])


 57%|█████▋    | 259/455 [16:23<11:48,  3.62s/it]

Inputs shape: torch.Size([1, 689])


 57%|█████▋    | 260/455 [16:27<12:20,  3.80s/it]

Inputs shape: torch.Size([1, 459])


 57%|█████▋    | 261/455 [16:31<12:05,  3.74s/it]

Inputs shape: torch.Size([1, 583])


 58%|█████▊    | 262/455 [16:35<12:20,  3.83s/it]

Inputs shape: torch.Size([1, 630])


 58%|█████▊    | 263/455 [16:40<13:00,  4.07s/it]

Inputs shape: torch.Size([1, 586])


 58%|█████▊    | 264/455 [16:43<12:13,  3.84s/it]

Inputs shape: torch.Size([1, 392])


 58%|█████▊    | 265/455 [16:47<12:36,  3.98s/it]

Inputs shape: torch.Size([1, 573])


 58%|█████▊    | 266/455 [16:51<11:58,  3.80s/it]

Inputs shape: torch.Size([1, 470])


 59%|█████▊    | 267/455 [16:53<10:15,  3.27s/it]

Inputs shape: torch.Size([1, 724])


 59%|█████▉    | 268/455 [16:56<10:18,  3.31s/it]

Inputs shape: torch.Size([1, 664])


 59%|█████▉    | 269/455 [17:01<11:36,  3.74s/it]

Inputs shape: torch.Size([1, 515])


 59%|█████▉    | 270/455 [17:04<11:24,  3.70s/it]

Inputs shape: torch.Size([1, 573])


 60%|█████▉    | 271/455 [17:08<11:27,  3.73s/it]

Inputs shape: torch.Size([1, 704])


 60%|█████▉    | 272/455 [17:12<11:11,  3.67s/it]

Inputs shape: torch.Size([1, 824])


 60%|██████    | 273/455 [17:17<12:10,  4.01s/it]

Inputs shape: torch.Size([1, 516])


 60%|██████    | 274/455 [17:20<11:49,  3.92s/it]

Inputs shape: torch.Size([1, 545])


 60%|██████    | 275/455 [17:24<11:15,  3.75s/it]

Inputs shape: torch.Size([1, 591])


 61%|██████    | 276/455 [17:27<10:59,  3.68s/it]

Inputs shape: torch.Size([1, 622])


 61%|██████    | 277/455 [17:32<12:18,  4.15s/it]

Inputs shape: torch.Size([1, 495])


 61%|██████    | 278/455 [17:35<11:13,  3.81s/it]

Inputs shape: torch.Size([1, 621])


 61%|██████▏   | 279/455 [17:39<11:07,  3.79s/it]

Inputs shape: torch.Size([1, 570])


 62%|██████▏   | 280/455 [17:42<10:35,  3.63s/it]

Inputs shape: torch.Size([1, 603])


 62%|██████▏   | 281/455 [17:46<10:10,  3.51s/it]

Inputs shape: torch.Size([1, 532])


 62%|██████▏   | 282/455 [17:50<11:01,  3.82s/it]

Inputs shape: torch.Size([1, 456])


 62%|██████▏   | 283/455 [17:54<10:48,  3.77s/it]

Inputs shape: torch.Size([1, 535])


 62%|██████▏   | 284/455 [17:56<09:38,  3.38s/it]

Inputs shape: torch.Size([1, 635])


 63%|██████▎   | 285/455 [18:00<09:34,  3.38s/it]

Inputs shape: torch.Size([1, 372])


 63%|██████▎   | 286/455 [18:04<10:14,  3.63s/it]

Inputs shape: torch.Size([1, 703])


 63%|██████▎   | 287/455 [18:09<11:31,  4.12s/it]

Inputs shape: torch.Size([1, 551])


 63%|██████▎   | 288/455 [18:13<10:55,  3.92s/it]

Inputs shape: torch.Size([1, 622])


 64%|██████▎   | 289/455 [18:17<10:48,  3.91s/it]

Inputs shape: torch.Size([1, 570])


 64%|██████▎   | 290/455 [18:21<11:25,  4.15s/it]

Inputs shape: torch.Size([1, 640])


 64%|██████▍   | 291/455 [18:26<11:38,  4.26s/it]

Inputs shape: torch.Size([1, 409])


 64%|██████▍   | 292/455 [18:30<11:49,  4.35s/it]

Inputs shape: torch.Size([1, 747])


 64%|██████▍   | 293/455 [18:34<11:26,  4.24s/it]

Inputs shape: torch.Size([1, 762])


 65%|██████▍   | 294/455 [18:40<12:19,  4.59s/it]

Inputs shape: torch.Size([1, 541])


 65%|██████▍   | 295/455 [18:43<11:01,  4.13s/it]

Inputs shape: torch.Size([1, 512])


 65%|██████▌   | 296/455 [18:47<11:00,  4.16s/it]

Inputs shape: torch.Size([1, 709])


 65%|██████▌   | 297/455 [18:52<11:34,  4.40s/it]

Inputs shape: torch.Size([1, 634])


 65%|██████▌   | 298/455 [18:55<10:25,  3.99s/it]

Inputs shape: torch.Size([1, 657])


 66%|██████▌   | 299/455 [18:58<09:51,  3.79s/it]

Inputs shape: torch.Size([1, 500])


 66%|██████▌   | 300/455 [19:02<09:36,  3.72s/it]

Inputs shape: torch.Size([1, 575])


 66%|██████▌   | 301/455 [19:06<09:48,  3.82s/it]

Inputs shape: torch.Size([1, 596])


 66%|██████▋   | 302/455 [19:09<09:29,  3.72s/it]

Inputs shape: torch.Size([1, 491])


 67%|██████▋   | 303/455 [19:13<09:30,  3.75s/it]

Inputs shape: torch.Size([1, 499])


 67%|██████▋   | 304/455 [19:16<08:57,  3.56s/it]

Inputs shape: torch.Size([1, 566])


 67%|██████▋   | 305/455 [19:20<09:19,  3.73s/it]

Inputs shape: torch.Size([1, 553])


 67%|██████▋   | 306/455 [19:24<09:04,  3.65s/it]

Inputs shape: torch.Size([1, 868])


 67%|██████▋   | 307/455 [19:28<09:37,  3.90s/it]

Inputs shape: torch.Size([1, 430])


 68%|██████▊   | 308/455 [19:33<09:42,  3.96s/it]

Inputs shape: torch.Size([1, 458])


 68%|██████▊   | 309/455 [19:37<09:48,  4.03s/it]

Inputs shape: torch.Size([1, 612])


 68%|██████▊   | 310/455 [19:39<08:40,  3.59s/it]

Inputs shape: torch.Size([1, 525])


 68%|██████▊   | 311/455 [19:43<08:24,  3.50s/it]

Inputs shape: torch.Size([1, 570])


 69%|██████▊   | 312/455 [19:46<08:26,  3.54s/it]

Inputs shape: torch.Size([1, 467])


 69%|██████▉   | 313/455 [19:50<08:24,  3.56s/it]

Inputs shape: torch.Size([1, 636])


 69%|██████▉   | 314/455 [19:54<08:27,  3.60s/it]

Inputs shape: torch.Size([1, 612])


 69%|██████▉   | 315/455 [19:57<08:34,  3.67s/it]

Inputs shape: torch.Size([1, 487])


 69%|██████▉   | 316/455 [20:01<08:45,  3.78s/it]

Inputs shape: torch.Size([1, 557])


 70%|██████▉   | 317/455 [20:05<08:50,  3.84s/it]

Inputs shape: torch.Size([1, 517])


 70%|██████▉   | 318/455 [20:09<08:35,  3.76s/it]

Inputs shape: torch.Size([1, 559])


 70%|███████   | 319/455 [20:12<08:08,  3.60s/it]

Inputs shape: torch.Size([1, 558])


 70%|███████   | 320/455 [20:16<08:18,  3.69s/it]

Inputs shape: torch.Size([1, 709])


 71%|███████   | 321/455 [20:20<08:06,  3.63s/it]

Inputs shape: torch.Size([1, 570])


 71%|███████   | 322/455 [20:23<07:37,  3.44s/it]

Inputs shape: torch.Size([1, 714])


 71%|███████   | 323/455 [20:26<07:44,  3.52s/it]

Inputs shape: torch.Size([1, 709])


 71%|███████   | 324/455 [20:31<08:16,  3.79s/it]

Inputs shape: torch.Size([1, 571])


 71%|███████▏  | 325/455 [20:34<08:05,  3.73s/it]

Inputs shape: torch.Size([1, 1012])


 72%|███████▏  | 326/455 [20:40<09:04,  4.22s/it]

Inputs shape: torch.Size([1, 616])


 72%|███████▏  | 327/455 [20:43<08:16,  3.88s/it]

Inputs shape: torch.Size([1, 549])


 72%|███████▏  | 328/455 [20:46<08:01,  3.79s/it]

Inputs shape: torch.Size([1, 608])


 72%|███████▏  | 329/455 [20:50<07:42,  3.67s/it]

Inputs shape: torch.Size([1, 641])


 73%|███████▎  | 330/455 [20:54<07:46,  3.73s/it]

Inputs shape: torch.Size([1, 676])


 73%|███████▎  | 331/455 [20:59<08:55,  4.32s/it]

Inputs shape: torch.Size([1, 820])


 73%|███████▎  | 332/455 [21:03<08:41,  4.24s/it]

Inputs shape: torch.Size([1, 617])


 73%|███████▎  | 333/455 [21:08<08:36,  4.23s/it]

Inputs shape: torch.Size([1, 654])


 73%|███████▎  | 334/455 [21:11<07:49,  3.88s/it]

Inputs shape: torch.Size([1, 825])


 74%|███████▎  | 335/455 [21:14<07:29,  3.75s/it]

Inputs shape: torch.Size([1, 555])


 74%|███████▍  | 336/455 [21:17<07:12,  3.63s/it]

Inputs shape: torch.Size([1, 558])


 74%|███████▍  | 337/455 [21:22<07:38,  3.88s/it]

Inputs shape: torch.Size([1, 779])


 74%|███████▍  | 338/455 [21:27<08:07,  4.16s/it]

Inputs shape: torch.Size([1, 819])


 75%|███████▍  | 339/455 [21:30<07:47,  4.03s/it]

Inputs shape: torch.Size([1, 669])


 75%|███████▍  | 340/455 [21:34<07:31,  3.93s/it]

Inputs shape: torch.Size([1, 833])


 75%|███████▍  | 341/455 [21:38<07:19,  3.86s/it]

Inputs shape: torch.Size([1, 607])


 75%|███████▌  | 342/455 [21:42<07:12,  3.83s/it]

Inputs shape: torch.Size([1, 697])


 75%|███████▌  | 343/455 [21:46<07:29,  4.02s/it]

Inputs shape: torch.Size([1, 572])


 76%|███████▌  | 344/455 [21:50<07:22,  3.99s/it]

Inputs shape: torch.Size([1, 545])


 76%|███████▌  | 345/455 [21:54<07:14,  3.95s/it]

Inputs shape: torch.Size([1, 564])


 76%|███████▌  | 346/455 [21:57<06:36,  3.63s/it]

Inputs shape: torch.Size([1, 751])


 76%|███████▋  | 347/455 [22:01<06:49,  3.79s/it]

Inputs shape: torch.Size([1, 678])


 76%|███████▋  | 348/455 [22:05<06:59,  3.92s/it]

Inputs shape: torch.Size([1, 536])


 77%|███████▋  | 349/455 [22:08<06:28,  3.67s/it]

Inputs shape: torch.Size([1, 536])


 77%|███████▋  | 350/455 [22:12<06:25,  3.67s/it]

Inputs shape: torch.Size([1, 817])


 77%|███████▋  | 351/455 [22:17<07:22,  4.26s/it]

Inputs shape: torch.Size([1, 532])


 77%|███████▋  | 352/455 [22:20<06:34,  3.83s/it]

Inputs shape: torch.Size([1, 630])


 78%|███████▊  | 353/455 [22:25<06:52,  4.04s/it]

Inputs shape: torch.Size([1, 287])


 78%|███████▊  | 354/455 [22:29<06:45,  4.02s/it]

Inputs shape: torch.Size([1, 607])


 78%|███████▊  | 355/455 [22:33<06:48,  4.08s/it]

Inputs shape: torch.Size([1, 712])


 78%|███████▊  | 356/455 [22:36<06:14,  3.78s/it]

Inputs shape: torch.Size([1, 580])


 78%|███████▊  | 357/455 [22:40<06:19,  3.88s/it]

Inputs shape: torch.Size([1, 728])


 79%|███████▊  | 358/455 [22:44<06:24,  3.96s/it]

Inputs shape: torch.Size([1, 581])


 79%|███████▉  | 359/455 [22:47<05:55,  3.70s/it]

Inputs shape: torch.Size([1, 589])


 79%|███████▉  | 360/455 [22:51<05:54,  3.73s/it]

Inputs shape: torch.Size([1, 599])


 79%|███████▉  | 361/455 [22:55<05:49,  3.71s/it]

Inputs shape: torch.Size([1, 618])


 80%|███████▉  | 362/455 [23:00<06:17,  4.05s/it]

Inputs shape: torch.Size([1, 658])


 80%|███████▉  | 363/455 [23:04<06:13,  4.06s/it]

Inputs shape: torch.Size([1, 549])


 80%|████████  | 364/455 [23:07<05:47,  3.82s/it]

Inputs shape: torch.Size([1, 401])


 80%|████████  | 365/455 [23:11<05:42,  3.80s/it]

Inputs shape: torch.Size([1, 766])


 80%|████████  | 366/455 [23:15<05:54,  3.98s/it]

Inputs shape: torch.Size([1, 672])


 81%|████████  | 367/455 [23:20<05:59,  4.09s/it]

Inputs shape: torch.Size([1, 764])


 81%|████████  | 368/455 [23:24<05:53,  4.07s/it]

Inputs shape: torch.Size([1, 718])


 81%|████████  | 369/455 [23:28<05:50,  4.08s/it]

Inputs shape: torch.Size([1, 676])


 81%|████████▏ | 370/455 [23:32<05:55,  4.19s/it]

Inputs shape: torch.Size([1, 755])


 82%|████████▏ | 371/455 [23:35<05:30,  3.94s/it]

Inputs shape: torch.Size([1, 950])


 82%|████████▏ | 372/455 [23:40<05:49,  4.22s/it]

Inputs shape: torch.Size([1, 630])


 82%|████████▏ | 373/455 [23:44<05:24,  3.96s/it]

Inputs shape: torch.Size([1, 690])


 82%|████████▏ | 374/455 [23:49<06:03,  4.49s/it]

Inputs shape: torch.Size([1, 535])


 82%|████████▏ | 375/455 [23:53<05:31,  4.14s/it]

Inputs shape: torch.Size([1, 585])


 83%|████████▎ | 376/455 [23:56<05:07,  3.89s/it]

Inputs shape: torch.Size([1, 537])


 83%|████████▎ | 377/455 [23:59<04:44,  3.65s/it]

Inputs shape: torch.Size([1, 452])


 83%|████████▎ | 378/455 [24:02<04:20,  3.39s/it]

Inputs shape: torch.Size([1, 507])


 83%|████████▎ | 379/455 [24:06<04:38,  3.67s/it]

Inputs shape: torch.Size([1, 668])


 84%|████████▎ | 380/455 [24:10<04:25,  3.54s/it]

Inputs shape: torch.Size([1, 763])


 84%|████████▎ | 381/455 [24:14<04:37,  3.76s/it]

Inputs shape: torch.Size([1, 567])


 84%|████████▍ | 382/455 [24:18<04:51,  4.00s/it]

Inputs shape: torch.Size([1, 602])


 84%|████████▍ | 383/455 [24:22<04:36,  3.84s/it]

Inputs shape: torch.Size([1, 582])


 84%|████████▍ | 384/455 [24:25<04:09,  3.52s/it]

Inputs shape: torch.Size([1, 624])


 85%|████████▍ | 385/455 [24:29<04:27,  3.82s/it]

Inputs shape: torch.Size([1, 559])


 85%|████████▍ | 386/455 [24:32<04:10,  3.63s/it]

Inputs shape: torch.Size([1, 623])


 85%|████████▌ | 387/455 [24:34<03:21,  2.97s/it]

Inputs shape: torch.Size([1, 419])


 85%|████████▌ | 388/455 [24:37<03:31,  3.16s/it]

Inputs shape: torch.Size([1, 660])


 85%|████████▌ | 389/455 [24:41<03:45,  3.41s/it]

Inputs shape: torch.Size([1, 753])


 86%|████████▌ | 390/455 [24:46<04:08,  3.82s/it]

Inputs shape: torch.Size([1, 780])


 86%|████████▌ | 391/455 [24:51<04:32,  4.25s/it]

Inputs shape: torch.Size([1, 604])


 86%|████████▌ | 392/455 [24:56<04:38,  4.42s/it]

Inputs shape: torch.Size([1, 515])


 86%|████████▋ | 393/455 [25:02<04:59,  4.83s/it]

Inputs shape: torch.Size([1, 748])


 87%|████████▋ | 394/455 [25:05<04:21,  4.28s/it]

Inputs shape: torch.Size([1, 676])


 87%|████████▋ | 395/455 [25:08<04:01,  4.03s/it]

Inputs shape: torch.Size([1, 513])


 87%|████████▋ | 396/455 [25:12<03:52,  3.94s/it]

Inputs shape: torch.Size([1, 381])


 87%|████████▋ | 397/455 [25:14<03:19,  3.45s/it]

Inputs shape: torch.Size([1, 693])


 87%|████████▋ | 398/455 [25:19<03:40,  3.86s/it]

Inputs shape: torch.Size([1, 467])


 88%|████████▊ | 399/455 [25:23<03:37,  3.88s/it]

Inputs shape: torch.Size([1, 564])


 88%|████████▊ | 400/455 [25:26<03:21,  3.67s/it]

Inputs shape: torch.Size([1, 594])


 88%|████████▊ | 401/455 [25:30<03:21,  3.74s/it]

Inputs shape: torch.Size([1, 567])


 88%|████████▊ | 402/455 [25:33<03:04,  3.49s/it]

Inputs shape: torch.Size([1, 625])


 89%|████████▊ | 403/455 [25:37<03:14,  3.73s/it]

Inputs shape: torch.Size([1, 535])


 89%|████████▉ | 404/455 [25:42<03:23,  4.00s/it]

Inputs shape: torch.Size([1, 367])


 89%|████████▉ | 405/455 [25:46<03:14,  3.88s/it]

Inputs shape: torch.Size([1, 655])


 89%|████████▉ | 406/455 [25:50<03:18,  4.05s/it]

Inputs shape: torch.Size([1, 717])


 89%|████████▉ | 407/455 [25:54<03:09,  3.95s/it]

Inputs shape: torch.Size([1, 593])


 90%|████████▉ | 408/455 [25:58<03:01,  3.87s/it]

Inputs shape: torch.Size([1, 626])


 90%|████████▉ | 409/455 [26:02<03:07,  4.07s/it]

Inputs shape: torch.Size([1, 605])


 90%|█████████ | 410/455 [26:06<02:59,  3.98s/it]

Inputs shape: torch.Size([1, 686])


 90%|█████████ | 411/455 [26:10<03:01,  4.12s/it]

Inputs shape: torch.Size([1, 534])


 91%|█████████ | 412/455 [26:14<02:49,  3.95s/it]

Inputs shape: torch.Size([1, 561])


 91%|█████████ | 413/455 [26:20<03:11,  4.57s/it]

Inputs shape: torch.Size([1, 743])


 91%|█████████ | 414/455 [26:25<03:13,  4.71s/it]

Inputs shape: torch.Size([1, 535])


 91%|█████████ | 415/455 [26:28<02:52,  4.32s/it]

Inputs shape: torch.Size([1, 577])


 91%|█████████▏| 416/455 [26:31<02:30,  3.86s/it]

Inputs shape: torch.Size([1, 618])


 92%|█████████▏| 417/455 [26:34<02:19,  3.68s/it]

Inputs shape: torch.Size([1, 664])


 92%|█████████▏| 418/455 [26:38<02:19,  3.77s/it]

Inputs shape: torch.Size([1, 560])


 92%|█████████▏| 419/455 [26:43<02:22,  3.96s/it]

Inputs shape: torch.Size([1, 627])


 92%|█████████▏| 420/455 [26:46<02:15,  3.86s/it]

Inputs shape: torch.Size([1, 609])


 93%|█████████▎| 421/455 [26:50<02:10,  3.82s/it]

Inputs shape: torch.Size([1, 698])


 93%|█████████▎| 422/455 [26:54<02:06,  3.83s/it]

Inputs shape: torch.Size([1, 766])


 93%|█████████▎| 423/455 [26:59<02:15,  4.22s/it]

Inputs shape: torch.Size([1, 569])


 93%|█████████▎| 424/455 [27:02<01:59,  3.86s/it]

Inputs shape: torch.Size([1, 675])


 93%|█████████▎| 425/455 [27:07<02:01,  4.04s/it]

Inputs shape: torch.Size([1, 700])


 94%|█████████▎| 426/455 [27:09<01:46,  3.66s/it]

Inputs shape: torch.Size([1, 640])


 94%|█████████▍| 427/455 [27:15<01:59,  4.27s/it]

Inputs shape: torch.Size([1, 558])


 94%|█████████▍| 428/455 [27:18<01:47,  4.00s/it]

Inputs shape: torch.Size([1, 577])


 94%|█████████▍| 429/455 [27:21<01:34,  3.64s/it]

Inputs shape: torch.Size([1, 578])


 95%|█████████▍| 430/455 [27:26<01:36,  3.85s/it]

Inputs shape: torch.Size([1, 699])


 95%|█████████▍| 431/455 [27:30<01:36,  4.00s/it]

Inputs shape: torch.Size([1, 547])


 95%|█████████▍| 432/455 [27:34<01:33,  4.07s/it]

Inputs shape: torch.Size([1, 607])


 95%|█████████▌| 433/455 [27:37<01:24,  3.84s/it]

Inputs shape: torch.Size([1, 490])


 95%|█████████▌| 434/455 [27:41<01:16,  3.65s/it]

Inputs shape: torch.Size([1, 530])


 96%|█████████▌| 435/455 [27:44<01:10,  3.51s/it]

Inputs shape: torch.Size([1, 838])


 96%|█████████▌| 436/455 [27:47<01:02,  3.29s/it]

Inputs shape: torch.Size([1, 624])


 96%|█████████▌| 437/455 [27:50<01:02,  3.44s/it]

Inputs shape: torch.Size([1, 663])


 96%|█████████▋| 438/455 [27:53<00:55,  3.26s/it]

Inputs shape: torch.Size([1, 624])


 96%|█████████▋| 439/455 [27:57<00:53,  3.34s/it]

Inputs shape: torch.Size([1, 665])


 97%|█████████▋| 440/455 [28:01<00:52,  3.50s/it]

Inputs shape: torch.Size([1, 617])


 97%|█████████▋| 441/455 [28:05<00:51,  3.65s/it]

Inputs shape: torch.Size([1, 587])


 97%|█████████▋| 442/455 [28:08<00:45,  3.47s/it]

Inputs shape: torch.Size([1, 516])


 97%|█████████▋| 443/455 [28:12<00:44,  3.72s/it]

Inputs shape: torch.Size([1, 441])


 98%|█████████▊| 444/455 [28:17<00:43,  3.99s/it]

Inputs shape: torch.Size([1, 511])


 98%|█████████▊| 445/455 [28:20<00:38,  3.80s/it]

Inputs shape: torch.Size([1, 577])


 98%|█████████▊| 446/455 [28:23<00:31,  3.55s/it]

Inputs shape: torch.Size([1, 470])


 98%|█████████▊| 447/455 [28:26<00:27,  3.41s/it]

Inputs shape: torch.Size([1, 572])


 98%|█████████▊| 448/455 [28:30<00:25,  3.60s/it]

Inputs shape: torch.Size([1, 546])


 99%|█████████▊| 449/455 [28:33<00:21,  3.52s/it]

Inputs shape: torch.Size([1, 556])


 99%|█████████▉| 450/455 [28:38<00:19,  3.80s/it]

Inputs shape: torch.Size([1, 630])


 99%|█████████▉| 451/455 [28:42<00:15,  3.78s/it]

Inputs shape: torch.Size([1, 455])


 99%|█████████▉| 452/455 [28:45<00:11,  3.67s/it]

Inputs shape: torch.Size([1, 652])


100%|█████████▉| 453/455 [28:50<00:07,  3.96s/it]

Inputs shape: torch.Size([1, 568])


100%|█████████▉| 454/455 [28:53<00:03,  3.92s/it]

Inputs shape: torch.Size([1, 557])


100%|██████████| 455/455 [28:58<00:00,  3.82s/it]


In [None]:
print(predictions)
print(references)

['How about "The Last King of Scotland"? It has a powerful and emotional impact, with a deep exploration of the complexities of human nature. The characters are well-developed and the story is engaging and thought-provoking. Plus, it has a great soundtrack and a strong performance by Forest Whitaker.']
['How about "The Curious Case of Benjamin Button"? It has vibrant characters, powerful storytelling, and heartwarming moments, along with subtle messages about kindness and strength. It\'s definitely a movie with deep emotional impact and character development that I think you would appreciate.']


In [None]:
from nltk.translate.bleu_score import corpus_bleu

bleu_score = corpus_bleu(
    [[ref] for ref in references],
    predictions
)
print(f"BLEU Score: {bleu_score}")

BLEU Score: 0.5408119057053787


In [None]:
%%capture
!pip install rouge_score

In [None]:
from rouge_score import rouge_scorer

scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
rouge_scores = [scorer.score(ref, pred) for ref, pred in zip(references, predictions)]

rouge1 = sum([score["rouge1"].fmeasure for score in rouge_scores]) / len(rouge_scores)
rouge2 = sum([score["rouge2"].fmeasure for score in rouge_scores]) / len(rouge_scores)
rougeL = sum([score["rougeL"].fmeasure for score in rouge_scores]) / len(rouge_scores)

print(f"ROUGE-1: {rouge1}")
print(f"ROUGE-2: {rouge2}")
print(f"ROUGE-L: {rougeL}")

ROUGE-1: 0.48863349994271804
ROUGE-2: 0.23661245661693792
ROUGE-L: 0.36428091511827243


In [None]:
print(references[30])

How about "Black Swan" (2010)? It's a drama thriller with first-class acting and intense drama. The ambiance and intensity of the film might resonate with your love for strong and realistic storytelling.


In [None]:
print(predictions[30])

How about "The Man Who Shot Liberty Valance" (1962)? It's a classic western drama with a rugged and raw feel, and it has a powerful performance by John Wayne. It's a timeless story with a strong focus on character development and a gripping plot that I think you'll really enjoy.


# Segundo entrenamiento

En este entrenamiento se usarán los mismos paramétros de entrenamiento que antes, pero con la diferencia de que usaré la data que tenía con la primera forma de parseo, en donde entrego información de películas.

In [8]:
files_to_download = {
    "formatted_train.json": "1Fm7aPdCv6bguP7UgoCqdUVcm8xpDo18_",
    "formatted_test.json": "1a4YeF--Sks7WA1ZIQL2zDZx4teKk7p4m",
    "formatted_validation.json": "1PC9OhZhNZt8lFO9wifhydy0BrIYe8Dkm"
}

for destination, file_id in files_to_download.items():
    gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=1Fm7aPdCv6bguP7UgoCqdUVcm8xpDo18_
From (redirected): https://drive.google.com/uc?id=1Fm7aPdCv6bguP7UgoCqdUVcm8xpDo18_&confirm=t&uuid=c2b6c445-5714-43c1-abbe-7a4e4db626c5
To: /content/formatted_train.json
100%|██████████| 237M/237M [00:01<00:00, 212MB/s]
Downloading...
From: https://drive.google.com/uc?id=1a4YeF--Sks7WA1ZIQL2zDZx4teKk7p4m
To: /content/formatted_test.json
100%|██████████| 10.8M/10.8M [00:00<00:00, 37.3MB/s]
Downloading...
From: https://drive.google.com/uc?id=1PC9OhZhNZt8lFO9wifhydy0BrIYe8Dkm
To: /content/formatted_validation.json
100%|██████████| 23.7M/23.7M [00:00<00:00, 133MB/s] 


In [16]:
with open("formatted_train.json", "r") as f:
    data = json.load(f)

dataset_second_train = Dataset.from_dict({"conversations": data})

In [10]:
training_configuration = {
    "lr_scheduler_type": "linear",
    "per_device_train_batch_size": 2,
    "gradient_accumulation_steps": 4,
    "max_steps": 200,
    "output_dir": "second_train_output",
    "learning_rate": 2e-4
}

In [13]:
model_second_train, tokenizer_second_train = initialize_model_and_tokenizer()

==((====))==  Unsloth 2024.11.10: Fast Llama patching. Transformers:4.46.2.
   \\   /|    GPU: NVIDIA A100-SXM4-40GB. Max memory: 39.564 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2024.11.10 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

In [17]:
formatted_dataset_second_train = preprocess_dataset(dataset_second_train, tokenizer_second_train)

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [18]:
trainer = SFTTrainer(
    model=model_second_train,
    tokenizer=tokenizer_second_train,
    train_dataset=formatted_dataset_second_train,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    args=TrainingArguments(
        learning_rate=5e-5,
        lr_scheduler_type=training_configuration["lr_scheduler_type"],
        per_device_train_batch_size=training_configuration["per_device_train_batch_size"],
        gradient_accumulation_steps=training_configuration["gradient_accumulation_steps"],
        num_train_epochs=1,
        max_steps=training_configuration["max_steps"],
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.001,
        warmup_steps=5,
        output_dir=training_configuration["output_dir"],
        seed=used_seed,
    )
)

Map (num_proc=2):   0%|          | 0/50000 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs


In [19]:
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.564 GB.
2.635 GB of memory reserved.


In [20]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 50,000 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 200
 "-____-"     Number of trainable parameters = 24,313,856
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmamunoz42[0m ([33mmamunoz42-pontificia-universidad-cat-lica-de-chile[0m). Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
1,2.2432
2,2.2561
3,2.1644
4,2.3263
5,2.3427
6,2.1763
7,2.1721
8,2.1916
9,2.0801
10,2.1056


In [21]:
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory         /max_memory*100, 3)
lora_percentage = round(used_memory_for_lora/max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

373.2538 seconds used for training.
6.22 minutes used for training.
Peak reserved memory = 3.658 GB.
Peak reserved memory for training = 1.023 GB.
Peak reserved memory % of max memory = 9.246 %.
Peak reserved memory for training % of max memory = 2.586 %.


In [22]:
model_second_train.save_pretrained("lora_model_second_train")
tokenizer_second_train.save_pretrained("lora_model_second_train")

('lora_model_second_train/tokenizer_config.json',
 'lora_model_second_train/special_tokens_map.json',
 'lora_model_second_train/tokenizer.json')

In [23]:
!zip -r lora_model_second_train.zip lora_model_second_train/

  adding: lora_model_second_train/ (stored 0%)
  adding: lora_model_second_train/tokenizer.json (deflated 85%)
  adding: lora_model_second_train/adapter_config.json (deflated 54%)
  adding: lora_model_second_train/README.md (deflated 66%)
  adding: lora_model_second_train/tokenizer_config.json (deflated 94%)
  adding: lora_model_second_train/adapter_model.safetensors (deflated 8%)
  adding: lora_model_second_train/special_tokens_map.json (deflated 71%)


# Testeo segundo modelo

In [24]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "lora_model_second_train",
    max_seq_length = max_seq_length,
    dtype=None,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)

==((====))==  Unsloth 2024.11.10: Fast Llama patching. Transformers:4.46.2.
   \\   /|    GPU: NVIDIA A100-SXM4-40GB. Max memory: 39.564 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 3072, padding_idx=128004)
        (layers): ModuleList(
          (0-27): 28 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lor

In [25]:
tokenizer = get_chat_template(
    tokenizer,
    chat_template="llama-3.1",
)

In [26]:
with open("formatted_test.json", "r") as f:
    data = json.load(f)

dataset = Dataset.from_dict({"conversations": data})

In [27]:
formatted_test_dataset = preprocess_dataset(dataset, tokenizer)

Map:   0%|          | 0/2277 [00:00<?, ? examples/s]

In [28]:
%%capture
!pip install tqdm

In [29]:
percentage = 0.2
sample_size = int(len(formatted_test_dataset) * percentage)
subset_test_dataset = formatted_test_dataset.shuffle(seed=42).select(range(sample_size))

print(f"Usando {sample_size} ejemplos para la evaluación.")

Usando 455 ejemplos para la evaluación.


In [30]:
from tqdm import tqdm

generation_args = {
    "max_new_tokens": 100,
    "temperature": 0.3,
    "use_cache": True,
    "top_p": 0.9,
    "top_k": 50,
}

predictions = []
references = []

for example in tqdm(subset_test_dataset):
    conversation = example["conversations"]
    system_message = conversation[0]
    user_assistant_turns = conversation[1:-2]

    max_turns = 10
    truncated_turns = user_assistant_turns[-max_turns:]

    messages = [system_message] + truncated_turns
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda")

    print(f"Inputs shape: {inputs.shape}")

    with torch.no_grad():
        outputs = model.generate(input_ids=inputs, **generation_args)

    generated_response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    if "assistant" in generated_response:
        generated_response = generated_response.split("assistant")[-1].strip()

    predictions.append(generated_response)

    references.append(conversation[-2]["content"])


  0%|          | 0/455 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Inputs shape: torch.Size([1, 891])


  0%|          | 1/455 [00:04<31:21,  4.14s/it]

Inputs shape: torch.Size([1, 971])


  0%|          | 2/455 [00:08<30:23,  4.02s/it]

Inputs shape: torch.Size([1, 872])


  1%|          | 3/455 [00:11<28:49,  3.83s/it]

Inputs shape: torch.Size([1, 766])


  1%|          | 4/455 [00:15<27:37,  3.68s/it]

Inputs shape: torch.Size([1, 798])


  1%|          | 5/455 [00:19<29:11,  3.89s/it]

Inputs shape: torch.Size([1, 771])


  1%|▏         | 6/455 [00:22<28:18,  3.78s/it]

Inputs shape: torch.Size([1, 1041])


  2%|▏         | 7/455 [00:27<28:54,  3.87s/it]

Inputs shape: torch.Size([1, 888])


  2%|▏         | 8/455 [00:30<28:58,  3.89s/it]

Inputs shape: torch.Size([1, 873])


  2%|▏         | 9/455 [00:34<27:42,  3.73s/it]

Inputs shape: torch.Size([1, 1427])


  2%|▏         | 10/455 [00:38<29:39,  4.00s/it]

Inputs shape: torch.Size([1, 673])


  2%|▏         | 11/455 [00:42<28:07,  3.80s/it]

Inputs shape: torch.Size([1, 823])


  3%|▎         | 12/455 [00:46<28:16,  3.83s/it]

Inputs shape: torch.Size([1, 821])


  3%|▎         | 13/455 [00:49<26:48,  3.64s/it]

Inputs shape: torch.Size([1, 938])


  3%|▎         | 14/455 [00:54<29:26,  4.00s/it]

Inputs shape: torch.Size([1, 1156])


  3%|▎         | 15/455 [00:55<22:46,  3.10s/it]

Inputs shape: torch.Size([1, 744])


  4%|▎         | 16/455 [00:59<24:44,  3.38s/it]

Inputs shape: torch.Size([1, 1168])


  4%|▎         | 17/455 [01:03<27:18,  3.74s/it]

Inputs shape: torch.Size([1, 776])


  4%|▍         | 18/455 [01:06<25:42,  3.53s/it]

Inputs shape: torch.Size([1, 972])


  4%|▍         | 19/455 [01:11<26:56,  3.71s/it]

Inputs shape: torch.Size([1, 957])


  4%|▍         | 20/455 [01:15<27:36,  3.81s/it]

Inputs shape: torch.Size([1, 878])


  5%|▍         | 21/455 [01:19<27:59,  3.87s/it]

Inputs shape: torch.Size([1, 1123])


  5%|▍         | 22/455 [01:23<29:44,  4.12s/it]

Inputs shape: torch.Size([1, 962])


  5%|▌         | 23/455 [01:28<31:22,  4.36s/it]

Inputs shape: torch.Size([1, 1086])


  5%|▌         | 24/455 [01:32<29:25,  4.10s/it]

Inputs shape: torch.Size([1, 989])


  5%|▌         | 25/455 [01:35<27:48,  3.88s/it]

Inputs shape: torch.Size([1, 919])


  6%|▌         | 26/455 [01:39<27:46,  3.88s/it]

Inputs shape: torch.Size([1, 985])


  6%|▌         | 27/455 [01:44<29:48,  4.18s/it]

Inputs shape: torch.Size([1, 916])


  6%|▌         | 28/455 [01:49<31:43,  4.46s/it]

Inputs shape: torch.Size([1, 835])


  6%|▋         | 29/455 [01:53<29:59,  4.22s/it]

Inputs shape: torch.Size([1, 1147])


  7%|▋         | 30/455 [01:57<31:09,  4.40s/it]

Inputs shape: torch.Size([1, 1002])


  7%|▋         | 31/455 [02:01<29:29,  4.17s/it]

Inputs shape: torch.Size([1, 736])


  7%|▋         | 32/455 [02:04<27:45,  3.94s/it]

Inputs shape: torch.Size([1, 1213])


  7%|▋         | 33/455 [02:09<29:30,  4.20s/it]

Inputs shape: torch.Size([1, 1109])


  7%|▋         | 34/455 [02:13<29:23,  4.19s/it]

Inputs shape: torch.Size([1, 715])


  8%|▊         | 35/455 [02:16<25:35,  3.66s/it]

Inputs shape: torch.Size([1, 762])


  8%|▊         | 36/455 [02:20<26:23,  3.78s/it]

Inputs shape: torch.Size([1, 902])


  8%|▊         | 37/455 [02:23<25:04,  3.60s/it]

Inputs shape: torch.Size([1, 1006])


  8%|▊         | 38/455 [02:27<25:53,  3.73s/it]

Inputs shape: torch.Size([1, 812])


  9%|▊         | 39/455 [02:31<25:34,  3.69s/it]

Inputs shape: torch.Size([1, 1060])


  9%|▉         | 40/455 [02:36<29:50,  4.31s/it]

Inputs shape: torch.Size([1, 784])


  9%|▉         | 41/455 [02:40<27:54,  4.05s/it]

Inputs shape: torch.Size([1, 794])


  9%|▉         | 42/455 [02:44<28:09,  4.09s/it]

Inputs shape: torch.Size([1, 956])


  9%|▉         | 43/455 [02:48<26:43,  3.89s/it]

Inputs shape: torch.Size([1, 942])


 10%|▉         | 44/455 [02:51<25:38,  3.74s/it]

Inputs shape: torch.Size([1, 1093])


 10%|▉         | 45/455 [02:57<29:44,  4.35s/it]

Inputs shape: torch.Size([1, 880])


 10%|█         | 46/455 [03:01<29:52,  4.38s/it]

Inputs shape: torch.Size([1, 1079])


 10%|█         | 47/455 [03:06<30:24,  4.47s/it]

Inputs shape: torch.Size([1, 680])


 11%|█         | 48/455 [03:09<27:53,  4.11s/it]

Inputs shape: torch.Size([1, 1066])


 11%|█         | 49/455 [03:13<26:52,  3.97s/it]

Inputs shape: torch.Size([1, 1165])


 11%|█         | 50/455 [03:17<27:56,  4.14s/it]

Inputs shape: torch.Size([1, 994])


 11%|█         | 51/455 [03:20<25:00,  3.71s/it]

Inputs shape: torch.Size([1, 816])


 11%|█▏        | 52/455 [03:23<24:09,  3.60s/it]

Inputs shape: torch.Size([1, 957])


 12%|█▏        | 53/455 [03:26<22:50,  3.41s/it]

Inputs shape: torch.Size([1, 837])


 12%|█▏        | 54/455 [03:30<23:15,  3.48s/it]

Inputs shape: torch.Size([1, 993])


 12%|█▏        | 55/455 [03:35<26:41,  4.00s/it]

Inputs shape: torch.Size([1, 794])


 12%|█▏        | 56/455 [03:39<25:51,  3.89s/it]

Inputs shape: torch.Size([1, 742])


 13%|█▎        | 57/455 [03:44<28:40,  4.32s/it]

Inputs shape: torch.Size([1, 822])


 13%|█▎        | 58/455 [03:49<30:15,  4.57s/it]

Inputs shape: torch.Size([1, 787])


 13%|█▎        | 59/455 [03:53<28:02,  4.25s/it]

Inputs shape: torch.Size([1, 892])


 13%|█▎        | 60/455 [03:57<27:19,  4.15s/it]

Inputs shape: torch.Size([1, 1106])


 13%|█▎        | 61/455 [04:01<27:19,  4.16s/it]

Inputs shape: torch.Size([1, 965])


 14%|█▎        | 62/455 [04:04<26:03,  3.98s/it]

Inputs shape: torch.Size([1, 849])


 14%|█▍        | 63/455 [04:08<24:42,  3.78s/it]

Inputs shape: torch.Size([1, 858])


 14%|█▍        | 64/455 [04:11<23:31,  3.61s/it]

Inputs shape: torch.Size([1, 1339])


 14%|█▍        | 65/455 [04:17<28:35,  4.40s/it]

Inputs shape: torch.Size([1, 796])


 15%|█▍        | 66/455 [04:20<25:54,  4.00s/it]

Inputs shape: torch.Size([1, 659])


 15%|█▍        | 67/455 [04:24<26:06,  4.04s/it]

Inputs shape: torch.Size([1, 1023])


 15%|█▍        | 68/455 [04:28<25:43,  3.99s/it]

Inputs shape: torch.Size([1, 845])


 15%|█▌        | 69/455 [04:32<24:56,  3.88s/it]

Inputs shape: torch.Size([1, 902])


 15%|█▌        | 70/455 [04:36<26:07,  4.07s/it]

Inputs shape: torch.Size([1, 925])


 16%|█▌        | 71/455 [04:40<26:07,  4.08s/it]

Inputs shape: torch.Size([1, 749])


 16%|█▌        | 72/455 [04:44<25:46,  4.04s/it]

Inputs shape: torch.Size([1, 852])


 16%|█▌        | 73/455 [04:48<24:39,  3.87s/it]

Inputs shape: torch.Size([1, 904])


 16%|█▋        | 74/455 [04:53<26:09,  4.12s/it]

Inputs shape: torch.Size([1, 920])


 16%|█▋        | 75/455 [04:57<26:29,  4.18s/it]

Inputs shape: torch.Size([1, 703])


 17%|█▋        | 76/455 [05:00<25:00,  3.96s/it]

Inputs shape: torch.Size([1, 1056])


 17%|█▋        | 77/455 [05:04<24:08,  3.83s/it]

Inputs shape: torch.Size([1, 742])


 17%|█▋        | 78/455 [05:08<24:26,  3.89s/it]

Inputs shape: torch.Size([1, 717])


 17%|█▋        | 79/455 [05:11<23:26,  3.74s/it]

Inputs shape: torch.Size([1, 1035])


 18%|█▊        | 80/455 [05:15<23:52,  3.82s/it]

Inputs shape: torch.Size([1, 859])


 18%|█▊        | 81/455 [05:18<22:16,  3.57s/it]

Inputs shape: torch.Size([1, 922])


 18%|█▊        | 82/455 [05:21<20:57,  3.37s/it]

Inputs shape: torch.Size([1, 856])


 18%|█▊        | 83/455 [05:24<20:33,  3.32s/it]

Inputs shape: torch.Size([1, 829])


 18%|█▊        | 84/455 [05:28<21:42,  3.51s/it]

Inputs shape: torch.Size([1, 820])


 19%|█▊        | 85/455 [05:32<22:32,  3.66s/it]

Inputs shape: torch.Size([1, 898])


 19%|█▉        | 86/455 [05:36<23:07,  3.76s/it]

Inputs shape: torch.Size([1, 1002])


 19%|█▉        | 87/455 [05:40<23:15,  3.79s/it]

Inputs shape: torch.Size([1, 727])


 19%|█▉        | 88/455 [05:44<23:39,  3.87s/it]

Inputs shape: torch.Size([1, 969])


 20%|█▉        | 89/455 [05:49<24:22,  4.00s/it]

Inputs shape: torch.Size([1, 875])


 20%|█▉        | 90/455 [05:54<26:43,  4.39s/it]

Inputs shape: torch.Size([1, 897])


 20%|██        | 91/455 [05:57<24:04,  3.97s/it]

Inputs shape: torch.Size([1, 818])


 20%|██        | 92/455 [06:01<24:58,  4.13s/it]

Inputs shape: torch.Size([1, 1024])


 20%|██        | 93/455 [06:05<24:25,  4.05s/it]

Inputs shape: torch.Size([1, 896])


 21%|██        | 94/455 [06:10<25:27,  4.23s/it]

Inputs shape: torch.Size([1, 980])


 21%|██        | 95/455 [06:14<25:14,  4.21s/it]

Inputs shape: torch.Size([1, 980])


 21%|██        | 96/455 [06:18<24:32,  4.10s/it]

Inputs shape: torch.Size([1, 709])


 21%|██▏       | 97/455 [06:22<23:41,  3.97s/it]

Inputs shape: torch.Size([1, 1002])


 22%|██▏       | 98/455 [06:26<24:41,  4.15s/it]

Inputs shape: torch.Size([1, 975])


 22%|██▏       | 99/455 [06:29<22:21,  3.77s/it]

Inputs shape: torch.Size([1, 941])


 22%|██▏       | 100/455 [06:33<22:51,  3.86s/it]

Inputs shape: torch.Size([1, 956])


 22%|██▏       | 101/455 [06:36<21:26,  3.63s/it]

Inputs shape: torch.Size([1, 746])


 22%|██▏       | 102/455 [06:39<20:45,  3.53s/it]

Inputs shape: torch.Size([1, 944])


 23%|██▎       | 103/455 [06:42<19:15,  3.28s/it]

Inputs shape: torch.Size([1, 1037])


 23%|██▎       | 104/455 [06:45<19:00,  3.25s/it]

Inputs shape: torch.Size([1, 942])


 23%|██▎       | 105/455 [06:50<21:00,  3.60s/it]

Inputs shape: torch.Size([1, 817])


 23%|██▎       | 106/455 [06:53<19:42,  3.39s/it]

Inputs shape: torch.Size([1, 916])


 24%|██▎       | 107/455 [06:56<19:27,  3.35s/it]

Inputs shape: torch.Size([1, 1001])


 24%|██▎       | 108/455 [07:00<20:56,  3.62s/it]

Inputs shape: torch.Size([1, 1124])


 24%|██▍       | 109/455 [07:05<23:10,  4.02s/it]

Inputs shape: torch.Size([1, 934])


 24%|██▍       | 110/455 [07:08<21:39,  3.77s/it]

Inputs shape: torch.Size([1, 864])


 24%|██▍       | 111/455 [07:12<21:59,  3.84s/it]

Inputs shape: torch.Size([1, 748])


 25%|██▍       | 112/455 [07:15<20:36,  3.61s/it]

Inputs shape: torch.Size([1, 788])


 25%|██▍       | 113/455 [07:19<20:31,  3.60s/it]

Inputs shape: torch.Size([1, 995])


 25%|██▌       | 114/455 [07:24<22:08,  3.90s/it]

Inputs shape: torch.Size([1, 953])


 25%|██▌       | 115/455 [07:27<21:24,  3.78s/it]

Inputs shape: torch.Size([1, 927])


 25%|██▌       | 116/455 [07:32<23:16,  4.12s/it]

Inputs shape: torch.Size([1, 1011])


 26%|██▌       | 117/455 [07:36<23:22,  4.15s/it]

Inputs shape: torch.Size([1, 868])


 26%|██▌       | 118/455 [07:39<21:21,  3.80s/it]

Inputs shape: torch.Size([1, 946])


 26%|██▌       | 119/455 [07:44<23:00,  4.11s/it]

Inputs shape: torch.Size([1, 932])


 26%|██▋       | 120/455 [07:48<23:08,  4.14s/it]

Inputs shape: torch.Size([1, 981])


 27%|██▋       | 121/455 [07:52<22:45,  4.09s/it]

Inputs shape: torch.Size([1, 830])


 27%|██▋       | 122/455 [07:56<22:35,  4.07s/it]

Inputs shape: torch.Size([1, 787])


 27%|██▋       | 123/455 [07:59<20:59,  3.79s/it]

Inputs shape: torch.Size([1, 879])


 27%|██▋       | 124/455 [08:03<19:55,  3.61s/it]

Inputs shape: torch.Size([1, 1026])


 27%|██▋       | 125/455 [08:05<18:37,  3.39s/it]

Inputs shape: torch.Size([1, 899])


 28%|██▊       | 126/455 [08:11<22:28,  4.10s/it]

Inputs shape: torch.Size([1, 709])


 28%|██▊       | 127/455 [08:15<21:14,  3.89s/it]

Inputs shape: torch.Size([1, 894])


 28%|██▊       | 128/455 [08:18<19:59,  3.67s/it]

Inputs shape: torch.Size([1, 748])


 28%|██▊       | 129/455 [08:22<21:10,  3.90s/it]

Inputs shape: torch.Size([1, 952])


 29%|██▊       | 130/455 [08:26<20:26,  3.77s/it]

Inputs shape: torch.Size([1, 757])


 29%|██▉       | 131/455 [08:29<19:36,  3.63s/it]

Inputs shape: torch.Size([1, 918])


 29%|██▉       | 132/455 [08:32<17:47,  3.30s/it]

Inputs shape: torch.Size([1, 917])


 29%|██▉       | 133/455 [08:35<18:40,  3.48s/it]

Inputs shape: torch.Size([1, 979])


 29%|██▉       | 134/455 [08:39<19:24,  3.63s/it]

Inputs shape: torch.Size([1, 857])


 30%|██▉       | 135/455 [08:45<21:51,  4.10s/it]

Inputs shape: torch.Size([1, 860])


 30%|██▉       | 136/455 [08:50<24:34,  4.62s/it]

Inputs shape: torch.Size([1, 953])


 30%|███       | 137/455 [08:55<24:32,  4.63s/it]

Inputs shape: torch.Size([1, 748])


 30%|███       | 138/455 [09:00<25:44,  4.87s/it]

Inputs shape: torch.Size([1, 1104])


 31%|███       | 139/455 [09:05<25:13,  4.79s/it]

Inputs shape: torch.Size([1, 902])


 31%|███       | 140/455 [09:09<23:13,  4.42s/it]

Inputs shape: torch.Size([1, 858])


 31%|███       | 141/455 [09:13<23:37,  4.51s/it]

Inputs shape: torch.Size([1, 807])


 31%|███       | 142/455 [09:17<22:30,  4.31s/it]

Inputs shape: torch.Size([1, 818])


 31%|███▏      | 143/455 [09:20<20:46,  3.99s/it]

Inputs shape: torch.Size([1, 921])


 32%|███▏      | 144/455 [09:24<19:37,  3.79s/it]

Inputs shape: torch.Size([1, 860])


 32%|███▏      | 145/455 [09:28<20:13,  3.92s/it]

Inputs shape: torch.Size([1, 879])


 32%|███▏      | 146/455 [09:32<20:13,  3.93s/it]

Inputs shape: torch.Size([1, 1068])


 32%|███▏      | 147/455 [09:38<23:06,  4.50s/it]

Inputs shape: torch.Size([1, 864])


 33%|███▎      | 148/455 [09:42<23:06,  4.52s/it]

Inputs shape: torch.Size([1, 936])


 33%|███▎      | 149/455 [09:46<21:39,  4.25s/it]

Inputs shape: torch.Size([1, 970])


 33%|███▎      | 150/455 [09:52<23:47,  4.68s/it]

Inputs shape: torch.Size([1, 996])


 33%|███▎      | 151/455 [09:56<23:44,  4.69s/it]

Inputs shape: torch.Size([1, 1048])


 33%|███▎      | 152/455 [10:01<23:49,  4.72s/it]

Inputs shape: torch.Size([1, 1099])


 34%|███▎      | 153/455 [10:06<23:43,  4.71s/it]

Inputs shape: torch.Size([1, 790])


 34%|███▍      | 154/455 [10:09<21:37,  4.31s/it]

Inputs shape: torch.Size([1, 949])


 34%|███▍      | 155/455 [10:14<22:42,  4.54s/it]

Inputs shape: torch.Size([1, 972])


 34%|███▍      | 156/455 [10:18<22:02,  4.42s/it]

Inputs shape: torch.Size([1, 973])


 35%|███▍      | 157/455 [10:22<21:23,  4.31s/it]

Inputs shape: torch.Size([1, 1065])


 35%|███▍      | 158/455 [10:27<20:58,  4.24s/it]

Inputs shape: torch.Size([1, 934])


 35%|███▍      | 159/455 [10:30<20:25,  4.14s/it]

Inputs shape: torch.Size([1, 1238])


 35%|███▌      | 160/455 [10:35<20:23,  4.15s/it]

Inputs shape: torch.Size([1, 1191])


 35%|███▌      | 161/455 [10:39<20:41,  4.22s/it]

Inputs shape: torch.Size([1, 797])


 36%|███▌      | 162/455 [10:43<19:37,  4.02s/it]

Inputs shape: torch.Size([1, 1101])


 36%|███▌      | 163/455 [10:46<19:23,  3.98s/it]

Inputs shape: torch.Size([1, 857])


 36%|███▌      | 164/455 [10:51<19:26,  4.01s/it]

Inputs shape: torch.Size([1, 799])


 36%|███▋      | 165/455 [10:54<18:48,  3.89s/it]

Inputs shape: torch.Size([1, 957])


 36%|███▋      | 166/455 [11:00<21:21,  4.43s/it]

Inputs shape: torch.Size([1, 868])


 37%|███▋      | 167/455 [11:03<19:13,  4.00s/it]

Inputs shape: torch.Size([1, 1026])


 37%|███▋      | 168/455 [11:07<19:41,  4.12s/it]

Inputs shape: torch.Size([1, 839])


 37%|███▋      | 169/455 [11:11<18:41,  3.92s/it]

Inputs shape: torch.Size([1, 690])


 37%|███▋      | 170/455 [11:14<18:09,  3.82s/it]

Inputs shape: torch.Size([1, 1147])


 38%|███▊      | 171/455 [11:18<18:32,  3.92s/it]

Inputs shape: torch.Size([1, 872])


 38%|███▊      | 172/455 [11:21<16:28,  3.49s/it]

Inputs shape: torch.Size([1, 992])


 38%|███▊      | 173/455 [11:26<18:21,  3.91s/it]

Inputs shape: torch.Size([1, 973])


 38%|███▊      | 174/455 [11:32<21:37,  4.62s/it]

Inputs shape: torch.Size([1, 837])


 38%|███▊      | 175/455 [11:36<20:17,  4.35s/it]

Inputs shape: torch.Size([1, 724])


 39%|███▊      | 176/455 [11:39<18:54,  4.06s/it]

Inputs shape: torch.Size([1, 893])


 39%|███▉      | 177/455 [11:44<19:10,  4.14s/it]

Inputs shape: torch.Size([1, 1037])


 39%|███▉      | 178/455 [11:47<18:25,  3.99s/it]

Inputs shape: torch.Size([1, 851])


 39%|███▉      | 179/455 [11:53<20:27,  4.45s/it]

Inputs shape: torch.Size([1, 855])


 40%|███▉      | 180/455 [11:57<19:34,  4.27s/it]

Inputs shape: torch.Size([1, 925])


 40%|███▉      | 181/455 [11:59<16:42,  3.66s/it]

Inputs shape: torch.Size([1, 1085])


 40%|████      | 182/455 [12:03<17:34,  3.86s/it]

Inputs shape: torch.Size([1, 875])


 40%|████      | 183/455 [12:05<15:12,  3.36s/it]

Inputs shape: torch.Size([1, 945])


 40%|████      | 184/455 [12:09<15:18,  3.39s/it]

Inputs shape: torch.Size([1, 763])


 41%|████      | 185/455 [12:13<15:54,  3.53s/it]

Inputs shape: torch.Size([1, 956])


 41%|████      | 186/455 [12:16<15:58,  3.56s/it]

Inputs shape: torch.Size([1, 807])


 41%|████      | 187/455 [12:20<16:19,  3.66s/it]

Inputs shape: torch.Size([1, 941])


 41%|████▏     | 188/455 [12:24<17:05,  3.84s/it]

Inputs shape: torch.Size([1, 1072])


 42%|████▏     | 189/455 [12:28<16:45,  3.78s/it]

Inputs shape: torch.Size([1, 761])


 42%|████▏     | 190/455 [12:32<16:32,  3.74s/it]

Inputs shape: torch.Size([1, 1115])


 42%|████▏     | 191/455 [12:36<17:44,  4.03s/it]

Inputs shape: torch.Size([1, 761])


 42%|████▏     | 192/455 [12:41<17:51,  4.08s/it]

Inputs shape: torch.Size([1, 694])


 42%|████▏     | 193/455 [12:44<17:31,  4.01s/it]

Inputs shape: torch.Size([1, 838])


 43%|████▎     | 194/455 [12:51<20:12,  4.65s/it]

Inputs shape: torch.Size([1, 928])


 43%|████▎     | 195/455 [12:55<19:16,  4.45s/it]

Inputs shape: torch.Size([1, 930])


 43%|████▎     | 196/455 [12:59<18:48,  4.36s/it]

Inputs shape: torch.Size([1, 1059])


 43%|████▎     | 197/455 [13:02<17:45,  4.13s/it]

Inputs shape: torch.Size([1, 646])


 44%|████▎     | 198/455 [13:06<17:08,  4.00s/it]

Inputs shape: torch.Size([1, 844])


 44%|████▎     | 199/455 [13:10<17:06,  4.01s/it]

Inputs shape: torch.Size([1, 955])


 44%|████▍     | 200/455 [13:14<16:59,  4.00s/it]

Inputs shape: torch.Size([1, 928])


 44%|████▍     | 201/455 [13:20<19:07,  4.52s/it]

Inputs shape: torch.Size([1, 877])


 44%|████▍     | 202/455 [13:25<19:46,  4.69s/it]

Inputs shape: torch.Size([1, 976])


 45%|████▍     | 203/455 [13:27<17:08,  4.08s/it]

Inputs shape: torch.Size([1, 1229])


 45%|████▍     | 204/455 [13:33<18:17,  4.37s/it]

Inputs shape: torch.Size([1, 845])


 45%|████▌     | 205/455 [13:36<16:34,  3.98s/it]

Inputs shape: torch.Size([1, 1051])


 45%|████▌     | 206/455 [13:39<15:44,  3.79s/it]

Inputs shape: torch.Size([1, 724])


 45%|████▌     | 207/455 [13:43<15:40,  3.79s/it]

Inputs shape: torch.Size([1, 847])


 46%|████▌     | 208/455 [13:46<14:28,  3.52s/it]

Inputs shape: torch.Size([1, 1064])


 46%|████▌     | 209/455 [13:51<16:48,  4.10s/it]

Inputs shape: torch.Size([1, 928])


 46%|████▌     | 210/455 [13:56<17:23,  4.26s/it]

Inputs shape: torch.Size([1, 779])


 46%|████▋     | 211/455 [13:59<16:34,  4.07s/it]

Inputs shape: torch.Size([1, 933])


 47%|████▋     | 212/455 [14:01<14:02,  3.47s/it]

Inputs shape: torch.Size([1, 1062])


 47%|████▋     | 213/455 [14:05<13:41,  3.39s/it]

Inputs shape: torch.Size([1, 892])


 47%|████▋     | 214/455 [14:09<15:13,  3.79s/it]

Inputs shape: torch.Size([1, 852])


 47%|████▋     | 215/455 [14:11<12:40,  3.17s/it]

Inputs shape: torch.Size([1, 1064])


 47%|████▋     | 216/455 [14:15<13:12,  3.32s/it]

Inputs shape: torch.Size([1, 1036])


 48%|████▊     | 217/455 [14:20<15:37,  3.94s/it]

Inputs shape: torch.Size([1, 957])


 48%|████▊     | 218/455 [14:24<15:06,  3.82s/it]

Inputs shape: torch.Size([1, 857])


 48%|████▊     | 219/455 [14:28<15:30,  3.94s/it]

Inputs shape: torch.Size([1, 1068])


 48%|████▊     | 220/455 [14:30<12:41,  3.24s/it]

Inputs shape: torch.Size([1, 1092])


 49%|████▊     | 221/455 [14:33<13:08,  3.37s/it]

Inputs shape: torch.Size([1, 829])


 49%|████▉     | 222/455 [14:37<13:41,  3.53s/it]

Inputs shape: torch.Size([1, 606])


 49%|████▉     | 223/455 [14:42<15:01,  3.89s/it]

Inputs shape: torch.Size([1, 1158])


 49%|████▉     | 224/455 [14:46<15:26,  4.01s/it]

Inputs shape: torch.Size([1, 1261])


 49%|████▉     | 225/455 [14:50<15:32,  4.05s/it]

Inputs shape: torch.Size([1, 747])


 50%|████▉     | 226/455 [14:53<14:23,  3.77s/it]

Inputs shape: torch.Size([1, 876])


 50%|████▉     | 227/455 [14:58<15:52,  4.18s/it]

Inputs shape: torch.Size([1, 1069])


 50%|█████     | 228/455 [15:03<15:45,  4.16s/it]

Inputs shape: torch.Size([1, 989])


 50%|█████     | 229/455 [15:06<15:21,  4.08s/it]

Inputs shape: torch.Size([1, 1137])


 51%|█████     | 230/455 [15:11<16:10,  4.31s/it]

Inputs shape: torch.Size([1, 870])


 51%|█████     | 231/455 [15:15<15:12,  4.07s/it]

Inputs shape: torch.Size([1, 865])


 51%|█████     | 232/455 [15:18<14:24,  3.88s/it]

Inputs shape: torch.Size([1, 920])


 51%|█████     | 233/455 [15:21<13:30,  3.65s/it]

Inputs shape: torch.Size([1, 954])


 51%|█████▏    | 234/455 [15:26<14:20,  3.89s/it]

Inputs shape: torch.Size([1, 936])


 52%|█████▏    | 235/455 [15:31<15:11,  4.14s/it]

Inputs shape: torch.Size([1, 792])


 52%|█████▏    | 236/455 [15:35<15:01,  4.12s/it]

Inputs shape: torch.Size([1, 971])


 52%|█████▏    | 237/455 [15:39<15:00,  4.13s/it]

Inputs shape: torch.Size([1, 893])


 52%|█████▏    | 238/455 [15:44<15:36,  4.32s/it]

Inputs shape: torch.Size([1, 918])


 53%|█████▎    | 239/455 [15:48<15:50,  4.40s/it]

Inputs shape: torch.Size([1, 867])


 53%|█████▎    | 240/455 [15:53<16:23,  4.57s/it]

Inputs shape: torch.Size([1, 849])


 53%|█████▎    | 241/455 [15:56<14:00,  3.93s/it]

Inputs shape: torch.Size([1, 715])


 53%|█████▎    | 242/455 [15:59<12:59,  3.66s/it]

Inputs shape: torch.Size([1, 843])


 53%|█████▎    | 243/455 [16:02<12:43,  3.60s/it]

Inputs shape: torch.Size([1, 927])


 54%|█████▎    | 244/455 [16:06<13:09,  3.74s/it]

Inputs shape: torch.Size([1, 733])


 54%|█████▍    | 245/455 [16:09<12:09,  3.47s/it]

Inputs shape: torch.Size([1, 1019])


 54%|█████▍    | 246/455 [16:14<13:47,  3.96s/it]

Inputs shape: torch.Size([1, 769])


 54%|█████▍    | 247/455 [16:19<14:36,  4.22s/it]

Inputs shape: torch.Size([1, 1057])


 55%|█████▍    | 248/455 [16:23<14:11,  4.12s/it]

Inputs shape: torch.Size([1, 929])


 55%|█████▍    | 249/455 [16:27<14:07,  4.11s/it]

Inputs shape: torch.Size([1, 1001])


 55%|█████▍    | 250/455 [16:29<12:12,  3.57s/it]

Inputs shape: torch.Size([1, 934])


 55%|█████▌    | 251/455 [16:34<13:48,  4.06s/it]

Inputs shape: torch.Size([1, 657])


 55%|█████▌    | 252/455 [16:39<13:48,  4.08s/it]

Inputs shape: torch.Size([1, 773])


 56%|█████▌    | 253/455 [16:42<13:08,  3.90s/it]

Inputs shape: torch.Size([1, 915])


 56%|█████▌    | 254/455 [16:47<14:37,  4.36s/it]

Inputs shape: torch.Size([1, 821])


 56%|█████▌    | 255/455 [16:51<13:44,  4.12s/it]

Inputs shape: torch.Size([1, 986])


 56%|█████▋    | 256/455 [16:57<15:20,  4.62s/it]

Inputs shape: torch.Size([1, 873])


 56%|█████▋    | 257/455 [17:00<13:57,  4.23s/it]

Inputs shape: torch.Size([1, 863])


 57%|█████▋    | 258/455 [17:04<13:26,  4.09s/it]

Inputs shape: torch.Size([1, 1041])


 57%|█████▋    | 259/455 [17:09<13:58,  4.28s/it]

Inputs shape: torch.Size([1, 1018])


 57%|█████▋    | 260/455 [17:13<13:38,  4.20s/it]

Inputs shape: torch.Size([1, 737])


 57%|█████▋    | 261/455 [17:16<12:57,  4.01s/it]

Inputs shape: torch.Size([1, 859])


 58%|█████▊    | 262/455 [17:20<13:03,  4.06s/it]

Inputs shape: torch.Size([1, 971])


 58%|█████▊    | 263/455 [17:25<13:26,  4.20s/it]

Inputs shape: torch.Size([1, 940])


 58%|█████▊    | 264/455 [17:29<13:45,  4.32s/it]

Inputs shape: torch.Size([1, 714])


 58%|█████▊    | 265/455 [17:33<13:20,  4.21s/it]

Inputs shape: torch.Size([1, 812])


 58%|█████▊    | 266/455 [17:37<13:01,  4.14s/it]

Inputs shape: torch.Size([1, 742])


 59%|█████▊    | 267/455 [17:41<12:06,  3.87s/it]

Inputs shape: torch.Size([1, 1080])


 59%|█████▉    | 268/455 [17:44<11:50,  3.80s/it]

Inputs shape: torch.Size([1, 1027])


 59%|█████▉    | 269/455 [17:50<13:06,  4.23s/it]

Inputs shape: torch.Size([1, 794])


 59%|█████▉    | 270/455 [17:53<12:14,  3.97s/it]

Inputs shape: torch.Size([1, 902])


 60%|█████▉    | 271/455 [17:56<11:23,  3.71s/it]

Inputs shape: torch.Size([1, 1047])


 60%|█████▉    | 272/455 [18:01<12:07,  3.98s/it]

Inputs shape: torch.Size([1, 1120])


 60%|██████    | 273/455 [18:04<11:51,  3.91s/it]

Inputs shape: torch.Size([1, 903])


 60%|██████    | 274/455 [18:07<11:03,  3.67s/it]

Inputs shape: torch.Size([1, 890])


 60%|██████    | 275/455 [18:11<10:33,  3.52s/it]

Inputs shape: torch.Size([1, 902])


 61%|██████    | 276/455 [18:15<10:55,  3.66s/it]

Inputs shape: torch.Size([1, 907])


 61%|██████    | 277/455 [18:20<12:07,  4.09s/it]

Inputs shape: torch.Size([1, 753])


 61%|██████    | 278/455 [18:22<10:51,  3.68s/it]

Inputs shape: torch.Size([1, 1021])


 61%|██████▏   | 279/455 [18:27<11:18,  3.85s/it]

Inputs shape: torch.Size([1, 888])


 62%|██████▏   | 280/455 [18:31<11:26,  3.92s/it]

Inputs shape: torch.Size([1, 899])


 62%|██████▏   | 281/455 [18:33<10:04,  3.47s/it]

Inputs shape: torch.Size([1, 792])


 62%|██████▏   | 282/455 [18:39<11:45,  4.08s/it]

Inputs shape: torch.Size([1, 753])


 62%|██████▏   | 283/455 [18:42<10:50,  3.78s/it]

Inputs shape: torch.Size([1, 839])


 62%|██████▏   | 284/455 [18:45<10:43,  3.76s/it]

Inputs shape: torch.Size([1, 998])


 63%|██████▎   | 285/455 [18:49<10:24,  3.67s/it]

Inputs shape: torch.Size([1, 691])


 63%|██████▎   | 286/455 [18:52<09:36,  3.41s/it]

Inputs shape: torch.Size([1, 990])


 63%|██████▎   | 287/455 [18:57<10:44,  3.84s/it]

Inputs shape: torch.Size([1, 935])


 63%|██████▎   | 288/455 [19:00<10:31,  3.78s/it]

Inputs shape: torch.Size([1, 888])


 64%|██████▎   | 289/455 [19:04<10:24,  3.76s/it]

Inputs shape: torch.Size([1, 804])


 64%|██████▎   | 290/455 [19:09<11:34,  4.21s/it]

Inputs shape: torch.Size([1, 985])


 64%|██████▍   | 291/455 [19:13<11:13,  4.11s/it]

Inputs shape: torch.Size([1, 698])


 64%|██████▍   | 292/455 [19:18<11:35,  4.27s/it]

Inputs shape: torch.Size([1, 1073])


 64%|██████▍   | 293/455 [19:21<10:28,  3.88s/it]

Inputs shape: torch.Size([1, 1109])


 65%|██████▍   | 294/455 [19:26<11:49,  4.41s/it]

Inputs shape: torch.Size([1, 774])


 65%|██████▍   | 295/455 [19:31<12:08,  4.55s/it]

Inputs shape: torch.Size([1, 801])


 65%|██████▌   | 296/455 [19:35<11:37,  4.38s/it]

Inputs shape: torch.Size([1, 1043])


 65%|██████▌   | 297/455 [19:40<11:35,  4.40s/it]

Inputs shape: torch.Size([1, 909])


 65%|██████▌   | 298/455 [19:42<09:57,  3.81s/it]

Inputs shape: torch.Size([1, 995])


 66%|██████▌   | 299/455 [19:45<09:29,  3.65s/it]

Inputs shape: torch.Size([1, 833])


 66%|██████▌   | 300/455 [19:49<09:24,  3.64s/it]

Inputs shape: torch.Size([1, 833])


 66%|██████▌   | 301/455 [19:53<09:41,  3.77s/it]

Inputs shape: torch.Size([1, 924])


 66%|██████▋   | 302/455 [19:57<09:38,  3.78s/it]

Inputs shape: torch.Size([1, 787])


 67%|██████▋   | 303/455 [20:00<09:27,  3.73s/it]

Inputs shape: torch.Size([1, 806])


 67%|██████▋   | 304/455 [20:04<09:01,  3.59s/it]

Inputs shape: torch.Size([1, 760])


 67%|██████▋   | 305/455 [20:09<09:54,  3.96s/it]

Inputs shape: torch.Size([1, 865])


 67%|██████▋   | 306/455 [20:14<10:49,  4.36s/it]

Inputs shape: torch.Size([1, 1106])


 67%|██████▋   | 307/455 [20:18<10:19,  4.19s/it]

Inputs shape: torch.Size([1, 686])


 68%|██████▊   | 308/455 [20:23<10:54,  4.45s/it]

Inputs shape: torch.Size([1, 723])


 68%|██████▊   | 309/455 [20:27<10:57,  4.50s/it]

Inputs shape: torch.Size([1, 911])


 68%|██████▊   | 310/455 [20:31<10:25,  4.31s/it]

Inputs shape: torch.Size([1, 807])


 68%|██████▊   | 311/455 [20:34<09:31,  3.97s/it]

Inputs shape: torch.Size([1, 891])


 69%|██████▊   | 312/455 [20:37<08:47,  3.69s/it]

Inputs shape: torch.Size([1, 718])


 69%|██████▉   | 313/455 [20:41<08:31,  3.60s/it]

Inputs shape: torch.Size([1, 869])


 69%|██████▉   | 314/455 [20:45<08:45,  3.73s/it]

Inputs shape: torch.Size([1, 880])


 69%|██████▉   | 315/455 [20:49<09:16,  3.97s/it]

Inputs shape: torch.Size([1, 833])


 69%|██████▉   | 316/455 [20:54<09:44,  4.20s/it]

Inputs shape: torch.Size([1, 900])


 70%|██████▉   | 317/455 [20:57<08:46,  3.82s/it]

Inputs shape: torch.Size([1, 773])


 70%|██████▉   | 318/455 [21:01<08:40,  3.80s/it]

Inputs shape: torch.Size([1, 932])


 70%|███████   | 319/455 [21:05<08:43,  3.85s/it]

Inputs shape: torch.Size([1, 826])


 70%|███████   | 320/455 [21:08<08:26,  3.75s/it]

Inputs shape: torch.Size([1, 1042])


 71%|███████   | 321/455 [21:11<07:51,  3.52s/it]

Inputs shape: torch.Size([1, 780])


 71%|███████   | 322/455 [21:14<07:29,  3.38s/it]

Inputs shape: torch.Size([1, 1013])


 71%|███████   | 323/455 [21:17<07:13,  3.29s/it]

Inputs shape: torch.Size([1, 1020])


 71%|███████   | 324/455 [21:23<08:44,  4.00s/it]

Inputs shape: torch.Size([1, 889])


 71%|███████▏  | 325/455 [21:27<08:30,  3.93s/it]

Inputs shape: torch.Size([1, 1274])


 72%|███████▏  | 326/455 [21:33<09:40,  4.50s/it]

Inputs shape: torch.Size([1, 849])


 72%|███████▏  | 327/455 [21:37<09:19,  4.37s/it]

Inputs shape: torch.Size([1, 818])


 72%|███████▏  | 328/455 [21:40<08:29,  4.01s/it]

Inputs shape: torch.Size([1, 814])


 72%|███████▏  | 329/455 [21:44<08:20,  3.97s/it]

Inputs shape: torch.Size([1, 997])


 73%|███████▎  | 330/455 [21:47<08:02,  3.86s/it]

Inputs shape: torch.Size([1, 980])


 73%|███████▎  | 331/455 [21:52<08:21,  4.04s/it]

Inputs shape: torch.Size([1, 1108])


 73%|███████▎  | 332/455 [21:56<08:28,  4.14s/it]

Inputs shape: torch.Size([1, 908])


 73%|███████▎  | 333/455 [22:01<08:44,  4.30s/it]

Inputs shape: torch.Size([1, 940])


 73%|███████▎  | 334/455 [22:05<08:21,  4.14s/it]

Inputs shape: torch.Size([1, 1149])


 74%|███████▎  | 335/455 [22:08<08:00,  4.00s/it]

Inputs shape: torch.Size([1, 868])


 74%|███████▍  | 336/455 [22:11<07:27,  3.76s/it]

Inputs shape: torch.Size([1, 833])


 74%|███████▍  | 337/455 [22:15<07:29,  3.81s/it]

Inputs shape: torch.Size([1, 1160])


 74%|███████▍  | 338/455 [22:20<07:39,  3.92s/it]

Inputs shape: torch.Size([1, 1087])


 75%|███████▍  | 339/455 [22:23<07:14,  3.74s/it]

Inputs shape: torch.Size([1, 1024])


 75%|███████▍  | 340/455 [22:28<07:55,  4.14s/it]

Inputs shape: torch.Size([1, 1111])


 75%|███████▍  | 341/455 [22:32<07:34,  3.99s/it]

Inputs shape: torch.Size([1, 890])


 75%|███████▌  | 342/455 [22:36<07:51,  4.17s/it]

Inputs shape: torch.Size([1, 1094])


 75%|███████▌  | 343/455 [22:41<08:00,  4.29s/it]

Inputs shape: torch.Size([1, 848])


 76%|███████▌  | 344/455 [22:44<07:37,  4.12s/it]

Inputs shape: torch.Size([1, 862])


 76%|███████▌  | 345/455 [22:49<07:46,  4.24s/it]

Inputs shape: torch.Size([1, 876])


 76%|███████▌  | 346/455 [22:52<06:47,  3.74s/it]

Inputs shape: torch.Size([1, 992])


 76%|███████▋  | 347/455 [22:56<06:55,  3.85s/it]

Inputs shape: torch.Size([1, 1081])


 76%|███████▋  | 348/455 [23:00<06:53,  3.86s/it]

Inputs shape: torch.Size([1, 880])


 77%|███████▋  | 349/455 [23:03<06:37,  3.75s/it]

Inputs shape: torch.Size([1, 852])


 77%|███████▋  | 350/455 [23:07<06:41,  3.83s/it]

Inputs shape: torch.Size([1, 1118])


 77%|███████▋  | 351/455 [23:13<07:43,  4.46s/it]

Inputs shape: torch.Size([1, 881])


 77%|███████▋  | 352/455 [23:17<07:15,  4.23s/it]

Inputs shape: torch.Size([1, 850])


 78%|███████▊  | 353/455 [23:21<07:13,  4.25s/it]

Inputs shape: torch.Size([1, 766])


 78%|███████▊  | 354/455 [23:24<06:40,  3.96s/it]

Inputs shape: torch.Size([1, 904])


 78%|███████▊  | 355/455 [23:28<06:37,  3.97s/it]

Inputs shape: torch.Size([1, 951])


 78%|███████▊  | 356/455 [23:32<06:13,  3.77s/it]

Inputs shape: torch.Size([1, 912])


 78%|███████▊  | 357/455 [23:36<06:20,  3.88s/it]

Inputs shape: torch.Size([1, 1064])


 79%|███████▊  | 358/455 [23:40<06:32,  4.04s/it]

Inputs shape: torch.Size([1, 883])


 79%|███████▉  | 359/455 [23:43<06:03,  3.79s/it]

Inputs shape: torch.Size([1, 861])


 79%|███████▉  | 360/455 [23:48<06:13,  3.93s/it]

Inputs shape: torch.Size([1, 952])


 79%|███████▉  | 361/455 [23:52<06:09,  3.93s/it]

Inputs shape: torch.Size([1, 907])


 80%|███████▉  | 362/455 [23:57<06:51,  4.43s/it]

Inputs shape: torch.Size([1, 841])


 80%|███████▉  | 363/455 [24:01<06:46,  4.41s/it]

Inputs shape: torch.Size([1, 898])


 80%|████████  | 364/455 [24:05<06:04,  4.01s/it]

Inputs shape: torch.Size([1, 663])


 80%|████████  | 365/455 [24:09<05:59,  4.00s/it]

Inputs shape: torch.Size([1, 1081])


 80%|████████  | 366/455 [24:13<05:59,  4.04s/it]

Inputs shape: torch.Size([1, 854])


 81%|████████  | 367/455 [24:17<05:50,  3.98s/it]

Inputs shape: torch.Size([1, 1052])


 81%|████████  | 368/455 [24:21<05:58,  4.12s/it]

Inputs shape: torch.Size([1, 973])


 81%|████████  | 369/455 [24:26<06:14,  4.36s/it]

Inputs shape: torch.Size([1, 998])


 81%|████████▏ | 370/455 [24:30<06:15,  4.42s/it]

Inputs shape: torch.Size([1, 1087])


 82%|████████▏ | 371/455 [24:35<06:10,  4.41s/it]

Inputs shape: torch.Size([1, 1266])


 82%|████████▏ | 372/455 [24:39<05:52,  4.24s/it]

Inputs shape: torch.Size([1, 988])


 82%|████████▏ | 373/455 [24:42<05:14,  3.83s/it]

Inputs shape: torch.Size([1, 967])


 82%|████████▏ | 374/455 [24:48<06:10,  4.57s/it]

Inputs shape: torch.Size([1, 830])


 82%|████████▏ | 375/455 [24:52<05:46,  4.33s/it]

Inputs shape: torch.Size([1, 936])


 83%|████████▎ | 376/455 [24:54<05:06,  3.88s/it]

Inputs shape: torch.Size([1, 943])


 83%|████████▎ | 377/455 [24:57<04:28,  3.45s/it]

Inputs shape: torch.Size([1, 746])


 83%|████████▎ | 378/455 [25:00<04:14,  3.30s/it]

Inputs shape: torch.Size([1, 797])


 83%|████████▎ | 379/455 [25:04<04:21,  3.44s/it]

Inputs shape: torch.Size([1, 991])


 84%|████████▎ | 380/455 [25:08<04:34,  3.66s/it]

Inputs shape: torch.Size([1, 1042])


 84%|████████▎ | 381/455 [25:12<04:33,  3.70s/it]

Inputs shape: torch.Size([1, 847])


 84%|████████▍ | 382/455 [25:16<04:46,  3.93s/it]

Inputs shape: torch.Size([1, 945])


 84%|████████▍ | 383/455 [25:20<04:38,  3.87s/it]

Inputs shape: torch.Size([1, 902])


 84%|████████▍ | 384/455 [25:23<04:11,  3.55s/it]

Inputs shape: torch.Size([1, 928])


 85%|████████▍ | 385/455 [25:27<04:23,  3.77s/it]

Inputs shape: torch.Size([1, 887])


 85%|████████▍ | 386/455 [25:31<04:20,  3.78s/it]

Inputs shape: torch.Size([1, 978])


 85%|████████▌ | 387/455 [25:32<03:26,  3.04s/it]

Inputs shape: torch.Size([1, 717])


 85%|████████▌ | 388/455 [25:35<03:25,  3.07s/it]

Inputs shape: torch.Size([1, 1070])


 85%|████████▌ | 389/455 [25:40<03:56,  3.59s/it]

Inputs shape: torch.Size([1, 1046])


 86%|████████▌ | 390/455 [25:43<03:47,  3.50s/it]

Inputs shape: torch.Size([1, 1154])


 86%|████████▌ | 391/455 [25:49<04:23,  4.12s/it]

Inputs shape: torch.Size([1, 936])


 86%|████████▌ | 392/455 [25:54<04:34,  4.36s/it]

Inputs shape: torch.Size([1, 854])


 86%|████████▋ | 393/455 [25:58<04:34,  4.43s/it]

Inputs shape: torch.Size([1, 1045])


 87%|████████▋ | 394/455 [26:01<04:03,  4.00s/it]

Inputs shape: torch.Size([1, 998])


 87%|████████▋ | 395/455 [26:05<03:47,  3.78s/it]

Inputs shape: torch.Size([1, 874])


 87%|████████▋ | 396/455 [26:08<03:29,  3.55s/it]

Inputs shape: torch.Size([1, 685])


 87%|████████▋ | 397/455 [26:10<03:07,  3.23s/it]

Inputs shape: torch.Size([1, 1029])


 87%|████████▋ | 398/455 [26:16<03:47,  3.99s/it]

Inputs shape: torch.Size([1, 778])


 88%|████████▊ | 399/455 [26:20<03:44,  4.02s/it]

Inputs shape: torch.Size([1, 829])


 88%|████████▊ | 400/455 [26:23<03:30,  3.83s/it]

Inputs shape: torch.Size([1, 857])


 88%|████████▊ | 401/455 [26:27<03:22,  3.75s/it]

Inputs shape: torch.Size([1, 893])


 88%|████████▊ | 402/455 [26:30<03:10,  3.59s/it]

Inputs shape: torch.Size([1, 964])


 89%|████████▊ | 403/455 [26:35<03:24,  3.93s/it]

Inputs shape: torch.Size([1, 804])


 89%|████████▉ | 404/455 [26:40<03:35,  4.22s/it]

Inputs shape: torch.Size([1, 743])


 89%|████████▉ | 405/455 [26:44<03:30,  4.20s/it]

Inputs shape: torch.Size([1, 999])


 89%|████████▉ | 406/455 [26:49<03:39,  4.48s/it]

Inputs shape: torch.Size([1, 1065])


 89%|████████▉ | 407/455 [26:53<03:32,  4.44s/it]

Inputs shape: torch.Size([1, 924])


 90%|████████▉ | 408/455 [26:58<03:32,  4.52s/it]

Inputs shape: torch.Size([1, 926])


 90%|████████▉ | 409/455 [27:01<03:01,  3.95s/it]

Inputs shape: torch.Size([1, 912])


 90%|█████████ | 410/455 [27:06<03:17,  4.39s/it]

Inputs shape: torch.Size([1, 981])


 90%|█████████ | 411/455 [27:10<03:09,  4.30s/it]

Inputs shape: torch.Size([1, 660])


 91%|█████████ | 412/455 [27:13<02:50,  3.96s/it]

Inputs shape: torch.Size([1, 895])


 91%|█████████ | 413/455 [27:18<03:00,  4.29s/it]

Inputs shape: torch.Size([1, 1041])


 91%|█████████ | 414/455 [27:24<03:06,  4.55s/it]

Inputs shape: torch.Size([1, 791])


 91%|█████████ | 415/455 [27:28<03:00,  4.51s/it]

Inputs shape: torch.Size([1, 896])


 91%|█████████▏| 416/455 [27:32<02:52,  4.42s/it]

Inputs shape: torch.Size([1, 954])


 92%|█████████▏| 417/455 [27:36<02:36,  4.12s/it]

Inputs shape: torch.Size([1, 970])


 92%|█████████▏| 418/455 [27:39<02:26,  3.95s/it]

Inputs shape: torch.Size([1, 1002])


 92%|█████████▏| 419/455 [27:44<02:27,  4.10s/it]

Inputs shape: torch.Size([1, 944])


 92%|█████████▏| 420/455 [27:48<02:27,  4.22s/it]

Inputs shape: torch.Size([1, 944])


 93%|█████████▎| 421/455 [27:52<02:18,  4.08s/it]

Inputs shape: torch.Size([1, 1082])


 93%|█████████▎| 422/455 [27:56<02:14,  4.08s/it]

Inputs shape: torch.Size([1, 1051])


 93%|█████████▎| 423/455 [28:01<02:15,  4.24s/it]

Inputs shape: torch.Size([1, 848])


 93%|█████████▎| 424/455 [28:03<01:56,  3.76s/it]

Inputs shape: torch.Size([1, 957])


 93%|█████████▎| 425/455 [28:07<01:53,  3.80s/it]

Inputs shape: torch.Size([1, 1019])


 94%|█████████▎| 426/455 [28:10<01:43,  3.56s/it]

Inputs shape: torch.Size([1, 937])


 94%|█████████▍| 427/455 [28:15<01:47,  3.86s/it]

Inputs shape: torch.Size([1, 841])


 94%|█████████▍| 428/455 [28:18<01:40,  3.71s/it]

Inputs shape: torch.Size([1, 871])


 94%|█████████▍| 429/455 [28:22<01:40,  3.86s/it]

Inputs shape: torch.Size([1, 871])


 95%|█████████▍| 430/455 [28:27<01:40,  4.01s/it]

Inputs shape: torch.Size([1, 995])


 95%|█████████▍| 431/455 [28:30<01:35,  3.97s/it]

Inputs shape: torch.Size([1, 895])


 95%|█████████▍| 432/455 [28:35<01:36,  4.20s/it]

Inputs shape: torch.Size([1, 812])


 95%|█████████▌| 433/455 [28:38<01:26,  3.94s/it]

Inputs shape: torch.Size([1, 827])


 95%|█████████▌| 434/455 [28:42<01:22,  3.95s/it]

Inputs shape: torch.Size([1, 822])


 96%|█████████▌| 435/455 [28:46<01:15,  3.78s/it]

Inputs shape: torch.Size([1, 1133])


 96%|█████████▌| 436/455 [28:50<01:11,  3.78s/it]

Inputs shape: torch.Size([1, 915])


 96%|█████████▌| 437/455 [28:52<01:02,  3.45s/it]

Inputs shape: torch.Size([1, 960])


 96%|█████████▋| 438/455 [28:56<01:00,  3.53s/it]

Inputs shape: torch.Size([1, 979])


 96%|█████████▋| 439/455 [29:00<00:58,  3.66s/it]

Inputs shape: torch.Size([1, 993])


 97%|█████████▋| 440/455 [29:03<00:53,  3.56s/it]

Inputs shape: torch.Size([1, 941])


 97%|█████████▋| 441/455 [29:07<00:51,  3.66s/it]

Inputs shape: torch.Size([1, 906])


 97%|█████████▋| 442/455 [29:09<00:42,  3.23s/it]

Inputs shape: torch.Size([1, 815])


 97%|█████████▋| 443/455 [29:13<00:40,  3.35s/it]

Inputs shape: torch.Size([1, 833])


 98%|█████████▊| 444/455 [29:18<00:43,  3.91s/it]

Inputs shape: torch.Size([1, 768])


 98%|█████████▊| 445/455 [29:22<00:37,  3.72s/it]

Inputs shape: torch.Size([1, 932])


 98%|█████████▊| 446/455 [29:25<00:32,  3.60s/it]

Inputs shape: torch.Size([1, 699])


 98%|█████████▊| 447/455 [29:28<00:28,  3.51s/it]

Inputs shape: torch.Size([1, 788])


 98%|█████████▊| 448/455 [29:32<00:26,  3.72s/it]

Inputs shape: torch.Size([1, 789])


 99%|█████████▊| 449/455 [29:35<00:21,  3.51s/it]

Inputs shape: torch.Size([1, 858])


 99%|█████████▉| 450/455 [29:39<00:17,  3.43s/it]

Inputs shape: torch.Size([1, 992])


 99%|█████████▉| 451/455 [29:42<00:13,  3.48s/it]

Inputs shape: torch.Size([1, 778])


 99%|█████████▉| 452/455 [29:46<00:10,  3.51s/it]

Inputs shape: torch.Size([1, 1014])


100%|█████████▉| 453/455 [29:50<00:07,  3.62s/it]

Inputs shape: torch.Size([1, 899])


100%|█████████▉| 454/455 [29:53<00:03,  3.58s/it]

Inputs shape: torch.Size([1, 831])


100%|██████████| 455/455 [29:58<00:00,  3.95s/it]


In [31]:
from nltk.translate.bleu_score import corpus_bleu

bleu_score = corpus_bleu(
    [[ref] for ref in references],
    predictions
)
print(f"BLEU Score: {bleu_score}")

BLEU Score: 0.527223052356641


In [32]:
%%capture
!pip install rouge_score

In [33]:
from rouge_score import rouge_scorer

scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
rouge_scores = [scorer.score(ref, pred) for ref, pred in zip(references, predictions)]

rouge1 = sum([score["rouge1"].fmeasure for score in rouge_scores]) / len(rouge_scores)
rouge2 = sum([score["rouge2"].fmeasure for score in rouge_scores]) / len(rouge_scores)
rougeL = sum([score["rougeL"].fmeasure for score in rouge_scores]) / len(rouge_scores)

print(f"ROUGE-1: {rouge1}")
print(f"ROUGE-2: {rouge2}")
print(f"ROUGE-L: {rougeL}")

ROUGE-1: 0.4723891178462233
ROUGE-2: 0.22256263988879998
ROUGE-L: 0.34859473014361825


In [39]:
from collections import Counter

from nltk.translate import bleu_score
from nltk.translate.bleu_score import SmoothingFunction
import numpy as np


def distinct(seqs):
    """ Calculate intra/inter distinct 1/2. """
    """ Recuperado de https://github.com/PaddlePaddle/models/blob/release/1.6/PaddleNLP/Research/Dialogue-PLATO/plato/metrics/metrics.py"""
    batch_size = len(seqs)
    intra_dist1, intra_dist2 = [], []
    unigrams_all, bigrams_all = Counter(), Counter()
    for seq in seqs:
        unigrams = Counter(seq)
        bigrams = Counter(zip(seq, seq[1:]))
        intra_dist1.append((len(unigrams)+1e-12) / (len(seq)+1e-5))
        intra_dist2.append((len(bigrams)+1e-12) / (max(0, len(seq)-1)+1e-5))

        unigrams_all.update(unigrams)
        bigrams_all.update(bigrams)

    inter_dist1 = (len(unigrams_all)+1e-12) / (sum(unigrams_all.values())+1e-5)
    inter_dist2 = (len(bigrams_all)+1e-12) / (sum(bigrams_all.values())+1e-5)
    intra_dist1 = np.average(intra_dist1)
    intra_dist2 = np.average(intra_dist2)
    return intra_dist1, intra_dist2, inter_dist1, inter_dist2

In [46]:
import nltk
nltk.download('punkt_tab')

tokenized_text = [nltk.word_tokenize(text) for text in predictions]
resultado = distinct(tokenized_text)
print(resultado)

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


(0.8174602211235673, 0.9857301007313511, 0.11541064977609411, 0.3757125912018547)
