In [None]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}")
else:
  print("No GPU detected. Go to Runtime > Change runtime type > GPU")

In [None]:
# Import of A3CG files

from google.colab import drive
import os
import shutil
import json

print("IMPORTING A3CG FILES")

# Mount Google Drive
print("Mounting Google Drive")
try:
    drive.mount('/content/drive')
    print("Drive mounted successfully")
except Exception as e:
    print(f"Error: {e}")
    exit()

# Finding the fold_1 folder
print("\nSearching for the 'fold_1' folder...")

possible_path = "/content/drive/MyDrive/A3CG_Dataset/folds/fold_1"
drive_fold_path = None

if os.path.exists(possible_path):
    drive_fold_path = possible_path
    print(f"Found: {possible_path}")
else:
    print("Folder 'fold_1' not found.")
    exit()

# List contents of the folder
print(f"\nContents of: {drive_fold_path}")
files_in_folder = os.listdir(drive_fold_path)
required_files = ["seen_train.json", "seen_val.json", "seen_test.json", "unseen_test.json"]

for file in files_in_folder:
    if file.endswith('.json'):
        file_path = os.path.join(drive_fold_path, file)
        size_kb = os.path.getsize(file_path) / 1024
        status = "[OK]" if file in required_files else "[INFO]"
        print(f"  {status} {file} ({size_kb:.1f} KB)")

# Create local structure
local_path = "/content/A3CG_DATASET/folds/fold_1"
os.makedirs(local_path, exist_ok=True)
print(f"\nDirectory created: {local_path}")

# Copy files
print("\nCopying files...")
copied_count = 0

for filename in required_files:
    source = os.path.join(drive_fold_path, filename)
    dest = os.path.join(local_path, filename)

    if os.path.exists(source):
        try:
            shutil.copy2(source, dest)
            size_kb = os.path.getsize(dest) / 1024

            # Check for valid JSON
            with open(dest, 'r', encoding='utf-8') as f:
                data = json.load(f)

            print(f"  [OK] {filename}: {len(data)} samples ({size_kb:.1f} KB)")
            copied_count += 1

        except Exception as e:
            print(f"  [ERROR] {filename}: {e}")
    else:
        print(f"  [MISSING] {filename} not found")

# Final result
print(f"\nRESULT:")
print(f"  Files copied: {copied_count}/4")

if copied_count >= 3:
    print(f"SUCCESS! Files are ready.")
    print(f"Path: {local_path}")

    # Show final structure
    print(f"\nFinal Directory Structure:")
    for root, dirs, files in os.walk("/content/A3CG_DATASET"):
        level = root.replace("/content/A3CG_DATASET", '').count(os.sep)
        indent = '  ' * level
        folder_name = os.path.basename(root) or "A3CG_DATASET"
        print(f'{indent}- {folder_name}/')

        sub_indent = '  ' * (level + 1)
        for file in files:
            if file.endswith('.json'):
                file_path = os.path.join(root, file)
                size_kb = os.path.getsize(file_path) / 1024
                print(f'{sub_indent}- {file} ({size_kb:.1f} KB)')

    print(f"\nUse this path in your code:")
    print(f"/content/A3CG_DATASET/folds/fold_1/")

else:
    print(f"FAILURE: Only {copied_count}/4 files were copied.")
    print("Please verify that all required JSON files are in your Google Drive folder.")

In [None]:
# HuggingFace token : get it at https://huggingface.co/settings/tokens
from huggingface_hub import login

# Interactive login
login()

# Verification
from huggingface_hub import whoami
try:
    user_info = whoami()
    print(f"Connected as: {user_info['name']}")
except Exception as e:
    print(f"Authentication error: {e}")

In [None]:
# ===================
# INSTALLATION SCRIPT
# ===================

!pip install -q --upgrade transformers peft bitsandbytes accelerate torch datasets scikit-learn "numpy<2.0"

print("Installation completed.")

print("Installation successful with compatible versions.")
print("\nNEXT STEPS:")
print("1. RESTART THE RUNTIME (MANDATORY)")
print("   Runtime > Restart session")
print("2. Wait 30 seconds after restart")

print("\nWHY THIS RESTART IS CRITICAL:")
print("   - Prevents NumPy 1.x/2.x conflicts in memory")
print("   - Clears Python cache")
print("   - Ensures correct versions are loaded")

In [None]:
# CELL 1: DEPENDENCIES INSTALLATION
# Run this cell, then manually restart the runtime environment.

print("STEP 0: Installing and upgrading required packages...")

# Install PyTorch from official CUDA source (cu121) to ensure compatibility and fix torchvision::nms error
!pip install -q -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# Install other required packages
!pip install -q -U transformers peft bitsandbytes accelerate datasets
!pip install -q flash-attn --no-build-isolation

print("\nINSTALLATION COMPLETE. PLEASE RESTART THE RUNTIME NOW.")
print("(Runtime -> Restart session)")

In [None]:
# LORA + FEW-SHOT + CONTRASTIVE LEARNING (WITHOUT ORDINAL)

# STEP 1: Imports
print("\nSTEP 1: Packages import...")
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig, Trainer, DataCollatorForLanguageModeling, TrainerCallback
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training

try:
    import bitsandbytes
    print(f"  bitsandbytes: Import OK")
except Exception as e:
    print(f"  bitsandbytes error: {e}")

import json
import time
import os
import re
import random
from datasets import Dataset
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from sklearn.metrics.pairwise import cosine_similarity

os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"

# STEP 2: Google Drive mounting
print("\nSTEP 2: Google Drive mounting...")
from google.colab import drive
try:
    drive.mount('/content/drive')
    print("  Drive mounted successfully")
except Exception as e:
    print(f"Drive error: {e}")

# STEP 3: Data verification
print("\nSTEP 3: Data verification...")
dataset_base = "/content/A3CG_DATASET/folds/fold_1"

required_files = ["seen_train.json", "seen_val.json", "seen_test.json", "unseen_test.json"]
files_ok = True

for filename in required_files:
    filepath = os.path.join(dataset_base, filename)
    if os.path.exists(filepath):
        size_kb = os.path.getsize(filepath) / 1024
        print(f"  {filename}: {size_kb:.1f} KB")
    else:
        print(f"  {filename}: Missing")
        files_ok = False

if not files_ok:
    print("WARNING: Data files missing!")
    print("Run the data import script first")
    exit()

# STEP 4: Contrastive Learning Components
print("\nSTEP 4: Contrastive Learning Components...")

@dataclass
class ContrastiveSample:
    """Structure for contrastive learning samples"""
    anchor_text: str
    anchor_aspects: Dict
    positive_text: str
    positive_aspects: Dict
    negative_text: str
    negative_aspects: Dict


class ContrastiveLoss(nn.Module):
    """Contrastive loss for aspect-action learning"""

    def __init__(self, temperature: float = 0.1, margin: float = 0.5):
        super().__init__()
        self.temperature = temperature
        self.margin = margin

    def forward(self, anchor_emb, positive_emb, negative_emb, reduction='mean'):
        """
        Compute contrastive loss
        Args:
            anchor_emb: Embeddings of anchor samples [batch_size, hidden_dim]
            positive_emb: Embeddings of positive samples [batch_size, hidden_dim]
            negative_emb: Embeddings of negative samples [batch_size, hidden_dim]
            reduction: 'mean' for batch average, 'none' for per-sample losses
        Returns:
            loss: Scalar (if reduction='mean') or [batch_size] tensor (if reduction='none')
        """
        # Normalize embeddings
        anchor_emb = F.normalize(anchor_emb, p=2, dim=1)
        positive_emb = F.normalize(positive_emb, p=2, dim=1)
        negative_emb = F.normalize(negative_emb, p=2, dim=1)

        # Compute similarities
        pos_sim = F.cosine_similarity(anchor_emb, positive_emb, dim=1)
        neg_sim = F.cosine_similarity(anchor_emb, negative_emb, dim=1)

        # Contrastive loss with temperature scaling
        pos_exp = torch.exp(pos_sim / self.temperature)
        neg_exp = torch.exp(neg_sim / self.temperature)

        # InfoNCE-style loss per sample
        loss_per_sample = -torch.log(pos_exp / (pos_exp + neg_exp + 1e-8))

        if reduction == 'mean':
            return loss_per_sample.mean()
        elif reduction == 'none':
            return loss_per_sample
        else:
            raise ValueError(f"reduction must be 'mean' or 'none', got {reduction}")

