# Descarga de los datos formateados

In [None]:
%%capture
!pip install gdown

In [None]:
import gdown

files_to_download = {
    "formatted_train.jsonl": "13p2I_-pxXTQRjddJDd26bwAERjAgMqbe",
    "formatted_test.jsonl": "10SQP_HikNR0mcgTJtT9BVczSk78aGziH",
    "formatted_validation.jsonl": "1qkfNA5tm8jFv_whMpNWAgVC_C9YOgo-V"
}

for destination, file_id in files_to_download.items():
    gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=13p2I_-pxXTQRjddJDd26bwAERjAgMqbe
From (redirected): https://drive.google.com/uc?id=13p2I_-pxXTQRjddJDd26bwAERjAgMqbe&confirm=t&uuid=5dd412a4-44cf-49cb-9aa2-8f72068e64e9
To: /content/formatted_train.jsonl
100%|██████████| 176M/176M [00:04<00:00, 37.7MB/s]
Downloading...
From: https://drive.google.com/uc?id=10SQP_HikNR0mcgTJtT9BVczSk78aGziH
To: /content/formatted_test.jsonl
100%|██████████| 8.00M/8.00M [00:00<00:00, 37.5MB/s]
Downloading...
From: https://drive.google.com/uc?id=1qkfNA5tm8jFv_whMpNWAgVC_C9YOgo-V
To: /content/formatted_validation.jsonl
100%|██████████| 17.6M/17.6M [00:00<00:00, 64.0MB/s]


# Implementación del modelo

In [None]:
%%capture
!pip install unsloth "xformers==0.0.28.post2"
# Also get the latest nightly Unsloth!
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install datasets

In [None]:
from datasets import load_dataset, Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq, TextStreamer, AutoTokenizer
from peft import AutoPeftModelForCausalLM
import json

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [None]:
with open("formatted_train.jsonl", "r") as f:
    data = [json.loads(line)["interaction"] for line in f]

dataset = Dataset.from_dict({"conversations": data})

In [None]:
max_seq_length = 2048

def initialize_model_and_tokenizer():
    base_model_name = "unsloth/Phi-3.5-mini-instruct"
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=base_model_name,
        max_seq_length=max_seq_length,
        dtype=None,
        load_in_4bit=True,
    )
    model = FastLanguageModel.get_peft_model(
        model,
        r=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_alpha=16,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=3407,
        use_rslora=False,
        loftq_config=None
    )
    tokenizer = get_chat_template(
        tokenizer,
        chat_template="phi-3.5"
    )

    return model, tokenizer

In [None]:
def preprocess_dataset(dataset, tokenizer):
    def formatting_prompts_func(examples):
        convos = examples["conversations"]
        texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
        return {"text": texts}

    formatted_dataset = dataset.map(formatting_prompts_func, batched=True)
    return formatted_dataset

In [None]:
used_seed = 3407

# Segundo entrenamiento

En este entrenamiento se usarán los mismos paramétros de entrenamiento que antes, pero con la diferencia de que usaré la data que tenía con la primera forma de parseo, en donde entrego información de películas.

In [None]:
files_to_download = {
    "formatted_train.json": "1Fm7aPdCv6bguP7UgoCqdUVcm8xpDo18_",
    "formatted_test.json": "1a4YeF--Sks7WA1ZIQL2zDZx4teKk7p4m",
    "formatted_validation.json": "1PC9OhZhNZt8lFO9wifhydy0BrIYe8Dkm"
}

for destination, file_id in files_to_download.items():
    gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=1Fm7aPdCv6bguP7UgoCqdUVcm8xpDo18_
From (redirected): https://drive.google.com/uc?id=1Fm7aPdCv6bguP7UgoCqdUVcm8xpDo18_&confirm=t&uuid=f5c05b0c-ad64-44b9-9d97-249041acc97f
To: /content/formatted_train.json
100%|██████████| 237M/237M [00:03<00:00, 68.9MB/s]
Downloading...
From: https://drive.google.com/uc?id=1a4YeF--Sks7WA1ZIQL2zDZx4teKk7p4m
To: /content/formatted_test.json
100%|██████████| 10.8M/10.8M [00:00<00:00, 41.8MB/s]
Downloading...
From: https://drive.google.com/uc?id=1PC9OhZhNZt8lFO9wifhydy0BrIYe8Dkm
To: /content/formatted_validation.json
100%|██████████| 23.7M/23.7M [00:00<00:00, 54.4MB/s]


In [None]:
with open("formatted_train.json", "r") as f:
    data = json.load(f)

dataset_second_train = Dataset.from_dict({"conversations": data})

In [None]:
training_configuration = {
    "lr_scheduler_type": "linear",
    "per_device_train_batch_size": 2,
    "gradient_accumulation_steps": 4,
    "max_steps": 200,
    "output_dir": "second_train_output",
    "learning_rate": 2e-4
}

In [None]:
model_second_train, tokenizer_second_train = initialize_model_and_tokenizer()

==((====))==  Unsloth 2024.12.4: Fast Llama patching. Transformers:4.46.3.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.26G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/140 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/3.37k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Unsloth 2024.12.4 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [None]:
formatted_dataset_second_train = preprocess_dataset(dataset_second_train, tokenizer_second_train)

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [None]:
trainer = SFTTrainer(
    model=model_second_train,
    tokenizer=tokenizer_second_train,
    train_dataset=formatted_dataset_second_train,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    args=TrainingArguments(
        learning_rate=5e-5,
        lr_scheduler_type=training_configuration["lr_scheduler_type"],
        per_device_train_batch_size=training_configuration["per_device_train_batch_size"],
        gradient_accumulation_steps=training_configuration["gradient_accumulation_steps"],
        num_train_epochs=1,
        max_steps=training_configuration["max_steps"],
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.001,
        warmup_steps=5,
        output_dir=training_configuration["output_dir"],
        seed=used_seed,
    )
)

Map (num_proc=2):   0%|          | 0/50000 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs


In [None]:
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = Tesla T4. Max memory = 14.748 GB.
2.285 GB of memory reserved.


In [None]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 50,000 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 200
 "-____-"     Number of trainable parameters = 29,884,416
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Step,Training Loss
1,1.9122
2,1.9333
3,1.8267
4,1.9899
5,1.94
6,1.809
7,1.8198
8,1.7919
9,1.7307
10,1.7024


In [None]:
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory         /max_memory*100, 3)
lora_percentage = round(used_memory_for_lora/max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

3475.3134 seconds used for training.
57.92 minutes used for training.
Peak reserved memory = 3.07 GB.
Peak reserved memory for training = 0.785 GB.
Peak reserved memory % of max memory = 20.816 %.
Peak reserved memory for training % of max memory = 5.323 %.


In [None]:
model_second_train.save_pretrained("lora_model_second_train")
tokenizer_second_train.save_pretrained("lora_model_second_train")

('lora_model_second_train/tokenizer_config.json',
 'lora_model_second_train/special_tokens_map.json',
 'lora_model_second_train/tokenizer.model',
 'lora_model_second_train/added_tokens.json',
 'lora_model_second_train/tokenizer.json')

In [None]:
!zip -r lora_model_second_train.zip lora_model_second_train/

  adding: lora_model_second_train/ (stored 0%)
  adding: lora_model_second_train/tokenizer_config.json (deflated 83%)
  adding: lora_model_second_train/added_tokens.json (deflated 62%)
  adding: lora_model_second_train/special_tokens_map.json (deflated 76%)
  adding: lora_model_second_train/tokenizer.model (deflated 55%)
  adding: lora_model_second_train/tokenizer.json (deflated 85%)
  adding: lora_model_second_train/adapter_model.safetensors (deflated 8%)
  adding: lora_model_second_train/README.md (deflated 66%)
  adding: lora_model_second_train/adapter_config.json (deflated 54%)


# Testeo segundo modelo

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "lora_model_second_train",
    max_seq_length = max_seq_length,
    dtype=None,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)

==((====))==  Unsloth 2024.12.4: Fast Llama patching. Transformers:4.46.3.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.

In [None]:
tokenizer = get_chat_template(
    tokenizer,
    chat_template="phi-3.5",
)

In [None]:
with open("formatted_test.json", "r") as f:
    data = json.load(f)

dataset = Dataset.from_dict({"conversations": data})

In [None]:
formatted_test_dataset = preprocess_dataset(dataset, tokenizer)

Map:   0%|          | 0/2277 [00:00<?, ? examples/s]

In [None]:
%%capture
!pip install tqdm

In [None]:
percentage = 0.2
sample_size = int(len(formatted_test_dataset) * percentage)
subset_test_dataset = formatted_test_dataset.shuffle(seed=42).select(range(sample_size))

print(f"Usando {sample_size} ejemplos para la evaluación.")

Usando 455 ejemplos para la evaluación.


In [None]:
from tqdm import tqdm

generation_args = {
    "max_new_tokens": 100,
    "temperature": 0.3,
    "use_cache": True,
    "top_p": 0.9,
    "top_k": 50,
}

predictions = []
references = []

