Full Execution on Hindi Dataset for Bias Analysis

In [None]:
# -*- coding: utf-8 -*-
"""
This script fine-tunes Gemma-2B model and performs layer-wise
bias analysis before and after training.

Workflow:
1.  **Initial Bias Analysis**: Performs a layer-wise bias analysis on the base
    model (`google/gemma-2b`) using an English dataset before any fine-tuning.
2.  **Fine-tuning**: Fine-tunes the `google/gemma-2b` model on a Hindi dataset
    (`iamshnoo/alpaca-cleaned-hindi`).
3.  **Post-Epoch Bias Analysis**: After each fine-tuning epoch, the script runs the
    layer-wise bias analysis on the model for both English and Hindi to track how
    bias evolves.
4.  **Results**: All bias analysis results are saved to CSV files for further examination.
5.  **Upload**: After training, the script can upload the final model, tokenizer,
    and required custom code to the Hugging Face Hub.

This script can be executed with command-line arguments to specify the action
('train' or 'upload'), training mode ('test' or 'full'), and the platform
('colab', 'digitalocean', or 'local').
"""

import os
import gc
import sys
import csv
import argparse
import torch
import pandas as pd
import numpy as np
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    TrainerCallback,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
)
from datasets import load_dataset
from huggingface_hub import HfApi, login, whoami
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import glob

# =============================================================================
# GLOBAL CONFIGURATION
# =============================================================================
BASE_MODEL_NAME = "google/gemma-2b"  # Changed from Llama-3.2-1B to Gemma-2B
TOKENIZER_NAME = "google/gemma-2b"  # Use same tokenizer as model
DATASET_NAME = "iamshnoo/alpaca-cleaned-hindi"
HF_USERNAME = "Debk"  # <-- IMPORTANT: SET YOUR HUGGING FACE USERNAME HERE
NEW_MODEL_REPO_NAME = f"{HF_USERNAME}/gemma-2b-finetuned-alpaca-hindi"  # Updated for Gemma-2B
FINAL_MODEL_DIR = "./gemma-2b-hindi-final"  # Updated for Gemma-2B
BIAS_RESULTS_DIR = "./bias_analysis_results"

# Language mapping for WEATHub dataset codes vs full names
language_mapping = {
    'english': 'en',
    'hindi': 'hi', 
    'bengali': 'bn'
}

# Reverse mapping for storing full language names in results
reverse_language_mapping = {v: k for k, v in language_mapping.items()}

# =============================================================================
# PLATFORM-SPECIFIC CONFIGURATION
# =============================================================================

def setup_platform_environment(platform: str = "local"):
    """
    Configures the environment based on the specified platform.
    """
    print(f"Setting up environment for: {platform.upper()}")
    project_path = "./"
    results_path = BIAS_RESULTS_DIR
    hf_cache_dir = "./hf_cache/"

    os.makedirs(results_path, exist_ok=True)
    os.makedirs(hf_cache_dir, exist_ok=True)

    # Attempt to log in to Hugging Face if a token is available
    hf_token = os.environ.get('HF_TOKEN')
    if hf_token:
        try:
            login(hf_token, add_to_git_credential=True)
            print("✅ Successfully logged in to Hugging Face!")
        except Exception as e:
            print(f"❌ Failed to login to Hugging Face: {e}")
    else:
        print("Hugging Face token not found in environment variables. Login manually if needed.")

    print(f"Project path set to: {project_path}")
    print(f"Bias analysis results will be saved to: {results_path}")

    return project_path, results_path, hf_cache_dir


# =============================================================================
# LAYER-WISE BIAS ANALYSIS COMPONENTS
# =============================================================================

class LLMManager:
    """Manages the lifecycle of LLMs for analysis to optimize memory."""
    def __init__(self, cache_dir: str):
        self.cache_dir = cache_dir
        self.model = None
        self.tokenizer = None
        self.current_model_id = None

    def load_model(self, model_id: str, tokenizer_id: str, model_repo: str):
        """Load model and tokenizer for Gemma-2B."""  # Updated comment
        if self.current_model_id == model_id and self.model is not None:
            print(f"Model '{model_id}' already loaded.")
            return self.model, self.tokenizer

        print(f"Loading model: {model_id} and tokenizer: {tokenizer_id}")
        load_path = model_id if model_repo == 'hf' else model_repo

        try:
            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, cache_dir=self.cache_dir)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            # Configure 4-bit quantization
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16
            )

            # Load Gemma-2B model  # Updated comment
            self.model = AutoModelForCausalLM.from_pretrained(
                load_path,
                torch_dtype=torch.float16,
                device_map="auto",
                quantization_config=quantization_config,
                cache_dir=self.cache_dir
            )
            self.current_model_id = model_id
            print(f"Model '{model_id}' loaded successfully.")
            return self.model, self.tokenizer

        except Exception as e:
            print(f"ERROR: Failed to load model '{model_id}'. Exception: {e}")
            print("Please ensure you have access to the Gemma model and are logged in to Hugging Face.")  # Updated message
            return None, None

    def unload_model(self):
        """Unloads the model and clears GPU cache."""
        if self.model:
            print(f"Unloading model: {self.current_model_id}...")
            del self.model
            del self.tokenizer
            self.model, self.tokenizer, self.current_model_id = None, None, None
            gc.collect()
            torch.cuda.empty_cache()
            print("Model unloaded and memory cleared.")

