In [None]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}")
else:
  print("No GPU detected. Go to Runtime > Change runtime type > GPU")

In [None]:
# Import of A3CG files

from google.colab import drive
import os
import shutil
import json

print("IMPORTING A3CG FILES")

# Mount Google Drive
print("Mounting Google Drive")
try:
    drive.mount('/content/drive')
    print("Drive mounted successfully")
except Exception as e:
    print(f"Error: {e}")
    exit()

# Finding the fold_1 folder
print("\nSearching for the 'fold_1' folder...")

possible_path = "/content/drive/MyDrive/A3CG_Dataset/folds/fold_1"
drive_fold_path = None

if os.path.exists(possible_path):
    drive_fold_path = possible_path
    print(f"Found: {possible_path}")
else:
    print("Folder 'fold_1' not found.")
    exit()

# List contents of the folder
print(f"\nContents of: {drive_fold_path}")
files_in_folder = os.listdir(drive_fold_path)
required_files = ["seen_train.json", "seen_val.json", "seen_test.json", "unseen_test.json"]

for file in files_in_folder:
    if file.endswith('.json'):
        file_path = os.path.join(drive_fold_path, file)
        size_kb = os.path.getsize(file_path) / 1024
        status = "[OK]" if file in required_files else "[INFO]"
        print(f"  {status} {file} ({size_kb:.1f} KB)")

# Create local structure
local_path = "/content/A3CG_DATASET/folds/fold_1"
os.makedirs(local_path, exist_ok=True)
print(f"\nDirectory created: {local_path}")

# Copy files
print("\nCopying files...")
copied_count = 0

for filename in required_files:
    source = os.path.join(drive_fold_path, filename)
    dest = os.path.join(local_path, filename)

    if os.path.exists(source):
        try:
            shutil.copy2(source, dest)
            size_kb = os.path.getsize(dest) / 1024

            # Check for valid JSON
            with open(dest, 'r', encoding='utf-8') as f:
                data = json.load(f)

            print(f"  [OK] {filename}: {len(data)} samples ({size_kb:.1f} KB)")
            copied_count += 1

        except Exception as e:
            print(f"  [ERROR] {filename}: {e}")
    else:
        print(f"  [MISSING] {filename} not found")

# Final result
print(f"\nRESULT:")
print(f"  Files copied: {copied_count}/4")

if copied_count >= 3:
    print(f"SUCCESS! Files are ready.")
    print(f"Path: {local_path}")

    # Show final structure
    print(f"\nFinal Directory Structure:")
    for root, dirs, files in os.walk("/content/A3CG_DATASET"):
        level = root.replace("/content/A3CG_DATASET", '').count(os.sep)
        indent = '  ' * level
        folder_name = os.path.basename(root) or "A3CG_DATASET"
        print(f'{indent}- {folder_name}/')

        sub_indent = '  ' * (level + 1)
        for file in files:
            if file.endswith('.json'):
                file_path = os.path.join(root, file)
                size_kb = os.path.getsize(file_path) / 1024
                print(f'{sub_indent}- {file} ({size_kb:.1f} KB)')

    print(f"\nUse this path in your code:")
    print(f"/content/A3CG_DATASET/folds/fold_1/")

else:
    print(f"FAILURE: Only {copied_count}/4 files were copied.")
    print("Please verify that all required JSON files are in your Google Drive folder.")

In [None]:
# HuggingFace token : get it at https://huggingface.co/settings/tokens
from huggingface_hub import login

# Interactive login
login()

# Verification
from huggingface_hub import whoami
try:
    user_info = whoami()
    print(f"Connected as: {user_info['name']}")
except Exception as e:
    print(f"Authentication error: {e}")

In [None]:
# ===================
# INSTALLATION SCRIPT
# ===================

!pip install -q --upgrade transformers peft bitsandbytes accelerate torch datasets scikit-learn "numpy<2.0"

print("Installation completed.")

print("Installation successful with compatible versions.")
print("\nNEXT STEPS:")
print("1. RESTART THE RUNTIME (MANDATORY)")
print("   Runtime > Restart session")
print("2. Wait 30 seconds after restart")

print("\nWHY THIS RESTART IS CRITICAL:")
print("   - Prevents NumPy 1.x/2.x conflicts in memory")
print("   - Clears Python cache")
print("   - Ensures correct versions are loaded")

In [None]:
# CELL 1: DEPENDENCIES INSTALLATION
# Run this cell, then manually restart the runtime environment.

print("STEP 0: Installing and upgrading required packages...")

# Install PyTorch from official CUDA source (cu121) to ensure compatibility and fix torchvision::nms error
!pip install -q -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# Install other required packages
!pip install -q -U transformers peft bitsandbytes accelerate datasets
!pip install -q flash-attn --no-build-isolation

print("\nINSTALLATION COMPLETE. PLEASE RESTART THE RUNTIME NOW.")
print("(Runtime -> Restart session)")

In [None]:
# LORA + FEW-SHOT + ORDINAL CONTRASTIVE LEARNING WITH GRADNORM META-PARAMETERS

# STEP 1: Imports
print("\nSTEP 1: Package imports...")
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig, Trainer, TrainerCallback
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training

try:
    import bitsandbytes
    print(f"  bitsandbytes: Import OK")
except Exception as e:
    print(f"  bitsandbytes error: {e}")

import json
import time
import os
import random
from datasets import Dataset
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from collections import Counter
import shutil
from datetime import datetime

os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"

# STEP 2: Google Drive mounting
print("\nSTEP 2: Google Drive mounting...")
from google.colab import drive
try:
    drive.mount('/content/drive')
    print("  Drive mounted successfully")
except Exception as e:
    print(f"Drive error: {e}")

# STEP 3: Data verification
print("\nSTEP 3: Data verification...")
dataset_base = "/content/A3CG_DATASET/folds/fold_1"

required_files = ["seen_train.json", "seen_val.json", "seen_test.json", "unseen_test.json"]
files_ok = True

for filename in required_files:
    filepath = os.path.join(dataset_base, filename)
    if os.path.exists(filepath):
        size_kb = os.path.getsize(filepath) / 1024
        print(f"  {filename}: {size_kb:.1f} KB")
    else:
        print(f"  {filename}: Missing")
        files_ok = False

if not files_ok:
    print("WARNING: Data files missing!")
    print("Run the data import script first")
    exit()

# STEP 4: Meta-Parameters Manager
print("\nSTEP 4: GradNorm Meta-Parameters Configuration...")

class MetaParametersManager(nn.Module):
    """Manager for meta-parameters with GradNorm optimization"""

    def __init__(self, initial_lambda_base=0.15, initial_lambda_ord=0.04,
                 initial_T_gen=2.5, initial_T_ctr=0.8, epsilon=1e-6):
        super().__init__()

        def safe_inverse_softplus(x):
            return float(np.log(np.exp(float(x)) - 1.0))

        # Reparametrized meta-parameters for positivity constraints
        self.rho_base = nn.Parameter(
            torch.tensor(safe_inverse_softplus(initial_lambda_base - epsilon),
                        dtype=torch.float32),
            requires_grad=True
        )
        self.rho_ord = nn.Parameter(
            torch.tensor(safe_inverse_softplus(initial_lambda_ord - epsilon),
                        dtype=torch.float32),
            requires_grad=True
        )
        self.tau_gen = nn.Parameter(
            torch.tensor(safe_inverse_softplus(initial_T_gen - epsilon),
                        dtype=torch.float32),
            requires_grad=True
        )
        self.tau_ctr = nn.Parameter(
            torch.tensor(safe_inverse_softplus(initial_T_ctr - epsilon),
                        dtype=torch.float32),
            requires_grad=True
        )

        self.epsilon = epsilon

        # Loss tracking for GradNorm
        self.loss_history = {'gen': [], 'ctr': [], 'ord': []}
        self.initial_losses = {'gen': None, 'ctr': None, 'ord': None}

    def get_meta_params(self):
        """Get actual meta-parameters from reparametrized ones"""
        lambda_base = F.softplus(self.rho_base) + self.epsilon
        lambda_ord = F.softplus(self.rho_ord) + self.epsilon
        T_gen = F.softplus(self.tau_gen) + self.epsilon
        T_ctr = F.softplus(self.tau_ctr) + self.epsilon

        return {
            'lambda_base': lambda_base,
            'lambda_ord': lambda_ord,
            'T_gen': T_gen,
            'T_ctr': T_ctr
        }

    def update_loss_history(self, gen_loss, ctr_loss, ord_loss):
        """Update loss history for GradNorm calculation"""
        def to_float(x):
            if isinstance(x, torch.Tensor):
                return float(x.detach().cpu().item())
            return float(x)

        self.loss_history['gen'].append(to_float(gen_loss))
        self.loss_history['ctr'].append(to_float(ctr_loss))
        self.loss_history['ord'].append(to_float(ord_loss))

        # Store initial losses
        if self.initial_losses['gen'] is None:
            self.initial_losses['gen'] = to_float(gen_loss)
            self.initial_losses['ctr'] = to_float(ctr_loss)
            self.initial_losses['ord'] = to_float(ord_loss)

    def compute_loss_ratios(self, gamma=0.5):
        """Compute normalized loss ratios for GradNorm"""
        if len(self.loss_history['gen']) == 0:
            return {'gen': 1.0, 'ctr': 1.0, 'ord': 1.0}

        # Current normalized losses
        L_tilde = {}
        for k in ['gen', 'ctr', 'ord']:
            current_loss = self.loss_history[k][-1]
            initial_loss = self.initial_losses[k]
            L_tilde[k] = current_loss / (initial_loss + 1e-8)

        # Average normalized loss
        avg_L_tilde = np.mean(list(L_tilde.values()))

        # Loss ratios
        ratios = {}
        for k in ['gen', 'ctr', 'ord']:
            ratios[k] = (L_tilde[k] / (avg_L_tilde + 1e-8)) ** gamma

        return ratios