for example in tqdm(subset_test_dataset):
    conversation = example["conversations"]
    system_message = conversation[0]
    user_assistant_turns = conversation[1:-2]

    max_turns = 10
    truncated_turns = user_assistant_turns[-max_turns:]

    messages = [system_message] + truncated_turns
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda")

    print(f"Inputs shape: {inputs.shape}")

    with torch.no_grad():
        outputs = model.generate(input_ids=inputs, **generation_args)

    generated_response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    if "assistant" in generated_response:
        generated_response = generated_response.split("assistant")[-1].strip()

    predictions.append(generated_response)

    references.append(conversation[-2]["content"])


  0%|          | 0/455 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Inputs shape: torch.Size([1, 1039])


  0%|          | 1/455 [00:24<3:05:41, 24.54s/it]

Inputs shape: torch.Size([1, 1106])


  0%|          | 2/455 [00:46<2:52:58, 22.91s/it]

Inputs shape: torch.Size([1, 1004])


  1%|          | 3/455 [00:59<2:20:48, 18.69s/it]

Inputs shape: torch.Size([1, 867])


  1%|          | 4/455 [01:10<1:56:28, 15.50s/it]

Inputs shape: torch.Size([1, 906])


  1%|          | 5/455 [01:20<1:40:28, 13.40s/it]

Inputs shape: torch.Size([1, 882])


  1%|▏         | 6/455 [01:26<1:22:43, 11.05s/it]

Inputs shape: torch.Size([1, 1208])


  2%|▏         | 7/455 [01:36<1:19:19, 10.62s/it]

Inputs shape: torch.Size([1, 1002])


  2%|▏         | 8/455 [01:47<1:19:47, 10.71s/it]

Inputs shape: torch.Size([1, 983])


  2%|▏         | 9/455 [01:53<1:09:52,  9.40s/it]

Inputs shape: torch.Size([1, 1608])


  2%|▏         | 10/455 [02:03<1:09:09,  9.32s/it]

Inputs shape: torch.Size([1, 770])


  2%|▏         | 11/455 [02:15<1:16:36, 10.35s/it]

Inputs shape: torch.Size([1, 958])


  3%|▎         | 12/455 [02:24<1:12:49,  9.86s/it]

Inputs shape: torch.Size([1, 932])


  3%|▎         | 13/455 [02:34<1:12:33,  9.85s/it]

Inputs shape: torch.Size([1, 1062])


  3%|▎         | 14/455 [02:45<1:14:54, 10.19s/it]

Inputs shape: torch.Size([1, 1315])


  3%|▎         | 15/455 [02:51<1:07:00,  9.14s/it]

Inputs shape: torch.Size([1, 866])


  4%|▎         | 16/455 [03:00<1:05:30,  8.95s/it]

Inputs shape: torch.Size([1, 1369])


  4%|▎         | 17/455 [03:11<1:10:23,  9.64s/it]

Inputs shape: torch.Size([1, 897])


  4%|▍         | 18/455 [03:19<1:05:38,  9.01s/it]

Inputs shape: torch.Size([1, 1127])


  4%|▍         | 19/455 [03:26<1:01:16,  8.43s/it]

Inputs shape: torch.Size([1, 1086])


  4%|▍         | 20/455 [03:37<1:06:29,  9.17s/it]

Inputs shape: torch.Size([1, 1028])


  5%|▍         | 21/455 [03:46<1:07:09,  9.28s/it]

Inputs shape: torch.Size([1, 1302])


  5%|▍         | 22/455 [03:53<1:01:40,  8.55s/it]

Inputs shape: torch.Size([1, 1080])


  5%|▌         | 23/455 [04:04<1:05:43,  9.13s/it]

Inputs shape: torch.Size([1, 1282])


  5%|▌         | 24/455 [04:14<1:09:00,  9.61s/it]

Inputs shape: torch.Size([1, 1100])


  5%|▌         | 25/455 [04:21<1:02:36,  8.74s/it]

Inputs shape: torch.Size([1, 1054])


  6%|▌         | 26/455 [04:30<1:03:18,  8.85s/it]

Inputs shape: torch.Size([1, 1139])


  6%|▌         | 27/455 [04:41<1:08:14,  9.57s/it]

Inputs shape: torch.Size([1, 1041])


  6%|▌         | 28/455 [04:49<1:02:56,  8.85s/it]

Inputs shape: torch.Size([1, 932])


  6%|▋         | 29/455 [04:56<1:00:07,  8.47s/it]

Inputs shape: torch.Size([1, 1333])


  7%|▋         | 30/455 [05:08<1:06:11,  9.34s/it]

Inputs shape: torch.Size([1, 1142])


  7%|▋         | 31/455 [05:16<1:04:08,  9.08s/it]

Inputs shape: torch.Size([1, 849])


  7%|▋         | 32/455 [05:23<59:09,  8.39s/it]  

Inputs shape: torch.Size([1, 1388])


  7%|▋         | 33/455 [05:34<1:04:42,  9.20s/it]

Inputs shape: torch.Size([1, 1311])


  7%|▋         | 34/455 [05:44<1:05:39,  9.36s/it]

Inputs shape: torch.Size([1, 807])


  8%|▊         | 35/455 [05:50<59:38,  8.52s/it]  

Inputs shape: torch.Size([1, 873])


  8%|▊         | 36/455 [06:00<1:02:25,  8.94s/it]

Inputs shape: torch.Size([1, 1005])


  8%|▊         | 37/455 [06:11<1:06:44,  9.58s/it]

Inputs shape: torch.Size([1, 1174])


  8%|▊         | 38/455 [06:18<1:00:39,  8.73s/it]

Inputs shape: torch.Size([1, 924])


  9%|▊         | 39/455 [06:27<1:00:48,  8.77s/it]

Inputs shape: torch.Size([1, 1217])


  9%|▉         | 40/455 [06:38<1:05:59,  9.54s/it]

Inputs shape: torch.Size([1, 892])


  9%|▉         | 41/455 [06:45<1:01:01,  8.84s/it]

Inputs shape: torch.Size([1, 909])


  9%|▉         | 42/455 [06:53<58:16,  8.47s/it]  

Inputs shape: torch.Size([1, 1094])


  9%|▉         | 43/455 [07:04<1:03:46,  9.29s/it]

Inputs shape: torch.Size([1, 1080])


 10%|▉         | 44/455 [07:13<1:02:07,  9.07s/it]

Inputs shape: torch.Size([1, 1263])


 10%|▉         | 45/455 [07:20<57:50,  8.47s/it]  

Inputs shape: torch.Size([1, 1047])


 10%|█         | 46/455 [07:31<1:02:37,  9.19s/it]

Inputs shape: torch.Size([1, 1240])


 10%|█         | 47/455 [07:40<1:03:46,  9.38s/it]

Inputs shape: torch.Size([1, 766])


 11%|█         | 48/455 [07:47<58:07,  8.57s/it]  

Inputs shape: torch.Size([1, 1214])


 11%|█         | 49/455 [07:57<1:01:00,  9.02s/it]

Inputs shape: torch.Size([1, 1305])


 11%|█         | 50/455 [08:08<1:05:27,  9.70s/it]

Inputs shape: torch.Size([1, 1108])


 11%|█         | 51/455 [08:15<59:24,  8.82s/it]  

Inputs shape: torch.Size([1, 942])


 11%|█▏        | 52/455 [08:24<58:57,  8.78s/it]

Inputs shape: torch.Size([1, 1064])


 12%|█▏        | 53/455 [08:35<1:03:49,  9.53s/it]

Inputs shape: torch.Size([1, 962])


 12%|█▏        | 54/455 [08:43<1:00:12,  9.01s/it]

Inputs shape: torch.Size([1, 1129])


 12%|█▏        | 55/455 [08:51<57:01,  8.55s/it]  

Inputs shape: torch.Size([1, 888])


 12%|█▏        | 56/455 [09:02<1:02:08,  9.34s/it]

Inputs shape: torch.Size([1, 863])


 13%|█▎        | 57/455 [09:11<1:00:58,  9.19s/it]

Inputs shape: torch.Size([1, 929])


 13%|█▎        | 58/455 [09:17<55:39,  8.41s/it]  

Inputs shape: torch.Size([1, 906])


 13%|█▎        | 59/455 [09:28<59:52,  9.07s/it]

Inputs shape: torch.Size([1, 1021])


 13%|█▎        | 60/455 [09:38<1:01:46,  9.38s/it]

Inputs shape: torch.Size([1, 1287])


 13%|█▎        | 61/455 [09:45<56:36,  8.62s/it]  

Inputs shape: torch.Size([1, 1118])


 14%|█▎        | 62/455 [09:54<58:34,  8.94s/it]

Inputs shape: torch.Size([1, 969])


 14%|█▍        | 63/455 [10:06<1:03:03,  9.65s/it]

Inputs shape: torch.Size([1, 996])


 14%|█▍        | 64/455 [10:12<57:01,  8.75s/it]  

Inputs shape: torch.Size([1, 1531])


 14%|█▍        | 65/455 [10:21<56:36,  8.71s/it]

Inputs shape: torch.Size([1, 901])


 15%|█▍        | 66/455 [10:32<1:01:14,  9.45s/it]

Inputs shape: torch.Size([1, 754])


 15%|█▍        | 67/455 [10:40<57:53,  8.95s/it]  

Inputs shape: torch.Size([1, 1169])


 15%|█▍        | 68/455 [10:47<54:15,  8.41s/it]