class WEATHubLoader:
    """Loads the WEATHub dataset and provides word lists."""
    def __init__(self, dataset_id: str, cache_dir: str = None):
        print(f"Loading WEATHub dataset from '{dataset_id}'...")
        try:
            self.dataset = load_dataset(dataset_id, cache_dir=cache_dir)
            print("WEATHub dataset loaded successfully.")
            self.split_mapping = {
                'WEAT1': 'original_weat', 'WEAT2': 'original_weat', 'WEAT6': 'original_weat', 'WEAT7': 'original_weat', 'WEAT8': 'original_weat'
            }
        except Exception as e:
            print(f"ERROR: Failed to load WEATHub dataset. Exception: {e}")
            self.dataset = None

    def get_word_lists(self, language_code: str, weat_category_id: str):
        """Retrieves target and attribute word lists."""
        if not self.dataset: return None
        split_name = self.split_mapping.get(weat_category_id)
        if not split_name:
            print(f"Warning: Category '{weat_category_id}' not found.")
            return None
        try:
            filtered = self.dataset[split_name].filter(lambda x: x['language'] == language_code and x['weat'] == weat_category_id)
            if len(filtered) > 0:
                return { 'targ1': filtered[0]['targ1.examples'], 'targ2': filtered[0]['targ2.examples'], 'attr1': filtered[0]['attr1.examples'], 'attr2': filtered[0]['attr2.examples'] }
            else:
                print(f"Warning: No data for language '{language_code}' and category '{weat_category_id}'.")
                return None
        except Exception as e:
            print(f"Error filtering data for '{weat_category_id}' in language '{language_code}': {e}")
            return None

class LayerEmbeddingExtractor:
    """Extracts hidden states from model layers."""
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.device = model.device

    @torch.no_grad()
    def get_embeddings(self, words: list, layer_idx: int):
        """Gets embeddings for a list of words at a specific layer."""
        all_embeddings = []
        for word in words:
            inputs = self.tokenizer(word, return_tensors="pt", add_special_tokens=False).to(self.device)
            outputs = self.model(**inputs, output_hidden_states=True)
            # Use the hidden states from the specified layer
            word_embedding = outputs.hidden_states[layer_idx][0].mean(dim=0).float().cpu().numpy()
            all_embeddings.append(word_embedding)
        return np.array(all_embeddings)

class BiasQuantifier:
    """Calculates bias scores using WEAT effect size."""
    def _s(self, w, A, B):
        mean_cos_A = np.mean([cosine_similarity([w], [a])[0][0] for a in A])
        mean_cos_B = np.mean([cosine_similarity([w], [b])[0][0] for b in B])
        return mean_cos_A - mean_cos_B

    def weat_effect_size(self, T1_embeds, T2_embeds, A1_embeds, A2_embeds):
        """Calculates the WEAT effect size (d-score)."""
        mean_T1 = np.mean([self._s(t, A1_embeds, A2_embeds) for t in T1_embeds])
        mean_T2 = np.mean([self._s(t, A1_embeds, A2_embeds) for t in T2_embeds])
        all_s = [self._s(t, A1_embeds, A2_embeds) for t in np.concatenate((T1_embeds, T2_embeds))]
        std_dev = np.std(all_s, ddof=1)
        return (mean_T1 - mean_T2) / std_dev if std_dev > 0 else 0

def create_detailed_comment(base_comment: str, language: str = "hindi", dataset: str = "alpaca", model: str = "gemma-2b", mode: str = None):
    """Creates a detailed comment for logging purposes."""
    # Updated to use gemma-2b in comments
    detailed_comment = f"{base_comment} {language} finetune on {model}"
    return detailed_comment

def execute_bias_analysis(model, tokenizer, results_path: str, hf_cache_dir: str, model_name: str, comments: str, languages: list, mode: str = None):
    """Runs the layer-wise bias analysis and saves the results."""
    weathub_loader = WEATHubLoader(dataset_id='iamshnoo/WEATHub', cache_dir=os.path.join(hf_cache_dir, "datasets"))
    bias_quantifier = BiasQuantifier()
    num_layers = len(model.model.layers)  # For Gemma models (same structure as Llama)
    embedding_extractor = LayerEmbeddingExtractor(model, tokenizer)
    all_results = []
    weat_categories_to_test = ['WEAT1', 'WEAT2', 'WEAT6']
    
    # Updated to use gemma-2b in comments
    detailed_comment = create_detailed_comment(comments, model="gemma-2b", mode=mode)

    for lang_full in languages:
        # Convert full language name to code for WEATHub
        lang_code = language_mapping.get(lang_full, lang_full)
        for weat_cat in weat_categories_to_test:
            print(f"\nProcessing: Lang='{lang_full}' (code: {lang_code}), Category='{weat_cat}'")
            word_lists = weathub_loader.get_word_lists(lang_code, weat_cat)
            if not word_lists: continue
            for layer_idx in tqdm(range(num_layers), desc=f"Layer Analysis ({lang_full}/{weat_cat})"):
                t1_embeds = embedding_extractor.get_embeddings(word_lists['targ1'], layer_idx)
                t2_embeds = embedding_extractor.get_embeddings(word_lists['targ2'], layer_idx)
                a1_embeds = embedding_extractor.get_embeddings(word_lists['attr1'], layer_idx)
                a2_embeds = embedding_extractor.get_embeddings(word_lists['attr2'], layer_idx)
                weat_score = bias_quantifier.weat_effect_size(t1_embeds, t2_embeds, a1_embeds, a2_embeds)
                # Store full language name in results
                all_results.append({'model_id': model_name, 'language': lang_full, 'weat_category_id': weat_cat, 'layer_idx': layer_idx, 'weat_score': weat_score, 'comments': detailed_comment})

    if all_results:
        results_df = pd.DataFrame(all_results)
        filename = f"bias_results_{model_name.replace('/', '_')}_{detailed_comment.replace(' ', '_')}.csv"
        filepath = os.path.join(results_path, filename)
        results_df.to_csv(filepath, index=False)
        print(f"Results successfully saved to: {filepath}")
    else:
        print("No results were generated.")
    print("\nAnalysis complete.")


