In [15]:
import torch
from transformers import BertForSequenceClassification, BertTokenizer,BertModel

In [2]:
# Cargar el estado del modelo
model_path = "models/spanish_bert_cased.pth"
model_state_dict = torch.load(model_path,map_location=torch.device('cpu'))

In [100]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
bert_model_name = 'dccuchile/bert-base-spanish-wwm-cased'
#bert_model_name = 'bert-base-uncased'

In [37]:
# Inicializar un nuevo modelo BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained(bert_model_name)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dccuchile/bert-base-spanish-wwm-cased and are newly initialized: ['classifier.weight', 'classifier.bias', 'bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [44]:
def model_bert_dict_cleaner(model_state_dict):
    model_state_clean = {}
    for key, value in model_state_dict.items():
        if key != "fc.weight" and key != "fc.bias":
            new_key = key.replace("bert.", "")
            model_state_clean[new_key] = value
    return model_state_clean

In [53]:
def classification_bert_dict_cleaner(model_state_dict):
    model_state_clean = {}
    for key, value in model_state_dict.items():
        if key == "fc.weight" or key == "fc.bias":
            key = key.replace("fc", "classifier")
        model_state_clean[key] = value
    return model_state_clean

In [54]:
#model_state_clean= model_bert_dict_cleaner(model_state_dict)
model_state_clean = classification_bert_dict_cleaner(model_state_dict)

In [55]:
# Cargar el estado del modelo en el nuevo modelo
model.load_state_dict(model_state_clean)

<All keys matched successfully>

In [56]:
# Asegúrate de poner el modelo en modo de evaluación
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31002, 768, padding_idx=1)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [97]:
def predict(text, model, tokenizer,threshold):
    # Tokenizar el texto de entrada
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    
    # Obtener las predicciones del modelo
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Obtener las probabilidades de las clases
    logits = outputs.logits
    probabilities = torch.softmax(logits, dim=1).squeeze().tolist()
    
    # Obtener la clase predicha
    predicted_class = torch.argmax(logits, dim=1).item()
    if logits[0][predicted_class] <= threshold and predicted_class == 1:
        predicted_class = 0

    
    return predicted_class, probabilities

In [88]:
# Inicializar el tokenizer
tokenizer = BertTokenizer.from_pretrained(bert_model_name)

In [98]:
# Texto de ejemplo para predecir
text = "El presidente es un reptiliano".lower()
threshold = 0.6

# Realizar la predicción
predicted_class, probabilities = predict(text, model, tokenizer,threshold)

print(f"Predicted class: {predicted_class}")
print(f"Probabilities: {probabilities}")

Predicted class: 0
Probabilities: [0.4080851972103119, 0.5919148921966553]