print("Meta-parameters manager configured")

# STEP 5: Ordinal Contrastive Learning Components
print("\nSTEP 5: Ordinal Contrastive Learning Components...")

@dataclass
class ContrastiveSample:
    """Structure for contrastive learning samples"""
    anchor_text: str
    anchor_aspects: Dict
    positive_text: str
    positive_aspects: Dict
    negative_text: str
    negative_aspects: Dict
    anchor_action_type: str = None

class ContrastiveLoss(nn.Module):
    """Standard contrastive loss for aspect-action learning"""

    def __init__(self, temperature: float = 0.1):
        super().__init__()
        self.temperature = temperature

    def forward(self, anchor_emb, positive_emb, negative_emb, reduction='mean'):
        """Compute contrastive loss"""
        # Normalize embeddings
        anchor_emb = F.normalize(anchor_emb, p=2, dim=1)
        positive_emb = F.normalize(positive_emb, p=2, dim=1)
        negative_emb = F.normalize(negative_emb, p=2, dim=1)

        # Compute similarities
        pos_sim = F.cosine_similarity(anchor_emb, positive_emb, dim=1)
        neg_sim = F.cosine_similarity(anchor_emb, negative_emb, dim=1)

        # Contrastive loss with temperature scaling
        pos_exp = torch.exp(pos_sim / self.temperature)
        neg_exp = torch.exp(neg_sim / self.temperature)

        # InfoNCE-style loss per sample
        loss_per_sample = -torch.log(pos_exp / (pos_exp + neg_exp + 1e-8))

        if reduction == 'mean':
            return loss_per_sample.mean()
        elif reduction == 'none':
            return loss_per_sample
        else:
            raise ValueError(f"reduction must be 'mean' or 'none', got {reduction}")

class OrdinalContrastiveLoss(nn.Module):
    """Ordinal contrastive loss with asymmetric margins"""

    def __init__(self, base_margin: float = 0.10, planning_margin: float = 0.15):
        super().__init__()
        self.base_margin = base_margin
        self.planning_margin = planning_margin

        # Action-to-margin mapping
        self.action_margins = {
            'implemented': base_margin,
            'indeterminate': base_margin,
            'planning': planning_margin  # Stricter margin for planning
        }

    def forward(self, anchor_emb, positive_emb, negative_emb, anchor_actions=None, reduction='mean'):
        """Compute ordinal contrastive loss with asymmetric margins"""
        # Normalize embeddings
        anchor_emb = F.normalize(anchor_emb, p=2, dim=1)
        positive_emb = F.normalize(positive_emb, p=2, dim=1)
        negative_emb = F.normalize(negative_emb, p=2, dim=1)

        # Compute distances (1 - cosine similarity)
        distance_0_1 = 1 - F.cosine_similarity(anchor_emb, positive_emb, dim=1)
        distance_0_2 = 1 - F.cosine_similarity(anchor_emb, negative_emb, dim=1)

        # Determine margins per sample
        if anchor_actions is not None:
            margins = torch.tensor([
                self.action_margins.get(action, self.base_margin)
                for action in anchor_actions
            ], device=anchor_emb.device, dtype=anchor_emb.dtype)
        else:
            margins = torch.full((anchor_emb.size(0),), self.base_margin,
                               device=anchor_emb.device, dtype=anchor_emb.dtype)

        # Ordinal contrastive loss per sample with asymmetric margins
        loss_per_sample = F.relu(distance_0_1 - distance_0_2 + margins)

        if reduction == 'mean':
            return loss_per_sample.mean()
        elif reduction == 'none':
            return loss_per_sample
        else:
            raise ValueError(f"reduction must be 'mean' or 'none', got {reduction}")

class ContrastiveDataGenerator:
    """Generates ordinal directional contrastive samples for training"""

    def __init__(self):
        self.action_groups = {
            'implemented': [],
            'planning': [],
            'indeterminate': []
        }
        # Defined ordinal order
        self.ordinal_order = ['indeterminate', 'planning', 'implemented']

    def group_samples_by_action(self, data: List[Dict]):
        """Group samples by their predominant action type"""
        for sample in data:
            aspects = sample.get('aspects', {})
            action_counts = {'implemented': 0, 'planning': 0, 'indeterminate': 0}

            for aspect, actions in aspects.items():
                for action in actions:
                    action_lower = action.lower().strip()
                    if action_lower in action_counts:
                        action_counts[action_lower] += 1
                    else:
                        action_counts['indeterminate'] += 1

            if sum(action_counts.values()) > 0:
                main_action = max(action_counts, key=action_counts.get)
                self.action_groups[main_action].append(sample)
            else:
                self.action_groups['indeterminate'].append(sample)

        print(f" Grouped samples:")
        for action, samples in self.action_groups.items():
            print(f"  {action}: {len(samples)} samples")

    def create_ordinal_directional_pairs(self, data: List[Dict], n_pairs: int = None) -> List[ContrastiveSample]:
        """Create ORDINAL DIRECTIONAL contrastive pairs with explicit rules"""
        if n_pairs is None:
            n_pairs = len(data)

        self.group_samples_by_action(data)
        contrastive_samples = []

        print(f"  Creating {n_pairs} ORDINAL DIRECTIONAL pairs...")

        # Directional construction rules
        ordinal_rules = {
            'implemented': {'positive': 'planning', 'negative': 'indeterminate'},
            'indeterminate': {'positive': 'planning', 'negative': 'implemented'},
            'planning': {'positive': 'implemented', 'negative': 'indeterminate'}
        }

        for i in range(n_pairs):
            if i % 100 == 0:
                print(f" Progress: {i}/{n_pairs}")

            # Select anchor
            anchor = random.choice(data)
            anchor_aspects = anchor.get('aspects', {})

            if not anchor_aspects:
                continue

            # Determine anchor's main action
            anchor_actions = []
            for actions in anchor_aspects.values():
                anchor_actions.extend([a.lower().strip() for a in actions])

            valid_anchor_actions = [a for a in anchor_actions if a in self.ordinal_order]
            if not valid_anchor_actions:
                anchor_main_action = 'indeterminate'
            else:
                anchor_main_action = max(set(valid_anchor_actions), key=valid_anchor_actions.count)

            # Apply directional ordinal rules
            rules = ordinal_rules[anchor_main_action]
            positive_action_type = rules['positive']
            negative_action_type = rules['negative']

            # Find positive and negative candidates
            positive_candidates = [s for s in self.action_groups[positive_action_type]
                                 if s['text'] != anchor['text']]
            negative_candidates = [s for s in self.action_groups[negative_action_type]
                                 if s['text'] != anchor['text']]

            # Selection with fallback
            if positive_candidates:
                positive = random.choice(positive_candidates)
            else:
                other_samples = [s for s in data if s['text'] != anchor['text']]
                positive = random.choice(other_samples) if other_samples else anchor

            if negative_candidates:
                negative = random.choice(negative_candidates)
            else:
                other_samples = [s for s in data
                               if s['text'] != anchor['text'] and s['text'] != positive['text']]
                negative = random.choice(other_samples) if other_samples else anchor

            contrastive_samples.append(ContrastiveSample(
                anchor_text=anchor['text'],
                anchor_aspects=anchor_aspects,
                positive_text=positive['text'],
                positive_aspects=positive.get('aspects', {}),
                negative_text=negative['text'],
                negative_aspects=negative.get('aspects', {}),
                anchor_action_type=anchor_main_action
            ))

        print(f"  Generated {len(contrastive_samples)} ORDINAL DIRECTIONAL samples")
        return contrastive_samples

print("Ordinal Contrastive components configured")

