# Descarga de los datos formateados

In [1]:
%%capture
!pip install gdown

In [2]:
import gdown

files_to_download = {
    "formatted_train.json": "1Fm7aPdCv6bguP7UgoCqdUVcm8xpDo18_",
    "formatted_test.json": "1a4YeF--Sks7WA1ZIQL2zDZx4teKk7p4m",
    "formatted_validation.json": "1PC9OhZhNZt8lFO9wifhydy0BrIYe8Dkm"
}

for destination, file_id in files_to_download.items():
    gdown.download(f"https://drive.google.com/uc?id={file_id}", destination, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=1Fm7aPdCv6bguP7UgoCqdUVcm8xpDo18_
From (redirected): https://drive.google.com/uc?id=1Fm7aPdCv6bguP7UgoCqdUVcm8xpDo18_&confirm=t&uuid=b426532f-326f-4f61-adbd-eb242cf4ab5e
To: /content/formatted_train.json
100%|██████████| 237M/237M [00:05<00:00, 42.9MB/s]
Downloading...
From: https://drive.google.com/uc?id=1a4YeF--Sks7WA1ZIQL2zDZx4teKk7p4m
To: /content/formatted_test.json
100%|██████████| 10.8M/10.8M [00:00<00:00, 29.1MB/s]
Downloading...
From: https://drive.google.com/uc?id=1PC9OhZhNZt8lFO9wifhydy0BrIYe8Dkm
To: /content/formatted_validation.json
100%|██████████| 23.7M/23.7M [00:00<00:00, 56.3MB/s]


# Implementación del modelo

In [3]:
%%capture
!pip install unsloth "xformers==0.0.28.post2"
# Also get the latest nightly Unsloth!
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install datasets

In [4]:
from datasets import load_dataset, Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq, TextStreamer, AutoTokenizer
from peft import AutoPeftModelForCausalLM
import json

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [16]:
with open("formatted_train.json", "r") as f:
    data = json.load(f)

dataset = Dataset.from_dict({"conversations": data})

In [17]:
max_seq_length = 2048

def initialize_model_and_tokenizer():
    base_model_name = "unsloth/Phi-3.5-mini-instructt"
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=base_model_name,
        max_seq_length=max_seq_length,
        dtype=None,
        load_in_4bit=True,
    )
    model = FastLanguageModel.get_peft_model(
        model,
        r=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_alpha=16,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=3407,
        use_rslora=False,
        loftq_config=None
    )
    tokenizer = get_chat_template(
        tokenizer,
        chat_template="phi-3.5"
    )

    return model, tokenizer

In [18]:
def preprocess_dataset(dataset, tokenizer):
    def formatting_prompts_func(examples):
        convos = examples["conversations"]
        texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
        return {"text": texts}

    formatted_dataset = dataset.map(formatting_prompts_func, batched=True)
    return formatted_dataset

In [19]:
used_seed = 3407

# Testeo Zero Shot

In [20]:
max_seq_length = 2048
base_model_name = "unsloth/Phi-3.5-mini-instruct"
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=base_model_name,
    max_seq_length=max_seq_length,
    dtype=None,
    load_in_4bit=True,
)
FastLanguageModel.for_inference(model)