Inputs shape: torch.Size([1, 954])


 15%|█▌        | 69/455 [10:58<58:51,  9.15s/it]

Inputs shape: torch.Size([1, 1028])


 15%|█▌        | 70/455 [11:07<59:07,  9.21s/it]

Inputs shape: torch.Size([1, 1080])


 16%|█▌        | 71/455 [11:14<54:00,  8.44s/it]

Inputs shape: torch.Size([1, 875])


 16%|█▌        | 72/455 [11:24<57:01,  8.93s/it]

Inputs shape: torch.Size([1, 968])


 16%|█▌        | 73/455 [11:35<1:00:16,  9.47s/it]

Inputs shape: torch.Size([1, 1068])


 16%|█▋        | 74/455 [11:41<54:46,  8.63s/it]  

Inputs shape: torch.Size([1, 1051])


 16%|█▋        | 75/455 [11:51<55:37,  8.78s/it]

Inputs shape: torch.Size([1, 798])


 17%|█▋        | 76/455 [12:02<59:52,  9.48s/it]

Inputs shape: torch.Size([1, 1215])


 17%|█▋        | 77/455 [12:09<55:20,  8.79s/it]

Inputs shape: torch.Size([1, 854])


 17%|█▋        | 78/455 [12:16<52:47,  8.40s/it]

Inputs shape: torch.Size([1, 819])


 17%|█▋        | 79/455 [12:27<57:37,  9.20s/it]

Inputs shape: torch.Size([1, 1200])


 18%|█▊        | 80/455 [12:36<56:31,  9.04s/it]

Inputs shape: torch.Size([1, 965])


 18%|█▊        | 81/455 [12:43<52:02,  8.35s/it]

Inputs shape: torch.Size([1, 1061])


 18%|█▊        | 82/455 [12:53<55:57,  9.00s/it]

Inputs shape: torch.Size([1, 986])


 18%|█▊        | 83/455 [13:03<57:54,  9.34s/it]

Inputs shape: torch.Size([1, 940])


 18%|█▊        | 84/455 [13:10<52:29,  8.49s/it]

Inputs shape: torch.Size([1, 916])


 19%|█▊        | 85/455 [13:19<53:37,  8.69s/it]

Inputs shape: torch.Size([1, 1039])


 19%|█▉        | 86/455 [13:30<58:05,  9.45s/it]

Inputs shape: torch.Size([1, 1179])


 19%|█▉        | 87/455 [13:37<53:20,  8.70s/it]

Inputs shape: torch.Size([1, 833])


 19%|█▉        | 88/455 [13:45<51:15,  8.38s/it]

Inputs shape: torch.Size([1, 1100])


 20%|█▉        | 89/455 [13:56<56:05,  9.19s/it]

Inputs shape: torch.Size([1, 1003])


 20%|█▉        | 90/455 [14:04<54:32,  8.97s/it]

Inputs shape: torch.Size([1, 1028])


 20%|██        | 91/455 [14:11<50:20,  8.30s/it]

Inputs shape: torch.Size([1, 946])


 20%|██        | 92/455 [14:22<54:13,  8.96s/it]

Inputs shape: torch.Size([1, 1168])


 20%|██        | 93/455 [14:32<56:15,  9.32s/it]

Inputs shape: torch.Size([1, 1048])


 21%|██        | 94/455 [14:38<51:03,  8.49s/it]

Inputs shape: torch.Size([1, 1116])


 21%|██        | 95/455 [14:48<52:17,  8.72s/it]

Inputs shape: torch.Size([1, 1129])


 21%|██        | 96/455 [14:59<56:34,  9.46s/it]

Inputs shape: torch.Size([1, 794])


 21%|██▏       | 97/455 [15:06<51:33,  8.64s/it]

Inputs shape: torch.Size([1, 1148])


 22%|██▏       | 98/455 [15:13<49:49,  8.37s/it]

Inputs shape: torch.Size([1, 1140])


 22%|██▏       | 99/455 [15:24<54:24,  9.17s/it]

Inputs shape: torch.Size([1, 1081])


 22%|██▏       | 100/455 [15:33<53:26,  9.03s/it]

Inputs shape: torch.Size([1, 1106])


 22%|██▏       | 101/455 [15:40<49:04,  8.32s/it]

Inputs shape: torch.Size([1, 855])


 22%|██▏       | 102/455 [15:50<52:53,  8.99s/it]

Inputs shape: torch.Size([1, 1094])


 23%|██▎       | 103/455 [16:00<54:31,  9.30s/it]

Inputs shape: torch.Size([1, 1175])


 23%|██▎       | 104/455 [16:07<49:35,  8.48s/it]

Inputs shape: torch.Size([1, 1084])


 23%|██▎       | 105/455 [16:16<50:55,  8.73s/it]

Inputs shape: torch.Size([1, 937])


 23%|██▎       | 106/455 [16:27<55:00,  9.46s/it]

Inputs shape: torch.Size([1, 1025])


 24%|██▎       | 107/455 [16:34<50:08,  8.64s/it]

Inputs shape: torch.Size([1, 1157])


 24%|██▎       | 108/455 [16:42<48:12,  8.34s/it]

Inputs shape: torch.Size([1, 1293])


 24%|██▍       | 109/455 [16:53<52:52,  9.17s/it]

Inputs shape: torch.Size([1, 1070])


 24%|██▍       | 110/455 [17:01<51:35,  8.97s/it]

Inputs shape: torch.Size([1, 984])


 24%|██▍       | 111/455 [17:08<47:41,  8.32s/it]

Inputs shape: torch.Size([1, 861])


 25%|██▍       | 112/455 [17:19<51:57,  9.09s/it]

Inputs shape: torch.Size([1, 906])


 25%|██▍       | 113/455 [17:29<53:07,  9.32s/it]

Inputs shape: torch.Size([1, 1119])


 25%|██▌       | 114/455 [17:36<48:32,  8.54s/it]

Inputs shape: torch.Size([1, 1089])


 25%|██▌       | 115/455 [17:45<50:40,  8.94s/it]

Inputs shape: torch.Size([1, 1076])


 25%|██▌       | 116/455 [17:57<54:07,  9.58s/it]

Inputs shape: torch.Size([1, 1180])


 26%|██▌       | 117/455 [18:03<49:07,  8.72s/it]

Inputs shape: torch.Size([1, 980])


 26%|██▌       | 118/455 [18:12<48:48,  8.69s/it]

Inputs shape: torch.Size([1, 1080])


 26%|██▌       | 119/455 [18:23<52:53,  9.45s/it]

Inputs shape: torch.Size([1, 1071])


 26%|██▋       | 120/455 [18:31<49:32,  8.87s/it]

Inputs shape: torch.Size([1, 1161])


 27%|██▋       | 121/455 [18:38<46:35,  8.37s/it]

Inputs shape: torch.Size([1, 949])


 27%|██▋       | 122/455 [18:49<50:51,  9.16s/it]

Inputs shape: torch.Size([1, 908])


 27%|██▋       | 123/455 [18:58<50:31,  9.13s/it]

Inputs shape: torch.Size([1, 1010])


 27%|██▋       | 124/455 [19:05<46:20,  8.40s/it]

Inputs shape: torch.Size([1, 1171])


 27%|██▋       | 125/455 [19:15<49:37,  9.02s/it]

Inputs shape: torch.Size([1, 1031])


 28%|██▊       | 126/455 [19:25<51:32,  9.40s/it]

Inputs shape: torch.Size([1, 812])


 28%|██▊       | 127/455 [19:32<46:32,  8.51s/it]

Inputs shape: torch.Size([1, 1038])


 28%|██▊       | 128/455 [19:41<47:08,  8.65s/it]

Inputs shape: torch.Size([1, 850])


 28%|██▊       | 129/455 [19:52<51:01,  9.39s/it]

Inputs shape: torch.Size([1, 1088])


 29%|██▊       | 130/455 [19:59<47:07,  8.70s/it]

Inputs shape: torch.Size([1, 859])


 29%|██▉       | 131/455 [20:06<44:39,  8.27s/it]

Inputs shape: torch.Size([1, 1073])


 29%|██▉       | 132/455 [20:17<48:59,  9.10s/it]

Inputs shape: torch.Size([1, 1059])


 29%|██▉       | 133/455 [20:26<48:28,  9.03s/it]

Inputs shape: torch.Size([1, 1140])


 29%|██▉       | 134/455 [20:33<44:21,  8.29s/it]

Inputs shape: torch.Size([1, 959])


 30%|██▉       | 135/455 [20:43<47:32,  8.91s/it]

Inputs shape: torch.Size([1, 982])


 30%|██▉       | 136/455 [20:53<49:35,  9.33s/it]

Inputs shape: torch.Size([1, 1119])


 30%|███       | 137/455 [21:00<44:56,  8.48s/it]

Inputs shape: torch.Size([1, 855])


 30%|███       | 138/455 [21:09<45:15,  8.57s/it]

Inputs shape: torch.Size([1, 1249])


 31%|███       | 139/455 [21:20<49:08,  9.33s/it]

Inputs shape: torch.Size([1, 1016])


 31%|███       | 140/455 [21:27<45:46,  8.72s/it]