# STEP 6: GradNorm Enhanced Model
print("\nSTEP 6: Enhanced model configuration with GradNorm...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

print(f"Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    padding_side="left"
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Loading LLaMA-3 model...")
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={'': 0},
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2"
)

# LoRA preparation
base_model = prepare_model_for_kbit_training(base_model)

# LoRA configuration
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    bias="none",
    lora_dropout=0.1,
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(base_model, lora_config)

class GradNormContrastiveLoRAModel(nn.Module):
    """Enhanced model with GradNorm meta-parameter optimization"""

    def __init__(self, base_model, hidden_size=4096, contrastive_dim=256):
        super().__init__()
        self.base_model = base_model
        self.hidden_size = hidden_size
        self.contrastive_dim = contrastive_dim

        # Contrastive projection head
        self.contrastive_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, contrastive_dim)
        )

        # Loss functions
        self.contrastive_loss_fn = ContrastiveLoss(temperature=0.1)
        self.ordinal_contrastive_loss_fn = OrdinalContrastiveLoss(
            base_margin=0.10,
            planning_margin=0.15
        )

        # Meta-parameters manager
        self.meta_params = MetaParametersManager()

        # Monitoring
        self.step_count = 0
        self.meta_optimizer = None

    def get_text_embedding(self, input_ids, attention_mask):
        """Extract text embedding for contrastive learning"""
        with torch.no_grad():
            outputs = self.base_model.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )

            last_hidden = outputs.hidden_states[-1]
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size())
            sum_hidden = torch.sum(last_hidden * mask_expanded, dim=1)
            sum_mask = torch.sum(mask_expanded, dim=1)
            mean_hidden = sum_hidden / sum_mask

            return mean_hidden

    def compute_component_gradients(self, generation_loss_per_sample, contrastive_loss_per_sample,
                                   ordinal_loss_per_sample, meta_params):
        """Compute gradient norms for each component"""

        lambda_base = meta_params['lambda_base']
        lambda_ord = meta_params['lambda_ord']

        L_gen = generation_loss_per_sample.mean()
        L_ctr = contrastive_loss_per_sample.mean()
        L_ord = ordinal_loss_per_sample.mean()

        # Get trainable parameters (LoRA + contrastive head)
        trainable_params = []

        for name, param in self.base_model.named_parameters():
            if param.requires_grad and any(lora_key in name.lower() for lora_key in ['lora_a', 'lora_b']):
                trainable_params.append(param)

        for param in self.contrastive_head.parameters():
            if param.requires_grad:
                trainable_params.append(param)

        if not trainable_params:
            device = L_gen.device
            return {
                'gen': torch.tensor(1.0, device=device, dtype=torch.float32, requires_grad=True),
                'ctr': torch.tensor(0.8, device=device, dtype=torch.float32, requires_grad=True),
                'ord': torch.tensor(0.6, device=device, dtype=torch.float32, requires_grad=True)
            }, L_gen, L_ctr, L_ord

        gradients = {}
        device = L_gen.device

        # Generation gradient
        try:
            gen_grads = torch.autograd.grad(
                outputs=L_gen,
                inputs=trainable_params,
                retain_graph=True,
                create_graph=True,
                allow_unused=True
            )

            grad_squares = [torch.sum(g**2) for g in gen_grads if g is not None]
            if grad_squares:
                gradients['gen'] = torch.sqrt(torch.stack(grad_squares).sum() + 1e-8)
            else:
                gradients['gen'] = torch.tensor(1e-3, device=device, dtype=torch.float32, requires_grad=True)
        except:
            gradients['gen'] = torch.tensor(1e-3, device=device, dtype=torch.float32, requires_grad=True)

        # Contrastive gradient
        try:
            if L_ctr.item() > 1e-6:
                scaled_ctr_loss = lambda_base * L_ctr
                ctr_grads = torch.autograd.grad(
                    outputs=scaled_ctr_loss,
                    inputs=trainable_params,
                    retain_graph=True,
                    create_graph=True,
                    allow_unused=True
                )

                grad_squares = [torch.sum(g**2) for g in ctr_grads if g is not None]
                if grad_squares:
                    gradients['ctr'] = torch.sqrt(torch.stack(grad_squares).sum() + 1e-8)
                else:
                    gradients['ctr'] = torch.tensor(1e-4, device=device, dtype=torch.float32, requires_grad=True)
            else:
                gradients['ctr'] = torch.tensor(1e-4, device=device, dtype=torch.float32, requires_grad=True)
        except:
            gradients['ctr'] = torch.tensor(1e-4, device=device, dtype=torch.float32, requires_grad=True)

        # Ordinal gradient
        try:
            if L_ord.item() > 1e-6:
                scaled_ord_loss = lambda_ord * L_ord
                ord_grads = torch.autograd.grad(
                    outputs=scaled_ord_loss,
                    inputs=trainable_params,
                    retain_graph=True,
                    create_graph=True,
                    allow_unused=True
                )

                grad_squares = [torch.sum(g**2) for g in ord_grads if g is not None]
                if grad_squares:
                    gradients['ord'] = torch.sqrt(torch.stack(grad_squares).sum() + 1e-8)
                else:
                    gradients['ord'] = torch.tensor(1e-5, device=device, dtype=torch.float32, requires_grad=True)
            else:
                gradients['ord'] = torch.tensor(1e-5, device=device, dtype=torch.float32, requires_grad=True)
        except:
            gradients['ord'] = torch.tensor(1e-5, device=device, dtype=torch.float32, requires_grad=True)

        return gradients, L_gen, L_ctr, L_ord

    def forward(self, input_ids, attention_mask, labels=None,
                contrastive_anchors=None, contrastive_positives=None, contrastive_negatives=None,
                anchor_actions=None, train_meta_params=False):

        # Get current meta-parameters (detached to avoid gradient conflicts)
        with torch.no_grad():
            meta_params = self.meta_params.get_meta_params()
            # Detach to prevent gradient conflicts
            lambda_base = meta_params['lambda_base'].detach()
            lambda_ord = meta_params['lambda_ord'].detach()
            T_gen = meta_params['T_gen'].detach()
            T_ctr = meta_params['T_ctr'].detach()

        # Generation loss per sample
        generation_outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        if labels is not None:
            logits = generation_outputs.logits
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            loss_fct = nn.CrossEntropyLoss(reduction='none')
            token_losses = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
            token_losses = token_losses.view(shift_labels.size())
            mask = (shift_labels != -100).float()
            generation_loss_per_sample = (token_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-8)
        else:
            generation_loss_per_sample = torch.zeros(input_ids.size(0), device=input_ids.device)

        # Initialize contrastive losses
        contrastive_loss_per_sample = torch.zeros_like(generation_loss_per_sample)
        ordinal_loss_per_sample = torch.zeros_like(generation_loss_per_sample)

        # Contrastive losses if data available
        if contrastive_anchors is not None and contrastive_positives is not None and contrastive_negatives is not None:
            # Get embeddings
            anchor_emb = self.get_text_embedding(
                contrastive_anchors['input_ids'],
                contrastive_anchors['attention_mask']
            )
            positive_emb = self.get_text_embedding(
                contrastive_positives['input_ids'],
                contrastive_positives['attention_mask']
            )
            negative_emb = self.get_text_embedding(
                contrastive_negatives['input_ids'],
                contrastive_negatives['attention_mask']
            )

            # Project to contrastive space
            anchor_proj = self.contrastive_head(anchor_emb)
            positive_proj = self.contrastive_head(positive_emb)
            negative_proj = self.contrastive_head(negative_emb)

            # Compute contrastive losses per sample
            contrastive_loss_per_sample = self.contrastive_loss_fn(
                anchor_proj, positive_proj, negative_proj, reduction='none'
            )

            # Ordinal loss with asymmetric margins
            ordinal_loss_per_sample = self.ordinal_contrastive_loss_fn(
                anchor_proj, positive_proj, negative_proj,
                anchor_actions=anchor_actions,
                reduction='none'
            )

        # Compute gating weights per sample (completely detached)
        with torch.no_grad():
            gen_loss_detached = generation_loss_per_sample.detach()
            ctr_loss_detached = contrastive_loss_per_sample.detach()
            ord_loss_detached = ordinal_loss_per_sample.detach()

            # Gating logits per sample
            a_gen = gen_loss_detached / T_gen
            a_ctr = -gen_loss_detached / T_gen + ctr_loss_detached / T_ctr
            a_ord = -gen_loss_detached / T_gen + ord_loss_detached / T_ctr

            # Softmax weights per sample
            logits_stack = torch.stack([a_gen, a_ctr, a_ord], dim=-1)
            w = torch.softmax(logits_stack, dim=-1)

            # Stability: clamp and renormalize
            w_min = 0.05
            w = w.clamp_min(w_min)
            w = w / w.sum(dim=-1, keepdim=True)

            w_gen = w[:, 0]
            w_ctr = w[:, 1]
            w_ord = w[:, 2]

        # Combined loss per sample - using detached lambda values
        scaled_ctr_loss = lambda_base * contrastive_loss_per_sample
        scaled_ord_loss = lambda_ord * ordinal_loss_per_sample

        total_loss_per_sample = (
            w_gen * generation_loss_per_sample +
            w_ctr * scaled_ctr_loss +
            w_ord * scaled_ord_loss
        )

        # Main loss for model parameters
        main_loss = total_loss_per_sample.mean()

        # Store losses for meta-parameter optimization (separate from main forward pass)
        if train_meta_params and self.step_count > 10:
            # Store component losses for later meta-parameter update
            self._stored_losses = {
                'generation_loss_per_sample': generation_loss_per_sample.detach(),
                'contrastive_loss_per_sample': contrastive_loss_per_sample.detach(),
                'ordinal_loss_per_sample': ordinal_loss_per_sample.detach(),
                'w_gen': w_gen,
                'w_ctr': w_ctr,
                'w_ord': w_ord
            }

        self.step_count += 1
        generation_outputs.loss = main_loss
        return generation_outputs

    def update_meta_parameters(self):
        """Separate method to update meta-parameters after main backward pass"""
        if not hasattr(self, '_stored_losses'):
            return

        try:
            stored = self._stored_losses

            # Only proceed if we have meaningful contrastive data
            if (stored['contrastive_loss_per_sample'].sum().item() > 1e-6 and
                stored['ordinal_loss_per_sample'].sum().item() > 1e-6):

                # Get fresh meta-parameters for gradient computation
                meta_params = self.meta_params.get_meta_params()

                # Compute component gradients on fresh forward pass
                gradients, L_gen, L_ctr, L_ord = self.compute_component_gradients(
                    stored['generation_loss_per_sample'],
                    stored['contrastive_loss_per_sample'],
                    stored['ordinal_loss_per_sample'],
                    meta_params
                )

                # Update loss history
                self.meta_params.update_loss_history(L_gen, L_ctr, L_ord)

                # Compute loss ratios
                ratios = self.meta_params.compute_loss_ratios(gamma=0.5)

                # Convert ratios to tensors
                device = L_gen.device
                ratio_gen = torch.tensor(ratios['gen'], device=device, dtype=torch.float32)
                ratio_ctr = torch.tensor(ratios['ctr'], device=device, dtype=torch.float32)
                ratio_ord = torch.tensor(ratios['ord'], device=device, dtype=torch.float32)

                # Target gradient norms
                G_avg = (gradients['gen'] + gradients['ctr'] + gradients['ord']) / 3.0

                targets = {
                    'gen': G_avg * ratio_gen,
                    'ctr': G_avg * ratio_ctr,
                    'ord': G_avg * ratio_ord
                }

                # GradNorm loss
                gradnorm_components = []
                for k in ['gen', 'ctr', 'ord']:
                    if gradients[k].requires_grad and targets[k].requires_grad:
                        grad_diff = torch.abs(gradients[k] - targets[k])
                        gradnorm_components.append(grad_diff)

                if gradnorm_components:
                    gradnorm_loss = torch.stack(gradnorm_components).sum()

                    # Entropy regularizer
                    w_batch = torch.stack([stored['w_gen'], stored['w_ctr'], stored['w_ord']], dim=1)
                    log_weights = torch.log(w_batch + 1e-8)
                    entropy_per_sample = -(w_batch * log_weights).sum(dim=1)
                    avg_entropy = entropy_per_sample.mean()
                    entropy_reg = 0.01 * avg_entropy

                    # Total meta-objective
                    meta_loss = gradnorm_loss + entropy_reg

                    # Initialize meta-optimizer
                    if self.meta_optimizer is None:
                        self.meta_optimizer = torch.optim.Adam(self.meta_params.parameters(), lr=1e-4)

                    # Update meta-parameters
                    self.meta_optimizer.zero_grad()
                    meta_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.meta_params.parameters(), max_norm=1.0)
                    self.meta_optimizer.step()

                    if self.step_count % 20 == 0:
                        print(f"GradNorm - Step {self.step_count}: meta_loss={meta_loss.item():.4f}")

        except Exception as e:
            print(f"Meta-parameter update failed: {e}")

        finally:
            # Clean up stored losses
            if hasattr(self, '_stored_losses'):
                delattr(self, '_stored_losses')

