In [12]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import WeightedRandomSampler
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score
from datasets import Dataset, ClassLabel
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, TrainingArguments, Trainer

In [13]:
# train_balanced.py
# Treina modelo DistilBERT com oversampled data

import pandas as pd
import numpy as np
import torch
from sklearn.metrics import f1_score
from transformers import (
    DistilBertTokenizerFast,
    DistilBertForSequenceClassification,
    Trainer,
    TrainingArguments
)
from datasets import Dataset, ClassLabel
from sklearn.preprocessing import LabelEncoder

# 1. Carregar dataset balanceado
print("📄 Lendo dados balanceados...")
df = pd.read_csv(r"C:\Users\Ana\Desktop\project_ENHESA\data\final_data.csv")


📄 Lendo dados balanceados...


In [14]:
# 2. Codificar labels (se necessário)
print("🔠 Codificando labels...")
label_encoder = LabelEncoder()
df["label"] = label_encoder.fit_transform(df["target"])
labels = label_encoder.classes_.tolist()

# 3. Tokenizador
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

def tokenize_function(batch):
    return tokenizer(batch["synopsis"], padding=True, truncation=True)

# 4. Preparar dataset Hugging Face
hf_dataset = Dataset.from_pandas(df[["synopsis", "label"]])
hf_dataset = hf_dataset.cast_column("label", ClassLabel(num_classes=len(labels), names=labels))
hf_dataset = hf_dataset.train_test_split(test_size=0.2, stratify_by_column="label", seed=42)
hf_dataset = hf_dataset.map(tokenize_function, batched=True)

# 5. Modelo
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=len(labels))

# 6. Métrica de avaliação
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    f1 = f1_score(labels, preds, average="macro")
    return {"f1_macro": f1}

# 7. Argumentos de treino
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    learning_rate=2e-5,
    logging_dir="./logs",
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    logging_steps=10
)