# =============================================================================
# FINE-TUNING AND UPLOAD COMPONENTS
# =============================================================================

def create_prompt(example):
    """Creates a formatted instruction prompt from a dataset example."""
    template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
{output}"""
    # Add EOS token for proper training
    return template.format(instruction=example["instruction"], output=example['output'] + "</s>")

class BiasAnalysisCallback(TrainerCallback):
    """A custom TrainerCallback that runs bias analysis at the end of each epoch."""
    def __init__(self, tokenizer, results_path, hf_cache_dir, model_name, mode):
        self.tokenizer = tokenizer
        self.results_path = results_path
        self.hf_cache_dir = hf_cache_dir
        self.model_name = model_name
        self.mode = mode

    def on_epoch_end(self, args, state, control, **kwargs):
        epoch = int(state.epoch)
        model = kwargs['model']
        print(f"\n--- Running Bias Analysis for Epoch {epoch} ---")
        execute_bias_analysis(model, self.tokenizer, self.results_path, self.hf_cache_dir, self.model_name, f"After epoch {epoch}", ['hindi', 'english'], self.mode)
        print(f"--- Bias Analysis for Epoch {epoch} Completed ---")

def upload_to_hf(model_path, repo_name):
    """Uploads a model folder and associated artifacts to the Hugging Face Hub."""
    print(f"Starting upload of '{model_path}' to '{repo_name}'...")
    
    try:
        user_info = whoami()
        current_user = user_info.get('name')
        print(f"✅ Authenticated as: {current_user}")
        
        expected_user = repo_name.split('/')[0]
        if current_user != expected_user:
            print(f"⚠️ WARNING: Authenticated user '{current_user}' doesn't match repo owner '{expected_user}'")
            repo_name = f"{current_user}/{repo_name.split('/')[1]}"
            print(f"🔄 Using corrected repo name: {repo_name}")
            
    except Exception as e:
        print(f"❌ Authentication check failed: {e}. Please log in first.")
        return False
    
    if not os.path.exists(model_path):
        print(f"❌ Model directory '{model_path}' not found!")
        return False
    
    try:
        api = HfApi()
        print(f"📁 Creating repository: {repo_name}")
        api.create_repo(repo_id=repo_name, repo_type="model", exist_ok=True)
        
        print(f"📤 Uploading folder: {model_path}")
        api.upload_folder(
            folder_path=model_path, 
            repo_id=repo_name, 
            repo_type="model",
            commit_message=f"Upload fine-tuned Gemma-2B model for Hindi"  # Updated message
        )
        
        print("✅ Upload completed successfully!")
        print(f"🔗 View your model at: https://huggingface.co/{repo_name}")
        return True
        
    except Exception as e:
        print(f"❌ Upload failed with error: {e}")
        return False

def merge_csv_files(results_path: str, model_name: str):
    """Merge all CSV files for this model into one consolidated file."""
    print(f"\n--- Merging CSV files for {model_name} ---")
    
    # Find all CSV files for this model
    model_clean = model_name.replace('/', '_')
    pattern = os.path.join(results_path, f"bias_results_{model_clean}_*.csv")
    csv_files = glob.glob(pattern)
    
    if not csv_files:
        print(f"No CSV files found for model {model_name}")
        return
    
    print(f"Found {len(csv_files)} CSV files to merge:")
    for file in csv_files:
        print(f"  - {os.path.basename(file)}")
    
    # Read and combine all CSV files
    all_dataframes = []
    for file in csv_files:
        try:
            df = pd.read_csv(file)
            all_dataframes.append(df)
            print(f"  ✅ Loaded {file} with {len(df)} rows")
        except Exception as e:
            print(f"  ❌ Error loading {file}: {e}")
    
    if all_dataframes:
        # Merge all dataframes
        merged_df = pd.concat(all_dataframes, ignore_index=True)
        
        # Sort by comments (to group before/after analyses together) and then by epoch
        merged_df = merged_df.sort_values(['comments', 'language', 'weat_category_id', 'layer_idx'])
        
        # Save merged file
        merged_filename = f"bias_results_{model_clean}_merged_all_epochs.csv"
        merged_filepath = os.path.join(results_path, merged_filename)
        merged_df.to_csv(merged_filepath, index=False)
        
        print(f"✅ Merged CSV saved to: {merged_filepath}")
        print(f"Total rows in merged file: {len(merged_df)}")
        
        # Show summary
        print("\nSummary of merged data:")
        print(f"Languages: {sorted(merged_df['language'].unique())}")
        print(f"WEAT categories: {sorted(merged_df['weat_category_id'].unique())}")
        print(f"Comments (analysis stages): {sorted(merged_df['comments'].unique())}")
        
    else:
        print("No valid CSV files could be loaded for merging.")

# =============================================================================
# MAIN EXECUTION FUNCTION
# =============================================================================

def main(args):
    """Main function to orchestrate the fine-tuning and analysis process."""
    project_path, results_path, hf_cache_dir = setup_platform_environment(args.platform)
    
    repo_name_with_mode = f"{NEW_MODEL_REPO_NAME}_{args.mode}"

    if args.action == 'upload':
        if not os.path.exists(FINAL_MODEL_DIR):
            print(f"Error: Final model directory '{FINAL_MODEL_DIR}' not found. Please run training first.")
            return
        upload_to_hf(FINAL_MODEL_DIR, repo_name_with_mode)
        return

    # --- Initial Bias Analysis (Before Fine-tuning) ---
    print("\n--- Running Initial Bias Analysis on Base Gemma-2B Model ---")  # Updated message
    llm_manager = LLMManager(cache_dir=hf_cache_dir)
    base_model, base_tokenizer = llm_manager.load_model(BASE_MODEL_NAME, TOKENIZER_NAME, 'hf')
    if base_model and base_tokenizer:
        execute_bias_analysis(base_model, base_tokenizer, results_path, hf_cache_dir, BASE_MODEL_NAME, "Before", ['english'], args.mode)
    llm_manager.unload_model()
    print("--- Initial Bias Analysis Completed ---")

    # --- Fine-tuning ---
    print("\n--- Preparing for Fine-tuning Gemma-2B ---")  # Updated message
    dataset = load_dataset(DATASET_NAME, split="train")
    
    # Load the Gemma-2B tokenizer  # Updated comment
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, cache_dir=hf_cache_dir)
    if tokenizer.pad_token is None:
        print("Padding token not found. Setting pad_token to eos_token.")
        tokenizer.pad_token = tokenizer.eos_token

    # Load the Gemma-2B model for fine-tuning  # Updated comment
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_NAME, 
        torch_dtype=torch.bfloat16,
        device_map="auto",
        cache_dir=hf_cache_dir
    )
    
    # Synchronize model embeddings with tokenizer if needed
    print(f"Original model vocab size: {model.config.vocab_size}, Tokenizer vocab size: {len(tokenizer)}")
    if model.config.vocab_size != len(tokenizer):
        print("Resizing model token embeddings to match tokenizer...")
        model.resize_token_embeddings(len(tokenizer))
        print(f"New model vocab size: {model.config.vocab_size}")
    
    # Ensure model's pad_token_id is configured
    model.config.pad_token_id = tokenizer.pad_token_id

    dataset_with_prompt = dataset.map(lambda example: {"text": create_prompt(example)})
    tokenized_dataset = dataset_with_prompt.map(lambda ex: tokenizer(ex["text"], truncation=True, max_length=512), batched=True, remove_columns=dataset.column_names)

    num_train_epochs = 1 if args.mode == 'test' else 5 # Use 5 epochs for full mode

    training_args = TrainingArguments(
        output_dir="./gemma-2b-hindi-tuned",  # Updated directory name
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8,
        learning_rate=2e-5,
        num_train_epochs=num_train_epochs,
        max_steps=-1,
        bf16=True,
        logging_strategy="steps",
        logging_steps=10,
        save_strategy="epoch",
        report_to="none",
    )

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    if args.mode == 'test':
        shuffled_dataset = tokenized_dataset.shuffle(seed=42)
        # Use a small subset for testing
        train_dataset = shuffled_dataset.select(range(100))
        eval_dataset = shuffled_dataset.select(range(100, 120))
    else:
        split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
        train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']

    bias_analysis_callback = BiasAnalysisCallback(tokenizer, results_path, hf_cache_dir, BASE_MODEL_NAME, args.mode)
    
    trainer = Trainer(
        model=model, 
        args=training_args, 
        train_dataset=train_dataset, 
        eval_dataset=eval_dataset, 
        data_collator=data_collator, 
        tokenizer=tokenizer, 
        callbacks=[bias_analysis_callback]
    )

    print(f"--- Starting Fine-tuning (Mode: {args.mode}) ---")
    trainer.train()
    print("--- Fine-tuning Completed ---")

    print(f"Saving final model to {FINAL_MODEL_DIR}")
    trainer.save_model(FINAL_MODEL_DIR)
    
    # --- Merge all CSV files ---
    merge_csv_files(results_path, BASE_MODEL_NAME)
    
    print("\n--- Starting Automatic Upload to HuggingFace Hub ---")
    if os.path.exists(FINAL_MODEL_DIR):
        upload_to_hf(FINAL_MODEL_DIR, repo_name_with_mode)
        print("--- Upload to HuggingFace Hub Completed ---")
    else:
        print(f"Error: Final model directory '{FINAL_MODEL_DIR}' not found.")


if __name__ == "__main__":
    is_notebook = 'google.colab' in sys.modules or 'ipykernel' in sys.modules

    if is_notebook:
        print("Running in a notebook environment. Setting arguments manually.")
        # Manually set args for notebook execution
        args = argparse.Namespace(action="train", mode="full", platform="local")
        main(args)
    else:
        parser = argparse.ArgumentParser(description="Fine-tune and analyze Gemma-2B.")  # Updated description
        parser.add_argument("--action", type=str, default="train", choices=["train", "upload"], help="Action to perform.")
        parser.add_argument("--mode", type=str, default="test", choices=["test", "full"], help="Training mode (test uses a small subset).")
        parser.add_argument("--platform", type=str, default="local", choices=["colab", "digitalocean", "local"], help="Execution platform.")
        parsed_args = parser.parse_args()
        main(parsed_args)

For Execution on Bengali

In [None]:
# -*- coding: utf-8 -*-
"""
This script fine-tunes a Gemma-2B model that was previously fine-tuned on Hindi data 
and performs layer-wise bias analysis before and after training on Bengali data.

