In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import get_linear_schedule_with_warmup
import pandas as pd
import math
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.metrics import mean_absolute_error, mean_squared_error
from tqdm import tqdm
import os
import itertools
import numpy as np

In [2]:
class TextValueDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length, column):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.column = column

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        text = row['generated_text']
        label = (row[self.column] - 1)/4 # Convert value to 0-1
        
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.float)
        }

In [3]:
# Load dataset
def load_dataset(file_path):
    if '.tsv' in file_path:
        df = pd.read_csv(file_path, sep='\t')
    else:
        df = pd.read_csv(file_path)
    return df

# Create DataLoader
def create_dataloader(df, tokenizer, max_length, batch_size, column):
    dataset = TextValueDataset(df, tokenizer, max_length, column)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [4]:
def evaluate(model, dataloader, device):
    model.eval()
    predictions, true_labels, pos_probs = [], [], []

    eval_loss = 0
    eval_loop = tqdm(dataloader, desc="Evaluating", leave=False)
    with torch.no_grad():
        for batch in eval_loop:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # print(input_ids, attention_mask, labels)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            logits = outputs.logits
            # print(logits)
            # preds = torch.argmax(logits, dim=1)
            # print(preds)
            
            loss = outputs.loss
            eval_loss += loss.item()
            eval_loop.set_postfix(loss=loss.item())
            
            predictions.extend(logits.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    # accuracy = accuracy_score(true_labels, predictions)
    # precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted')
    
    mae = mean_absolute_error(true_labels, predictions)
    mse = mean_squared_error(true_labels, predictions)
    return eval_loss, np.sqrt(mse), mse, mae

In [5]:
data_dict = { 
                'model' : [],
                 'dimension' : [],
                 'RMSE' : [],
                 'MSE' : [],
                 'MAE' : []
            }

In [6]:
def get_performance(tokname, modeldir):
    max_length = 256
    batch_size = 16
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # tokenizer = AutoTokenizer.from_pretrained('google-t5/t5-small')
    tokenizer = AutoTokenizer.from_pretrained(tokname)
    test_df = load_dataset('../data/crowd-enVent-test.tsv')
    for model_name in os.listdir(modeldir):
        # model_name='./models/google-t5/suddenness_google-t5/'
        if '.ipy' in model_name:
            continue
        column = str.join('_', model_name.split('_')[:-1])
        print(column)
        test_loader = create_dataloader(test_df, tokenizer, max_length, batch_size, column)
        model = AutoModelForSequenceClassification.from_pretrained(os.path.join(modeldir, model_name))
        model = model.to(device)
        _, test_rmse, test_mse, test_mae = evaluate(model, test_loader, device)
        
        
        print(f'Model Name: {model_name}, Test RMSE: {test_rmse:.4f}, Test MSE: {test_mse:.4f}, Test MAE: {test_mae:.4f}')
        data_dict['model'].append(model_name.split('_')[-1])
        data_dict['dimension'].append(column)
        data_dict['RMSE'].append(test_rmse)
        data_dict['MSE'].append(test_mse)
        data_dict['MAE'].append(test_mae)
        # Cleanup
        model.to('cpu')
        del model
        torch.cuda.empty_cache()
    
    return data_dict

In [7]:
def roc_curve(model, app_dim, fpr, tpr, auc):
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='blue', label=f'ROC Curve (AUC = {auc:.2f})')
    plt.plot([0, 1], [0, 1], color='red', linestyle='--')  # Diagonal line for random guessing
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(model + ':' + app_dim + ' Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')
    plt.grid()
    plt.show()

In [8]:
mod_tok_dict = {
    'bert-base-uncased' : 'bert-base-uncased',
    'bert-large-uncased' : 'bert-large-uncased',
    'dialogpt' : 'microsoft/DialogRPT-updown',
    'distilbert-base-uncased' : 'distilbert-base-uncased',
    'google-t5' : 'google-t5/t5-small',
    'roberta-base' : 'roberta-base',
    'roberta-large': 'FacebookAI/roberta-large',
}

for key in mod_tok_dict:
    modeldir = os.path.join('./models', key)
    data_dict = get_performance(mod_tok_dict[key], modeldir)



self_control


                                                                                                                                                                                                                      

Model Name: self_control_bert-base-uncased, Test RMSE: 0.3408, Test MSE: 0.1162, Test MAE: 0.2672
urgency


                                                                                                                                                                                                                      

Model Name: urgency_bert-base-uncased, Test RMSE: 0.3582, Test MSE: 0.1283, Test MAE: 0.2891
predict_event


                                                                                                                                                                                                                      

