In [None]:
# !pip3 install torch --index-url https://download.pytorch.org/whl/cu124
# !pip3 install bitsandbytes
# !pip3 install --upgrade accelerate
# !pip install transformers[sentencepiece]

# import torch
# print(torch.cuda.is_available())
# print(torch.cuda.device_count())
# print(torch.cuda.get_device_name(0))
# print(torch.cuda.current_device())

In [None]:
import os
import re
import json
import random
import torch
import warnings
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datetime import datetime

# =======================
# Configuration Section
# =======================
CONFIG = {
    # File paths and directories
    'PATHS': {
        'INPUT_FILE': './data/original/natural.txt',
        'OUTPUT_DIR': './data/encoded/caesar-cipher/',
        'CONFIG_FILE': './log/caesar-cipher-experiments-log.txt'
    },
    
    # Model configurations
    'MODELS': [
        '/storage/ukp/shared/shared_model_weights/models--llama-3/8B-Instruct/',
        '/storage/ukp/shared/shared_model_weights/models--llama-3/70B-Instruct/',
        '/storage/ukp/shared/shared_model_weights/models--mistralai--Mistral-7B-Instruct-v0.3/',
        '/storage/ukp/shared/shared_model_weights/models--mistralai--Mixtral-8x7B-Instruct-v0.1/'
    ],
    
    # Experiment settings
    'EXPERIMENT': {
        'TOTAL_LINES': 25,
        'SHIFT_VALUES': [12],
        'TEMPERATURE': 0.01,
        'SEED': 2266
    },
    
    # Model generation settings
    'GENERATION': {
        'MAX_TOKENS': {
            'base': 64,
            'cot-like': 512,
            'default': 2024,
        }
    },
    
    # Prompt templates
    'PROMPTS': {
        'base': "Encode the following text to a Caesar cipher. The shift is $SHIFT$. Output the cipher text only.\n\nplain text: ",
        'cot-like': "Encode the following text to a Caesar cipher. The shift is $SHIFT$. Output a lookup table and the cipher text in a Python dictionary: {\"lookup_table\": {}, \"cipher_text\":...}. Output the dictionary only.\n\nplain text: ",
        'default': "Encode the following text to a Caesar cipher. The shift is $SHIFT$.\n\nplain text: "
    },
    
    # Regex patterns for text extraction
    'PATTERNS': {
        'LOOKUP_TABLE': r"(\"lookup_table\":|\'lookup_table\':)\s*(\{[^\}]*\})",
        'CIPHER_TEXT': r"(\"cipher_text\":|\'cipher_text\':|`cipher_text`:)\s*(\"[^\"]*\"|'[^']*'|`[^`]*`)",
        'GENERAL_CIPHER': r"(Cipher text:|cipher text:|cipher_text:|shifted text:|Caesar cipher:|The encoded text is:|encoded text:|cipher text is:|answer is|Answer:|Answer is:|encoded text is|output|Solution:)(?!\s*\?:?$)\s*((?!plain text|encoded text).+)"
    }
}

# =======================
# Initialization
# =======================
def setup_environment():
    """Initialize environment settings and suppress warnings"""
    random.seed(CONFIG['EXPERIMENT']['SEED'])
    torch.manual_seed(CONFIG['EXPERIMENT']['SEED'])
    torch.cuda.manual_seed_all(CONFIG['EXPERIMENT']['SEED'])
    
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    warnings.filterwarnings("ignore", message="huggingface/tokenizers: The current process just got forked")
    warnings.filterwarnings("ignore", category=UserWarning)

# =======================
# Encoder Class Definition
# =======================
class TextEncoder:
    def __init__(self, model_path, temperature=0.01):
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            quantization_config=quantization_config,
            device_map="auto"
        )
        self.temperature = temperature
        self.model_name = self._get_model_name(model_path)
    
    def _get_model_name(self, model_path):
        """Extract model name from path"""
        return model_path.split('models--')[-1].replace('/', '_') if "llama" in model_path else model_path.split('/')[-2]

    def encode(self, prompt_template, text, shift, max_tokens, prompt_type):
        prompt = prompt_template.replace("$SHIFT$", str(shift)) + text
        return self.generate_text(prompt, max_tokens, prompt_type)

    def generate_text(self, prompt, max_tokens, prompt_type):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model.device)
        attention_mask = inputs["attention_mask"].to(self.model.device)

        with torch.no_grad():
            outputs = self.model.generate(
                input_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_tokens,
                temperature=self.temperature,
                pad_token_id=self.tokenizer.eos_token_id
            )

        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        generated_text = "\n".join([line for line in generated_text.splitlines() if line.strip()]).replace("`", "")

        #print(f"GENERATED_TEXT: {generated_text}")
        
        return self._process_generated_text(generated_text, prompt_type)
    
    def _process_generated_text(self, generated_text, prompt_type):
        """Process the generated text based on prompt type"""
        if prompt_type == 'cot-like':
            return self._process_cot_like_text(generated_text)
        return self._process_regular_text(generated_text)
    
    def _process_cot_like_text(self, generated_text):
        """Process text for cot-like prompt type"""
        lookup_table_matches = re.findall(CONFIG['PATTERNS']['LOOKUP_TABLE'], generated_text)
        cipher_text_matches = re.findall(CONFIG['PATTERNS']['CIPHER_TEXT'], generated_text)
        
        final_lookup_table = next((match[1] for match in lookup_table_matches if match[1] and len(match[1]) > 2), None)
        final_cipher_text = next((match[1].strip("\"'`") for match in cipher_text_matches if match[1] and len(match[1].strip()) > 0), "")
        
        try:
            parsed_lookup_table = json.loads(final_lookup_table) if final_lookup_table else {}
        except json.JSONDecodeError:
            parsed_lookup_table = final_lookup_table if final_lookup_table else {}
            
        return parsed_lookup_table, final_cipher_text
    
    def _process_regular_text(self, generated_text):
        """Process text for regular prompt types"""
        matches = re.findall(CONFIG['PATTERNS']['GENERAL_CIPHER'], generated_text)
        # print(f"MATCHES: {matches}")
        
        cipher_texts = [match[-1].strip() for match in matches if match[-1].strip()]
        
        if not cipher_texts:
            fallback_match = re.findall(CONFIG['PATTERNS']['GENERAL_CIPHER'], generated_text)
            # print(f"FALLBACK_MATCH: {fallback_match}")
            if fallback_match:
                cipher_texts.append(fallback_match[0][-1].strip())

        if len(cipher_texts)>3:
            cipher_texts=cipher_texts[:-1]
        
        return {}, cipher_texts[-1] if cipher_texts else generated_text.strip()