# 8. Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=hf_dataset["train"],
    eval_dataset=hf_dataset["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# 9. Treinar
print("🚀 Iniciando treino...")
trainer.train()

# 10. Avaliar
print("📊 Avaliação:")
metrics = trainer.evaluate()
print(metrics)

# 11. Salvar modelo
print("💾 Salvando modelo em 'modelo_distilbert'...")
trainer.save_model("modelo_distilbert")
tokenizer.save_pretrained("modelo_distilbert")

🔠 Codificando labels...



Casting the dataset: 100%|██████████| 3075/3075 [00:00<00:00, 30090.16 examples/s]

[A
[A
[A
Map: 100%|██████████| 2460/2460 [00:02<00:00, 955.05 examples/s]

[A
Map: 100%|██████████| 615/615 [00:00<00:00, 986.16 examples/s] 
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 0/232 [09:43<?, ?it/s]


🚀 Iniciando treino...


  1%|          | 10/924 [01:52<2:53:03, 11.36s/it]
  1%|          | 10/924 [01:52<2:53:03, 11.36s/it]

{'loss': 1.0924, 'grad_norm': 3.170321464538574, 'learning_rate': 1.9783549783549785e-05, 'epoch': 0.03}


  2%|▏         | 20/924 [03:49<2:54:55, 11.61s/it]
  2%|▏         | 20/924 [03:49<2:54:55, 11.61s/it]

{'loss': 1.1034, 'grad_norm': 2.3274991512298584, 'learning_rate': 1.9567099567099568e-05, 'epoch': 0.06}


  3%|▎         | 30/924 [05:33<2:30:41, 10.11s/it]
  3%|▎         | 30/924 [05:33<2:30:41, 10.11s/it]

{'loss': 1.0692, 'grad_norm': 2.788170576095581, 'learning_rate': 1.9350649350649354e-05, 'epoch': 0.1}


  4%|▍         | 40/924 [07:15<2:32:09, 10.33s/it]
  4%|▍         | 40/924 [07:15<2:32:09, 10.33s/it]

{'loss': 1.0433, 'grad_norm': 4.508909702301025, 'learning_rate': 1.9134199134199138e-05, 'epoch': 0.13}


  5%|▌         | 50/924 [08:56<2:25:19,  9.98s/it]
  5%|▌         | 50/924 [08:56<2:25:19,  9.98s/it]

{'loss': 1.0155, 'grad_norm': 3.2613258361816406, 'learning_rate': 1.891774891774892e-05, 'epoch': 0.16}


  6%|▋         | 60/924 [10:38<2:29:17, 10.37s/it]
  6%|▋         | 60/924 [10:38<2:29:17, 10.37s/it]

{'loss': 0.9092, 'grad_norm': 3.396336317062378, 'learning_rate': 1.8701298701298704e-05, 'epoch': 0.19}


  8%|▊         | 70/924 [12:26<2:31:57, 10.68s/it]
  8%|▊         | 70/924 [12:26<2:31:57, 10.68s/it]

{'loss': 0.8275, 'grad_norm': 5.6048808097839355, 'learning_rate': 1.8484848484848487e-05, 'epoch': 0.23}


  9%|▊         | 80/924 [14:04<2:13:11,  9.47s/it]
  9%|▊         | 80/924 [14:04<2:13:11,  9.47s/it]

{'loss': 0.8211, 'grad_norm': 4.2908711433410645, 'learning_rate': 1.826839826839827e-05, 'epoch': 0.26}


 10%|▉         | 90/924 [15:39<2:15:09,  9.72s/it]
 10%|▉         | 90/924 [15:39<2:15:09,  9.72s/it]

{'loss': 0.7916, 'grad_norm': 11.080763816833496, 'learning_rate': 1.8051948051948053e-05, 'epoch': 0.29}


 11%|█         | 100/924 [17:40<2:33:36, 11.19s/it]
 11%|█         | 100/924 [17:40<2:33:36, 11.19s/it]

{'loss': 0.732, 'grad_norm': 8.519530296325684, 'learning_rate': 1.7835497835497836e-05, 'epoch': 0.32}


 12%|█▏        | 110/924 [19:29<2:30:58, 11.13s/it]
 12%|█▏        | 110/924 [19:29<2:30:58, 11.13s/it]

{'loss': 0.7961, 'grad_norm': 10.298563003540039, 'learning_rate': 1.761904761904762e-05, 'epoch': 0.36}


 13%|█▎        | 120/924 [21:21<2:25:40, 10.87s/it]
 13%|█▎        | 120/924 [21:21<2:25:40, 10.87s/it]

{'loss': 0.6412, 'grad_norm': 8.243788719177246, 'learning_rate': 1.7402597402597403e-05, 'epoch': 0.39}


 14%|█▍        | 130/924 [22:59<2:06:44,  9.58s/it]
 14%|█▍        | 130/924 [22:59<2:06:44,  9.58s/it]

{'loss': 0.6514, 'grad_norm': 12.171868324279785, 'learning_rate': 1.718614718614719e-05, 'epoch': 0.42}


 15%|█▌        | 140/924 [24:38<2:05:00,  9.57s/it]
 15%|█▌        | 140/924 [24:38<2:05:00,  9.57s/it]

{'loss': 0.7047, 'grad_norm': 5.392636775970459, 'learning_rate': 1.6969696969696972e-05, 'epoch': 0.45}


 16%|█▌        | 150/924 [26:16<2:06:55,  9.84s/it]
 16%|█▌        | 150/924 [26:16<2:06:55,  9.84s/it]

{'loss': 0.6464, 'grad_norm': 8.097850799560547, 'learning_rate': 1.6753246753246756e-05, 'epoch': 0.49}


 17%|█▋        | 160/924 [28:08<2:22:38, 11.20s/it]
 17%|█▋        | 160/924 [28:08<2:22:38, 11.20s/it]

{'loss': 0.7345, 'grad_norm': 8.301258087158203, 'learning_rate': 1.653679653679654e-05, 'epoch': 0.52}


 18%|█▊        | 170/924 [30:04<2:23:34, 11.43s/it]
 18%|█▊        | 170/924 [30:04<2:23:34, 11.43s/it]

{'loss': 0.7351, 'grad_norm': 7.5268940925598145, 'learning_rate': 1.6320346320346322e-05, 'epoch': 0.55}


 19%|█▉        | 180/924 [31:56<2:17:05, 11.06s/it]
 19%|█▉        | 180/924 [31:56<2:17:05, 11.06s/it]

{'loss': 0.562, 'grad_norm': 6.949707984924316, 'learning_rate': 1.6103896103896105e-05, 'epoch': 0.58}


 21%|██        | 190/924 [33:44<2:09:10, 10.56s/it]
 21%|██        | 190/924 [33:44<2:09:10, 10.56s/it]

{'loss': 0.5982, 'grad_norm': 4.067653656005859, 'learning_rate': 1.5887445887445888e-05, 'epoch': 0.62}


 22%|██▏       | 200/924 [35:19<1:56:40,  9.67s/it]
 22%|██▏       | 200/924 [35:19<1:56:40,  9.67s/it]

{'loss': 0.552, 'grad_norm': 5.478620529174805, 'learning_rate': 1.567099567099567e-05, 'epoch': 0.65}


 23%|██▎       | 210/924 [37:08<2:11:02, 11.01s/it]
 23%|██▎       | 210/924 [37:08<2:11:02, 11.01s/it]

{'loss': 0.566, 'grad_norm': 4.183908939361572, 'learning_rate': 1.5454545454545454e-05, 'epoch': 0.68}


 24%|██▍       | 220/924 [38:56<2:04:41, 10.63s/it]
 24%|██▍       | 220/924 [38:56<2:04:41, 10.63s/it]

{'loss': 0.5092, 'grad_norm': 62.38662338256836, 'learning_rate': 1.523809523809524e-05, 'epoch': 0.71}


 25%|██▍       | 230/924 [40:39<1:56:42, 10.09s/it]
 25%|██▍       | 230/924 [40:39<1:56:42, 10.09s/it]

{'loss': 0.4858, 'grad_norm': 16.743417739868164, 'learning_rate': 1.5021645021645024e-05, 'epoch': 0.75}


 26%|██▌       | 240/924 [42:30<1:56:43, 10.24s/it]
 26%|██▌       | 240/924 [42:30<1:56:43, 10.24s/it]

{'loss': 0.5175, 'grad_norm': 14.059120178222656, 'learning_rate': 1.4805194805194807e-05, 'epoch': 0.78}


 27%|██▋       | 250/924 [44:15<1:54:55, 10.23s/it]
 27%|██▋       | 250/924 [44:15<1:54:55, 10.23s/it]

{'loss': 0.372, 'grad_norm': 5.385976314544678, 'learning_rate': 1.458874458874459e-05, 'epoch': 0.81}


 28%|██▊       | 260/924 [46:07<2:07:46, 11.55s/it]
 28%|██▊       | 260/924 [46:07<2:07:46, 11.55s/it]

{'loss': 0.4441, 'grad_norm': 5.124004364013672, 'learning_rate': 1.4372294372294374e-05, 'epoch': 0.84}


 29%|██▉       | 270/924 [48:06<2:12:01, 12.11s/it]
 29%|██▉       | 270/924 [48:06<2:12:01, 12.11s/it]

{'loss': 0.4507, 'grad_norm': 3.5989339351654053, 'learning_rate': 1.4155844155844157e-05, 'epoch': 0.88}


 30%|███       | 280/924 [50:04<2:02:39, 11.43s/it]
 30%|███       | 280/924 [50:04<2:02:39, 11.43s/it]

{'loss': 0.4263, 'grad_norm': 15.455986976623535, 'learning_rate': 1.3939393939393942e-05, 'epoch': 0.91}


 31%|███▏      | 290/924 [52:00<2:03:03, 11.65s/it]
 31%|███▏      | 290/924 [52:00<2:03:03, 11.65s/it]

{'loss': 0.4035, 'grad_norm': 8.447094917297363, 'learning_rate': 1.3722943722943725e-05, 'epoch': 0.94}


 32%|███▏      | 300/924 [53:53<1:56:40, 11.22s/it]
 32%|███▏      | 300/924 [53:53<1:56:40, 11.22s/it]

{'loss': 0.4519, 'grad_norm': 24.57912254333496, 'learning_rate': 1.3506493506493508e-05, 'epoch': 0.97}


 33%|███▎      | 308/924 [55:19<1:42:01,  9.94s/it]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                         
                                                   
 33%|███▎      | 308/924 [58:25<1:42:01,  9.94s/it]
[A

{'eval_loss': 0.4182160794734955, 'eval_f1_macro': 0.8427809506580477, 'eval_runtime': 185.5145, 'eval_samples_per_second': 3.315, 'eval_steps_per_second': 0.415, 'epoch': 1.0}


 34%|███▎      | 310/924 [58:47<8:24:51, 49.34s/it] 
 34%|███▎      | 310/924 [58:47<8:24:51, 49.34s/it]

{'loss': 0.3126, 'grad_norm': 7.7613205909729, 'learning_rate': 1.3290043290043291e-05, 'epoch': 1.01}


 35%|███▍      | 320/924 [1:00:34<2:04:04, 12.33s/it]
 35%|███▍      | 320/924 [1:00:34<2:04:04, 12.33s/it]

{'loss': 0.3193, 'grad_norm': 10.761905670166016, 'learning_rate': 1.3073593073593074e-05, 'epoch': 1.04}


 36%|███▌      | 330/924 [1:02:35<1:59:04, 12.03s/it]
 36%|███▌      | 330/924 [1:02:35<1:59:04, 12.03s/it]

{'loss': 0.429, 'grad_norm': 8.636361122131348, 'learning_rate': 1.2857142857142859e-05, 'epoch': 1.07}


 37%|███▋      | 340/924 [1:04:11<1:27:51,  9.03s/it]
 37%|███▋      | 340/924 [1:04:11<1:27:51,  9.03s/it]

{'loss': 0.306, 'grad_norm': 5.2880473136901855, 'learning_rate': 1.2640692640692642e-05, 'epoch': 1.1}


 38%|███▊      | 350/924 [1:05:43<1:31:18,  9.54s/it]
 38%|███▊      | 350/924 [1:05:43<1:31:18,  9.54s/it]

{'loss': 0.2435, 'grad_norm': 4.1484599113464355, 'learning_rate': 1.2424242424242425e-05, 'epoch': 1.14}


 39%|███▉      | 360/924 [1:07:37<1:51:07, 11.82s/it]
 39%|███▉      | 360/924 [1:07:37<1:51:07, 11.82s/it]

{'loss': 0.3478, 'grad_norm': 2.597181797027588, 'learning_rate': 1.2207792207792208e-05, 'epoch': 1.17}


 40%|████      | 370/924 [1:09:31<1:44:23, 11.31s/it]
 40%|████      | 370/924 [1:09:31<1:44:23, 11.31s/it]

{'loss': 0.4911, 'grad_norm': 12.529495239257812, 'learning_rate': 1.1991341991341991e-05, 'epoch': 1.2}


 41%|████      | 380/924 [1:11:15<1:38:25, 10.85s/it]
 41%|████      | 380/924 [1:11:15<1:38:25, 10.85s/it]

{'loss': 0.3075, 'grad_norm': 23.992582321166992, 'learning_rate': 1.1774891774891776e-05, 'epoch': 1.23}


 42%|████▏     | 390/924 [1:13:06<1:34:27, 10.61s/it]
 42%|████▏     | 390/924 [1:13:06<1:34:27, 10.61s/it]

{'loss': 0.2601, 'grad_norm': 1.6189945936203003, 'learning_rate': 1.155844155844156e-05, 'epoch': 1.27}


 43%|████▎     | 400/924 [1:14:44<1:24:02,  9.62s/it]
 43%|████▎     | 400/924 [1:14:44<1:24:02,  9.62s/it]

{'loss': 0.3418, 'grad_norm': 7.6308112144470215, 'learning_rate': 1.1341991341991343e-05, 'epoch': 1.3}


 44%|████▍     | 410/924 [1:16:20<1:19:47,  9.31s/it]
 44%|████▍     | 410/924 [1:16:20<1:19:47,  9.31s/it]

{'loss': 0.2496, 'grad_norm': 1.8456485271453857, 'learning_rate': 1.1125541125541126e-05, 'epoch': 1.33}


 45%|████▌     | 420/924 [1:17:58<1:18:42,  9.37s/it]
 45%|████▌     | 420/924 [1:17:58<1:18:42,  9.37s/it]

{'loss': 0.2602, 'grad_norm': 16.959470748901367, 'learning_rate': 1.0909090909090909e-05, 'epoch': 1.36}


 47%|████▋     | 430/924 [1:19:33<1:19:49,  9.70s/it]
 47%|████▋     | 430/924 [1:19:33<1:19:49,  9.70s/it]

{'loss': 0.3454, 'grad_norm': 5.860401153564453, 'learning_rate': 1.0692640692640694e-05, 'epoch': 1.4}


 48%|████▊     | 440/924 [1:21:09<1:18:11,  9.69s/it]
 48%|████▊     | 440/924 [1:21:09<1:18:11,  9.69s/it]

{'loss': 0.3706, 'grad_norm': 9.878073692321777, 'learning_rate': 1.0476190476190477e-05, 'epoch': 1.43}


 49%|████▊     | 450/924 [1:22:49<1:20:55, 10.24s/it]
 49%|████▊     | 450/924 [1:22:49<1:20:55, 10.24s/it]

{'loss': 0.2983, 'grad_norm': 10.509800910949707, 'learning_rate': 1.025974025974026e-05, 'epoch': 1.46}


 50%|████▉     | 460/924 [1:24:30<1:15:13,  9.73s/it]
 50%|████▉     | 460/924 [1:24:30<1:15:13,  9.73s/it]

{'loss': 0.2166, 'grad_norm': 28.700563430786133, 'learning_rate': 1.0043290043290043e-05, 'epoch': 1.49}


 51%|█████     | 470/924 [1:26:24<1:28:36, 11.71s/it]
 51%|█████     | 470/924 [1:26:24<1:28:36, 11.71s/it]

{'loss': 0.283, 'grad_norm': 8.022814750671387, 'learning_rate': 9.826839826839828e-06, 'epoch': 1.53}


 52%|█████▏    | 480/924 [1:28:12<1:18:23, 10.59s/it]
 52%|█████▏    | 480/924 [1:28:12<1:18:23, 10.59s/it]

{'loss': 0.2473, 'grad_norm': 2.5427064895629883, 'learning_rate': 9.610389610389611e-06, 'epoch': 1.56}


 53%|█████▎    | 490/924 [1:29:54<1:16:58, 10.64s/it]
 53%|█████▎    | 490/924 [1:29:54<1:16:58, 10.64s/it]

{'loss': 0.1985, 'grad_norm': 1.299401879310608, 'learning_rate': 9.393939393939396e-06, 'epoch': 1.59}


 54%|█████▍    | 500/924 [1:31:40<1:14:31, 10.55s/it]
 54%|█████▍    | 500/924 [1:31:40<1:14:31, 10.55s/it]

{'loss': 0.2876, 'grad_norm': 11.745614051818848, 'learning_rate': 9.177489177489179e-06, 'epoch': 1.62}


 55%|█████▌    | 510/924 [1:33:28<1:14:00, 10.73s/it]
 55%|█████▌    | 510/924 [1:33:28<1:14:00, 10.73s/it]

{'loss': 0.1578, 'grad_norm': 0.45030954480171204, 'learning_rate': 8.96103896103896e-06, 'epoch': 1.66}


 56%|█████▋    | 520/924 [1:35:19<1:19:39, 11.83s/it]
 56%|█████▋    | 520/924 [1:35:19<1:19:39, 11.83s/it]

{'loss': 0.188, 'grad_norm': 0.4139750897884369, 'learning_rate': 8.744588744588745e-06, 'epoch': 1.69}


 57%|█████▋    | 530/924 [1:36:59<1:04:21,  9.80s/it]
 57%|█████▋    | 530/924 [1:36:59<1:04:21,  9.80s/it]

{'loss': 0.2385, 'grad_norm': 16.035444259643555, 'learning_rate': 8.528138528138529e-06, 'epoch': 1.72}


 58%|█████▊    | 540/924 [1:38:37<1:08:18, 10.67s/it]
 58%|█████▊    | 540/924 [1:38:37<1:08:18, 10.67s/it]

{'loss': 0.2855, 'grad_norm': 0.3678413927555084, 'learning_rate': 8.311688311688313e-06, 'epoch': 1.75}


 60%|█████▉    | 550/924 [1:40:13<1:02:31, 10.03s/it]
 60%|█████▉    | 550/924 [1:40:13<1:02:31, 10.03s/it]

{'loss': 0.1642, 'grad_norm': 1.690521240234375, 'learning_rate': 8.095238095238097e-06, 'epoch': 1.79}


 61%|██████    | 560/924 [1:42:03<1:05:56, 10.87s/it]
 61%|██████    | 560/924 [1:42:03<1:05:56, 10.87s/it]

{'loss': 0.1545, 'grad_norm': 28.68500328063965, 'learning_rate': 7.87878787878788e-06, 'epoch': 1.82}


 62%|██████▏   | 570/924 [1:43:49<1:04:02, 10.85s/it]
 62%|██████▏   | 570/924 [1:43:49<1:04:02, 10.85s/it]

{'loss': 0.2804, 'grad_norm': 21.41816520690918, 'learning_rate': 7.662337662337663e-06, 'epoch': 1.85}


 63%|██████▎   | 580/924 [1:45:37<1:03:35, 11.09s/it]
 63%|██████▎   | 580/924 [1:45:37<1:03:35, 11.09s/it]

{'loss': 0.3821, 'grad_norm': 11.347739219665527, 'learning_rate': 7.445887445887446e-06, 'epoch': 1.88}


 64%|██████▍   | 590/924 [1:47:23<56:18, 10.11s/it]  
 64%|██████▍   | 590/924 [1:47:23<56:18, 10.11s/it]

{'loss': 0.1324, 'grad_norm': 1.196947455406189, 'learning_rate': 7.229437229437229e-06, 'epoch': 1.92}


 65%|██████▍   | 600/924 [1:49:03<52:17,  9.68s/it]
 65%|██████▍   | 600/924 [1:49:03<52:17,  9.68s/it]

{'loss': 0.1691, 'grad_norm': 3.2920584678649902, 'learning_rate': 7.012987012987014e-06, 'epoch': 1.95}


 66%|██████▌   | 610/924 [1:50:49<55:15, 10.56s/it]
 66%|██████▌   | 610/924 [1:50:49<55:15, 10.56s/it]

{'loss': 0.155, 'grad_norm': 8.252318382263184, 'learning_rate': 6.796536796536797e-06, 'epoch': 1.98}


 67%|██████▋   | 616/924 [1:51:45<45:37,  8.89s/it]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

KeyboardInterrupt: 