Model Name: predict_event_bert-base-uncased, Test RMSE: 0.3501, Test MSE: 0.1226, Test MAE: 0.3006
chance_control


                                                                                                                                                                                                                      

Model Name: chance_control_bert-base-uncased, Test RMSE: 0.3406, Test MSE: 0.1160, Test MAE: 0.2709
effort


                                                                                                                                                                                                                      

Model Name: effort_bert-base-uncased, Test RMSE: 0.3307, Test MSE: 0.1093, Test MAE: 0.2668
attention


                                                                                                                                                                                                                      

Model Name: attention_bert-base-uncased, Test RMSE: 0.3110, Test MSE: 0.0967, Test MAE: 0.2578
not_consider


                                                                                                                                                                                                                      

Model Name: not_consider_bert-base-uncased, Test RMSE: 0.3552, Test MSE: 0.1262, Test MAE: 0.2892
chance_responsblt


                                                                                                                                                                                                                      

Model Name: chance_responsblt_bert-base-uncased, Test RMSE: 0.3560, Test MSE: 0.1268, Test MAE: 0.2966
accept_conseq


                                                                                                                                                                                                                      

Model Name: accept_conseq_bert-base-uncased, Test RMSE: 0.3491, Test MSE: 0.1218, Test MAE: 0.2987
pleasantness


                                                                                                                                                                                                                      

Model Name: pleasantness_bert-base-uncased, Test RMSE: 0.2678, Test MSE: 0.0717, Test MAE: 0.1918
predict_conseq


                                                                                                                                                                                                                      

Model Name: predict_conseq_bert-base-uncased, Test RMSE: 0.3555, Test MSE: 0.1264, Test MAE: 0.3082
suddenness


                                                                                                                                                                                                                      

Model Name: suddenness_bert-base-uncased, Test RMSE: 0.3210, Test MSE: 0.1031, Test MAE: 0.2596
unpleasantness


                                                                                                                                                                                                                      

Model Name: unpleasantness_bert-base-uncased, Test RMSE: 0.3019, Test MSE: 0.0911, Test MAE: 0.2065
other_responsblt


                                                                                                                                                                                                                      

Model Name: other_responsblt_bert-base-uncased, Test RMSE: 0.3377, Test MSE: 0.1140, Test MAE: 0.2613
social_norms


                                                                                                                                                                                                                      

Model Name: social_norms_bert-base-uncased, Test RMSE: 0.2806, Test MSE: 0.0787, Test MAE: 0.2174
standards


                                                                                                                                                                                                                      

Model Name: standards_bert-base-uncased, Test RMSE: 0.3444, Test MSE: 0.1186, Test MAE: 0.2837
goal_support


                                                                                                                                                                                                                      

Model Name: goal_support_bert-base-uncased, Test RMSE: 0.2937, Test MSE: 0.0862, Test MAE: 0.2284
familiarity


                                                                                                                                                                                                                      

Model Name: familiarity_bert-base-uncased, Test RMSE: 0.3355, Test MSE: 0.1126, Test MAE: 0.2817
other_control


                                                                                                                                                                                                                      

Model Name: other_control_bert-base-uncased, Test RMSE: 0.3455, Test MSE: 0.1193, Test MAE: 0.2935
self_responsblt


                                                                                                                                                                                                                      

Model Name: self_responsblt_bert-base-uncased, Test RMSE: 0.3153, Test MSE: 0.0994, Test MAE: 0.2444
goal_relevance


                                                                                                                                                                                                                      

Model Name: goal_relevance_bert-base-uncased, Test RMSE: 0.3211, Test MSE: 0.1031, Test MAE: 0.2681




goal_relevance


                                                                                                                                                                                                                      

Model Name: goal_relevance_bert-large-uncased, Test RMSE: 0.3179, Test MSE: 0.1010, Test MAE: 0.2596
chance_responsblt


                                                                                                                                                                                                                      

Model Name: chance_responsblt_bert-large-uncased, Test RMSE: 0.3949, Test MSE: 0.1560, Test MAE: 0.3479
chance_control


                                                                                                                                                                                                                      

Model Name: chance_control_bert-large-uncased, Test RMSE: 0.3404, Test MSE: 0.1158, Test MAE: 0.2769
effort


                                                                                                                                                                                                                      

Model Name: effort_bert-large-uncased, Test RMSE: 0.3274, Test MSE: 0.1072, Test MAE: 0.2645
accept_conseq


                                                                                                                                                                                                                      