==((====))==  Unsloth 2024.12.4: Fast Llama patching. Transformers:4.46.3.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.26G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/140 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/3.37k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
          (v_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
          (o_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
          (rotary_emb): LongRopeRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear4bit(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear4bit(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNor

In [21]:
tokenizer = get_chat_template(
    tokenizer,
    chat_template="phi-3.5",
)

In [22]:
with open("formatted_test.json", "r") as f:
    data = json.load(f)

dataset = Dataset.from_dict({"conversations": data})

In [23]:
formatted_test_dataset = preprocess_dataset(dataset, tokenizer)

Map:   0%|          | 0/2277 [00:00<?, ? examples/s]

In [24]:
percentage = 0.2
sample_size = int(len(formatted_test_dataset) * percentage)
subset_test_dataset = formatted_test_dataset.shuffle(seed=42).select(range(sample_size))

print(f"Usando {sample_size} ejemplos para la evaluación.")

Usando 455 ejemplos para la evaluación.


In [25]:
from tqdm import tqdm

generation_args = {
    "max_new_tokens": 100,
    "temperature": 0.3,
    "use_cache": True,
    "top_p": 0.9,
    "top_k": 50,
}

predictions = []
references = []

for example in tqdm(subset_test_dataset):
    conversation = example["conversations"]
    system_message = conversation[0]
    user_assistant_turns = conversation[1:-2]

    max_turns = 10
    truncated_turns = user_assistant_turns[-max_turns:]

    messages = [system_message] + truncated_turns
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda")

    print(f"Inputs shape: {inputs.shape}")

    with torch.no_grad():
        outputs = model.generate(input_ids=inputs, **generation_args)

    generated_response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    if "assistant" in generated_response:
        generated_response = generated_response.split("assistant")[-1].strip()

    predictions.append(generated_response)

    references.append(conversation[-2]["content"])


  0%|          | 0/455 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Inputs shape: torch.Size([1, 1039])


  0%|          | 1/455 [00:10<1:18:47, 10.41s/it]

Inputs shape: torch.Size([1, 1106])


  0%|          | 2/455 [00:18<1:06:36,  8.82s/it]

Inputs shape: torch.Size([1, 1004])


  1%|          | 3/455 [00:25<1:00:34,  8.04s/it]

Inputs shape: torch.Size([1, 867])


  1%|          | 4/455 [00:30<51:01,  6.79s/it]  

Inputs shape: torch.Size([1, 906])


  1%|          | 5/455 [00:35<46:24,  6.19s/it]

Inputs shape: torch.Size([1, 882])


  1%|▏         | 6/455 [00:42<49:41,  6.64s/it]

Inputs shape: torch.Size([1, 1208])


  2%|▏         | 7/455 [00:51<54:17,  7.27s/it]

Inputs shape: torch.Size([1, 1002])


  2%|▏         | 8/455 [00:56<48:43,  6.54s/it]

Inputs shape: torch.Size([1, 983])


  2%|▏         | 9/455 [01:01<46:38,  6.27s/it]

Inputs shape: torch.Size([1, 1608])


  2%|▏         | 10/455 [01:09<49:44,  6.71s/it]

Inputs shape: torch.Size([1, 770])


  2%|▏         | 11/455 [01:17<51:59,  7.03s/it]

Inputs shape: torch.Size([1, 958])


  3%|▎         | 12/455 [01:23<48:52,  6.62s/it]

Inputs shape: torch.Size([1, 932])


  3%|▎         | 13/455 [01:27<44:39,  6.06s/it]

Inputs shape: torch.Size([1, 1062])


  3%|▎         | 14/455 [01:34<45:07,  6.14s/it]

Inputs shape: torch.Size([1, 1315])


  3%|▎         | 15/455 [01:42<49:03,  6.69s/it]

Inputs shape: torch.Size([1, 866])


  4%|▎         | 16/455 [01:48<49:09,  6.72s/it]

Inputs shape: torch.Size([1, 1369])


  4%|▎         | 17/455 [01:53<45:08,  6.18s/it]

Inputs shape: torch.Size([1, 897])


  4%|▍         | 18/455 [01:59<42:55,  5.89s/it]

Inputs shape: torch.Size([1, 1127])


  4%|▍         | 19/455 [02:06<46:51,  6.45s/it]

Inputs shape: torch.Size([1, 1086])


  4%|▍         | 20/455 [02:14<49:56,  6.89s/it]

Inputs shape: torch.Size([1, 1028])


  5%|▍         | 21/455 [02:19<45:21,  6.27s/it]

Inputs shape: torch.Size([1, 1302])


  5%|▍         | 22/455 [02:24<42:33,  5.90s/it]

Inputs shape: torch.Size([1, 1080])


  5%|▌         | 23/455 [02:31<45:32,  6.32s/it]

Inputs shape: torch.Size([1, 1282])


  5%|▌         | 24/455 [02:39<48:52,  6.80s/it]

Inputs shape: torch.Size([1, 1100])


  5%|▌         | 25/455 [02:45<46:11,  6.44s/it]

Inputs shape: torch.Size([1, 1054])


  6%|▌         | 26/455 [02:50<42:44,  5.98s/it]

Inputs shape: torch.Size([1, 1139])


  6%|▌         | 27/455 [02:56<43:19,  6.07s/it]

Inputs shape: torch.Size([1, 1041])


  6%|▌         | 28/455 [03:04<46:59,  6.60s/it]

Inputs shape: torch.Size([1, 932])


  6%|▋         | 29/455 [03:11<47:51,  6.74s/it]

Inputs shape: torch.Size([1, 1333])


  7%|▋         | 30/455 [03:16<44:08,  6.23s/it]

Inputs shape: torch.Size([1, 1142])


  7%|▋         | 31/455 [03:22<42:26,  6.01s/it]

Inputs shape: torch.Size([1, 849])


  7%|▋         | 32/455 [03:29<45:39,  6.48s/it]

Inputs shape: torch.Size([1, 1388])


  7%|▋         | 33/455 [03:37<48:46,  6.93s/it]

Inputs shape: torch.Size([1, 1311])


  7%|▋         | 34/455 [03:42<44:33,  6.35s/it]

Inputs shape: torch.Size([1, 807])


  8%|▊         | 35/455 [03:47<41:12,  5.89s/it]

Inputs shape: torch.Size([1, 873])


  8%|▊         | 36/455 [03:54<43:57,  6.29s/it]

Inputs shape: torch.Size([1, 1005])


  8%|▊         | 37/455 [04:02<46:53,  6.73s/it]

Inputs shape: torch.Size([1, 1174])


  8%|▊         | 38/455 [04:08<45:10,  6.50s/it]

Inputs shape: torch.Size([1, 924])


  9%|▊         | 39/455 [04:13<41:37,  6.00s/it]

Inputs shape: torch.Size([1, 1217])


  9%|▉         | 40/455 [04:19<41:37,  6.02s/it]

Inputs shape: torch.Size([1, 892])


  9%|▉         | 41/455 [04:26<45:00,  6.52s/it]

Inputs shape: torch.Size([1, 909])


  9%|▉         | 42/455 [04:34<46:23,  6.74s/it]

Inputs shape: torch.Size([1, 1094])


  9%|▉         | 43/455 [04:39<42:16,  6.16s/it]

Inputs shape: torch.Size([1, 1080])


 10%|▉         | 44/455 [04:44<39:53,  5.82s/it]

Inputs shape: torch.Size([1, 1263])


 10%|▉         | 45/455 [04:51<43:11,  6.32s/it]

Inputs shape: torch.Size([1, 1047])


 10%|█         | 46/455 [04:59<45:53,  6.73s/it]

Inputs shape: torch.Size([1, 1240])


 10%|█         | 47/455 [05:04<42:51,  6.30s/it]

Inputs shape: torch.Size([1, 766])


 11%|█         | 48/455 [05:09<39:32,  5.83s/it]

Inputs shape: torch.Size([1, 1214])


 11%|█         | 49/455 [05:16<41:33,  6.14s/it]

Inputs shape: torch.Size([1, 1305])


 11%|█         | 50/455 [05:24<44:59,  6.67s/it]

Inputs shape: torch.Size([1, 1108])


 11%|█         | 51/455 [05:29<41:29,  6.16s/it]

Inputs shape: torch.Size([1, 942])


 11%|█▏        | 52/455 [05:33<38:36,  5.75s/it]

Inputs shape: torch.Size([1, 1064])


 12%|█▏        | 53/455 [05:38<37:14,  5.56s/it]

Inputs shape: torch.Size([1, 962])


 12%|█▏        | 54/455 [05:46<41:07,  6.15s/it]

Inputs shape: torch.Size([1, 1129])


 12%|█▏        | 55/455 [05:54<44:11,  6.63s/it]

Inputs shape: torch.Size([1, 888])


 12%|█▏        | 56/455 [05:59<40:44,  6.13s/it]

Inputs shape: torch.Size([1, 863])


 13%|█▎        | 57/455 [06:03<38:04,  5.74s/it]

Inputs shape: torch.Size([1, 929])


 13%|█▎        | 58/455 [06:10<40:20,  6.10s/it]

Inputs shape: torch.Size([1, 906])


 13%|█▎        | 59/455 [06:18<43:30,  6.59s/it]

Inputs shape: torch.Size([1, 1021])


 13%|█▎        | 60/455 [06:24<42:40,  6.48s/it]

Inputs shape: torch.Size([1, 1287])


 13%|█▎        | 61/455 [06:29<39:35,  6.03s/it]

Inputs shape: torch.Size([1, 1118])


 14%|█▎        | 62/455 [06:35<39:14,  5.99s/it]

Inputs shape: torch.Size([1, 969])


 14%|█▍        | 63/455 [06:43<42:22,  6.49s/it]

Inputs shape: torch.Size([1, 996])


 14%|█▍        | 64/455 [06:50<44:13,  6.79s/it]

Inputs shape: torch.Size([1, 1531])


 14%|█▍        | 65/455 [06:55<40:40,  6.26s/it]

Inputs shape: torch.Size([1, 901])


 15%|█▍        | 66/455 [07:00<37:59,  5.86s/it]

Inputs shape: torch.Size([1, 754])


 15%|█▍        | 67/455 [07:08<40:51,  6.32s/it]

Inputs shape: torch.Size([1, 1169])


 15%|█▍        | 68/455 [07:16<43:42,  6.78s/it]

Inputs shape: torch.Size([1, 954])


 15%|█▌        | 69/455 [07:21<40:45,  6.34s/it]

Inputs shape: torch.Size([1, 1028])


 15%|█▌        | 70/455 [07:26<37:59,  5.92s/it]

Inputs shape: torch.Size([1, 1080])


 16%|█▌        | 71/455 [07:33<39:31,  6.17s/it]

Inputs shape: torch.Size([1, 875])


 16%|█▌        | 72/455 [07:40<42:19,  6.63s/it]

Inputs shape: torch.Size([1, 968])


 16%|█▌        | 73/455 [07:47<42:11,  6.63s/it]

Inputs shape: torch.Size([1, 1068])


 16%|█▋        | 74/455 [07:52<38:45,  6.10s/it]

Inputs shape: torch.Size([1, 1051])


 16%|█▋        | 75/455 [07:57<37:14,  5.88s/it]

Inputs shape: torch.Size([1, 798])


 17%|█▋        | 76/455 [08:05<40:21,  6.39s/it]

Inputs shape: torch.Size([1, 1215])


 17%|█▋        | 77/455 [08:12<42:43,  6.78s/it]

Inputs shape: torch.Size([1, 854])


 17%|█▋        | 78/455 [08:17<38:51,  6.18s/it]

Inputs shape: torch.Size([1, 819])


 17%|█▋        | 79/455 [08:22<36:14,  5.78s/it]

Inputs shape: torch.Size([1, 1200])


 18%|█▊        | 80/455 [08:29<38:35,  6.17s/it]

Inputs shape: torch.Size([1, 965])


 18%|█▊        | 81/455 [08:37<41:28,  6.65s/it]

Inputs shape: torch.Size([1, 1061])


 18%|█▊        | 82/455 [08:43<40:21,  6.49s/it]

Inputs shape: torch.Size([1, 986])


 18%|█▊        | 83/455 [08:48<37:04,  5.98s/it]

Inputs shape: torch.Size([1, 940])


 18%|█▊        | 84/455 [08:53<36:11,  5.85s/it]

Inputs shape: torch.Size([1, 916])


 19%|█▊        | 85/455 [09:01<39:20,  6.38s/it]

Inputs shape: torch.Size([1, 1039])


 19%|█▉        | 86/455 [09:08<40:58,  6.66s/it]

Inputs shape: torch.Size([1, 1179])


 19%|█▉        | 87/455 [09:13<37:52,  6.18s/it]

Inputs shape: torch.Size([1, 833])


 19%|█▉        | 88/455 [09:18<35:24,  5.79s/it]

Inputs shape: torch.Size([1, 1100])


 20%|█▉        | 89/455 [09:26<38:16,  6.27s/it]

Inputs shape: torch.Size([1, 1003])


 20%|█▉        | 90/455 [09:33<40:14,  6.62s/it]

Inputs shape: torch.Size([1, 1028])


 20%|██        | 91/455 [09:38<37:52,  6.24s/it]

Inputs shape: torch.Size([1, 946])


 20%|██        | 92/455 [09:43<35:11,  5.82s/it]

Inputs shape: torch.Size([1, 1168])


 20%|██        | 93/455 [09:50<36:18,  6.02s/it]

Inputs shape: torch.Size([1, 1048])


 21%|██        | 94/455 [09:57<39:06,  6.50s/it]

Inputs shape: torch.Size([1, 1116])


 21%|██        | 95/455 [10:04<39:23,  6.57s/it]

Inputs shape: torch.Size([1, 1129])


 21%|██        | 96/455 [10:09<36:19,  6.07s/it]

Inputs shape: torch.Size([1, 794])


 21%|██▏       | 97/455 [10:14<34:43,  5.82s/it]

Inputs shape: torch.Size([1, 1148])


 22%|██▏       | 98/455 [10:22<37:44,  6.34s/it]

Inputs shape: torch.Size([1, 1140])


 22%|██▏       | 99/455 [10:29<39:58,  6.74s/it]

Inputs shape: torch.Size([1, 1081])


 22%|██▏       | 100/455 [10:34<36:19,  6.14s/it]

Inputs shape: torch.Size([1, 1106])


 22%|██▏       | 101/455 [10:39<33:53,  5.74s/it]

Inputs shape: torch.Size([1, 855])


 22%|██▏       | 102/455 [10:46<35:45,  6.08s/it]

Inputs shape: torch.Size([1, 1094])


 23%|██▎       | 103/455 [10:54<38:26,  6.55s/it]

Inputs shape: torch.Size([1, 1175])


 23%|██▎       | 104/455 [11:00<37:37,  6.43s/it]

Inputs shape: torch.Size([1, 1084])


 23%|██▎       | 105/455 [11:05<34:41,  5.95s/it]

Inputs shape: torch.Size([1, 937])


 23%|██▎       | 106/455 [11:10<34:00,  5.85s/it]

Inputs shape: torch.Size([1, 1025])


 24%|██▎       | 107/455 [11:18<36:59,  6.38s/it]

Inputs shape: torch.Size([1, 1157])


 24%|██▎       | 108/455 [11:25<38:47,  6.71s/it]

Inputs shape: torch.Size([1, 1293])


 24%|██▍       | 109/455 [11:30<35:33,  6.17s/it]

Inputs shape: torch.Size([1, 1070])


 24%|██▍       | 110/455 [11:35<33:03,  5.75s/it]

Inputs shape: torch.Size([1, 984])


 24%|██▍       | 111/455 [11:42<35:31,  6.20s/it]

Inputs shape: torch.Size([1, 861])


 25%|██▍       | 112/455 [11:50<37:50,  6.62s/it]

Inputs shape: torch.Size([1, 906])


 25%|██▍       | 113/455 [11:55<35:49,  6.29s/it]

Inputs shape: torch.Size([1, 1119])


 25%|██▌       | 114/455 [12:00<33:16,  5.85s/it]

Inputs shape: torch.Size([1, 1089])


 25%|██▌       | 115/455 [12:06<34:03,  6.01s/it]

Inputs shape: torch.Size([1, 1076])


 25%|██▌       | 116/455 [12:14<36:43,  6.50s/it]

Inputs shape: torch.Size([1, 1180])


 26%|██▌       | 117/455 [12:21<36:50,  6.54s/it]

Inputs shape: torch.Size([1, 980])


 26%|██▌       | 118/455 [12:26<33:47,  6.02s/it]

Inputs shape: torch.Size([1, 1080])


 26%|██▌       | 119/455 [12:31<32:26,  5.79s/it]

Inputs shape: torch.Size([1, 1071])


 26%|██▋       | 120/455 [12:38<35:22,  6.34s/it]

Inputs shape: torch.Size([1, 1161])


 27%|██▋       | 121/455 [12:46<37:28,  6.73s/it]

Inputs shape: torch.Size([1, 949])


 27%|██▋       | 122/455 [12:51<34:02,  6.13s/it]

Inputs shape: torch.Size([1, 908])


 27%|██▋       | 123/455 [12:56<31:40,  5.73s/it]

Inputs shape: torch.Size([1, 1010])


 27%|██▋       | 124/455 [13:03<33:57,  6.16s/it]

Inputs shape: torch.Size([1, 1171])


 27%|██▋       | 125/455 [13:10<36:19,  6.61s/it]

Inputs shape: torch.Size([1, 1031])


 28%|██▊       | 126/455 [13:16<34:48,  6.35s/it]

Inputs shape: torch.Size([1, 812])


 28%|██▊       | 127/455 [13:21<31:58,  5.85s/it]

Inputs shape: torch.Size([1, 1038])


 28%|██▊       | 128/455 [13:27<32:05,  5.89s/it]

Inputs shape: torch.Size([1, 850])


 28%|██▊       | 129/455 [13:34<34:35,  6.37s/it]

Inputs shape: torch.Size([1, 1088])


 29%|██▊       | 130/455 [13:41<35:35,  6.57s/it]

Inputs shape: torch.Size([1, 859])


 29%|██▉       | 131/455 [13:46<32:35,  6.04s/it]

Inputs shape: torch.Size([1, 1073])


 29%|██▉       | 132/455 [13:51<30:52,  5.74s/it]

Inputs shape: torch.Size([1, 1059])


 29%|██▉       | 133/455 [13:59<33:30,  6.24s/it]

Inputs shape: torch.Size([1, 1140])


 29%|██▉       | 134/455 [14:06<35:42,  6.67s/it]

Inputs shape: torch.Size([1, 959])


 30%|██▉       | 135/455 [14:11<33:10,  6.22s/it]

Inputs shape: torch.Size([1, 982])


 30%|██▉       | 136/455 [14:16<30:44,  5.78s/it]

Inputs shape: torch.Size([1, 1119])


 30%|███       | 137/455 [14:23<31:54,  6.02s/it]

Inputs shape: torch.Size([1, 855])


 30%|███       | 138/455 [14:30<34:12,  6.47s/it]

Inputs shape: torch.Size([1, 1249])


 31%|███       | 139/455 [14:37<33:51,  6.43s/it]

Inputs shape: torch.Size([1, 1016])


 31%|███       | 140/455 [14:41<31:12,  5.94s/it]

Inputs shape: torch.Size([1, 979])


 31%|███       | 141/455 [14:47<30:29,  5.83s/it]

Inputs shape: torch.Size([1, 913])


 31%|███       | 142/455 [14:55<33:10,  6.36s/it]

Inputs shape: torch.Size([1, 947])


 31%|███▏      | 143/455 [15:02<34:39,  6.67s/it]

Inputs shape: torch.Size([1, 1082])


 32%|███▏      | 144/455 [15:07<31:41,  6.11s/it]

Inputs shape: torch.Size([1, 975])


 32%|███▏      | 145/455 [15:12<29:35,  5.73s/it]

Inputs shape: torch.Size([1, 1026])


 32%|███▏      | 146/455 [15:19<31:54,  6.19s/it]

Inputs shape: torch.Size([1, 1234])


 32%|███▏      | 147/455 [15:27<34:03,  6.63s/it]

Inputs shape: torch.Size([1, 990])


 33%|███▎      | 148/455 [15:32<32:25,  6.34s/it]

Inputs shape: torch.Size([1, 1072])


 33%|███▎      | 149/455 [15:37<29:59,  5.88s/it]

Inputs shape: torch.Size([1, 1135])


 33%|███▎      | 150/455 [15:43<30:23,  5.98s/it]

Inputs shape: torch.Size([1, 1156])


 33%|███▎      | 151/455 [15:51<32:49,  6.48s/it]

Inputs shape: torch.Size([1, 1189])


 33%|███▎      | 152/455 [15:58<33:29,  6.63s/it]

Inputs shape: torch.Size([1, 1293])


 34%|███▎      | 153/455 [16:03<30:49,  6.12s/it]

Inputs shape: torch.Size([1, 885])


 34%|███▍      | 154/455 [16:08<29:09,  5.81s/it]

Inputs shape: torch.Size([1, 1085])


 34%|███▍      | 155/455 [16:15<31:31,  6.30s/it]

Inputs shape: torch.Size([1, 1138])


 34%|███▍      | 156/455 [16:23<33:29,  6.72s/it]

Inputs shape: torch.Size([1, 1128])


 35%|███▍      | 157/455 [16:28<30:59,  6.24s/it]

Inputs shape: torch.Size([1, 1240])


 35%|███▍      | 158/455 [16:33<28:47,  5.82s/it]

Inputs shape: torch.Size([1, 1097])


 35%|███▍      | 159/455 [16:40<30:00,  6.08s/it]

Inputs shape: torch.Size([1, 1403])


 35%|███▌      | 160/455 [16:47<32:03,  6.52s/it]

Inputs shape: torch.Size([1, 1387])


 35%|███▌      | 161/455 [16:54<32:00,  6.53s/it]

Inputs shape: torch.Size([1, 915])


 36%|███▌      | 162/455 [16:59<29:19,  6.00s/it]

Inputs shape: torch.Size([1, 1269])


 36%|███▌      | 163/455 [17:04<28:32,  5.86s/it]

Inputs shape: torch.Size([1, 973])


 36%|███▌      | 164/455 [17:12<30:54,  6.37s/it]

Inputs shape: torch.Size([1, 899])


 36%|███▋      | 165/455 [17:19<32:22,  6.70s/it]

Inputs shape: torch.Size([1, 1084])


 36%|███▋      | 166/455 [17:24<29:30,  6.13s/it]

Inputs shape: torch.Size([1, 992])


 37%|███▋      | 167/455 [17:29<27:29,  5.73s/it]

Inputs shape: torch.Size([1, 1200])


 37%|███▋      | 168/455 [17:36<29:31,  6.17s/it]

Inputs shape: torch.Size([1, 947])


 37%|███▋      | 169/455 [17:43<31:23,  6.59s/it]

Inputs shape: torch.Size([1, 770])


 37%|███▋      | 170/455 [17:49<29:59,  6.31s/it]

Inputs shape: torch.Size([1, 1299])


 38%|███▊      | 171/455 [17:54<27:55,  5.90s/it]

Inputs shape: torch.Size([1, 989])


 38%|███▊      | 172/455 [18:00<28:18,  6.00s/it]

Inputs shape: torch.Size([1, 1143])


 38%|███▊      | 173/455 [18:08<30:29,  6.49s/it]

Inputs shape: torch.Size([1, 1088])


 38%|███▊      | 174/455 [18:15<30:35,  6.53s/it]

Inputs shape: torch.Size([1, 953])


 38%|███▊      | 175/455 [18:19<28:07,  6.03s/it]

Inputs shape: torch.Size([1, 805])


 39%|███▊      | 176/455 [18:25<26:45,  5.76s/it]

Inputs shape: torch.Size([1, 1035])


 39%|███▉      | 177/455 [18:32<29:02,  6.27s/it]

Inputs shape: torch.Size([1, 1207])


 39%|███▉      | 178/455 [18:40<30:57,  6.70s/it]

Inputs shape: torch.Size([1, 975])


 39%|███▉      | 179/455 [18:45<28:22,  6.17s/it]

Inputs shape: torch.Size([1, 969])


 40%|███▉      | 180/455 [18:50<26:30,  5.78s/it]

Inputs shape: torch.Size([1, 1051])


 40%|███▉      | 181/455 [18:56<27:50,  6.10s/it]

Inputs shape: torch.Size([1, 1256])


 40%|████      | 182/455 [19:04<29:56,  6.58s/it]

Inputs shape: torch.Size([1, 1003])


 40%|████      | 183/455 [19:10<29:25,  6.49s/it]

Inputs shape: torch.Size([1, 1087])


 40%|████      | 184/455 [19:15<27:08,  6.01s/it]

Inputs shape: torch.Size([1, 880])


 41%|████      | 185/455 [19:21<26:23,  5.87s/it]

Inputs shape: torch.Size([1, 1090])


 41%|████      | 186/455 [19:28<28:34,  6.38s/it]

Inputs shape: torch.Size([1, 907])


 41%|████      | 187/455 [19:36<30:08,  6.75s/it]

Inputs shape: torch.Size([1, 1080])


 41%|████▏     | 188/455 [19:41<27:31,  6.19s/it]

Inputs shape: torch.Size([1, 1250])


 42%|████▏     | 189/455 [19:46<25:38,  5.78s/it]

Inputs shape: torch.Size([1, 877])


 42%|████▏     | 190/455 [19:53<27:30,  6.23s/it]

Inputs shape: torch.Size([1, 1295])


 42%|████▏     | 191/455 [20:01<29:28,  6.70s/it]

Inputs shape: torch.Size([1, 879])


 42%|████▏     | 192/455 [20:06<27:44,  6.33s/it]

Inputs shape: torch.Size([1, 799])


 42%|████▏     | 193/455 [20:11<25:39,  5.88s/it]

Inputs shape: torch.Size([1, 962])


 43%|████▎     | 194/455 [20:18<26:23,  6.07s/it]

Inputs shape: torch.Size([1, 1093])


 43%|████▎     | 195/455 [20:25<28:20,  6.54s/it]

Inputs shape: torch.Size([1, 1053])


 43%|████▎     | 196/455 [20:32<28:17,  6.55s/it]

Inputs shape: torch.Size([1, 1228])


 43%|████▎     | 197/455 [20:37<26:02,  6.06s/it]

Inputs shape: torch.Size([1, 742])


 44%|████▎     | 198/455 [20:42<24:42,  5.77s/it]

Inputs shape: torch.Size([1, 979])


 44%|████▎     | 199/455 [20:49<26:53,  6.30s/it]

Inputs shape: torch.Size([1, 1103])


 44%|████▍     | 200/455 [20:57<28:34,  6.72s/it]

Inputs shape: torch.Size([1, 1049])


 44%|████▍     | 201/455 [21:02<26:05,  6.16s/it]

Inputs shape: torch.Size([1, 999])


 44%|████▍     | 202/455 [21:07<24:14,  5.75s/it]

Inputs shape: torch.Size([1, 1116])


 45%|████▍     | 203/455 [21:13<25:23,  6.04s/it]

Inputs shape: torch.Size([1, 1417])


 45%|████▍     | 204/455 [21:21<27:29,  6.57s/it]

Inputs shape: torch.Size([1, 952])


 45%|████▌     | 205/455 [21:27<26:55,  6.46s/it]

Inputs shape: torch.Size([1, 1217])


 45%|████▌     | 206/455 [21:32<24:44,  5.96s/it]

Inputs shape: torch.Size([1, 819])


 45%|████▌     | 207/455 [21:38<23:56,  5.79s/it]

Inputs shape: torch.Size([1, 954])


 46%|████▌     | 208/455 [21:45<26:04,  6.33s/it]

Inputs shape: torch.Size([1, 1195])


 46%|████▌     | 209/455 [21:53<27:24,  6.68s/it]

Inputs shape: torch.Size([1, 1095])


 46%|████▌     | 210/455 [21:58<25:01,  6.13s/it]

Inputs shape: torch.Size([1, 902])


 46%|████▋     | 211/455 [22:02<23:23,  5.75s/it]

Inputs shape: torch.Size([1, 1086])


 47%|████▋     | 212/455 [22:10<25:17,  6.25s/it]

Inputs shape: torch.Size([1, 1240])


 47%|████▋     | 213/455 [22:17<26:55,  6.67s/it]

Inputs shape: torch.Size([1, 1013])


 47%|████▋     | 214/455 [22:23<25:14,  6.29s/it]

Inputs shape: torch.Size([1, 976])


 47%|████▋     | 215/455 [22:25<20:19,  5.08s/it]

Inputs shape: torch.Size([1, 1236])


 47%|████▋     | 216/455 [22:30<20:11,  5.07s/it]

Inputs shape: torch.Size([1, 1233])


 48%|████▊     | 217/455 [22:38<23:04,  5.82s/it]

Inputs shape: torch.Size([1, 1085])


 48%|████▊     | 218/455 [22:45<25:10,  6.37s/it]

Inputs shape: torch.Size([1, 948])


 48%|████▊     | 219/455 [22:50<23:23,  5.95s/it]

Inputs shape: torch.Size([1, 1260])


 48%|████▊     | 220/455 [22:55<22:04,  5.64s/it]

Inputs shape: torch.Size([1, 1240])


 49%|████▊     | 221/455 [23:02<23:30,  6.03s/it]

Inputs shape: torch.Size([1, 959])


 49%|████▉     | 222/455 [23:10<25:20,  6.53s/it]

Inputs shape: torch.Size([1, 708])


 49%|████▉     | 223/455 [23:16<24:32,  6.35s/it]

Inputs shape: torch.Size([1, 1324])


 49%|████▉     | 224/455 [23:21<22:46,  5.91s/it]

Inputs shape: torch.Size([1, 1444])


 49%|████▉     | 225/455 [23:27<22:54,  5.97s/it]

Inputs shape: torch.Size([1, 840])


 50%|████▉     | 226/455 [23:34<24:38,  6.46s/it]

Inputs shape: torch.Size([1, 999])


 50%|████▉     | 227/455 [23:41<24:52,  6.55s/it]

Inputs shape: torch.Size([1, 1243])


 50%|█████     | 228/455 [23:46<22:49,  6.03s/it]

Inputs shape: torch.Size([1, 1164])


 50%|█████     | 229/455 [23:51<21:21,  5.67s/it]

Inputs shape: torch.Size([1, 1288])


 51%|█████     | 230/455 [23:58<23:07,  6.17s/it]

Inputs shape: torch.Size([1, 977])


 51%|█████     | 231/455 [24:06<24:41,  6.61s/it]

Inputs shape: torch.Size([1, 989])


 51%|█████     | 232/455 [24:11<23:24,  6.30s/it]

Inputs shape: torch.Size([1, 1036])


 51%|█████     | 233/455 [24:16<21:41,  5.86s/it]

Inputs shape: torch.Size([1, 1099])


 51%|█████▏    | 234/455 [24:22<21:57,  5.96s/it]

Inputs shape: torch.Size([1, 1086])


 52%|█████▏    | 235/455 [24:30<23:37,  6.44s/it]

Inputs shape: torch.Size([1, 910])


 52%|█████▏    | 236/455 [24:37<23:45,  6.51s/it]

Inputs shape: torch.Size([1, 1113])


 52%|█████▏    | 237/455 [24:41<21:46,  5.99s/it]

Inputs shape: torch.Size([1, 1018])


 52%|█████▏    | 238/455 [24:47<20:46,  5.74s/it]

Inputs shape: torch.Size([1, 1072])


 53%|█████▎    | 239/455 [24:54<22:37,  6.28s/it]

Inputs shape: torch.Size([1, 990])


 53%|█████▎    | 240/455 [25:02<23:53,  6.67s/it]

Inputs shape: torch.Size([1, 995])


 53%|█████▎    | 241/455 [25:06<21:42,  6.08s/it]

Inputs shape: torch.Size([1, 821])


 53%|█████▎    | 242/455 [25:11<20:09,  5.68s/it]

Inputs shape: torch.Size([1, 954])


 53%|█████▎    | 243/455 [25:18<21:14,  6.01s/it]

Inputs shape: torch.Size([1, 1076])


 54%|█████▎    | 244/455 [25:26<22:50,  6.49s/it]

Inputs shape: torch.Size([1, 848])


 54%|█████▍    | 245/455 [25:32<22:25,  6.40s/it]

Inputs shape: torch.Size([1, 1179])


 54%|█████▍    | 246/455 [25:37<20:36,  5.92s/it]

Inputs shape: torch.Size([1, 869])


 54%|█████▍    | 247/455 [25:42<19:57,  5.76s/it]

Inputs shape: torch.Size([1, 1215])


 55%|█████▍    | 248/455 [25:50<21:45,  6.31s/it]

Inputs shape: torch.Size([1, 1035])


 55%|█████▍    | 249/455 [25:57<22:48,  6.64s/it]

Inputs shape: torch.Size([1, 1141])


 55%|█████▍    | 250/455 [26:02<20:50,  6.10s/it]

Inputs shape: torch.Size([1, 1069])


 55%|█████▌    | 251/455 [26:07<19:22,  5.70s/it]

Inputs shape: torch.Size([1, 757])


 55%|█████▌    | 252/455 [26:14<20:51,  6.16s/it]

Inputs shape: torch.Size([1, 867])


 56%|█████▌    | 253/455 [26:21<22:10,  6.58s/it]

Inputs shape: torch.Size([1, 1061])


 56%|█████▌    | 254/455 [26:27<20:52,  6.23s/it]

Inputs shape: torch.Size([1, 940])


 56%|█████▌    | 255/455 [26:32<19:20,  5.80s/it]

Inputs shape: torch.Size([1, 1142])


 56%|█████▋    | 256/455 [26:38<19:55,  6.01s/it]

Inputs shape: torch.Size([1, 996])


 56%|█████▋    | 257/455 [26:46<21:22,  6.48s/it]

Inputs shape: torch.Size([1, 1005])


 57%|█████▋    | 258/455 [26:52<21:29,  6.55s/it]

Inputs shape: torch.Size([1, 1213])


 57%|█████▋    | 259/455 [26:57<19:44,  6.04s/it]

Inputs shape: torch.Size([1, 1167])


 57%|█████▋    | 260/455 [27:02<18:50,  5.80s/it]

Inputs shape: torch.Size([1, 833])


 57%|█████▋    | 261/455 [27:10<20:22,  6.30s/it]

Inputs shape: torch.Size([1, 1002])


 58%|█████▊    | 262/455 [27:17<21:28,  6.68s/it]

Inputs shape: torch.Size([1, 1147])


 58%|█████▊    | 263/455 [27:22<19:44,  6.17s/it]

Inputs shape: torch.Size([1, 1077])


 58%|█████▊    | 264/455 [27:27<18:23,  5.78s/it]

Inputs shape: torch.Size([1, 807])


 58%|█████▊    | 265/455 [27:34<19:11,  6.06s/it]

Inputs shape: torch.Size([1, 922])


 58%|█████▊    | 266/455 [27:42<20:34,  6.53s/it]

Inputs shape: torch.Size([1, 840])


 59%|█████▊    | 267/455 [27:48<19:57,  6.37s/it]

Inputs shape: torch.Size([1, 1242])


 59%|█████▉    | 268/455 [27:52<18:22,  5.90s/it]

Inputs shape: torch.Size([1, 1189])


 59%|█████▉    | 269/455 [27:58<18:06,  5.84s/it]

Inputs shape: torch.Size([1, 923])


 59%|█████▉    | 270/455 [28:06<19:33,  6.34s/it]

Inputs shape: torch.Size([1, 1009])


 60%|█████▉    | 271/455 [28:13<20:31,  6.69s/it]

Inputs shape: torch.Size([1, 1199])


 60%|█████▉    | 272/455 [28:18<18:42,  6.13s/it]

Inputs shape: torch.Size([1, 1288])


 60%|██████    | 273/455 [28:23<17:32,  5.78s/it]

Inputs shape: torch.Size([1, 1030])


 60%|██████    | 274/455 [28:30<18:47,  6.23s/it]

Inputs shape: torch.Size([1, 1044])


 60%|██████    | 275/455 [28:38<19:57,  6.65s/it]

Inputs shape: torch.Size([1, 1024])


 61%|██████    | 276/455 [28:44<18:56,  6.35s/it]

Inputs shape: torch.Size([1, 1029])


 61%|██████    | 277/455 [28:48<17:26,  5.88s/it]

Inputs shape: torch.Size([1, 833])


 61%|██████    | 278/455 [28:54<17:22,  5.89s/it]

Inputs shape: torch.Size([1, 1173])


 61%|██████▏   | 279/455 [29:02<18:46,  6.40s/it]

Inputs shape: torch.Size([1, 1018])


 62%|██████▏   | 280/455 [29:09<19:17,  6.61s/it]

Inputs shape: torch.Size([1, 1003])


 62%|██████▏   | 281/455 [29:14<17:37,  6.08s/it]

Inputs shape: torch.Size([1, 916])


 62%|██████▏   | 282/455 [29:19<16:33,  5.74s/it]

Inputs shape: torch.Size([1, 852])


 62%|██████▏   | 283/455 [29:26<17:51,  6.23s/it]

Inputs shape: torch.Size([1, 935])


 62%|██████▏   | 284/455 [29:34<18:54,  6.63s/it]

Inputs shape: torch.Size([1, 1138])


 63%|██████▎   | 285/455 [29:39<17:41,  6.24s/it]

Inputs shape: torch.Size([1, 804])


 63%|██████▎   | 286/455 [29:44<16:16,  5.78s/it]

Inputs shape: torch.Size([1, 1168])


 63%|██████▎   | 287/455 [29:50<16:43,  5.97s/it]

Inputs shape: torch.Size([1, 1162])


 63%|██████▎   | 288/455 [29:58<18:01,  6.48s/it]

Inputs shape: torch.Size([1, 1041])


 64%|██████▎   | 289/455 [30:04<17:58,  6.50s/it]

Inputs shape: torch.Size([1, 944])


 64%|██████▎   | 290/455 [30:09<16:28,  5.99s/it]

Inputs shape: torch.Size([1, 1121])


 64%|██████▍   | 291/455 [30:15<15:56,  5.83s/it]

Inputs shape: torch.Size([1, 802])


 64%|██████▍   | 292/455 [30:22<17:11,  6.33s/it]

Inputs shape: torch.Size([1, 1242])


 64%|██████▍   | 293/455 [30:30<18:08,  6.72s/it]

Inputs shape: torch.Size([1, 1294])


 65%|██████▍   | 294/455 [30:35<16:31,  6.16s/it]

Inputs shape: torch.Size([1, 889])


 65%|██████▍   | 295/455 [30:39<15:15,  5.72s/it]

Inputs shape: torch.Size([1, 945])


 65%|██████▌   | 296/455 [30:46<16:13,  6.12s/it]

Inputs shape: torch.Size([1, 1179])


 65%|██████▌   | 297/455 [30:54<17:20,  6.59s/it]

Inputs shape: torch.Size([1, 1047])


 65%|██████▌   | 298/455 [31:00<16:43,  6.39s/it]

Inputs shape: torch.Size([1, 1109])


 66%|██████▌   | 299/455 [31:05<15:23,  5.92s/it]

Inputs shape: torch.Size([1, 946])


 66%|██████▌   | 300/455 [31:11<15:15,  5.91s/it]

Inputs shape: torch.Size([1, 933])


 66%|██████▌   | 301/455 [31:18<16:24,  6.39s/it]

Inputs shape: torch.Size([1, 1078])


 66%|██████▋   | 302/455 [31:25<16:50,  6.60s/it]

Inputs shape: torch.Size([1, 894])


 67%|██████▋   | 303/455 [31:30<15:19,  6.05s/it]

Inputs shape: torch.Size([1, 904])


 67%|██████▋   | 304/455 [31:35<14:22,  5.71s/it]

Inputs shape: torch.Size([1, 840])


 67%|██████▋   | 305/455 [31:42<15:27,  6.18s/it]

Inputs shape: torch.Size([1, 1002])


 67%|██████▋   | 306/455 [31:50<16:25,  6.61s/it]

Inputs shape: torch.Size([1, 1245])


 67%|██████▋   | 307/455 [31:55<15:27,  6.27s/it]

Inputs shape: torch.Size([1, 795])


 68%|██████▊   | 308/455 [32:00<14:12,  5.80s/it]

Inputs shape: torch.Size([1, 825])


 68%|██████▊   | 309/455 [32:06<14:26,  5.94s/it]

Inputs shape: torch.Size([1, 1027])


 68%|██████▊   | 310/455 [32:14<15:33,  6.44s/it]

Inputs shape: torch.Size([1, 921])


 68%|██████▊   | 311/455 [32:21<15:44,  6.56s/it]

Inputs shape: torch.Size([1, 1007])


 69%|██████▊   | 312/455 [32:25<14:21,  6.03s/it]

Inputs shape: torch.Size([1, 812])


 69%|██████▉   | 313/455 [32:31<13:36,  5.75s/it]

Inputs shape: torch.Size([1, 991])


 69%|██████▉   | 314/455 [32:37<13:48,  5.87s/it]

Inputs shape: torch.Size([1, 1017])


 69%|██████▉   | 315/455 [32:44<14:56,  6.40s/it]

Inputs shape: torch.Size([1, 957])


 69%|██████▉   | 316/455 [32:50<14:16,  6.16s/it]

Inputs shape: torch.Size([1, 1051])


 70%|██████▉   | 317/455 [32:55<13:15,  5.77s/it]

Inputs shape: torch.Size([1, 879])


 70%|██████▉   | 318/455 [33:01<13:19,  5.84s/it]

Inputs shape: torch.Size([1, 1073])


 70%|███████   | 319/455 [33:08<14:25,  6.36s/it]

Inputs shape: torch.Size([1, 948])


 70%|███████   | 320/455 [33:15<14:45,  6.56s/it]

Inputs shape: torch.Size([1, 1202])


 71%|███████   | 321/455 [33:20<13:29,  6.04s/it]

Inputs shape: torch.Size([1, 880])


 71%|███████   | 322/455 [33:25<12:42,  5.73s/it]

Inputs shape: torch.Size([1, 1186])


 71%|███████   | 323/455 [33:33<13:44,  6.24s/it]

Inputs shape: torch.Size([1, 1205])


 71%|███████   | 324/455 [33:40<14:35,  6.68s/it]

Inputs shape: torch.Size([1, 1060])


 71%|███████▏  | 325/455 [33:46<13:34,  6.27s/it]

Inputs shape: torch.Size([1, 1430])


 72%|███████▏  | 326/455 [33:51<12:39,  5.89s/it]

Inputs shape: torch.Size([1, 951])


 72%|███████▏  | 327/455 [33:57<12:57,  6.07s/it]

Inputs shape: torch.Size([1, 937])


 72%|███████▏  | 328/455 [34:05<13:48,  6.53s/it]

Inputs shape: torch.Size([1, 920])


 72%|███████▏  | 329/455 [34:11<13:45,  6.55s/it]

Inputs shape: torch.Size([1, 1152])


 73%|███████▎  | 330/455 [34:16<12:34,  6.03s/it]

Inputs shape: torch.Size([1, 1129])


 73%|███████▎  | 331/455 [34:22<12:01,  5.82s/it]

Inputs shape: torch.Size([1, 1298])


 73%|███████▎  | 332/455 [34:29<13:10,  6.42s/it]

Inputs shape: torch.Size([1, 1040])


 73%|███████▎  | 333/455 [34:37<13:52,  6.82s/it]

Inputs shape: torch.Size([1, 1074])


 73%|███████▎  | 334/455 [34:42<12:34,  6.23s/it]

Inputs shape: torch.Size([1, 1336])


 74%|███████▎  | 335/455 [34:47<11:40,  5.84s/it]

Inputs shape: torch.Size([1, 976])


 74%|███████▍  | 336/455 [34:54<12:28,  6.29s/it]

Inputs shape: torch.Size([1, 943])


 74%|███████▍  | 337/455 [35:02<13:10,  6.70s/it]

Inputs shape: torch.Size([1, 1343])


 74%|███████▍  | 338/455 [35:08<12:26,  6.38s/it]

Inputs shape: torch.Size([1, 1245])


 75%|███████▍  | 339/455 [35:12<11:26,  5.92s/it]

Inputs shape: torch.Size([1, 1163])


 75%|███████▍  | 340/455 [35:19<11:30,  6.00s/it]

Inputs shape: torch.Size([1, 1266])


 75%|███████▍  | 341/455 [35:23<10:31,  5.54s/it]

Inputs shape: torch.Size([1, 1023])


 75%|███████▌  | 342/455 [35:31<11:34,  6.15s/it]

Inputs shape: torch.Size([1, 1256])


 75%|███████▌  | 343/455 [35:36<10:56,  5.87s/it]

Inputs shape: torch.Size([1, 977])


 76%|███████▌  | 344/455 [35:41<10:16,  5.55s/it]

Inputs shape: torch.Size([1, 1009])


 76%|███████▌  | 345/455 [35:47<10:48,  5.89s/it]

Inputs shape: torch.Size([1, 1004])


 76%|███████▌  | 346/455 [35:55<11:38,  6.41s/it]

Inputs shape: torch.Size([1, 1118])


 76%|███████▋  | 347/455 [36:01<11:35,  6.44s/it]

Inputs shape: torch.Size([1, 1237])


 76%|███████▋  | 348/455 [36:06<10:39,  5.98s/it]

Inputs shape: torch.Size([1, 1029])


 77%|███████▋  | 349/455 [36:12<10:18,  5.83s/it]

Inputs shape: torch.Size([1, 977])


 77%|███████▋  | 350/455 [36:19<11:05,  6.33s/it]

Inputs shape: torch.Size([1, 1285])


 77%|███████▋  | 351/455 [36:27<11:42,  6.75s/it]

Inputs shape: torch.Size([1, 1017])


 77%|███████▋  | 352/455 [36:32<10:35,  6.17s/it]

Inputs shape: torch.Size([1, 970])


 78%|███████▊  | 353/455 [36:37<09:47,  5.76s/it]

Inputs shape: torch.Size([1, 896])


 78%|███████▊  | 354/455 [36:44<10:24,  6.18s/it]

Inputs shape: torch.Size([1, 1058])


 78%|███████▊  | 355/455 [36:51<11:01,  6.61s/it]

Inputs shape: torch.Size([1, 1105])


 78%|███████▊  | 356/455 [36:57<10:29,  6.36s/it]

Inputs shape: torch.Size([1, 1046])


 78%|███████▊  | 357/455 [37:02<09:37,  5.89s/it]

Inputs shape: torch.Size([1, 1229])


 79%|███████▊  | 358/455 [37:08<09:36,  5.94s/it]

Inputs shape: torch.Size([1, 1008])


 79%|███████▉  | 359/455 [37:16<10:21,  6.47s/it]

Inputs shape: torch.Size([1, 1001])


 79%|███████▉  | 360/455 [37:23<10:30,  6.63s/it]

Inputs shape: torch.Size([1, 1117])


 79%|███████▉  | 361/455 [37:28<09:33,  6.10s/it]

Inputs shape: torch.Size([1, 1019])


 80%|███████▉  | 362/455 [37:33<08:58,  5.79s/it]

Inputs shape: torch.Size([1, 934])


 80%|███████▉  | 363/455 [37:40<09:38,  6.29s/it]

Inputs shape: torch.Size([1, 1014])


 80%|████████  | 364/455 [37:48<10:09,  6.70s/it]

Inputs shape: torch.Size([1, 746])


 80%|████████  | 365/455 [37:53<09:17,  6.19s/it]

Inputs shape: torch.Size([1, 1247])


 80%|████████  | 366/455 [37:58<08:35,  5.80s/it]

Inputs shape: torch.Size([1, 977])


 81%|████████  | 367/455 [38:04<08:49,  6.02s/it]

Inputs shape: torch.Size([1, 1214])


 81%|████████  | 368/455 [38:12<09:29,  6.54s/it]

Inputs shape: torch.Size([1, 1145])


 81%|████████  | 369/455 [38:19<09:25,  6.57s/it]

Inputs shape: torch.Size([1, 1138])


 81%|████████▏ | 370/455 [38:24<08:34,  6.06s/it]

Inputs shape: torch.Size([1, 1225])


 82%|████████▏ | 371/455 [38:28<07:59,  5.71s/it]

Inputs shape: torch.Size([1, 1474])


 82%|████████▏ | 372/455 [38:36<08:42,  6.30s/it]

Inputs shape: torch.Size([1, 1121])


 82%|████████▏ | 373/455 [38:44<09:11,  6.72s/it]

Inputs shape: torch.Size([1, 1093])


 82%|████████▏ | 374/455 [38:49<08:18,  6.15s/it]

Inputs shape: torch.Size([1, 950])


 82%|████████▏ | 375/455 [38:53<07:35,  5.69s/it]

Inputs shape: torch.Size([1, 1064])


 83%|████████▎ | 376/455 [39:00<07:53,  5.99s/it]

Inputs shape: torch.Size([1, 1061])


 83%|████████▎ | 377/455 [39:08<08:26,  6.49s/it]

Inputs shape: torch.Size([1, 860])


 83%|████████▎ | 378/455 [39:14<08:18,  6.48s/it]

Inputs shape: torch.Size([1, 926])


 83%|████████▎ | 379/455 [39:19<07:32,  5.95s/it]

Inputs shape: torch.Size([1, 1151])


 84%|████████▎ | 380/455 [39:24<07:15,  5.81s/it]

Inputs shape: torch.Size([1, 1195])


 84%|████████▎ | 381/455 [39:32<07:50,  6.35s/it]

Inputs shape: torch.Size([1, 964])


 84%|████████▍ | 382/455 [39:40<08:12,  6.75s/it]

Inputs shape: torch.Size([1, 1079])


 84%|████████▍ | 383/455 [39:44<07:25,  6.18s/it]

Inputs shape: torch.Size([1, 1032])


 84%|████████▍ | 384/455 [39:49<06:51,  5.80s/it]

Inputs shape: torch.Size([1, 1034])


 85%|████████▍ | 385/455 [39:57<07:14,  6.21s/it]

Inputs shape: torch.Size([1, 1020])


 85%|████████▍ | 386/455 [40:04<07:38,  6.65s/it]

Inputs shape: torch.Size([1, 1108])


 85%|████████▌ | 387/455 [40:10<07:16,  6.42s/it]

Inputs shape: torch.Size([1, 834])


 85%|████████▌ | 388/455 [40:15<06:36,  5.92s/it]

Inputs shape: torch.Size([1, 1227])


 85%|████████▌ | 389/455 [40:21<06:30,  5.91s/it]

Inputs shape: torch.Size([1, 1186])


 86%|████████▌ | 390/455 [40:28<06:56,  6.41s/it]

Inputs shape: torch.Size([1, 1314])


 86%|████████▌ | 391/455 [40:35<07:05,  6.65s/it]

Inputs shape: torch.Size([1, 1098])


 86%|████████▌ | 392/455 [40:40<06:25,  6.12s/it]

Inputs shape: torch.Size([1, 985])


 86%|████████▋ | 393/455 [40:45<05:58,  5.78s/it]

Inputs shape: torch.Size([1, 1181])


 87%|████████▋ | 394/455 [40:53<06:22,  6.26s/it]

Inputs shape: torch.Size([1, 1135])


 87%|████████▋ | 395/455 [41:00<06:41,  6.69s/it]

Inputs shape: torch.Size([1, 1000])


 87%|████████▋ | 396/455 [41:06<06:12,  6.31s/it]

Inputs shape: torch.Size([1, 778])


 87%|████████▋ | 397/455 [41:11<05:39,  5.86s/it]

Inputs shape: torch.Size([1, 1172])


 87%|████████▋ | 398/455 [41:17<05:45,  6.07s/it]

Inputs shape: torch.Size([1, 877])


 88%|████████▊ | 399/455 [41:25<06:04,  6.50s/it]

Inputs shape: torch.Size([1, 964])


 88%|████████▊ | 400/455 [41:31<05:56,  6.48s/it]

Inputs shape: torch.Size([1, 965])


 88%|████████▊ | 401/455 [41:36<05:23,  5.98s/it]

Inputs shape: torch.Size([1, 1044])


 88%|████████▊ | 402/455 [41:41<05:06,  5.79s/it]

Inputs shape: torch.Size([1, 1136])


 89%|████████▊ | 403/455 [41:49<05:27,  6.30s/it]

Inputs shape: torch.Size([1, 923])


 89%|████████▉ | 404/455 [41:56<05:40,  6.68s/it]

Inputs shape: torch.Size([1, 863])


 89%|████████▉ | 405/455 [42:01<05:04,  6.08s/it]

Inputs shape: torch.Size([1, 1151])


 89%|████████▉ | 406/455 [42:06<04:39,  5.71s/it]

Inputs shape: torch.Size([1, 1223])


 89%|████████▉ | 407/455 [42:13<04:52,  6.10s/it]

Inputs shape: torch.Size([1, 1042])


 90%|████████▉ | 408/455 [42:21<05:07,  6.55s/it]

Inputs shape: torch.Size([1, 1068])


 90%|████████▉ | 409/455 [42:27<04:56,  6.44s/it]

Inputs shape: torch.Size([1, 1037])


 90%|█████████ | 410/455 [42:32<04:28,  5.97s/it]

Inputs shape: torch.Size([1, 1120])


 90%|█████████ | 411/455 [42:37<04:18,  5.88s/it]

Inputs shape: torch.Size([1, 748])


 91%|█████████ | 412/455 [42:45<04:33,  6.37s/it]

Inputs shape: torch.Size([1, 1041])


 91%|█████████ | 413/455 [42:53<04:51,  6.93s/it]

Inputs shape: torch.Size([1, 1207])


 91%|█████████ | 414/455 [43:00<04:46,  6.99s/it]

Inputs shape: torch.Size([1, 890])


 91%|█████████ | 415/455 [43:12<05:39,  8.48s/it]

Inputs shape: torch.Size([1, 1039])


 91%|█████████▏| 416/455 [43:25<06:19,  9.74s/it]

Inputs shape: torch.Size([1, 1095])


 92%|█████████▏| 417/455 [43:35<06:15,  9.89s/it]

Inputs shape: torch.Size([1, 1103])


 92%|█████████▏| 418/455 [43:50<07:07, 11.55s/it]

Inputs shape: torch.Size([1, 1149])


 92%|█████████▏| 419/455 [43:55<05:43,  9.55s/it]

Inputs shape: torch.Size([1, 1078])


 92%|█████████▏| 420/455 [44:00<04:44,  8.13s/it]

Inputs shape: torch.Size([1, 1091])


 93%|█████████▎| 421/455 [44:07<04:27,  7.87s/it]

Inputs shape: torch.Size([1, 1235])


 93%|█████████▎| 422/455 [44:15<04:17,  7.80s/it]

Inputs shape: torch.Size([1, 1186])


 93%|█████████▎| 423/455 [44:21<03:50,  7.21s/it]

Inputs shape: torch.Size([1, 976])


 93%|█████████▎| 424/455 [44:26<03:20,  6.47s/it]

Inputs shape: torch.Size([1, 1103])


 93%|█████████▎| 425/455 [44:32<03:10,  6.34s/it]

Inputs shape: torch.Size([1, 1157])


 94%|█████████▎| 426/455 [44:39<03:15,  6.73s/it]

Inputs shape: torch.Size([1, 1073])


 94%|█████████▍| 427/455 [44:46<03:10,  6.79s/it]

Inputs shape: torch.Size([1, 952])


 94%|█████████▍| 428/455 [44:51<02:46,  6.18s/it]

Inputs shape: torch.Size([1, 1000])


 94%|█████████▍| 429/455 [44:56<02:32,  5.87s/it]

Inputs shape: torch.Size([1, 1006])


 95%|█████████▍| 430/455 [45:04<02:39,  6.36s/it]

Inputs shape: torch.Size([1, 1138])


 95%|█████████▍| 431/455 [45:11<02:42,  6.79s/it]

Inputs shape: torch.Size([1, 1023])


 95%|█████████▍| 432/455 [45:16<02:23,  6.23s/it]

Inputs shape: torch.Size([1, 943])


 95%|█████████▌| 433/455 [45:21<02:07,  5.80s/it]

Inputs shape: torch.Size([1, 949])


 95%|█████████▌| 434/455 [45:32<02:35,  7.43s/it]

Inputs shape: torch.Size([1, 930])


 96%|█████████▌| 435/455 [45:44<02:55,  8.76s/it]

Inputs shape: torch.Size([1, 1273])


 96%|█████████▌| 436/455 [45:59<03:18, 10.43s/it]

Inputs shape: torch.Size([1, 1033])


 96%|█████████▌| 437/455 [46:15<03:42, 12.35s/it]

Inputs shape: torch.Size([1, 1096])


 96%|█████████▋| 438/455 [46:24<03:11, 11.26s/it]

Inputs shape: torch.Size([1, 1109])


 96%|█████████▋| 439/455 [46:32<02:44, 10.29s/it]

Inputs shape: torch.Size([1, 1128])


 97%|█████████▋| 440/455 [46:39<02:21,  9.40s/it]

Inputs shape: torch.Size([1, 1097])


 97%|█████████▋| 441/455 [46:44<01:52,  8.01s/it]

Inputs shape: torch.Size([1, 1038])


 97%|█████████▋| 442/455 [46:49<01:31,  7.05s/it]

Inputs shape: torch.Size([1, 948])


 97%|█████████▋| 443/455 [46:56<01:25,  7.10s/it]

Inputs shape: torch.Size([1, 963])


 98%|█████████▊| 444/455 [47:04<01:19,  7.23s/it]

Inputs shape: torch.Size([1, 887])


 98%|█████████▊| 445/455 [47:09<01:06,  6.67s/it]

Inputs shape: torch.Size([1, 1093])


 98%|█████████▊| 446/455 [47:14<00:55,  6.11s/it]

Inputs shape: torch.Size([1, 782])


 98%|█████████▊| 447/455 [47:20<00:49,  6.13s/it]

Inputs shape: torch.Size([1, 894])


 98%|█████████▊| 448/455 [47:28<00:45,  6.54s/it]

Inputs shape: torch.Size([1, 880])


 99%|█████████▊| 449/455 [47:34<00:39,  6.55s/it]

Inputs shape: torch.Size([1, 977])


 99%|█████████▉| 450/455 [47:39<00:30,  6.08s/it]

Inputs shape: torch.Size([1, 1154])


 99%|█████████▉| 451/455 [47:45<00:23,  5.86s/it]

Inputs shape: torch.Size([1, 885])


 99%|█████████▉| 452/455 [47:52<00:18,  6.32s/it]

Inputs shape: torch.Size([1, 1152])


100%|█████████▉| 453/455 [48:01<00:14,  7.05s/it]

Inputs shape: torch.Size([1, 1036])


100%|█████████▉| 454/455 [48:05<00:06,  6.36s/it]

Inputs shape: torch.Size([1, 952])


100%|██████████| 455/455 [48:10<00:00,  6.35s/it]


In [28]:
from nltk.translate.bleu_score import corpus_bleu
bleu_score = corpus_bleu(
    [[ref] for ref in references],
    predictions
)
print(f"BLEU Score: {bleu_score}")

BLEU Score: 0.061186483440603026


In [35]:
%%capture
!pip install rouge_score

In [36]:

from rouge_score import rouge_scorer
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
rouge_scores = [scorer.score(ref, pred) for ref, pred in zip(references, predictions)]

rouge1 = sum([score["rouge1"].fmeasure for score in rouge_scores]) / len(rouge_scores)
rouge2 = sum([score["rouge2"].fmeasure for score in rouge_scores]) / len(rouge_scores)
rougeL = sum([score["rougeL"].fmeasure for score in rouge_scores]) / len(rouge_scores)

print(f"ROUGE-1: {rouge1}")
print(f"ROUGE-2: {rouge2}")
print(f"ROUGE-L: {rougeL}")

ROUGE-1: 0.10840883575270839
ROUGE-2: 0.05707974991497904
ROUGE-L: 0.07343145479797761


In [37]:
from collections import Counter

def distinct(seqs):
    """ Calculate intra/inter distinct 1/2. """
    """ Recuperado de https://github.com/PaddlePaddle/models/blob/release/1.6/PaddleNLP/Research/Dialogue-PLATO/plato/metrics/metrics.py"""
    batch_size = len(seqs)
    intra_dist1, intra_dist2 = [], []
    unigrams_all, bigrams_all = Counter(), Counter()
    for seq in seqs:
        unigrams = Counter(seq)
        bigrams = Counter(zip(seq, seq[1:]))
        intra_dist1.append((len(unigrams)+1e-12) / (len(seq)+1e-5))
        intra_dist2.append((len(bigrams)+1e-12) / (max(0, len(seq)-1)+1e-5))

        unigrams_all.update(unigrams)
        bigrams_all.update(bigrams)

    inter_dist1 = (len(unigrams_all)+1e-12) / (sum(unigrams_all.values())+1e-5)
    inter_dist2 = (len(bigrams_all)+1e-12) / (sum(bigrams_all.values())+1e-5)
    intra_dist1 = np.average(intra_dist1)
    intra_dist2 = np.average(intra_dist2)
    return intra_dist1, intra_dist2, inter_dist1, inter_dist2

In [39]:
import nltk
import numpy as np
nltk.download('punkt_tab')

tokenized_text = [nltk.word_tokenize(text) for text in predictions]
resultado = distinct(tokenized_text)
print(resultado)

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


(0.4019415528884578, 0.7486258296098945, 0.047694873310890985, 0.2169420319069025)


In [41]:
%%capture
!pip install evaluate bert_score

In [42]:
from evaluate import load

bertscore = load("bertscore")
results = bertscore.compute(predictions=predictions, references=references, lang="en", model_type="distilbert-base-uncased")

Downloading builder script:   0%|          | 0.00/7.95k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [43]:
from statistics import mean

print(results.keys())
print(f"F1: {mean(results['f1'])}")
print(f"Precision: {mean(results['precision'])}")
print(f"Recall: {mean(results['recall'])}")

dict_keys(['precision', 'recall', 'f1', 'hashcode'])
F1: 0.7110254476358603
Precision: 0.658111031631847
Recall: 0.7734422106009263


In [45]:
from collections import Counter

def calculate_distinct_metrics(predictions):
    tokens = [token for sentence in predictions for token in sentence.split()]
    unigrams = Counter(tokens)
    bigrams = Counter(zip(tokens, tokens[1:]))

    distinct_1 = len(unigrams) / len(tokens) if tokens else 0
    distinct_2 = len(bigrams) / len(tokens) if len(tokens) > 1 else 0

    return distinct_1, distinct_2

distinct_1, distinct_2 = calculate_distinct_metrics(predictions)

print(f"Distinct-1: {distinct_1:.4f}")
print(f"Distinct-2: {distinct_2:.4f}")

Distinct-1: 0.0788
Distinct-2: 0.3119