Workflow:
1.  **Initial Bias Analysis**: Performs a layer-wise bias analysis on the base
    model (`Debk/gemma-2b-finetuned-alpaca-hindi_full`) using Hindi and English datasets before any Bengali fine-tuning.
2.  **Fine-tuning**: Fine-tunes the `Debk/gemma-2b-finetuned-alpaca-hindi_full` model on a Bengali dataset
    (`iamshnoo/alpaca-cleaned-bengali`).
3.  **Post-Epoch Bias Analysis**: After each fine-tuning epoch, the script runs the
    layer-wise bias analysis on the model for Hindi, English, and Bengali to track how
    bias evolves.
4.  **Results**: All bias analysis results are saved to CSV files for further examination.
5.  **Upload**: After training, the script can upload the final model, tokenizer,
    and required custom code to the Hugging Face Hub.

This script can be executed with command-line arguments to specify the action
('train' or 'upload'), training mode ('test' or 'full'), and the platform
('colab', 'digitalocean', or 'local').
"""

import os
import gc
import sys
import csv
import argparse
import torch
import pandas as pd
import numpy as np
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    TrainerCallback,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
)
from datasets import load_dataset
from huggingface_hub import HfApi, login, whoami
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import glob

# =============================================================================
# GLOBAL CONFIGURATION
# =============================================================================
BASE_MODEL_NAME = "Debk/gemma-2b-finetuned-alpaca-hindi_full"  # Fine-tuned Hindi model to build upon
TOKENIZER_NAME = "Debk/gemma-2b-finetuned-alpaca-hindi_full"  # Use same tokenizer as model
DATASET_NAME = "iamshnoo/alpaca-cleaned-bengali"  # Bengali dataset for further finetuning
HF_USERNAME = "Debk"  # <-- IMPORTANT: SET YOUR HUGGING FACE USERNAME HERE
NEW_MODEL_REPO_NAME = f"{HF_USERNAME}/gemma-2b-finetuned-alpaca-hindi-bengali"  # Updated for Hindi-Bengali model
FINAL_MODEL_DIR = "./gemma-2b-hindi-bengali-final"  # Updated for Hindi-Bengali model
BIAS_RESULTS_DIR = "./bias_analysis_results"

# Language mapping for WEATHub dataset codes vs full names
language_mapping = {
    'english': 'en',
    'hindi': 'hi', 
    'bengali': 'bn'
}

# Reverse mapping for storing full language names in results
reverse_language_mapping = {v: k for k, v in language_mapping.items()}

# =============================================================================
# PLATFORM-SPECIFIC CONFIGURATION
# =============================================================================

def setup_platform_environment(platform: str = "local"):
    """
    Configures the environment based on the specified platform.
    """
    print(f"Setting up environment for: {platform.upper()}")
    project_path = "./"
    results_path = BIAS_RESULTS_DIR
    hf_cache_dir = "./hf_cache/"

    os.makedirs(results_path, exist_ok=True)
    os.makedirs(hf_cache_dir, exist_ok=True)

    # Attempt to log in to Hugging Face if a token is available
    hf_token = os.environ.get('HF_TOKEN')
    if hf_token:
        try:
            login(hf_token, add_to_git_credential=True)
            print("✅ Successfully logged in to Hugging Face!")
        except Exception as e:
            print(f"❌ Failed to login to Hugging Face: {e}")
    else:
        print("Hugging Face token not found in environment variables. Login manually if needed.")

    print(f"Project path set to: {project_path}")
    print(f"Bias analysis results will be saved to: {results_path}")

    return project_path, results_path, hf_cache_dir


# =============================================================================
# LAYER-WISE BIAS ANALYSIS COMPONENTS
# =============================================================================

class LLMManager:
    """Manages the lifecycle of LLMs for analysis to optimize memory."""
    def __init__(self, cache_dir: str):
        self.cache_dir = cache_dir
        self.model = None
        self.tokenizer = None
        self.current_model_id = None

    def load_model(self, model_id: str, tokenizer_id: str, model_repo: str):
        """Load model and tokenizer for fine-tuned Gemma-2B (Hindi)."""  # Updated comment
        if self.current_model_id == model_id and self.model is not None:
            print(f"Model '{model_id}' already loaded.")
            return self.model, self.tokenizer

        print(f"Loading model: {model_id} and tokenizer: {tokenizer_id}")
        load_path = model_id if model_repo == 'hf' else model_repo

        try:
            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, cache_dir=self.cache_dir)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            # Configure 4-bit quantization
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16
            )

            # Load fine-tuned Gemma-2B model (Hindi)  # Updated comment
            self.model = AutoModelForCausalLM.from_pretrained(
                load_path,
                torch_dtype=torch.float16,
                device_map="auto",
                quantization_config=quantization_config,
                cache_dir=self.cache_dir
            )
            self.current_model_id = model_id
            print(f"Model '{model_id}' loaded successfully.")
            return self.model, self.tokenizer

        except Exception as e:
            print(f"ERROR: Failed to load model '{model_id}'. Exception: {e}")
            print("Please ensure you have access to the fine-tuned Gemma model and are logged in to Hugging Face.")  # Updated message
            return None, None

    def unload_model(self):
        """Unloads the model and clears GPU cache."""
        if self.model:
            print(f"Unloading model: {self.current_model_id}...")
            del self.model
            del self.tokenizer
            self.model, self.tokenizer, self.current_model_id = None, None, None
            gc.collect()
            torch.cuda.empty_cache()
            print("Model unloaded and memory cleared.")

class WEATHubLoader:
    """Loads the WEATHub dataset and provides word lists."""
    def __init__(self, dataset_id: str, cache_dir: str = None):
        print(f"Loading WEATHub dataset from '{dataset_id}'...")
        try:
            self.dataset = load_dataset(dataset_id, cache_dir=cache_dir)
            print("WEATHub dataset loaded successfully.")
            self.split_mapping = {
                'WEAT1': 'original_weat', 'WEAT2': 'original_weat', 'WEAT6': 'original_weat', 'WEAT7': 'original_weat', 'WEAT8': 'original_weat'
            }
        except Exception as e:
            print(f"ERROR: Failed to load WEATHub dataset. Exception: {e}")
            self.dataset = None

    def get_word_lists(self, language_code: str, weat_category_id: str):
        """Retrieves target and attribute word lists."""
        if not self.dataset: return None
        split_name = self.split_mapping.get(weat_category_id)
        if not split_name:
            print(f"Warning: Category '{weat_category_id}' not found.")
            return None
        try:
            filtered = self.dataset[split_name].filter(lambda x: x['language'] == language_code and x['weat'] == weat_category_id)
            if len(filtered) > 0:
                return { 'targ1': filtered[0]['targ1.examples'], 'targ2': filtered[0]['targ2.examples'], 'attr1': filtered[0]['attr1.examples'], 'attr2': filtered[0]['attr2.examples'] }
            else:
                print(f"Warning: No data for language '{language_code}' and category '{weat_category_id}'.")
                return None
        except Exception as e:
            print(f"Error filtering data for '{weat_category_id}' in language '{language_code}': {e}")
            return None

class LayerEmbeddingExtractor:
    """Extracts hidden states from model layers."""
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.device = model.device

    @torch.no_grad()
    def get_embeddings(self, words: list, layer_idx: int):
        """Gets embeddings for a list of words at a specific layer."""
        all_embeddings = []
        for word in words:
            inputs = self.tokenizer(word, return_tensors="pt", add_special_tokens=False).to(self.device)
            outputs = self.model(**inputs, output_hidden_states=True)
            # Use the hidden states from the specified layer
            word_embedding = outputs.hidden_states[layer_idx][0].mean(dim=0).float().cpu().numpy()
            all_embeddings.append(word_embedding)
        return np.array(all_embeddings)

class BiasQuantifier:
    """Calculates bias scores using WEAT effect size."""
    def _s(self, w, A, B):
        mean_cos_A = np.mean([cosine_similarity([w], [a])[0][0] for a in A])
        mean_cos_B = np.mean([cosine_similarity([w], [b])[0][0] for b in B])
        return mean_cos_A - mean_cos_B

    def weat_effect_size(self, T1_embeds, T2_embeds, A1_embeds, A2_embeds):
        """Calculates the WEAT effect size (d-score)."""
        mean_T1 = np.mean([self._s(t, A1_embeds, A2_embeds) for t in T1_embeds])
        mean_T2 = np.mean([self._s(t, A1_embeds, A2_embeds) for t in T2_embeds])
        all_s = [self._s(t, A1_embeds, A2_embeds) for t in np.concatenate((T1_embeds, T2_embeds))]
        std_dev = np.std(all_s, ddof=1)
        return (mean_T1 - mean_T2) / std_dev if std_dev > 0 else 0

def create_detailed_comment(base_comment: str, language: str = "bengali", dataset: str = "alpaca", model: str = "gemma-2b-finetuned-alpaca-hindi_full", mode: str = None):
    """Creates a detailed comment for logging purposes."""
    # Updated to reflect Bengali finetuning on Hindi-finetuned model
    detailed_comment = f"{base_comment} {language} finetune on {model}"
    return detailed_comment

def execute_bias_analysis(model, tokenizer, results_path: str, hf_cache_dir: str, model_name: str, comments: str, languages: list, mode: str = None):
    """Runs the layer-wise bias analysis and saves the results."""
    weathub_loader = WEATHubLoader(dataset_id='iamshnoo/WEATHub', cache_dir=os.path.join(hf_cache_dir, "datasets"))
    bias_quantifier = BiasQuantifier()
    num_layers = len(model.model.layers)  # For Gemma models (same structure as Llama)
    embedding_extractor = LayerEmbeddingExtractor(model, tokenizer)
    all_results = []
    weat_categories_to_test = ['WEAT1', 'WEAT2', 'WEAT6']
    
    # Updated to reflect Bengali finetuning on Hindi-finetuned model
    detailed_comment = create_detailed_comment(comments, model="gemma-2b-finetuned-alpaca-hindi_full", mode=mode)

    for lang_full in languages:
        # Convert full language name to code for WEATHub
        lang_code = language_mapping.get(lang_full, lang_full)
        for weat_cat in weat_categories_to_test:
            print(f"\nProcessing: Lang='{lang_full}' (code: {lang_code}), Category='{weat_cat}'")
            word_lists = weathub_loader.get_word_lists(lang_code, weat_cat)
            if not word_lists: continue
            for layer_idx in tqdm(range(num_layers), desc=f"Layer Analysis ({lang_full}/{weat_cat})"):
                t1_embeds = embedding_extractor.get_embeddings(word_lists['targ1'], layer_idx)
                t2_embeds = embedding_extractor.get_embeddings(word_lists['targ2'], layer_idx)
                a1_embeds = embedding_extractor.get_embeddings(word_lists['attr1'], layer_idx)
                a2_embeds = embedding_extractor.get_embeddings(word_lists['attr2'], layer_idx)
                weat_score = bias_quantifier.weat_effect_size(t1_embeds, t2_embeds, a1_embeds, a2_embeds)
                # Store full language name in results
                all_results.append({'model_id': model_name, 'language': lang_full, 'weat_category_id': weat_cat, 'layer_idx': layer_idx, 'weat_score': weat_score, 'comments': detailed_comment})

    if all_results:
        results_df = pd.DataFrame(all_results)
        filename = f"bias_results_{model_name.replace('/', '_')}_{detailed_comment.replace(' ', '_')}.csv"
        filepath = os.path.join(results_path, filename)
        results_df.to_csv(filepath, index=False)
        print(f"Results successfully saved to: {filepath}")
    else:
        print("No results were generated.")
    print("\nAnalysis complete.")


# =============================================================================
# FINE-TUNING AND UPLOAD COMPONENTS
# =============================================================================

def create_prompt(example):
    """Creates a formatted instruction prompt from a dataset example."""
    template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
{output}"""
    # Add EOS token for proper training
    return template.format(instruction=example["instruction"], output=example['output'] + "</s>")