class ContrastiveDataGenerator:
    """Generates contrastive samples for training"""

    def __init__(self):
        self.action_groups = {
            'implemented': [],
            'planning': [],
            'indeterminate': []
        }

    def group_samples_by_action(self, data: List[Dict]):
        """Group samples by their predominant action type"""
        for sample in data:
            aspects = sample.get('aspects', {})
            action_counts = {'implemented': 0, 'planning': 0, 'indeterminate':0}

            for aspect, actions in aspects.items():
                for action in actions:
                    action_lower = action.lower().strip()

                    if action_lower == 'implemented':
                        action_counts['implemented'] += 1
                    elif action_lower == 'planning':
                        action_counts['planning'] += 1
                    elif action_lower == 'indeterminate':
                        action_counts['indeterminate'] += 1
                    else:
                        # For unrecognized actions, categorize as indeterminate
                        action_counts['indeterminate'] += 1

            # Determine predominant action
            if sum(action_counts.values()) > 0:
                main_action = max(action_counts, key=action_counts.get)
                self.action_groups[main_action].append(sample)
            else:
                self.action_groups['indeterminate'].append(sample)

        print(f" Grouped samples:")
        for action, samples in self.action_groups.items():
            print(f"  {action}: {len(samples)} samples")

    def create_contrastive_pairs(self, data: List[Dict], n_pairs: int = None) -> List[ContrastiveSample]:
        """Create contrastive learning pairs"""
        if n_pairs is None:
            n_pairs = len(data)

        self.group_samples_by_action(data)
        contrastive_samples = []

        print(f"  Creating {n_pairs} contrastive pairs...")

        valid_actions = set(self.action_groups.keys())

        for i in range(n_pairs):
            if i%100 == 0:
                print(f" Progress: {i}/{n_pairs}")

            # Select anchor
            anchor = random.choice(data)
            anchor_aspects = anchor.get('aspects', {})

            if not anchor_aspects:
                continue

            # Determine anchor's main action
            anchor_actions = []
            for actions in anchor_aspects.values():
                anchor_actions.extend(actions)

            if not anchor_actions:
                continue

            valid_anchor_actions = [action for action in anchor_actions if action in valid_actions]

            if not valid_anchor_actions:
                anchor_main_action = 'indeterminate'
            else:
                anchor_main_action = max(set(valid_anchor_actions), key=valid_anchor_actions.count)

            if anchor_main_action not in valid_actions:
                anchor_main_action = 'indeterminate'

            # Find positive (same action type, different text)
            positive_candidates = [s for s in self.action_groups[anchor_main_action]
                                 if s['text'] != anchor['text']]

            if not positive_candidates:
                all_candidates = []
                for action_type, samples in self.action_groups.items():
                    all_candidates.extend([s for s in samples if s['text'] != anchor['text']])

                if all_candidates:
                    positive = random.choice(all_candidates)
                else:
                    positive = anchor
            else:
                positive = random.choice(positive_candidates)

            # Find negative (different action type)
            negative_actions = [a for a in valid_actions if a != anchor_main_action]
            if negative_actions:
                neg_action = random.choice(negative_actions)
                negative_candidates = self.action_groups[neg_action]
                if negative_candidates:
                    negative = random.choice(negative_candidates)
                else:
                    other_samples = [s for s in data if s['text'] != anchor['text']]
                    negative = random.choice(other_samples) if other_samples else anchor
            else:
                other_samples = [s for s in data if s['text'] != anchor['text']]
                negative = random.choice(other_samples) if other_samples else anchor

            contrastive_samples.append(ContrastiveSample(
                anchor_text=anchor['text'],
                anchor_aspects=anchor_aspects,
                positive_text=positive['text'],
                positive_aspects=positive.get('aspects', {}),
                negative_text=negative['text'],
                negative_aspects=negative.get('aspects', {})
            ))

        print(f"  Generated {len(contrastive_samples)} contrastive samples")
        return contrastive_samples

print("Contrastive components configured")

# STEP 5: Enhanced monitoring callback
print("\nSTEP 5: Enhanced monitoring configuration...")

class ContrastiveMemoryMonitorCallback(TrainerCallback):
    def __init__(self):
        self.start_time = time.time()
        self.last_check = time.time()

    def on_step_end(self, args, state, control, model=None, **kwargs):
        current_time = time.time()
        step = state.global_step

        if step % 10 == 0:
            elapsed = current_time - self.start_time
            step_time = current_time - self.last_check

            if torch.cuda.is_available():
                allocated = torch.cuda.memory_allocated() / 1e9
                reserved = torch.cuda.memory_reserved() / 1e9

                print(f"Step {step}: GPU {allocated:.1f}/{reserved:.1f} GB, Time: {elapsed/60:.1f}min")

                # Auto cleanup
                gpu_total = torch.cuda.get_device_properties(0).total_memory / 1e9
                if reserved / gpu_total > 0.9:
                    torch.cuda.empty_cache()
                    print("Automatic cleanup")
            else:
                print(f"Step {step}: Time: {elapsed/60:.1f}min")

            self.last_check = current_time

    def on_epoch_end(self, args, state, control, model=None, **kwargs):
        epoch = state.epoch
        elapsed = time.time() - self.start_time
        print(f"Epoch {epoch} completed in {elapsed/60:.1f} minutes")

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

print("Enhanced callback configured")

# STEP 6: Model configuration with contrastive head
print("\nSTEP 6: Model configuration with contrastive capabilities...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

print(f"Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    padding_side="left"
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Loading LLaMA-3 8B model...")
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={'': 0},
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2"
)

# LoRA preparation
base_model = prepare_model_for_kbit_training(base_model)

# LoRA configuration
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    bias="none",
    lora_dropout=0.1,
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(base_model, lora_config)

class ContrastiveLoRAModel(nn.Module):
    """Wrapper model that combines generation and contrastive learning (without ordinal)"""

    def __init__(self, base_model, hidden_size=4096, contrastive_dim=256):
        super().__init__()
        self.base_model = base_model
        self.hidden_size = hidden_size
        self.contrastive_dim = contrastive_dim

        # Contrastive projection head
        self.contrastive_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, contrastive_dim)
        )

        self.contrastive_loss_fn = ContrastiveLoss(temperature=0.1)

    def get_text_embedding(self, input_ids, attention_mask):
        """Extract text embedding for contrastive learning"""
        with torch.no_grad():
            outputs = self.base_model.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )

            # Use last hidden state with mean pooling
            last_hidden = outputs.hidden_states[-1]

            # Mean pooling with attention mask
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size())
            sum_hidden = torch.sum(last_hidden * mask_expanded, dim=1)
            sum_mask = torch.sum(mask_expanded, dim=1)
            mean_hidden = sum_hidden / sum_mask

            return mean_hidden

    def forward(self, input_ids, attention_mask, labels=None,
        contrastive_anchors=None, contrastive_positives=None, contrastive_negatives=None,
        lambda_base=0.1):

        # Compute per-sample losses (generation + contrastive)

        # 1. Generation loss per sample (not averaged)
        generation_outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        # Calculate per-sample loss (cross-entropy per token, then average per sample)
        if labels is not None:
            logits = generation_outputs.logits
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # Cross-entropy per token
            loss_fct = nn.CrossEntropyLoss(reduction='none')
            token_losses = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
            token_losses = token_losses.view(shift_labels.size())

            # Mask to ignore padding tokens (-100)
            mask = (shift_labels != -100).float()

            # Generation loss per sample (average over valid tokens)
            generation_loss_per_sample = (token_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-8)
        else:
            generation_loss_per_sample = torch.zeros(input_ids.size(0), device=input_ids.device)

        # 2. Initialize contrastive loss
        contrastive_loss_per_sample = torch.zeros_like(generation_loss_per_sample)

        # 3. Contrastive loss per sample if data available
        if contrastive_anchors is not None and contrastive_positives is not None and contrastive_negatives is not None:
            # Get embeddings
            anchor_emb = self.get_text_embedding(
                contrastive_anchors['input_ids'],
                contrastive_anchors['attention_mask']
            )
            positive_emb = self.get_text_embedding(
                contrastive_positives['input_ids'],
                contrastive_positives['attention_mask']
            )
            negative_emb = self.get_text_embedding(
                contrastive_negatives['input_ids'],
                contrastive_negatives['attention_mask']
            )

            # Project to contrastive space
            anchor_proj = self.contrastive_head(anchor_emb)
            positive_proj = self.contrastive_head(positive_emb)
            negative_proj = self.contrastive_head(negative_emb)

            contrastive_loss_per_sample = self.contrastive_loss_fn(
                anchor_proj, positive_proj, negative_proj, reduction='none'
            )

        # Loss combination: generation + contrastive only
        scaled_contr_loss = lambda_base * contrastive_loss_per_sample
        total_loss_per_sample = generation_loss_per_sample + scaled_contr_loss
        total_loss = total_loss_per_sample.mean()

        # Optional loss monitoring
        if hasattr(self, '_loss_monitoring'):
            with torch.no_grad():
                self._loss_monitoring = {
                    'generation_loss_mean': generation_loss_per_sample.mean().item(),
                    'contrastive_loss_mean': contrastive_loss_per_sample.mean().item(),
                    'scaled_contrastive_mean': scaled_contr_loss.mean().item(),
                    'lambda_base': lambda_base,
                }

        generation_outputs.loss = total_loss
        return generation_outputs

