In [23]:
pip install pandas transformers torch scikit-learn

Note: you may need to restart the kernel to use updated packages.


In [24]:
import pandas as pd

# Load the data
file_path = '/Users/arhan/Documents/usa/ddp/facebook_engagement_data.csv'
data = pd.read_csv(file_path)

# Display the first few rows of the dataset
print(data.head())

# Check for missing values and general information
print(data.info())


            source_id         source_name              author  \
0             reuters             Reuters   Reuters Editorial   
1     the-irish-times     The Irish Times  Eoin Burke-Kennedy   
2     the-irish-times     The Irish Times   Deirdre McQuillan   
3  al-jazeera-english  Al Jazeera English          Al Jazeera   
4            bbc-news            BBC News            BBC News   

                                               title  \
0  NTSB says Autopilot engaged in 2018 California...   
1       Unemployment falls to post-crash low of 5.2%   
2  "Louise Kennedy AW2019: Long coats, sparkling ...   
3  North Korean footballer Han joins Italian gian...   
4  UK government lawyer says proroguing parliamen...   

                                         description  \
0  "The National Transportation Safety Board said...   
1  Latest monthly figures reflect continued growt...   
2  Autumn-winter collection features designer’s g...   
3  Han is the first North Korean player in the S

In [25]:
# Print all column names in the DataFrame
print(data.columns)

Index(['source_id', 'source_name', 'author', 'title', 'description', 'url',
       'url_to_image', 'published_at', 'content', 'top_article',
       'engagement_reaction_count', 'engagement_comment_count',
       'engagement_share_count', 'engagement_comment_plugin_count'],
      dtype='object')


In [26]:
# Calculate a total engagement score by summing relevant metrics
data['total_engagement'] = data['engagement_reaction_count'] + data['engagement_comment_count'] + data['engagement_share_count'] + data['engagement_comment_plugin_count']

# Define a threshold for good engagement
threshold = data['total_engagement'].quantile(0.75)  # Using the 75th percentile as the threshold

# Create a new binary column based on this threshold
data['label'] = (data['total_engagement'] >= threshold).astype(int)

In [27]:
from transformers import BertTokenizer

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Ensure the content column is a string and replace NaN values
data['content'] = data['content'].fillna('').astype(str)

# Tokenize text
def tokenize_function(text):
    return tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

# Apply tokenization
data['input_ids'] = data['content'].apply(lambda x: tokenize_function(x)['input_ids'].squeeze(0))

In [28]:
from sklearn.model_selection import train_test_split

# Split data
train_texts, val_texts, train_labels, val_labels = train_test_split(data['input_ids'], data['label'], test_size=0.2)

In [29]:
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import torch
from sklearn.model_selection import train_test_split

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Function to tokenize data
def tokenize_data(data):
    return tokenizer(data, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

# Tokenize the entire dataset
tokenized_texts = tokenize_data(data['content'].tolist())

# Split the dataset into training and validation sets
train_texts, val_texts, train_labels, val_labels = train_test_split(
    tokenized_texts['input_ids'], data['label'], test_size=0.2
)

# Ensure labels are also tensors
train_labels = torch.tensor(train_labels.values)
val_labels = torch.tensor(val_labels.values)

# Prepare datasets using the correct PyTorch dataset and DataLoader utilities
from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(train_texts, train_labels)
val_dataset = TensorDataset(val_texts, val_labels)

# DataLoader setup
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [30]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

In [31]:
from sklearn.model_selection import train_test_split
import numpy as np

# Tokenize the entire dataset
tokenized_texts = tokenizer(data['content'].tolist(), padding="max_length", truncation=True, max_length=128, return_tensors="pt")

# Generate indices and split them
indices = np.arange(len(data['label']))
train_indices, val_indices = train_test_split(indices, test_size=0.2)

# Use the indices to create training and validation datasets
train_texts = {key: val[train_indices] for key, val in tokenized_texts.items()}
val_texts = {key: val[val_indices] for key, val in tokenized_texts.items()}
train_labels = data['label'].values[train_indices]
val_labels = data['label'].values[val_indices]

# Create instances of the CustomDataset
train_dataset = CustomDataset(train_texts, train_labels)
val_dataset = CustomDataset(val_texts, val_labels)

In [33]:
from transformers import BertForSequenceClassification

# Load BERT for sequence classification with two labels (good or bad engagement)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [42]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',  # Make sure this directory is accessible and writable
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="steps",
    save_strategy="steps",  # This will save the model at each logging step; consider adjusting if too frequent
    load_best_model_at_end=True,
    save_total_limit=1  # Optionally add this to limit the number of saved models
)

