In [None]:
import os
import re
import pandas as pd
from typing import Optional

class SentimentLabelExtractor:
    """Extracts the sentiment label from a model output string."""
    _pattern = re.compile(r'\b(positive|negative|neutral|mixed)\b', re.IGNORECASE)

    @staticmethod
    def extract(text: str) -> Optional[str]:
        match = SentimentLabelExtractor._pattern.search(text)
        return match.group(1).lower() if match else None

class FairnessEvaluator:
    """
    Evaluate the Set Equivalence MRT for sentiment:
      - saves RowNumber (csv line, starting from 1)
      - only compares OriginalLabel x PerturbedLabel
    """
    def __init__(self, df: pd.DataFrame):
        self.df = df.copy()

    def evaluate(self) -> pd.DataFrame:
        records = []
        for idx, row in self.df.iterrows():
            orig_label = SentimentLabelExtractor.extract(row['OriginalOutput'])
            pert_label = SentimentLabelExtractor.extract(row['PerturbedOutput'])
            records.append({
                'RowNumber':     idx + 1,            # DataFrame row → line
                'InputTextID':   row['InputTextID'],
                'PerturbationID':row['PerturbationID'],
                'OriginalLabel': orig_label,
                'PerturbedLabel': pert_label,
                'IsEquivalent':  FairnessEvaluator.isEquivalent(orig_label, pert_label)
            })
        return pd.DataFrame.from_records(records)

    @staticmethod
    def isEquivalent(orig_label, pert_label):
      if((orig_label == "mixed" or orig_label == "neutral") and (pert_label == "mixed" or pert_label == "neutral")):
        return True

      return (orig_label == pert_label)

def main(csv_path: str, output_report: str) -> None:
    df = pd.read_csv(csv_path)
    evaluator = FairnessEvaluator(df)
    results = evaluator.evaluate()

    # filter where the sentiment has chaged
    mismatches = results[~results['IsEquivalent']]
    rows = mismatches['RowNumber'].tolist()

    print(f"Total pairs checked: {len(results)}")
    print(f"Total mismatches found: {len(rows)}")
    print("Rows with sentiment mismatch:", rows)

    # optionally saves the result in a csv
    if os.path.isdir(output_report):
        output_file = os.path.join(output_report, 'sentiment_mismatches.csv')
    else:
        output_file = output_report if output_report.lower().endswith('.csv') else output_report + '.csv'
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    mismatches.to_csv(output_file, index=False)
    print(f"Report with details saved to {output_file}")

if __name__ == '__main__':
    path_to_enter = input("Enter the path to the input CSV file: ").strip()
    path_to_exit  = input("Enter the path for the output report (file or folder): ").strip()
    main(path_to_enter, path_to_exit)

Enter the path to the input CSV file: nemotron.csv
Enter the path for the output report (file or folder): /content/output/nemotron_sentiment_mismatches.csv
Total pairs checked: 2100
Total mismatches found: 693
Rows with sentiment mismatch: [47, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 150, 151, 154, 157, 158, 159, 161, 163, 167, 193, 215, 218, 229, 230, 231, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 297, 302, 364, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 424, 444, 445, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 507, 509, 511, 513, 514, 516, 518, 519, 522, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 