class BiasAnalysisCallback(TrainerCallback):
    """A custom TrainerCallback that runs bias analysis at the end of each epoch."""
    def __init__(self, tokenizer, results_path, hf_cache_dir, model_name, mode):
        self.tokenizer = tokenizer
        self.results_path = results_path
        self.hf_cache_dir = hf_cache_dir
        self.model_name = model_name
        self.mode = mode

    def on_epoch_end(self, args, state, control, **kwargs):
        epoch = int(state.epoch)
        model = kwargs['model']
        print(f"\n--- Running Bias Analysis for Epoch {epoch} ---")
        execute_bias_analysis(model, self.tokenizer, self.results_path, self.hf_cache_dir, self.model_name, f"After epoch {epoch}", ['hindi', 'english', 'bengali'], self.mode)
        print(f"--- Bias Analysis for Epoch {epoch} Completed ---")

def upload_to_hf(model_path, repo_name):
    """Uploads a model folder and associated artifacts to the Hugging Face Hub."""
    print(f"Starting upload of '{model_path}' to '{repo_name}'...")
    
    try:
        user_info = whoami()
        current_user = user_info.get('name')
        print(f"✅ Authenticated as: {current_user}")
        
        expected_user = repo_name.split('/')[0]
        if current_user != expected_user:
            print(f"⚠️ WARNING: Authenticated user '{current_user}' doesn't match repo owner '{expected_user}'")
            repo_name = f"{current_user}/{repo_name.split('/')[1]}"
            print(f"🔄 Using corrected repo name: {repo_name}")
            
    except Exception as e:
        print(f"❌ Authentication check failed: {e}. Please log in first.")
        return False
    
    if not os.path.exists(model_path):
        print(f"❌ Model directory '{model_path}' not found!")
        return False
    
    try:
        api = HfApi()
        print(f"📁 Creating repository: {repo_name}")
        api.create_repo(repo_id=repo_name, repo_type="model", exist_ok=True)
        
        print(f"📤 Uploading folder: {model_path}")
        api.upload_folder(
            folder_path=model_path, 
            repo_id=repo_name, 
            repo_type="model",
            commit_message=f"Upload fine-tuned Gemma-2B model for Hindi-Bengali"  # Updated message
        )
        
        print("✅ Upload completed successfully!")
        print(f"🔗 View your model at: https://huggingface.co/{repo_name}")
        return True
        
    except Exception as e:
        print(f"❌ Upload failed with error: {e}")
        return False