# Wrap model with contrastive capabilities
model = ContrastiveLoRAModel(model)
model.print_trainable_parameters = lambda: print(f"Trainable parameters: LoRA + Contrastive head")

print("Contrastive LoRA model configured")

# STEP 7: Enhanced Few-Shot data processor with contrastive support
print("\nSTEP 7: Enhanced Few-Shot + Contrastive processor...")

class A3CGContrastiveFewShotDataProcessor:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.system_prompt = """You are an expert in ESG analysis. Extract aspect-action pairs from sustainability statements.

CRITICAL INSTRUCTIONS:
- Extract EXACT terms from the input text
- Do not paraphrase or interpret creatively
- Use literal wording from the sentence
- Focus on specific terms rather than general concepts

DEFINITIONS:
- Aspect: A sustainability-related entity, goal, sub-area, or activity (use exact wording)
- Action: "implemented", "planning", or "indeterminate"

OUTPUT FORMAT: ("aspect1", "action1"), ("aspect2", "action2"), ...
If none: ("no aspect", "no action")"""

        # Enhanced few-shot examples
        self.few_shot_examples = [
            {
                "text": "We have implemented solar panels to reduce energy consumption in our facilities.",
                "output": '("solar panels", "implemented"), ("energy consumption", "implemented")'
            },
            {
                "text": "The company plans to improve workplace diversity initiatives next year.",
                "output": '("workplace diversity initiatives", "planning")'
            },
            {
                "text": "We are committed to enhancing our environmental management systems.",
                "output": '("environmental management systems", "planning")'
            },
            {
                "text": "Our recycling program has achieved a 50% waste reduction.",
                "output": '("recycling program", "implemented"), ("waste reduction", "implemented")'
            },
            {
                "text": "The board may consider sustainability investments where feasible.",
                "output": '("sustainability investments", "indeterminate")'
            },
            {
                "text": "All staff complete an e-learning programme from an accredited training provider.",
                "output": '("e-learning programme", "implemented"), ("training provider", "implemented"), ("staff", "implemented")'
            },
            {
                "text": "We strive to minimize carbon footprints in our operations.",
                "output": '("carbon footprints", "planning")'
            },
            {
                "text": "Recycled items are processed to create new products.",
                "output": '("recycled items", "indeterminate")'
            }
        ]

        self.contrastive_generator = ContrastiveDataGenerator()

    def get_few_shot_examples(self, n_examples: int = 3) -> str:
        """Randomly selects n examples for few-shot"""
        selected = random.sample(self.few_shot_examples, min(n_examples, len(self.few_shot_examples)))

        examples_text = ""
        for i, example in enumerate(selected, 1):
            examples_text += f"\nExample {i}:\n"
            examples_text += f"Text: {example['text']}\n"
            examples_text += f"Output: {example['output']}\n"

        return examples_text

    def create_prompt(self, text: str, aspects_dict: Dict = None) -> str:
        """Create prompt using Llama 3 chat format via tokenizer"""
        few_shot_text = self.get_few_shot_examples(n_examples=3)

        user_content = (
            f"{few_shot_text.strip()}\n\n"
            f"Now extract from this text:\n"
            f"Text: {text}\n\n"
            f"Extract the aspect-action pairs:"
        )

        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": user_content}
        ]

        if aspects_dict:
            output_pairs = []
            for aspect, actions in aspects_dict.items():
                for action in actions:
                    output_pairs.append(f'("{aspect}", "{action}")')
            expected_output = ', '.join(output_pairs) if output_pairs else '("no aspect", "no action")'
            messages.append({"role": "assistant", "content": expected_output})

        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True if aspects_dict is None else False
        )

    def load_data(self, file_path: str) -> List[Dict]:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)

    def prepare_contrastive_dataset(self, data: List[Dict]) -> Tuple[Dataset, List[ContrastiveSample]]:
        """Prepare dataset with both generation and contrastive samples"""

        print(f"  Preparing {len(data)} samples with few-shot + contrastive...")

        # Generation prompts
        prompts = []
        for i, item in enumerate(data):
            if i % 100 == 0:
                print(f"    Generation progress: {i}/{len(data)} ({i/len(data)*100:.1f}%)")

            prompt = self.create_prompt(item['text'], item.get('aspects', {}))
            prompts.append(prompt)

        # Contrastive samples
        print("  Generating contrastive samples...")
        contrastive_samples = self.contrastive_generator.create_contrastive_pairs(data, n_pairs=len(data)//2)

        print(f"  {len(prompts)} generation prompts + {len(contrastive_samples)} contrastive samples")

        return Dataset.from_dict({"text": prompts}), contrastive_samples

print("Enhanced processor configured")

# STEP 8: Data preparation with contrastive support
print("\nSTEP 8: Loading and preparing data with Few-Shot + Contrastive...")

processor = A3CGContrastiveFewShotDataProcessor(tokenizer=tokenizer)

print("Loading JSON files...")
train_data = processor.load_data(f"{dataset_base}/seen_train.json")
val_data = processor.load_data(f"{dataset_base}/seen_val.json")

print(f"Train: {len(train_data)} samples")
print(f"Validation: {len(val_data)} samples")

# Prepare datasets
print("Generating few-shot + contrastive data...")
train_dataset, train_contrastive = processor.prepare_contrastive_dataset(train_data)
val_dataset, val_contrastive = processor.prepare_contrastive_dataset(val_data)

# Display example
print("\nEXAMPLE GENERATION PROMPT:")
print("=" * 50)
sample_prompt = train_dataset[0]['text']
print(sample_prompt[:800] + "..." if len(sample_prompt) > 800 else sample_prompt)
print("=" * 50)

print(f"\nEXAMPLE CONTRASTIVE TRIPLET:")
print("=" * 50)
sample_contrastive = train_contrastive[0]
print(f"Anchor: {sample_contrastive.anchor_text[:100]}...")
print(f"Positive: {sample_contrastive.positive_text[:100]}...")
print(f"Negative: {sample_contrastive.negative_text[:100]}...")
print("=" * 50)

# Tokenization
def tokenize_function(examples):
    tokenized = tokenizer(
        examples["text"],
        truncation=True,
        padding=False,
        max_length=2048,
        # max_length=1024,  # Use if VRAM is insufficient
        return_tensors=None
    )
    return tokenized

print("Tokenizing data...")
train_dataset = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=train_dataset.column_names,
    desc="Train tokenization"
)
val_dataset = val_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=val_dataset.column_names,
    desc="Validation tokenization"
)

print("Data preparation completed")

# STEP 9: Enhanced data collator
print("\nSTEP 9: Configuring the enhanced contrastive data collator...")

class ContrastiveDataCollator:
    """
    Custom data collator that handles standard padding for language modeling
    and probabilistically injects tokenized contrastive learning triplets into each batch.
    """
    def __init__(self, tokenizer, contrastive_samples: list, contrastive_prob: float = 0.3, contrastive_max_length: int = 512):
        self.tokenizer = tokenizer
        self.contrastive_samples = contrastive_samples
        self.contrastive_prob = contrastive_prob
        self.contrastive_max_length = contrastive_max_length

    def __call__(self, features: list) -> dict:
        # Separate labels before padding
        labels = [feature["labels"] for feature in features] if "labels" in features[0] else None

        # Use tokenizer's built-in padding
        batch = self.tokenizer.pad(
            features,
            return_tensors="pt",
            padding=True,
        )

        # For causal LM, labels are the same as input_ids if not present
        if labels is None:
            batch["labels"] = batch["input_ids"].clone()
        else:
            batch["labels"] = self.tokenizer.pad(
                {"input_ids": labels},
                return_tensors="pt",
                padding=True
            )["input_ids"]

        # Probabilistically inject contrastive samples
        if random.random() < self.contrastive_prob and self.contrastive_samples:
            batch_size = batch['input_ids'].shape[0]
            selected_contrastive = random.sample(self.contrastive_samples,
                                               min(batch_size, len(self.contrastive_samples)))

            # Prepare texts for tokenization
            anchor_texts = [cs.anchor_text for cs in selected_contrastive]
            positive_texts = [cs.positive_text for cs in selected_contrastive]
            negative_texts = [cs.negative_text for cs in selected_contrastive]

            # Tokenize triplets
            anchor_tokens = self.tokenizer(anchor_texts, truncation=True, padding=True,
                                         max_length=self.contrastive_max_length, return_tensors="pt")
            positive_tokens = self.tokenizer(positive_texts, truncation=True, padding=True,
                                           max_length=self.contrastive_max_length, return_tensors="pt")
            negative_tokens = self.tokenizer(negative_texts, truncation=True, padding=True,
                                           max_length=self.contrastive_max_length, return_tensors="pt")

            # Add tokenized triplets to batch
            batch['contrastive_anchors'] = anchor_tokens
            batch['contrastive_positives'] = positive_tokens
            batch['contrastive_negatives'] = negative_tokens

        return batch

