In [130]:
import pandas as pd
from datasets import load_dataset
from sklearn.model_selection import train_test_split

# ¿Que dicen los LLM de nuestros datos?

Primero importaremos los datos y le aplicaremos la función de rating

In [131]:
# leemos el csv
df = pd.read_csv("threads.csv")
df = df.drop(["source", "review_date"], axis=1)
# change the column names to match the HuggingFace dataset format
df = df.rename(columns={"review": "text", "rating": "label"})

# función para convertir el rating a palabras
def ratingTransform(rating):
    if rating <= 2:
        return 0
    elif rating <= 4:
        return 1
    else:
        return 2

# Transform the ratings to positive, neutral, and negative
df["label"] = df["label"].apply(ratingTransform)

train, test = train_test_split(df, test_size=0.33)

# guardamos los datos en un csv
train.to_csv("train.csv", index=False)
test.to_csv("test.csv", index=False)

# cargamos los datos en un dataset de huggingface
dataset = load_dataset("csv", data_files={"train": "train.csv", "test": "test.csv"})

Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 2000.62it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 154.26it/s]
Generating train split: 22049 examples [00:00, 485786.82 examples/s]
Generating test split: 10861 examples [00:00, 473488.57 examples/s]


## Usando Tokenizers

Los tokenizers son el proceso de convertir una secuencia de texto en una secuencia de tokens (números que hacen referencia a palabras).

In [132]:
# Entrenemos un modelo DistilBERT
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

#max_length = 512
# Tokenizamos los datos
def tokenize_data(example):
    return tokenizer(example['review_description'], padding='max_length', truncation=True)

# tokenizamos los datos
dataset = dataset.map(tokenize_data, batched=True)

Map: 100%|██████████| 22049/22049 [00:03<00:00, 5586.79 examples/s]
Map: 100%|██████████| 10861/10861 [00:01<00:00, 5812.50 examples/s]


## Cargamos el modelo DistilBERT

DistilBERT es un modelo de BERT que es más pequeño y rápido, pero que mantiene una precisión similar a la de BERT.

In [133]:
from transformers import AutoModelForSequenceClassification, TrainingArguments

# Cargamos un modelo pre-entrenado
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=3)