# Wrap model with GradNorm capabilities
model = GradNormContrastiveLoRAModel(model)
print("GradNorm Contrastive LoRA model configured")

# STEP 7: Data Processor
print("\nSTEP 7: Enhanced Few-Shot + Ordinal Contrastive processor...")

class A3CGOrdinalFewShotDataProcessor:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.system_prompt = """You are an expert in ESG analysis. Extract aspect-action pairs from sustainability statements.

CRITICAL INSTRUCTIONS:
- Extract EXACT terms from the input text
- Do not paraphrase or interpret creatively
- Use literal wording from the sentence
- Focus on specific terms rather than general concepts

DEFINITIONS:
- Aspect: A sustainability-related entity, goal, sub-area, or activity (use exact wording)
- Action: "implemented", "planning", or "indeterminate"

OUTPUT FORMAT: ("aspect1", "action1"), ("aspect2", "action2"), ...
If none: ("no aspect", "no action")"""

        self.few_shot_examples = [
            {
                "text": "We have implemented solar panels to reduce energy consumption in our facilities.",
                "output": '("solar panels", "implemented"), ("energy consumption", "implemented")'
            },
            {
                "text": "The company plans to improve workplace diversity initiatives next year.",
                "output": '("workplace diversity initiatives", "planning")'
            },
            {
                "text": "We are committed to enhancing our environmental management systems.",
                "output": '("environmental management systems", "planning")'
            },
            {
                "text": "Our recycling program has achieved a 50% waste reduction.",
                "output": '("recycling program", "implemented"), ("waste reduction", "implemented")'
            },
            {
                "text": "The board may consider sustainability investments where feasible.",
                "output": '("sustainability investments", "indeterminate")'
            }
        ]

        self.contrastive_generator = ContrastiveDataGenerator()

    def get_few_shot_examples(self, n_examples: int = 3) -> str:
        """Randomly selects n examples for few-shot"""
        selected = random.sample(self.few_shot_examples, min(n_examples, len(self.few_shot_examples)))
        examples_text = ""
        for i, example in enumerate(selected, 1):
            examples_text += f"\nExample {i}:\n"
            examples_text += f"Text: {example['text']}\n"
            examples_text += f"Output: {example['output']}\n"
        return examples_text

    def create_prompt(self, text: str, aspects_dict: Dict = None) -> str:
        """Creates a prompt using the Llama 3 chat format"""
        few_shot_text = self.get_few_shot_examples(n_examples=3)
        user_content = (
            f"{few_shot_text.strip()}\n\n"
            f"Now extract from this text:\n"
            f"Text: {text}\n\n"
            f"Extract the aspect-action pairs:"
        )

        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": user_content}
        ]

        if aspects_dict:
            output_pairs = []
            for aspect, actions in aspects_dict.items():
                for action in actions:
                    output_pairs.append(f'("{aspect}", "{action}")')
            expected_output = ', '.join(output_pairs) if output_pairs else '("no aspect", "no action")'
            messages.append({"role": "assistant", "content": expected_output})

        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True if aspects_dict is None else False
        )

    def load_data(self, file_path: str) -> List[Dict]:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)

    def prepare_dataset(self, data: List[Dict]) -> Tuple[Dataset, List[ContrastiveSample]]:
        """Prepare dataset with generation and ordinal contrastive samples"""
        print(f"  Preparing {len(data)} samples...")

        prompts = []
        for i, item in enumerate(data):
            if i % 100 == 0:
                print(f"    Progress: {i}/{len(data)}")
            prompt = self.create_prompt(item['text'], item.get('aspects', {}))
            prompts.append(prompt)

        print("  Generating ordinal contrastive samples...")
        contrastive_samples = self.contrastive_generator.create_ordinal_directional_pairs(data, n_pairs=len(data)//2)

        return Dataset.from_dict({"text": prompts}), contrastive_samples

print("Enhanced processor configured")

# STEP 8: Data Preparation
print("\nSTEP 8: Data preparation...")

processor = A3CGOrdinalFewShotDataProcessor(tokenizer=tokenizer)

train_data = processor.load_data(f"{dataset_base}/seen_train.json")
val_data = processor.load_data(f"{dataset_base}/seen_val.json")

print(f"Train: {len(train_data)} samples")
print(f"Validation: {len(val_data)} samples")

train_dataset, train_contrastive = processor.prepare_dataset(train_data)
val_dataset, val_contrastive = processor.prepare_dataset(val_data)

# Tokenization
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding=False,
        max_length=2048,
        return_tensors=None
    )

print("Tokenizing data...")
train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)
val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=val_dataset.column_names)

print("Data preparation completed")

# STEP 9: Data Collator
print("\nSTEP 9: Enhanced data collator...")

class ContrastiveDataCollator:
    """Enhanced data collator with ordinal directional support"""

    def __init__(self, tokenizer, contrastive_samples: list, contrastive_prob: float = 0.3,
                 contrastive_max_length: int = 512):
        self.tokenizer = tokenizer
        self.contrastive_samples = contrastive_samples
        self.contrastive_prob = contrastive_prob
        self.contrastive_max_length = contrastive_max_length

    def __call__(self, features: list) -> dict:
        batch = self.tokenizer.pad(features, return_tensors="pt", padding=True)

        if "labels" not in features[0]:
            batch["labels"] = batch["input_ids"].clone()

        # Inject contrastive samples
        if random.random() < self.contrastive_prob and self.contrastive_samples:
            batch_size = batch['input_ids'].shape[0]
            selected_contrastive = random.sample(
                self.contrastive_samples,
                min(batch_size, len(self.contrastive_samples))
            )

            anchor_texts = [cs.anchor_text for cs in selected_contrastive]
            positive_texts = [cs.positive_text for cs in selected_contrastive]
            negative_texts = [cs.negative_text for cs in selected_contrastive]
            anchor_actions = [cs.anchor_action_type for cs in selected_contrastive]

            anchor_tokens = self.tokenizer(
                anchor_texts, truncation=True, padding=True,
                max_length=self.contrastive_max_length, return_tensors="pt"
            )
            positive_tokens = self.tokenizer(
                positive_texts, truncation=True, padding=True,
                max_length=self.contrastive_max_length, return_tensors="pt"
            )
            negative_tokens = self.tokenizer(
                negative_texts, truncation=True, padding=True,
                max_length=self.contrastive_max_length, return_tensors="pt"
            )

            batch['contrastive_anchors'] = anchor_tokens
            batch['contrastive_positives'] = positive_tokens
            batch['contrastive_negatives'] = negative_tokens
            batch['anchor_actions'] = anchor_actions

        return batch

data_collator = ContrastiveDataCollator(
    tokenizer=tokenizer,
    contrastive_samples=train_contrastive,
    contrastive_prob=0.3,
    contrastive_max_length=512
)

print("Data collator configured")

# STEP 10: Training Callback
print("\nSTEP 10: Training callback...")

class GradNormMonitorCallback(TrainerCallback):
    def __init__(self):
        self.start_time = time.time()
        self.last_check = time.time()

    def on_step_end(self, args, state, control, model=None, **kwargs):
        current_time = time.time()
        step = state.global_step

        if step % 10 == 0:
            elapsed = current_time - self.start_time

            if torch.cuda.is_available():
                allocated = torch.cuda.memory_allocated() / 1e9
                reserved = torch.cuda.memory_reserved() / 1e9
                print(f"Step {step}: GPU {allocated:.1f}/{reserved:.1f} GB, Time: {elapsed/60:.1f}min")

                if reserved / torch.cuda.get_device_properties(0).total_memory * 1e9 > 0.9:
                    torch.cuda.empty_cache()

            if hasattr(model, 'meta_params') and step % 50 == 0:
                meta_params = model.meta_params.get_meta_params()
                print(f"  Meta-params - lambda_base: {meta_params['lambda_base']:.4f}, "
                      f"lambda_ord: {meta_params['lambda_ord']:.4f}")

            self.last_check = current_time

    def on_epoch_end(self, args, state, control, model=None, **kwargs):
        epoch = state.epoch
        elapsed = time.time() - self.start_time
        print(f"Epoch {epoch} completed in {elapsed/60:.1f} minutes")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

memory_callback = GradNormMonitorCallback()

# STEP 11: Custom Trainer
print("\nSTEP 11: Custom Trainer...")