Model Name: accept_conseq_bert-large-uncased, Test RMSE: 0.3465, Test MSE: 0.1201, Test MAE: 0.2890
standards


                                                                                                                                                                                                                      

Model Name: standards_bert-large-uncased, Test RMSE: 0.3377, Test MSE: 0.1140, Test MAE: 0.2584
other_control


                                                                                                                                                                                                                      

Model Name: other_control_bert-large-uncased, Test RMSE: 0.3465, Test MSE: 0.1201, Test MAE: 0.2870
pleasantness


                                                                                                                                                                                                                      

Model Name: pleasantness_bert-large-uncased, Test RMSE: 0.2623, Test MSE: 0.0688, Test MAE: 0.1769
predict_conseq


                                                                                                                                                                                                                      

Model Name: predict_conseq_bert-large-uncased, Test RMSE: 0.3542, Test MSE: 0.1255, Test MAE: 0.3058
unpleasantness


                                                                                                                                                                                                                      

Model Name: unpleasantness_bert-large-uncased, Test RMSE: 0.2951, Test MSE: 0.0871, Test MAE: 0.2102
goal_support


                                                                                                                                                                                                                      

Model Name: goal_support_bert-large-uncased, Test RMSE: 0.2839, Test MSE: 0.0806, Test MAE: 0.2181
familiarity


                                                                                                                                                                                                                      

Model Name: familiarity_bert-large-uncased, Test RMSE: 0.3742, Test MSE: 0.1400, Test MAE: 0.3381
self_control


                                                                                                                                                                                                                      

Model Name: self_control_bert-large-uncased, Test RMSE: 0.3282, Test MSE: 0.1077, Test MAE: 0.2703
urgency


                                                                                                                                                                                                                      

Model Name: urgency_bert-large-uncased, Test RMSE: 0.3478, Test MSE: 0.1210, Test MAE: 0.2884
suddenness


                                                                                                                                                                                                                      

Model Name: suddenness_bert-large-uncased, Test RMSE: 0.3273, Test MSE: 0.1072, Test MAE: 0.2799
self_responsblt


                                                                                                                                                                                                                      

Model Name: self_responsblt_bert-large-uncased, Test RMSE: 0.4113, Test MSE: 0.1692, Test MAE: 0.3793
attention


                                                                                                                                                                                                                      

Model Name: attention_bert-large-uncased, Test RMSE: 0.3290, Test MSE: 0.1082, Test MAE: 0.2798
social_norms


                                                                                                                                                                                                                      

Model Name: social_norms_bert-large-uncased, Test RMSE: 0.2773, Test MSE: 0.0769, Test MAE: 0.1802
predict_event


                                                                                                                                                                                                                      

Model Name: predict_event_bert-large-uncased, Test RMSE: 0.3805, Test MSE: 0.1447, Test MAE: 0.3387
other_responsblt


                                                                                                                                                                                                                      

Model Name: other_responsblt_bert-large-uncased, Test RMSE: 0.3404, Test MSE: 0.1158, Test MAE: 0.2793
not_consider


                                                                                                                                                                                                                      

Model Name: not_consider_bert-large-uncased, Test RMSE: 0.3507, Test MSE: 0.1230, Test MAE: 0.2889




predict_event


                                                                                                                                                                                                                      

Model Name: predict_event_microsoft, Test RMSE: 0.3507, Test MSE: 0.1230, Test MAE: 0.3002
effort


                                                                                                                                                                                                                      

Model Name: effort_microsoft, Test RMSE: 0.3309, Test MSE: 0.1095, Test MAE: 0.2781
attention


                                                                                                                                                                                                                      

Model Name: attention_microsoft, Test RMSE: 0.3157, Test MSE: 0.0997, Test MAE: 0.2686
standards


                                                                                                                                                                                                                      

Model Name: standards_microsoft, Test RMSE: 0.3434, Test MSE: 0.1179, Test MAE: 0.2753
chance_responsblt


                                                                                                                                                                                                                      

Model Name: chance_responsblt_microsoft, Test RMSE: 0.3575, Test MSE: 0.1278, Test MAE: 0.3068
unpleasantness


                                                                                                                                                                                                                      

Model Name: unpleasantness_microsoft, Test RMSE: 0.3133, Test MSE: 0.0981, Test MAE: 0.2261
goal_support


                                                                                                                                                                                                                      

Model Name: goal_support_microsoft, Test RMSE: 0.3089, Test MSE: 0.0954, Test MAE: 0.2551
accept_conseq


                                                                                                                                                                                                                      