batch_size = 8
number_of_epochs = 7
logging_steps = len(dataset['train']) // batch_size
steps = (len(dataset['train']) // batch_size) * number_of_epochs
warmup_steps = int(0.2*steps)

training_args = TrainingArguments(
                                  num_train_epochs=number_of_epochs, 
                                  load_best_model_at_end=True,
                                  evaluation_strategy='steps', 
                                  save_strategy='steps',
                                  learning_rate=2e-5,
                                  logging_steps=logging_steps,
                                  warmup_steps= warmup_steps,
                                  save_steps=1000,
                                  eval_steps=500,
                                  output_dir="fine-tuned-distilbert-base-uncased"
                                  )

# shuffle the dataset
train_dataset = dataset['train'].shuffle(seed=10) 
eval_dataset = dataset['test'].shuffle(seed=10)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.bias', 'classifier.weight', 'pre_classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Fine-tuning de DistilBERT

El fine-tuning es un proceso de entrenamiento en el que se usa un modelo pre-entrenado y se ajusta para que se adapte a los datos específicos del problema. En este caso, usaremos el modelo pre-entrenado de DistilBERT y lo ajustaremos para que se adapte a nuestros datos de entrenamiento.

In [134]:
from transformers import Trainer
import numpy as np
from datasets import load_metric

# Funcion para poder computar las metricas
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    # Cargar las metricas
    accuracy_metric = load_metric("accuracy")
    f1_metric = load_metric("f1")
    precision_metric = load_metric("precision")
    recall_metric = load_metric("recall")
    
    # Calcular las metricas
    accuracy  = accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"]
    f1        = f1_metric.compute(predictions=predictions, references=labels, average="weighted")["f1"]
    precision = precision_metric.compute(predictions=predictions, references=labels, average="weighted")["precision"]
    recall    = recall_metric.compute(predictions=predictions, references=labels, average="weighted")["recall"]

    metrics = {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

    return metrics

""" prev func
#Función para computar las metricas
def compute_metrics(eval_pred):
    # load the metrics to use
    load_accuracy = load_metric("accuracy")
    load_f1 = load_metric("f1")

    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    # calculate the mertic using the predicted and true value 
    accuracy = load_accuracy.compute(predictions=predictions, references=labels)
    f1 = load_f1.compute(predictions=predictions, references=labels, average="weighted")
    return {"accuracy": accuracy, "f1score": f1}
"""

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

  3%|▎         | 500/19299 [01:41<1:03:34,  4.93it/s]
Downloading builder script: 7.55kB [00:00, ?B/s]                       

Downloading builder script: 7.38kB [00:00, ?B/s]                       
  _warn_prf(average, modifier, msg_start, len(result))

  3%|▎         | 500/19299 [03:20<1:03:34,  4.93it/s]

{'eval_loss': 0.8851959705352783, 'eval_accuracy': 0.6289476107172451, 'eval_f1': 0.5639379312396521, 'eval_precision': 0.5150621330237951, 'eval_recall': 0.6289476107172451, 'eval_runtime': 98.789, 'eval_samples_per_second': 109.941, 'eval_steps_per_second': 13.746, 'epoch': 0.18}


                                                       
  5%|▌         | 1000/19299 [06:40<1:01:54,  4.93it/s]

{'eval_loss': 0.6858401894569397, 'eval_accuracy': 0.7304115643126784, 'eval_f1': 0.698458599270395, 'eval_precision': 0.7078432111473476, 'eval_recall': 0.7304115643126784, 'eval_runtime': 98.1154, 'eval_samples_per_second': 110.696, 'eval_steps_per_second': 13.841, 'epoch': 0.36}


                                                        
  8%|▊         | 1500/19299 [10:02<59:37,  4.97it/s]

{'eval_loss': 0.6718159914016724, 'eval_accuracy': 0.7404474726084155, 'eval_f1': 0.7173389678974019, 'eval_precision': 0.7203884366629909, 'eval_recall': 0.7404474726084155, 'eval_runtime': 99.8571, 'eval_samples_per_second': 108.765, 'eval_steps_per_second': 13.599, 'epoch': 0.54}


                                                        
 10%|█         | 2000/19299 [13:21<58:11,  4.96it/s]

{'eval_loss': 0.6772379279136658, 'eval_accuracy': 0.7392505294171807, 'eval_f1': 0.7216436501170218, 'eval_precision': 0.7310111879554647, 'eval_recall': 0.7392505294171807, 'eval_runtime': 97.5601, 'eval_samples_per_second': 111.326, 'eval_steps_per_second': 13.92, 'epoch': 0.73}


                                                        
 13%|█▎        | 2500/19299 [16:43<56:26,  4.96it/s]

{'eval_loss': 0.6630202531814575, 'eval_accuracy': 0.7502071632446368, 'eval_f1': 0.7335579360264689, 'eval_precision': 0.7397891768972807, 'eval_recall': 0.7502071632446368, 'eval_runtime': 100.0544, 'eval_samples_per_second': 108.551, 'eval_steps_per_second': 13.573, 'epoch': 0.91}


 14%|█▍        | 2756/19299 [17:40<1:00:30,  4.56it/s]  

{'loss': 0.7338, 'learning_rate': 1.4287195438050804e-05, 'epoch': 1.0}


                                                      
 16%|█▌        | 3000/19299 [20:24<1:03:23,  4.28it/s]

{'eval_loss': 0.6403883695602417, 'eval_accuracy': 0.7593223460086548, 'eval_f1': 0.7438053636788539, 'eval_precision': 0.7436633636039628, 'eval_recall': 0.7593223460086548, 'eval_runtime': 111.031, 'eval_samples_per_second': 97.82, 'eval_steps_per_second': 12.231, 'epoch': 1.09}


                                                        
 18%|█▊        | 3500/19299 [24:10<1:00:46,  4.33it/s]

{'eval_loss': 0.6695029735565186, 'eval_accuracy': 0.7571126047325293, 'eval_f1': 0.7447369318088908, 'eval_precision': 0.7420504388876412, 'eval_recall': 0.7571126047325293, 'eval_runtime': 112.0429, 'eval_samples_per_second': 96.936, 'eval_steps_per_second': 12.12, 'epoch': 1.27}


                                                        
 21%|██        | 4000/19299 [27:35<54:06,  4.71it/s]

{'eval_loss': 0.638457715511322, 'eval_accuracy': 0.7639259736672498, 'eval_f1': 0.7514273856034948, 'eval_precision': 0.7494562038253741, 'eval_recall': 0.7639259736672498, 'eval_runtime': 99.8175, 'eval_samples_per_second': 108.809, 'eval_steps_per_second': 13.605, 'epoch': 1.45}


                                                        
 23%|██▎       | 4500/19299 [30:54<49:24,  4.99it/s]

{'eval_loss': 0.6414019465446472, 'eval_accuracy': 0.765583279624344, 'eval_f1': 0.7456701024900905, 'eval_precision': 0.7513022317167645, 'eval_recall': 0.765583279624344, 'eval_runtime': 97.9402, 'eval_samples_per_second': 110.894, 'eval_steps_per_second': 13.866, 'epoch': 1.63}


                                                        
 26%|██▌       | 5000/19299 [34:11<47:46,  4.99it/s]

{'eval_loss': 0.6400208473205566, 'eval_accuracy': 0.7613479421784366, 'eval_f1': 0.7372532583644794, 'eval_precision': 0.7448736139400357, 'eval_recall': 0.7613479421784366, 'eval_runtime': 96.8468, 'eval_samples_per_second': 112.146, 'eval_steps_per_second': 14.022, 'epoch': 1.81}


                                                        
 28%|██▊       | 5500/19299 [37:30<45:53,  5.01it/s]

{'eval_loss': 0.62331622838974, 'eval_accuracy': 0.7674247306877819, 'eval_f1': 0.7491925002541671, 'eval_precision': 0.7508941820762998, 'eval_recall': 0.7674247306877819, 'eval_runtime': 97.5005, 'eval_samples_per_second': 111.394, 'eval_steps_per_second': 13.928, 'epoch': 1.99}


 29%|██▊       | 5513/19299 [37:32<2:19:00,  1.65it/s]  

{'loss': 0.5966, 'learning_rate': 1.7857651706495696e-05, 'epoch': 2.0}


                                                      
 31%|███       | 6000/19299 [40:46<44:10,  5.02it/s]

{'eval_loss': 0.70716792345047, 'eval_accuracy': 0.7651229168584844, 'eval_f1': 0.7470067444787822, 'eval_precision': 0.7471164889197514, 'eval_recall': 0.7651229168584844, 'eval_runtime': 96.663, 'eval_samples_per_second': 112.359, 'eval_steps_per_second': 14.049, 'epoch': 2.18}


                                                        
 34%|███▎      | 6500/19299 [44:05<42:36,  5.01it/s]

{'eval_loss': 0.7206271886825562, 'eval_accuracy': 0.7553632262222632, 'eval_f1': 0.7516016694969326, 'eval_precision': 0.7493531789368872, 'eval_recall': 0.7553632262222632, 'eval_runtime': 97.6827, 'eval_samples_per_second': 111.186, 'eval_steps_per_second': 13.902, 'epoch': 2.36}


                                                        
 36%|███▋      | 7000/19299 [47:22<40:58,  5.00it/s]

{'eval_loss': 0.694503903388977, 'eval_accuracy': 0.7666881502624068, 'eval_f1': 0.7548450403098814, 'eval_precision': 0.7545004757801279, 'eval_recall': 0.7666881502624068, 'eval_runtime': 96.679, 'eval_samples_per_second': 112.341, 'eval_steps_per_second': 14.046, 'epoch': 2.54}


                                                        
 39%|███▉      | 7500/19299 [50:40<39:17,  5.00it/s]

{'eval_loss': 0.7029852271080017, 'eval_accuracy': 0.7680692385599853, 'eval_f1': 0.7414228440033869, 'eval_precision': 0.751428291955373, 'eval_recall': 0.7680692385599853, 'eval_runtime': 97.6406, 'eval_samples_per_second': 111.234, 'eval_steps_per_second': 13.908, 'epoch': 2.72}


                                                       
 41%|████▏     | 8000/19299 [53:57<37:45,  4.99it/s]

{'eval_loss': 0.7050660252571106, 'eval_accuracy': 0.7540742104778565, 'eval_f1': 0.750821922293519, 'eval_precision': 0.748656268677426, 'eval_recall': 0.7540742104778565, 'eval_runtime': 96.5317, 'eval_samples_per_second': 112.512, 'eval_steps_per_second': 14.068, 'epoch': 2.9}


 43%|████▎     | 8269/19299 [54:52<36:54,  4.98it/s]   

{'loss': 0.5064, 'learning_rate': 1.4287934719253935e-05, 'epoch': 3.0}


                                                    
 44%|████▍     | 8500/19299 [57:16<35:58,  5.00it/s]

{'eval_loss': 0.7729307413101196, 'eval_accuracy': 0.7663198600497192, 'eval_f1': 0.7521354107210186, 'eval_precision': 0.7512750709969251, 'eval_recall': 0.7663198600497192, 'eval_runtime': 97.4658, 'eval_samples_per_second': 111.434, 'eval_steps_per_second': 13.933, 'epoch': 3.08}


                                                       
 47%|████▋     | 9000/19299 [1:00:33<34:24,  4.99it/s]

{'eval_loss': 0.7848901748657227, 'eval_accuracy': 0.7526010496271062, 'eval_f1': 0.7500519595720716, 'eval_precision': 0.7489956488216549, 'eval_recall': 0.7526010496271062, 'eval_runtime': 96.5986, 'eval_samples_per_second': 112.434, 'eval_steps_per_second': 14.058, 'epoch': 3.26}


                                                         
 49%|████▉     | 9500/19299 [1:03:51<32:35,  5.01it/s]

{'eval_loss': 0.799060583114624, 'eval_accuracy': 0.7590461283491391, 'eval_f1': 0.7458676307318718, 'eval_precision': 0.7458613596477447, 'eval_recall': 0.7590461283491391, 'eval_runtime': 97.55, 'eval_samples_per_second': 111.338, 'eval_steps_per_second': 13.921, 'epoch': 3.45}


                                                         
 52%|█████▏    | 10000/19299 [1:07:06<31:00,  5.00it/s]

{'eval_loss': 0.8517731428146362, 'eval_accuracy': 0.7485498572875425, 'eval_f1': 0.7445785489361729, 'eval_precision': 0.7414986707029605, 'eval_recall': 0.7485498572875425, 'eval_runtime': 94.8247, 'eval_samples_per_second': 114.538, 'eval_steps_per_second': 14.321, 'epoch': 3.63}


                                                          
 54%|█████▍    | 10500/19299 [1:10:19<28:27,  5.15it/s]

{'eval_loss': 0.8575529456138611, 'eval_accuracy': 0.7580333302642482, 'eval_f1': 0.7490732464038934, 'eval_precision': 0.7451016982465196, 'eval_recall': 0.7580333302642482, 'eval_runtime': 94.8164, 'eval_samples_per_second': 114.548, 'eval_steps_per_second': 14.322, 'epoch': 3.81}


                                                          
 57%|█████▋    | 11000/19299 [1:13:31<26:49,  5.16it/s]

{'eval_loss': 0.9002107977867126, 'eval_accuracy': 0.7526010496271062, 'eval_f1': 0.748532447791032, 'eval_precision': 0.7466476610161179, 'eval_recall': 0.7526010496271062, 'eval_runtime': 94.81, 'eval_samples_per_second': 114.555, 'eval_steps_per_second': 14.323, 'epoch': 3.99}


 57%|█████▋    | 11024/19299 [1:13:37<28:43,  4.80it/s]   

{'loss': 0.4218, 'learning_rate': 1.0718217732012177e-05, 'epoch': 4.0}


                                                       
 60%|█████▉    | 11500/19299 [1:16:49<26:02,  4.99it/s]

{'eval_loss': 1.0175721645355225, 'eval_accuracy': 0.7610717245189209, 'eval_f1': 0.7426393370375057, 'eval_precision': 0.7424135798612074, 'eval_recall': 0.7610717245189209, 'eval_runtime': 97.7039, 'eval_samples_per_second': 111.162, 'eval_steps_per_second': 13.899, 'epoch': 4.17}


                                                          
 62%|██████▏   | 12000/19299 [1:20:07<24:23,  4.99it/s]

{'eval_loss': 0.9951866865158081, 'eval_accuracy': 0.7549949360095756, 'eval_f1': 0.7466203620737191, 'eval_precision': 0.7425736431595161, 'eval_recall': 0.7549949360095756, 'eval_runtime': 96.8985, 'eval_samples_per_second': 112.086, 'eval_steps_per_second': 14.015, 'epoch': 4.35}


                                                          
 65%|██████▍   | 12500/19299 [1:23:26<22:38,  5.00it/s]

{'eval_loss': 1.0549551248550415, 'eval_accuracy': 0.7456956081392137, 'eval_f1': 0.7438484850784487, 'eval_precision': 0.7424502874471023, 'eval_recall': 0.7456956081392137, 'eval_runtime': 98.03, 'eval_samples_per_second': 110.793, 'eval_steps_per_second': 13.853, 'epoch': 4.53}


                                                          
 67%|██████▋   | 13000/19299 [1:26:43<20:53,  5.03it/s]

{'eval_loss': 1.0537890195846558, 'eval_accuracy': 0.7537059202651689, 'eval_f1': 0.7465079568495879, 'eval_precision': 0.7423083770524582, 'eval_recall': 0.7537059202651689, 'eval_runtime': 97.0683, 'eval_samples_per_second': 111.89, 'eval_steps_per_second': 13.99, 'epoch': 4.72}


                                                          
 70%|██████▉   | 13500/19299 [1:30:09<19:23,  4.98it/s]

{'eval_loss': 0.9979603886604309, 'eval_accuracy': 0.7548107909032318, 'eval_f1': 0.7474661432204066, 'eval_precision': 0.7435532749964809, 'eval_recall': 0.7548107909032318, 'eval_runtime': 97.7599, 'eval_samples_per_second': 111.099, 'eval_steps_per_second': 13.891, 'epoch': 4.9}


 71%|███████▏  | 13780/19299 [1:31:05<18:18,  5.03it/s]   

{'loss': 0.3548, 'learning_rate': 7.148500744770417e-06, 'epoch': 5.0}


                                                       
 73%|███████▎  | 14000/19299 [1:33:26<17:40,  5.00it/s]

{'eval_loss': 1.1180205345153809, 'eval_accuracy': 0.7526931221802781, 'eval_f1': 0.7457154064528313, 'eval_precision': 0.7414120939724116, 'eval_recall': 0.7526931221802781, 'eval_runtime': 96.9141, 'eval_samples_per_second': 112.068, 'eval_steps_per_second': 14.012, 'epoch': 5.08}


                                                          
 75%|███████▌  | 14500/19299 [1:36:45<16:01,  4.99it/s]

{'eval_loss': 1.150274395942688, 'eval_accuracy': 0.7461559709050732, 'eval_f1': 0.7434284467711635, 'eval_precision': 0.7433543265803839, 'eval_recall': 0.7461559709050732, 'eval_runtime': 98.1309, 'eval_samples_per_second': 110.679, 'eval_steps_per_second': 13.839, 'epoch': 5.26}


                                                          
 78%|███████▊  | 15000/19299 [1:40:03<14:24,  4.97it/s]

{'eval_loss': 1.2045056819915771, 'eval_accuracy': 0.7490102200534021, 'eval_f1': 0.744623005237226, 'eval_precision': 0.7414781591053871, 'eval_recall': 0.7490102200534021, 'eval_runtime': 96.8908, 'eval_samples_per_second': 112.095, 'eval_steps_per_second': 14.016, 'epoch': 5.44}


                                                          
 80%|████████  | 15500/19299 [1:43:22<12:39,  5.00it/s]

{'eval_loss': 1.1370725631713867, 'eval_accuracy': 0.7560998066476383, 'eval_f1': 0.7446041315527394, 'eval_precision': 0.7407282098329865, 'eval_recall': 0.7560998066476383, 'eval_runtime': 98.0222, 'eval_samples_per_second': 110.801, 'eval_steps_per_second': 13.854, 'epoch': 5.62}


                                                          
 83%|████████▎ | 16000/19299 [1:46:39<10:55,  5.04it/s]

{'eval_loss': 1.219282627105713, 'eval_accuracy': 0.7478132768621674, 'eval_f1': 0.7426457963598913, 'eval_precision': 0.7393485736711732, 'eval_recall': 0.7478132768621674, 'eval_runtime': 96.8479, 'eval_samples_per_second': 112.145, 'eval_steps_per_second': 14.022, 'epoch': 5.8}


                                                          
 85%|████████▌ | 16500/19299 [1:49:58<09:20,  4.99it/s]

{'eval_loss': 1.194361686706543, 'eval_accuracy': 0.7525089770739343, 'eval_f1': 0.746429163653703, 'eval_precision': 0.743947939053839, 'eval_recall': 0.7525089770739343, 'eval_runtime': 97.8638, 'eval_samples_per_second': 110.981, 'eval_steps_per_second': 13.876, 'epoch': 5.98}


 86%|████████▌ | 16537/19299 [1:50:05<09:11,  5.00it/s]   

{'loss': 0.2918, 'learning_rate': 3.5787837575286576e-06, 'epoch': 6.0}


                                                       
 88%|████████▊ | 17000/19299 [1:53:15<07:42,  4.97it/s]

{'eval_loss': 1.242308497428894, 'eval_accuracy': 0.7458797532455576, 'eval_f1': 0.7410606739087273, 'eval_precision': 0.7380768881910048, 'eval_recall': 0.7458797532455576, 'eval_runtime': 97.0247, 'eval_samples_per_second': 111.941, 'eval_steps_per_second': 13.996, 'epoch': 6.17}


                                                          
 91%|█████████ | 17500/19299 [1:56:34<06:02,  4.97it/s]

{'eval_loss': 1.2791359424591064, 'eval_accuracy': 0.7495626553724335, 'eval_f1': 0.7417198256595304, 'eval_precision': 0.7371932785666107, 'eval_recall': 0.7495626553724335, 'eval_runtime': 97.9746, 'eval_samples_per_second': 110.855, 'eval_steps_per_second': 13.861, 'epoch': 6.35}


                                                          
 93%|█████████▎| 18000/19299 [1:59:54<04:21,  4.97it/s]

{'eval_loss': 1.2837837934494019, 'eval_accuracy': 0.7489181475002302, 'eval_f1': 0.7425067013595325, 'eval_precision': 0.738398064891284, 'eval_recall': 0.7489181475002302, 'eval_runtime': 99.4784, 'eval_samples_per_second': 109.18, 'eval_steps_per_second': 13.651, 'epoch': 6.53}


                                                          
 96%|█████████▌| 18500/19299 [2:03:18<02:54,  4.57it/s]

{'eval_loss': 1.306471347808838, 'eval_accuracy': 0.7440383021821195, 'eval_f1': 0.7405255652268454, 'eval_precision': 0.7382415913713938, 'eval_recall': 0.7440383021821195, 'eval_runtime': 101.3549, 'eval_samples_per_second': 107.158, 'eval_steps_per_second': 13.398, 'epoch': 6.71}


                                                         
 98%|█████████▊| 19000/19299 [2:06:41<01:00,  4.92it/s]

{'eval_loss': 1.286392092704773, 'eval_accuracy': 0.749102292606574, 'eval_f1': 0.7431863611860364, 'eval_precision': 0.7400404821034527, 'eval_recall': 0.749102292606574, 'eval_runtime': 99.8074, 'eval_samples_per_second': 108.82, 'eval_steps_per_second': 13.606, 'epoch': 6.89}


100%|█████████▉| 19292/19299 [2:07:42<00:01,  4.93it/s]  

{'loss': 0.2545, 'learning_rate': 9.066770286898517e-09, 'epoch': 7.0}


100%|██████████| 19299/19299 [2:07:44<00:00,  2.52it/s]

{'train_runtime': 7664.0357, 'train_samples_per_second': 20.139, 'train_steps_per_second': 2.518, 'train_loss': 0.45131043809900334, 'epoch': 7.0}





TrainOutput(global_step=19299, training_loss=0.45131043809900334, metrics={'train_runtime': 7664.0357, 'train_samples_per_second': 20.139, 'train_steps_per_second': 2.518, 'train_loss': 0.45131043809900334, 'epoch': 7.0})

# Evaluación del modelo

Ahora que hemos entrenado nuestro modelo, podemos evaluarlo en el conjunto de datos de prueba, para ver cómo se desempeña en datos que nunca ha visto antes.

In [135]:
trainer.evaluate()

100%|██████████| 1358/1358 [01:42<00:00, 13.31it/s]


{'eval_loss': 0.638457715511322,
 'eval_accuracy': 0.7639259736672498,
 'eval_f1': 0.7514273856034948,
 'eval_precision': 0.7494562038253741,
 'eval_recall': 0.7639259736672498,
 'eval_runtime': 102.3634,
 'eval_samples_per_second': 106.102,
 'eval_steps_per_second': 13.266,
 'epoch': 7.0}

# Interpretación del modelo

Ahora veremos que palabras son las que más influyen en la predicción del modelo.

In [169]:
# Interpretabilidad del modelo
from transformers_interpret import SequenceClassificationExplainer

# load the explainer
explainer = SequenceClassificationExplainer(
    model,
    tokenizer
)

sample = "best app! I really like it. better than twitter"
explanation = explainer(sample)
print(f"\nsample: \"{sample}\" -> Debería ser Positivo (Label 2)")
print("Visualization: ")
explainer.visualize(sample)

sample = "its neither good nor bad."
explanation = explainer(sample)
print(f"\nsample: \"{sample}\" -> Debería ser Neutral (Label 1)")
print("Visualization: ")
explainer.visualize(sample)

sample = "I recently installed this app due to its popularity, but my initial impression is quite negative. I haven't found anything that I like about it."
explanation = explainer(sample)
print(f"\nsample: \"{sample}\" -> Debería ser Negativo (Label 0)")
print("Visualization: ")
explainer.visualize(sample)

""


sample: "best app! I really like it. better than twitter" -> Debería ser Positivo (Label 2)
Visualization: 


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
2.0,LABEL_2 (0.87),LABEL_2,1.43,[CLS] best app ! i really like it . better than twitter [SEP]
,,,,



sample: "its neither good nor bad." -> Debería ser Neutral (Label 1)
Visualization: 


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,LABEL_1 (0.44),LABEL_1,0.87,[CLS] its neither good nor bad . [SEP]
,,,,



sample: "I recently installed this app due to its popularity, but my initial impression is quite negative. I haven't found anything that I like about it." -> Debería ser Negativo (Label 0)
Visualization: 


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,LABEL_0 (0.46),LABEL_0,0.62,"[CLS] i recently installed this app due to its popularity , but my initial impression is quite negative . i haven ' t found anything that i like about it . [SEP]"
,,,,


''