data_collator = ContrastiveDataCollator(
    tokenizer=tokenizer,
    contrastive_samples=train_contrastive,
    contrastive_prob=0.3,
    contrastive_max_length=512
)

print("Contrastive data collator configured.")

# STEP 10: Training configuration for Llama-3-8B
print("\nSTEP 10: Configuring training for Llama-3-8B Contrastive...")

training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/A3CG_Llama3_8B_Contrastive_Simple_LoRA",
    num_train_epochs=10,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    optim="paged_adamw_32bit",
    save_steps=100,
    logging_steps=10,
    learning_rate=3e-5,
    weight_decay=0.01,
    fp16=False,
    bf16=True,
    max_grad_norm=1.0,
    warmup_ratio=0.1,
    group_by_length=True,
    lr_scheduler_type="cosine",
    eval_steps=100,
    eval_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    remove_unused_columns=False,
    report_to="none",
)

memory_callback = ContrastiveMemoryMonitorCallback()

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.add_callback(memory_callback)

print("Trainer configured for Llama-3-8B Contrastive.")

# STEP 11: Training execution
print("\nSTEP 11: STARTING LLAMA-3-8B CONTRASTIVE + FEW-SHOT TRAINING...")
print("=" * 70)

print(f"Start time: {time.strftime('%H:%M:%S')}")
print(f"Base Model: {model_name}")
print(f"Config: Batch {training_args.per_device_train_batch_size}x{training_args.gradient_accumulation_steps} (Effective: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}), Epochs {training_args.num_train_epochs}")
print(f"LoRA: Rank {lora_config.r}, Alpha {lora_config.lora_alpha}")
print(f"Few-Shot: 3 examples per prompt")
print(f"Contrastive: {data_collator.contrastive_prob*100}% of batches, temperature={model.contrastive_loss_fn.temperature}")
print(f"Architecture: Generation + Contrastive (without ordinal)")
print("=" * 70)

start_time = time.time()

try:
    trainer.train()

    training_time = time.time() - start_time
    print(f"\nLLAMA-3-8B CONTRASTIVE + FEW-SHOT TRAINING COMPLETED!")
    print(f"Total time: {training_time/3600:.1f}h ({training_time/60:.1f}min)")

except Exception as e:
    print(f"\nERROR DURING TRAINING: {e}")
    print("Automatic cleanup...")
    torch.cuda.empty_cache()
    raise e

# STEP 12: Model saving
model_output_path = f"./Llama3-8B-lora-a3cg-contrastive-simple-final"
print(f"\nSTEP 12: Saving fine-tuned model to {model_output_path}...")

try:
    trainer.model.base_model.save_pretrained(model_output_path)
    tokenizer.save_pretrained(model_output_path)

    torch.save(trainer.model.contrastive_head.state_dict(),
               f"{model_output_path}/contrastive_head.pt")

    print(f"Model saved successfully in: {model_output_path}")
except Exception as e:
    print(f"Save error: {e}")

# STEP 13: Backup to Google Drive
print("\nSTEP 13: AUTOMATIC BACKUP TO GOOGLE DRIVE...")
print("=" * 70)

import shutil
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M")

try:
    drive_folder = f"/content/drive/MyDrive/A3CG_Llama3_8B_Models"
    os.makedirs(drive_folder, exist_ok=True)
    print(f"Drive backup folder: {drive_folder}")

    drive_model_path = f"{drive_folder}/lora-a3cg-contrastive-simple_{timestamp}"
    shutil.copytree(model_output_path, drive_model_path)

    files_copied = os.listdir(drive_model_path)
    total_size_mb = sum(os.path.getsize(os.path.join(drive_model_path, f))
                       for f in files_copied) / (1024*1024)

    print(f"Model successfully copied to: {drive_model_path}")
    print(f"Total size: {total_size_mb:.1f} MB")

    metadata_file = f"{drive_folder}/model_info_simple_{timestamp}.txt"
    with open(metadata_file, 'w', encoding='utf-8') as f:
        f.write(f"A3CG LORA + CONTRASTIVE + FEW-SHOT MODEL METADATA\n")
        f.write(f"=" * 60 + "\n")
        f.write(f"Creation Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Base Model: {model_name}\n")
        f.write(f"Techniques: Q-LoRA + Contrastive Learning + Few-Shot (without ordinal)\n")
        f.write(f"Task: A3CG aspect-action extraction\n")
        f.write(f"\n--- TRAINING CONFIG ---\n")
        f.write(f"Epochs: {training_args.num_train_epochs}\n")
        f.write(f"Batch Size: {training_args.per_device_train_batch_size} (Effective: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps})\n")
        f.write(f"Learning Rate: {training_args.learning_rate}\n")
        f.write(f"\n--- LoRA CONFIG ---\n")
        f.write(f"Rank (r): {lora_config.r}\n")
        f.write(f"Alpha: {lora_config.lora_alpha}\n")
        f.write(f"Target Modules: {lora_config.target_modules}\n")
        f.write(f"\n--- CONTRASTIVE CONFIG ---\n")
        f.write(f"Temperature: {model.contrastive_loss_fn.temperature}\n")
        f.write(f"Batch Probability: {data_collator.contrastive_prob}\n")
        f.write(f"Architecture: Generation + Contrastive (2 components)\n")
        f.write(f"\n--- OUTPUT ---\n")
        f.write(f"Model Size (MB): {total_size_mb:.1f}\n")
        f.write(f"Drive Path: {drive_model_path}\n")

    print(f"Metadata file created: {metadata_file}")
    print("=" * 70)
    print("LLAMA-3-8B CONTRASTIVE FINE-TUNED MODEL IS READY!")
    print(f"Location: {drive_model_path}")

except Exception as e:
    print(f"Backup error: {e}")

# STEP 14: Final cleanup
print("\nSTEP 14: Final cleanup...")
torch.cuda.empty_cache()

print("\nCONTRASTIVE SCRIPT COMPLETED!")
print("=" * 70)
print(f"End time: {time.strftime('%H:%M:%S')}")
print(f"Model saved to: {model_output_path}")
print(f"Backed up to Drive at: {drive_model_path}")
print("=" * 70)

print("\nNEXT STEPS:")
print("=" * 40)
print("1. Run evaluation with the new Llama-3-8B contrastive model")
print("2. Compare performance: Full (3 losses) vs Simple (2 losses)")
print("3. Analyze if ordinal loss really helps or adds complexity")
print("4. Document findings on contrastive learning effectiveness")
print("=" * 40)

In [None]:
# A3CG MODEL EVALUATION SCRIPT - SIMPLE CONTRASTIVE
# Exact Match evaluation compatible with Generation + Contrastive (without ordinal)

import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import re
import pandas as pd
import random
import numpy as np
from datetime import datetime
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)
from peft import PeftModel, PeftConfig
import warnings
warnings.filterwarnings('ignore')

print("A3CG MODEL EVALUATION SCRIPT - SIMPLE CONTRASTIVE")
print("Exact Match evaluation (paper implementation)")
print("Compatible with Generation + Contrastive (without ordinal)")
print("=" * 70)
print(f"Start time: {time.strftime('%H:%M:%S')}")
print("=" * 70)

# STEP 1: Import exact classes from simple training
print("STEP 1: Import exact classes from simple training...")

@dataclass
class ContrastiveSample:
    """Structure for contrastive learning samples"""
    anchor_text: str
    anchor_aspects: Dict
    positive_text: str
    positive_aspects: Dict
    negative_text: str
    negative_aspects: Dict