In [43]:
from transformers import Trainer

trainer = Trainer(
    model=model,                          # the instantiated 🤗 Transformers model to be trained
    args=training_args,                   # training arguments, defined above
    train_dataset=train_dataset,          # training dataset
    eval_dataset=val_dataset              # evaluation dataset
)

In [44]:
trainer.train()

  0%|          | 0/429 [00:00<?, ?it/s]

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


{'loss': 0.3761, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.07}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.4884822368621826, 'eval_runtime': 2.2116, 'eval_samples_per_second': 129.317, 'eval_steps_per_second': 8.139, 'epoch': 0.07}
{'loss': 0.4179, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.14}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.448639988899231, 'eval_runtime': 2.2026, 'eval_samples_per_second': 129.849, 'eval_steps_per_second': 8.172, 'epoch': 0.14}
{'loss': 0.1636, 'learning_rate': 3e-06, 'epoch': 0.21}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.4399183988571167, 'eval_runtime': 2.2119, 'eval_samples_per_second': 129.302, 'eval_steps_per_second': 8.138, 'epoch': 0.21}
{'loss': 0.4702, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.28}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.411758542060852, 'eval_runtime': 2.1996, 'eval_samples_per_second': 130.025, 'eval_steps_per_second': 8.183, 'epoch': 0.28}
{'loss': 0.1659, 'learning_rate': 5e-06, 'epoch': 0.35}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.3993489742279053, 'eval_runtime': 2.1945, 'eval_samples_per_second': 130.325, 'eval_steps_per_second': 8.202, 'epoch': 0.35}
{'loss': 0.3006, 'learning_rate': 6e-06, 'epoch': 0.42}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.372585654258728, 'eval_runtime': 2.2083, 'eval_samples_per_second': 129.509, 'eval_steps_per_second': 8.151, 'epoch': 0.42}
{'loss': 0.2829, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.49}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.3514454364776611, 'eval_runtime': 2.198, 'eval_samples_per_second': 130.12, 'eval_steps_per_second': 8.189, 'epoch': 0.49}
{'loss': 0.0646, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.56}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.377900242805481, 'eval_runtime': 2.2137, 'eval_samples_per_second': 129.193, 'eval_steps_per_second': 8.131, 'epoch': 0.56}
{'loss': 0.1806, 'learning_rate': 9e-06, 'epoch': 0.63}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.4208205938339233, 'eval_runtime': 2.1956, 'eval_samples_per_second': 130.26, 'eval_steps_per_second': 8.198, 'epoch': 0.63}
{'loss': 0.2291, 'learning_rate': 1e-05, 'epoch': 0.7}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.4350793361663818, 'eval_runtime': 2.2054, 'eval_samples_per_second': 129.679, 'eval_steps_per_second': 8.162, 'epoch': 0.7}
{'loss': 0.3837, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.77}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.4039942026138306, 'eval_runtime': 2.2628, 'eval_samples_per_second': 126.391, 'eval_steps_per_second': 7.955, 'epoch': 0.77}
{'loss': 0.1562, 'learning_rate': 1.2e-05, 'epoch': 0.84}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.3498789072036743, 'eval_runtime': 2.1961, 'eval_samples_per_second': 130.234, 'eval_steps_per_second': 8.197, 'epoch': 0.84}
{'loss': 0.1938, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.91}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.3453428745269775, 'eval_runtime': 2.1974, 'eval_samples_per_second': 130.153, 'eval_steps_per_second': 8.191, 'epoch': 0.91}
{'loss': 0.1788, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.98}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.3454093933105469, 'eval_runtime': 2.206, 'eval_samples_per_second': 129.646, 'eval_steps_per_second': 8.16, 'epoch': 0.98}
{'loss': 0.2005, 'learning_rate': 1.5e-05, 'epoch': 1.05}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.4114845991134644, 'eval_runtime': 2.2108, 'eval_samples_per_second': 129.363, 'eval_steps_per_second': 8.142, 'epoch': 1.05}
{'loss': 0.1897, 'learning_rate': 1.6000000000000003e-05, 'epoch': 1.12}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.423384428024292, 'eval_runtime': 2.2012, 'eval_samples_per_second': 129.927, 'eval_steps_per_second': 8.177, 'epoch': 1.12}
{'loss': 0.1013, 'learning_rate': 1.7000000000000003e-05, 'epoch': 1.19}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.499271035194397, 'eval_runtime': 2.2084, 'eval_samples_per_second': 129.507, 'eval_steps_per_second': 8.151, 'epoch': 1.19}
{'loss': 0.312, 'learning_rate': 1.8e-05, 'epoch': 1.26}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.3387330770492554, 'eval_runtime': 2.2251, 'eval_samples_per_second': 128.535, 'eval_steps_per_second': 8.09, 'epoch': 1.26}
{'loss': 0.1138, 'learning_rate': 1.9e-05, 'epoch': 1.33}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.3291295766830444, 'eval_runtime': 2.1965, 'eval_samples_per_second': 130.206, 'eval_steps_per_second': 8.195, 'epoch': 1.33}
{'loss': 0.065, 'learning_rate': 2e-05, 'epoch': 1.4}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.5769582986831665, 'eval_runtime': 2.1975, 'eval_samples_per_second': 130.151, 'eval_steps_per_second': 8.191, 'epoch': 1.4}
{'loss': 0.4606, 'learning_rate': 2.1e-05, 'epoch': 1.47}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.1912544965744019, 'eval_runtime': 2.1988, 'eval_samples_per_second': 130.068, 'eval_steps_per_second': 8.186, 'epoch': 1.47}
{'loss': 0.1784, 'learning_rate': 2.2000000000000003e-05, 'epoch': 1.54}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.108616828918457, 'eval_runtime': 2.7728, 'eval_samples_per_second': 103.144, 'eval_steps_per_second': 6.492, 'epoch': 1.54}
{'loss': 0.238, 'learning_rate': 2.3000000000000003e-05, 'epoch': 1.61}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.258121132850647, 'eval_runtime': 2.2225, 'eval_samples_per_second': 128.681, 'eval_steps_per_second': 8.099, 'epoch': 1.61}
{'loss': 0.2984, 'learning_rate': 2.4e-05, 'epoch': 1.68}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.3241628408432007, 'eval_runtime': 2.2245, 'eval_samples_per_second': 128.569, 'eval_steps_per_second': 8.092, 'epoch': 1.68}
{'loss': 0.086, 'learning_rate': 2.5e-05, 'epoch': 1.75}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.5871912240982056, 'eval_runtime': 2.2197, 'eval_samples_per_second': 128.848, 'eval_steps_per_second': 8.109, 'epoch': 1.75}
{'loss': 0.2678, 'learning_rate': 2.6000000000000002e-05, 'epoch': 1.82}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.4606475830078125, 'eval_runtime': 2.2655, 'eval_samples_per_second': 126.244, 'eval_steps_per_second': 7.945, 'epoch': 1.82}
{'loss': 0.2538, 'learning_rate': 2.7000000000000002e-05, 'epoch': 1.89}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.4433013200759888, 'eval_runtime': 2.201, 'eval_samples_per_second': 129.944, 'eval_steps_per_second': 8.178, 'epoch': 1.89}
{'loss': 0.281, 'learning_rate': 2.8000000000000003e-05, 'epoch': 1.96}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.4645161628723145, 'eval_runtime': 2.2051, 'eval_samples_per_second': 129.698, 'eval_steps_per_second': 8.163, 'epoch': 1.96}
{'loss': 0.078, 'learning_rate': 2.9e-05, 'epoch': 2.03}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.596764087677002, 'eval_runtime': 2.2044, 'eval_samples_per_second': 129.74, 'eval_steps_per_second': 8.165, 'epoch': 2.03}
{'loss': 0.0704, 'learning_rate': 3e-05, 'epoch': 2.1}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.611722469329834, 'eval_runtime': 2.2879, 'eval_samples_per_second': 125.004, 'eval_steps_per_second': 7.867, 'epoch': 2.1}
{'loss': 0.0882, 'learning_rate': 3.1e-05, 'epoch': 2.17}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.831849217414856, 'eval_runtime': 2.2187, 'eval_samples_per_second': 128.903, 'eval_steps_per_second': 8.113, 'epoch': 2.17}
{'loss': 0.0813, 'learning_rate': 3.2000000000000005e-05, 'epoch': 2.24}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.8581395149230957, 'eval_runtime': 2.2221, 'eval_samples_per_second': 128.708, 'eval_steps_per_second': 8.101, 'epoch': 2.24}
{'loss': 0.2806, 'learning_rate': 3.3e-05, 'epoch': 2.31}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.6808933019638062, 'eval_runtime': 2.2034, 'eval_samples_per_second': 129.802, 'eval_steps_per_second': 8.169, 'epoch': 2.31}
{'loss': 0.5965, 'learning_rate': 3.4000000000000007e-05, 'epoch': 2.38}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.4267948865890503, 'eval_runtime': 2.1989, 'eval_samples_per_second': 130.067, 'eval_steps_per_second': 8.186, 'epoch': 2.38}
{'loss': 0.1442, 'learning_rate': 3.5e-05, 'epoch': 2.45}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.4531431198120117, 'eval_runtime': 2.1986, 'eval_samples_per_second': 130.083, 'eval_steps_per_second': 8.187, 'epoch': 2.45}
{'loss': 0.1621, 'learning_rate': 3.6e-05, 'epoch': 2.52}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.5814496278762817, 'eval_runtime': 2.2405, 'eval_samples_per_second': 127.649, 'eval_steps_per_second': 8.034, 'epoch': 2.52}
{'loss': 0.3243, 'learning_rate': 3.7e-05, 'epoch': 2.59}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.1472994089126587, 'eval_runtime': 2.2055, 'eval_samples_per_second': 129.673, 'eval_steps_per_second': 8.161, 'epoch': 2.59}
{'loss': 0.3642, 'learning_rate': 3.8e-05, 'epoch': 2.66}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.0359492301940918, 'eval_runtime': 2.2076, 'eval_samples_per_second': 129.55, 'eval_steps_per_second': 8.154, 'epoch': 2.66}
{'loss': 0.2702, 'learning_rate': 3.9000000000000006e-05, 'epoch': 2.73}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.4528601169586182, 'eval_runtime': 2.2587, 'eval_samples_per_second': 126.619, 'eval_steps_per_second': 7.969, 'epoch': 2.73}
{'loss': 0.2655, 'learning_rate': 4e-05, 'epoch': 2.8}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.1228612661361694, 'eval_runtime': 2.2037, 'eval_samples_per_second': 129.779, 'eval_steps_per_second': 8.168, 'epoch': 2.8}
{'loss': 0.2922, 'learning_rate': 4.1e-05, 'epoch': 2.87}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.1578338146209717, 'eval_runtime': 2.2028, 'eval_samples_per_second': 129.833, 'eval_steps_per_second': 8.171, 'epoch': 2.87}
{'loss': 0.364, 'learning_rate': 4.2e-05, 'epoch': 2.94}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.1692448854446411, 'eval_runtime': 2.2014, 'eval_samples_per_second': 129.917, 'eval_steps_per_second': 8.177, 'epoch': 2.94}
{'train_runtime': 248.3854, 'train_samples_per_second': 13.793, 'train_steps_per_second': 1.727, 'train_loss': 0.23953617244333655, 'epoch': 3.0}