def merge_csv_files(results_path: str, model_name: str):
    """Merge all CSV files for this model into one consolidated file."""
    print(f"\n--- Merging CSV files for {model_name} ---")
    
    # Find all CSV files for this model
    model_clean = model_name.replace('/', '_')
    pattern = os.path.join(results_path, f"bias_results_{model_clean}_*.csv")
    csv_files = glob.glob(pattern)
    
    if not csv_files:
        print(f"No CSV files found for model {model_name}")
        return
    
    print(f"Found {len(csv_files)} CSV files to merge:")
    for file in csv_files:
        print(f"  - {os.path.basename(file)}")
    
    # Read and combine all CSV files
    all_dataframes = []
    for file in csv_files:
        try:
            df = pd.read_csv(file)
            all_dataframes.append(df)
            print(f"  ✅ Loaded {file} with {len(df)} rows")
        except Exception as e:
            print(f"  ❌ Error loading {file}: {e}")
    
    if all_dataframes:
        # Merge all dataframes
        merged_df = pd.concat(all_dataframes, ignore_index=True)
        
        # Sort by comments (to group before/after analyses together) and then by epoch
        merged_df = merged_df.sort_values(['comments', 'language', 'weat_category_id', 'layer_idx'])
        
        # Save merged file
        merged_filename = f"bias_results_{model_clean}_merged_all_epochs.csv"
        merged_filepath = os.path.join(results_path, merged_filename)
        merged_df.to_csv(merged_filepath, index=False)
        
        print(f"✅ Merged CSV saved to: {merged_filepath}")
        print(f"Total rows in merged file: {len(merged_df)}")
        
        # Show summary
        print("\nSummary of merged data:")
        print(f"Languages: {sorted(merged_df['language'].unique())}")
        print(f"WEAT categories: {sorted(merged_df['weat_category_id'].unique())}")
        print(f"Comments (analysis stages): {sorted(merged_df['comments'].unique())}")
        
    else:
        print("No valid CSV files could be loaded for merging.")