Model Name: accept_conseq_microsoft, Test RMSE: 0.3587, Test MSE: 0.1287, Test MAE: 0.3037
self_control


                                                                                                                                                                                                                      

Model Name: self_control_microsoft, Test RMSE: 0.3329, Test MSE: 0.1108, Test MAE: 0.2743
not_consider


                                                                                                                                                                                                                      

Model Name: not_consider_microsoft, Test RMSE: 0.3607, Test MSE: 0.1301, Test MAE: 0.3058
other_control


                                                                                                                                                                                                                      

Model Name: other_control_microsoft, Test RMSE: 0.3518, Test MSE: 0.1238, Test MAE: 0.2893
urgency


                                                                                                                                                                                                                      

Model Name: urgency_microsoft, Test RMSE: 0.3569, Test MSE: 0.1274, Test MAE: 0.2964
pleasantness


                                                                                                                                                                                                                      

Model Name: pleasantness_microsoft, Test RMSE: 0.2911, Test MSE: 0.0847, Test MAE: 0.2096
suddenness


                                                                                                                                                                                                                      

Model Name: suddenness_microsoft, Test RMSE: 0.3174, Test MSE: 0.1007, Test MAE: 0.2592
other_responsblt


                                                                                                                                                                                                                      

Model Name: other_responsblt_microsoft, Test RMSE: 0.3502, Test MSE: 0.1227, Test MAE: 0.2808
predict_conseq


                                                                                                                                                                                                                      

Model Name: predict_conseq_microsoft, Test RMSE: 0.3603, Test MSE: 0.1298, Test MAE: 0.3128
chance_control


                                                                                                                                                                                                                      

Model Name: chance_control_microsoft, Test RMSE: 0.3494, Test MSE: 0.1221, Test MAE: 0.2821
goal_relevance


                                                                                                                                                                                                                      

Model Name: goal_relevance_microsoft, Test RMSE: 0.3231, Test MSE: 0.1044, Test MAE: 0.2670
familiarity


                                                                                                                                                                                                                      

Model Name: familiarity_microsoft, Test RMSE: 0.3327, Test MSE: 0.1107, Test MAE: 0.2863
social_norms


                                                                                                                                                                                                                      

Model Name: social_norms_microsoft, Test RMSE: 0.2889, Test MSE: 0.0835, Test MAE: 0.2033
self_responsblt


                                                                                                                                                                                                                      

Model Name: self_responsblt_microsoft, Test RMSE: 0.3256, Test MSE: 0.1060, Test MAE: 0.2644




accept_conseq


                                                                                                                                                                                                                      

Model Name: accept_conseq_distilbert-base-uncased, Test RMSE: 0.3451, Test MSE: 0.1191, Test MAE: 0.2934
other_responsblt


                                                                                                                                                                                                                      

Model Name: other_responsblt_distilbert-base-uncased, Test RMSE: 0.3480, Test MSE: 0.1211, Test MAE: 0.2759
self_responsblt


                                                                                                                                                                                                                      

Model Name: self_responsblt_distilbert-base-uncased, Test RMSE: 0.3217, Test MSE: 0.1035, Test MAE: 0.2501
goal_relevance


                                                                                                                                                                                                                      

Model Name: goal_relevance_distilbert-base-uncased, Test RMSE: 0.3218, Test MSE: 0.1035, Test MAE: 0.2630
attention


                                                                                                                                                                                                                      

Model Name: attention_distilbert-base-uncased, Test RMSE: 0.3078, Test MSE: 0.0947, Test MAE: 0.2494
self_control


                                                                                                                                                                                                                      

Model Name: self_control_distilbert-base-uncased, Test RMSE: 0.3399, Test MSE: 0.1155, Test MAE: 0.2836
standards


                                                                                                                                                                                                                      

Model Name: standards_distilbert-base-uncased, Test RMSE: 0.3415, Test MSE: 0.1166, Test MAE: 0.2612
predict_conseq


                                                                                                                                                                                                                      

Model Name: predict_conseq_distilbert-base-uncased, Test RMSE: 0.3539, Test MSE: 0.1253, Test MAE: 0.3032
pleasantness


                                                                                                                                                                                                                      

Model Name: pleasantness_distilbert-base-uncased, Test RMSE: 0.2742, Test MSE: 0.0752, Test MAE: 0.1873
familiarity


                                                                                                                                                                                                                      

Model Name: familiarity_distilbert-base-uncased, Test RMSE: 0.3190, Test MSE: 0.1018, Test MAE: 0.2720
urgency


                                                                                                                                                                                                                      