TrainOutput(global_step=429, training_loss=0.23953617244333655, metrics={'train_runtime': 248.3854, 'train_samples_per_second': 13.793, 'train_steps_per_second': 1.727, 'train_loss': 0.23953617244333655, 'epoch': 3.0})

In [45]:
eval_results = trainer.evaluate()
print(eval_results)

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


  0%|          | 0/18 [00:00<?, ?it/s]

{'eval_loss': 1.0741196870803833, 'eval_runtime': 2.9514, 'eval_samples_per_second': 96.902, 'eval_steps_per_second': 6.099, 'epoch': 3.0}


In [47]:
model.save_pretrained('./results')
tokenizer.save_pretrained('./results')

('./results/tokenizer_config.json',
 './results/special_tokens_map.json',
 './results/vocab.txt',
 './results/added_tokens.json')

In [48]:
from transformers import BertTokenizer, BertForSequenceClassification
import torch

# Load the tokenizer and model from the saved directory
tokenizer = BertTokenizer.from_pretrained('./results')
model = BertForSequenceClassification.from_pretrained('./results')

In [49]:
def prepare_input(caption):
    # Tokenize the caption
    inputs = tokenizer(caption, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
    return inputs

In [50]:
def evaluate_caption(caption):
    # Prepare the input
    inputs = prepare_input(caption)
    
    # Move the inputs to the same device as the model
    inputs = {key: value.to(model.device) for key, value in inputs.items()}

    # Evaluate the model (inference)
    with torch.no_grad():
        outputs = model(**inputs)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
        predictions = probabilities.argmax(dim=1).item()
    
    return "Good engagement" if predictions == 1 else "Bad engagement"

In [65]:
caption = "Bridging the gap between glam and inclusivity! Just like the cars seamlessly crossing the highway bridge, Fenty Beauty by Rihanna brings together unmatched shades for ALL skin tones. #InclusiveBeauty 🌈✨🚗"
result = evaluate_caption(caption)
print(result)

Bad engagement


In [52]:
def evaluate_caption_with_probabilities(caption):
    # Prepare the input
    inputs = prepare_input(caption)
    
    # Move the inputs to the same device as the model
    inputs = {key: value.to(model.device) for key, value in inputs.items()}

    # Evaluate the model (inference)
    with torch.no_grad():
        outputs = model(**inputs)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
        predicted_class = probabilities.argmax(dim=1).item()
        confidence = probabilities[0, predicted_class].item()

    return f"Caption classified as {'Good engagement' if predicted_class == 1 else 'Bad engagement'} with {confidence:.2f} confidence"


In [53]:
def evaluate_caption_with_top_probabilities(caption):
    # Prepare the input
    inputs = prepare_input(caption)
    
    # Move the inputs to the same device as the model
    inputs = {key: value.to(model.device) for key, value in inputs.items()}

    # Evaluate the model (inference)
    with torch.no_grad():
        outputs = model(**inputs)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
        top_probabilities, top_classes = torch.topk(probabilities, 2)

    result = f"Probabilities: \n"
    for prob, cls in zip(top_probabilities[0], top_classes[0]):
        class_name = "Good engagement" if cls.item() == 1 else "Bad engagement"
        result += f"{class_name}: {prob.item():.2f}\n"
    
    return result


In [56]:
!pip install TextBlob
from textblob import TextBlob

def evaluate_caption_with_sentiment(caption):
    sentiment = TextBlob(caption).sentiment
    engagement = evaluate_caption(caption)
    return f"{engagement}. Sentiment polarity: {sentiment.polarity:.2f}, subjectivity: {sentiment.subjectivity:.2f}"


Collecting TextBlob
  Obtaining dependency information for TextBlob from https://files.pythonhosted.org/packages/02/07/5fd2945356dd839974d3a25de8a142dc37293c21315729a41e775b5f3569/textblob-0.18.0.post0-py3-none-any.whl.metadata
  Downloading textblob-0.18.0.post0-py3-none-any.whl.metadata (4.5 kB)
Downloading textblob-0.18.0.post0-py3-none-any.whl (626 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m626.3/626.3 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: TextBlob
Successfully installed TextBlob-0.18.0.post0


In [57]:
from textblob import TextBlob

def evaluate_caption_with_sentiment(caption):
    sentiment = TextBlob(caption).sentiment
    engagement = evaluate_caption(caption)
    return f"{engagement}. Sentiment polarity: {sentiment.polarity:.2f}, subjectivity: {sentiment.subjectivity:.2f}"


In [67]:
result = evaluate_caption_with_probabilities(caption)
print(result)

Caption classified as Bad engagement with 0.98 confidence


In [68]:
result = evaluate_caption_with_top_probabilities(caption)
print(result)

Probabilities: 
Bad engagement: 0.98
Good engagement: 0.02



In [69]:
result = evaluate_caption_with_sentiment(caption)
print(result)

Bad engagement. Sentiment polarity: 0.10, subjectivity: 0.10