class GradNormTrainer(Trainer):
    """Custom trainer with GradNorm meta-parameter updates"""

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Custom loss computation with meta-parameter training"""

        train_meta_params = self.state.global_step > 20

        outputs = model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            labels=inputs.get('labels'),
            contrastive_anchors=inputs.get('contrastive_anchors'),
            contrastive_positives=inputs.get('contrastive_positives'),
            contrastive_negatives=inputs.get('contrastive_negatives'),
            anchor_actions=inputs.get('anchor_actions'),
            train_meta_params=train_meta_params
        )

        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss

    def training_step(self, model, inputs, num_items_in_batch=None):
        """Override training step to include meta-parameter updates"""

        # Standard training step
        loss = super().training_step(model, inputs, num_items_in_batch)

        # Update meta-parameters after main backward pass (if enabled)
        if self.state.global_step > 20:
            try:
                model.update_meta_parameters()
            except Exception as e:
                print(f"Meta-parameter update failed in training step: {e}")

        return loss

# STEP 12: Training Configuration
print("\nSTEP 12: Training configuration...")

training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/A3CG_GradNorm_Ordinal_Models",
    num_train_epochs=10,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    optim="paged_adamw_32bit",
    save_steps=100,
    logging_steps=10,
    learning_rate=3e-5,
    weight_decay=0.01,
    fp16=False,
    bf16=True,
    max_grad_norm=1.0,
    warmup_ratio=0.1,
    group_by_length=True,
    lr_scheduler_type="cosine",
    eval_steps=100,
    eval_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    remove_unused_columns=False,
    report_to="none",
)

trainer = GradNormTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.add_callback(memory_callback)

print("Trainer configured")

# STEP 13: Training Execution
print("\nSTEP 13: STARTING GRADNORM ORDINAL DIRECTIONAL TRAINING...")
print("=" * 80)

print(f"Start time: {time.strftime('%H:%M:%S')}")
print(f"Base Model: {model_name}")
print(f"Configuration: LoRA + Ordinal Directional Contrastive + Few-Shot + GradNorm")
print(f"Batch: {training_args.per_device_train_batch_size}x{training_args.gradient_accumulation_steps}")
print(f"Epochs: {training_args.num_train_epochs}")
print(f"Ordinal Rules:")
print(f"  - implemented -> planning (pos) vs indeterminate (neg)")
print(f"  - indeterminate -> planning (pos) vs implemented (neg)")
print(f"  - planning -> implemented (pos) vs indeterminate (neg)")
print(f"Margins: Planning=0.15, Base=0.1")
print("=" * 80)

start_time = time.time()

try:
    trainer.train()

    training_time = time.time() - start_time
    print(f"\nGRADNORM ORDINAL DIRECTIONAL TRAINING COMPLETED!")
    print(f"Total time: {training_time/3600:.1f}h ({training_time/60:.1f}min)")

    # Print final meta-parameters
    final_meta_params = model.meta_params.get_meta_params()
    print(f"\nFinal Meta-Parameters:")
    for k, v in final_meta_params.items():
        print(f"  {k}: {v.item():.6f}")

except Exception as e:
    print(f"\nERROR DURING TRAINING: {e}")
    torch.cuda.empty_cache()
    raise e

# STEP 14: Model Saving
model_output_path = f"./Llama3-8B-gradnorm-ordinal-lora-final"
print(f"\nSTEP 14: Saving model to {model_output_path}...")

try:
    trainer.model.base_model.save_pretrained(model_output_path)
    tokenizer.save_pretrained(model_output_path)

    # Save additional components
    torch.save(trainer.model.contrastive_head.state_dict(), f"{model_output_path}/contrastive_head.pt")
    torch.save(trainer.model.meta_params.state_dict(), f"{model_output_path}/meta_parameters.pt")

    # Save configurations
    final_meta_params = trainer.model.meta_params.get_meta_params()
    config = {
        'final_meta_params': {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in final_meta_params.items()},
        'ordinal_config': {
            'base_margin': trainer.model.ordinal_contrastive_loss_fn.base_margin,
            'planning_margin': trainer.model.ordinal_contrastive_loss_fn.planning_margin,
            'ordinal_rules': {
                'implemented': {'positive': 'planning', 'negative': 'indeterminate'},
                'indeterminate': {'positive': 'planning', 'negative': 'implemented'},
                'planning': {'positive': 'implemented', 'negative': 'indeterminate'}
            }
        }
    }

    with open(f"{model_output_path}/config.json", 'w') as f:
        json.dump(config, f, indent=2)

    print(f"Model saved successfully")

except Exception as e:
    print(f"Save error: {e}")

# STEP 15: Backup to Drive
print("\nSTEP 15: Backup to Google Drive...")

timestamp = datetime.now().strftime("%Y%m%d_%H%M")

try:
    drive_folder = f"/content/drive/MyDrive/A3CG_GradNorm_Ordinal_Models"
    os.makedirs(drive_folder, exist_ok=True)

    drive_model_path = f"{drive_folder}/gradnorm-ordinal-{timestamp}"
    shutil.copytree(model_output_path, drive_model_path)

    total_size_mb = sum(os.path.getsize(os.path.join(drive_model_path, f))
                       for f in os.listdir(drive_model_path)) / (1024*1024)

    print(f"Model backed up to: {drive_model_path}")
    print(f"Total size: {total_size_mb:.1f} MB")

except Exception as e:
    print(f"Backup error: {e}")

# STEP 16: Final Summary
print("\nSTEP 16: Training completed successfully!")
print("=" * 80)
print(f"End time: {time.strftime('%H:%M:%S')}")
print(f"Model location: {model_output_path}")
print("=" * 80)

In [None]:
# =====================================
# GRADNORM + ORDINAL CONTRASTIVE MODEL EVALUATION SCRIPT
# Compatible with GradNorm Meta-Parameters + Ordinal Directional Learning
# =====================================

import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import re
import pandas as pd
import random
import numpy as np
from datetime import datetime
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from difflib import SequenceMatcher
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)
from peft import PeftModel, PeftConfig
import warnings
warnings.filterwarnings('ignore')

print("GRADNORM + ORDINAL CONTRASTIVE MODEL EVALUATION SCRIPT")
print("Compatible with GradNorm Meta-Parameters + Ordinal Directional Learning")
print("=" * 70)
print(f"Start time: {time.strftime('%H:%M:%S')}")
print("=" * 70)

# Step 1: Exact GradNorm Model Classes
print("Step 1: Import GradNorm + Ordinal Contrastive classes...")

# Meta-Parameters Manager (exactly as in training)
class MetaParametersManager(nn.Module):
    def __init__(self, initial_lambda_base=0.15, initial_lambda_ord=0.04,
                 initial_T_gen=2.5, initial_T_ctr=0.8, epsilon=1e-6):
        super().__init__()

        def safe_inverse_softplus(x):
            return float(np.log(np.exp(float(x)) - 1.0))

        self.rho_base = nn.Parameter(
            torch.tensor(safe_inverse_softplus(initial_lambda_base - epsilon),
                        dtype=torch.float32),
            requires_grad=True
        )
        self.rho_ord = nn.Parameter(
            torch.tensor(safe_inverse_softplus(initial_lambda_ord - epsilon),
                        dtype=torch.float32),
            requires_grad=True
        )
        self.tau_gen = nn.Parameter(
            torch.tensor(safe_inverse_softplus(initial_T_gen - epsilon),
                        dtype=torch.float32),
            requires_grad=True
        )
        self.tau_ctr = nn.Parameter(
            torch.tensor(safe_inverse_softplus(initial_T_ctr - epsilon),
                        dtype=torch.float32),
            requires_grad=True
        )

        self.epsilon = epsilon
        self.loss_history = {'gen': [], 'ctr': [], 'ord': []}
        self.initial_losses = {'gen': None, 'ctr': None, 'ord': None}

    def get_meta_params(self):
        lambda_base = F.softplus(self.rho_base) + self.epsilon
        lambda_ord = F.softplus(self.rho_ord) + self.epsilon
        T_gen = F.softplus(self.tau_gen) + self.epsilon
        T_ctr = F.softplus(self.tau_ctr) + self.epsilon

        return {
            'lambda_base': lambda_base,
            'lambda_ord': lambda_ord,
            'T_gen': T_gen,
            'T_ctr': T_ctr
        }

@dataclass
class ContrastiveSample:
    anchor_text: str
    anchor_aspects: Dict
    positive_text: str
    positive_aspects: Dict
    negative_text: str
    negative_aspects: Dict
    anchor_action_type: str = None

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature: float = 0.1):
        super().__init__()
        self.temperature = temperature

    def forward(self, anchor_emb, positive_emb, negative_emb, reduction='mean'):
        anchor_emb = F.normalize(anchor_emb, p=2, dim=1)
        positive_emb = F.normalize(positive_emb, p=2, dim=1)
        negative_emb = F.normalize(negative_emb, p=2, dim=1)

        pos_sim = F.cosine_similarity(anchor_emb, positive_emb, dim=1)
        neg_sim = F.cosine_similarity(anchor_emb, negative_emb, dim=1)

        pos_exp = torch.exp(pos_sim / self.temperature)
        neg_exp = torch.exp(neg_sim / self.temperature)

        loss_per_sample = -torch.log(pos_exp / (pos_exp + neg_exp + 1e-8))

        if reduction == 'mean':
            return loss_per_sample.mean()
        elif reduction == 'none':
            return loss_per_sample
        else:
            raise ValueError(f"reduction must be 'mean' or 'none', got {reduction}")

class OrdinalContrastiveLoss(nn.Module):
    def __init__(self, base_margin: float = 0.10, planning_margin: float = 0.15):
        super().__init__()
        self.base_margin = base_margin
        self.planning_margin = planning_margin

        self.action_margins = {
            'implemented': base_margin,
            'indeterminate': base_margin,
            'planning': planning_margin
        }

    def forward(self, anchor_emb, positive_emb, negative_emb, anchor_actions=None, reduction='mean'):
        anchor_emb = F.normalize(anchor_emb, p=2, dim=1)
        positive_emb = F.normalize(positive_emb, p=2, dim=1)
        negative_emb = F.normalize(negative_emb, p=2, dim=1)

        distance_0_1 = 1 - F.cosine_similarity(anchor_emb, positive_emb, dim=1)
        distance_0_2 = 1 - F.cosine_similarity(anchor_emb, negative_emb, dim=1)

        if anchor_actions is not None:
            margins = torch.tensor([
                self.action_margins.get(action, self.base_margin)
                for action in anchor_actions
            ], device=anchor_emb.device, dtype=anchor_emb.dtype)
        else:
            margins = torch.full((anchor_emb.size(0),), self.base_margin,
                               device=anchor_emb.device, dtype=anchor_emb.dtype)

        loss_per_sample = F.relu(distance_0_1 - distance_0_2 + margins)

        if reduction == 'mean':
            return loss_per_sample.mean()
        elif reduction == 'none':
            return loss_per_sample
        else:
            raise ValueError(f"reduction must be 'mean' or 'none', got {reduction}")

class GradNormContrastiveLoRAModel(nn.Module):
    """GradNorm model with meta-parameter optimization"""

    def __init__(self, base_model, hidden_size=4096, contrastive_dim=256):
        super().__init__()
        self.base_model = base_model
        self.hidden_size = hidden_size
        self.contrastive_dim = contrastive_dim

        # Contrastive projection head
        self.contrastive_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, contrastive_dim)
        )

        # Loss functions with asymmetric margins
        self.contrastive_loss_fn = ContrastiveLoss(temperature=0.1)
        self.ordinal_contrastive_loss_fn = OrdinalContrastiveLoss(
            base_margin=0.10,
            planning_margin=0.15
        )

        # Meta-parameters manager
        self.meta_params = MetaParametersManager()

        self.step_count = 0
        self.meta_optimizer = None

    def get_text_embedding(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.base_model.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )

            last_hidden = outputs.hidden_states[-1]
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size())
            sum_hidden = torch.sum(last_hidden * mask_expanded, dim=1)
            sum_mask = torch.sum(mask_expanded, dim=1)
            mean_hidden = sum_hidden / sum_mask

            return mean_hidden

    def forward(self, input_ids, attention_mask, labels=None,
                contrastive_anchors=None, contrastive_positives=None, contrastive_negatives=None,
                anchor_actions=None, train_meta_params=False):

        # Get current meta-parameters (detached for inference)
        with torch.no_grad():
            meta_params = self.meta_params.get_meta_params()
            lambda_base = meta_params['lambda_base'].detach()
            lambda_ord = meta_params['lambda_ord'].detach()
            T_gen = meta_params['T_gen'].detach()
            T_ctr = meta_params['T_ctr'].detach()

        # Generation loss
        generation_outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        if labels is not None:
            logits = generation_outputs.logits
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            loss_fct = nn.CrossEntropyLoss(reduction='none')
            token_losses = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
            token_losses = token_losses.view(shift_labels.size())
            mask = (shift_labels != -100).float()
            generation_loss_per_sample = (token_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-8)
        else:
            generation_loss_per_sample = torch.zeros(input_ids.size(0), device=input_ids.device)

        contrastive_loss_per_sample = torch.zeros_like(generation_loss_per_sample)
        ordinal_loss_per_sample = torch.zeros_like(generation_loss_per_sample)

        # Contrastive losses if data available
        if contrastive_anchors is not None and contrastive_positives is not None and contrastive_negatives is not None:
            anchor_emb = self.get_text_embedding(
                contrastive_anchors['input_ids'],
                contrastive_anchors['attention_mask']
            )
            positive_emb = self.get_text_embedding(
                contrastive_positives['input_ids'],
                contrastive_positives['attention_mask']
            )
            negative_emb = self.get_text_embedding(
                contrastive_negatives['input_ids'],
                contrastive_negatives['attention_mask']
            )

            anchor_proj = self.contrastive_head(anchor_emb)
            positive_proj = self.contrastive_head(positive_emb)
            negative_proj = self.contrastive_head(negative_emb)

            contrastive_loss_per_sample = self.contrastive_loss_fn(
                anchor_proj, positive_proj, negative_proj, reduction='none'
            )

            ordinal_loss_per_sample = self.ordinal_contrastive_loss_fn(
                anchor_proj, positive_proj, negative_proj,
                anchor_actions=anchor_actions,
                reduction='none'
            )

        # Gating weights per sample
        with torch.no_grad():
            gen_loss_detached = generation_loss_per_sample.detach()
            ctr_loss_detached = contrastive_loss_per_sample.detach()
            ord_loss_detached = ordinal_loss_per_sample.detach()

            a_gen = gen_loss_detached / T_gen
            a_ctr = -gen_loss_detached / T_gen + ctr_loss_detached / T_ctr
            a_ord = -gen_loss_detached / T_gen + ord_loss_detached / T_ctr

            logits_stack = torch.stack([a_gen, a_ctr, a_ord], dim=-1)
            w = torch.softmax(logits_stack, dim=-1)

            w_min = 0.05
            w = w.clamp_min(w_min)
            w = w / w.sum(dim=-1, keepdim=True)

            w_gen = w[:, 0]
            w_ctr = w[:, 1]
            w_ord = w[:, 2]

        # Combined loss
        scaled_ctr_loss = lambda_base * contrastive_loss_per_sample
        scaled_ord_loss = lambda_ord * ordinal_loss_per_sample

        total_loss_per_sample = (
            w_gen * generation_loss_per_sample +
            w_ctr * scaled_ctr_loss +
            w_ord * scaled_ord_loss
        )

        main_loss = total_loss_per_sample.mean()
        generation_outputs.loss = main_loss
        return generation_outputs

    def generate(self, *args, **kwargs):
        return self.base_model.generate(*args, **kwargs)

print("GradNorm + Ordinal Contrastive classes imported")

# Step 2: GradNorm Model Loading
print("\nStep 2: Loading trained GradNorm model...")

MODEL_BASE = "meta-llama/Meta-Llama-3-8B-Instruct"

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Configure paths
DATA_DIR = "/content/A3CG_DATASET/folds/fold_1"
DRIVE_DATA_DIR = "/content/drive/MyDrive/A3CG_Dataset/folds/fold_1"

if os.path.exists(DATA_DIR):
    print(f"Using data directory: {DATA_DIR}")
elif os.path.exists(DRIVE_DATA_DIR):
    DATA_DIR = DRIVE_DATA_DIR
    print(f"Using Drive data directory: {DATA_DIR}")
else:
    print("Data directory not found!")

def load_gradnorm_model_robust(model_path: str):
    """Robust loading of GradNorm + Ordinal model"""

    print(f"Analyzing GradNorm model: {model_path}")

    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model not found: {model_path}")

    files_in_model = os.listdir(model_path)
    print(f"Files found: {files_in_model}")

    # Load tokenizer
    print("Loading tokenizer...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        print("Tokenizer loaded from model")
    except:
        print("Fallback: tokenizer from base model")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load base model
    print("Loading base model...")
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )

    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_BASE,
        quantization_config=quantization_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        attn_implementation="flash_attention_2"
    )

    # Load LoRA adapters
    print("Loading LoRA adapters...")
    lora_model = PeftModel.from_pretrained(
        base_model,
        model_path,
        torch_dtype=torch.bfloat16
    )

    # Create GradNorm model
    print("Creating GradNorm model...")
    hidden_size = base_model.config.hidden_size
    print(f"Detected hidden size: {hidden_size}")

    gradnorm_model = GradNormContrastiveLoRAModel(
        base_model=lora_model,
        hidden_size=hidden_size,
        contrastive_dim=256
    )

    # Load contrastive head
    contrastive_head_path = os.path.join(model_path, "contrastive_head.pt")

    if os.path.exists(contrastive_head_path):
        print("Loading contrastive head...")
        try:
            state_dict = torch.load(contrastive_head_path, map_location='cpu')
            print(f"Contrastive weights loaded: {list(state_dict.keys())}")

            gradnorm_model.contrastive_head.load_state_dict(state_dict)
            device = next(gradnorm_model.base_model.parameters()).device
            gradnorm_model.contrastive_head.to(device)
            print("Contrastive head loaded successfully")
            contrastive_loaded = True
        except Exception as e:
            print(f"Error loading contrastive head: {e}")
            contrastive_loaded = False
    else:
        print("No contrastive head found")
        contrastive_loaded = False

    # Load GradNorm meta-parameters
    meta_params_path = os.path.join(model_path, "meta_parameters.pt")

    if os.path.exists(meta_params_path):
        print("Loading GradNorm meta-parameters...")
        try:
            meta_state_dict = torch.load(meta_params_path, map_location='cpu')
            print(f"Meta-parameters loaded: {list(meta_state_dict.keys())}")

            gradnorm_model.meta_params.load_state_dict(meta_state_dict)
            device = next(gradnorm_model.base_model.parameters()).device
            gradnorm_model.meta_params.to(device)

            # Display final values
            final_params = gradnorm_model.meta_params.get_meta_params()
            print("Meta-parameter values:")
            for k, v in final_params.items():
                print(f"  {k}: {v.item():.6f}")

            meta_loaded = True
        except Exception as e:
            print(f"Error loading meta-parameters: {e}")
            meta_loaded = False
    else:
        print("No meta-parameters found")
        meta_loaded = False

    # Load ordinal configuration
    ordinal_config_path = os.path.join(model_path, "ordinal_config.json")
    ordinal_config = None

    if os.path.exists(ordinal_config_path):
        print("Loading ordinal configuration...")
        try:
            with open(ordinal_config_path, 'r') as f:
                ordinal_config = json.load(f)
            print("Ordinal configuration loaded")
            print(f"  Margins: Planning={ordinal_config.get('planning_margin', 0.15)}, Base={ordinal_config.get('base_margin', 0.10)}")
        except Exception as e:
            print(f"Error loading ordinal config: {e}")

    # Determine loading status
    if contrastive_loaded and meta_loaded:
        loading_status = "gradnorm_complete"
    elif contrastive_loaded:
        loading_status = "contrastive_only"
    else:
        loading_status = "lora_only"

    return gradnorm_model, tokenizer, loading_status, ordinal_config

# Find and load GradNorm model
print("Searching for trained GradNorm model...")

# Possible paths for GradNorm model
possible_paths = [
    "./Llama3-8B-gradnorm-ordinal-lora-final",
    "/content/drive/MyDrive/A3CG_GradNorm_Ordinal_Models"
]

# Add dynamic GradNorm folders
for base_dir in ["/content/drive/MyDrive/A3CG_GradNorm_Ordinal_Models"]:
    if os.path.exists(base_dir):
        for folder in os.listdir(base_dir):
            if "gradnorm" in folder.lower() or "ordinal" in folder.lower():
                possible_paths.append(os.path.join(base_dir, folder))

model_path = None
for path in possible_paths:
    if os.path.exists(path):
        model_path = path
        print(f"GradNorm model found: {path}")
        break

if not model_path:
    print("No GradNorm model found!")
    print("Paths checked:")
    for path in possible_paths:
        print(f"  - {path}")
    raise FileNotFoundError("No GradNorm model found!")

# Load model
model, tokenizer, loading_status, ordinal_config = load_gradnorm_model_robust(model_path)

print(f"Model loaded with status: {loading_status}")

# Step 3: GradNorm Architecture Verification
print("\nStep 3: GradNorm architecture verification...")

def verify_gradnorm_architecture(model, loading_status, ordinal_config):
    """Verify complete GradNorm architecture"""

    print("Verifying GradNorm architecture...")

    if loading_status == "gradnorm_complete":
        print("Complete GradNorm model with meta-parameters")

        # Verify meta-parameters
        if hasattr(model, 'meta_params'):
            print("Meta-parameters present")
            final_params = model.meta_params.get_meta_params()
            print("Current meta-parameter values:")
            for k, v in final_params.items():
                print(f"  {k}: {v.item():.6f}")

        # Verify contrastive head
        if hasattr(model, 'contrastive_head'):
            print("Contrastive head present")

        # Verify ordinal margins
        if hasattr(model, 'ordinal_contrastive_loss_fn'):
            print("Ordinal loss with asymmetric margins present")
            print(f"  Planning margin: {model.ordinal_contrastive_loss_fn.planning_margin}")
            print(f"  Base margin: {model.ordinal_contrastive_loss_fn.base_margin}")

        return True

    elif loading_status == "contrastive_only":
        print("Contrastive model without meta-parameters")
        return True

    else:
        print("LoRA model only")
        return True

architecture_ok = verify_gradnorm_architecture(model, loading_status, ordinal_config)

if architecture_ok:
    print("GradNorm architecture verified")
else:
    print("Problem with GradNorm architecture")

# Step 4: Compatible Prompt Generator
print("\nStep 4: Prompt generator configuration...")

class A3CGGradNormEvaluationPromptGenerator:
    def __init__(self, tokenizer, model_type: str = "gradnorm_complete"):
        self.tokenizer = tokenizer
        self.model_type = model_type

        self.system_prompt = """You are an expert in ESG analysis. Extract aspect-action pairs from sustainability statements.