class ContrastiveLoss(nn.Module):
    """Contrastive loss for aspect-action learning - exact copy from training"""

    def __init__(self, temperature: float = 0.1, margin: float = 0.5):
        super().__init__()
        self.temperature = temperature
        self.margin = margin

    def forward(self, anchor_emb, positive_emb, negative_emb, reduction='mean'):
        """Exact implementation from training"""
        # Normalize embeddings
        anchor_emb = F.normalize(anchor_emb, p=2, dim=1)
        positive_emb = F.normalize(positive_emb, p=2, dim=1)
        negative_emb = F.normalize(negative_emb, p=2, dim=1)

        # Compute similarities
        pos_sim = F.cosine_similarity(anchor_emb, positive_emb, dim=1)
        neg_sim = F.cosine_similarity(anchor_emb, negative_emb, dim=1)

        # Contrastive loss with temperature scaling
        pos_exp = torch.exp(pos_sim / self.temperature)
        neg_exp = torch.exp(neg_sim / self.temperature)

        # InfoNCE-style loss per sample
        loss_per_sample = -torch.log(pos_exp / (pos_exp + neg_exp + 1e-8))

        if reduction == 'mean':
            return loss_per_sample.mean()
        elif reduction == 'none':
            return loss_per_sample
        else:
            raise ValueError(f"reduction must be 'mean' or 'none', got {reduction}")

class ContrastiveLoRAModel(nn.Module):
    """Modified class for simple version (Generation + Contrastive only)"""

    def __init__(self, base_model, hidden_size=4096, contrastive_dim=256):
        super().__init__()
        self.base_model = base_model
        self.hidden_size = hidden_size
        self.contrastive_dim = contrastive_dim

        # Contrastive projection head
        self.contrastive_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, contrastive_dim)
        )

        self.contrastive_loss_fn = ContrastiveLoss(temperature=0.1)

    def get_text_embedding(self, input_ids, attention_mask):
        """Extract text embedding - exact method from training"""
        with torch.no_grad():
            outputs = self.base_model.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )

            last_hidden = outputs.hidden_states[-1]
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size())
            sum_hidden = torch.sum(last_hidden * mask_expanded, dim=1)
            sum_mask = torch.sum(mask_expanded, dim=1)
            mean_hidden = sum_hidden / sum_mask

            return mean_hidden

    def forward(self, input_ids, attention_mask, labels=None,
            contrastive_anchors=None, contrastive_positives=None, contrastive_negatives=None,
            lambda_base=0.1):

        # Compute per-sample losses (Generation + Contrastive)

        # 1. Generation loss per sample
        generation_outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        if labels is not None:
            logits = generation_outputs.logits
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            loss_fct = nn.CrossEntropyLoss(reduction='none')
            token_losses = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
            token_losses = token_losses.view(shift_labels.size())

            mask = (shift_labels != -100).float()
            generation_loss_per_sample = (token_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-8)
        else:
            generation_loss_per_sample = torch.zeros(input_ids.size(0), device=input_ids.device)

        # 2. Initialize contrastive loss
        contrastive_loss_per_sample = torch.zeros_like(generation_loss_per_sample)

        # 3. Contrastive loss per sample if data available
        if contrastive_anchors is not None and contrastive_positives is not None and contrastive_negatives is not None:
            anchor_emb = self.get_text_embedding(
                contrastive_anchors['input_ids'],
                contrastive_anchors['attention_mask']
            )
            positive_emb = self.get_text_embedding(
                contrastive_positives['input_ids'],
                contrastive_positives['attention_mask']
            )
            negative_emb = self.get_text_embedding(
                contrastive_negatives['input_ids'],
                contrastive_negatives['attention_mask']
            )

            anchor_proj = self.contrastive_head(anchor_emb)
            positive_proj = self.contrastive_head(positive_emb)
            negative_proj = self.contrastive_head(negative_emb)

            contrastive_loss_per_sample = self.contrastive_loss_fn(
                anchor_proj, positive_proj, negative_proj, reduction='none'
            )

        # Simplified combination: generation + contrastive only
        scaled_contr_loss = lambda_base * contrastive_loss_per_sample
        total_loss_per_sample = generation_loss_per_sample + scaled_contr_loss
        total_loss = total_loss_per_sample.mean()

        generation_outputs.loss = total_loss
        return generation_outputs

    def generate(self, *args, **kwargs):
        """Forward to base model's generate method"""
        return self.base_model.generate(*args, **kwargs)

print("Exact classes imported from simple training")

# STEP 2: Robust model loading
print("\nSTEP 2: Robust model loading...")

# Configuration
MODEL_BASE = "meta-llama/Meta-Llama-3-8B-Instruct"

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Configure data paths
DATA_DIR = "/content/A3CG_DATASET/folds/fold_1"
DRIVE_DATA_DIR = "/content/drive/MyDrive/A3CG_Dataset/folds/fold_1"

if os.path.exists(DATA_DIR):
    print(f"Using data directory: {DATA_DIR}")
elif os.path.exists(DRIVE_DATA_DIR):
    DATA_DIR = DRIVE_DATA_DIR
    print(f"Using Drive data directory: {DATA_DIR}")
else:
    print("Data directory not found!")
    print("Please ensure data is available at:")
    print("- /content/A3CG_DATASET/folds/fold_1 OR")
    print("- /content/drive/MyDrive/A3CG_Dataset/folds/fold_1")

def load_trained_model_robust(model_path: str):
    """Robust loading with thorough verification"""

    print(f"Analyzing trained model: {model_path}")

    # 1. Verify directory exists and contains required files
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model not found: {model_path}")

    files_in_model = os.listdir(model_path)
    print(f"Files found: {files_in_model}")

    required_files = ['adapter_config.json']
    has_weights = any(f.endswith('.safetensors') or f.endswith('.bin') for f in files_in_model)

    missing_files = [f for f in required_files if f not in files_in_model]
    if missing_files:
        print(f"Missing files: {missing_files}")

    if not has_weights:
        print("No weight files found (.safetensors or .bin)")

    # 2. Load tokenizer
    print("Loading tokenizer...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        print("Tokenizer loaded from model")
    except:
        print("Fallback: tokenizer from base model")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # 3. Load base model with quantization
    print("Loading base model...")

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )

    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_BASE,
        quantization_config=quantization_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        attn_implementation="flash_attention_2"
    )

    # 4. Load LoRA adapters
    print("Loading LoRA adapters...")

    lora_model = PeftModel.from_pretrained(
        base_model,
        model_path,
        torch_dtype=torch.bfloat16
    )

    print(f"LoRA adapters loaded")

    # 5. Create contrastive model with exact simple training architecture
    print("Creating simple contrastive model...")

    hidden_size = base_model.config.hidden_size
    print(f"Detected hidden size: {hidden_size}")

    contrastive_model = ContrastiveLoRAModel(
        base_model=lora_model,
        hidden_size=hidden_size,
        contrastive_dim=256
    )

    # 6. Load contrastive head with verification
    contrastive_head_path = os.path.join(model_path, "contrastive_head.pt")

    if os.path.exists(contrastive_head_path):
        print("Loading simple contrastive head...")

        try:
            state_dict = torch.load(contrastive_head_path, map_location='cpu')
            print(f"Contrastive weights loaded: {list(state_dict.keys())}")

            # Verify compatibility before loading
            model_state_keys = set(contrastive_model.contrastive_head.state_dict().keys())
            loaded_keys = set(state_dict.keys())

            if model_state_keys == loaded_keys:
                contrastive_model.contrastive_head.load_state_dict(state_dict)
                device = next(contrastive_model.base_model.parameters()).device
                contrastive_model.contrastive_head.to(device)
                print("Simple contrastive head loaded successfully")
                return contrastive_model, tokenizer, "contrastive_simple_loaded"
            else:
                print(f"Key incompatibility:")
                print(f"   Model: {model_state_keys}")
                print(f"   File: {loaded_keys}")
                print("Using model without contrastive head")
                return lora_model, tokenizer, "lora_only"

        except Exception as e:
            print(f"Error loading contrastive head: {e}")
            print("Using LoRA model only")
            return lora_model, tokenizer, "lora_only"
    else:
        print("No contrastive head found")
        return lora_model, tokenizer, "lora_only"

# Find and load model
print("Searching for trained simple model...")

possible_paths = [
    "./Llama3-8B-lora-a3cg-contrastive-simple-final",
    "./Llama3-8B-lora-a3cg-contrastive-final",
    "/content/drive/MyDrive/A3CG_Llama3_8B_Contrastive_Simple_LoRA",
    "/content/drive/MyDrive/A3CG_Llama3_8B_Contrastive_LoRA"
]

# Add dynamic folders
for base_dir in ["/content/drive/MyDrive/A3CG_Llama3_8B_Models"]:
    if os.path.exists(base_dir):
        for folder in os.listdir(base_dir):
            if "contrastive" in folder:
                possible_paths.append(os.path.join(base_dir, folder))