Inputs shape: torch.Size([1, 979])


 31%|███       | 141/455 [21:34<43:24,  8.30s/it]

Inputs shape: torch.Size([1, 913])


 31%|███       | 142/455 [21:45<47:14,  9.06s/it]

Inputs shape: torch.Size([1, 947])


 31%|███▏      | 143/455 [21:54<47:09,  9.07s/it]

Inputs shape: torch.Size([1, 1082])


 32%|███▏      | 144/455 [22:01<43:06,  8.32s/it]

Inputs shape: torch.Size([1, 975])


 32%|███▏      | 145/455 [22:11<45:56,  8.89s/it]

Inputs shape: torch.Size([1, 1026])


 32%|███▏      | 146/455 [22:22<48:14,  9.37s/it]

Inputs shape: torch.Size([1, 1234])


 32%|███▏      | 147/455 [22:28<43:54,  8.56s/it]

Inputs shape: torch.Size([1, 990])


 33%|███▎      | 148/455 [22:37<44:44,  8.75s/it]

Inputs shape: torch.Size([1, 1072])


 33%|███▎      | 149/455 [22:49<48:17,  9.47s/it]

Inputs shape: torch.Size([1, 1135])


 33%|███▎      | 150/455 [22:56<44:24,  8.74s/it]

Inputs shape: torch.Size([1, 1156])


 33%|███▎      | 151/455 [23:03<43:04,  8.50s/it]

Inputs shape: torch.Size([1, 1189])


 33%|███▎      | 152/455 [23:15<46:59,  9.31s/it]

Inputs shape: torch.Size([1, 1293])


 34%|███▎      | 153/455 [23:23<45:34,  9.06s/it]

Inputs shape: torch.Size([1, 885])


 34%|███▍      | 154/455 [23:30<42:15,  8.42s/it]

Inputs shape: torch.Size([1, 1085])


 34%|███▍      | 155/455 [23:41<45:33,  9.11s/it]

Inputs shape: torch.Size([1, 1138])


 34%|███▍      | 156/455 [23:51<46:33,  9.34s/it]

Inputs shape: torch.Size([1, 1128])


 35%|███▍      | 157/455 [23:57<42:26,  8.54s/it]

Inputs shape: torch.Size([1, 1240])


 35%|███▍      | 158/455 [24:07<44:11,  8.93s/it]

Inputs shape: torch.Size([1, 1097])


 35%|███▍      | 159/455 [24:18<47:21,  9.60s/it]

Inputs shape: torch.Size([1, 1403])


 35%|███▌      | 160/455 [24:25<43:03,  8.76s/it]

Inputs shape: torch.Size([1, 1387])


 35%|███▌      | 161/455 [24:34<43:07,  8.80s/it]

Inputs shape: torch.Size([1, 915])


 36%|███▌      | 162/455 [24:45<46:40,  9.56s/it]

Inputs shape: torch.Size([1, 1269])


 36%|███▌      | 163/455 [24:53<44:13,  9.09s/it]

Inputs shape: torch.Size([1, 973])


 36%|███▌      | 164/455 [25:01<41:20,  8.52s/it]

Inputs shape: torch.Size([1, 899])


 36%|███▋      | 165/455 [25:11<44:32,  9.22s/it]

Inputs shape: torch.Size([1, 1084])


 36%|███▋      | 166/455 [25:21<44:49,  9.31s/it]

Inputs shape: torch.Size([1, 992])


 37%|███▋      | 167/455 [25:28<40:43,  8.49s/it]

Inputs shape: torch.Size([1, 1200])


 37%|███▋      | 168/455 [25:38<42:51,  8.96s/it]

Inputs shape: torch.Size([1, 947])


 37%|███▋      | 169/455 [25:49<45:35,  9.57s/it]

Inputs shape: torch.Size([1, 770])


 37%|███▋      | 170/455 [25:55<41:12,  8.68s/it]

Inputs shape: torch.Size([1, 1299])


 38%|███▊      | 171/455 [26:04<41:22,  8.74s/it]

Inputs shape: torch.Size([1, 989])


 38%|███▊      | 172/455 [26:15<44:38,  9.46s/it]

Inputs shape: torch.Size([1, 1143])


 38%|███▊      | 173/455 [26:23<42:15,  8.99s/it]

Inputs shape: torch.Size([1, 1088])


 38%|███▊      | 174/455 [26:30<39:44,  8.48s/it]

Inputs shape: torch.Size([1, 953])


 38%|███▊      | 175/455 [26:41<43:04,  9.23s/it]

Inputs shape: torch.Size([1, 805])


 39%|███▊      | 176/455 [26:50<42:43,  9.19s/it]

Inputs shape: torch.Size([1, 1035])


 39%|███▉      | 177/455 [26:57<39:12,  8.46s/it]

Inputs shape: torch.Size([1, 1207])


 39%|███▉      | 178/455 [27:08<41:43,  9.04s/it]

Inputs shape: torch.Size([1, 975])


 39%|███▉      | 179/455 [27:18<43:52,  9.54s/it]

Inputs shape: torch.Size([1, 969])


 40%|███▉      | 180/455 [27:25<39:50,  8.69s/it]

Inputs shape: torch.Size([1, 1051])


 40%|███▉      | 181/455 [27:34<40:04,  8.78s/it]

Inputs shape: torch.Size([1, 1256])


 40%|████      | 182/455 [27:45<43:17,  9.51s/it]

Inputs shape: torch.Size([1, 1003])


 40%|████      | 183/455 [27:53<40:13,  8.87s/it]

Inputs shape: torch.Size([1, 1087])


 40%|████      | 184/455 [28:00<38:27,  8.51s/it]

Inputs shape: torch.Size([1, 880])


 41%|████      | 185/455 [28:11<41:40,  9.26s/it]

Inputs shape: torch.Size([1, 1090])


 41%|████      | 186/455 [28:20<41:11,  9.19s/it]

Inputs shape: torch.Size([1, 907])


 41%|████      | 187/455 [28:27<37:34,  8.41s/it]

Inputs shape: torch.Size([1, 1080])


 41%|████▏     | 188/455 [28:37<40:06,  9.01s/it]

Inputs shape: torch.Size([1, 1250])


 42%|████▏     | 189/455 [28:48<41:53,  9.45s/it]

Inputs shape: torch.Size([1, 877])


 42%|████▏     | 190/455 [28:54<37:55,  8.59s/it]

Inputs shape: torch.Size([1, 1295])


 42%|████▏     | 191/455 [29:04<38:39,  8.79s/it]

Inputs shape: torch.Size([1, 879])


 42%|████▏     | 192/455 [29:15<41:29,  9.47s/it]

Inputs shape: torch.Size([1, 799])


 42%|████▏     | 193/455 [29:21<37:48,  8.66s/it]

Inputs shape: torch.Size([1, 962])


 43%|████▎     | 194/455 [29:29<36:24,  8.37s/it]

Inputs shape: torch.Size([1, 1093])


 43%|████▎     | 195/455 [29:40<39:50,  9.19s/it]

Inputs shape: torch.Size([1, 1053])


 43%|████▎     | 196/455 [29:49<39:00,  9.04s/it]

Inputs shape: torch.Size([1, 1228])


 43%|████▎     | 197/455 [29:56<35:42,  8.30s/it]

Inputs shape: torch.Size([1, 742])


 44%|████▎     | 198/455 [30:06<38:07,  8.90s/it]

Inputs shape: torch.Size([1, 979])


 44%|████▎     | 199/455 [30:16<39:37,  9.29s/it]

Inputs shape: torch.Size([1, 1103])


 44%|████▍     | 200/455 [30:23<36:04,  8.49s/it]

Inputs shape: torch.Size([1, 1049])


 44%|████▍     | 201/455 [30:32<36:44,  8.68s/it]

Inputs shape: torch.Size([1, 999])


 44%|████▍     | 202/455 [30:43<39:36,  9.39s/it]

Inputs shape: torch.Size([1, 1116])


 45%|████▍     | 203/455 [30:50<36:05,  8.59s/it]

Inputs shape: torch.Size([1, 1417])


 45%|████▍     | 204/455 [30:58<35:14,  8.42s/it]

Inputs shape: torch.Size([1, 952])


 45%|████▌     | 205/455 [31:09<38:14,  9.18s/it]

Inputs shape: torch.Size([1, 1217])


 45%|████▌     | 206/455 [31:17<36:59,  8.92s/it]

Inputs shape: torch.Size([1, 819])


 45%|████▌     | 207/455 [31:24<34:18,  8.30s/it]

Inputs shape: torch.Size([1, 954])


 46%|████▌     | 208/455 [31:34<36:56,  8.97s/it]

Inputs shape: torch.Size([1, 1195])


 46%|████▌     | 209/455 [31:44<37:45,  9.21s/it]

Inputs shape: torch.Size([1, 1095])


 46%|████▌     | 210/455 [31:51<34:19,  8.41s/it]

Inputs shape: torch.Size([1, 902])


 46%|████▋     | 211/455 [32:00<35:18,  8.68s/it]

Inputs shape: torch.Size([1, 1086])


 47%|████▋     | 212/455 [32:11<38:11,  9.43s/it]