CRITICAL INSTRUCTIONS:
- Extract EXACT terms from the input text
- Do not paraphrase or interpret creatively
- Use literal wording from the sentence
- Focus on specific terms rather than general concepts

DEFINITIONS:
- Aspect: A sustainability-related entity, goal, sub-area, or activity (use exact wording)
- Action: "implemented", "planning", or "indeterminate"

OUTPUT FORMAT: ("aspect1", "action1"), ("aspect2", "action2"), ...
If none: ("no aspect", "no action")"""

        self.few_shot_examples = [
            {
                "text": "We have implemented solar panels to reduce energy consumption in our facilities.",
                "output": '("solar panels", "implemented"), ("energy consumption", "implemented")'
            },
            {
                "text": "The company plans to improve workplace diversity initiatives next year.",
                "output": '("workplace diversity initiatives", "planning")'
            },
            {
                "text": "We are committed to enhancing our environmental management systems.",
                "output": '("environmental management systems", "planning")'
            },
            {
                "text": "Our recycling program has achieved a 50% waste reduction.",
                "output": '("recycling program", "implemented"), ("waste reduction", "implemented")'
            },
            {
                "text": "The board may consider sustainability investments where feasible.",
                "output": '("sustainability investments", "indeterminate")'
            }
        ]

    def get_few_shot_examples(self, n_examples: int = 3) -> str:
        selected = random.sample(self.few_shot_examples, min(n_examples, len(self.few_shot_examples)))
        examples_text = ""
        for i, example in enumerate(selected, 1):
            examples_text += f"\nExample {i}:\n"
            examples_text += f"Text: {example['text']}\n"
            examples_text += f"Output: {example['output']}\n"
        return examples_text

    def create_prompt(self, sentence: str) -> str:
        if self.model_type in ["gradnorm_complete", "contrastive_only"]:
            few_shot_text = self.get_few_shot_examples(n_examples=3)
            user_content = (
                f"{few_shot_text.strip()}\n\n"
                f"Now extract from this text:\nText: {sentence}\n\n"
                f"Extract the aspect-action pairs:"
            )

            messages = [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": user_content}
            ]

            return self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )

        else:
            messages = [
                {"role": "user", "content": f"Extract aspect-action pairs from: {sentence}"}
            ]
            return self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )

# Initialize generator
prompt_generator = A3CGGradNormEvaluationPromptGenerator(tokenizer, loading_status)

print(f"Prompt generator configured for {loading_status}")

# Step 5: Evaluation Functions
print("\nStep 5: Evaluation functions configuration...")

def generate_prediction_gradnorm(model, tokenizer, sentence: str, model_type: str, max_length: int = 512) -> str:
    """Optimized prediction generation for GradNorm"""

    prompt = prompt_generator.create_prompt(sentence)

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=1024,
        padding=True
    )

    device = next(model.parameters()).device if hasattr(model, 'parameters') else 'cuda' if torch.cuda.is_available() else 'cpu'
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Optimized generation parameters
    gen_params = {
        "max_new_tokens": 150,
        "do_sample": True,
        "temperature": 0.1,
        "top_p": 0.9,
        "repetition_penalty": 1.1,
        "pad_token_id": tokenizer.eos_token_id,
        "eos_token_id": tokenizer.eos_token_id,
    }

    with torch.no_grad():
        try:
            outputs = model.generate(**inputs, **gen_params)
        except Exception as e:
            print(f"WARNING: Generation error: {e}")
            outputs = model.generate(
                inputs['input_ids'],
                max_new_tokens=150,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id
            )

    generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
    response = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return response.strip()

def parse_prediction_enhanced(prediction: str, model_type: str) -> List[Tuple[str, str]]:
    """Enhanced parsing for tuple format"""
    pairs = []

    try:
        # Tuple format: ("aspect", "action")
        tuple_pattern = r'\(\s*["\']([^"\']+)["\']\s*,\s*["\']([^"\']+)["\']\s*\)'
        matches = re.findall(tuple_pattern, prediction)

        for aspect, action in matches:
            aspect = aspect.strip()
            action = action.strip()
            if aspect and action and aspect != "no aspect" and action != "no action":
                pairs.append((aspect.lower(), action.lower()))

        # If no tuples found, try simple pattern
        if not pairs:
            simple_pattern = r'\(\s*([^,\)]+)\s*,\s*([^,\)]+)\s*\)'
            simple_matches = re.findall(simple_pattern, prediction)

            for aspect, action in simple_matches:
                aspect = aspect.strip().strip('"\'')
                action = action.strip().strip('"\'')
                if aspect and action and aspect != "no aspect" and action != "no action":
                    pairs.append((aspect.lower(), action.lower()))

    except Exception as e:
        print(f"WARNING: Parsing error: {e}")

    return pairs

# Metric functions (identical to original script)
def calculate_metrics_exact_match(predictions: List[List[Tuple[str, str]]],
                                 ground_truth: List[List[Tuple[str, str]]]) -> Dict:
    """Exact Match metrics as in A3CG paper"""

    exact_matches = 0
    partial_matches = 0
    total_pred_pairs = 0
    total_true_pairs = 0

    exact_true_positives = 0
    exact_false_positives = 0
    exact_false_negatives = 0

    for pred_pairs, true_pairs in zip(predictions, ground_truth):
        total_pred_pairs += len(pred_pairs)
        total_true_pairs += len(true_pairs)

        pred_set = set(pred_pairs)
        true_set = set(true_pairs)

        matched_pairs = pred_set.intersection(true_set)

        exact_true_positives += len(matched_pairs)
        exact_false_positives += len(pred_set - true_set)
        exact_false_negatives += len(true_set - pred_set)

        if pred_set == true_set and len(pred_set) > 0:
            exact_matches += 1

        if len(matched_pairs) > 0:
            partial_matches += 1

    n_samples = len(predictions)

    exact_match_accuracy = exact_matches / n_samples if n_samples > 0 else 0
    partial_match_accuracy = partial_matches / n_samples if n_samples > 0 else 0

    exact_precision = exact_true_positives / (exact_true_positives + exact_false_positives) if (exact_true_positives + exact_false_positives) > 0 else 0
    exact_recall = exact_true_positives / (exact_true_positives + exact_false_negatives) if (exact_true_positives + exact_false_negatives) > 0 else 0
    exact_f1_score = 2 * (exact_precision * exact_recall) / (exact_precision + exact_recall) if (exact_precision + exact_recall) > 0 else 0

    return {
        'exact_match_accuracy': exact_match_accuracy,
        'partial_match_accuracy': partial_match_accuracy,
        'exact_precision': exact_precision,
        'exact_recall': exact_recall,
        'exact_f1_score': exact_f1_score,
        'exact_true_positives': exact_true_positives,
        'exact_false_positives': exact_false_positives,
        'exact_false_negatives': exact_false_negatives,
        'total_predictions': total_pred_pairs,
        'total_ground_truth': total_true_pairs,
    }

def load_test_data_flexible(file_path: str) -> Tuple[List[str], List[List[Tuple[str, str]]]]:
    """Flexible test data loading"""
    print(f"Loading: {os.path.basename(file_path)}")

    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    sentences = []
    ground_truth = []

    if isinstance(data, list):
        for item in data:
            if isinstance(item, dict):
                if 'text' in item and 'aspects' in item:
                    sentence = item['text']
                    pairs = []

                    aspects_dict = item['aspects']
                    if isinstance(aspects_dict, dict):
                        for aspect, actions in aspects_dict.items():
                            if isinstance(actions, list):
                                for action in actions:
                                    pairs.append((aspect.strip(), action.strip()))
                            elif isinstance(actions, str):
                                pairs.append((aspect.strip(), actions.strip()))

                    sentences.append(sentence)
                    ground_truth.append(pairs)

    print(f"Loaded {len(sentences)} samples")
    return sentences, ground_truth

print("Evaluation functions configured")

# Step 6: Test Data Loading
print("\nStep 6: Loading test data...")

test_files = {
    'seen_test': f"{DATA_DIR}/seen_test.json",
    'unseen_test': f"{DATA_DIR}/unseen_test.json"
}

test_data = {}
for name, file_path in test_files.items():
    if os.path.exists(file_path):
        try:
            sentences, ground_truth = load_test_data_flexible(file_path)
            if len(sentences) > 0:
                test_data[name] = (sentences, ground_truth)
                print(f"{name}: {len(sentences)} samples")
        except Exception as e:
            print(f"ERROR loading {name}: {e}")
    else:
        print(f"File not found: {file_path}")

if not test_data:
    print("WARNING: No test files found - creating demo data")
    test_data['demo'] = (
        ["This company has implemented strong environmental policies."],
        [[("environmental policies", "implemented")]]
    )

# Step 7: GradNorm Model Evaluation
print("\nStep 7: GradNorm model evaluation...")
print("=" * 60)
print(f"Model evaluated: {model_path}")
print(f"Loading status: {loading_status}")
print(f"Architecture verified: {'Yes' if architecture_ok else 'No'}")

# Display meta-parameters if available
if loading_status == "gradnorm_complete" and hasattr(model, 'meta_params'):
    print("\nGradNorm Meta-parameters:")
    final_params = model.meta_params.get_meta_params()
    for k, v in final_params.items():
        print(f"  {k}: {v.item():.6f}")
    print(f"Ordinal margins: Planning={model.ordinal_contrastive_loss_fn.planning_margin}, Base={model.ordinal_contrastive_loss_fn.base_margin}")

print("=" * 60)

results = {}

for dataset_name, (sentences, ground_truth) in test_data.items():
    print(f"\nEvaluation on {dataset_name}...")
    print(f"Base Model: {MODEL_BASE}")
    print(f"Model type: {loading_status}")
    print(f"Number of samples: {len(sentences)}")

    predictions = []
    start_time = time.time()

    n_test = min(len(sentences), 150)  # Adjust as needed
    test_sentences = sentences[:n_test]
    test_ground_truth = ground_truth[:n_test]

    print(f"Testing on {n_test} samples...")

    for i, sentence in enumerate(test_sentences):
        if i % 25 == 0:
            print(f"   Progress: {i}/{n_test} ({i/n_test*100:.1f}%)")

        try:
            prediction_text = generate_prediction_gradnorm(model, tokenizer, sentence, loading_status)
            pred_pairs = parse_prediction_enhanced(prediction_text, loading_status)
            predictions.append(pred_pairs)

            if i < 3:
                print(f"   Sample {i+1}: {len(pred_pairs)} pairs predicted")
                print(f"   Raw output: {prediction_text[:100]}...")
                print(f"   Parsed pairs: {pred_pairs}")

        except Exception as e:
            print(f"WARNING: Error sample {i}: {e}")
            predictions.append([])

    evaluation_time = time.time() - start_time

    # Calculate metrics
    print(f"Calculating metrics...")
    exact_metrics = calculate_metrics_exact_match(predictions, test_ground_truth)

    # Save results
    results[dataset_name] = {
        'exact_metrics': exact_metrics,
        'evaluation_time': evaluation_time,
        'samples_per_second': n_test / evaluation_time,
        'predictions': predictions[:5],
        'ground_truth': test_ground_truth[:5],
        'model_type': loading_status,
        'n_samples': n_test,
        'base_model': MODEL_BASE
    }

    print(f"Evaluation completed in {evaluation_time:.2f}s")
    print(f"Speed: {n_test/evaluation_time:.2f} samples/sec")

# Step 8: Results Display
print("\n" + "="*80)
print("GRADNORM + ORDINAL CONTRASTIVE EVALUATION RESULTS")
print("="*80)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

for dataset_name, result in results.items():
    exact_metrics = result['exact_metrics']

    print(f"\nRESULTS - {dataset_name.upper()}")
    print(f"Base model: {result['base_model']}")
    print(f"Model type: {loading_status}")
    print("-" * 60)

    print(f"EXACT MATCH METRICS:")
    print(f"   Exact Match Accuracy:     {exact_metrics['exact_match_accuracy']:.4f} ({exact_metrics['exact_match_accuracy']*100:.2f}%)")
    print(f"   Exact Precision:          {exact_metrics['exact_precision']:.4f} ({exact_metrics['exact_precision']*100:.2f}%)")
    print(f"   Exact Recall:             {exact_metrics['exact_recall']:.4f} ({exact_metrics['exact_recall']*100:.2f}%)")
    print(f"   Exact F1-Score:           {exact_metrics['exact_f1_score']:.4f} ({exact_metrics['exact_f1_score']*100:.2f}%)")

    print(f"\nPERFORMANCE:")
    print(f"   Evaluation speed:         {result['samples_per_second']:.2f} samples/sec")

# Step 9: Meta-Parameters Analysis
if loading_status == "gradnorm_complete" and hasattr(model, 'meta_params'):
    print(f"\n" + "="*80)
    print("GRADNORM META-PARAMETERS ANALYSIS")
    print("="*80)

    final_params = model.meta_params.get_meta_params()

    print("Optimized meta-parameters:")
    print(f"  lambda_base (contrastive): {final_params['lambda_base'].item():.6f}")
    print(f"  lambda_ord (ordinal): {final_params['lambda_ord'].item():.6f}")
    print(f"  T_gen (generation): {final_params['T_gen'].item():.6f}")
    print(f"  T_ctr (contrastive): {final_params['T_ctr'].item():.6f}")

    print("\nOrdinal configuration:")
    print(f"  Planning margin: {model.ordinal_contrastive_loss_fn.planning_margin}")
    print(f"  Base margin: {model.ordinal_contrastive_loss_fn.base_margin}")

    # Balance analysis
    lambda_ratio = final_params['lambda_ord'].item() / final_params['lambda_base'].item()
    temp_ratio = final_params['T_ctr'].item() / final_params['T_gen'].item()

    print(f"\nBalance analysis:")
    print(f"  Ratio lambda_ord/lambda_base: {lambda_ratio:.3f}")
    print(f"  Ratio T_ctr/T_gen: {temp_ratio:.3f}")

    if lambda_ratio < 0.5:
        print(f"  Note: lambda_ord relatively low - ordinal objective has less impact")
    elif lambda_ratio > 1.5:
        print(f"  Note: lambda_ord relatively high - ordinal objective dominates")
    else:
        print(f"  Note: Good balance between contrastive objectives")

# Step 10: Results Saving
print(f"\nStep 10: Saving results...")

results_dir = f"/content/drive/MyDrive/A3CG_GradNorm_Evaluation_Results"
os.makedirs(results_dir, exist_ok=True)

results_file = f"{results_dir}/gradnorm_evaluation_results_{timestamp}.json"

save_data = {
    'timestamp': timestamp,
    'model_info': {
        'model_path': model_path,
        'base_model': MODEL_BASE,
        'loading_status': loading_status,
        'architecture_verified': architecture_ok
    },
    'evaluation_results': results
}

# Add meta-parameters if available
if loading_status == "gradnorm_complete" and hasattr(model, 'meta_params'):
    final_params = model.meta_params.get_meta_params()
    save_data['meta_parameters'] = {
        k: v.item() if isinstance(v, torch.Tensor) else v
        for k, v in final_params.items()
    }

# Save
with open(results_file, 'w', encoding='utf-8') as f:
    json.dump(save_data, f, indent=2, ensure_ascii=False)

print(f"Results saved: {results_file}")

# Conclusion


print("\n" + "="*80)
print("GRADNORM EVALUATION COMPLETED!")
print("="*80)
print(f"End time: {time.strftime('%H:%M:%S')}")
print(f"Model evaluated: {os.path.basename(model_path)}")
print(f"Status: {loading_status}")

if results:
    avg_f1 = np.mean([r['exact_metrics']['exact_f1_score'] for r in results.values()])
    print(f"Average F1-Score: {avg_f1:.4f} ({avg_f1*100:.2f}%)")

print(f"Results saved in: {results_dir}")

if loading_status == "gradnorm_complete":
    print("SUCCESS: Complete GradNorm model evaluated with meta-parameters!")
elif loading_status == "contrastive_only":
    print("SUCCESS: Contrastive model evaluated (without meta-parameters)")
else:
    print("WARNING: Only LoRA adapters were evaluated")

print("="*80)