Model Name: urgency_distilbert-base-uncased, Test RMSE: 0.3548, Test MSE: 0.1259, Test MAE: 0.2937
chance_control


                                                                                                                                                                                                                      

Model Name: chance_control_distilbert-base-uncased, Test RMSE: 0.3427, Test MSE: 0.1174, Test MAE: 0.2784
effort


                                                                                                                                                                                                                      

Model Name: effort_distilbert-base-uncased, Test RMSE: 0.3325, Test MSE: 0.1105, Test MAE: 0.2773
goal_support


                                                                                                                                                                                                                      

Model Name: goal_support_distilbert-base-uncased, Test RMSE: 0.2967, Test MSE: 0.0880, Test MAE: 0.2376
chance_responsblt


                                                                                                                                                                                                                      

Model Name: chance_responsblt_distilbert-base-uncased, Test RMSE: 0.3652, Test MSE: 0.1334, Test MAE: 0.2983
social_norms


                                                                                                                                                                                                                      

Model Name: social_norms_distilbert-base-uncased, Test RMSE: 0.2792, Test MSE: 0.0779, Test MAE: 0.2189
predict_event


                                                                                                                                                                                                                      

Model Name: predict_event_distilbert-base-uncased, Test RMSE: 0.3425, Test MSE: 0.1173, Test MAE: 0.2910
unpleasantness


                                                                                                                                                                                                                      

Model Name: unpleasantness_distilbert-base-uncased, Test RMSE: 0.3002, Test MSE: 0.0901, Test MAE: 0.2140
suddenness


                                                                                                                                                                                                                      

Model Name: suddenness_distilbert-base-uncased, Test RMSE: 0.3205, Test MSE: 0.1027, Test MAE: 0.2634
not_consider


                                                                                                                                                                                                                      

Model Name: not_consider_distilbert-base-uncased, Test RMSE: 0.3583, Test MSE: 0.1284, Test MAE: 0.3061
other_control


                                                                                                                                                                                                                      

Model Name: other_control_distilbert-base-uncased, Test RMSE: 0.3527, Test MSE: 0.1244, Test MAE: 0.3033
familiarity


                                                                                                                                                                                                                      

Model Name: familiarity_google-t5, Test RMSE: 0.3471, Test MSE: 0.1204, Test MAE: 0.3069
goal_support


                                                                                                                                                                                                                      

Model Name: goal_support_google-t5, Test RMSE: 0.3169, Test MSE: 0.1004, Test MAE: 0.2629
suddenness


                                                                                                                                                                                                                      

Model Name: suddenness_google-t5, Test RMSE: 0.3418, Test MSE: 0.1168, Test MAE: 0.2960
social_norms


                                                                                                                                                                                                                      

Model Name: social_norms_google-t5, Test RMSE: 0.2970, Test MSE: 0.0882, Test MAE: 0.2309
not_consider


                                                                                                                                                                                                                      

Model Name: not_consider_google-t5, Test RMSE: 0.3633, Test MSE: 0.1320, Test MAE: 0.3132
chance_responsblt


                                                                                                                                                                                                                      

Model Name: chance_responsblt_google-t5, Test RMSE: 0.3697, Test MSE: 0.1366, Test MAE: 0.3237
self_responsblt


                                                                                                                                                                                                                      

Model Name: self_responsblt_google-t5, Test RMSE: 0.3674, Test MSE: 0.1350, Test MAE: 0.3087
other_control


                                                                                                                                                                                                                      

Model Name: other_control_google-t5, Test RMSE: 0.3667, Test MSE: 0.1345, Test MAE: 0.3141
pleasantness


                                                                                                                                                                                                                      

Model Name: pleasantness_google-t5, Test RMSE: 0.3046, Test MSE: 0.0928, Test MAE: 0.2362
effort


                                                                                                                                                                                                                      

Model Name: effort_google-t5, Test RMSE: 0.3520, Test MSE: 0.1239, Test MAE: 0.3030
predict_conseq


                                                                                                                                                                                                                      

Model Name: predict_conseq_google-t5, Test RMSE: 0.3645, Test MSE: 0.1329, Test MAE: 0.3146
accept_conseq


                                                                                                                                                                                                                      

Model Name: accept_conseq_google-t5, Test RMSE: 0.3567, Test MSE: 0.1273, Test MAE: 0.3032
predict_event


                                                                                                                                                                                                                      

Model Name: predict_event_google-t5, Test RMSE: 0.3737, Test MSE: 0.1397, Test MAE: 0.3283
urgency


                                                                                                                                                                                                                      

