In [None]:
import os
import json
import csv
import logging
from collections import defaultdict
from typing import List, Dict, Tuple, Optional

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

class Config:
    def __init__(self):
        # Directory containing the JSON files
        self.data_directory = './data/encoded/caesar-cipher/'
        # Path to the configuration file
        self.config_file_path = './caesar-cipher-experiments-log.txt'
        # Output file names
        self.ter_csv_file = './eval/caesar_cipher/caesar-cipher-benchmark-result.csv'
        self.token_csv_file = './eval/caesar_cipher/token_based_performance_result.csv'
        # You can add more configuration parameters here

class CaesarCipherTest:
    def __init__(self, config: Config):
        self.config = config
        self.directory = self.config.data_directory
        self.config_file = self.config.config_file_path
        self.config_data = self.parse_config_file()
        self.global_token_accumulated = defaultdict(lambda: {'accuracy_sum': 0, 'count': 0})
        self.ter_results = []

    def parse_config_file(self) -> Dict[str, Dict[str, str]]:
        config_data = {}
        try:
            with open(self.config_file, 'r') as f:
                for line in f:
                    line = line.strip()
                    if line:
                        timestamp = self.extract_timestamp(line)
                        fields = self.parse_filename(line)
                        config_data[timestamp] = fields
        except FileNotFoundError:
            logging.error(f"Config file {self.config_file} not found.")
        return config_data

    @staticmethod
    def extract_timestamp(filename: str) -> str:
        parts = filename.split('_')
        if len(parts) >= 2:
            return f"{parts[0]}_{parts[1]}"
        return parts[0]

    @staticmethod
    def parse_filename(filename: str) -> Dict[str, str]:
        # Remove the .json extension if any
        if filename.endswith('.json'):
            filename = filename[:-5]

        # Split filename into parts
        parts = filename.split('__')
        pre_model_info = parts[0]
        post_model_info = parts[1] if len(parts) > 1 else ''
        pre_fields = pre_model_info.split('_')

        # Extract fields
        date = pre_fields[0]
        time = pre_fields[1]
        samples = pre_fields[2]
        shift = pre_fields[3]
        prompt_type = '_'.join(pre_fields[4:6])
        method = pre_fields[6]
        shot_and_model_name_parts = pre_fields[7:]
        shot_and_model_name = '_'.join(shot_and_model_name_parts)

        # Handle shot and model_name
        if '-' in shot_and_model_name:
            shot, model_name = shot_and_model_name.split('-', 1)
        else:
            shot = ''
            model_name = shot_and_model_name

        # Handle 'models--' prefix in model_name
        if model_name.startswith('models--'):
            model_name = model_name[len('models--'):]
            model_name_parts = model_name.split('--')
            model_name = model_name_parts[-1]

        # Handle temperature and max_token
        if post_model_info:
            post_fields = post_model_info.split('_')
            temperature = post_fields[0]
            max_token = post_fields[1] if len(post_fields) > 1 else ''
        else:
            temperature = ''
            max_token = ''

        return {
            'date': date,
            'time': time,
            'samples': samples,
            'shift': shift,
            'prompt_type': prompt_type,
            'method': method,
            'shot': shot,
            'model_name': model_name,
            'temperature': temperature,
            'max_token': max_token
        }

    @staticmethod
    def calculate_ter(pred: str, gold: str) -> Tuple[float, List[str], List[str]]:
        pred_tokens = [token for token in pred.strip().split() if token]
        gold_tokens = [token for token in gold.strip().split() if token]

        # Calculate token-level errors
        errors = sum(1 for p, g in zip(pred_tokens, gold_tokens) if p != g)
        errors += abs(len(pred_tokens) - len(gold_tokens))

        max_len = max(len(pred_tokens), len(gold_tokens))
        error_rate = errors / max_len if max_len > 0 else 1  # Avoid division by zero
        return error_rate, pred_tokens, gold_tokens

    @staticmethod
    def process_token_based_performance(token_performance: List[Tuple[int, float, float]]) -> Dict[int, Dict[str, float]]:
        token_accumulated = defaultdict(lambda: {'accuracy_sum': 0, 'count': 0})
        for token_index, token_acc, _ in token_performance:
            token_accumulated[token_index]['accuracy_sum'] += token_acc
            token_accumulated[token_index]['count'] += 1
        return token_accumulated

    @staticmethod
    def write_average_token_accuracy(writer, token_accumulated: Dict[int, Dict[str, float]]):
        for token_index, data in sorted(token_accumulated.items()):
            average_accuracy = data['accuracy_sum'] / data['count'] if data['count'] > 0 else 0
            writer.writerow([token_index, f"{average_accuracy:.2f}"])

    def process_json_file(self, filepath: str, fields: Dict[str, str]) -> Optional[Dict]:
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                content = f.read().strip()
                if not content:
                    logging.warning(f"Skipping empty file: {os.path.basename(filepath)}")
                    return None
                data = json.loads(content)
        except json.JSONDecodeError:
            logging.error(f"Error decoding JSON in file: {os.path.basename(filepath)}")
            return None
        except FileNotFoundError:
            logging.error(f"File not found: {filepath}")
            return None

        if not isinstance(data, list):
            logging.error(f"File {os.path.basename(filepath)} does not contain a list of records.")
            return None

        # Initialize accumulators
        total_accuracy = 0.0
        total_error_rate = 0.0
        total_correct_samples = 0
        total_matching_tokens = 0
        total_tokens = 0
        record_count = 0
        token_performance = []

        for record in data:
            cipher_text = record.get("cipher_text", "").strip()
            gold_label = record.get("gold_label", "").strip()

            if not cipher_text and not gold_label:
                continue

            error_rate, pred_tokens, gold_tokens = self.calculate_ter(cipher_text, gold_label)

            matching_tokens = sum(1 for p, g in zip(pred_tokens, gold_tokens) if p == g)
            total_matching_tokens += matching_tokens
            total_tokens += max(len(pred_tokens), len(gold_tokens))

            accuracy = 1 - error_rate
            total_accuracy += accuracy
            total_error_rate += error_rate
            record_count += 1

            if cipher_text == gold_label and cipher_text:
                total_correct_samples += 1
                logging.info(f"Exact match found in file {os.path.basename(filepath)}:")
                logging.info(f"cipher_text: {cipher_text}")
                logging.info(f"gold_label: {gold_label}")
                logging.info(f"Record: {record}")

            # Token-based performance
            for i, (p, g) in enumerate(zip(pred_tokens, gold_tokens)):
                if not p and not g:
                    continue
                token_accuracy = 1 if p == g else 0
                token_error_rate = 1 - token_accuracy
                token_performance.append((i + 1, token_accuracy * 100, token_error_rate * 100))

        if record_count > 0:
            average_accuracy = (total_accuracy / record_count) * 100
            average_error_rate = (total_error_rate / record_count) * 100
            sample_accuracy = (total_correct_samples / record_count) * 100
            token_exact_match_accuracy = (total_matching_tokens / total_tokens) * 100 if total_tokens > 0 else 0
        else:
            average_accuracy = 0.0
            average_error_rate = 0.0
            sample_accuracy = 0.0
            token_exact_match_accuracy = 0.0

        # Process token-based performance
        token_accumulated = self.process_token_based_performance(token_performance)

        return {
            'filename': os.path.basename(filepath),
            'model_name': fields.get('model_name', ''),
            'shift': fields.get('shift', ''),
            'prompt_type': fields.get('prompt_type', ''),
            'temperature': fields.get('temperature', ''),
            'max_token': fields.get('max_token', ''),
            'average_accuracy': average_accuracy,
            'sample_accuracy': sample_accuracy,
            'average_error_rate': average_error_rate,
            'token_exact_match_accuracy': token_exact_match_accuracy,
            'token_accumulated': token_accumulated
        }

    def process_directory(self):
        for filename in os.listdir(self.directory):
            if filename.endswith('.json'):
                filepath = os.path.join(self.directory, filename)
                filename_without_ext = filename[:-5]

                timestamp = self.extract_timestamp(filename_without_ext)
                fields = self.config_data.get(timestamp)
                if not fields:
                    logging.warning(f"Timestamp {timestamp} from filename {filename_without_ext} not found in {self.config_file}")
                    continue

                result_data = self.process_json_file(filepath, fields)
                if result_data:
                    self.ter_results.append(result_data)
                    # Merge token performance into global accumulator
                    token_accumulated = result_data['token_accumulated']
                    for token_index, data in token_accumulated.items():
                        self.global_token_accumulated[token_index]['accuracy_sum'] += data['accuracy_sum']
                        self.global_token_accumulated[token_index]['count'] += data['count']

        # Sort results
        self.ter_results.sort(key=lambda x: (x['model_name'], int(x['shift']), x['prompt_type']))

    def write_results(self):
        # Write TER results
        ter_csv_file = self.config.ter_csv_file
        with open(ter_csv_file, mode='w', newline='', encoding='utf-8') as ter_file:
            ter_writer = csv.writer(ter_file)
            ter_writer.writerow([
                'Filename', 'Model', 'Shift', 'Prompt Type', 'Temperature', 'Max Token',
                'Average Accuracy (%)', 'Sample-level Accuracy (%)', 'Average Error Rate (%)',
                'Token Exact Match Accuracy (%)'
            ])
            for result in self.ter_results:
                ter_writer.writerow([
                    result['filename'],
                    result['model_name'],
                    result['shift'],
                    result['prompt_type'],
                    result['temperature'],
                    result['max_token'],
                    f"{result['average_accuracy']:.2f}",
                    f"{result['sample_accuracy']:.2f}",
                    f"{result['average_error_rate']:.2f}",
                    f"{result['token_exact_match_accuracy']:.2f}"
                ])
        logging.info(f"TER results written to {ter_csv_file}")

        # Write token-based performance
        token_csv_file = self.config.token_csv_file
        with open(token_csv_file, mode='w', newline='', encoding='utf-8') as token_file:
            token_writer = csv.writer(token_file)
            token_writer.writerow(['Token Index', 'Average Accuracy (%)'])
            self.write_average_token_accuracy(token_writer, self.global_token_accumulated)
        logging.info(f"Token-based performance results written to {token_csv_file}")

    def run(self):
        self.process_directory()
        self.write_results()

if __name__ == '__main__':
    # Initialize configuration
    config = Config()
    # Create an instance of CaesarCipherTest with the configuration
    caesar_test = CaesarCipherTest(config)
    caesar_test.run()


INFO: TER results written to caesar-cipher-benchmark-result.csv
INFO: Token-based performance results written to Token_Based_Performance_Result.csv