model_path = None
for path in possible_paths:
    if os.path.exists(path):
        model_path = path
        print(f"Model found: {path}")
        break

if not model_path:
    print("No trained model found!")
    print("Paths checked:")
    for path in possible_paths:
        print(f"  - {path}")
    raise FileNotFoundError("No trained model found!")

# Load model
model, tokenizer, loading_method = load_trained_model_robust(model_path)

print(f"Model loaded with method: {loading_method}")

# STEP 3: Architecture verification
print("\nSTEP 3: Architecture verification...")

def verify_model_architecture(model, loading_method):
    """Verify architecture matches simple training"""

    print("Verifying simple model architecture...")

    if loading_method in ["contrastive_simple_loaded", "contrastive_loaded"]:
        print("Contrastive simple model with loaded head")

        if hasattr(model, 'contrastive_head'):
            print(f"Contrastive head present")

            # Verify dimensions
            head_layers = list(model.contrastive_head.children())
            print(f"Contrastive head architecture:")
            for i, layer in enumerate(head_layers):
                if hasattr(layer, 'in_features') and hasattr(layer, 'out_features'):
                    print(f"   Layer {i}: {layer.in_features} -> {layer.out_features}")
                else:
                    print(f"   Layer {i}: {type(layer).__name__}")

            # Quick test
            device = next(model.parameters()).device
            test_tensor = torch.randn(1, model.hidden_size).to(device)
            try:
                output = model.contrastive_head(test_tensor)
                print(f"Contrastive head test successful: {test_tensor.shape} -> {output.shape}")

                # Verify no ordinal loss (simple version)
                if not hasattr(model, 'ordinal_contrastive_loss_fn'):
                    print("Simple architecture confirmed (no ordinal loss)")
                else:
                    print("Full architecture detected (with ordinal loss)")

                return True
            except Exception as e:
                print(f"Contrastive head test error: {e}")
                return False
        else:
            print("Missing contrastive head!")
            return False

    elif loading_method == "lora_only":
        print("LoRA only model (without contrastive)")
        return True

    else:
        print("Unknown loading method")
        return False

architecture_ok = verify_model_architecture(model, loading_method)

if not architecture_ok:
    print("ERROR: Architecture does not match training!")
    print("   The evaluated model will not be the correct model.")
else:
    print("Architecture verified - model matches simple training")

# STEP 4: Evaluation prompt generator
print("\nSTEP 4: Evaluation prompt generator configuration...")

class A3CGEvaluationPromptGenerator:
    def __init__(self, tokenizer, model_type: str = "contrastive"):
        self.tokenizer = tokenizer
        self.model_type = model_type

        self.system_prompt = """You are an expert in ESG analysis. Extract aspect-action pairs from sustainability statements.

CRITICAL INSTRUCTIONS:
- Extract EXACT terms from the input text
- Do not paraphrase or interpret creatively
- Use literal wording from the sentence
- Focus on specific terms rather than general concepts

DEFINITIONS:
- Aspect: A sustainability-related entity, goal, sub-area, or activity (use exact wording)
- Action: "implemented", "planning", or "indeterminate"

OUTPUT FORMAT: ("aspect1", "action1"), ("aspect2", "action2"), ...
If none: ("no aspect", "no action")"""

        self.few_shot_examples = [
            {
                "text": "We have implemented solar panels to reduce energy consumption in our facilities.",
                "output": '("solar panels", "implemented"), ("energy consumption", "implemented")'
            },
            {
                "text": "The company plans to improve workplace diversity initiatives next year.",
                "output": '("workplace diversity initiatives", "planning")'
            },
            {
                "text": "We are committed to enhancing our environmental management systems.",
                "output": '("environmental management systems", "planning")'
            },
            {
                "text": "Our recycling program has achieved a 50% waste reduction.",
                "output": '("recycling program", "implemented"), ("waste reduction", "implemented")'
            },
            {
                "text": "The board may consider sustainability investments where feasible.",
                "output": '("sustainability investments", "indeterminate")'
            }
        ]

    def get_few_shot_examples(self, n_examples: int = 3) -> str:
        """Get few-shot examples"""
        selected = random.sample(self.few_shot_examples, min(n_examples, len(self.few_shot_examples)))

        examples_text = ""
        for i, example in enumerate(selected, 1):
            examples_text += f"\nExample {i}:\n"
            examples_text += f"Text: {example['text']}\n"
            examples_text += f"Output: {example['output']}\n"

        return examples_text

    def create_prompt(self, sentence: str) -> str:
        """Create evaluation prompt using Llama 3 chat template"""

        if self.model_type in ["contrastive_simple_loaded", "contrastive_loaded", "contrastive"]:
            # Few-shot prompt format with chat template
            few_shot_text = self.get_few_shot_examples(n_examples=3)

            user_content = (
                f"{few_shot_text.strip()}\n\n"
                f"Now extract from this text:\nText: {sentence}\n\n"
                f"Extract the aspect-action pairs:"
            )

            messages = [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": user_content}
            ]

            return self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )

        else:
            # Standard prompt format with chat template
            messages = [
                {"role": "user", "content": f"Extract aspect-action pairs from the following ESG sentence.\nFormat your response as JSON with \"aspect-action_pairs\" containing a list of objects with \"aspect\" and \"action\" fields.\n\nSentence: {sentence}\n\nResponse:"}
            ]

            return self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )

# Initialize prompt generator
prompt_generator = A3CGEvaluationPromptGenerator(tokenizer, loading_method)

print(f"Prompt generator configured for {loading_method}")

# STEP 5: Generation and parsing functions
print("\nSTEP 5: Generation and parsing functions configuration...")

def generate_prediction_enhanced(model, tokenizer, sentence: str, model_type: str, max_length: int = 512) -> str:
    """Enhanced prediction generation optimized for Llama 3-8B"""

    prompt = prompt_generator.create_prompt(sentence)

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=1024,
        padding=True
    )

    # Move to device
    device = next(model.parameters()).device if hasattr(model, 'parameters') else 'cuda' if torch.cuda.is_available() else 'cpu'
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Generation parameters optimized for Llama 3-8B
    if model_type in ["contrastive_simple_loaded", "contrastive_loaded", "contrastive"]:
        gen_params = {
            "max_new_tokens": 150,
            "do_sample": True,
            "temperature": 0.1,
            "top_p": 0.9,
            "repetition_penalty": 1.1,
            "pad_token_id": tokenizer.eos_token_id,
            "eos_token_id": tokenizer.eos_token_id,
        }
    else:
        gen_params = {
            "max_new_tokens": 200,
            "do_sample": True,
            "temperature": 0.1,
            "top_p": 0.9,
            "pad_token_id": tokenizer.eos_token_id,
            "eos_token_id": tokenizer.eos_token_id,
        }

    # Generate
    with torch.no_grad():
        try:
            outputs = model.generate(**inputs, **gen_params)
        except Exception as e:
            print(f"WARNING: Generation error: {e}")
            # Fallback generation
            outputs = model.generate(
                inputs['input_ids'],
                max_new_tokens=150,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id
            )

    # Decode only the generated part
    generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
    response = tokenizer.decode(generated_ids, skip_special_tokens=True)

    return response.strip()