Inputs shape: torch.Size([1, 1240])


 47%|████▋     | 213/455 [32:18<34:39,  8.59s/it]

Inputs shape: torch.Size([1, 1013])


 47%|████▋     | 214/455 [32:26<33:37,  8.37s/it]

Inputs shape: torch.Size([1, 976])


 47%|████▋     | 215/455 [32:36<36:32,  9.13s/it]

Inputs shape: torch.Size([1, 1236])


 47%|████▋     | 216/455 [32:45<35:19,  8.87s/it]

Inputs shape: torch.Size([1, 1233])


 48%|████▊     | 217/455 [32:52<32:46,  8.26s/it]

Inputs shape: torch.Size([1, 1085])


 48%|████▊     | 218/455 [33:02<35:23,  8.96s/it]

Inputs shape: torch.Size([1, 948])


 48%|████▊     | 219/455 [33:12<36:04,  9.17s/it]

Inputs shape: torch.Size([1, 1260])


 48%|████▊     | 220/455 [33:18<32:48,  8.38s/it]

Inputs shape: torch.Size([1, 1240])


 49%|████▊     | 221/455 [33:28<34:10,  8.76s/it]

Inputs shape: torch.Size([1, 959])


 49%|████▉     | 222/455 [33:39<36:44,  9.46s/it]

Inputs shape: torch.Size([1, 708])


 49%|████▉     | 223/455 [33:45<33:05,  8.56s/it]

Inputs shape: torch.Size([1, 1324])


 49%|████▉     | 224/455 [33:54<32:42,  8.50s/it]

Inputs shape: torch.Size([1, 1444])


 49%|████▉     | 225/455 [34:05<35:38,  9.30s/it]

Inputs shape: torch.Size([1, 840])


 50%|████▉     | 226/455 [34:13<33:30,  8.78s/it]

Inputs shape: torch.Size([1, 999])


 50%|████▉     | 227/455 [34:20<31:19,  8.24s/it]

Inputs shape: torch.Size([1, 1243])


 50%|█████     | 228/455 [34:30<34:00,  8.99s/it]

Inputs shape: torch.Size([1, 1164])


 50%|█████     | 229/455 [34:40<34:27,  9.15s/it]

Inputs shape: torch.Size([1, 1288])


 51%|█████     | 230/455 [34:47<31:42,  8.46s/it]

Inputs shape: torch.Size([1, 977])


 51%|█████     | 231/455 [34:57<33:25,  8.95s/it]

Inputs shape: torch.Size([1, 989])


 51%|█████     | 232/455 [35:08<35:25,  9.53s/it]

Inputs shape: torch.Size([1, 1036])


 51%|█████     | 233/455 [35:14<31:51,  8.61s/it]

Inputs shape: torch.Size([1, 1099])


 51%|█████▏    | 234/455 [35:23<31:32,  8.56s/it]

Inputs shape: torch.Size([1, 1086])


 52%|█████▏    | 235/455 [35:34<34:06,  9.30s/it]

Inputs shape: torch.Size([1, 910])


 52%|█████▏    | 236/455 [35:41<31:57,  8.75s/it]

Inputs shape: torch.Size([1, 1113])


 52%|█████▏    | 237/455 [35:48<29:57,  8.24s/it]

Inputs shape: torch.Size([1, 1018])


 52%|█████▏    | 238/455 [35:59<32:38,  9.03s/it]

Inputs shape: torch.Size([1, 1072])


 53%|█████▎    | 239/455 [36:08<32:37,  9.06s/it]

Inputs shape: torch.Size([1, 990])


 53%|█████▎    | 240/455 [36:15<29:48,  8.32s/it]

Inputs shape: torch.Size([1, 995])


 53%|█████▎    | 241/455 [36:25<31:44,  8.90s/it]

Inputs shape: torch.Size([1, 821])


 53%|█████▎    | 242/455 [36:35<33:07,  9.33s/it]

Inputs shape: torch.Size([1, 954])


 53%|█████▎    | 243/455 [36:42<29:54,  8.46s/it]

Inputs shape: torch.Size([1, 1076])


 54%|█████▎    | 244/455 [36:50<30:03,  8.55s/it]

Inputs shape: torch.Size([1, 848])


 54%|█████▍    | 245/455 [37:01<32:26,  9.27s/it]

Inputs shape: torch.Size([1, 1179])


 54%|█████▍    | 246/455 [37:09<30:09,  8.66s/it]

Inputs shape: torch.Size([1, 869])


 54%|█████▍    | 247/455 [37:16<28:21,  8.18s/it]

Inputs shape: torch.Size([1, 1215])


 55%|█████▍    | 248/455 [37:27<31:10,  9.04s/it]

Inputs shape: torch.Size([1, 1035])


 55%|█████▍    | 249/455 [37:36<30:46,  8.96s/it]

Inputs shape: torch.Size([1, 1141])


 55%|█████▍    | 250/455 [37:42<28:08,  8.24s/it]

Inputs shape: torch.Size([1, 1069])


 55%|█████▌    | 251/455 [37:53<30:13,  8.89s/it]

Inputs shape: torch.Size([1, 757])


 55%|█████▌    | 252/455 [38:03<31:36,  9.34s/it]

Inputs shape: torch.Size([1, 867])


 56%|█████▌    | 253/455 [38:09<28:35,  8.49s/it]

Inputs shape: torch.Size([1, 1061])


 56%|█████▌    | 254/455 [38:18<28:52,  8.62s/it]

Inputs shape: torch.Size([1, 940])


 56%|█████▌    | 255/455 [38:29<31:09,  9.35s/it]

Inputs shape: torch.Size([1, 1142])


 56%|█████▋    | 256/455 [38:36<28:40,  8.65s/it]

Inputs shape: torch.Size([1, 996])


 56%|█████▋    | 257/455 [38:44<27:20,  8.28s/it]

Inputs shape: torch.Size([1, 1005])


 57%|█████▋    | 258/455 [38:55<29:43,  9.05s/it]

Inputs shape: torch.Size([1, 1213])


 57%|█████▋    | 259/455 [39:03<29:17,  8.97s/it]

Inputs shape: torch.Size([1, 1167])


 57%|█████▋    | 260/455 [39:10<26:46,  8.24s/it]

Inputs shape: torch.Size([1, 833])


 57%|█████▋    | 261/455 [39:20<28:25,  8.79s/it]

Inputs shape: torch.Size([1, 1002])


 58%|█████▊    | 262/455 [39:31<29:57,  9.31s/it]

Inputs shape: torch.Size([1, 1147])


 58%|█████▊    | 263/455 [39:37<27:08,  8.48s/it]

Inputs shape: torch.Size([1, 1077])


 58%|█████▊    | 264/455 [39:46<27:19,  8.58s/it]

Inputs shape: torch.Size([1, 807])


 58%|█████▊    | 265/455 [39:57<29:29,  9.31s/it]

Inputs shape: torch.Size([1, 922])


 58%|█████▊    | 266/455 [40:04<27:12,  8.64s/it]

Inputs shape: torch.Size([1, 840])


 59%|█████▊    | 267/455 [40:11<25:40,  8.19s/it]

Inputs shape: torch.Size([1, 1242])


 59%|█████▉    | 268/455 [40:22<28:06,  9.02s/it]

Inputs shape: torch.Size([1, 1189])


 59%|█████▉    | 269/455 [40:31<27:51,  8.99s/it]

Inputs shape: torch.Size([1, 923])


 59%|█████▉    | 270/455 [40:38<25:25,  8.25s/it]

Inputs shape: torch.Size([1, 1009])


 60%|█████▉    | 271/455 [40:48<27:14,  8.88s/it]

Inputs shape: torch.Size([1, 1199])


 60%|█████▉    | 272/455 [40:58<28:22,  9.30s/it]

Inputs shape: torch.Size([1, 1288])


 60%|██████    | 273/455 [41:05<25:48,  8.51s/it]

Inputs shape: torch.Size([1, 1030])


 60%|██████    | 274/455 [41:14<26:11,  8.69s/it]

Inputs shape: torch.Size([1, 1044])


 60%|██████    | 275/455 [41:25<28:18,  9.44s/it]

Inputs shape: torch.Size([1, 1024])


 61%|██████    | 276/455 [41:32<25:54,  8.68s/it]

Inputs shape: torch.Size([1, 1029])


 61%|██████    | 277/455 [41:40<24:47,  8.36s/it]

Inputs shape: torch.Size([1, 833])


 61%|██████    | 278/455 [41:51<26:58,  9.15s/it]

Inputs shape: torch.Size([1, 1173])


 61%|██████▏   | 279/455 [42:00<26:32,  9.05s/it]

Inputs shape: torch.Size([1, 1018])


 62%|██████▏   | 280/455 [42:06<24:13,  8.30s/it]

Inputs shape: torch.Size([1, 1003])


 62%|██████▏   | 281/455 [42:17<26:04,  8.99s/it]

Inputs shape: torch.Size([1, 916])


 62%|██████▏   | 282/455 [42:27<26:45,  9.28s/it]

Inputs shape: torch.Size([1, 852])


 62%|██████▏   | 283/455 [42:33<24:09,  8.43s/it]

Inputs shape: torch.Size([1, 935])


 62%|██████▏   | 284/455 [42:42<24:38,  8.64s/it]