Model Name: urgency_google-t5, Test RMSE: 0.3623, Test MSE: 0.1312, Test MAE: 0.3131
other_responsblt


                                                                                                                                                                                                                      

Model Name: other_responsblt_google-t5, Test RMSE: 0.3870, Test MSE: 0.1497, Test MAE: 0.3343
self_control


                                                                                                                                                                                                                      

Model Name: self_control_google-t5, Test RMSE: 0.3497, Test MSE: 0.1223, Test MAE: 0.3011
chance_control


                                                                                                                                                                                                                      

Model Name: chance_control_google-t5, Test RMSE: 0.3691, Test MSE: 0.1363, Test MAE: 0.3178
standards


                                                                                                                                                                                                                      

Model Name: standards_google-t5, Test RMSE: 0.3477, Test MSE: 0.1209, Test MAE: 0.2913
goal_relevance


                                                                                                                                                                                                                      

Model Name: goal_relevance_google-t5, Test RMSE: 0.3456, Test MSE: 0.1194, Test MAE: 0.2969
attention


                                                                                                                                                                                                                      

Model Name: attention_google-t5, Test RMSE: 0.3182, Test MSE: 0.1013, Test MAE: 0.2595
unpleasantness


                                                                                                                                                                                                                      

Model Name: unpleasantness_google-t5, Test RMSE: 0.3243, Test MSE: 0.1052, Test MAE: 0.2555




other_responsblt


                                                                                                                                                                                                                      

Model Name: other_responsblt_roberta-base, Test RMSE: 0.3195, Test MSE: 0.1021, Test MAE: 0.2434
goal_support


                                                                                                                                                                                                                      

Model Name: goal_support_roberta-base, Test RMSE: 0.2862, Test MSE: 0.0819, Test MAE: 0.2281
unpleasantness


                                                                                                                                                                                                                      

Model Name: unpleasantness_roberta-base, Test RMSE: 0.2880, Test MSE: 0.0829, Test MAE: 0.2282
effort


                                                                                                                                                                                                                      

Model Name: effort_roberta-base, Test RMSE: 0.3340, Test MSE: 0.1115, Test MAE: 0.2791
accept_conseq


                                                                                                                                                                                                                      

Model Name: accept_conseq_roberta-base, Test RMSE: 0.3458, Test MSE: 0.1196, Test MAE: 0.2957
chance_control


                                                                                                                                                                                                                      

Model Name: chance_control_roberta-base, Test RMSE: 0.3373, Test MSE: 0.1137, Test MAE: 0.2742
predict_event


                                                                                                                                                                                                                      

Model Name: predict_event_roberta-base, Test RMSE: 0.3456, Test MSE: 0.1194, Test MAE: 0.2831
social_norms


                                                                                                                                                                                                                      

Model Name: social_norms_roberta-base, Test RMSE: 0.2734, Test MSE: 0.0748, Test MAE: 0.2091
goal_relevance


                                                                                                                                                                                                                      

Model Name: goal_relevance_roberta-base, Test RMSE: 0.3139, Test MSE: 0.0985, Test MAE: 0.2571
suddenness


                                                                                                                                                                                                                      

Model Name: suddenness_roberta-base, Test RMSE: 0.3191, Test MSE: 0.1018, Test MAE: 0.2767
urgency


                                                                                                                                                                                                                      

Model Name: urgency_roberta-base, Test RMSE: 0.3557, Test MSE: 0.1266, Test MAE: 0.2869
attention


                                                                                                                                                                                                                      

Model Name: attention_roberta-base, Test RMSE: 0.3287, Test MSE: 0.1081, Test MAE: 0.2765
self_control


                                                                                                                                                                                                                      

Model Name: self_control_roberta-base, Test RMSE: 0.3255, Test MSE: 0.1059, Test MAE: 0.2698
other_control


                                                                                                                                                                                                                      

Model Name: other_control_roberta-base, Test RMSE: 0.3260, Test MSE: 0.1063, Test MAE: 0.2607
chance_responsblt


                                                                                                                                                                                                                      

Model Name: chance_responsblt_roberta-base, Test RMSE: 0.3480, Test MSE: 0.1211, Test MAE: 0.2864
pleasantness


                                                                                                                                                                                                                      

Model Name: pleasantness_roberta-base, Test RMSE: 0.2564, Test MSE: 0.0657, Test MAE: 0.1787
standards


                                                                                                                                                                                                                      

Model Name: standards_roberta-base, Test RMSE: 0.3351, Test MSE: 0.1123, Test MAE: 0.2509
self_responsblt


                                                                                                                                                                                                                      