def parse_prediction_enhanced(prediction: str, model_type: str) -> List[Tuple[str, str]]:
    """Enhanced parsing based on model type and expected format"""

    pairs = []

    if model_type in ["contrastive_simple_loaded", "contrastive_loaded", "contrastive"]:
        # Parse tuple format: ("aspect", "action"), ("aspect", "action")
        try:
            # Find all tuple patterns
            tuple_pattern = r'\(\s*["\']([^"\']+)["\']\s*,\s*["\']([^"\']+)["\']\s*\)'
            matches = re.findall(tuple_pattern, prediction)

            for aspect, action in matches:
                aspect = aspect.strip()
                action = action.strip()
                if aspect and action and aspect != "no aspect" and action != "no action":
                    pairs.append((aspect.lower(), action.lower()))

            # If no tuples found, try simpler patterns
            if not pairs:
                simple_pattern = r'\(\s*([^,\)]+)\s*,\s*([^,\)]+)\s*\)'
                simple_matches = re.findall(simple_pattern, prediction)

                for aspect, action in simple_matches:
                    aspect = aspect.strip().strip('"\'')
                    action = action.strip().strip('"\'')
                    if aspect and action and aspect != "no aspect" and action != "no action":
                        pairs.append((aspect.lower(), action.lower()))

        except Exception as e:
            print(f"WARNING: Tuple parsing error: {e}")

    else:
        # Standard JSON parsing for base models
        try:
            if '{' in prediction and '}' in prediction:
                json_start = prediction.find('{')
                brace_count = 0
                json_end = json_start

                for i in range(json_start, len(prediction)):
                    if prediction[i] == '{':
                        brace_count += 1
                    elif prediction[i] == '}':
                        brace_count -= 1
                        if brace_count == 0:
                            json_end = i + 1
                            break

                if json_end > json_start:
                    json_str = prediction[json_start:json_end]
                    json_str = re.sub(r'[\n\r\t]', ' ', json_str)
                    json_str = re.sub(r'\s+', ' ', json_str)

                    data = json.loads(json_str)

                    for key in ["aspect-action_pairs", "aspect_action_pairs", "pairs", "results"]:
                        if key in data and isinstance(data[key], list):
                            for pair in data[key]:
                                if isinstance(pair, dict) and "aspect" in pair and "action" in pair:
                                    aspect = str(pair["aspect"]).strip()
                                    action = str(pair["action"]).strip()
                                    if aspect and action:
                                        pairs.append((aspect.lower(), action.lower()))
                            break

        except json.JSONDecodeError:
            # Fallback regex for JSON-like patterns
            json_pattern = r'\{\s*["\']aspect["\']\s*:\s*["\']([^"\']+)["\']\s*,\s*["\']action["\']\s*:\s*["\']([^"\']+)["\']\s*\}'
            matches = re.findall(json_pattern, prediction, re.IGNORECASE)

            for aspect, action in matches:
                pairs.append((aspect.strip().lower(), action.strip().lower()))

    return pairs

print("Generation and parsing functions configured")

# STEP 6: Evaluation functions - Exact Match only
print("\nSTEP 6: Evaluation functions configuration...")

def calculate_metrics_exact_match(predictions: List[List[Tuple[str, str]]],
                                 ground_truth: List[List[Tuple[str, str]]]) -> Dict:
    """Calculate metrics with exact matching (as in A3CG paper)"""

    print("Calculating Exact Match metrics (A3CG paper)...")

    exact_matches = 0
    partial_matches = 0
    total_pred_pairs = 0
    total_true_pairs = 0

    exact_true_positives = 0
    exact_false_positives = 0
    exact_false_negatives = 0

    for pred_pairs, true_pairs in zip(predictions, ground_truth):
        total_pred_pairs += len(pred_pairs)
        total_true_pairs += len(true_pairs)

        # Convert to sets for exact comparison
        pred_set = set(pred_pairs)
        true_set = set(true_pairs)

        # Exact matches for this sample
        matched_pairs = pred_set.intersection(true_set)

        exact_true_positives += len(matched_pairs)
        exact_false_positives += len(pred_set - true_set)
        exact_false_negatives += len(true_set - pred_set)

        # Sample-level exact match
        if pred_set == true_set and len(pred_set) > 0:
            exact_matches += 1

        # Sample-level partial match
        if len(matched_pairs) > 0:
            partial_matches += 1

    n_samples = len(predictions)

    exact_match_accuracy = exact_matches / n_samples if n_samples > 0 else 0
    partial_match_accuracy = partial_matches / n_samples if n_samples > 0 else 0

    exact_precision = exact_true_positives / (exact_true_positives + exact_false_positives) if (exact_true_positives + exact_false_positives) > 0 else 0
    exact_recall = exact_true_positives / (exact_true_positives + exact_false_negatives) if (exact_true_positives + exact_false_negatives) > 0 else 0
    exact_f1_score = 2 * (exact_precision * exact_recall) / (exact_precision + exact_recall) if (exact_precision + exact_recall) > 0 else 0

    return {
        'exact_match_accuracy': exact_match_accuracy,
        'partial_match_accuracy': partial_match_accuracy,
        'exact_precision': exact_precision,
        'exact_recall': exact_recall,
        'exact_f1_score': exact_f1_score,
        'exact_true_positives': exact_true_positives,
        'exact_false_positives': exact_false_positives,
        'exact_false_negatives': exact_false_negatives,
        'total_predictions': total_pred_pairs,
        'total_ground_truth': total_true_pairs,
    }

def load_test_data_flexible(file_path: str) -> Tuple[List[str], List[List[Tuple[str, str]]]]:
    """Flexible data loading"""
    print(f"Loading: {os.path.basename(file_path)}")

    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    sentences = []
    ground_truth = []

    if isinstance(data, list):
        for item in data:
            if isinstance(item, dict):
                if 'text' in item and 'aspects' in item:
                    sentence = item['text']
                    pairs = []

                    aspects_dict = item['aspects']
                    if isinstance(aspects_dict, dict):
                        for aspect, actions in aspects_dict.items():
                            if isinstance(actions, list):
                                for action in actions:
                                    pairs.append((aspect.strip(), action.strip()))
                            elif isinstance(actions, str):
                                pairs.append((aspect.strip(), actions.strip()))

                    sentences.append(sentence)
                    ground_truth.append(pairs)

    print(f"Loaded {len(sentences)} samples with {sum(len(gt) for gt in ground_truth)} total pairs")
    return sentences, ground_truth

print("Exact match evaluation functions configured")

# STEP 7: Load test data
print("\nSTEP 7: Loading test data...")

test_files = {
    'seen_test': f"{DATA_DIR}/seen_test.json",
    'unseen_test': f"{DATA_DIR}/unseen_test.json"
}

test_data = {}
for name, file_path in test_files.items():
    if os.path.exists(file_path):
        try:
            sentences, ground_truth = load_test_data_flexible(file_path)
            if len(sentences) > 0:
                test_data[name] = (sentences, ground_truth)
                print(f"{name}: {len(sentences)} samples")
        except Exception as e:
            print(f"ERROR loading {name}: {e}")
    else:
        print(f"File not found: {file_path}")

if not test_data:
    print("WARNING: No test files found - creating demo data")
    test_data['demo'] = (
        ["This company has implemented strong environmental policies to reduce carbon emissions."],
        [[("environmental policies", "implemented"), ("carbon emissions", "implemented")]]
    )

# STEP 8: Model evaluation
print("\nSTEP 8: Simple contrastive model evaluation...")
print("=" * 60)
print(f"Model evaluated: {model_path}")
print(f"Loading method: {loading_method}")
print(f"Architecture verified: {'Yes' if architecture_ok else 'No'}")
print(f"Architecture: Generation + Contrastive only (without ordinal)")
print(f"Metrics: Exact Match (paper implementation)")
print("=" * 60)

results = {}

for dataset_name, (sentences, ground_truth) in test_data.items():
    print(f"\nEvaluation on {dataset_name}...")
    print(f"Base Model: {MODEL_BASE}")
    print(f"Model type: {loading_method}")
    print(f"Number of samples: {len(sentences)}")

    predictions = []
    start_time = time.time()

    # Evaluate on full dataset or subset for testing
    n_test = min(len(sentences), 270)
    test_sentences = sentences[:n_test]
    test_ground_truth = ground_truth[:n_test]

    print(f"Testing on {n_test} samples...")

    for i, sentence in enumerate(test_sentences):
        if i % 25 == 0:
            print(f"   Progress: {i}/{n_test} ({i/n_test*100:.1f}%)")

        try:
            # Generate prediction
            prediction_text = generate_prediction_enhanced(model, tokenizer, sentence, loading_method)

            # Parse prediction
            pred_pairs = parse_prediction_enhanced(prediction_text, loading_method)
            predictions.append(pred_pairs)

            # Debug first few predictions
            if i < 3:
                print(f"   Sample {i+1}: {len(pred_pairs)} pairs predicted")
                print(f"   Raw output: {prediction_text[:100]}...")
                print(f"   Parsed pairs: {pred_pairs}")

        except Exception as e:
            print(f"WARNING: Error sample {i}: {e}")
            predictions.append([])

    evaluation_time = time.time() - start_time

    # Calculate exact match metrics
    print(f"Calculating metrics...")
    exact_metrics = calculate_metrics_exact_match(predictions, test_ground_truth)

    # Save results
    results[dataset_name] = {
        'exact_metrics': exact_metrics,
        'evaluation_time': evaluation_time,
        'samples_per_second': n_test / evaluation_time,
        'predictions': predictions[:5],
        'ground_truth': test_ground_truth[:5],
        'model_type': loading_method,
        'n_samples': n_test,
        'base_model': MODEL_BASE
    }

    print(f"Evaluation completed in {evaluation_time:.2f}s")
    print(f"Speed: {n_test/evaluation_time:.2f} samples/sec")

# STEP 9: Results display
print("\n" + "="*80)
print("EVALUATION RESULTS - SIMPLE CONTRASTIVE")
print("="*80)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