def save_results(entry, json_file, is_last):
    """Save results to JSON file"""
    json.dump(entry, json_file, indent=4)
    json_file.write("\n" if is_last else ",\n")

def write_config_line(experiment_id, total_lines, shift, type_of_plain_text, prompt_type, model, temperature, max_tokens):
    """Write configuration to config file"""
    config_line = f"{experiment_id}_{total_lines}_{shift}_{type_of_plain_text}_{prompt_type}_encrypt_0-{model}_{temperature}_{max_tokens}\n"
    with open(CONFIG['PATHS']['CONFIG_FILE'], 'a') as config_file:
        config_file.write(config_line)

def process_model(model_path, lines_to_encode, shift, type_of_plain_text, model_index, total_models):
    """Process a single model"""
    encoder = TextEncoder(model_path, CONFIG['EXPERIMENT']['TEMPERATURE'])
    
    for prompt_type, prompt_template in CONFIG['PROMPTS'].items():
        max_tokens = CONFIG['GENERATION']['MAX_TOKENS'][prompt_type]
        print(f"\nStarting encoding with model {encoder.model_name} using prompt type '{prompt_type}' (Model {model_index + 1} of {total_models}) with shift {shift}")
        
        experiment_id = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file_path = os.path.join(CONFIG['PATHS']['OUTPUT_DIR'], f"{experiment_id}.json")
        
        write_config_line(experiment_id, len(lines_to_encode), shift, type_of_plain_text, 
                         prompt_type, encoder.model_name, encoder.temperature, max_tokens)
        
        with open(output_file_path, 'w') as json_file:
            json_file.write("[\n")
            
            for line_idx, line in enumerate(tqdm(lines_to_encode, desc=f"Encoding Progress with {encoder.model_name} ({prompt_type})", dynamic_ncols=True)):
                plain_text = line.strip()
                if plain_text:
                    lookup_table, cipher_text = encoder.encode(prompt_template, plain_text, shift, max_tokens, prompt_type)
                    # print(f"CIPHER_TEXT: {cipher_text}")
                    
                    entry = {
                        'plain_text': plain_text,
                        'shift': shift,
                        'cipher_text': cipher_text
                    }
                    if prompt_type == 'cot-like':
                        entry['lookup_table'] = lookup_table
                    
                    save_results(entry, json_file, line_idx == len(lines_to_encode) - 1)
            
            json_file.write("]\n")
        
        print(f"Finished encoding with {encoder.model_name} using prompt type '{prompt_type}' and shift {shift}. Results saved to {output_file_path}\n")
    
    # Clean up
    del encoder.model
    del encoder.tokenizer
    torch.cuda.empty_cache()
    print(f"Freed up GPU memory after processing {encoder.model_name}. Moving to the next model.\n")

def main():
    setup_environment()
    
    # Read input file
    with open(CONFIG['PATHS']['INPUT_FILE'], 'r') as file:
        lines = file.readlines()
    
    lines_to_encode = lines[:CONFIG['EXPERIMENT']['TOTAL_LINES']]
    type_of_plain_text = os.path.splitext(os.path.basename(CONFIG['PATHS']['INPUT_FILE']))[0]
    
    # Process each shift value and model
    for shift in CONFIG['EXPERIMENT']['SHIFT_VALUES']:
        for model_index, model_path in enumerate(CONFIG['MODELS']):
            process_model(model_path, lines_to_encode, shift, type_of_plain_text, 
                        model_index, len(CONFIG['MODELS']))
    
    print("All models, prompt types, and shift values have been processed. Exiting.\n")

if __name__ == "__main__":
    main()

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]


Starting encoding with model models--mistralai--Mixtral-8x7B-Instruct-v0.1 using prompt type 'cot-like' (Model 1 of 1) with shift 12


Encoding Progress with models--mistralai--Mixtral-8x7B-Instruct-v0.1 (cot-like):


Finished encoding with models--mistralai--Mixtral-8x7B-Instruct-v0.1 using prompt type 'cot-like' and shift 12. Results saved to ./data/encoded/caesar-cipher/20241202_090613.json


Starting encoding with model models--mistralai--Mixtral-8x7B-Instruct-v0.1 using prompt type 'default' (Model 1 of 1) with shift 12


Encoding Progress with models--mistralai--Mixtral-8x7B-Instruct-v0.1 (default): 

Finished encoding with models--mistralai--Mixtral-8x7B-Instruct-v0.1 using prompt type 'default' and shift 12. Results saved to ./data/encoded/caesar-cipher/20241202_100317.json

Freed up GPU memory after processing models--mistralai--Mixtral-8x7B-Instruct-v0.1. Moving to the next model.

All models, prompt types, and shift values have been processed. Exiting.




