In [63]:
import numpy as np
import os, argparse
import matplotlib.pyplot as plt
from utils.config import LABELS
from sklearn.metrics import multilabel_confusion_matrix, ConfusionMatrixDisplay, classification_report
from pretty_print_report import *

In [64]:
def load_data(file_path):
    """Load and parse the input file."""
    true_label_matrix = []
    predicted_label_matrix = []
    
    with open(file_path, 'r') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) >= 3:
                clause, true_label, predicted_labels = parts[-3:]
                true_label_set = set(true_label.strip('[]').split(','))
                predicted_label_set = set(predicted_labels.strip('[]').split(','))
                
                # Convert labels to binary indicator vectors
                true_vector = [1 if label in true_label_set else 0 for label in LABELS.labels]
                predicted_vector = [1 if label in predicted_label_set else 0 for label in LABELS.labels]

                if sum(true_vector) > 1:
                    # print(clause, true_label_set)
                    true_label_matrix.append(true_vector)
                    predicted_label_matrix.append(predicted_vector)
                
    return np.array(true_label_matrix), np.array(predicted_label_matrix)

In [65]:
def analyze_model(method, model_name):
    file_path = os.path.join("out", method, model_name+"_resp.txt")
    true_matrix, pred_matrix = load_data(file_path)
    print(f"REPORT FOR MODEL {method}/{model_name}")
    print(classification_report(true_matrix, pred_matrix, zero_division=0, target_names=LABELS.labels))
    return true_matrix, pred_matrix

In [66]:
true_matrix, pred_matrix = analyze_model("prompt_chain_8_long", "qwen32")

REPORT FOR MODEL prompt_chain_8_long/qwen32
              precision    recall  f1-score   support

        fair       0.00      0.00      0.00         0
           a       0.00      0.00      0.00         1
          ch       1.00      0.62      0.77        16
          cr       0.89      0.47      0.62        17
           j       1.00      1.00      1.00         2
         law       1.00      1.00      1.00         2
         ltd       1.00      0.20      0.33         5
         ter       0.93      0.52      0.67        27
         use       1.00      0.80      0.89        10
        pinc       1.00      0.86      0.92         7

   micro avg       0.84      0.59      0.69        87
   macro avg       0.78      0.55      0.62        87
weighted avg       0.95      0.59      0.71        87
 samples avg       0.78      0.59      0.65        87



In [68]:
true_matrix, pred_matrix = analyze_model("prompt_chain_8_long", "nemo")

REPORT FOR MODEL prompt_chain_8_long/nemo
              precision    recall  f1-score   support

        fair       0.00      0.00      0.00         0
           a       0.00      0.00      0.00         1
          ch       0.82      0.88      0.85        16
          cr       0.73      0.65      0.69        17
           j       1.00      1.00      1.00         2
         law       1.00      1.00      1.00         2
         ltd       1.00      1.00      1.00         5
         ter       0.90      0.70      0.79        27
         use       1.00      0.70      0.82        10
        pinc       1.00      0.29      0.44         7

   micro avg       0.81      0.71      0.76        87
   macro avg       0.75      0.62      0.66        87
weighted avg       0.87      0.71      0.77        87
 samples avg       0.78      0.71      0.71        87