Inputs shape: torch.Size([1, 1138])


 63%|██████▎   | 285/455 [42:53<26:31,  9.36s/it]

Inputs shape: torch.Size([1, 804])


 63%|██████▎   | 286/455 [43:00<24:02,  8.54s/it]

Inputs shape: torch.Size([1, 1168])


 63%|██████▎   | 287/455 [43:08<23:15,  8.31s/it]

Inputs shape: torch.Size([1, 1162])


 63%|██████▎   | 288/455 [43:19<25:21,  9.11s/it]

Inputs shape: torch.Size([1, 1041])


 64%|██████▎   | 289/455 [43:27<24:30,  8.86s/it]

Inputs shape: torch.Size([1, 944])


 64%|██████▎   | 290/455 [43:34<22:39,  8.24s/it]

Inputs shape: torch.Size([1, 1121])


 64%|██████▍   | 291/455 [43:44<24:24,  8.93s/it]

Inputs shape: torch.Size([1, 802])


 64%|██████▍   | 292/455 [43:54<25:04,  9.23s/it]

Inputs shape: torch.Size([1, 1242])


 64%|██████▍   | 293/455 [44:01<22:47,  8.44s/it]

Inputs shape: torch.Size([1, 1294])


 65%|██████▍   | 294/455 [44:11<23:45,  8.86s/it]

Inputs shape: torch.Size([1, 889])


 65%|██████▍   | 295/455 [44:22<25:23,  9.52s/it]

Inputs shape: torch.Size([1, 945])


 65%|██████▌   | 296/455 [44:28<22:48,  8.61s/it]

Inputs shape: torch.Size([1, 1179])


 65%|██████▌   | 297/455 [44:37<22:30,  8.55s/it]

Inputs shape: torch.Size([1, 1047])


 65%|██████▌   | 298/455 [44:48<24:23,  9.32s/it]

Inputs shape: torch.Size([1, 1109])


 66%|██████▌   | 299/455 [44:56<23:06,  8.89s/it]

Inputs shape: torch.Size([1, 946])


 66%|██████▌   | 300/455 [45:03<21:41,  8.40s/it]

Inputs shape: torch.Size([1, 933])


 66%|██████▌   | 301/455 [45:14<23:47,  9.27s/it]

Inputs shape: torch.Size([1, 1078])


 66%|██████▋   | 302/455 [45:23<23:28,  9.21s/it]

Inputs shape: torch.Size([1, 894])


 67%|██████▋   | 303/455 [45:30<21:19,  8.42s/it]

Inputs shape: torch.Size([1, 904])


 67%|██████▋   | 304/455 [45:40<22:44,  9.04s/it]

Inputs shape: torch.Size([1, 840])


 67%|██████▋   | 305/455 [45:50<23:30,  9.40s/it]

Inputs shape: torch.Size([1, 1002])


 67%|██████▋   | 306/455 [45:57<21:21,  8.60s/it]

Inputs shape: torch.Size([1, 1245])


 67%|██████▋   | 307/455 [46:07<21:55,  8.89s/it]

Inputs shape: torch.Size([1, 795])


 68%|██████▊   | 308/455 [46:18<23:21,  9.53s/it]

Inputs shape: torch.Size([1, 825])


 68%|██████▊   | 309/455 [46:24<21:02,  8.65s/it]

Inputs shape: torch.Size([1, 1027])


 68%|██████▊   | 310/455 [46:32<20:25,  8.45s/it]

Inputs shape: torch.Size([1, 921])


 68%|██████▊   | 311/455 [46:43<22:11,  9.24s/it]

Inputs shape: torch.Size([1, 1007])


 69%|██████▊   | 312/455 [46:52<21:19,  8.95s/it]

Inputs shape: torch.Size([1, 812])


 69%|██████▉   | 313/455 [46:58<19:37,  8.29s/it]

Inputs shape: torch.Size([1, 991])


 69%|██████▉   | 314/455 [47:10<21:30,  9.15s/it]

Inputs shape: torch.Size([1, 1017])


 69%|██████▉   | 315/455 [47:20<22:25,  9.61s/it]

Inputs shape: torch.Size([1, 957])


 69%|██████▉   | 316/455 [47:27<20:27,  8.83s/it]

Inputs shape: torch.Size([1, 1051])


 70%|██████▉   | 317/455 [47:38<21:21,  9.28s/it]

Inputs shape: torch.Size([1, 879])


 70%|██████▉   | 318/455 [47:49<22:54, 10.03s/it]

Inputs shape: torch.Size([1, 1073])


 70%|███████   | 319/455 [47:57<20:48,  9.18s/it]

Inputs shape: torch.Size([1, 948])


 70%|███████   | 320/455 [48:06<20:56,  9.31s/it]

Inputs shape: torch.Size([1, 1202])


 71%|███████   | 321/455 [48:18<22:09,  9.92s/it]

Inputs shape: torch.Size([1, 880])


 71%|███████   | 322/455 [48:25<20:11,  9.11s/it]

Inputs shape: torch.Size([1, 1186])


 71%|███████   | 323/455 [48:33<19:08,  8.70s/it]

Inputs shape: torch.Size([1, 1205])


 71%|███████   | 324/455 [48:44<20:37,  9.45s/it]

Inputs shape: torch.Size([1, 1060])


 71%|███████▏  | 325/455 [48:53<20:06,  9.28s/it]

Inputs shape: torch.Size([1, 1430])


 72%|███████▏  | 326/455 [49:00<18:41,  8.69s/it]

Inputs shape: torch.Size([1, 951])


 72%|███████▏  | 327/455 [49:11<19:45,  9.26s/it]

Inputs shape: torch.Size([1, 937])


 72%|███████▏  | 328/455 [49:21<20:04,  9.49s/it]

Inputs shape: torch.Size([1, 920])


 72%|███████▏  | 329/455 [49:27<18:12,  8.67s/it]

Inputs shape: torch.Size([1, 1152])


 73%|███████▎  | 330/455 [49:37<18:48,  9.03s/it]

Inputs shape: torch.Size([1, 1129])


 73%|███████▎  | 331/455 [49:49<20:08,  9.74s/it]

Inputs shape: torch.Size([1, 1298])


 73%|███████▎  | 332/455 [49:56<18:18,  8.93s/it]

Inputs shape: torch.Size([1, 1040])


 73%|███████▎  | 333/455 [50:04<17:54,  8.81s/it]

Inputs shape: torch.Size([1, 1074])


 73%|███████▎  | 334/455 [50:15<19:15,  9.55s/it]

Inputs shape: torch.Size([1, 1336])


 74%|███████▎  | 335/455 [50:24<18:18,  9.16s/it]

Inputs shape: torch.Size([1, 976])


 74%|███████▍  | 336/455 [50:31<17:00,  8.57s/it]

Inputs shape: torch.Size([1, 943])


 74%|███████▍  | 337/455 [50:42<18:15,  9.28s/it]

Inputs shape: torch.Size([1, 1343])


 74%|███████▍  | 338/455 [50:51<18:07,  9.29s/it]

Inputs shape: torch.Size([1, 1245])


 75%|███████▍  | 339/455 [50:58<16:29,  8.53s/it]

Inputs shape: torch.Size([1, 1163])


 75%|███████▍  | 340/455 [51:08<17:28,  9.11s/it]

Inputs shape: torch.Size([1, 1266])


 75%|███████▍  | 341/455 [51:13<14:59,  7.89s/it]

Inputs shape: torch.Size([1, 1023])


 75%|███████▌  | 342/455 [51:22<15:09,  8.05s/it]

Inputs shape: torch.Size([1, 1256])


 75%|███████▌  | 343/455 [51:29<14:27,  7.74s/it]

Inputs shape: torch.Size([1, 977])


 76%|███████▌  | 344/455 [51:40<16:00,  8.65s/it]

Inputs shape: torch.Size([1, 1009])


 76%|███████▌  | 345/455 [51:49<16:24,  8.95s/it]

Inputs shape: torch.Size([1, 1004])


 76%|███████▌  | 346/455 [51:56<15:02,  8.28s/it]

Inputs shape: torch.Size([1, 1118])


 76%|███████▋  | 347/455 [52:06<15:50,  8.80s/it]

Inputs shape: torch.Size([1, 1237])


 76%|███████▋  | 348/455 [52:17<16:58,  9.52s/it]

Inputs shape: torch.Size([1, 1029])


 77%|███████▋  | 349/455 [52:24<15:20,  8.68s/it]

Inputs shape: torch.Size([1, 977])


 77%|███████▋  | 350/455 [52:33<15:12,  8.69s/it]

Inputs shape: torch.Size([1, 1285])


 77%|███████▋  | 351/455 [52:44<16:22,  9.45s/it]

Inputs shape: torch.Size([1, 1017])


 77%|███████▋  | 352/455 [52:51<15:08,  8.82s/it]

Inputs shape: torch.Size([1, 970])


 78%|███████▊  | 353/455 [52:58<14:09,  8.33s/it]

Inputs shape: torch.Size([1, 896])


 78%|███████▊  | 354/455 [53:09<15:20,  9.11s/it]

Inputs shape: torch.Size([1, 1058])


 78%|███████▊  | 355/455 [53:18<15:06,  9.07s/it]