# =============================================================================
# MAIN EXECUTION FUNCTION
# =============================================================================

def main(args):
    """Main function to orchestrate the fine-tuning and analysis process."""
    project_path, results_path, hf_cache_dir = setup_platform_environment(args.platform)
    
    repo_name_with_mode = f"{NEW_MODEL_REPO_NAME}_{args.mode}"

    if args.action == 'upload':
        if not os.path.exists(FINAL_MODEL_DIR):
            print(f"Error: Final model directory '{FINAL_MODEL_DIR}' not found. Please run training first.")
            return
        upload_to_hf(FINAL_MODEL_DIR, repo_name_with_mode)
        return

    # --- Initial Bias Analysis (Before Fine-tuning) ---
    print("\n--- Running Initial Bias Analysis on Hindi-finetuned Gemma-2B Model ---")  # Updated message
    llm_manager = LLMManager(cache_dir=hf_cache_dir)
    base_model, base_tokenizer = llm_manager.load_model(BASE_MODEL_NAME, TOKENIZER_NAME, 'hf')
    if base_model and base_tokenizer:
        execute_bias_analysis(base_model, base_tokenizer, results_path, hf_cache_dir, BASE_MODEL_NAME, "Before", ['hindi', 'english'], args.mode)
    llm_manager.unload_model()
    print("--- Initial Bias Analysis Completed ---")

    # --- Fine-tuning ---
    print("\n--- Preparing for Bengali Fine-tuning on Hindi-finetuned Gemma-2B ---")  # Updated message
    dataset = load_dataset(DATASET_NAME, split="train")
    
    # Load the fine-tuned Gemma-2B tokenizer  # Updated comment
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, cache_dir=hf_cache_dir)
    if tokenizer.pad_token is None:
        print("Padding token not found. Setting pad_token to eos_token.")
        tokenizer.pad_token = tokenizer.eos_token

    # Load the Hindi-finetuned Gemma-2B model for Bengali fine-tuning  # Updated comment
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_NAME, 
        torch_dtype=torch.bfloat16,
        device_map="auto",
        cache_dir=hf_cache_dir
    )
    
    # Synchronize model embeddings with tokenizer if needed
    print(f"Original model vocab size: {model.config.vocab_size}, Tokenizer vocab size: {len(tokenizer)}")
    if model.config.vocab_size != len(tokenizer):
        print("Resizing model token embeddings to match tokenizer...")
        model.resize_token_embeddings(len(tokenizer))
        print(f"New model vocab size: {model.config.vocab_size}")
    
    # Ensure model's pad_token_id is configured
    model.config.pad_token_id = tokenizer.pad_token_id

    dataset_with_prompt = dataset.map(lambda example: {"text": create_prompt(example)})
    tokenized_dataset = dataset_with_prompt.map(lambda ex: tokenizer(ex["text"], truncation=True, max_length=512), batched=True, remove_columns=dataset.column_names)

    num_train_epochs = 1 if args.mode == 'test' else 5 # Use 5 epochs for full mode

    training_args = TrainingArguments(
        output_dir="./gemma-2b-hindi-bengali-tuned",  # Updated directory name
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8,
        learning_rate=2e-5,
        num_train_epochs=num_train_epochs,
        max_steps=-1,
        bf16=True,
        logging_strategy="steps",
        logging_steps=10,
        save_strategy="epoch",
        report_to="none",
    )

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    if args.mode == 'test':
        shuffled_dataset = tokenized_dataset.shuffle(seed=42)
        # Use a small subset for testing
        train_dataset = shuffled_dataset.select(range(100))
        eval_dataset = shuffled_dataset.select(range(100, 120))
    else:
        split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
        train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']

    bias_analysis_callback = BiasAnalysisCallback(tokenizer, results_path, hf_cache_dir, BASE_MODEL_NAME, args.mode)
    
    trainer = Trainer(
        model=model, 
        args=training_args, 
        train_dataset=train_dataset, 
        eval_dataset=eval_dataset, 
        data_collator=data_collator, 
        tokenizer=tokenizer, 
        callbacks=[bias_analysis_callback]
    )

    print(f"--- Starting Fine-tuning (Mode: {args.mode}) ---")
    trainer.train()
    print("--- Fine-tuning Completed ---")

    print(f"Saving final model to {FINAL_MODEL_DIR}")
    trainer.save_model(FINAL_MODEL_DIR)
    
    # --- Merge all CSV files ---
    merge_csv_files(results_path, BASE_MODEL_NAME)
    
    print("\n--- Starting Automatic Upload to HuggingFace Hub ---")
    if os.path.exists(FINAL_MODEL_DIR):
        upload_to_hf(FINAL_MODEL_DIR, repo_name_with_mode)
        print("--- Upload to HuggingFace Hub Completed ---")
    else:
        print(f"Error: Final model directory '{FINAL_MODEL_DIR}' not found.")


if __name__ == "__main__":
    is_notebook = 'google.colab' in sys.modules or 'ipykernel' in sys.modules

    if is_notebook:
        print("Running in a notebook environment. Setting arguments manually.")
        # Manually set args for notebook execution
        args = argparse.Namespace(action="train", mode="full", platform="local")
        main(args)
    else:
        parser = argparse.ArgumentParser(description="Fine-tune Hindi-finetuned Gemma-2B on Bengali data.")  # Updated description
        parser.add_argument("--action", type=str, default="train", choices=["train", "upload"], help="Action to perform.")
        parser.add_argument("--mode", type=str, default="test", choices=["test", "full"], help="Training mode (test uses a small subset).")
        parser.add_argument("--platform", type=str, default="local", choices=["colab", "digitalocean", "local"], help="Execution platform.")
        parsed_args = parser.parse_args()
        main(parsed_args)