Model Name: self_responsblt_roberta-base, Test RMSE: 0.2959, Test MSE: 0.0875, Test MAE: 0.2289
familiarity


                                                                                                                                                                                                                      

Model Name: familiarity_roberta-base, Test RMSE: 0.3272, Test MSE: 0.1071, Test MAE: 0.2740
predict_conseq


                                                                                                                                                                                                                      

Model Name: predict_conseq_roberta-base, Test RMSE: 0.3514, Test MSE: 0.1235, Test MAE: 0.3023
not_consider


                                                                                                                                                                                                                      

Model Name: not_consider_roberta-base, Test RMSE: 0.3444, Test MSE: 0.1186, Test MAE: 0.2793




standards


                                                                                                                                                                                                                      

Model Name: standards_FacebookAI, Test RMSE: 0.3367, Test MSE: 0.1134, Test MAE: 0.2526
predict_event


                                                                                                                                                                                                                      

Model Name: predict_event_FacebookAI, Test RMSE: 0.3814, Test MSE: 0.1454, Test MAE: 0.3371
familiarity


                                                                                                                                                                                                                      

Model Name: familiarity_FacebookAI, Test RMSE: 0.3174, Test MSE: 0.1008, Test MAE: 0.2511
accept_conseq


                                                                                                                                                                                                                      

Model Name: accept_conseq_FacebookAI, Test RMSE: 0.3488, Test MSE: 0.1216, Test MAE: 0.2962
not_consider


                                                                                                                                                                                                                      

Model Name: not_consider_FacebookAI, Test RMSE: 0.3675, Test MSE: 0.1350, Test MAE: 0.3066
goal_relevance


                                                                                                                                                                                                                      

Model Name: goal_relevance_FacebookAI, Test RMSE: 0.3162, Test MSE: 0.1000, Test MAE: 0.2726
chance_responsblt


                                                                                                                                                                                                                      

Model Name: chance_responsblt_FacebookAI, Test RMSE: 0.3943, Test MSE: 0.1555, Test MAE: 0.3483
pleasantness


                                                                                                                                                                                                                      

Model Name: pleasantness_FacebookAI, Test RMSE: 0.2571, Test MSE: 0.0661, Test MAE: 0.1605
attention


                                                                                                                                                                                                                      

Model Name: attention_FacebookAI, Test RMSE: 0.3008, Test MSE: 0.0905, Test MAE: 0.2419
unpleasantness


                                                                                                                                                                                                                      

Model Name: unpleasantness_FacebookAI, Test RMSE: 0.4300, Test MSE: 0.1849, Test MAE: 0.4022
effort


                                                                                                                                                                                                                      

Model Name: effort_FacebookAI, Test RMSE: 0.3276, Test MSE: 0.1073, Test MAE: 0.2800
self_responsblt


                                                                                                                                                                                                                      

Model Name: self_responsblt_FacebookAI, Test RMSE: 0.4142, Test MSE: 0.1715, Test MAE: 0.3810
urgency


                                                                                                                                                                                                                      

Model Name: urgency_FacebookAI, Test RMSE: 0.3515, Test MSE: 0.1235, Test MAE: 0.2881
suddenness


                                                                                                                                                                                                                      

Model Name: suddenness_FacebookAI, Test RMSE: 0.3842, Test MSE: 0.1476, Test MAE: 0.3428
other_control


                                                                                                                                                                                                                      

Model Name: other_control_FacebookAI, Test RMSE: 0.3278, Test MSE: 0.1075, Test MAE: 0.2655
goal_support


                                                                                                                                                                                                                      

Model Name: goal_support_FacebookAI, Test RMSE: 0.2807, Test MSE: 0.0788, Test MAE: 0.2276
social_norms


                                                                                                                                                                                                                      

Model Name: social_norms_FacebookAI, Test RMSE: 0.3346, Test MSE: 0.1120, Test MAE: 0.2678
predict_conseq


                                                                                                                                                                                                                      

Model Name: predict_conseq_FacebookAI, Test RMSE: 0.3614, Test MSE: 0.1306, Test MAE: 0.3144
other_responsblt


                                                                                                                                                                                                                      

Model Name: other_responsblt_FacebookAI, Test RMSE: 0.3241, Test MSE: 0.1050, Test MAE: 0.2446
chance_control


                                                                                                                                                                                                                      

Model Name: chance_control_FacebookAI, Test RMSE: 0.3923, Test MSE: 0.1539, Test MAE: 0.3369
self_control


                                                                                                                                                                                                                      