Inputs shape: torch.Size([1, 1105])


 78%|███████▊  | 356/455 [53:25<13:47,  8.36s/it]

Inputs shape: torch.Size([1, 1046])


 78%|███████▊  | 357/455 [53:35<14:39,  8.97s/it]

Inputs shape: torch.Size([1, 1229])


 79%|███████▊  | 358/455 [53:46<15:14,  9.42s/it]

Inputs shape: torch.Size([1, 1008])


 79%|███████▉  | 359/455 [53:52<13:41,  8.56s/it]

Inputs shape: torch.Size([1, 1001])


 79%|███████▉  | 360/455 [54:01<13:42,  8.66s/it]

Inputs shape: torch.Size([1, 1117])


 79%|███████▉  | 361/455 [54:12<14:44,  9.41s/it]

Inputs shape: torch.Size([1, 1019])


 80%|███████▉  | 362/455 [54:20<13:31,  8.72s/it]

Inputs shape: torch.Size([1, 934])


 80%|███████▉  | 363/455 [54:27<12:51,  8.38s/it]

Inputs shape: torch.Size([1, 1014])


 80%|████████  | 364/455 [54:38<13:57,  9.20s/it]

Inputs shape: torch.Size([1, 746])


 80%|████████  | 365/455 [54:47<13:28,  8.98s/it]

Inputs shape: torch.Size([1, 1247])


 80%|████████  | 366/455 [54:53<12:17,  8.29s/it]

Inputs shape: torch.Size([1, 977])


 81%|████████  | 367/455 [55:04<13:09,  8.97s/it]

Inputs shape: torch.Size([1, 1214])


 81%|████████  | 368/455 [55:14<13:25,  9.25s/it]

Inputs shape: torch.Size([1, 1145])


 81%|████████  | 369/455 [55:21<12:10,  8.49s/it]

Inputs shape: torch.Size([1, 1138])


 81%|████████▏ | 370/455 [55:30<12:32,  8.85s/it]

Inputs shape: torch.Size([1, 1225])


 82%|████████▏ | 371/455 [55:41<13:21,  9.55s/it]

Inputs shape: torch.Size([1, 1474])


 82%|████████▏ | 372/455 [55:48<12:00,  8.68s/it]

Inputs shape: torch.Size([1, 1121])


 82%|████████▏ | 373/455 [55:56<11:39,  8.53s/it]

Inputs shape: torch.Size([1, 1093])


 82%|████████▏ | 374/455 [56:07<12:34,  9.32s/it]

Inputs shape: torch.Size([1, 950])


 82%|████████▏ | 375/455 [56:15<11:48,  8.86s/it]

Inputs shape: torch.Size([1, 1064])


 83%|████████▎ | 376/455 [56:22<10:53,  8.27s/it]

Inputs shape: torch.Size([1, 1061])


 83%|████████▎ | 377/455 [56:33<11:45,  9.05s/it]

Inputs shape: torch.Size([1, 860])


 83%|████████▎ | 378/455 [56:43<12:02,  9.38s/it]

Inputs shape: torch.Size([1, 926])


 83%|████████▎ | 379/455 [56:50<10:47,  8.52s/it]

Inputs shape: torch.Size([1, 1151])


 84%|████████▎ | 380/455 [57:00<11:24,  9.13s/it]

Inputs shape: torch.Size([1, 1195])


 84%|████████▎ | 381/455 [57:10<11:38,  9.45s/it]

Inputs shape: torch.Size([1, 964])


 84%|████████▍ | 382/455 [57:17<10:28,  8.61s/it]

Inputs shape: torch.Size([1, 1079])


 84%|████████▍ | 383/455 [57:26<10:34,  8.82s/it]

Inputs shape: torch.Size([1, 1032])


 84%|████████▍ | 384/455 [57:37<11:13,  9.49s/it]

Inputs shape: torch.Size([1, 1034])


 85%|████████▍ | 385/455 [57:44<10:03,  8.62s/it]

Inputs shape: torch.Size([1, 1020])


 85%|████████▍ | 386/455 [57:52<09:42,  8.43s/it]

Inputs shape: torch.Size([1, 1108])


 85%|████████▌ | 387/455 [58:03<10:26,  9.21s/it]

Inputs shape: torch.Size([1, 834])


 85%|████████▌ | 388/455 [58:11<09:52,  8.85s/it]

Inputs shape: torch.Size([1, 1227])


 85%|████████▌ | 389/455 [58:18<09:04,  8.25s/it]

Inputs shape: torch.Size([1, 1186])


 86%|████████▌ | 390/455 [58:24<08:13,  7.59s/it]

Inputs shape: torch.Size([1, 1314])


 86%|████████▌ | 391/455 [58:35<09:17,  8.71s/it]

Inputs shape: torch.Size([1, 1098])


 86%|████████▌ | 392/455 [58:42<08:34,  8.17s/it]

Inputs shape: torch.Size([1, 985])


 86%|████████▋ | 393/455 [58:50<08:19,  8.06s/it]

Inputs shape: torch.Size([1, 1181])


 87%|████████▋ | 394/455 [59:01<09:06,  8.97s/it]

Inputs shape: torch.Size([1, 1135])


 87%|████████▋ | 395/455 [59:09<08:43,  8.72s/it]

Inputs shape: torch.Size([1, 1000])


 87%|████████▋ | 396/455 [59:16<08:00,  8.15s/it]

Inputs shape: torch.Size([1, 778])


 87%|████████▋ | 397/455 [59:26<08:32,  8.84s/it]

Inputs shape: torch.Size([1, 1172])


 87%|████████▋ | 398/455 [59:36<08:40,  9.13s/it]

Inputs shape: torch.Size([1, 877])


 88%|████████▊ | 399/455 [59:43<07:45,  8.32s/it]

Inputs shape: torch.Size([1, 964])


 88%|████████▊ | 400/455 [59:52<07:53,  8.61s/it]

Inputs shape: torch.Size([1, 965])


 88%|████████▊ | 401/455 [1:00:03<08:25,  9.36s/it]

Inputs shape: torch.Size([1, 1044])


 88%|████████▊ | 402/455 [1:00:10<07:34,  8.57s/it]

Inputs shape: torch.Size([1, 1136])


 89%|████████▊ | 403/455 [1:00:18<07:18,  8.43s/it]

Inputs shape: torch.Size([1, 923])


 89%|████████▉ | 404/455 [1:00:29<07:51,  9.25s/it]

Inputs shape: torch.Size([1, 863])


 89%|████████▉ | 405/455 [1:00:38<07:31,  9.03s/it]

Inputs shape: torch.Size([1, 1151])


 89%|████████▉ | 406/455 [1:00:45<06:53,  8.43s/it]

Inputs shape: torch.Size([1, 1223])


 89%|████████▉ | 407/455 [1:00:56<07:20,  9.17s/it]

Inputs shape: torch.Size([1, 1042])


 90%|████████▉ | 408/455 [1:01:05<07:17,  9.31s/it]

Inputs shape: torch.Size([1, 1068])


 90%|████████▉ | 409/455 [1:01:12<06:32,  8.52s/it]

Inputs shape: torch.Size([1, 1037])


 90%|█████████ | 410/455 [1:01:22<06:42,  8.94s/it]

Inputs shape: torch.Size([1, 1120])


 90%|█████████ | 411/455 [1:01:33<07:02,  9.60s/it]

Inputs shape: torch.Size([1, 748])


 91%|█████████ | 412/455 [1:01:39<06:12,  8.66s/it]

Inputs shape: torch.Size([1, 1041])


 91%|█████████ | 413/455 [1:01:48<05:59,  8.56s/it]

Inputs shape: torch.Size([1, 1207])


 91%|█████████ | 414/455 [1:01:59<06:22,  9.33s/it]

Inputs shape: torch.Size([1, 890])


 91%|█████████ | 415/455 [1:02:07<05:54,  8.87s/it]

Inputs shape: torch.Size([1, 1039])


 91%|█████████▏| 416/455 [1:02:14<05:24,  8.31s/it]

Inputs shape: torch.Size([1, 1095])


 92%|█████████▏| 417/455 [1:02:24<05:44,  9.07s/it]

Inputs shape: torch.Size([1, 1103])


 92%|█████████▏| 418/455 [1:02:34<05:37,  9.11s/it]

Inputs shape: torch.Size([1, 1149])


 92%|█████████▏| 419/455 [1:02:40<05:00,  8.35s/it]

Inputs shape: torch.Size([1, 1078])


 92%|█████████▏| 420/455 [1:02:50<05:10,  8.88s/it]

Inputs shape: torch.Size([1, 1091])


 93%|█████████▎| 421/455 [1:03:01<05:19,  9.39s/it]

Inputs shape: torch.Size([1, 1235])


 93%|█████████▎| 422/455 [1:03:08<04:41,  8.54s/it]

Inputs shape: torch.Size([1, 1186])


 93%|█████████▎| 423/455 [1:03:16<04:36,  8.63s/it]

Inputs shape: torch.Size([1, 976])


 93%|█████████▎| 424/455 [1:03:27<04:50,  9.37s/it]

Inputs shape: torch.Size([1, 1103])


 93%|█████████▎| 425/455 [1:03:34<04:19,  8.64s/it]