for dataset_name, result in results.items():
    exact_metrics = result['exact_metrics']

    print(f"\nRESULTS - {dataset_name.upper()}")
    print(f"Base model: {result['base_model']}")
    print(f"Model type: {loading_method}")
    print(f"Architecture: Generation + Simple Contrastive (2 components)")
    print("-" * 60)

    print(f"EXACT MATCH METRICS (A3CG paper implementation):")
    print(f"   Exact Match Accuracy:     {exact_metrics['exact_match_accuracy']:.4f} ({exact_metrics['exact_match_accuracy']*100:.2f}%)")
    print(f"   Exact Precision:          {exact_metrics['exact_precision']:.4f} ({exact_metrics['exact_precision']*100:.2f}%)")
    print(f"   Exact Recall:             {exact_metrics['exact_recall']:.4f} ({exact_metrics['exact_recall']*100:.2f}%)")
    print(f"   Exact F1-Score:           {exact_metrics['exact_f1_score']:.4f} ({exact_metrics['exact_f1_score']*100:.2f}%)")

    print(f"\nPERFORMANCE:")
    print(f"   Evaluation speed:         {result['samples_per_second']:.2f} samples/sec")

# STEP 10: Detailed examples
print(f"\n" + "="*80)
print("DETAILED EXAMPLES ANALYSIS")
print("="*80)

for dataset_name, result in results.items():
    print(f"\nDETAILED EXAMPLES - {dataset_name.upper()}")
    print("-" * 70)

    sentences, ground_truth = test_data[dataset_name]
    predictions = result['predictions']

    for i in range(min(3, len(predictions))):
        print(f"\nExample {i+1}:")
        print(f"Sentence: {sentences[i][:120]}...")
        print(f"Ground truth: {ground_truth[i]}")
        print(f"Prediction: {predictions[i]}")

        # Exact match analysis
        pred_set = set(predictions[i])
        true_set = set(ground_truth[i])
        exact_matches = pred_set.intersection(true_set)

        print(f"Exact Match Analysis:")
        if exact_matches:
            print(f"   Exact matches: {exact_matches}")
        else:
            print(f"   No exact matches")

        if not predictions[i]:
            print(f"   WARNING: No prediction generated")

# STEP 11: Performance summary
print("\n" + "="*80)
print("PERFORMANCE SUMMARY LLAMA 3-8B SIMPLE CONTRASTIVE")
print("="*80)

# Calculate average metrics across datasets
if results:
    avg_exact_f1 = np.mean([r['exact_metrics']['exact_f1_score'] for r in results.values()])
    avg_exact_precision = np.mean([r['exact_metrics']['exact_precision'] for r in results.values()])
    avg_exact_recall = np.mean([r['exact_metrics']['exact_recall'] for r in results.values()])

    print(f"BASE MODEL: {MODEL_BASE}")
    print(f"MODEL TYPE: {loading_method.upper()}")
    print(f"MODEL EVALUATED: {os.path.basename(model_path)}")
    print(f"ARCHITECTURE: Generation + Simple Contrastive (2 components)")
    print(f"ARCHITECTURE VERIFIED: {'Yes' if architecture_ok else 'No'}")
    print("-" * 50)

    print(f"EXACT MATCH METRICS (A3CG paper):")
    print(f"   F1-Score:        {avg_exact_f1:.4f} ({avg_exact_f1*100:.2f}%)")
    print(f"   Precision:       {avg_exact_precision:.4f} ({avg_exact_precision*100:.2f}%)")
    print(f"   Recall:          {avg_exact_recall:.4f} ({avg_exact_recall*100:.2f}%)")

    # Performance comparison with paper
    print(f"\nCOMPARISON WITH A3CG PAPER:")
    print(f"   Paper GRACE (best):          47.51% F1 (exact)")
    print(f"   Our Exact Match:             {avg_exact_f1*100:.2f}% F1")

    if avg_exact_f1 >= 0.25:
        print(f"   Performance comparable to paper")
    else:
        print(f"   Performance below paper baseline")

# STEP 12: Save results
print(f"\nSTEP 12: Save results...")

# Create results directory
results_dir = f"/content/drive/MyDrive/A3CG_Evaluation_Results"
os.makedirs(results_dir, exist_ok=True)

# Save complete results
results_file = f"{results_dir}/evaluation_results_simple_contrastive_{timestamp}.json"

# Prepare data for saving
save_data = {
    'timestamp': timestamp,
    'model_info': {
        'model_path': model_path,
        'base_model': MODEL_BASE,
        'loading_method': loading_method,
        'architecture_verified': architecture_ok,
        'architecture_type': 'simple_contrastive'
    },
    'evaluation_results': {}
}

for dataset_name, result in results.items():
    save_data['evaluation_results'][dataset_name] = {
        'exact_metrics': result['exact_metrics'],
        'evaluation_time': result['evaluation_time'],
        'samples_per_second': result['samples_per_second'],
        'n_samples': result['n_samples'],
        'examples': {
            'predictions': result['predictions'],
            'ground_truth': result['ground_truth']
        }
    }

# Save JSON
with open(results_file, 'w', encoding='utf-8') as f:
    json.dump(save_data, f, indent=2, ensure_ascii=False)

# Create comprehensive markdown report
report_file = f"{results_dir}/evaluation_report_simple_contrastive_{timestamp}.md"

with open(report_file, 'w', encoding='utf-8') as f:
    f.write(f"# A3CG Simple Contrastive Evaluation Report - Llama 3-8B\n\n")
    f.write(f"**Date**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")

    f.write(f"## Model Configuration\n\n")
    f.write(f"- **Base model**: {MODEL_BASE}\n")
    f.write(f"- **Model path**: {model_path}\n")
    f.write(f"- **Loading method**: {loading_method}\n")
    f.write(f"- **Architecture**: Generation + Simple Contrastive (2 components, without ordinal)\n")
    f.write(f"- **Architecture verified**: {'Yes' if architecture_ok else 'No'}\n\n")

    f.write(f"## Evaluation Method\n\n")
    f.write(f"### Exact Match (A3CG Paper)\n")
    f.write(f"- Strict matching of aspect-action pairs\n")
    f.write(f"- Conservative approach, comparable to original paper\n\n")

    f.write(f"## Evaluation Results\n\n")

    for dataset_name, result in results.items():
        exact_metrics = result['exact_metrics']

        f.write(f"### {dataset_name.upper()}\n\n")

        f.write(f"#### Exact Match (Paper)\n")
        f.write(f"| Metric | Value | Percentage |\n")
        f.write(f"|--------|-------|-----------|\n")
        f.write(f"| Exact F1 | {exact_metrics['exact_f1_score']:.4f} | {exact_metrics['exact_f1_score']*100:.2f}% |\n")
        f.write(f"| Exact Precision | {exact_metrics['exact_precision']:.4f} | {exact_metrics['exact_precision']*100:.2f}% |\n")
        f.write(f"| Exact Recall | {exact_metrics['exact_recall']:.4f} | {exact_metrics['exact_recall']*100:.2f}% |\n\n")

print(f"Complete results saved:")
print(f"   JSON: {results_file}")
print(f"   Report: {report_file}")

# CONCLUSION
print("\n" + "="*80)
print("SIMPLE CONTRASTIVE EVALUATION COMPLETED!")
print("="*80)
print(f"End time: {time.strftime('%H:%M:%S')}")
print(f"Model evaluated: {os.path.basename(model_path)}")
print(f"Method: {loading_method}")
print(f"Architecture: Generation + Simple Contrastive (2 components)")
print(f"Architecture verified: {'Yes' if architecture_ok else 'No'}")
print(f"Datasets evaluated: {len(results)}")

if results:
    avg_exact_f1 = np.mean([r['exact_metrics']['exact_f1_score'] for r in results.values()])

    print(f"F1-Score Exact (paper): {avg_exact_f1:.4f} ({avg_exact_f1*100:.2f}%)")

print(f"Results saved in: {results_dir}")
print("="*80)

if not architecture_ok:
    print("WARNING: Architecture could not be fully verified.")
    print("   Results may not reflect trained model performance.")
elif loading_method not in ["contrastive_simple_loaded", "contrastive_loaded"]:
    print("WARNING: Contrastive head was not loaded.")
    print("   Evaluated model uses only LoRA adapters.")
else:
    print("SUCCESS: Complete simple contrastive model evaluated correctly!")

print(f"\nFINAL SUMMARY:")
print(f"   Exact Match F1 (paper comparable): {avg_exact_f1*100:.1f}%")
print(f"   Simplified architecture: Generation + Contrastive (without ordinal)")

print("\nTo analyze results in detail, check:")
print(f"   {results_file}")
print(f"   {report_file}")