Model Name: self_control_FacebookAI, Test RMSE: 0.3690, Test MSE: 0.1362, Test MAE: 0.3311


In [10]:
# for ix, _ in enumerate(data_dict['model']):
#     # roc_curve(data_dict['model'], data_dict['dimension'])

TypeError: roc_curve() missing 3 required positional arguments: 'fpr', 'tpr', and 'auc'

In [11]:
df = pd.DataFrame(data_dict)
df

Unnamed: 0,model,dimension,RMSE,MSE,MAE
0,bert-base-uncased,self_control,0.340837,0.116170,0.267178
1,bert-base-uncased,urgency,0.358199,0.128306,0.289085
2,bert-base-uncased,predict_event,0.350095,0.122567,0.300647
3,bert-base-uncased,chance_control,0.340582,0.115996,0.270933
4,bert-base-uncased,effort,0.330662,0.109338,0.266773
...,...,...,...,...,...
142,FacebookAI,social_norms,0.334594,0.111953,0.267847
143,FacebookAI,predict_conseq,0.361370,0.130588,0.314368
144,FacebookAI,other_responsblt,0.324101,0.105041,0.244587
145,FacebookAI,chance_control,0.392263,0.153871,0.336888


In [12]:
df[df['model'] == 'bert-large-uncased'].describe()

Unnamed: 0,RMSE,MSE,MAE
count,21.0,21.0,21.0
mean,0.336824,0.114753,0.277004
std,0.036986,0.024848,0.051267
min,0.26232,0.068812,0.176911
25%,0.327346,0.107155,0.259553
50%,0.340358,0.115843,0.279807
75%,0.350693,0.122986,0.288983
max,0.41129,0.169159,0.379276


In [13]:
df[df['model'] == 'roberta-base'].describe()

Unnamed: 0,RMSE,MSE,MAE
count,21.0,21.0,21.0
mean,0.321761,0.104236,0.260431
std,0.02722,0.016866,0.03111
min,0.256373,0.065727,0.178736
25%,0.31388,0.098521,0.243398
50%,0.327237,0.107084,0.274048
75%,0.34441,0.118618,0.279331
max,0.355742,0.126553,0.302349


In [14]:
df[df['model'] == 'distilbert-base-uncased'].describe()

Unnamed: 0,RMSE,MSE,MAE
count,21.0,21.0,21.0
mean,0.329429,0.109161,0.267668
std,0.025866,0.016597,0.032188
min,0.274217,0.075195,0.187304
25%,0.319028,0.101779,0.250079
50%,0.339855,0.115501,0.275937
75%,0.348023,0.12112,0.293428
max,0.365246,0.133405,0.306085


In [15]:
df[df['model'] == 'google-t5'].describe()

Unnamed: 0,RMSE,MSE,MAE
count,21.0,21.0,21.0
mean,0.348815,0.122223,0.295777
std,0.024044,0.016386,0.029438
min,0.296971,0.088192,0.230874
25%,0.341761,0.116801,0.291304
50%,0.351996,0.123901,0.30317
75%,0.366699,0.134468,0.314051
max,0.386958,0.149736,0.33429


In [16]:
df[df['model'] == 'bert-base-uncased'].describe()

Unnamed: 0,RMSE,MSE,MAE
count,21.0,21.0,21.0
mean,0.329125,0.108963,0.265786
std,0.025916,0.016475,0.032267
min,0.267815,0.071725,0.191849
25%,0.315301,0.099415,0.257789
50%,0.337686,0.114032,0.268092
75%,0.349059,0.121842,0.289181
max,0.358199,0.128306,0.308196


In [21]:
df[df['model'] == 'microsoft'].describe()

Unnamed: 0,RMSE,MSE,MAE
count,21.0,21.0,21.0
mean,0.334287,0.112225,0.273591
std,0.022386,0.014705,0.030304
min,0.288885,0.083455,0.203294
25%,0.317354,0.100713,0.264389
50%,0.332921,0.110837,0.278148
75%,0.351846,0.123795,0.296393
max,0.360667,0.130081,0.312788


In [22]:
df[df['model'] == 'FacebookAI'].describe()

Unnamed: 0,RMSE,MSE,MAE
count,21.0,21.0,21.0
mean,0.348441,0.123188,0.292807
std,0.043191,0.029984,0.056277
min,0.257111,0.066106,0.160482
25%,0.324101,0.105041,0.252638
50%,0.348768,0.121639,0.288071
75%,0.381361,0.145436,0.336888
max,0.429966,0.184871,0.402199