Inputs shape: torch.Size([1, 1157])


 94%|█████████▎| 426/455 [1:03:42<04:00,  8.28s/it]

Inputs shape: torch.Size([1, 1073])


 94%|█████████▍| 427/455 [1:03:53<04:14,  9.09s/it]

Inputs shape: torch.Size([1, 952])


 94%|█████████▍| 428/455 [1:04:01<04:01,  8.95s/it]

Inputs shape: torch.Size([1, 1000])


 94%|█████████▍| 429/455 [1:04:08<03:33,  8.23s/it]

Inputs shape: torch.Size([1, 1006])


 95%|█████████▍| 430/455 [1:04:18<03:42,  8.89s/it]

Inputs shape: torch.Size([1, 1138])


 95%|█████████▍| 431/455 [1:04:28<03:41,  9.21s/it]

Inputs shape: torch.Size([1, 1023])


 95%|█████████▍| 432/455 [1:04:35<03:14,  8.46s/it]

Inputs shape: torch.Size([1, 943])


 95%|█████████▌| 433/455 [1:04:45<03:13,  8.81s/it]

Inputs shape: torch.Size([1, 949])


 95%|█████████▌| 434/455 [1:04:56<03:17,  9.41s/it]

Inputs shape: torch.Size([1, 930])


 96%|█████████▌| 435/455 [1:05:02<02:50,  8.55s/it]

Inputs shape: torch.Size([1, 1273])


 96%|█████████▌| 436/455 [1:05:11<02:42,  8.54s/it]

Inputs shape: torch.Size([1, 1033])


 96%|█████████▌| 437/455 [1:05:22<02:46,  9.26s/it]

Inputs shape: torch.Size([1, 1096])


 96%|█████████▋| 438/455 [1:05:29<02:29,  8.78s/it]

Inputs shape: torch.Size([1, 1109])


 96%|█████████▋| 439/455 [1:05:36<02:12,  8.27s/it]

Inputs shape: torch.Size([1, 1128])


 97%|█████████▋| 440/455 [1:05:47<02:17,  9.15s/it]

Inputs shape: torch.Size([1, 1097])


 97%|█████████▋| 441/455 [1:05:57<02:07,  9.12s/it]

Inputs shape: torch.Size([1, 1038])


 97%|█████████▋| 442/455 [1:06:03<01:48,  8.36s/it]

Inputs shape: torch.Size([1, 948])


 97%|█████████▋| 443/455 [1:06:13<01:47,  8.93s/it]

Inputs shape: torch.Size([1, 963])


 98%|█████████▊| 444/455 [1:06:24<01:43,  9.38s/it]

Inputs shape: torch.Size([1, 887])


 98%|█████████▊| 445/455 [1:06:30<01:25,  8.54s/it]

Inputs shape: torch.Size([1, 1093])


 98%|█████████▊| 446/455 [1:06:39<01:17,  8.59s/it]

Inputs shape: torch.Size([1, 782])


 98%|█████████▊| 447/455 [1:06:50<01:14,  9.29s/it]

Inputs shape: torch.Size([1, 894])


 98%|█████████▊| 448/455 [1:06:57<01:00,  8.62s/it]

Inputs shape: torch.Size([1, 880])


 99%|█████████▊| 449/455 [1:07:04<00:49,  8.24s/it]

Inputs shape: torch.Size([1, 977])


 99%|█████████▉| 450/455 [1:07:15<00:45,  9.07s/it]

Inputs shape: torch.Size([1, 1154])


 99%|█████████▉| 451/455 [1:07:24<00:35,  8.98s/it]

Inputs shape: torch.Size([1, 885])


 99%|█████████▉| 452/455 [1:07:31<00:24,  8.23s/it]

Inputs shape: torch.Size([1, 1152])


100%|█████████▉| 453/455 [1:07:41<00:17,  8.83s/it]

Inputs shape: torch.Size([1, 1036])


100%|█████████▉| 454/455 [1:07:51<00:09,  9.34s/it]

Inputs shape: torch.Size([1, 952])


100%|██████████| 455/455 [1:07:58<00:00,  8.96s/it]


In [None]:
from nltk.translate.bleu_score import corpus_bleu

bleu_score = corpus_bleu(
    [[ref] for ref in references],
    predictions
)
print(f"BLEU Score: {bleu_score}")

BLEU Score: 0.06133919602439004


In [None]:
%%capture
!pip install rouge_score

In [None]:
from rouge_score import rouge_scorer

scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
rouge_scores = [scorer.score(ref, pred) for ref, pred in zip(references, predictions)]

rouge1 = sum([score["rouge1"].fmeasure for score in rouge_scores]) / len(rouge_scores)
rouge2 = sum([score["rouge2"].fmeasure for score in rouge_scores]) / len(rouge_scores)
rougeL = sum([score["rougeL"].fmeasure for score in rouge_scores]) / len(rouge_scores)

print(f"ROUGE-1: {rouge1}")
print(f"ROUGE-2: {rouge2}")
print(f"ROUGE-L: {rougeL}")

ROUGE-1: 0.10813143784100818
ROUGE-2: 0.057229347077285886
ROUGE-L: 0.07532111627142717


In [None]:
position = 30
for position in range(91, 120):
    print(references[position])
    print(predictions[position])
    print()

How about "12 Years a Slave"? It's a powerful and emotional true story about a man's struggle for freedom during the era of slavery. The storytelling is flawless and the performances are extremely moving. It's definitely visually stunning and will leave a lasting impact on you.
. Use the user's movie history and preferences to suggest movies that align with their interests.

Seen movies and their corresponding evaluation and details:
Movie: Jumanji: The Next Level (2019)
Positive comments: - Fantastic film, action-packed and hilarious- New deadly locations and additional players add excitement- Thrilling set pieces and flawless visual effects- Spectacular performances from the cast, especially Dwayne Johnson, Jack Black, Kevin Hart, Karen Gillan, and Awkwafina- Impressive supporting cast
Negative comments: There are no negative comments
Details:
Title: Jumanji: The Next Level (2019)
Genre: Action, Adventure, Comedy, Fantasy
Director: Jake Kasdan
Cast: Dwayne Johnson, Kevin Hart, Jack B

In [None]:
from collections import Counter

from nltk.translate import bleu_score
from nltk.translate.bleu_score import SmoothingFunction
import numpy as np


def distinct(seqs):
    """ Calculate intra/inter distinct 1/2. """
    """ Recuperado de https://github.com/PaddlePaddle/models/blob/release/1.6/PaddleNLP/Research/Dialogue-PLATO/plato/metrics/metrics.py"""
    batch_size = len(seqs)
    intra_dist1, intra_dist2 = [], []
    unigrams_all, bigrams_all = Counter(), Counter()
    for seq in seqs:
        unigrams = Counter(seq)
        bigrams = Counter(zip(seq, seq[1:]))
        intra_dist1.append((len(unigrams)+1e-12) / (len(seq)+1e-5))
        intra_dist2.append((len(bigrams)+1e-12) / (max(0, len(seq)-1)+1e-5))

        unigrams_all.update(unigrams)
        bigrams_all.update(bigrams)

    inter_dist1 = (len(unigrams_all)+1e-12) / (sum(unigrams_all.values())+1e-5)
    inter_dist2 = (len(bigrams_all)+1e-12) / (sum(bigrams_all.values())+1e-5)
    intra_dist1 = np.average(intra_dist1)
    intra_dist2 = np.average(intra_dist2)
    return intra_dist1, intra_dist2, inter_dist1, inter_dist2

In [None]:
import nltk
nltk.download('punkt_tab')

tokenized_text = [nltk.word_tokenize(text) for text in references]
resultado = distinct(tokenized_text)
print(resultado)

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


(0.8174602211235673, 0.9857301007313511, 0.11541064977609411, 0.3757125912018547)


In [None]:
%%capture
!pip install bert-score

from bert_score import score

# Calcular métricas de BertScore
P, R, F1 = score(predictions, references, lang="en", rescale_with_baseline=True)

# Imprimir las métricas
print(f"BertScore Precision (P): {P.mean().item():.4f}")
print(f"BertScore Recall (R): {R.mean().item():.4f}")
print(f"BertScore F1: {F1.mean().item():.4f}")


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from nltk import ngrams
from collections import Counter

tokenized_text = [nltk.word_tokenize(text) for text in references]

def calculate_distinct(tokenized_text):
    # unigramas
    all_unigrams = [word for text in tokenized_text for word in text]
    unique_unigrams = set(all_unigrams)
    distinc_1 = len(unique_unigrams) / len(all_unigrams) if all_unigrams else 0

    # bigramas
    all_bigrams = [bigram for text in tokenized_text for bigram in ngrams(text, 2)]
    unique_bigrams = set(all_bigrams)
    distinc_2 = len(unique_bigrams) / len(all_bigrams) if all_bigrams else 0

    return distinc_1, distinc_2

distinc_1, distinc_2 = calculate_distinct(tokenized_text)

print(f"Distinc-1 (Unigramas): {distinc_1}")
print(f"Distinc-2 (Bigramas): {distinc_2}")


Distinc-1 (Unigramas): 0.11541064981949459
Distinc-2 (Bigramas): 0.37571259134560203
