# Descarga de los datos formateados

In [None]:
%%capture
!pip install gdown

In [None]:
import gdown

files_to_download = {
    "formatted_train.json": "1Fm7aPdCv6bguP7UgoCqdUVcm8xpDo18_",
    "formatted_test.json": "1a4YeF--Sks7WA1ZIQL2zDZx4teKk7p4m",
    "formatted_validation.json": "1PC9OhZhNZt8lFO9wifhydy0BrIYe8Dkm"
}

for destination, file_id in files_to_download.items():
    gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=1Fm7aPdCv6bguP7UgoCqdUVcm8xpDo18_
From (redirected): https://drive.google.com/uc?id=1Fm7aPdCv6bguP7UgoCqdUVcm8xpDo18_&confirm=t&uuid=9ab6e425-7ce3-4570-b3fe-4ec63db02e34
To: /content/formatted_train.json

  0%|          | 0.00/237M [00:00<?, ?B/s][A
  2%|▏         | 4.72M/237M [00:00<00:14, 16.6MB/s][A
  9%|▊         | 20.4M/237M [00:00<00:03, 63.2MB/s][A
 13%|█▎        | 30.4M/237M [00:00<00:04, 49.6MB/s][A
 17%|█▋        | 41.4M/237M [00:00<00:03, 63.4MB/s][A
 21%|██        | 50.3M/237M [00:00<00:03, 51.9MB/s][A
 25%|██▌       | 59.2M/237M [00:01<00:04, 44.0MB/s][A
 32%|███▏      | 76.0M/237M [00:01<00:03, 42.9MB/s][A
 38%|███▊      | 90.2M/237M [00:01<00:02, 56.9MB/s][A
 42%|████▏     | 98.6M/237M [00:01<00:02, 51.3MB/s][A
 45%|████▍     | 106M/237M [00:02<00:02, 55.9MB/s] [A
 48%|████▊     | 114M/237M [00:02<00:02, 46.8MB/s][A
 51%|█████     | 120M/237M [00:02<00:03, 34.6MB/s][A
 60%|██████    | 143M/237M [

# Implementación del modelo

In [None]:
%%capture
!pip install unsloth "xformers==0.0.28.post2"
# Also get the latest nightly Unsloth!
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install datasets

In [None]:
from datasets import load_dataset, Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq, TextStreamer, AutoTokenizer
from peft import AutoPeftModelForCausalLM
import json

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [None]:
with open("formatted_train.jsonl", "r") as f:
    data = [json.loads(line)["interaction"] for line in f]

dataset = Dataset.from_dict({"conversations": data})

In [None]:
max_seq_length = 2048

def initialize_model_and_tokenizer():
    base_model_name = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=base_model_name,
        max_seq_length=max_seq_length,
        dtype=None,
        load_in_4bit=True,
    )
    model = FastLanguageModel.get_peft_model(
        model,
        r=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_alpha=16,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=3407,
        use_rslora=False,
        loftq_config=None
    )
    tokenizer = get_chat_template(
        tokenizer,
        chat_template="llama-3.1"
    )

    return model, tokenizer

In [None]:
def preprocess_dataset(dataset, tokenizer):
    def formatting_prompts_func(examples):
        convos = examples["conversations"]
        texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
        return {"text": texts}

    formatted_dataset = dataset.map(formatting_prompts_func, batched=True)
    return formatted_dataset

In [None]:
used_seed = 3407

# Entrenamiento

In [None]:
with open("formatted_train.json", "r") as f:
    data = json.load(f)

dataset_second_train = Dataset.from_dict({"conversations": data})

In [None]:
training_configuration = {
    "lr_scheduler_type": "linear",
    "per_device_train_batch_size": 2,
    "gradient_accumulation_steps": 4,
    "max_steps": 200,
    "output_dir": "second_train_output",
    "learning_rate": 2e-4
}

In [None]:
model_second_train, tokenizer_second_train = initialize_model_and_tokenizer()

==((====))==  Unsloth 2024.11.10: Fast Llama patching. Transformers:4.46.2.
   \\   /|    GPU: NVIDIA A100-SXM4-40GB. Max memory: 39.564 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2024.11.10 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

In [None]:
formatted_dataset_second_train = preprocess_dataset(dataset_second_train, tokenizer_second_train)

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [None]:
trainer = SFTTrainer(
    model=model_second_train,
    tokenizer=tokenizer_second_train,
    train_dataset=formatted_dataset_second_train,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    args=TrainingArguments(
        learning_rate=5e-5,
        lr_scheduler_type=training_configuration["lr_scheduler_type"],
        per_device_train_batch_size=training_configuration["per_device_train_batch_size"],
        gradient_accumulation_steps=training_configuration["gradient_accumulation_steps"],
        num_train_epochs=1,
        max_steps=training_configuration["max_steps"],
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.001,
        warmup_steps=5,
        output_dir=training_configuration["output_dir"],
        seed=used_seed,
    )
)

Map (num_proc=2):   0%|          | 0/50000 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs


In [None]:
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.564 GB.
2.635 GB of memory reserved.


In [None]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 50,000 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 200
 "-____-"     Number of trainable parameters = 24,313,856
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmamunoz42[0m ([33mmamunoz42-pontificia-universidad-cat-lica-de-chile[0m). Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
1,2.2432
2,2.2561
3,2.1644
4,2.3263
5,2.3427
6,2.1763
7,2.1721
8,2.1916
9,2.0801
10,2.1056


In [None]:
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory         /max_memory*100, 3)
lora_percentage = round(used_memory_for_lora/max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

373.2538 seconds used for training.
6.22 minutes used for training.
Peak reserved memory = 3.658 GB.
Peak reserved memory for training = 1.023 GB.
Peak reserved memory % of max memory = 9.246 %.
Peak reserved memory for training % of max memory = 2.586 %.


In [None]:
model_second_train.save_pretrained("lora_model_second_train")
tokenizer_second_train.save_pretrained("lora_model_second_train")

('lora_model_second_train/tokenizer_config.json',
 'lora_model_second_train/special_tokens_map.json',
 'lora_model_second_train/tokenizer.json')

In [None]:
!zip -r lora_model_second_train.zip lora_model_second_train/

  adding: lora_model_second_train/ (stored 0%)
  adding: lora_model_second_train/tokenizer.json (deflated 85%)
  adding: lora_model_second_train/adapter_config.json (deflated 54%)
  adding: lora_model_second_train/README.md (deflated 66%)
  adding: lora_model_second_train/tokenizer_config.json (deflated 94%)
  adding: lora_model_second_train/adapter_model.safetensors (deflated 8%)
  adding: lora_model_second_train/special_tokens_map.json (deflated 71%)


# Carga del modelo entrenado

In [None]:
destination = "lora_model_second_train.zip"
gdown.download(f"https://drive.google.com/uc?id=1eHv0yWdVjgq-23AO_HcWR6ImoKyJtiL9", destination, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=1eHv0yWdVjgq-23AO_HcWR6ImoKyJtiL9
From (redirected): https://drive.google.com/uc?id=1eHv0yWdVjgq-23AO_HcWR6ImoKyJtiL9&confirm=t&uuid=0df34e24-6f29-4b3f-9a4b-2b7d1128699c
To: /content/lora_model_second_train.zip

  0%|          | 0.00/92.5M [00:00<?, ?B/s][A
  5%|▌         | 4.72M/92.5M [00:00<00:08, 9.78MB/s][A
 28%|██▊       | 25.7M/92.5M [00:00<00:02, 29.6MB/s][A
 37%|███▋      | 34.1M/92.5M [00:01<00:01, 35.2MB/s][A
 46%|████▌     | 42.5M/92.5M [00:01<00:01, 32.9MB/s][A
 55%|█████▍    | 50.9M/92.5M [00:01<00:01, 36.9MB/s][A
 64%|██████▍   | 59.2M/92.5M [00:01<00:01, 32.9MB/s][A
 73%|███████▎  | 67.6M/92.5M [00:01<00:00, 40.0MB/s][A
100%|██████████| 92.5M/92.5M [00:02<00:00, 38.7MB/s]


'lora_model_second_train.zip'

In [None]:
!unzip -q lora_model_second_train.zip -d ./

# Testeo del modelo

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "lora_model_train",
    max_seq_length = max_seq_length,
    dtype=None,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)

==((====))==  Unsloth 2024.12.4: Fast Llama patching. Transformers:4.46.3.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

Unsloth 2024.12.4 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 3072, padding_idx=128004)
        (layers): ModuleList(
          (0-27): 28 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lor

In [None]:
tokenizer = get_chat_template(
    tokenizer,
    chat_template="llama-3.1",
)

In [None]:
with open("formatted_test.json", "r") as f:
    data = json.load(f)

dataset = Dataset.from_dict({"conversations": data})

In [None]:
formatted_test_dataset = preprocess_dataset(dataset, tokenizer)

Map:   0%|          | 0/2277 [00:00<?, ? examples/s]

In [None]:
%%capture
!pip install tqdm

In [None]:
percentage = 0.2
sample_size = int(len(formatted_test_dataset) * percentage)
subset_test_dataset = formatted_test_dataset.shuffle(seed=42).select(range(sample_size))

print(f"Usando {sample_size} ejemplos para la evaluación.")

Usando 455 ejemplos para la evaluación.


In [None]:
from tqdm import tqdm

generation_args = {
    "max_new_tokens": 100,
    "temperature": 0.3,
    "use_cache": True,
    "top_p": 0.9,
    "top_k": 50,
}

predictions = []
references = []

for example in tqdm(subset_test_dataset):
    conversation = example["conversations"]
    system_message = conversation[0]
    user_assistant_turns = conversation[1:-2]

    max_turns = 10
    truncated_turns = user_assistant_turns[-max_turns:]

    messages = [system_message] + truncated_turns
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda")

    print(f"Inputs shape: {inputs.shape}")

    with torch.no_grad():
        outputs = model.generate(input_ids=inputs, **generation_args)

    generated_response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    if "assistant" in generated_response:
        generated_response = generated_response.split("assistant")[-1].strip()

    predictions.append(generated_response)

    references.append(conversation[-2]["content"])



  0%|          | 0/455 [00:00<?, ?it/s][AThe attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Inputs shape: torch.Size([1, 891])



  0%|          | 1/455 [00:06<52:23,  6.92s/it][A

Inputs shape: torch.Size([1, 971])



  0%|          | 2/455 [00:11<39:41,  5.26s/it][A

Inputs shape: torch.Size([1, 872])



  1%|          | 3/455 [00:15<35:33,  4.72s/it][A

Inputs shape: torch.Size([1, 766])



  1%|          | 4/455 [00:18<30:19,  4.03s/it][A

Inputs shape: torch.Size([1, 798])



  1%|          | 5/455 [00:21<29:55,  3.99s/it][A

Inputs shape: torch.Size([1, 771])



  1%|▏         | 6/455 [00:25<29:23,  3.93s/it][A

Inputs shape: torch.Size([1, 1041])



  2%|▏         | 7/455 [00:29<28:30,  3.82s/it][A

Inputs shape: torch.Size([1, 888])



  2%|▏         | 8/455 [00:32<27:04,  3.63s/it][A

Inputs shape: torch.Size([1, 873])



  2%|▏         | 9/455 [00:35<26:12,  3.53s/it][A

Inputs shape: torch.Size([1, 1427])



  2%|▏         | 10/455 [00:41<31:52,  4.30s/it][A

Inputs shape: torch.Size([1, 673])



  2%|▏         | 11/455 [00:45<30:42,  4.15s/it][A

Inputs shape: torch.Size([1, 823])



  3%|▎         | 12/455 [00:49<30:18,  4.10s/it][A

Inputs shape: torch.Size([1, 821])



  3%|▎         | 13/455 [00:53<28:18,  3.84s/it][A

Inputs shape: torch.Size([1, 938])



  3%|▎         | 14/455 [00:56<27:59,  3.81s/it][A

Inputs shape: torch.Size([1, 1156])



  3%|▎         | 15/455 [01:00<28:00,  3.82s/it][A

Inputs shape: torch.Size([1, 744])



  4%|▎         | 16/455 [01:03<26:58,  3.69s/it][A

Inputs shape: torch.Size([1, 1168])



  4%|▎         | 17/455 [01:09<30:33,  4.19s/it][A

Inputs shape: torch.Size([1, 776])



  4%|▍         | 18/455 [01:12<29:09,  4.00s/it][A

Inputs shape: torch.Size([1, 972])



  4%|▍         | 19/455 [01:17<31:23,  4.32s/it][A

Inputs shape: torch.Size([1, 957])



  4%|▍         | 20/455 [01:21<30:08,  4.16s/it][A

Inputs shape: torch.Size([1, 878])



  5%|▍         | 21/455 [01:24<27:17,  3.77s/it][A

Inputs shape: torch.Size([1, 1123])



  5%|▍         | 22/455 [01:28<26:38,  3.69s/it][A

Inputs shape: torch.Size([1, 962])



  5%|▌         | 23/455 [01:32<27:12,  3.78s/it][A

Inputs shape: torch.Size([1, 1086])



  5%|▌         | 24/455 [01:35<26:38,  3.71s/it][A

Inputs shape: torch.Size([1, 989])



  5%|▌         | 25/455 [01:39<26:15,  3.66s/it][A

Inputs shape: torch.Size([1, 919])



  6%|▌         | 26/455 [01:43<27:15,  3.81s/it][A

Inputs shape: torch.Size([1, 985])



  6%|▌         | 27/455 [01:48<30:14,  4.24s/it][A

Inputs shape: torch.Size([1, 916])



  6%|▌         | 28/455 [01:53<31:11,  4.38s/it][A

Inputs shape: torch.Size([1, 835])



  6%|▋         | 29/455 [01:57<29:43,  4.19s/it][A

Inputs shape: torch.Size([1, 1147])



  7%|▋         | 30/455 [02:01<31:10,  4.40s/it][A

Inputs shape: torch.Size([1, 1002])



  7%|▋         | 31/455 [02:05<29:15,  4.14s/it][A

Inputs shape: torch.Size([1, 736])



  7%|▋         | 32/455 [02:08<27:36,  3.92s/it][A

Inputs shape: torch.Size([1, 1213])



  7%|▋         | 33/455 [02:13<28:07,  4.00s/it][A

Inputs shape: torch.Size([1, 1109])



  7%|▋         | 34/455 [02:17<30:00,  4.28s/it][A

Inputs shape: torch.Size([1, 715])



  8%|▊         | 35/455 [02:21<28:16,  4.04s/it][A

Inputs shape: torch.Size([1, 762])



  8%|▊         | 36/455 [02:25<27:11,  3.89s/it][A

Inputs shape: torch.Size([1, 902])



  8%|▊         | 37/455 [02:28<26:03,  3.74s/it][A

Inputs shape: torch.Size([1, 1006])



  8%|▊         | 38/455 [02:33<27:55,  4.02s/it][A

Inputs shape: torch.Size([1, 812])



  9%|▊         | 39/455 [02:36<25:54,  3.74s/it][A

Inputs shape: torch.Size([1, 1060])



  9%|▉         | 40/455 [02:41<28:33,  4.13s/it][A

Inputs shape: torch.Size([1, 784])



  9%|▉         | 41/455 [02:44<27:41,  4.01s/it][A

Inputs shape: torch.Size([1, 794])



  9%|▉         | 42/455 [02:47<25:33,  3.71s/it][A

Inputs shape: torch.Size([1, 956])



  9%|▉         | 43/455 [02:51<24:41,  3.60s/it][A

Inputs shape: torch.Size([1, 942])



 10%|▉         | 44/455 [02:54<24:55,  3.64s/it][A

Inputs shape: torch.Size([1, 1093])



 10%|▉         | 45/455 [02:59<26:29,  3.88s/it][A

Inputs shape: torch.Size([1, 880])



 10%|█         | 46/455 [03:03<27:05,  3.97s/it][A

Inputs shape: torch.Size([1, 1079])



 10%|█         | 47/455 [03:07<27:36,  4.06s/it][A

Inputs shape: torch.Size([1, 680])



 11%|█         | 48/455 [03:11<26:17,  3.88s/it][A

Inputs shape: torch.Size([1, 1066])



 11%|█         | 49/455 [03:15<27:05,  4.00s/it][A

Inputs shape: torch.Size([1, 1165])



 11%|█         | 50/455 [03:19<26:51,  3.98s/it][A

Inputs shape: torch.Size([1, 994])



 11%|█         | 51/455 [03:23<27:03,  4.02s/it][A

Inputs shape: torch.Size([1, 816])



 11%|█▏        | 52/455 [03:26<25:33,  3.81s/it][A

Inputs shape: torch.Size([1, 957])



 12%|█▏        | 53/455 [03:29<23:53,  3.56s/it][A

Inputs shape: torch.Size([1, 837])



 12%|█▏        | 54/455 [03:34<25:01,  3.74s/it][A

Inputs shape: torch.Size([1, 993])



 12%|█▏        | 55/455 [03:39<27:12,  4.08s/it][A

Inputs shape: torch.Size([1, 794])



 12%|█▏        | 56/455 [03:43<27:01,  4.06s/it][A

Inputs shape: torch.Size([1, 742])



 13%|█▎        | 57/455 [03:48<28:59,  4.37s/it][A

Inputs shape: torch.Size([1, 822])



 13%|█▎        | 58/455 [03:51<27:10,  4.11s/it][A

Inputs shape: torch.Size([1, 787])



 13%|█▎        | 59/455 [03:54<25:15,  3.83s/it][A

Inputs shape: torch.Size([1, 892])



 13%|█▎        | 60/455 [03:58<25:52,  3.93s/it][A

Inputs shape: torch.Size([1, 1106])



 13%|█▎        | 61/455 [04:02<25:14,  3.84s/it][A

Inputs shape: torch.Size([1, 965])



 14%|█▎        | 62/455 [04:05<24:15,  3.70s/it][A

Inputs shape: torch.Size([1, 849])



 14%|█▍        | 63/455 [04:08<22:48,  3.49s/it][A

Inputs shape: torch.Size([1, 858])



 14%|█▍        | 64/455 [04:12<22:55,  3.52s/it][A

Inputs shape: torch.Size([1, 1339])



 14%|█▍        | 65/455 [04:18<27:07,  4.17s/it][A

Inputs shape: torch.Size([1, 796])



 15%|█▍        | 66/455 [04:21<24:39,  3.80s/it][A

Inputs shape: torch.Size([1, 659])



 15%|█▍        | 67/455 [04:24<24:11,  3.74s/it][A

Inputs shape: torch.Size([1, 1023])



 15%|█▍        | 68/455 [04:28<23:43,  3.68s/it][A

Inputs shape: torch.Size([1, 845])



 15%|█▌        | 69/455 [04:31<23:11,  3.60s/it][A

Inputs shape: torch.Size([1, 902])



 15%|█▌        | 70/455 [04:35<24:03,  3.75s/it][A

Inputs shape: torch.Size([1, 925])



 16%|█▌        | 71/455 [04:39<24:04,  3.76s/it][A

Inputs shape: torch.Size([1, 749])



 16%|█▌        | 72/455 [04:42<22:14,  3.49s/it][A

Inputs shape: torch.Size([1, 852])



 16%|█▌        | 73/455 [04:46<22:26,  3.53s/it][A

Inputs shape: torch.Size([1, 904])



 16%|█▋        | 74/455 [04:50<23:23,  3.68s/it][A

Inputs shape: torch.Size([1, 920])



 16%|█▋        | 75/455 [04:54<23:44,  3.75s/it][A

Inputs shape: torch.Size([1, 703])



 17%|█▋        | 76/455 [04:56<22:06,  3.50s/it][A

Inputs shape: torch.Size([1, 1056])



 17%|█▋        | 77/455 [05:02<26:09,  4.15s/it][A

Inputs shape: torch.Size([1, 742])



 17%|█▋        | 78/455 [05:06<25:52,  4.12s/it][A

Inputs shape: torch.Size([1, 717])



 17%|█▋        | 79/455 [05:10<25:15,  4.03s/it][A

Inputs shape: torch.Size([1, 1035])



 18%|█▊        | 80/455 [05:14<25:17,  4.05s/it][A

Inputs shape: torch.Size([1, 859])



 18%|█▊        | 81/455 [05:17<23:43,  3.81s/it][A

Inputs shape: torch.Size([1, 922])



 18%|█▊        | 82/455 [05:21<23:21,  3.76s/it][A

Inputs shape: torch.Size([1, 856])



 18%|█▊        | 83/455 [05:25<23:33,  3.80s/it][A

Inputs shape: torch.Size([1, 829])



 18%|█▊        | 84/455 [05:28<23:07,  3.74s/it][A

Inputs shape: torch.Size([1, 820])



 19%|█▊        | 85/455 [05:33<23:35,  3.83s/it][A

Inputs shape: torch.Size([1, 898])



 19%|█▉        | 86/455 [05:36<23:26,  3.81s/it][A

Inputs shape: torch.Size([1, 1002])



 19%|█▉        | 87/455 [05:40<23:27,  3.82s/it][A

Inputs shape: torch.Size([1, 727])



 19%|█▉        | 88/455 [05:43<22:30,  3.68s/it][A

Inputs shape: torch.Size([1, 969])



 20%|█▉        | 89/455 [05:47<22:19,  3.66s/it][A

Inputs shape: torch.Size([1, 875])



 20%|█▉        | 90/455 [05:50<21:24,  3.52s/it][A

Inputs shape: torch.Size([1, 897])



 20%|██        | 91/455 [05:54<21:36,  3.56s/it][A

Inputs shape: torch.Size([1, 818])



 20%|██        | 92/455 [05:58<22:43,  3.76s/it][A

Inputs shape: torch.Size([1, 1024])



 20%|██        | 93/455 [06:01<21:27,  3.56s/it][A

Inputs shape: torch.Size([1, 896])



 21%|██        | 94/455 [06:06<23:01,  3.83s/it][A

Inputs shape: torch.Size([1, 980])



 21%|██        | 95/455 [06:10<23:23,  3.90s/it][A

Inputs shape: torch.Size([1, 980])



 21%|██        | 96/455 [06:14<24:00,  4.01s/it][A

Inputs shape: torch.Size([1, 709])



 21%|██▏       | 97/455 [06:16<20:53,  3.50s/it][A

Inputs shape: torch.Size([1, 1002])



 22%|██▏       | 98/455 [06:21<22:20,  3.75s/it][A

Inputs shape: torch.Size([1, 975])



 22%|██▏       | 99/455 [06:23<20:25,  3.44s/it][A

Inputs shape: torch.Size([1, 941])



 22%|██▏       | 100/455 [06:28<22:12,  3.75s/it][A

Inputs shape: torch.Size([1, 956])



 22%|██▏       | 101/455 [06:31<20:52,  3.54s/it][A

Inputs shape: torch.Size([1, 746])



 22%|██▏       | 102/455 [06:34<20:36,  3.50s/it][A

Inputs shape: torch.Size([1, 944])



 23%|██▎       | 103/455 [06:38<20:56,  3.57s/it][A

Inputs shape: torch.Size([1, 1037])



 23%|██▎       | 104/455 [06:41<20:24,  3.49s/it][A

Inputs shape: torch.Size([1, 942])



 23%|██▎       | 105/455 [06:46<21:29,  3.68s/it][A

Inputs shape: torch.Size([1, 817])



 23%|██▎       | 106/455 [06:49<20:44,  3.57s/it][A

Inputs shape: torch.Size([1, 916])



 24%|██▎       | 107/455 [06:53<21:57,  3.78s/it][A

Inputs shape: torch.Size([1, 1001])



 24%|██▎       | 108/455 [06:56<20:52,  3.61s/it][A

Inputs shape: torch.Size([1, 1124])



 24%|██▍       | 109/455 [07:01<21:54,  3.80s/it][A

Inputs shape: torch.Size([1, 934])



 24%|██▍       | 110/455 [07:04<21:39,  3.77s/it][A

Inputs shape: torch.Size([1, 864])



 24%|██▍       | 111/455 [07:08<20:57,  3.66s/it][A

Inputs shape: torch.Size([1, 748])



 25%|██▍       | 112/455 [07:10<18:41,  3.27s/it][A

Inputs shape: torch.Size([1, 788])



 25%|██▍       | 113/455 [07:13<18:37,  3.27s/it][A

Inputs shape: torch.Size([1, 995])



 25%|██▌       | 114/455 [07:17<19:30,  3.43s/it][A

Inputs shape: torch.Size([1, 953])



 25%|██▌       | 115/455 [07:20<18:29,  3.26s/it][A

Inputs shape: torch.Size([1, 927])



 25%|██▌       | 116/455 [07:24<19:15,  3.41s/it][A

Inputs shape: torch.Size([1, 1011])



 26%|██▌       | 117/455 [07:28<20:21,  3.61s/it][A

Inputs shape: torch.Size([1, 868])



 26%|██▌       | 118/455 [07:31<20:02,  3.57s/it][A

Inputs shape: torch.Size([1, 946])



 26%|██▌       | 119/455 [07:36<22:05,  3.94s/it][A

Inputs shape: torch.Size([1, 932])



 26%|██▋       | 120/455 [07:40<21:11,  3.80s/it][A

Inputs shape: torch.Size([1, 981])



 27%|██▋       | 121/455 [07:42<19:15,  3.46s/it][A

Inputs shape: torch.Size([1, 830])



 27%|██▋       | 122/455 [07:45<18:53,  3.40s/it][A

Inputs shape: torch.Size([1, 787])



 27%|██▋       | 123/455 [07:48<16:42,  3.02s/it][A

Inputs shape: torch.Size([1, 879])



 27%|██▋       | 124/455 [07:52<18:36,  3.37s/it][A

Inputs shape: torch.Size([1, 1026])



 27%|██▋       | 125/455 [07:54<17:13,  3.13s/it][A

Inputs shape: torch.Size([1, 899])



 28%|██▊       | 126/455 [07:59<19:50,  3.62s/it][A

Inputs shape: torch.Size([1, 709])



 28%|██▊       | 127/455 [08:02<18:45,  3.43s/it][A

Inputs shape: torch.Size([1, 894])



 28%|██▊       | 128/455 [08:05<18:23,  3.38s/it][A

Inputs shape: torch.Size([1, 748])



 28%|██▊       | 129/455 [08:10<20:18,  3.74s/it][A

Inputs shape: torch.Size([1, 952])



 29%|██▊       | 130/455 [08:14<21:28,  3.96s/it][A

Inputs shape: torch.Size([1, 757])



 29%|██▉       | 131/455 [08:18<20:22,  3.77s/it][A

Inputs shape: torch.Size([1, 918])



 29%|██▉       | 132/455 [08:20<18:17,  3.40s/it][A

Inputs shape: torch.Size([1, 917])



 29%|██▉       | 133/455 [08:25<20:18,  3.78s/it][A

Inputs shape: torch.Size([1, 979])



 29%|██▉       | 134/455 [08:28<19:35,  3.66s/it][A

Inputs shape: torch.Size([1, 857])



 30%|██▉       | 135/455 [08:33<20:47,  3.90s/it][A

Inputs shape: torch.Size([1, 860])



 30%|██▉       | 136/455 [08:37<21:06,  3.97s/it][A

Inputs shape: torch.Size([1, 953])



 30%|███       | 137/455 [08:42<22:21,  4.22s/it][A

Inputs shape: torch.Size([1, 748])



 30%|███       | 138/455 [08:47<23:57,  4.54s/it][A

Inputs shape: torch.Size([1, 1104])



 31%|███       | 139/455 [08:52<24:18,  4.62s/it][A

Inputs shape: torch.Size([1, 902])



 31%|███       | 140/455 [08:55<22:12,  4.23s/it][A

Inputs shape: torch.Size([1, 858])



 31%|███       | 141/455 [08:58<20:33,  3.93s/it][A

Inputs shape: torch.Size([1, 807])



 31%|███       | 142/455 [09:02<20:32,  3.94s/it][A

Inputs shape: torch.Size([1, 818])



 31%|███▏      | 143/455 [09:06<20:01,  3.85s/it][A

Inputs shape: torch.Size([1, 921])



 32%|███▏      | 144/455 [09:10<19:59,  3.86s/it][A

Inputs shape: torch.Size([1, 860])



 32%|███▏      | 145/455 [09:14<20:45,  4.02s/it][A

Inputs shape: torch.Size([1, 879])



 32%|███▏      | 146/455 [09:18<20:49,  4.05s/it][A

Inputs shape: torch.Size([1, 1068])



 32%|███▏      | 147/455 [09:22<19:55,  3.88s/it][A

Inputs shape: torch.Size([1, 864])



 33%|███▎      | 148/455 [09:26<20:07,  3.93s/it][A

Inputs shape: torch.Size([1, 936])



 33%|███▎      | 149/455 [09:29<19:00,  3.73s/it][A

Inputs shape: torch.Size([1, 970])



 33%|███▎      | 150/455 [09:34<19:58,  3.93s/it][A

Inputs shape: torch.Size([1, 996])



 33%|███▎      | 151/455 [09:37<19:23,  3.83s/it][A

Inputs shape: torch.Size([1, 1048])



 33%|███▎      | 152/455 [09:41<18:38,  3.69s/it][A

Inputs shape: torch.Size([1, 1099])



 34%|███▎      | 153/455 [09:46<21:12,  4.21s/it][A

Inputs shape: torch.Size([1, 790])



 34%|███▍      | 154/455 [09:49<19:21,  3.86s/it][A

Inputs shape: torch.Size([1, 949])



 34%|███▍      | 155/455 [09:54<20:25,  4.08s/it][A

Inputs shape: torch.Size([1, 972])



 34%|███▍      | 156/455 [09:57<19:17,  3.87s/it][A

Inputs shape: torch.Size([1, 973])



 35%|███▍      | 157/455 [10:01<19:20,  3.89s/it][A

Inputs shape: torch.Size([1, 1065])



 35%|███▍      | 158/455 [10:04<18:35,  3.76s/it][A

Inputs shape: torch.Size([1, 934])



 35%|███▍      | 159/455 [10:08<17:50,  3.62s/it][A

Inputs shape: torch.Size([1, 1238])



 35%|███▌      | 160/455 [10:12<18:23,  3.74s/it][A

Inputs shape: torch.Size([1, 1191])



 35%|███▌      | 161/455 [10:16<18:58,  3.87s/it][A

Inputs shape: torch.Size([1, 797])



 36%|███▌      | 162/455 [10:19<18:22,  3.76s/it][A

Inputs shape: torch.Size([1, 1101])



 36%|███▌      | 163/455 [10:24<19:05,  3.92s/it][A

Inputs shape: torch.Size([1, 857])



 36%|███▌      | 164/455 [10:27<18:29,  3.81s/it][A

Inputs shape: torch.Size([1, 799])



 36%|███▋      | 165/455 [10:31<18:33,  3.84s/it][A

Inputs shape: torch.Size([1, 957])



 36%|███▋      | 166/455 [10:35<18:33,  3.85s/it][A

Inputs shape: torch.Size([1, 868])



 37%|███▋      | 167/455 [10:38<17:53,  3.73s/it][A

Inputs shape: torch.Size([1, 1026])



 37%|███▋      | 168/455 [10:43<19:35,  4.09s/it][A

Inputs shape: torch.Size([1, 839])



 37%|███▋      | 169/455 [10:47<19:11,  4.03s/it][A

Inputs shape: torch.Size([1, 690])



 37%|███▋      | 170/455 [10:50<17:48,  3.75s/it][A

Inputs shape: torch.Size([1, 1147])



 38%|███▊      | 171/455 [10:55<19:13,  4.06s/it][A

Inputs shape: torch.Size([1, 872])



 38%|███▊      | 172/455 [10:58<17:42,  3.75s/it][A

Inputs shape: torch.Size([1, 992])



 38%|███▊      | 173/455 [11:03<19:05,  4.06s/it][A

Inputs shape: torch.Size([1, 973])



 38%|███▊      | 174/455 [11:07<19:37,  4.19s/it][A

Inputs shape: torch.Size([1, 837])



 38%|███▊      | 175/455 [11:11<19:16,  4.13s/it][A

Inputs shape: torch.Size([1, 724])



 39%|███▊      | 176/455 [11:14<17:12,  3.70s/it][A

Inputs shape: torch.Size([1, 893])



 39%|███▉      | 177/455 [11:19<18:27,  3.98s/it][A

Inputs shape: torch.Size([1, 1037])



 39%|███▉      | 178/455 [11:22<17:58,  3.89s/it][A

Inputs shape: torch.Size([1, 851])



 39%|███▉      | 179/455 [11:27<18:21,  3.99s/it][A

Inputs shape: torch.Size([1, 855])



 40%|███▉      | 180/455 [11:30<18:04,  3.94s/it][A

Inputs shape: torch.Size([1, 925])



 40%|███▉      | 181/455 [11:33<15:35,  3.41s/it][A

Inputs shape: torch.Size([1, 1085])



 40%|████      | 182/455 [11:37<16:08,  3.55s/it][A

Inputs shape: torch.Size([1, 875])



 40%|████      | 183/455 [11:40<15:30,  3.42s/it][A

Inputs shape: torch.Size([1, 945])



 40%|████      | 184/455 [11:44<17:03,  3.78s/it][A

Inputs shape: torch.Size([1, 763])



 41%|████      | 185/455 [11:48<16:49,  3.74s/it][A

Inputs shape: torch.Size([1, 956])



 41%|████      | 186/455 [11:51<15:50,  3.54s/it][A

Inputs shape: torch.Size([1, 807])



 41%|████      | 187/455 [11:55<16:24,  3.67s/it][A

Inputs shape: torch.Size([1, 941])



 41%|████▏     | 188/455 [11:59<16:12,  3.64s/it][A

Inputs shape: torch.Size([1, 1072])



 42%|████▏     | 189/455 [12:03<16:46,  3.78s/it][A

Inputs shape: torch.Size([1, 761])



 42%|████▏     | 190/455 [12:06<16:18,  3.69s/it][A

Inputs shape: torch.Size([1, 1115])



 42%|████▏     | 191/455 [12:11<17:16,  3.93s/it][A

Inputs shape: torch.Size([1, 761])



 42%|████▏     | 192/455 [12:14<16:58,  3.87s/it][A

Inputs shape: torch.Size([1, 694])



 42%|████▏     | 193/455 [12:17<15:27,  3.54s/it][A

Inputs shape: torch.Size([1, 838])



 43%|████▎     | 194/455 [12:22<17:38,  4.06s/it][A

Inputs shape: torch.Size([1, 928])



 43%|████▎     | 195/455 [12:27<18:03,  4.17s/it][A

Inputs shape: torch.Size([1, 930])



 43%|████▎     | 196/455 [12:30<16:43,  3.87s/it][A

Inputs shape: torch.Size([1, 1059])



 43%|████▎     | 197/455 [12:33<16:02,  3.73s/it][A

Inputs shape: torch.Size([1, 646])



 44%|████▎     | 198/455 [12:37<15:23,  3.59s/it][A

Inputs shape: torch.Size([1, 844])



 44%|████▎     | 199/455 [12:41<16:21,  3.83s/it][A

Inputs shape: torch.Size([1, 955])



 44%|████▍     | 200/455 [12:45<16:18,  3.84s/it][A

Inputs shape: torch.Size([1, 928])



 44%|████▍     | 201/455 [12:50<17:42,  4.18s/it][A

Inputs shape: torch.Size([1, 877])



 44%|████▍     | 202/455 [12:54<17:17,  4.10s/it][A

Inputs shape: torch.Size([1, 976])



 45%|████▍     | 203/455 [12:57<16:15,  3.87s/it][A

Inputs shape: torch.Size([1, 1229])



 45%|████▍     | 204/455 [13:03<18:11,  4.35s/it][A

Inputs shape: torch.Size([1, 845])



 45%|████▌     | 205/455 [13:06<16:33,  3.97s/it][A

Inputs shape: torch.Size([1, 1051])



 45%|████▌     | 206/455 [13:09<15:44,  3.79s/it][A

Inputs shape: torch.Size([1, 724])



 45%|████▌     | 207/455 [13:12<15:00,  3.63s/it][A

Inputs shape: torch.Size([1, 847])



 46%|████▌     | 208/455 [13:16<14:56,  3.63s/it][A

Inputs shape: torch.Size([1, 1064])



 46%|████▌     | 209/455 [13:20<15:52,  3.87s/it][A

Inputs shape: torch.Size([1, 928])



 46%|████▌     | 210/455 [13:25<16:24,  4.02s/it][A

Inputs shape: torch.Size([1, 779])



 46%|████▋     | 211/455 [13:29<16:33,  4.07s/it][A

Inputs shape: torch.Size([1, 933])



 47%|████▋     | 212/455 [13:32<14:43,  3.63s/it][A

Inputs shape: torch.Size([1, 1062])



 47%|████▋     | 213/455 [13:36<15:19,  3.80s/it][A

Inputs shape: torch.Size([1, 892])



 47%|████▋     | 214/455 [13:39<14:49,  3.69s/it][A

Inputs shape: torch.Size([1, 852])



 47%|████▋     | 215/455 [13:44<16:00,  4.00s/it][A

Inputs shape: torch.Size([1, 1064])



 47%|████▋     | 216/455 [13:48<16:36,  4.17s/it][A

Inputs shape: torch.Size([1, 1036])



 48%|████▊     | 217/455 [13:53<16:47,  4.23s/it][A

Inputs shape: torch.Size([1, 957])



 48%|████▊     | 218/455 [13:56<15:50,  4.01s/it][A

Inputs shape: torch.Size([1, 857])



 48%|████▊     | 219/455 [14:01<16:02,  4.08s/it][A

Inputs shape: torch.Size([1, 1068])



 48%|████▊     | 220/455 [14:03<14:28,  3.70s/it][A

Inputs shape: torch.Size([1, 1092])



 49%|████▊     | 221/455 [14:07<14:46,  3.79s/it][A

Inputs shape: torch.Size([1, 829])



 49%|████▉     | 222/455 [14:11<14:28,  3.73s/it][A

Inputs shape: torch.Size([1, 606])



 49%|████▉     | 223/455 [14:16<15:22,  3.98s/it][A

Inputs shape: torch.Size([1, 1158])



 49%|████▉     | 224/455 [14:19<14:44,  3.83s/it][A

Inputs shape: torch.Size([1, 1261])



 49%|████▉     | 225/455 [14:23<15:23,  4.02s/it][A

Inputs shape: torch.Size([1, 747])



 50%|████▉     | 226/455 [14:26<13:52,  3.64s/it][A

Inputs shape: torch.Size([1, 876])



 50%|████▉     | 227/455 [14:30<14:19,  3.77s/it][A

Inputs shape: torch.Size([1, 1069])



 50%|█████     | 228/455 [14:34<14:13,  3.76s/it][A

Inputs shape: torch.Size([1, 989])



 50%|█████     | 229/455 [14:38<14:24,  3.82s/it][A

Inputs shape: torch.Size([1, 1137])



 51%|█████     | 230/455 [14:42<14:10,  3.78s/it][A

Inputs shape: torch.Size([1, 870])



 51%|█████     | 231/455 [14:45<13:40,  3.66s/it][A

Inputs shape: torch.Size([1, 865])



 51%|█████     | 232/455 [14:49<13:51,  3.73s/it][A

Inputs shape: torch.Size([1, 920])



 51%|█████     | 233/455 [14:52<13:30,  3.65s/it][A

Inputs shape: torch.Size([1, 954])



 51%|█████▏    | 234/455 [14:56<13:35,  3.69s/it][A

Inputs shape: torch.Size([1, 936])



 52%|█████▏    | 235/455 [15:00<13:40,  3.73s/it][A

Inputs shape: torch.Size([1, 792])



 52%|█████▏    | 236/455 [15:04<13:44,  3.76s/it][A

Inputs shape: torch.Size([1, 971])



 52%|█████▏    | 237/455 [15:08<13:34,  3.73s/it][A

Inputs shape: torch.Size([1, 893])



 52%|█████▏    | 238/455 [15:12<14:18,  3.96s/it][A

Inputs shape: torch.Size([1, 918])



 53%|█████▎    | 239/455 [15:15<13:28,  3.74s/it][A

Inputs shape: torch.Size([1, 867])



 53%|█████▎    | 240/455 [15:19<13:33,  3.78s/it][A

Inputs shape: torch.Size([1, 849])



 53%|█████▎    | 241/455 [15:22<12:49,  3.60s/it][A

Inputs shape: torch.Size([1, 715])



 53%|█████▎    | 242/455 [15:26<12:22,  3.48s/it][A

Inputs shape: torch.Size([1, 843])



 53%|█████▎    | 243/455 [15:29<12:00,  3.40s/it][A

Inputs shape: torch.Size([1, 927])



 54%|█████▎    | 244/455 [15:33<12:44,  3.63s/it][A

Inputs shape: torch.Size([1, 733])



 54%|█████▍    | 245/455 [15:36<12:22,  3.54s/it][A

Inputs shape: torch.Size([1, 1019])



 54%|█████▍    | 246/455 [15:41<13:14,  3.80s/it][A

Inputs shape: torch.Size([1, 769])



 54%|█████▍    | 247/455 [15:44<12:41,  3.66s/it][A

Inputs shape: torch.Size([1, 1057])



 55%|█████▍    | 248/455 [15:48<12:41,  3.68s/it][A

Inputs shape: torch.Size([1, 929])



 55%|█████▍    | 249/455 [15:51<12:10,  3.54s/it][A

Inputs shape: torch.Size([1, 1001])



 55%|█████▍    | 250/455 [15:54<11:50,  3.46s/it][A

Inputs shape: torch.Size([1, 934])



 55%|█████▌    | 251/455 [15:58<12:14,  3.60s/it][A

Inputs shape: torch.Size([1, 657])



 55%|█████▌    | 252/455 [16:01<11:52,  3.51s/it][A

Inputs shape: torch.Size([1, 773])



 56%|█████▌    | 253/455 [16:05<11:49,  3.51s/it][A

Inputs shape: torch.Size([1, 915])



 56%|█████▌    | 254/455 [16:11<14:00,  4.18s/it][A

Inputs shape: torch.Size([1, 821])



 56%|█████▌    | 255/455 [16:14<13:26,  4.03s/it][A

Inputs shape: torch.Size([1, 986])



 56%|█████▋    | 256/455 [16:20<14:50,  4.48s/it][A

Inputs shape: torch.Size([1, 873])



 56%|█████▋    | 257/455 [16:23<13:29,  4.09s/it][A

Inputs shape: torch.Size([1, 863])



 57%|█████▋    | 258/455 [16:26<12:11,  3.71s/it][A

Inputs shape: torch.Size([1, 1041])



 57%|█████▋    | 259/455 [16:31<13:06,  4.01s/it][A

Inputs shape: torch.Size([1, 1018])



 57%|█████▋    | 260/455 [16:34<12:46,  3.93s/it][A

Inputs shape: torch.Size([1, 737])



 57%|█████▋    | 261/455 [16:38<12:09,  3.76s/it][A

Inputs shape: torch.Size([1, 859])



 58%|█████▊    | 262/455 [16:42<12:11,  3.79s/it][A

Inputs shape: torch.Size([1, 971])



 58%|█████▊    | 263/455 [16:46<13:08,  4.11s/it][A

Inputs shape: torch.Size([1, 940])



 58%|█████▊    | 264/455 [16:50<13:00,  4.09s/it][A

Inputs shape: torch.Size([1, 714])



 58%|█████▊    | 265/455 [16:53<11:50,  3.74s/it][A

Inputs shape: torch.Size([1, 812])



 58%|█████▊    | 266/455 [16:57<11:22,  3.61s/it][A

Inputs shape: torch.Size([1, 742])



 59%|█████▊    | 267/455 [17:00<10:46,  3.44s/it][A

Inputs shape: torch.Size([1, 1080])



 59%|█████▉    | 268/455 [17:03<10:52,  3.49s/it][A

Inputs shape: torch.Size([1, 1027])



 59%|█████▉    | 269/455 [17:08<11:34,  3.73s/it][A

Inputs shape: torch.Size([1, 794])



 59%|█████▉    | 270/455 [17:11<11:18,  3.67s/it][A

Inputs shape: torch.Size([1, 902])



 60%|█████▉    | 271/455 [17:15<11:26,  3.73s/it][A

Inputs shape: torch.Size([1, 1047])



 60%|█████▉    | 272/455 [17:19<11:57,  3.92s/it][A

Inputs shape: torch.Size([1, 1120])



 60%|██████    | 273/455 [17:24<12:45,  4.20s/it][A

Inputs shape: torch.Size([1, 903])



 60%|██████    | 274/455 [17:28<11:55,  3.95s/it][A

Inputs shape: torch.Size([1, 890])



 60%|██████    | 275/455 [17:31<11:30,  3.83s/it][A

Inputs shape: torch.Size([1, 902])



 61%|██████    | 276/455 [17:34<10:47,  3.62s/it][A

Inputs shape: torch.Size([1, 907])



 61%|██████    | 277/455 [17:39<12:08,  4.09s/it][A

Inputs shape: torch.Size([1, 753])



 61%|██████    | 278/455 [17:42<11:06,  3.77s/it][A

Inputs shape: torch.Size([1, 1021])



 61%|██████▏   | 279/455 [17:46<10:44,  3.66s/it][A

Inputs shape: torch.Size([1, 888])



 62%|██████▏   | 280/455 [17:49<10:32,  3.61s/it][A

Inputs shape: torch.Size([1, 899])



 62%|██████▏   | 281/455 [17:53<10:32,  3.64s/it][A

Inputs shape: torch.Size([1, 792])



 62%|██████▏   | 282/455 [17:57<10:20,  3.58s/it][A

Inputs shape: torch.Size([1, 753])



 62%|██████▏   | 283/455 [18:00<10:08,  3.54s/it][A

Inputs shape: torch.Size([1, 839])



 62%|██████▏   | 284/455 [18:03<09:32,  3.35s/it][A

Inputs shape: torch.Size([1, 998])



 63%|██████▎   | 285/455 [18:06<09:32,  3.37s/it][A

Inputs shape: torch.Size([1, 691])



 63%|██████▎   | 286/455 [18:10<09:34,  3.40s/it][A

Inputs shape: torch.Size([1, 990])



 63%|██████▎   | 287/455 [18:15<11:03,  3.95s/it][A

Inputs shape: torch.Size([1, 935])



 63%|██████▎   | 288/455 [18:20<11:41,  4.20s/it][A

Inputs shape: torch.Size([1, 888])



 64%|██████▎   | 289/455 [18:24<11:31,  4.17s/it][A

Inputs shape: torch.Size([1, 804])



 64%|██████▎   | 290/455 [18:27<10:27,  3.80s/it][A

Inputs shape: torch.Size([1, 985])



 64%|██████▍   | 291/455 [18:31<10:24,  3.81s/it][A

Inputs shape: torch.Size([1, 698])



 64%|██████▍   | 292/455 [18:34<10:09,  3.74s/it][A

Inputs shape: torch.Size([1, 1073])



 64%|██████▍   | 293/455 [18:38<09:48,  3.63s/it][A

Inputs shape: torch.Size([1, 1109])



 65%|██████▍   | 294/455 [18:43<11:20,  4.22s/it][A

Inputs shape: torch.Size([1, 774])



 65%|██████▍   | 295/455 [18:46<10:25,  3.91s/it][A

Inputs shape: torch.Size([1, 801])



 65%|██████▌   | 296/455 [18:51<10:30,  3.97s/it][A

Inputs shape: torch.Size([1, 1043])



 65%|██████▌   | 297/455 [18:55<10:31,  4.00s/it][A

Inputs shape: torch.Size([1, 909])



 65%|██████▌   | 298/455 [18:58<09:48,  3.75s/it][A

Inputs shape: torch.Size([1, 995])



 66%|██████▌   | 299/455 [19:01<09:35,  3.69s/it][A

Inputs shape: torch.Size([1, 833])



 66%|██████▌   | 300/455 [19:05<09:38,  3.73s/it][A

Inputs shape: torch.Size([1, 833])



 66%|██████▌   | 301/455 [19:09<09:48,  3.82s/it][A

Inputs shape: torch.Size([1, 924])



 66%|██████▋   | 302/455 [19:14<10:11,  4.00s/it][A

Inputs shape: torch.Size([1, 787])



 67%|██████▋   | 303/455 [19:17<09:21,  3.70s/it][A

Inputs shape: torch.Size([1, 806])



 67%|██████▋   | 304/455 [19:20<09:09,  3.64s/it][A

Inputs shape: torch.Size([1, 760])



 67%|██████▋   | 305/455 [19:25<09:42,  3.88s/it][A

Inputs shape: torch.Size([1, 865])



 67%|██████▋   | 306/455 [19:28<09:22,  3.77s/it][A

Inputs shape: torch.Size([1, 1106])



 67%|██████▋   | 307/455 [19:32<09:32,  3.87s/it][A

Inputs shape: torch.Size([1, 686])



 68%|██████▊   | 308/455 [19:36<09:50,  4.02s/it][A

Inputs shape: torch.Size([1, 723])



 68%|██████▊   | 309/455 [19:40<09:41,  3.98s/it][A

Inputs shape: torch.Size([1, 911])



 68%|██████▊   | 310/455 [19:43<08:59,  3.72s/it][A

Inputs shape: torch.Size([1, 807])



 68%|██████▊   | 311/455 [19:47<09:07,  3.80s/it][A

Inputs shape: torch.Size([1, 891])



 69%|██████▊   | 312/455 [19:51<08:39,  3.63s/it][A

Inputs shape: torch.Size([1, 718])



 69%|██████▉   | 313/455 [19:54<08:33,  3.61s/it][A

Inputs shape: torch.Size([1, 869])



 69%|██████▉   | 314/455 [19:57<08:05,  3.45s/it][A

Inputs shape: torch.Size([1, 880])



 69%|██████▉   | 315/455 [20:01<08:02,  3.45s/it][A

Inputs shape: torch.Size([1, 833])



 69%|██████▉   | 316/455 [20:04<08:03,  3.48s/it][A

Inputs shape: torch.Size([1, 900])



 70%|██████▉   | 317/455 [20:08<08:08,  3.54s/it][A

Inputs shape: torch.Size([1, 773])



 70%|██████▉   | 318/455 [20:11<07:58,  3.49s/it][A

Inputs shape: torch.Size([1, 932])



 70%|███████   | 319/455 [20:15<08:14,  3.64s/it][A

Inputs shape: torch.Size([1, 826])



 70%|███████   | 320/455 [20:19<07:54,  3.51s/it][A

Inputs shape: torch.Size([1, 1042])



 71%|███████   | 321/455 [20:22<07:26,  3.33s/it][A

Inputs shape: torch.Size([1, 780])



 71%|███████   | 322/455 [20:25<07:21,  3.32s/it][A

Inputs shape: torch.Size([1, 1013])



 71%|███████   | 323/455 [20:28<07:20,  3.34s/it][A

Inputs shape: torch.Size([1, 1020])



 71%|███████   | 324/455 [20:33<08:18,  3.80s/it][A

Inputs shape: torch.Size([1, 889])



 71%|███████▏  | 325/455 [20:37<08:20,  3.85s/it][A

Inputs shape: torch.Size([1, 1274])



 72%|███████▏  | 326/455 [20:42<08:47,  4.09s/it][A

Inputs shape: torch.Size([1, 849])



 72%|███████▏  | 327/455 [20:45<08:16,  3.88s/it][A

Inputs shape: torch.Size([1, 818])



 72%|███████▏  | 328/455 [20:48<07:34,  3.58s/it][A

Inputs shape: torch.Size([1, 814])



 72%|███████▏  | 329/455 [20:52<07:33,  3.60s/it][A

Inputs shape: torch.Size([1, 997])



 73%|███████▎  | 330/455 [20:55<07:27,  3.58s/it][A

Inputs shape: torch.Size([1, 980])



 73%|███████▎  | 331/455 [21:00<08:13,  3.98s/it][A

Inputs shape: torch.Size([1, 1108])



 73%|███████▎  | 332/455 [21:03<07:39,  3.73s/it][A

Inputs shape: torch.Size([1, 908])



 73%|███████▎  | 333/455 [21:07<07:40,  3.77s/it][A

Inputs shape: torch.Size([1, 940])



 73%|███████▎  | 334/455 [21:11<07:45,  3.85s/it][A

Inputs shape: torch.Size([1, 1149])



 74%|███████▎  | 335/455 [21:15<07:46,  3.89s/it][A

Inputs shape: torch.Size([1, 868])



 74%|███████▍  | 336/455 [21:19<07:36,  3.84s/it][A

Inputs shape: torch.Size([1, 833])



 74%|███████▍  | 337/455 [21:23<07:50,  3.99s/it][A

Inputs shape: torch.Size([1, 1160])



 74%|███████▍  | 338/455 [21:28<08:07,  4.17s/it][A

Inputs shape: torch.Size([1, 1087])



 75%|███████▍  | 339/455 [21:32<08:05,  4.19s/it][A

Inputs shape: torch.Size([1, 1024])



 75%|███████▍  | 340/455 [21:36<07:39,  4.00s/it][A

Inputs shape: torch.Size([1, 1111])



 75%|███████▍  | 341/455 [21:39<07:18,  3.85s/it][A

Inputs shape: torch.Size([1, 890])



 75%|███████▌  | 342/455 [21:43<07:15,  3.86s/it][A

Inputs shape: torch.Size([1, 1094])



 75%|███████▌  | 343/455 [21:47<07:23,  3.96s/it][A

Inputs shape: torch.Size([1, 848])



 76%|███████▌  | 344/455 [21:51<07:30,  4.06s/it][A

Inputs shape: torch.Size([1, 862])



 76%|███████▌  | 345/455 [21:55<07:25,  4.05s/it][A

Inputs shape: torch.Size([1, 876])



 76%|███████▌  | 346/455 [21:58<06:45,  3.72s/it][A

Inputs shape: torch.Size([1, 992])



 76%|███████▋  | 347/455 [22:03<07:03,  3.93s/it][A

Inputs shape: torch.Size([1, 1081])



 76%|███████▋  | 348/455 [22:07<06:55,  3.89s/it][A

Inputs shape: torch.Size([1, 880])



 77%|███████▋  | 349/455 [22:10<06:47,  3.84s/it][A

Inputs shape: torch.Size([1, 852])



 77%|███████▋  | 350/455 [22:14<06:49,  3.90s/it][A

Inputs shape: torch.Size([1, 1118])



 77%|███████▋  | 351/455 [22:19<07:09,  4.13s/it][A

Inputs shape: torch.Size([1, 881])



 77%|███████▋  | 352/455 [22:22<06:29,  3.78s/it][A

Inputs shape: torch.Size([1, 850])



 78%|███████▊  | 353/455 [22:26<06:43,  3.96s/it][A

Inputs shape: torch.Size([1, 766])



 78%|███████▊  | 354/455 [22:31<07:16,  4.32s/it][A

Inputs shape: torch.Size([1, 904])



 78%|███████▊  | 355/455 [22:36<07:16,  4.37s/it][A

Inputs shape: torch.Size([1, 951])



 78%|███████▊  | 356/455 [22:40<07:02,  4.26s/it][A

Inputs shape: torch.Size([1, 912])



 78%|███████▊  | 357/455 [22:43<06:33,  4.02s/it][A

Inputs shape: torch.Size([1, 1064])



 79%|███████▊  | 358/455 [22:48<06:36,  4.09s/it][A

Inputs shape: torch.Size([1, 883])



 79%|███████▉  | 359/455 [22:52<06:28,  4.05s/it][A

Inputs shape: torch.Size([1, 861])



 79%|███████▉  | 360/455 [22:56<06:22,  4.03s/it][A

Inputs shape: torch.Size([1, 952])



 79%|███████▉  | 361/455 [22:59<06:04,  3.87s/it][A

Inputs shape: torch.Size([1, 907])



 80%|███████▉  | 362/455 [23:04<06:29,  4.19s/it][A

Inputs shape: torch.Size([1, 841])



 80%|███████▉  | 363/455 [23:08<06:28,  4.22s/it][A

Inputs shape: torch.Size([1, 898])



 80%|████████  | 364/455 [23:12<06:00,  3.96s/it][A

Inputs shape: torch.Size([1, 663])



 80%|████████  | 365/455 [23:15<05:30,  3.68s/it][A

Inputs shape: torch.Size([1, 1081])



 80%|████████  | 366/455 [23:19<05:48,  3.92s/it][A

Inputs shape: torch.Size([1, 854])



 81%|████████  | 367/455 [23:22<05:18,  3.62s/it][A

Inputs shape: torch.Size([1, 1052])



 81%|████████  | 368/455 [23:27<05:38,  3.89s/it][A

Inputs shape: torch.Size([1, 973])



 81%|████████  | 369/455 [23:32<06:06,  4.26s/it][A

Inputs shape: torch.Size([1, 998])



 81%|████████▏ | 370/455 [23:36<06:08,  4.33s/it][A

Inputs shape: torch.Size([1, 1087])



 82%|████████▏ | 371/455 [23:40<05:45,  4.12s/it][A

Inputs shape: torch.Size([1, 1266])



 82%|████████▏ | 372/455 [23:44<05:34,  4.03s/it][A

Inputs shape: torch.Size([1, 988])



 82%|████████▏ | 373/455 [23:48<05:26,  3.98s/it][A

Inputs shape: torch.Size([1, 967])



 82%|████████▏ | 374/455 [23:53<05:59,  4.44s/it][A

Inputs shape: torch.Size([1, 830])



 82%|████████▏ | 375/455 [23:57<05:33,  4.17s/it][A

Inputs shape: torch.Size([1, 936])



 83%|████████▎ | 376/455 [24:00<05:12,  3.95s/it][A

Inputs shape: torch.Size([1, 943])



 83%|████████▎ | 377/455 [24:03<04:32,  3.50s/it][A

Inputs shape: torch.Size([1, 746])



 83%|████████▎ | 378/455 [24:06<04:29,  3.50s/it][A

Inputs shape: torch.Size([1, 797])



 83%|████████▎ | 379/455 [24:10<04:32,  3.58s/it][A

Inputs shape: torch.Size([1, 991])



 84%|████████▎ | 380/455 [24:14<04:43,  3.78s/it][A

Inputs shape: torch.Size([1, 1042])



 84%|████████▎ | 381/455 [24:18<04:52,  3.96s/it][A

Inputs shape: torch.Size([1, 847])



 84%|████████▍ | 382/455 [24:22<04:35,  3.77s/it][A

Inputs shape: torch.Size([1, 945])



 84%|████████▍ | 383/455 [24:25<04:25,  3.69s/it][A

Inputs shape: torch.Size([1, 902])



 84%|████████▍ | 384/455 [24:28<04:10,  3.53s/it][A

Inputs shape: torch.Size([1, 928])



 85%|████████▍ | 385/455 [24:33<04:34,  3.91s/it][A

Inputs shape: torch.Size([1, 887])



 85%|████████▍ | 386/455 [24:36<04:07,  3.59s/it][A

Inputs shape: torch.Size([1, 978])



 85%|████████▌ | 387/455 [24:38<03:22,  2.98s/it][A

Inputs shape: torch.Size([1, 717])



 85%|████████▌ | 388/455 [24:42<03:45,  3.37s/it][A

Inputs shape: torch.Size([1, 1070])



 85%|████████▌ | 389/455 [24:46<03:59,  3.63s/it][A

Inputs shape: torch.Size([1, 1046])



 86%|████████▌ | 390/455 [24:50<03:53,  3.59s/it][A

Inputs shape: torch.Size([1, 1154])



 86%|████████▌ | 391/455 [24:55<04:24,  4.13s/it][A

Inputs shape: torch.Size([1, 936])



 86%|████████▌ | 392/455 [24:59<04:22,  4.17s/it][A

Inputs shape: torch.Size([1, 854])



 86%|████████▋ | 393/455 [25:05<04:40,  4.53s/it][A

Inputs shape: torch.Size([1, 1045])



 87%|████████▋ | 394/455 [25:08<04:16,  4.20s/it][A

Inputs shape: torch.Size([1, 998])



 87%|████████▋ | 395/455 [25:11<03:57,  3.96s/it][A

Inputs shape: torch.Size([1, 874])



 87%|████████▋ | 396/455 [25:14<03:36,  3.67s/it][A

Inputs shape: torch.Size([1, 685])



 87%|████████▋ | 397/455 [25:17<03:18,  3.42s/it][A

Inputs shape: torch.Size([1, 1029])



 87%|████████▋ | 398/455 [25:22<03:33,  3.75s/it][A

Inputs shape: torch.Size([1, 778])



 88%|████████▊ | 399/455 [25:25<03:27,  3.70s/it][A

Inputs shape: torch.Size([1, 829])



 88%|████████▊ | 400/455 [25:28<03:13,  3.51s/it][A

Inputs shape: torch.Size([1, 857])



 88%|████████▊ | 401/455 [25:32<03:17,  3.65s/it][A

Inputs shape: torch.Size([1, 893])



 88%|████████▊ | 402/455 [25:36<03:09,  3.57s/it][A

Inputs shape: torch.Size([1, 964])



 89%|████████▊ | 403/455 [25:41<03:33,  4.10s/it][A

Inputs shape: torch.Size([1, 804])



 89%|████████▉ | 404/455 [25:45<03:27,  4.07s/it][A

Inputs shape: torch.Size([1, 743])



 89%|████████▉ | 405/455 [25:49<03:14,  3.90s/it][A

Inputs shape: torch.Size([1, 999])



 89%|████████▉ | 406/455 [25:53<03:21,  4.11s/it][A

Inputs shape: torch.Size([1, 1065])



 89%|████████▉ | 407/455 [25:57<03:17,  4.12s/it][A

Inputs shape: torch.Size([1, 924])



 90%|████████▉ | 408/455 [26:01<03:10,  4.05s/it][A

Inputs shape: torch.Size([1, 926])



 90%|████████▉ | 409/455 [26:06<03:10,  4.13s/it][A

Inputs shape: torch.Size([1, 912])



 90%|█████████ | 410/455 [26:09<03:00,  4.02s/it][A

Inputs shape: torch.Size([1, 981])



 90%|█████████ | 411/455 [26:12<02:43,  3.71s/it][A

Inputs shape: torch.Size([1, 660])



 91%|█████████ | 412/455 [26:16<02:43,  3.81s/it][A

Inputs shape: torch.Size([1, 895])



 91%|█████████ | 413/455 [26:22<03:05,  4.41s/it][A

Inputs shape: torch.Size([1, 1041])



 91%|█████████ | 414/455 [26:28<03:18,  4.84s/it][A

Inputs shape: torch.Size([1, 791])



 91%|█████████ | 415/455 [26:32<03:06,  4.67s/it][A

Inputs shape: torch.Size([1, 896])



 91%|█████████▏| 416/455 [26:36<02:50,  4.38s/it][A

Inputs shape: torch.Size([1, 954])



 92%|█████████▏| 417/455 [26:39<02:34,  4.06s/it][A

Inputs shape: torch.Size([1, 970])



 92%|█████████▏| 418/455 [26:43<02:25,  3.92s/it][A

Inputs shape: torch.Size([1, 1002])



 92%|█████████▏| 419/455 [26:47<02:26,  4.06s/it][A

Inputs shape: torch.Size([1, 944])



 92%|█████████▏| 420/455 [26:51<02:22,  4.06s/it][A

Inputs shape: torch.Size([1, 944])



 93%|█████████▎| 421/455 [26:56<02:18,  4.09s/it][A

Inputs shape: torch.Size([1, 1082])



 93%|█████████▎| 422/455 [26:59<02:12,  4.02s/it][A

Inputs shape: torch.Size([1, 1051])



 93%|█████████▎| 423/455 [27:03<02:08,  4.02s/it][A

Inputs shape: torch.Size([1, 848])



 93%|█████████▎| 424/455 [27:07<01:56,  3.75s/it][A

Inputs shape: torch.Size([1, 957])



 93%|█████████▎| 425/455 [27:11<01:58,  3.94s/it][A

Inputs shape: torch.Size([1, 1019])



 94%|█████████▎| 426/455 [27:14<01:45,  3.65s/it][A

Inputs shape: torch.Size([1, 937])



 94%|█████████▍| 427/455 [27:19<01:55,  4.11s/it][A

Inputs shape: torch.Size([1, 841])



 94%|█████████▍| 428/455 [27:22<01:41,  3.76s/it][A

Inputs shape: torch.Size([1, 871])



 94%|█████████▍| 429/455 [27:26<01:41,  3.90s/it][A

Inputs shape: torch.Size([1, 871])



 95%|█████████▍| 430/455 [27:30<01:36,  3.88s/it][A

Inputs shape: torch.Size([1, 995])



 95%|█████████▍| 431/455 [27:34<01:35,  3.96s/it][A

Inputs shape: torch.Size([1, 895])



 95%|█████████▍| 432/455 [27:38<01:32,  4.02s/it][A

Inputs shape: torch.Size([1, 812])



 95%|█████████▌| 433/455 [27:42<01:27,  3.96s/it][A

Inputs shape: torch.Size([1, 827])



 95%|█████████▌| 434/455 [27:46<01:19,  3.77s/it][A

Inputs shape: torch.Size([1, 822])



 96%|█████████▌| 435/455 [27:49<01:13,  3.65s/it][A

Inputs shape: torch.Size([1, 1133])



 96%|█████████▌| 436/455 [27:52<01:08,  3.59s/it][A

Inputs shape: torch.Size([1, 915])



 96%|█████████▌| 437/455 [27:56<01:06,  3.70s/it][A

Inputs shape: torch.Size([1, 960])



 96%|█████████▋| 438/455 [27:59<00:59,  3.47s/it][A

Inputs shape: torch.Size([1, 979])



 96%|█████████▋| 439/455 [28:03<00:57,  3.58s/it][A

Inputs shape: torch.Size([1, 993])



 97%|█████████▋| 440/455 [28:06<00:52,  3.47s/it][A

Inputs shape: torch.Size([1, 941])



 97%|█████████▋| 441/455 [28:10<00:50,  3.63s/it][A

Inputs shape: torch.Size([1, 906])



 97%|█████████▋| 442/455 [28:13<00:43,  3.35s/it][A

Inputs shape: torch.Size([1, 815])



 97%|█████████▋| 443/455 [28:17<00:41,  3.49s/it][A

Inputs shape: torch.Size([1, 833])



 98%|█████████▊| 444/455 [28:20<00:37,  3.43s/it][A

Inputs shape: torch.Size([1, 768])



 98%|█████████▊| 445/455 [28:23<00:33,  3.38s/it][A

Inputs shape: torch.Size([1, 932])



 98%|█████████▊| 446/455 [28:27<00:30,  3.41s/it][A

Inputs shape: torch.Size([1, 699])



 98%|█████████▊| 447/455 [28:30<00:26,  3.33s/it][A

Inputs shape: torch.Size([1, 788])



 98%|█████████▊| 448/455 [28:34<00:24,  3.50s/it][A

Inputs shape: torch.Size([1, 789])



 99%|█████████▊| 449/455 [28:37<00:20,  3.49s/it][A

Inputs shape: torch.Size([1, 858])



 99%|█████████▉| 450/455 [28:41<00:17,  3.57s/it][A

Inputs shape: torch.Size([1, 992])



 99%|█████████▉| 451/455 [28:45<00:14,  3.64s/it][A

Inputs shape: torch.Size([1, 778])



 99%|█████████▉| 452/455 [28:48<00:10,  3.56s/it][A

Inputs shape: torch.Size([1, 1014])



100%|█████████▉| 453/455 [28:52<00:07,  3.68s/it][A

Inputs shape: torch.Size([1, 899])



100%|█████████▉| 454/455 [28:56<00:03,  3.66s/it][A

Inputs shape: torch.Size([1, 831])



100%|██████████| 455/455 [28:59<00:00,  3.82s/it]


In [None]:
from nltk.translate.bleu_score import corpus_bleu

bleu_score = corpus_bleu(
    [[ref] for ref in references],
    predictions
)
print(f"BLEU Score: {bleu_score}")

BLEU Score: 0.5438235127832646


In [None]:
%%capture
!pip install rouge_score

In [None]:
from rouge_score import rouge_scorer

scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
rouge_scores = [scorer.score(ref, pred) for ref, pred in zip(references, predictions)]

rouge1 = sum([score["rouge1"].fmeasure for score in rouge_scores]) / len(rouge_scores)
rouge2 = sum([score["rouge2"].fmeasure for score in rouge_scores]) / len(rouge_scores)
rougeL = sum([score["rougeL"].fmeasure for score in rouge_scores]) / len(rouge_scores)

print(f"ROUGE-1: {rouge1}")
print(f"ROUGE-2: {rouge2}")
print(f"ROUGE-L: {rougeL}")

ROUGE-1: 0.4875301201572035
ROUGE-2: 0.23844141275577627
ROUGE-L: 0.3649153520526568


In [None]:
from collections import Counter

from nltk.translate import bleu_score
from nltk.translate.bleu_score import SmoothingFunction
import numpy as np


def distinct(seqs):
    """ Calculate intra/inter distinct 1/2. """
    """ Recuperado de https://github.com/PaddlePaddle/models/blob/release/1.6/PaddleNLP/Research/Dialogue-PLATO/plato/metrics/metrics.py"""
    batch_size = len(seqs)
    intra_dist1, intra_dist2 = [], []
    unigrams_all, bigrams_all = Counter(), Counter()
    for seq in seqs:
        unigrams = Counter(seq)
        bigrams = Counter(zip(seq, seq[1:]))
        intra_dist1.append((len(unigrams)+1e-12) / (len(seq)+1e-5))
        intra_dist2.append((len(bigrams)+1e-12) / (max(0, len(seq)-1)+1e-5))

        unigrams_all.update(unigrams)
        bigrams_all.update(bigrams)

    inter_dist1 = (len(unigrams_all)+1e-12) / (sum(unigrams_all.values())+1e-5)
    inter_dist2 = (len(bigrams_all)+1e-12) / (sum(bigrams_all.values())+1e-5)
    intra_dist1 = np.average(intra_dist1)
    intra_dist2 = np.average(intra_dist2)
    return intra_dist1, intra_dist2, inter_dist1, inter_dist2

In [None]:
import nltk
nltk.download('punkt_tab')

tokenized_text = [nltk.word_tokenize(text) for text in references]
resultado = distinct(tokenized_text)
print(resultado)

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


(0.8174602211235673, 0.9857301007313511, 0.11541064977609411, 0.3757125912018547)


In [None]:
%%capture
!pip install evaluate bert_score

In [None]:
from evaluate import load
bertscore = load("bertscore")
results = bertscore.compute(predictions=predictions, references=references, lang="en", model_type="distilbert-base-uncased")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [None]:
from statistics import mean
print(results.keys())
print(f"F1: {mean(results['f1'])}")
print(f"Precision: {mean(results['precision'])}")
print(f"Recall: {mean(results['recall'])}")

dict_keys(['precision', 'recall', 'f1', 'hashcode'])
F1: 0.8604703283571935
Precision: 0.8617244886827993
Recall: 0.8595911986225254


# Testeo Zero Shot

In [None]:
max_seq_length = 2048
base_model_name = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=base_model_name,
    max_seq_length=max_seq_length,
    dtype=None,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)

==((====))==  Unsloth 2024.12.4: Fast Llama patching. Transformers:4.46.3.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


 24%|██▍       | 4.19M/17.6M [42:54<2:17:03, 1.63kB/s]


tokenizer_config.json:   0%|          | 0.00/54.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072, padding_idx=128004)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear4bit(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
          (rotary_emb): LlamaExtendedRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear4bit(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear4bit(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): Llam

In [None]:
tokenizer = get_chat_template(
    tokenizer,
    chat_template="llama-3.1",
)

In [None]:
with open("formatted_test.json", "r") as f:
    data = json.load(f)

dataset = Dataset.from_dict({"conversations": data})

In [None]:
formatted_test_dataset = preprocess_dataset(dataset, tokenizer)

Map:   0%|          | 0/2277 [00:00<?, ? examples/s]

In [None]:
percentage = 0.2
sample_size = int(len(formatted_test_dataset) * percentage)
subset_test_dataset = formatted_test_dataset.shuffle(seed=42).select(range(sample_size))

print(f"Usando {sample_size} ejemplos para la evaluación.")

Usando 455 ejemplos para la evaluación.


In [None]:
from tqdm import tqdm

generation_args = {
    "max_new_tokens": 100,
    "temperature": 0.3,
    "use_cache": True,
    "top_p": 0.9,
    "top_k": 50,
}

predictions = []
references = []

for example in tqdm(subset_test_dataset):
    conversation = example["conversations"]
    system_message = conversation[0]
    user_assistant_turns = conversation[1:-2]

    max_turns = 10
    truncated_turns = user_assistant_turns[-max_turns:]

    messages = [system_message] + truncated_turns
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda")

    print(f"Inputs shape: {inputs.shape}")

    with torch.no_grad():
        outputs = model.generate(input_ids=inputs, **generation_args)

    generated_response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    if "assistant" in generated_response:
        generated_response = generated_response.split("assistant")[-1].strip()

    predictions.append(generated_response)

    references.append(conversation[-2]["content"])


  0%|          | 0/455 [00:00<?, ?it/s]

Inputs shape: torch.Size([1, 891])


  0%|          | 1/455 [00:04<30:16,  4.00s/it]

Inputs shape: torch.Size([1, 971])


  0%|          | 2/455 [00:07<28:27,  3.77s/it]

Inputs shape: torch.Size([1, 872])


  1%|          | 3/455 [00:12<31:56,  4.24s/it]

Inputs shape: torch.Size([1, 766])


  1%|          | 4/455 [00:16<30:50,  4.10s/it]

Inputs shape: torch.Size([1, 798])


  1%|          | 5/455 [00:20<31:50,  4.24s/it]

Inputs shape: torch.Size([1, 771])


  1%|▏         | 6/455 [00:25<32:54,  4.40s/it]

Inputs shape: torch.Size([1, 1041])


  2%|▏         | 7/455 [00:30<33:28,  4.48s/it]

Inputs shape: torch.Size([1, 888])


  2%|▏         | 8/455 [00:34<33:23,  4.48s/it]

Inputs shape: torch.Size([1, 873])


  2%|▏         | 9/455 [00:39<33:13,  4.47s/it]

Inputs shape: torch.Size([1, 1427])


  2%|▏         | 10/455 [00:43<33:19,  4.49s/it]

Inputs shape: torch.Size([1, 673])


  2%|▏         | 11/455 [00:47<32:52,  4.44s/it]

Inputs shape: torch.Size([1, 823])


  3%|▎         | 12/455 [00:52<33:02,  4.47s/it]

Inputs shape: torch.Size([1, 821])


  3%|▎         | 13/455 [00:57<33:03,  4.49s/it]

Inputs shape: torch.Size([1, 938])


  3%|▎         | 14/455 [01:01<33:01,  4.49s/it]

Inputs shape: torch.Size([1, 1156])


  3%|▎         | 15/455 [01:06<32:55,  4.49s/it]

Inputs shape: torch.Size([1, 744])


  4%|▎         | 16/455 [01:10<32:35,  4.45s/it]

Inputs shape: torch.Size([1, 1168])


  4%|▎         | 17/455 [01:14<32:45,  4.49s/it]

Inputs shape: torch.Size([1, 776])


  4%|▍         | 18/455 [01:19<32:36,  4.48s/it]

Inputs shape: torch.Size([1, 972])


  4%|▍         | 19/455 [01:23<32:46,  4.51s/it]

Inputs shape: torch.Size([1, 957])


  4%|▍         | 20/455 [01:28<32:47,  4.52s/it]

Inputs shape: torch.Size([1, 878])


  5%|▍         | 21/455 [01:32<30:34,  4.23s/it]

Inputs shape: torch.Size([1, 1123])


  5%|▍         | 22/455 [01:36<30:35,  4.24s/it]

Inputs shape: torch.Size([1, 962])


  5%|▌         | 23/455 [01:40<31:09,  4.33s/it]

Inputs shape: torch.Size([1, 1086])


  5%|▌         | 24/455 [01:45<31:32,  4.39s/it]

Inputs shape: torch.Size([1, 989])


  5%|▌         | 25/455 [01:48<29:42,  4.15s/it]

Inputs shape: torch.Size([1, 919])


  6%|▌         | 26/455 [01:53<29:44,  4.16s/it]

Inputs shape: torch.Size([1, 985])


  6%|▌         | 27/455 [01:57<30:19,  4.25s/it]

Inputs shape: torch.Size([1, 916])


  6%|▌         | 28/455 [02:02<30:40,  4.31s/it]

Inputs shape: torch.Size([1, 835])


  6%|▋         | 29/455 [02:06<30:59,  4.36s/it]

Inputs shape: torch.Size([1, 1147])


  7%|▋         | 30/455 [02:11<31:30,  4.45s/it]

Inputs shape: torch.Size([1, 1002])


  7%|▋         | 31/455 [02:15<31:54,  4.52s/it]

Inputs shape: torch.Size([1, 736])


  7%|▋         | 32/455 [02:20<31:46,  4.51s/it]

Inputs shape: torch.Size([1, 1213])


  7%|▋         | 33/455 [02:24<29:59,  4.26s/it]

Inputs shape: torch.Size([1, 1109])


  7%|▋         | 34/455 [02:28<29:59,  4.27s/it]

Inputs shape: torch.Size([1, 715])


  8%|▊         | 35/455 [02:32<30:20,  4.33s/it]

Inputs shape: torch.Size([1, 762])


  8%|▊         | 36/455 [02:36<28:17,  4.05s/it]

Inputs shape: torch.Size([1, 902])


  8%|▊         | 37/455 [02:39<27:23,  3.93s/it]

Inputs shape: torch.Size([1, 1006])


  8%|▊         | 38/455 [02:44<28:25,  4.09s/it]

Inputs shape: torch.Size([1, 812])


  9%|▊         | 39/455 [02:48<27:33,  3.97s/it]

Inputs shape: torch.Size([1, 1060])


  9%|▉         | 40/455 [02:52<28:25,  4.11s/it]

Inputs shape: torch.Size([1, 784])


  9%|▉         | 41/455 [02:56<29:03,  4.21s/it]

Inputs shape: torch.Size([1, 794])


  9%|▉         | 42/455 [03:01<29:38,  4.31s/it]

Inputs shape: torch.Size([1, 956])


  9%|▉         | 43/455 [03:05<29:51,  4.35s/it]

Inputs shape: torch.Size([1, 942])


 10%|▉         | 44/455 [03:10<29:59,  4.38s/it]

Inputs shape: torch.Size([1, 1093])


 10%|▉         | 45/455 [03:14<29:59,  4.39s/it]

Inputs shape: torch.Size([1, 880])


 10%|█         | 46/455 [03:19<30:00,  4.40s/it]

Inputs shape: torch.Size([1, 1079])


 10%|█         | 47/455 [03:23<30:09,  4.44s/it]

Inputs shape: torch.Size([1, 680])


 11%|█         | 48/455 [03:28<29:57,  4.42s/it]

Inputs shape: torch.Size([1, 1066])


 11%|█         | 49/455 [03:31<28:50,  4.26s/it]

Inputs shape: torch.Size([1, 1165])


 11%|█         | 50/455 [03:35<27:10,  4.03s/it]

Inputs shape: torch.Size([1, 994])


 11%|█         | 51/455 [03:39<27:12,  4.04s/it]

Inputs shape: torch.Size([1, 816])


 11%|█▏        | 52/455 [03:44<28:06,  4.19s/it]

Inputs shape: torch.Size([1, 957])


 12%|█▏        | 53/455 [03:48<27:43,  4.14s/it]

Inputs shape: torch.Size([1, 837])


 12%|█▏        | 54/455 [03:52<28:22,  4.24s/it]

Inputs shape: torch.Size([1, 993])


 12%|█▏        | 55/455 [03:57<28:39,  4.30s/it]

Inputs shape: torch.Size([1, 794])


 12%|█▏        | 56/455 [04:01<28:40,  4.31s/it]

Inputs shape: torch.Size([1, 742])


 13%|█▎        | 57/455 [04:05<28:52,  4.35s/it]

Inputs shape: torch.Size([1, 822])


 13%|█▎        | 58/455 [04:10<28:52,  4.36s/it]

Inputs shape: torch.Size([1, 787])


 13%|█▎        | 59/455 [04:14<28:50,  4.37s/it]

Inputs shape: torch.Size([1, 892])


 13%|█▎        | 60/455 [04:18<28:53,  4.39s/it]

Inputs shape: torch.Size([1, 1106])


 13%|█▎        | 61/455 [04:22<27:40,  4.21s/it]

Inputs shape: torch.Size([1, 965])


 14%|█▎        | 62/455 [04:27<27:50,  4.25s/it]

Inputs shape: torch.Size([1, 849])


 14%|█▍        | 63/455 [04:31<28:12,  4.32s/it]

Inputs shape: torch.Size([1, 858])


 14%|█▍        | 64/455 [04:35<27:04,  4.15s/it]

Inputs shape: torch.Size([1, 1339])


 14%|█▍        | 65/455 [04:39<27:53,  4.29s/it]

Inputs shape: torch.Size([1, 796])


 15%|█▍        | 66/455 [04:43<27:14,  4.20s/it]

Inputs shape: torch.Size([1, 659])


 15%|█▍        | 67/455 [04:48<27:31,  4.26s/it]

Inputs shape: torch.Size([1, 1023])


 15%|█▍        | 68/455 [04:51<26:02,  4.04s/it]

Inputs shape: torch.Size([1, 845])


 15%|█▌        | 69/455 [04:56<26:44,  4.16s/it]

Inputs shape: torch.Size([1, 902])


 15%|█▌        | 70/455 [05:00<27:15,  4.25s/it]

Inputs shape: torch.Size([1, 925])


 16%|█▌        | 71/455 [05:04<27:03,  4.23s/it]

Inputs shape: torch.Size([1, 749])


 16%|█▌        | 72/455 [05:09<27:10,  4.26s/it]

Inputs shape: torch.Size([1, 852])


 16%|█▌        | 73/455 [05:13<27:09,  4.27s/it]

Inputs shape: torch.Size([1, 904])


 16%|█▋        | 74/455 [05:18<27:23,  4.31s/it]

Inputs shape: torch.Size([1, 920])


 16%|█▋        | 75/455 [05:22<27:34,  4.35s/it]

Inputs shape: torch.Size([1, 703])


 17%|█▋        | 76/455 [05:26<27:32,  4.36s/it]

Inputs shape: torch.Size([1, 1056])


 17%|█▋        | 77/455 [05:31<27:38,  4.39s/it]

Inputs shape: torch.Size([1, 742])


 17%|█▋        | 78/455 [05:35<27:30,  4.38s/it]

Inputs shape: torch.Size([1, 717])


 17%|█▋        | 79/455 [05:39<27:15,  4.35s/it]

Inputs shape: torch.Size([1, 1035])


 18%|█▊        | 80/455 [05:44<27:15,  4.36s/it]

Inputs shape: torch.Size([1, 859])


 18%|█▊        | 81/455 [05:48<27:10,  4.36s/it]

Inputs shape: torch.Size([1, 922])


 18%|█▊        | 82/455 [05:53<27:18,  4.39s/it]

Inputs shape: torch.Size([1, 856])


 18%|█▊        | 83/455 [05:57<27:05,  4.37s/it]

Inputs shape: torch.Size([1, 829])


 18%|█▊        | 84/455 [06:01<26:52,  4.35s/it]

Inputs shape: torch.Size([1, 820])


 19%|█▊        | 85/455 [06:05<25:36,  4.15s/it]

Inputs shape: torch.Size([1, 898])


 19%|█▉        | 86/455 [06:09<26:10,  4.26s/it]

Inputs shape: torch.Size([1, 1002])


 19%|█▉        | 87/455 [06:14<25:48,  4.21s/it]

Inputs shape: torch.Size([1, 727])


 19%|█▉        | 88/455 [06:18<26:04,  4.26s/it]

Inputs shape: torch.Size([1, 969])


 20%|█▉        | 89/455 [06:22<26:17,  4.31s/it]

Inputs shape: torch.Size([1, 875])


 20%|█▉        | 90/455 [06:27<26:13,  4.31s/it]

Inputs shape: torch.Size([1, 897])


 20%|██        | 91/455 [06:30<24:32,  4.05s/it]

Inputs shape: torch.Size([1, 818])


 20%|██        | 92/455 [06:35<25:20,  4.19s/it]

Inputs shape: torch.Size([1, 1024])


 20%|██        | 93/455 [06:39<25:59,  4.31s/it]

Inputs shape: torch.Size([1, 896])


 21%|██        | 94/455 [06:44<26:16,  4.37s/it]

Inputs shape: torch.Size([1, 980])


 21%|██        | 95/455 [06:48<26:13,  4.37s/it]

Inputs shape: torch.Size([1, 980])


 21%|██        | 96/455 [06:52<25:25,  4.25s/it]

Inputs shape: torch.Size([1, 709])


 21%|██▏       | 97/455 [06:55<22:15,  3.73s/it]

Inputs shape: torch.Size([1, 1002])


 22%|██▏       | 98/455 [06:59<23:28,  3.95s/it]

Inputs shape: torch.Size([1, 975])


 22%|██▏       | 99/455 [07:02<21:50,  3.68s/it]

Inputs shape: torch.Size([1, 941])


 22%|██▏       | 100/455 [07:07<23:13,  3.92s/it]

Inputs shape: torch.Size([1, 956])


 22%|██▏       | 101/455 [07:11<23:30,  3.98s/it]

Inputs shape: torch.Size([1, 746])


 22%|██▏       | 102/455 [07:15<23:56,  4.07s/it]

Inputs shape: torch.Size([1, 944])


 23%|██▎       | 103/455 [07:19<23:14,  3.96s/it]

Inputs shape: torch.Size([1, 1037])


 23%|██▎       | 104/455 [07:23<24:04,  4.12s/it]

Inputs shape: torch.Size([1, 942])


 23%|██▎       | 105/455 [07:28<24:35,  4.22s/it]

Inputs shape: torch.Size([1, 817])


 23%|██▎       | 106/455 [07:32<24:52,  4.28s/it]

Inputs shape: torch.Size([1, 916])


 24%|██▎       | 107/455 [07:36<24:58,  4.31s/it]

Inputs shape: torch.Size([1, 1001])


 24%|██▎       | 108/455 [07:41<25:00,  4.33s/it]

Inputs shape: torch.Size([1, 1124])


 24%|██▍       | 109/455 [07:45<25:08,  4.36s/it]

Inputs shape: torch.Size([1, 934])


 24%|██▍       | 110/455 [07:50<25:09,  4.38s/it]

Inputs shape: torch.Size([1, 864])


 24%|██▍       | 111/455 [07:54<24:32,  4.28s/it]

Inputs shape: torch.Size([1, 748])


 25%|██▍       | 112/455 [07:58<24:33,  4.30s/it]

Inputs shape: torch.Size([1, 788])


 25%|██▍       | 113/455 [08:02<24:30,  4.30s/it]

Inputs shape: torch.Size([1, 995])


 25%|██▌       | 114/455 [08:06<23:57,  4.21s/it]

Inputs shape: torch.Size([1, 953])


 25%|██▌       | 115/455 [08:10<23:08,  4.08s/it]

Inputs shape: torch.Size([1, 927])


 25%|██▌       | 116/455 [08:15<23:35,  4.18s/it]

Inputs shape: torch.Size([1, 1011])


 26%|██▌       | 117/455 [08:19<23:55,  4.25s/it]

Inputs shape: torch.Size([1, 868])


 26%|██▌       | 118/455 [08:23<24:04,  4.29s/it]

Inputs shape: torch.Size([1, 946])


 26%|██▌       | 119/455 [08:28<24:05,  4.30s/it]

Inputs shape: torch.Size([1, 932])


 26%|██▋       | 120/455 [08:32<24:04,  4.31s/it]

Inputs shape: torch.Size([1, 981])


 27%|██▋       | 121/455 [08:37<24:24,  4.39s/it]

Inputs shape: torch.Size([1, 830])


 27%|██▋       | 122/455 [08:41<24:26,  4.41s/it]

Inputs shape: torch.Size([1, 787])


 27%|██▋       | 123/455 [08:45<24:26,  4.42s/it]

Inputs shape: torch.Size([1, 879])


 27%|██▋       | 124/455 [08:50<24:14,  4.39s/it]

Inputs shape: torch.Size([1, 1026])


 27%|██▋       | 125/455 [08:53<22:35,  4.11s/it]

Inputs shape: torch.Size([1, 899])


 28%|██▊       | 126/455 [08:58<22:57,  4.19s/it]

Inputs shape: torch.Size([1, 709])


 28%|██▊       | 127/455 [09:02<23:15,  4.25s/it]

Inputs shape: torch.Size([1, 894])


 28%|██▊       | 128/455 [09:06<23:31,  4.32s/it]

Inputs shape: torch.Size([1, 748])


 28%|██▊       | 129/455 [09:11<23:38,  4.35s/it]

Inputs shape: torch.Size([1, 952])


 29%|██▊       | 130/455 [09:15<23:31,  4.34s/it]

Inputs shape: torch.Size([1, 757])


 29%|██▉       | 131/455 [09:19<23:18,  4.32s/it]

Inputs shape: torch.Size([1, 918])


 29%|██▉       | 132/455 [09:24<23:26,  4.36s/it]

Inputs shape: torch.Size([1, 917])


 29%|██▉       | 133/455 [09:28<23:36,  4.40s/it]

Inputs shape: torch.Size([1, 979])


 29%|██▉       | 134/455 [09:33<23:12,  4.34s/it]

Inputs shape: torch.Size([1, 857])


 30%|██▉       | 135/455 [09:37<23:15,  4.36s/it]

Inputs shape: torch.Size([1, 860])


 30%|██▉       | 136/455 [09:41<23:04,  4.34s/it]

Inputs shape: torch.Size([1, 953])


 30%|███       | 137/455 [09:46<23:01,  4.35s/it]

Inputs shape: torch.Size([1, 748])


 30%|███       | 138/455 [09:50<22:58,  4.35s/it]

Inputs shape: torch.Size([1, 1104])


 31%|███       | 139/455 [09:54<22:38,  4.30s/it]

Inputs shape: torch.Size([1, 902])


 31%|███       | 140/455 [09:59<22:44,  4.33s/it]

Inputs shape: torch.Size([1, 858])


 31%|███       | 141/455 [10:03<22:44,  4.35s/it]

Inputs shape: torch.Size([1, 807])


 31%|███       | 142/455 [10:07<22:38,  4.34s/it]

Inputs shape: torch.Size([1, 818])


 31%|███▏      | 143/455 [10:12<22:35,  4.34s/it]

Inputs shape: torch.Size([1, 921])


 32%|███▏      | 144/455 [10:16<22:43,  4.38s/it]

Inputs shape: torch.Size([1, 860])


 32%|███▏      | 145/455 [10:21<22:39,  4.39s/it]

Inputs shape: torch.Size([1, 879])


 32%|███▏      | 146/455 [10:24<21:36,  4.20s/it]

Inputs shape: torch.Size([1, 1068])


 32%|███▏      | 147/455 [10:29<21:47,  4.25s/it]

Inputs shape: torch.Size([1, 864])


 33%|███▎      | 148/455 [10:33<21:47,  4.26s/it]

Inputs shape: torch.Size([1, 936])


 33%|███▎      | 149/455 [10:37<21:28,  4.21s/it]

Inputs shape: torch.Size([1, 970])


 33%|███▎      | 150/455 [10:41<21:42,  4.27s/it]

Inputs shape: torch.Size([1, 996])


 33%|███▎      | 151/455 [10:46<21:54,  4.32s/it]

Inputs shape: torch.Size([1, 1048])


 33%|███▎      | 152/455 [10:50<22:02,  4.37s/it]

Inputs shape: torch.Size([1, 1099])


 34%|███▎      | 153/455 [10:55<22:03,  4.38s/it]

Inputs shape: torch.Size([1, 790])


 34%|███▍      | 154/455 [10:59<21:49,  4.35s/it]

Inputs shape: torch.Size([1, 949])


 34%|███▍      | 155/455 [11:03<21:48,  4.36s/it]

Inputs shape: torch.Size([1, 972])


 34%|███▍      | 156/455 [11:08<21:53,  4.39s/it]

Inputs shape: torch.Size([1, 973])


 35%|███▍      | 157/455 [11:12<21:46,  4.39s/it]

Inputs shape: torch.Size([1, 1065])


 35%|███▍      | 158/455 [11:17<21:40,  4.38s/it]

Inputs shape: torch.Size([1, 934])


 35%|███▍      | 159/455 [11:21<21:30,  4.36s/it]

Inputs shape: torch.Size([1, 1238])


 35%|███▌      | 160/455 [11:25<21:03,  4.28s/it]

Inputs shape: torch.Size([1, 1191])


 35%|███▌      | 161/455 [11:29<21:05,  4.31s/it]

Inputs shape: torch.Size([1, 797])


 36%|███▌      | 162/455 [11:33<20:30,  4.20s/it]

Inputs shape: torch.Size([1, 1101])


 36%|███▌      | 163/455 [11:37<19:57,  4.10s/it]

Inputs shape: torch.Size([1, 857])


 36%|███▌      | 164/455 [11:42<20:12,  4.17s/it]

Inputs shape: torch.Size([1, 799])


 36%|███▋      | 165/455 [11:46<20:16,  4.20s/it]

Inputs shape: torch.Size([1, 957])


 36%|███▋      | 166/455 [11:49<19:07,  3.97s/it]

Inputs shape: torch.Size([1, 868])


 37%|███▋      | 167/455 [11:54<19:38,  4.09s/it]

Inputs shape: torch.Size([1, 1026])


 37%|███▋      | 168/455 [11:58<20:03,  4.19s/it]

Inputs shape: torch.Size([1, 839])


 37%|███▋      | 169/455 [12:02<20:12,  4.24s/it]

Inputs shape: torch.Size([1, 690])


 37%|███▋      | 170/455 [12:07<20:12,  4.26s/it]

Inputs shape: torch.Size([1, 1147])


 38%|███▊      | 171/455 [12:10<18:55,  4.00s/it]

Inputs shape: torch.Size([1, 872])


 38%|███▊      | 172/455 [12:14<19:16,  4.09s/it]

Inputs shape: torch.Size([1, 992])


 38%|███▊      | 173/455 [12:18<19:07,  4.07s/it]

Inputs shape: torch.Size([1, 973])


 38%|███▊      | 174/455 [12:23<19:30,  4.17s/it]

Inputs shape: torch.Size([1, 837])


 38%|███▊      | 175/455 [12:27<19:41,  4.22s/it]

Inputs shape: torch.Size([1, 724])


 39%|███▊      | 176/455 [12:31<19:43,  4.24s/it]

Inputs shape: torch.Size([1, 893])


 39%|███▉      | 177/455 [12:36<19:43,  4.26s/it]

Inputs shape: torch.Size([1, 1037])


 39%|███▉      | 178/455 [12:40<19:06,  4.14s/it]

Inputs shape: torch.Size([1, 851])


 39%|███▉      | 179/455 [12:44<19:18,  4.20s/it]

Inputs shape: torch.Size([1, 855])


 40%|███▉      | 180/455 [12:48<19:34,  4.27s/it]

Inputs shape: torch.Size([1, 925])


 40%|███▉      | 181/455 [12:52<18:18,  4.01s/it]

Inputs shape: torch.Size([1, 1085])


 40%|████      | 182/455 [12:55<16:45,  3.68s/it]

Inputs shape: torch.Size([1, 875])


 40%|████      | 183/455 [12:58<16:47,  3.71s/it]

Inputs shape: torch.Size([1, 945])


 40%|████      | 184/455 [13:02<17:02,  3.77s/it]

Inputs shape: torch.Size([1, 763])


 41%|████      | 185/455 [13:07<17:47,  3.95s/it]

Inputs shape: torch.Size([1, 956])


 41%|████      | 186/455 [13:11<18:02,  4.03s/it]

Inputs shape: torch.Size([1, 807])


 41%|████      | 187/455 [13:15<18:25,  4.12s/it]

Inputs shape: torch.Size([1, 941])


 41%|████▏     | 188/455 [13:19<17:56,  4.03s/it]

Inputs shape: torch.Size([1, 1072])


 42%|████▏     | 189/455 [13:24<18:21,  4.14s/it]

Inputs shape: torch.Size([1, 761])


 42%|████▏     | 190/455 [13:28<18:25,  4.17s/it]

Inputs shape: torch.Size([1, 1115])


 42%|████▏     | 191/455 [13:32<18:44,  4.26s/it]

Inputs shape: torch.Size([1, 761])


 42%|████▏     | 192/455 [13:37<18:49,  4.29s/it]

Inputs shape: torch.Size([1, 694])


 42%|████▏     | 193/455 [13:41<18:45,  4.29s/it]

Inputs shape: torch.Size([1, 838])


 43%|████▎     | 194/455 [13:45<18:42,  4.30s/it]

Inputs shape: torch.Size([1, 928])


 43%|████▎     | 195/455 [13:50<18:40,  4.31s/it]

Inputs shape: torch.Size([1, 930])


 43%|████▎     | 196/455 [13:53<17:17,  4.01s/it]

Inputs shape: torch.Size([1, 1059])


 43%|████▎     | 197/455 [13:56<16:38,  3.87s/it]

Inputs shape: torch.Size([1, 646])


 44%|████▎     | 198/455 [14:01<17:08,  4.00s/it]

Inputs shape: torch.Size([1, 844])


 44%|████▎     | 199/455 [14:05<17:37,  4.13s/it]

Inputs shape: torch.Size([1, 955])


 44%|████▍     | 200/455 [14:09<17:18,  4.07s/it]

Inputs shape: torch.Size([1, 928])


 44%|████▍     | 201/455 [14:13<17:35,  4.15s/it]

Inputs shape: torch.Size([1, 877])


 44%|████▍     | 202/455 [14:18<17:41,  4.20s/it]

Inputs shape: torch.Size([1, 976])


 45%|████▍     | 203/455 [14:22<17:06,  4.07s/it]

Inputs shape: torch.Size([1, 1229])


 45%|████▍     | 204/455 [14:26<17:24,  4.16s/it]

Inputs shape: torch.Size([1, 845])


 45%|████▌     | 205/455 [14:29<16:15,  3.90s/it]

Inputs shape: torch.Size([1, 1051])


 45%|████▌     | 206/455 [14:32<15:04,  3.63s/it]

Inputs shape: torch.Size([1, 724])


 45%|████▌     | 207/455 [14:36<15:47,  3.82s/it]

Inputs shape: torch.Size([1, 847])


 46%|████▌     | 208/455 [14:41<16:19,  3.96s/it]

Inputs shape: torch.Size([1, 1064])


 46%|████▌     | 209/455 [14:45<16:49,  4.10s/it]

Inputs shape: torch.Size([1, 928])


 46%|████▌     | 210/455 [14:49<16:57,  4.15s/it]

Inputs shape: torch.Size([1, 779])


 46%|████▋     | 211/455 [14:54<17:06,  4.21s/it]

Inputs shape: torch.Size([1, 933])


 47%|████▋     | 212/455 [14:57<15:41,  3.87s/it]

Inputs shape: torch.Size([1, 1062])


 47%|████▋     | 213/455 [15:01<16:11,  4.02s/it]

Inputs shape: torch.Size([1, 892])


 47%|████▋     | 214/455 [15:05<15:42,  3.91s/it]

Inputs shape: torch.Size([1, 852])


 47%|████▋     | 215/455 [15:07<13:23,  3.35s/it]

Inputs shape: torch.Size([1, 1064])


 47%|████▋     | 216/455 [15:11<14:37,  3.67s/it]

Inputs shape: torch.Size([1, 1036])


 48%|████▊     | 217/455 [15:16<15:27,  3.90s/it]

Inputs shape: torch.Size([1, 957])


 48%|████▊     | 218/455 [15:19<15:08,  3.83s/it]

Inputs shape: torch.Size([1, 857])


 48%|████▊     | 219/455 [15:23<14:39,  3.73s/it]

Inputs shape: torch.Size([1, 1068])


 48%|████▊     | 220/455 [15:26<14:13,  3.63s/it]

Inputs shape: torch.Size([1, 1092])


 49%|████▊     | 221/455 [15:31<15:01,  3.85s/it]

Inputs shape: torch.Size([1, 829])


 49%|████▉     | 222/455 [15:35<15:22,  3.96s/it]

Inputs shape: torch.Size([1, 606])


 49%|████▉     | 223/455 [15:39<15:43,  4.07s/it]

Inputs shape: torch.Size([1, 1158])


 49%|████▉     | 224/455 [15:44<16:10,  4.20s/it]

Inputs shape: torch.Size([1, 1261])


 49%|████▉     | 225/455 [15:48<16:02,  4.18s/it]

Inputs shape: torch.Size([1, 747])


 50%|████▉     | 226/455 [15:52<16:03,  4.21s/it]

Inputs shape: torch.Size([1, 876])


 50%|████▉     | 227/455 [15:56<15:49,  4.16s/it]

Inputs shape: torch.Size([1, 1069])


 50%|█████     | 228/455 [16:01<16:04,  4.25s/it]

Inputs shape: torch.Size([1, 989])


 50%|█████     | 229/455 [16:05<16:11,  4.30s/it]

Inputs shape: torch.Size([1, 1137])


 51%|█████     | 230/455 [16:10<16:17,  4.34s/it]

Inputs shape: torch.Size([1, 870])


 51%|█████     | 231/455 [16:14<16:11,  4.34s/it]

Inputs shape: torch.Size([1, 865])


 51%|█████     | 232/455 [16:17<15:20,  4.13s/it]

Inputs shape: torch.Size([1, 920])


 51%|█████     | 233/455 [16:22<15:26,  4.17s/it]

Inputs shape: torch.Size([1, 954])


 51%|█████▏    | 234/455 [16:26<15:37,  4.24s/it]

Inputs shape: torch.Size([1, 936])


 52%|█████▏    | 235/455 [16:31<15:44,  4.30s/it]

Inputs shape: torch.Size([1, 792])


 52%|█████▏    | 236/455 [16:35<15:44,  4.31s/it]

Inputs shape: torch.Size([1, 971])


 52%|█████▏    | 237/455 [16:39<15:45,  4.34s/it]

Inputs shape: torch.Size([1, 893])


 52%|█████▏    | 238/455 [16:44<15:39,  4.33s/it]

Inputs shape: torch.Size([1, 918])


 53%|█████▎    | 239/455 [16:48<15:37,  4.34s/it]

Inputs shape: torch.Size([1, 867])


 53%|█████▎    | 240/455 [16:52<15:39,  4.37s/it]

Inputs shape: torch.Size([1, 849])


 53%|█████▎    | 241/455 [16:57<15:36,  4.38s/it]

Inputs shape: torch.Size([1, 715])


 53%|█████▎    | 242/455 [17:01<15:32,  4.38s/it]

Inputs shape: torch.Size([1, 843])


 53%|█████▎    | 243/455 [17:06<15:22,  4.35s/it]

Inputs shape: torch.Size([1, 927])


 54%|█████▎    | 244/455 [17:09<13:53,  3.95s/it]

Inputs shape: torch.Size([1, 733])


 54%|█████▍    | 245/455 [17:13<14:10,  4.05s/it]

Inputs shape: torch.Size([1, 1019])


 54%|█████▍    | 246/455 [17:17<14:26,  4.15s/it]

Inputs shape: torch.Size([1, 769])


 54%|█████▍    | 247/455 [17:20<13:26,  3.88s/it]

Inputs shape: torch.Size([1, 1057])


 55%|█████▍    | 248/455 [17:25<13:56,  4.04s/it]

Inputs shape: torch.Size([1, 929])


 55%|█████▍    | 249/455 [17:29<14:09,  4.12s/it]

Inputs shape: torch.Size([1, 1001])


 55%|█████▍    | 250/455 [17:34<14:18,  4.19s/it]

Inputs shape: torch.Size([1, 934])


 55%|█████▌    | 251/455 [17:38<14:26,  4.25s/it]

Inputs shape: torch.Size([1, 657])


 55%|█████▌    | 252/455 [17:42<14:26,  4.27s/it]

Inputs shape: torch.Size([1, 773])


 56%|█████▌    | 253/455 [17:47<14:26,  4.29s/it]

Inputs shape: torch.Size([1, 915])


 56%|█████▌    | 254/455 [17:51<14:26,  4.31s/it]

Inputs shape: torch.Size([1, 821])


 56%|█████▌    | 255/455 [17:55<14:19,  4.30s/it]

Inputs shape: torch.Size([1, 986])


 56%|█████▋    | 256/455 [18:00<14:19,  4.32s/it]

Inputs shape: torch.Size([1, 873])


 56%|█████▋    | 257/455 [18:04<14:18,  4.34s/it]

Inputs shape: torch.Size([1, 863])


 57%|█████▋    | 258/455 [18:08<14:14,  4.34s/it]

Inputs shape: torch.Size([1, 1041])


 57%|█████▋    | 259/455 [18:13<14:14,  4.36s/it]

Inputs shape: torch.Size([1, 1018])


 57%|█████▋    | 260/455 [18:16<13:19,  4.10s/it]

Inputs shape: torch.Size([1, 737])


 57%|█████▋    | 261/455 [18:20<13:24,  4.15s/it]

Inputs shape: torch.Size([1, 859])


 58%|█████▊    | 262/455 [18:25<13:27,  4.18s/it]

Inputs shape: torch.Size([1, 971])


 58%|█████▊    | 263/455 [18:29<13:36,  4.25s/it]

Inputs shape: torch.Size([1, 940])


 58%|█████▊    | 264/455 [18:34<13:40,  4.29s/it]

Inputs shape: torch.Size([1, 714])


 58%|█████▊    | 265/455 [18:38<13:36,  4.30s/it]

Inputs shape: torch.Size([1, 812])


 58%|█████▊    | 266/455 [18:42<13:34,  4.31s/it]

Inputs shape: torch.Size([1, 742])


 59%|█████▊    | 267/455 [18:46<13:30,  4.31s/it]

Inputs shape: torch.Size([1, 1080])


 59%|█████▉    | 268/455 [18:50<12:44,  4.09s/it]

Inputs shape: torch.Size([1, 1027])


 59%|█████▉    | 269/455 [18:54<13:00,  4.20s/it]

Inputs shape: torch.Size([1, 794])


 59%|█████▉    | 270/455 [18:59<13:05,  4.25s/it]

Inputs shape: torch.Size([1, 902])


 60%|█████▉    | 271/455 [19:03<13:08,  4.29s/it]

Inputs shape: torch.Size([1, 1047])


 60%|█████▉    | 272/455 [19:07<12:34,  4.12s/it]

Inputs shape: torch.Size([1, 1120])


 60%|██████    | 273/455 [19:11<12:46,  4.21s/it]

Inputs shape: torch.Size([1, 903])


 60%|██████    | 274/455 [19:16<12:47,  4.24s/it]

Inputs shape: torch.Size([1, 890])


 60%|██████    | 275/455 [19:20<12:53,  4.30s/it]

Inputs shape: torch.Size([1, 902])


 61%|██████    | 276/455 [19:24<12:51,  4.31s/it]

Inputs shape: torch.Size([1, 907])


 61%|██████    | 277/455 [19:29<12:51,  4.34s/it]

Inputs shape: torch.Size([1, 753])


 61%|██████    | 278/455 [19:32<11:35,  3.93s/it]

Inputs shape: torch.Size([1, 1021])


 61%|██████▏   | 279/455 [19:36<11:53,  4.06s/it]

Inputs shape: torch.Size([1, 888])


 62%|██████▏   | 280/455 [19:41<12:03,  4.13s/it]

Inputs shape: torch.Size([1, 899])


 62%|██████▏   | 281/455 [19:45<12:12,  4.21s/it]

Inputs shape: torch.Size([1, 792])


 62%|██████▏   | 282/455 [19:49<12:14,  4.25s/it]

Inputs shape: torch.Size([1, 753])


 62%|██████▏   | 283/455 [19:54<12:13,  4.26s/it]

Inputs shape: torch.Size([1, 839])


 62%|██████▏   | 284/455 [19:58<12:11,  4.28s/it]

Inputs shape: torch.Size([1, 998])


 63%|██████▎   | 285/455 [20:02<12:11,  4.31s/it]

Inputs shape: torch.Size([1, 691])


 63%|██████▎   | 286/455 [20:07<12:10,  4.32s/it]

Inputs shape: torch.Size([1, 990])


 63%|██████▎   | 287/455 [20:11<12:14,  4.37s/it]

Inputs shape: torch.Size([1, 935])


 63%|██████▎   | 288/455 [20:16<12:15,  4.40s/it]

Inputs shape: torch.Size([1, 888])


 64%|██████▎   | 289/455 [20:20<12:07,  4.38s/it]

Inputs shape: torch.Size([1, 804])


 64%|██████▎   | 290/455 [20:24<11:55,  4.34s/it]

Inputs shape: torch.Size([1, 985])


 64%|██████▍   | 291/455 [20:28<11:38,  4.26s/it]

Inputs shape: torch.Size([1, 698])


 64%|██████▍   | 292/455 [20:33<11:38,  4.28s/it]

Inputs shape: torch.Size([1, 1073])


 64%|██████▍   | 293/455 [20:36<11:04,  4.10s/it]

Inputs shape: torch.Size([1, 1109])


 65%|██████▍   | 294/455 [20:40<11:03,  4.12s/it]

Inputs shape: torch.Size([1, 774])


 65%|██████▍   | 295/455 [20:44<10:33,  3.96s/it]

Inputs shape: torch.Size([1, 801])


 65%|██████▌   | 296/455 [20:48<10:44,  4.06s/it]

Inputs shape: torch.Size([1, 1043])


 65%|██████▌   | 297/455 [20:53<10:54,  4.14s/it]

Inputs shape: torch.Size([1, 909])


 65%|██████▌   | 298/455 [20:56<10:24,  3.98s/it]

Inputs shape: torch.Size([1, 995])


 66%|██████▌   | 299/455 [21:01<10:42,  4.12s/it]

Inputs shape: torch.Size([1, 833])


 66%|██████▌   | 300/455 [21:05<10:35,  4.10s/it]

Inputs shape: torch.Size([1, 833])


 66%|██████▌   | 301/455 [21:09<10:43,  4.18s/it]

Inputs shape: torch.Size([1, 924])


 66%|██████▋   | 302/455 [21:13<10:49,  4.24s/it]

Inputs shape: torch.Size([1, 787])


 67%|██████▋   | 303/455 [21:18<10:49,  4.27s/it]

Inputs shape: torch.Size([1, 806])


 67%|██████▋   | 304/455 [21:22<10:50,  4.31s/it]

Inputs shape: torch.Size([1, 760])


 67%|██████▋   | 305/455 [21:27<10:48,  4.32s/it]

Inputs shape: torch.Size([1, 865])


 67%|██████▋   | 306/455 [21:31<10:45,  4.33s/it]

Inputs shape: torch.Size([1, 1106])


 67%|██████▋   | 307/455 [21:35<10:44,  4.35s/it]

Inputs shape: torch.Size([1, 686])


 68%|██████▊   | 308/455 [21:40<10:36,  4.33s/it]

Inputs shape: torch.Size([1, 723])


 68%|██████▊   | 309/455 [21:44<10:29,  4.31s/it]

Inputs shape: torch.Size([1, 911])


 68%|██████▊   | 310/455 [21:48<10:28,  4.34s/it]

Inputs shape: torch.Size([1, 807])


 68%|██████▊   | 311/455 [21:53<10:30,  4.38s/it]

Inputs shape: torch.Size([1, 891])


 69%|██████▊   | 312/455 [21:57<10:21,  4.35s/it]

Inputs shape: torch.Size([1, 718])


 69%|██████▉   | 313/455 [22:01<10:15,  4.34s/it]

Inputs shape: torch.Size([1, 869])


 69%|██████▉   | 314/455 [22:06<10:10,  4.33s/it]

Inputs shape: torch.Size([1, 880])


 69%|██████▉   | 315/455 [22:10<10:08,  4.34s/it]

Inputs shape: torch.Size([1, 833])


 69%|██████▉   | 316/455 [22:14<10:06,  4.36s/it]

Inputs shape: torch.Size([1, 900])


 70%|██████▉   | 317/455 [22:19<10:04,  4.38s/it]

Inputs shape: torch.Size([1, 773])


 70%|██████▉   | 318/455 [22:23<10:00,  4.39s/it]

Inputs shape: torch.Size([1, 932])


 70%|███████   | 319/455 [22:28<09:53,  4.37s/it]

Inputs shape: torch.Size([1, 826])


 70%|███████   | 320/455 [22:32<09:45,  4.33s/it]

Inputs shape: torch.Size([1, 1042])


 71%|███████   | 321/455 [22:36<09:45,  4.37s/it]

Inputs shape: torch.Size([1, 780])


 71%|███████   | 322/455 [22:41<09:41,  4.37s/it]

Inputs shape: torch.Size([1, 1013])


 71%|███████   | 323/455 [22:45<09:31,  4.33s/it]

Inputs shape: torch.Size([1, 1020])


 71%|███████   | 324/455 [22:49<09:31,  4.36s/it]

Inputs shape: torch.Size([1, 889])


 71%|███████▏  | 325/455 [22:54<09:24,  4.35s/it]

Inputs shape: torch.Size([1, 1274])


 72%|███████▏  | 326/455 [22:58<09:22,  4.36s/it]

Inputs shape: torch.Size([1, 849])


 72%|███████▏  | 327/455 [23:02<09:19,  4.37s/it]

Inputs shape: torch.Size([1, 818])


 72%|███████▏  | 328/455 [23:07<09:17,  4.39s/it]

Inputs shape: torch.Size([1, 814])


 72%|███████▏  | 329/455 [23:09<07:54,  3.76s/it]

Inputs shape: torch.Size([1, 997])


 73%|███████▎  | 330/455 [23:14<08:17,  3.98s/it]

Inputs shape: torch.Size([1, 980])


 73%|███████▎  | 331/455 [23:18<08:24,  4.07s/it]

Inputs shape: torch.Size([1, 1108])


 73%|███████▎  | 332/455 [23:22<08:10,  3.99s/it]

Inputs shape: torch.Size([1, 908])


 73%|███████▎  | 333/455 [23:26<08:21,  4.11s/it]

Inputs shape: torch.Size([1, 940])


 73%|███████▎  | 334/455 [23:30<08:28,  4.20s/it]

Inputs shape: torch.Size([1, 1149])


 74%|███████▎  | 335/455 [23:34<08:03,  4.03s/it]

Inputs shape: torch.Size([1, 868])


 74%|███████▍  | 336/455 [23:38<08:10,  4.12s/it]

Inputs shape: torch.Size([1, 833])


 74%|███████▍  | 337/455 [23:43<08:12,  4.17s/it]

Inputs shape: torch.Size([1, 1160])


 74%|███████▍  | 338/455 [23:47<08:14,  4.23s/it]

Inputs shape: torch.Size([1, 1087])


 75%|███████▍  | 339/455 [23:51<08:14,  4.26s/it]

Inputs shape: torch.Size([1, 1024])


 75%|███████▍  | 340/455 [23:56<08:14,  4.30s/it]

Inputs shape: torch.Size([1, 1111])


 75%|███████▍  | 341/455 [23:59<07:24,  3.90s/it]

Inputs shape: torch.Size([1, 890])


 75%|███████▌  | 342/455 [24:03<07:36,  4.04s/it]

Inputs shape: torch.Size([1, 1094])


 75%|███████▌  | 343/455 [24:07<07:41,  4.12s/it]

Inputs shape: torch.Size([1, 848])


 76%|███████▌  | 344/455 [24:12<07:34,  4.09s/it]

Inputs shape: torch.Size([1, 862])


 76%|███████▌  | 345/455 [24:16<07:38,  4.17s/it]

Inputs shape: torch.Size([1, 876])


 76%|███████▌  | 346/455 [24:20<07:40,  4.23s/it]

Inputs shape: torch.Size([1, 992])


 76%|███████▋  | 347/455 [24:25<07:41,  4.28s/it]

Inputs shape: torch.Size([1, 1081])


 76%|███████▋  | 348/455 [24:29<07:50,  4.39s/it]

Inputs shape: torch.Size([1, 880])


 77%|███████▋  | 349/455 [24:34<07:42,  4.37s/it]

Inputs shape: torch.Size([1, 852])


 77%|███████▋  | 350/455 [24:38<07:30,  4.29s/it]

Inputs shape: torch.Size([1, 1118])


 77%|███████▋  | 351/455 [24:42<07:13,  4.17s/it]

Inputs shape: torch.Size([1, 881])


 77%|███████▋  | 352/455 [24:46<07:16,  4.24s/it]

Inputs shape: torch.Size([1, 850])


 78%|███████▊  | 353/455 [24:50<07:16,  4.28s/it]

Inputs shape: torch.Size([1, 766])


 78%|███████▊  | 354/455 [24:55<07:12,  4.28s/it]

Inputs shape: torch.Size([1, 904])


 78%|███████▊  | 355/455 [24:59<07:10,  4.30s/it]

Inputs shape: torch.Size([1, 951])


 78%|███████▊  | 356/455 [25:03<07:07,  4.32s/it]

Inputs shape: torch.Size([1, 912])


 78%|███████▊  | 357/455 [25:08<07:05,  4.34s/it]

Inputs shape: torch.Size([1, 1064])


 79%|███████▊  | 358/455 [25:12<07:02,  4.36s/it]

Inputs shape: torch.Size([1, 883])


 79%|███████▉  | 359/455 [25:16<06:58,  4.36s/it]

Inputs shape: torch.Size([1, 861])


 79%|███████▉  | 360/455 [25:20<06:43,  4.25s/it]

Inputs shape: torch.Size([1, 952])


 79%|███████▉  | 361/455 [25:25<06:40,  4.26s/it]

Inputs shape: torch.Size([1, 907])


 80%|███████▉  | 362/455 [25:29<06:40,  4.31s/it]

Inputs shape: torch.Size([1, 841])


 80%|███████▉  | 363/455 [25:34<06:37,  4.32s/it]

Inputs shape: torch.Size([1, 898])


 80%|████████  | 364/455 [25:38<06:36,  4.35s/it]

Inputs shape: torch.Size([1, 663])


 80%|████████  | 365/455 [25:42<06:30,  4.34s/it]

Inputs shape: torch.Size([1, 1081])


 80%|████████  | 366/455 [25:47<06:26,  4.34s/it]

Inputs shape: torch.Size([1, 854])


 81%|████████  | 367/455 [25:51<06:20,  4.32s/it]

Inputs shape: torch.Size([1, 1052])


 81%|████████  | 368/455 [25:55<06:19,  4.36s/it]

Inputs shape: torch.Size([1, 973])


 81%|████████  | 369/455 [26:00<06:17,  4.39s/it]

Inputs shape: torch.Size([1, 998])


 81%|████████▏ | 370/455 [26:04<06:14,  4.41s/it]

Inputs shape: torch.Size([1, 1087])


 82%|████████▏ | 371/455 [26:08<05:56,  4.25s/it]

Inputs shape: torch.Size([1, 1266])


 82%|████████▏ | 372/455 [26:12<05:44,  4.15s/it]

Inputs shape: torch.Size([1, 988])


 82%|████████▏ | 373/455 [26:16<05:46,  4.22s/it]

Inputs shape: torch.Size([1, 967])


 82%|████████▏ | 374/455 [26:21<05:49,  4.31s/it]

Inputs shape: torch.Size([1, 830])


 82%|████████▏ | 375/455 [26:25<05:47,  4.35s/it]

Inputs shape: torch.Size([1, 936])


 83%|████████▎ | 376/455 [26:30<05:44,  4.37s/it]

Inputs shape: torch.Size([1, 943])


 83%|████████▎ | 377/455 [26:34<05:41,  4.38s/it]

Inputs shape: torch.Size([1, 746])


 83%|████████▎ | 378/455 [26:39<05:35,  4.36s/it]

Inputs shape: torch.Size([1, 797])


 83%|████████▎ | 379/455 [26:43<05:31,  4.37s/it]

Inputs shape: torch.Size([1, 991])


 84%|████████▎ | 380/455 [26:47<05:31,  4.42s/it]

Inputs shape: torch.Size([1, 1042])


 84%|████████▎ | 381/455 [26:51<05:05,  4.13s/it]

Inputs shape: torch.Size([1, 847])


 84%|████████▍ | 382/455 [26:55<05:09,  4.24s/it]

Inputs shape: torch.Size([1, 945])


 84%|████████▍ | 383/455 [27:00<05:08,  4.29s/it]

Inputs shape: torch.Size([1, 902])


 84%|████████▍ | 384/455 [27:04<05:04,  4.30s/it]

Inputs shape: torch.Size([1, 928])


 85%|████████▍ | 385/455 [27:09<05:02,  4.33s/it]

Inputs shape: torch.Size([1, 887])


 85%|████████▍ | 386/455 [27:13<05:01,  4.36s/it]

Inputs shape: torch.Size([1, 978])


 85%|████████▌ | 387/455 [27:16<04:23,  3.88s/it]

Inputs shape: torch.Size([1, 717])


 85%|████████▌ | 388/455 [27:20<04:30,  4.04s/it]

Inputs shape: torch.Size([1, 1070])


 85%|████████▌ | 389/455 [27:24<04:22,  3.98s/it]

Inputs shape: torch.Size([1, 1046])


 86%|████████▌ | 390/455 [27:28<04:25,  4.09s/it]

Inputs shape: torch.Size([1, 1154])


 86%|████████▌ | 391/455 [27:33<04:28,  4.19s/it]

Inputs shape: torch.Size([1, 936])


 86%|████████▌ | 392/455 [27:37<04:30,  4.29s/it]

Inputs shape: torch.Size([1, 854])


 86%|████████▋ | 393/455 [27:42<04:27,  4.31s/it]

Inputs shape: torch.Size([1, 1045])


 87%|████████▋ | 394/455 [27:46<04:25,  4.36s/it]

Inputs shape: torch.Size([1, 998])


 87%|████████▋ | 395/455 [27:50<04:20,  4.35s/it]

Inputs shape: torch.Size([1, 874])


 87%|████████▋ | 396/455 [27:53<03:52,  3.95s/it]

Inputs shape: torch.Size([1, 685])


 87%|████████▋ | 397/455 [27:58<03:54,  4.04s/it]

Inputs shape: torch.Size([1, 1029])


 87%|████████▋ | 398/455 [28:02<03:57,  4.17s/it]

Inputs shape: torch.Size([1, 778])


 88%|████████▊ | 399/455 [28:06<03:56,  4.22s/it]

Inputs shape: torch.Size([1, 829])


 88%|████████▊ | 400/455 [28:11<03:54,  4.27s/it]

Inputs shape: torch.Size([1, 857])


 88%|████████▊ | 401/455 [28:15<03:51,  4.28s/it]

Inputs shape: torch.Size([1, 893])


 88%|████████▊ | 402/455 [28:19<03:47,  4.29s/it]

Inputs shape: torch.Size([1, 964])


 89%|████████▊ | 403/455 [28:24<03:45,  4.33s/it]

Inputs shape: torch.Size([1, 804])


 89%|████████▉ | 404/455 [28:28<03:41,  4.34s/it]

Inputs shape: torch.Size([1, 743])


 89%|████████▉ | 405/455 [28:33<03:37,  4.34s/it]

Inputs shape: torch.Size([1, 999])


 89%|████████▉ | 406/455 [28:37<03:33,  4.36s/it]

Inputs shape: torch.Size([1, 1065])


 89%|████████▉ | 407/455 [28:41<03:29,  4.37s/it]

Inputs shape: torch.Size([1, 924])


 90%|████████▉ | 408/455 [28:46<03:24,  4.35s/it]

Inputs shape: torch.Size([1, 926])


 90%|████████▉ | 409/455 [28:50<03:20,  4.37s/it]

Inputs shape: torch.Size([1, 912])


 90%|█████████ | 410/455 [28:54<03:14,  4.33s/it]

Inputs shape: torch.Size([1, 981])


 90%|█████████ | 411/455 [28:59<03:10,  4.34s/it]

Inputs shape: torch.Size([1, 660])


 91%|█████████ | 412/455 [29:03<03:04,  4.30s/it]

Inputs shape: torch.Size([1, 895])


 91%|█████████ | 413/455 [29:07<03:00,  4.31s/it]

Inputs shape: torch.Size([1, 1041])


 91%|█████████ | 414/455 [29:12<02:57,  4.32s/it]

Inputs shape: torch.Size([1, 791])


 91%|█████████ | 415/455 [29:16<02:53,  4.33s/it]

Inputs shape: torch.Size([1, 896])


 91%|█████████▏| 416/455 [29:20<02:49,  4.34s/it]

Inputs shape: torch.Size([1, 954])


 92%|█████████▏| 417/455 [29:24<02:39,  4.20s/it]

Inputs shape: torch.Size([1, 970])


 92%|█████████▏| 418/455 [29:29<02:37,  4.26s/it]

Inputs shape: torch.Size([1, 1002])


 92%|█████████▏| 419/455 [29:33<02:33,  4.27s/it]

Inputs shape: torch.Size([1, 944])


 92%|█████████▏| 420/455 [29:37<02:30,  4.30s/it]

Inputs shape: torch.Size([1, 944])


 93%|█████████▎| 421/455 [29:42<02:26,  4.31s/it]

Inputs shape: torch.Size([1, 1082])


 93%|█████████▎| 422/455 [29:46<02:23,  4.35s/it]

Inputs shape: torch.Size([1, 1051])


 93%|█████████▎| 423/455 [29:50<02:19,  4.37s/it]

Inputs shape: torch.Size([1, 848])


 93%|█████████▎| 424/455 [29:54<02:11,  4.25s/it]

Inputs shape: torch.Size([1, 957])


 93%|█████████▎| 425/455 [29:59<02:07,  4.27s/it]

Inputs shape: torch.Size([1, 1019])


 94%|█████████▎| 426/455 [30:02<01:59,  4.11s/it]

Inputs shape: torch.Size([1, 937])


 94%|█████████▍| 427/455 [30:07<01:57,  4.18s/it]

Inputs shape: torch.Size([1, 841])


 94%|█████████▍| 428/455 [30:11<01:49,  4.05s/it]

Inputs shape: torch.Size([1, 871])


 94%|█████████▍| 429/455 [30:15<01:47,  4.14s/it]

Inputs shape: torch.Size([1, 871])


 95%|█████████▍| 430/455 [30:19<01:44,  4.19s/it]

Inputs shape: torch.Size([1, 995])


 95%|█████████▍| 431/455 [30:23<01:34,  3.92s/it]

Inputs shape: torch.Size([1, 895])


 95%|█████████▍| 432/455 [30:27<01:32,  4.04s/it]

Inputs shape: torch.Size([1, 812])


 95%|█████████▌| 433/455 [30:31<01:31,  4.14s/it]

Inputs shape: torch.Size([1, 827])


 95%|█████████▌| 434/455 [30:36<01:28,  4.22s/it]

Inputs shape: torch.Size([1, 822])


 96%|█████████▌| 435/455 [30:40<01:25,  4.26s/it]

Inputs shape: torch.Size([1, 1133])


 96%|█████████▌| 436/455 [30:43<01:15,  3.96s/it]

Inputs shape: torch.Size([1, 915])


 96%|█████████▌| 437/455 [30:46<01:06,  3.67s/it]

Inputs shape: torch.Size([1, 960])


 96%|█████████▋| 438/455 [30:50<01:04,  3.82s/it]

Inputs shape: torch.Size([1, 979])


 96%|█████████▋| 439/455 [30:55<01:04,  4.01s/it]

Inputs shape: torch.Size([1, 993])


 97%|█████████▋| 440/455 [30:59<01:02,  4.14s/it]

Inputs shape: torch.Size([1, 941])


 97%|█████████▋| 441/455 [31:04<00:58,  4.21s/it]

Inputs shape: torch.Size([1, 906])


 97%|█████████▋| 442/455 [31:08<00:53,  4.11s/it]

Inputs shape: torch.Size([1, 815])


 97%|█████████▋| 443/455 [31:12<00:49,  4.14s/it]

Inputs shape: torch.Size([1, 833])


 98%|█████████▊| 444/455 [31:16<00:46,  4.20s/it]

Inputs shape: torch.Size([1, 768])


 98%|█████████▊| 445/455 [31:21<00:42,  4.27s/it]

Inputs shape: torch.Size([1, 932])


 98%|█████████▊| 446/455 [31:24<00:36,  4.02s/it]

Inputs shape: torch.Size([1, 699])


 98%|█████████▊| 447/455 [31:28<00:32,  4.10s/it]

Inputs shape: torch.Size([1, 788])


 98%|█████████▊| 448/455 [31:33<00:29,  4.15s/it]

Inputs shape: torch.Size([1, 789])


 99%|█████████▊| 449/455 [31:37<00:25,  4.18s/it]

Inputs shape: torch.Size([1, 858])


 99%|█████████▉| 450/455 [31:41<00:21,  4.21s/it]

Inputs shape: torch.Size([1, 992])


 99%|█████████▉| 451/455 [31:45<00:17,  4.27s/it]

Inputs shape: torch.Size([1, 778])


 99%|█████████▉| 452/455 [31:50<00:12,  4.30s/it]

Inputs shape: torch.Size([1, 1014])


100%|█████████▉| 453/455 [31:54<00:08,  4.32s/it]

Inputs shape: torch.Size([1, 899])


100%|█████████▉| 454/455 [31:58<00:04,  4.30s/it]

Inputs shape: torch.Size([1, 831])


100%|██████████| 455/455 [32:03<00:00,  4.23s/it]


In [None]:
bleu_score = corpus_bleu(
    [[ref] for ref in references],
    predictions
)
print(f"BLEU Score: {bleu_score}")

BLEU Score: 0.3720423190369473


In [None]:
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
rouge_scores = [scorer.score(ref, pred) for ref, pred in zip(references, predictions)]

rouge1 = sum([score["rouge1"].fmeasure for score in rouge_scores]) / len(rouge_scores)
rouge2 = sum([score["rouge2"].fmeasure for score in rouge_scores]) / len(rouge_scores)
rougeL = sum([score["rougeL"].fmeasure for score in rouge_scores]) / len(rouge_scores)

print(f"ROUGE-1: {rouge1}")
print(f"ROUGE-2: {rouge2}")
print(f"ROUGE-L: {rougeL}")

ROUGE-1: 0.36443422599813075
ROUGE-2: 0.12167336133694615
ROUGE-L: 0.2351907453172963


In [None]:
def distinct(seqs):
    """ Calculate intra/inter distinct 1/2. """
    """ Recuperado de https://github.com/PaddlePaddle/models/blob/release/1.6/PaddleNLP/Research/Dialogue-PLATO/plato/metrics/metrics.py"""
    batch_size = len(seqs)
    intra_dist1, intra_dist2 = [], []
    unigrams_all, bigrams_all = Counter(), Counter()
    for seq in seqs:
        unigrams = Counter(seq)
        bigrams = Counter(zip(seq, seq[1:]))
        intra_dist1.append((len(unigrams)+1e-12) / (len(seq)+1e-5))
        intra_dist2.append((len(bigrams)+1e-12) / (max(0, len(seq)-1)+1e-5))

        unigrams_all.update(unigrams)
        bigrams_all.update(bigrams)

    inter_dist1 = (len(unigrams_all)+1e-12) / (sum(unigrams_all.values())+1e-5)
    inter_dist2 = (len(bigrams_all)+1e-12) / (sum(bigrams_all.values())+1e-5)
    intra_dist1 = np.average(intra_dist1)
    intra_dist2 = np.average(intra_dist2)
    return intra_dist1, intra_dist2, inter_dist1, inter_dist2

In [None]:
nltk.download('punkt_tab')

tokenized_text = [nltk.word_tokenize(text) for text in predictions]
resultado = distinct(tokenized_text)
print(resultado)

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


(0.7221576114901389, 0.9508190513729, 0.0907063749434587, 0.28086002097908586)


In [None]:
bertscore = load("bertscore")
results = bertscore.compute(predictions=predictions, references=references, lang="en", model_type="distilbert-base-uncased")

In [None]:
print(results.keys())
print(f"F1: {mean(results['f1'])}")
print(f"Precision: {mean(results['precision'])}")
print(f"Recall: {mean(results['recall'])}")

dict_keys(['precision', 'recall', 'f1', 'hashcode'])
F1: 0.8185031487391545
Precision: 0.8018398773539197
Recall: 0.836233556663597
