# TS PEFT

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import re
import subprocess
import tempfile
from datetime import datetime
import math
from typing import Optional, List, Dict
from dataclasses import dataclass
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR


@dataclass
class TSPEFTConfig:
    rank: int = 32
    alpha: float = 0.5
    dropout: float = 0.05
    s: float = 4e-5
    lambda_reg: float = 4.5e-5
    target_modules: List[str] = None
    
    def __post_init__(self):
        if self.target_modules is None:
            self.target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]


class LoRALayer(nn.Module):
    def __init__(self, in_features: int, out_features: int, rank: int = 4, alpha: float = 1.0, dropout: float = 0.0):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        self.lora_A = nn.Parameter(torch.zeros(rank, in_features, dtype=torch.bfloat16))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank, dtype=torch.bfloat16))
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        with torch.no_grad():
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_dropped = self.dropout(x)
        result = (x_dropped @ self.lora_A.T @ self.lora_B.T) * self.scaling
        return result


class TSPEFTLayer(nn.Module):
    def __init__(self, base_layer: nn.Linear, rank: int = 4, alpha: float = 1.0, dropout: float = 0.0, s: float = 4e-5, lambda_reg: float = 1e-5, beta1: float = 0.9, beta2: float = 0.98, eps: float = 1e-8):
        super().__init__()
        self.base_layer = base_layer
        for param in self.base_layer.parameters():
            param.requires_grad = False
        self.lora = LoRALayer(base_layer.in_features, base_layer.out_features, rank=rank, alpha=alpha, dropout=dropout)
        self.s = s
        self.lambda_reg = lambda_reg
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.tau = nn.Parameter(torch.tensor(0.0), requires_grad=False)
        self.m = nn.Parameter(torch.tensor(0.0), requires_grad=False)
        self.v = nn.Parameter(torch.tensor(0.0), requires_grad=False)
        self.step = nn.Parameter(torch.tensor(0, dtype=torch.long), requires_grad=False)
        
    def compute_relative_magnitude(self, base_output: torch.Tensor, lora_output: torch.Tensor) -> torch.Tensor:
        base_norm = torch.norm(base_output, p=2, dim=-1, keepdim=True)
        lora_norm = torch.norm(lora_output, p=2, dim=-1, keepdim=True)
        r_i = lora_norm / (base_norm + self.eps)
        return r_i.squeeze(-1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        base_output = self.base_layer(x)
        lora_output = self.lora(x)
        if not self.training:
            r_i = self.compute_relative_magnitude(base_output, lora_output)
            gate = (r_i >= self.tau).float().unsqueeze(-1)
            return base_output + gate * lora_output
        r_i = self.compute_relative_magnitude(base_output, lora_output)
        gate = (r_i >= self.tau).float()
        gated_output = base_output + gate.unsqueeze(-1) * lora_output
        self._cache_for_backward = {'r_i': r_i, 'gate': gate, 'lora_output': lora_output, 'base_output': base_output}
        return gated_output
    
    def compute_threshold_gradient(self, grad_output: torch.Tensor) -> float:
        if not hasattr(self, '_cache_for_backward'):
            return 0.0
        cache = self._cache_for_backward
        r_i = cache['r_i']
        gate = cache['gate']
        lora_output = cache['lora_output']
        mu_i = (grad_output * lora_output).sum(dim=-1)
        consistency_mask = ((mu_i >= 0).float() == gate).float()
        sparsity_mask = gate
        grad_loss = -self.s * (consistency_mask * mu_i).sum()
        grad_sparsity = -self.s * (sparsity_mask * self.lambda_reg).sum()
        g_k = grad_loss + grad_sparsity
        return g_k.item()
    
    def update_threshold(self, grad_output: torch.Tensor, lr: float = 1.0):
        if not self.training:
            return
        g_k = self.compute_threshold_gradient(grad_output)
        self.step.data += 1
        self.m.data = self.beta1 * self.m.data + (1 - self.beta1) * g_k
        self.v.data = self.beta2 * self.v.data + (1 - self.beta2) * (g_k ** 2)
        m_hat = self.m.data / (1 - self.beta1 ** self.step.item())
        v_hat = self.v.data / (1 - self.beta2 ** self.step.item())
        tau_update = lr * self.s * m_hat / (torch.sqrt(v_hat) + self.eps)
        self.tau.data = torch.clamp(self.tau.data + tau_update, min=0.0)
        if hasattr(self, '_cache_for_backward'):
            delattr(self, '_cache_for_backward')

class CodeTranslationDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


class TSPEFTMistral(nn.Module):
    def __init__(self, model_name: str, tspeft_config: TSPEFTConfig, cache_dir: str = None):
        super().__init__()
        torch.cuda.empty_cache()
        gc.collect()
        num_gpus = torch.cuda.device_count()
        print(f"\nInitializing model across {num_gpus} GPUs")
        for i in range(num_gpus):
            props = torch.cuda.get_device_properties(i)
            print(f"  GPU {i}: {props.name} - {props.total_memory / 1024**3:.2f} GB")
        self.primary_device = 'cuda:0'
        print(f"\nLoading base model: {model_name}")
        base_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=cache_dir, low_cpu_mem_usage=True, device_map='auto', trust_remote_code=True)
        self.llm = base_model
        self.tspeft_config = tspeft_config
        self.tspeft_layers = nn.ModuleDict()
        num_layers = len(base_model.model.layers)
        print(f"Model has {num_layers} transformer layers")
        for param in self.llm.parameters():
            param.requires_grad = False
        self._inject_tspeft_layers()
        print("Model initialization complete!\n")
    
    def _inject_tspeft_layers(self):
        num_layers = len(self.llm.model.layers)
        print(f"Injecting TS-PEFT layers for {num_layers} transformer layers...")
        trainable_params = 0
        for layer_idx in range(num_layers):
            layer = self.llm.model.layers[layer_idx]
            layer_device = next(layer.parameters()).device
            if hasattr(layer.self_attn, 'q_proj'):
                module = layer.self_attn.q_proj
                tspeft_layer = TSPEFTLayer(base_layer=module, rank=self.tspeft_config.rank, alpha=self.tspeft_config.alpha, dropout=self.tspeft_config.dropout, s=self.tspeft_config.s, lambda_reg=self.tspeft_config.lambda_reg)
                tspeft_layer = tspeft_layer.to(device=layer_device)
                layer.self_attn.q_proj = tspeft_layer
                self.tspeft_layers[f"layer_{layer_idx}_q_proj"] = tspeft_layer
                trainable_params += sum(p.numel() for p in tspeft_layer.lora.parameters())
            if hasattr(layer.self_attn, 'k_proj'):
                module = layer.self_attn.k_proj
                tspeft_layer = TSPEFTLayer(base_layer=module, rank=self.tspeft_config.rank, alpha=self.tspeft_config.alpha, dropout=self.tspeft_config.dropout, s=self.tspeft_config.s, lambda_reg=self.tspeft_config.lambda_reg)
                tspeft_layer = tspeft_layer.to(device=layer_device)
                layer.self_attn.k_proj = tspeft_layer
                self.tspeft_layers[f"layer_{layer_idx}_k_proj"] = tspeft_layer
                trainable_params += sum(p.numel() for p in tspeft_layer.lora.parameters())
            if hasattr(layer.self_attn, 'v_proj'):
                module = layer.self_attn.v_proj
                tspeft_layer = TSPEFTLayer(base_layer=module, rank=self.tspeft_config.rank, alpha=self.tspeft_config.alpha, dropout=self.tspeft_config.dropout, s=self.tspeft_config.s, lambda_reg=self.tspeft_config.lambda_reg)
                tspeft_layer = tspeft_layer.to(device=layer_device)
                layer.self_attn.v_proj = tspeft_layer
                self.tspeft_layers[f"layer_{layer_idx}_v_proj"] = tspeft_layer
                trainable_params += sum(p.numel() for p in tspeft_layer.lora.parameters())
            if hasattr(layer.self_attn, 'o_proj'):
                module = layer.self_attn.o_proj
                tspeft_layer = TSPEFTLayer(base_layer=module, rank=self.tspeft_config.rank, alpha=self.tspeft_config.alpha, dropout=self.tspeft_config.dropout, s=self.tspeft_config.s, lambda_reg=self.tspeft_config.lambda_reg)
                tspeft_layer = tspeft_layer.to(device=layer_device)
                layer.self_attn.o_proj = tspeft_layer
                self.tspeft_layers[f"layer_{layer_idx}_o_proj"] = tspeft_layer
                trainable_params += sum(p.numel() for p in tspeft_layer.lora.parameters())
            if hasattr(layer, 'mlp'):
                if hasattr(layer.mlp, 'gate_proj'):
                    module = layer.mlp.gate_proj
                    tspeft_layer = TSPEFTLayer(base_layer=module, rank=self.tspeft_config.rank, alpha=self.tspeft_config.alpha, dropout=self.tspeft_config.dropout, s=self.tspeft_config.s, lambda_reg=self.tspeft_config.lambda_reg)
                    tspeft_layer = tspeft_layer.to(device=layer_device)
                    layer.mlp.gate_proj = tspeft_layer
                    self.tspeft_layers[f"layer_{layer_idx}_gate_proj"] = tspeft_layer
                    trainable_params += sum(p.numel() for p in tspeft_layer.lora.parameters())
                if hasattr(layer.mlp, 'up_proj'):
                    module = layer.mlp.up_proj
                    tspeft_layer = TSPEFTLayer(base_layer=module, rank=self.tspeft_config.rank, alpha=self.tspeft_config.alpha, dropout=self.tspeft_config.dropout, s=self.tspeft_config.s, lambda_reg=self.tspeft_config.lambda_reg)
                    tspeft_layer = tspeft_layer.to(device=layer_device)
                    layer.mlp.up_proj = tspeft_layer
                    self.tspeft_layers[f"layer_{layer_idx}_up_proj"] = tspeft_layer
                    trainable_params += sum(p.numel() for p in tspeft_layer.lora.parameters())
                if hasattr(layer.mlp, 'down_proj'):
                    module = layer.mlp.down_proj
                    tspeft_layer = TSPEFTLayer(base_layer=module, rank=self.tspeft_config.rank, alpha=self.tspeft_config.alpha, dropout=self.tspeft_config.dropout, s=self.tspeft_config.s, lambda_reg=self.tspeft_config.lambda_reg)
                    tspeft_layer = tspeft_layer.to(device=layer_device)
                    layer.mlp.down_proj = tspeft_layer
                    self.tspeft_layers[f"layer_{layer_idx}_down_proj"] = tspeft_layer
                    trainable_params += sum(p.numel() for p in tspeft_layer.lora.parameters())
        print(f"TS-PEFT training enabled: {trainable_params} trainable parameters")
            
    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.llm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True)
        loss = None
        if labels is not None:
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            if torch.isnan(loss) or torch.isinf(loss):
                loss = None
        return {'logits': outputs.logits, 'loss': loss, 'hidden_states': outputs.hidden_states}
    
    def update_thresholds(self, lr: float = 1.0):
        for layer in self.tspeft_layers.values():
            if hasattr(layer, '_cache_for_backward') and layer.training:
                grad_output = torch.ones_like(layer._cache_for_backward['base_output'])
                layer.update_threshold(grad_output, lr)


def print_system_info():
    print(f"\n{'='*80}")
    print("SYSTEM INFORMATION")
    print(f"{'='*80}")
    print(f"Date/Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"PyTorch Version: {torch.__version__}")
    print(f"CUDA Available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    print(f"Initial RAM Usage: {mem_info.rss / 1024**3:.2f} GB")
    virtual_mem = psutil.virtual_memory()
    print(f"Total System RAM: {virtual_mem.total / 1024**3:.2f} GB")
    print(f"Available RAM: {virtual_mem.available / 1024**3:.2f} GB")
    print(f"{'='*80}\n")


def print_parameter_statistics(model):
    total_params = sum(p.numel() for p in model.llm.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\n{'='*80}")
    print("PARAMETER STATISTICS")
    print(f"{'='*80}")
    print(f"Total LLM Parameters: {total_params:,}")
    print(f"Trainable TS-PEFT Parameters: {trainable_params:,}")
    print(f"Trainable Percentage: {(trainable_params / total_params * 100):.4f}%")
    if trainable_params > 0:
        print(f"Parameter Efficiency: {total_params / trainable_params:.2f}x reduction")
    print(f"TS-PEFT Rank: {model.tspeft_config.rank}")
    print(f"TS-PEFT Alpha: {model.tspeft_config.alpha}")
    print(f"TS-PEFT s: {model.tspeft_config.s}")
    print(f"TS-PEFT lambda_reg: {model.tspeft_config.lambda_reg}")
    print(f"{'='*80}\n")
    print(f"{'='*80}")
    print("GPU MEMORY ALLOCATION")
    print(f"{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        total = torch.cuda.get_device_properties(i).total_memory / 1024**3
        print(f"GPU {i}: {allocated:.2f} GB allocated / {reserved:.2f} GB reserved / {total:.2f} GB total")
    print(f"{'='*80}\n")


def load_and_process_data(csv_path, tokenizer, max_length=384, java_col='java', csharp_col='C#'):
    print(f"Loading dataset from: {csv_path}")
    df = pd.read_csv(csv_path)
    print(f"Dataset loaded: {len(df)} samples\n")
    data_list = []
    skipped = 0
    print("Processing dataset...")
    for idx, row in df.iterrows():
        try:
            java_code = str(row[java_col]).strip()
            csharp_code = str(row[csharp_col]).strip()
            if len(java_code) < 5 or len(csharp_code) < 5:
                skipped += 1
                continue
            prompt = f"[INST] Translate this Java code to C#:\n{java_code}\n[/INST]"
            full_text = f"{prompt} {csharp_code}"
            inputs = tokenizer(full_text, max_length=max_length, truncation=True, padding='max_length', return_tensors='pt')
            labels = inputs['input_ids'].clone()
            prompt_tokens = tokenizer(prompt, return_tensors='pt')['input_ids']
            prompt_length = prompt_tokens.shape[1]
            labels[0, :prompt_length] = -100
            data_list.append({'input_ids': inputs['input_ids'].squeeze(0), 'attention_mask': inputs['attention_mask'].squeeze(0), 'labels': labels.squeeze(0), 'java_code': java_code, 'csharp_code': csharp_code})
            if (idx + 1) % 1000 == 0:
                print(f"  Processed {idx + 1}/{len(df)} samples")
        except Exception as e:
            skipped += 1
            continue
    print(f"Processing complete: {len(data_list)} valid samples, {skipped} skipped\n")
    return data_list


def calculate_bleu_score(reference, hypothesis, max_n=4):
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    if len(hyp_tokens) == 0:
        return 0.0
    precisions = []
    for n in range(1, max_n + 1):
        ref_ngrams = Counter([tuple(ref_tokens[i:i+n]) for i in range(len(ref_tokens) - n + 1)])
        hyp_ngrams = Counter([tuple(hyp_tokens[i:i+n]) for i in range(len(hyp_tokens) - n + 1)])
        matches = sum((ref_ngrams & hyp_ngrams).values())
        total = sum(hyp_ngrams.values())
        precision = matches / total if total > 0 else 0.0
        precisions.append(precision)
    if all(p == 0 for p in precisions):
        return 0.0
    weights = [1.0 / max_n] * max_n
    log_precisions = [w * math.log(p) if p > 0 else float('-inf') for w, p in zip(weights, precisions)]
    if any(lp == float('-inf') for lp in log_precisions):
        return 0.0
    geometric_mean = math.exp(sum(log_precisions))
    brevity_penalty = min(1.0, math.exp(1 - len(ref_tokens) / len(hyp_tokens))) if len(hyp_tokens) > 0 else 0.0
    return brevity_penalty * geometric_mean


def calculate_meteor_score(reference, hypothesis):
    ref_tokens = set(reference.lower().split())
    hyp_tokens = set(hypothesis.lower().split())
    if len(hyp_tokens) == 0:
        return 0.0
    matches = len(ref_tokens & hyp_tokens)
    precision = matches / len(hyp_tokens) if len(hyp_tokens) > 0 else 0.0
    recall = matches / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f_mean = (10 * precision * recall) / (9 * precision + recall)
    return f_mean


def calculate_rouge_l(reference, hypothesis):
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    if len(ref_tokens) == 0 or len(hyp_tokens) == 0:
        return 0.0
    lcs_length = 0
    m, n = len(ref_tokens), len(hyp_tokens)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if ref_tokens[i-1] == hyp_tokens[j-1]:
                dp[i][j] = dp[i-1][j-1] + 1
            else:
                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
    lcs_length = dp[m][n]
    precision = lcs_length / len(hyp_tokens) if len(hyp_tokens) > 0 else 0.0
    recall = lcs_length / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def calculate_exact_match(reference, hypothesis):
    return 1.0 if reference.strip() == hypothesis.strip() else 0.0


def calculate_chrf_score(reference, hypothesis, n=6):
    def get_char_ngrams(text, n):
        chars = list(text)
        ngrams = []
        for i in range(len(chars) - n + 1):
            ngrams.append(''.join(chars[i:i+n]))
        return ngrams
    ref_ngrams = Counter(get_char_ngrams(reference, n))
    hyp_ngrams = Counter(get_char_ngrams(hypothesis, n))
    if len(hyp_ngrams) == 0:
        return 0.0
    matches = sum((ref_ngrams & hyp_ngrams).values())
    precision = matches / sum(hyp_ngrams.values()) if sum(hyp_ngrams.values()) > 0 else 0.0
    recall = matches / sum(ref_ngrams.values()) if sum(ref_ngrams.values()) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def extract_syntax_features(code):
    features = {'methods': len(re.findall(r'\b(public|private|protected|static)\s+\w+\s+\w+\s*\(', code)), 'classes': len(re.findall(r'\bclass\s+\w+', code)), 'if_statements': len(re.findall(r'\bif\s*\(', code)), 'for_loops': len(re.findall(r'\bfor\s*\(', code)), 'while_loops': len(re.findall(r'\bwhile\s*\(', code)), 'return_statements': len(re.findall(r'\breturn\b', code)), 'variables': len(re.findall(r'\b\w+\s*=\s*', code)), 'try_catch': len(re.findall(r'\btry\s*\{', code))}
    return features


def calculate_syntax_similarity(ref_features, hyp_features):
    if not ref_features or not hyp_features:
        return 0.0
    total_diff = 0
    total_count = 0
    for key in ref_features:
        ref_val = ref_features[key]
        hyp_val = hyp_features.get(key, 0)
        max_val = max(ref_val, hyp_val)
        if max_val > 0:
            similarity = 1.0 - abs(ref_val - hyp_val) / max_val
            total_diff += similarity
            total_count += 1
    return total_diff / total_count if total_count > 0 else 0.0


def calculate_code_similarity_metrics(reference, hypothesis):
    metrics = {}
    metrics['bleu_1'] = calculate_bleu_score(reference, hypothesis, max_n=1)
    metrics['bleu_2'] = calculate_bleu_score(reference, hypothesis, max_n=2)
    metrics['bleu_4'] = calculate_bleu_score(reference, hypothesis, max_n=4)
    metrics['meteor'] = calculate_meteor_score(reference, hypothesis)
    metrics['rouge_l'] = calculate_rouge_l(reference, hypothesis)
    metrics['exact_match'] = calculate_exact_match(reference, hypothesis)
    metrics['chrf'] = calculate_chrf_score(reference, hypothesis)
    ref_features = extract_syntax_features(reference)
    hyp_features = extract_syntax_features(hypothesis)
    metrics['syntax_similarity'] = calculate_syntax_similarity(ref_features, hyp_features)
    return metrics


def check_compilation(code, language='csharp'):
    try:
        if language == 'csharp':
            with tempfile.NamedTemporaryFile(mode='w', suffix='.cs', delete=False) as f:
                wrapper = f"using System;\nusing System.Collections.Generic;\nusing System.Linq;\n\nnamespace TempCompile {{\n    public class TempClass {{\n        {code}\n    }}\n}}"
                f.write(wrapper)
                temp_file = f.name
            result = subprocess.run(['csc', '/nologo', '/t:library', '/noconfig', temp_file], capture_output=True, timeout=5, text=True)
            success = result.returncode == 0
            try:
                os.unlink(temp_file)
                dll_file = temp_file.replace('.cs', '.dll')
                if os.path.exists(dll_file):
                    os.unlink(dll_file)
            except:
                pass
            return success
    except:
        pass
    return False


def collate_fn(batch):
    return {'input_ids': torch.stack([item['input_ids'] for item in batch]), 'attention_mask': torch.stack([item['attention_mask'] for item in batch]), 'labels': torch.stack([item['labels'] for item in batch]), 'java_code': [item['java_code'] for item in batch], 'csharp_code': [item['csharp_code'] for item in batch]}

def train_model(model, train_data, num_epochs=5, learning_rate=2e-5, batch_size=2, warmup_steps=100):
    print(f"\n{'='*80}")
    print("TRAINING CONFIGURATION")
    print(f"{'='*80}")
    print(f"Number of Epochs: {num_epochs}")
    print(f"Learning Rate: {learning_rate}")
    print(f"Batch Size: {batch_size}")
    print(f"Warmup Steps: {warmup_steps}")
    print(f"Total Training Samples: {len(train_data)}")
    print(f"Total Batches per Epoch: {len(train_data) // batch_size}")
    print(f"{'='*80}\n")
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = AdamW(trainable_params, lr=learning_rate, weight_decay=0.01, eps=1e-8)
    def lr_lambda(current_step: int):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return 1.0
    scheduler = LambdaLR(optimizer, lr_lambda)
    dataset = CodeTranslationDataset(train_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0, pin_memory=False)
    model.train()
    start_time = time.time()
    global_step = 0
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0.0
        valid_batches = 0
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 80)
        for batch_idx, batch in enumerate(dataloader):
            try:
                input_ids = batch['input_ids'].to(model.primary_device)
                attention_mask = batch['attention_mask'].to(model.primary_device)
                labels = batch['labels'].to(model.primary_device)
                optimizer.zero_grad()
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs['loss']
                if loss is not None and not torch.isnan(loss) and not torch.isinf(loss):
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
                    optimizer.step()
                    model.update_thresholds(lr=1.0)
                    scheduler.step()
                    total_loss += loss.item()
                    valid_batches += 1
                    global_step += 1
                    if (batch_idx + 1) % 100 == 0:
                        avg_loss = total_loss / valid_batches
                        print(f"  Batch {batch_idx + 1}/{len(dataloader)} - Avg Loss: {avg_loss:.4f}")
                if (batch_idx + 1) % 500 == 0:
                    torch.cuda.empty_cache()
            except Exception as e:
                print(f"  Warning: Batch {batch_idx + 1} failed - {str(e)}")
                continue
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / valid_batches if valid_batches > 0 else 0.0
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"  Average Loss: {avg_loss:.4f}")
        print(f"  Valid Batches: {valid_batches}/{len(dataloader)}")
        print(f"  Time: {epoch_time:.2f}s")
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
        torch.cuda.empty_cache()
    total_time = time.time() - start_time
    print(f"\n{'='*80}")
    print("TRAINING COMPLETE")
    print(f"{'='*80}")
    print(f"Total Time: {total_time:.2f}s ({total_time/60:.2f} minutes)")
    print(f"Average Time per Epoch: {total_time/num_epochs:.2f}s")
    print(f"{'='*80}\n")


def evaluate_model(model, tokenizer, test_data, batch_size=1, max_new_tokens=256):
    print(f"\n{'='*80}")
    print("EVALUATION PHASE")
    print(f"{'='*80}")
    print(f"Test Samples: {len(test_data)}")
    print(f"Max Generation Length: {max_new_tokens}")
    print(f"{'='*80}\n")
    model.eval()
    all_metrics = {'bleu_1': [], 'bleu_2': [], 'bleu_4': [], 'meteor': [], 'rouge_l': [], 'exact_match': [], 'chrf': [], 'syntax_similarity': []}
    references = []
    hypotheses = []
    compilation_success = 0
    start_time = time.time()
    total_tokens = 0
    print("Generating translations...")
    with torch.no_grad():
        for idx, item in enumerate(test_data):
            try:
                java_code = item['java_code']
                csharp_code = item['csharp_code']
                prompt = f"[INST] Translate this Java code to C#:\n{java_code}\n[/INST]"
                inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=384)
                inputs = {k: v.to(model.primary_device) for k, v in inputs.items()}
                outputs = model.llm.generate(**inputs, max_new_tokens=max_new_tokens, num_beams=4, early_stopping=True, do_sample=False, temperature=1.0, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
                generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
                hypothesis = generated_text[len(prompt):].strip()
                references.append(csharp_code)
                hypotheses.append(hypothesis)
                metrics = calculate_code_similarity_metrics(csharp_code, hypothesis)
                for key in all_metrics:
                    all_metrics[key].append(metrics[key])
                if check_compilation(hypothesis):
                    compilation_success += 1
                total_tokens += len(outputs[0])
                if (idx + 1) % 50 == 0:
                    print(f"  Processed {idx + 1}/{len(test_data)} samples")
                if (idx + 1) % 100 == 0:
                    torch.cuda.empty_cache()
            except Exception as e:
                print(f"  Warning: Sample {idx + 1} failed - {str(e)}")
                for key in all_metrics:
                    all_metrics[key].append(0.0)
                references.append("")
                hypotheses.append("")
                continue
    inference_time = time.time() - start_time
    results = {'bleu_1': np.mean(all_metrics['bleu_1']), 'bleu_1_std': np.std(all_metrics['bleu_1']), 'bleu_2': np.mean(all_metrics['bleu_2']), 'bleu_2_std': np.std(all_metrics['bleu_2']), 'bleu_4': np.mean(all_metrics['bleu_4']), 'bleu_4_std': np.std(all_metrics['bleu_4']), 'meteor': np.mean(all_metrics['meteor']), 'meteor_std': np.std(all_metrics['meteor']), 'rouge_l': np.mean(all_metrics['rouge_l']), 'rouge_l_std': np.std(all_metrics['rouge_l']), 'exact_match': np.mean(all_metrics['exact_match']), 'chrf': np.mean(all_metrics['chrf']), 'chrf_std': np.std(all_metrics['chrf']), 'syntax_similarity': np.mean(all_metrics['syntax_similarity']), 'syntax_similarity_std': np.std(all_metrics['syntax_similarity']), 'compilation_rate': (compilation_success / len(test_data)) * 100, 'inference_time': inference_time, 'throughput': len(test_data) / inference_time, 'tokens_per_sec': total_tokens / inference_time, 'total_tokens': total_tokens, 'avg_tokens_per_sample': total_tokens / len(test_data)}
    return results


def print_evaluation_results(results):
    print(f"\n{'='*80}")
    print("EVALUATION RESULTS")
    print(f"{'='*80}\n")
    print(f"{'='*80}")
    print("N-GRAM BASED METRICS")
    print(f"{'='*80}")
    print(f"BLEU-1: {results['bleu_1']:.4f} (±{results['bleu_1_std']:.4f})")
    print(f"BLEU-2: {results['bleu_2']:.4f} (±{results['bleu_2_std']:.4f})")
    print(f"BLEU-4: {results['bleu_4']:.4f} (±{results['bleu_4_std']:.4f})")
    print(f"chrF: {results['chrf']:.4f} (±{results['chrf_std']:.4f})")
    print(f"\n{'='*80}")
    print("SEMANTIC SIMILARITY METRICS")
    print(f"{'='*80}")
    print(f"METEOR: {results['meteor']:.4f} (±{results['meteor_std']:.4f})")
    print(f"ROUGE-L: {results['rouge_l']:.4f} (±{results['rouge_l_std']:.4f})")
    print(f"\n{'='*80}")
    print("CODE-SPECIFIC METRICS")
    print(f"{'='*80}")
    print(f"Syntax Similarity: {results['syntax_similarity']:.4f} (±{results['syntax_similarity_std']:.4f})")
    print(f"Exact Match: {results['exact_match']:.4f}")
    print(f"Compilation Success Rate: {results['compilation_rate']:.2f}%")
    print(f"\n{'='*80}")
    print("PERFORMANCE METRICS")
    print(f"{'='*80}")
    print(f"Total Inference Time: {results['inference_time']:.2f}s ({results['inference_time']/60:.2f} min)")
    print(f"Throughput: {results['throughput']:.2f} samples/sec")
    print(f"Token Generation Rate: {results['tokens_per_sec']:.2f} tokens/sec")
    print(f"Total Tokens Generated: {results['total_tokens']}")
    print(f"Average Tokens per Sample: {results['avg_tokens_per_sample']:.2f}")
    print(f"{'='*80}\n")


def main():
    print(f"\n{'='*80}")
    print("JAVA TO C# CODE TRANSLATION")
    print("Using TS-PEFT-Enhanced Mistral-7B")
    print(f"{'='*80}\n")
    print_system_info()
    torch.cuda.empty_cache()
    gc.collect()
    cache_dir = '/hf_cache'
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", cache_dir=cache_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print("Tokenizer loaded\n")
    tspeft_config = TSPEFTConfig(rank=32, alpha=0.5, dropout=0.05, s=4e-5, lambda_reg=4.5e-5, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
    print("Initializing model...")
    model = TSPEFTMistral("mistralai/Mistral-7B-Instruct-v0.2", tspeft_config, cache_dir=cache_dir)
    print_parameter_statistics(model)
    print(f"\n{'='*80}")
    print("LOADING TRAINING DATA")
    print(f"{'='*80}\n")
    train_csv = 'train.csv'
    train_data = load_and_process_data(train_csv, tokenizer, java_col='java', csharp_col='C#')
    if len(train_data) == 0:
        print("ERROR: No training data loaded. Exiting.")
        return
    train_model(model, train_data, num_epochs=5, learning_rate=2e-5, batch_size=2, warmup_steps=100)
    print(f"\n{'='*80}")
    print("LOADING TEST DATA")
    print(f"{'='*80}\n")
    test_csv = '/test.csv'
    test_data = load_and_process_data(test_csv, tokenizer, java_col='java', csharp_col='C#')
    if len(test_data) == 0:
        print("ERROR: No test data loaded. Exiting.")
        return
    results = evaluate_model(model, tokenizer, test_data, batch_size=1)
    print_evaluation_results(results)
    print(f"\n{'='*80}")
    print("EXECUTION COMPLETE")
    print(f"{'='*80}")
    print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*80}\n")


if __name__ == '__main__':
    main()

# GateRA

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import re
import subprocess
import tempfile
from datetime import datetime
import math
from typing import Optional, List, Dict
from dataclasses import dataclass
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR


@dataclass
class GateRAConfig:
    rank: int = 16
    alpha: float = 16.0
    dropout: float = 0.0
    target_modules: List[str] = None
    entropy_reg_weight: float = 0.01
    
    def __post_init__(self):
        if self.target_modules is None:
            self.target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]


class GatingModule(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.gate_linear = nn.Linear(input_dim, 1, bias=True)
        nn.init.zeros_(self.gate_linear.weight)
        nn.init.zeros_(self.gate_linear.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_logits = self.gate_linear(x)
        gate_values = torch.sigmoid(gate_logits)
        return gate_values


class GateRALayer(nn.Module):
    def __init__(self, base_layer: nn.Module, rank: int, alpha: float, dropout: float, input_dim: int, output_dim: int):
        super().__init__()
        self.base_layer = base_layer
        self.rank = rank
        self.scaling = alpha / rank
        self.dropout = dropout
        self.lora_A = nn.Parameter(torch.zeros(input_dim, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, output_dim))
        self.gating_module = GatingModule(input_dim)
        self.dropout_layer = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        base_output = self.base_layer(x)
        if x.dim() == 3:
            batch_size, seq_len, hidden_dim = x.shape
            x_flat = x.view(-1, hidden_dim)
        else:
            x_flat = x
            batch_size, seq_len = None, None
        gate_values = self.gating_module(x_flat)
        lora_output = x_flat @ self.lora_A @ self.lora_B
        lora_output = self.dropout_layer(lora_output)
        modulated_output = gate_values * lora_output * self.scaling
        if batch_size is not None and seq_len is not None:
            modulated_output = modulated_output.view(batch_size, seq_len, -1)
        final_output = base_output + modulated_output
        return final_output


class CodeTranslationDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


class GateRAMistral(nn.Module):
    def __init__(self, model_name: str, gatera_config: GateRAConfig, cache_dir: str = None):
        super().__init__()
        torch.cuda.empty_cache()
        gc.collect()
        num_gpus = torch.cuda.device_count()
        print(f"\nInitializing model across {num_gpus} GPUs")
        for i in range(num_gpus):
            props = torch.cuda.get_device_properties(i)
            print(f"  GPU {i}: {props.name} - {props.total_memory / 1024**3:.2f} GB")
        self.primary_device = 'cuda:0'
        print(f"\nLoading base model: {model_name}")
        base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir,
            low_cpu_mem_usage=True,
            device_map='auto',
            trust_remote_code=True
        )
        self.llm = base_model
        self.gatera_config = gatera_config
        self.gatera_layers = nn.ModuleDict()
        num_layers = len(base_model.model.layers)
        print(f"Model has {num_layers} transformer layers")
        for param in self.llm.parameters():
            param.requires_grad = False
        self._inject_gatera_layers()
        print("Model initialization complete!\n")
    
    def _inject_gatera_layers(self):
        num_layers = len(self.llm.model.layers)
        print(f"Injecting GateRA layers for {num_layers} transformer layers...")
        trainable_params = 0
        for layer_idx in range(num_layers):
            layer = self.llm.model.layers[layer_idx]
            if hasattr(layer.self_attn, 'q_proj'):
                module = layer.self_attn.q_proj
                input_dim = module.in_features
                output_dim = module.out_features
                gatera_layer = GateRALayer(
                    base_layer=module,
                    rank=self.gatera_config.rank,
                    alpha=self.gatera_config.alpha,
                    dropout=self.gatera_config.dropout,
                    input_dim=input_dim,
                    output_dim=output_dim
                )
                layer.self_attn.q_proj = gatera_layer
                self.gatera_layers[f"layer_{layer_idx}_q_proj"] = gatera_layer
                trainable_params += gatera_layer.lora_A.numel() + gatera_layer.lora_B.numel()
                trainable_params += sum(p.numel() for p in gatera_layer.gating_module.parameters())
            if hasattr(layer.self_attn, 'k_proj'):
                module = layer.self_attn.k_proj
                input_dim = module.in_features
                output_dim = module.out_features
                gatera_layer = GateRALayer(
                    base_layer=module,
                    rank=self.gatera_config.rank,
                    alpha=self.gatera_config.alpha,
                    dropout=self.gatera_config.dropout,
                    input_dim=input_dim,
                    output_dim=output_dim
                )
                layer.self_attn.k_proj = gatera_layer
                self.gatera_layers[f"layer_{layer_idx}_k_proj"] = gatera_layer
                trainable_params += gatera_layer.lora_A.numel() + gatera_layer.lora_B.numel()
                trainable_params += sum(p.numel() for p in gatera_layer.gating_module.parameters())
            if hasattr(layer.self_attn, 'v_proj'):
                module = layer.self_attn.v_proj
                input_dim = module.in_features
                output_dim = module.out_features
                gatera_layer = GateRALayer(
                    base_layer=module,
                    rank=self.gatera_config.rank,
                    alpha=self.gatera_config.alpha,
                    dropout=self.gatera_config.dropout,
                    input_dim=input_dim,
                    output_dim=output_dim
                )
                layer.self_attn.v_proj = gatera_layer
                self.gatera_layers[f"layer_{layer_idx}_v_proj"] = gatera_layer
                trainable_params += gatera_layer.lora_A.numel() + gatera_layer.lora_B.numel()
                trainable_params += sum(p.numel() for p in gatera_layer.gating_module.parameters())
            if hasattr(layer.self_attn, 'o_proj'):
                module = layer.self_attn.o_proj
                input_dim = module.in_features
                output_dim = module.out_features
                gatera_layer = GateRALayer(
                    base_layer=module,
                    rank=self.gatera_config.rank,
                    alpha=self.gatera_config.alpha,
                    dropout=self.gatera_config.dropout,
                    input_dim=input_dim,
                    output_dim=output_dim
                )
                layer.self_attn.o_proj = gatera_layer
                self.gatera_layers[f"layer_{layer_idx}_o_proj"] = gatera_layer
                trainable_params += gatera_layer.lora_A.numel() + gatera_layer.lora_B.numel()
                trainable_params += sum(p.numel() for p in gatera_layer.gating_module.parameters())
            if hasattr(layer, 'mlp'):
                if hasattr(layer.mlp, 'gate_proj'):
                    module = layer.mlp.gate_proj
                    input_dim = module.in_features
                    output_dim = module.out_features
                    gatera_layer = GateRALayer(
                        base_layer=module,
                        rank=self.gatera_config.rank,
                        alpha=self.gatera_config.alpha,
                        dropout=self.gatera_config.dropout,
                        input_dim=input_dim,
                        output_dim=output_dim
                    )
                    layer.mlp.gate_proj = gatera_layer
                    self.gatera_layers[f"layer_{layer_idx}_gate_proj"] = gatera_layer
                    trainable_params += gatera_layer.lora_A.numel() + gatera_layer.lora_B.numel()
                    trainable_params += sum(p.numel() for p in gatera_layer.gating_module.parameters())
                if hasattr(layer.mlp, 'up_proj'):
                    module = layer.mlp.up_proj
                    input_dim = module.in_features
                    output_dim = module.out_features
                    gatera_layer = GateRALayer(
                        base_layer=module,
                        rank=self.gatera_config.rank,
                        alpha=self.gatera_config.alpha,
                        dropout=self.gatera_config.dropout,
                        input_dim=input_dim,
                        output_dim=output_dim
                    )
                    layer.mlp.up_proj = gatera_layer
                    self.gatera_layers[f"layer_{layer_idx}_up_proj"] = gatera_layer
                    trainable_params += gatera_layer.lora_A.numel() + gatera_layer.lora_B.numel()
                    trainable_params += sum(p.numel() for p in gatera_layer.gating_module.parameters())
                if hasattr(layer.mlp, 'down_proj'):
                    module = layer.mlp.down_proj
                    input_dim = module.in_features
                    output_dim = module.out_features
                    gatera_layer = GateRALayer(
                        base_layer=module,
                        rank=self.gatera_config.rank,
                        alpha=self.gatera_config.alpha,
                        dropout=self.gatera_config.dropout,
                        input_dim=input_dim,
                        output_dim=output_dim
                    )
                    layer.mlp.down_proj = gatera_layer
                    self.gatera_layers[f"layer_{layer_idx}_down_proj"] = gatera_layer
                    trainable_params += gatera_layer.lora_A.numel() + gatera_layer.lora_B.numel()
                    trainable_params += sum(p.numel() for p in gatera_layer.gating_module.parameters())
        print(f"GateRA training enabled: {trainable_params} trainable parameters")
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.llm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        loss = None
        if labels is not None:
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
            task_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            total_entropy_loss = 0.0
            gate_count = 0
            for name, layer in self.gatera_layers.items():
                if hasattr(layer, 'gating_module'):
                    dummy_input = torch.randn(
                        input_ids.shape[0] * input_ids.shape[1],
                        layer.lora_A.shape[0],
                        device=input_ids.device,
                        dtype=torch.bfloat16
                    )
                    gate_vals = layer.gating_module(dummy_input)
                    eps = 1e-8
                    gate_vals = torch.clamp(gate_vals, eps, 1.0 - eps)
                    entropy = -gate_vals * torch.log(gate_vals) - (1 - gate_vals) * torch.log(1 - gate_vals)
                    total_entropy_loss += entropy.mean()
                    gate_count += 1
            if gate_count > 0:
                avg_entropy_loss = total_entropy_loss / gate_count
            else:
                avg_entropy_loss = torch.tensor(0.0, device=input_ids.device)
            loss = task_loss + self.gatera_config.entropy_reg_weight * avg_entropy_loss
            if torch.isnan(loss) or torch.isinf(loss):
                loss = None
        return {'logits': outputs.logits, 'loss': loss, 'hidden_states': outputs.hidden_states}


def print_system_info():
    print(f"\n{'='*80}")
    print("SYSTEM INFORMATION")
    print(f"{'='*80}")
    print(f"Date/Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"PyTorch Version: {torch.__version__}")
    print(f"CUDA Available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    print(f"Initial RAM Usage: {mem_info.rss / 1024**3:.2f} GB")
    virtual_mem = psutil.virtual_memory()
    print(f"Total System RAM: {virtual_mem.total / 1024**3:.2f} GB")
    print(f"Available RAM: {virtual_mem.available / 1024**3:.2f} GB")
    print(f"{'='*80}\n")


def print_parameter_statistics(model):
    total_params = sum(p.numel() for p in model.llm.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\n{'='*80}")
    print("PARAMETER STATISTICS")
    print(f"{'='*80}")
    print(f"Total LLM Parameters: {total_params:,}")
    print(f"Trainable GateRA Parameters: {trainable_params:,}")
    print(f"Trainable Percentage: {(trainable_params / total_params * 100):.4f}%")
    if trainable_params > 0:
        print(f"Parameter Efficiency: {total_params / trainable_params:.2f}x reduction")
    print(f"GateRA Rank: {model.gatera_config.rank}")
    print(f"GateRA Alpha: {model.gatera_config.alpha}")
    print(f"Entropy Regularization Weight: {model.gatera_config.entropy_reg_weight}")
    print(f"{'='*80}\n")
    print(f"{'='*80}")
    print("GPU MEMORY ALLOCATION")
    print(f"{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        total = torch.cuda.get_device_properties(i).total_memory / 1024**3
        print(f"GPU {i}: {allocated:.2f} GB allocated / {reserved:.2f} GB reserved / {total:.2f} GB total")
    print(f"{'='*80}\n")


def load_and_process_data(csv_path, tokenizer, max_length=384, java_col='java', csharp_col='C#'):
    print(f"Loading dataset from: {csv_path}")
    df = pd.read_csv(csv_path)
    print(f"Dataset loaded: {len(df)} samples\n")
    data_list = []
    skipped = 0
    print("Processing dataset...")
    for idx, row in df.iterrows():
        try:
            java_code = str(row[java_col]).strip()
            csharp_code = str(row[csharp_col]).strip()
            if len(java_code) < 5 or len(csharp_code) < 5:
                skipped += 1
                continue
            prompt = f"[INST] Translate this Java code to C#:\n{java_code}\n[/INST]"
            full_text = f"{prompt} {csharp_code}"
            inputs = tokenizer(
                full_text,
                max_length=max_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )
            labels = inputs['input_ids'].clone()
            prompt_tokens = tokenizer(prompt, return_tensors='pt')['input_ids']
            prompt_length = prompt_tokens.shape[1]
            labels[0, :prompt_length] = -100
            data_list.append({
                'input_ids': inputs['input_ids'].squeeze(0),
                'attention_mask': inputs['attention_mask'].squeeze(0),
                'labels': labels.squeeze(0),
                'java_code': java_code,
                'csharp_code': csharp_code
            })
            if (idx + 1) % 1000 == 0:
                print(f"  Processed {idx + 1}/{len(df)} samples")
        except Exception as e:
            skipped += 1
            continue
    print(f"Processing complete: {len(data_list)} valid samples, {skipped} skipped\n")
    return data_list


def calculate_bleu_score(reference, hypothesis, max_n=4):
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    if len(hyp_tokens) == 0:
        return 0.0
    precisions = []
    for n in range(1, max_n + 1):
        ref_ngrams = Counter([tuple(ref_tokens[i:i+n]) for i in range(len(ref_tokens) - n + 1)])
        hyp_ngrams = Counter([tuple(hyp_tokens[i:i+n]) for i in range(len(hyp_tokens) - n + 1)])
        matches = sum((ref_ngrams & hyp_ngrams).values())
        total = sum(hyp_ngrams.values())
        precision = matches / total if total > 0 else 0.0
        precisions.append(precision)
    if all(p == 0 for p in precisions):
        return 0.0
    weights = [1.0 / max_n] * max_n
    log_precisions = [w * math.log(p) if p > 0 else float('-inf') for w, p in zip(weights, precisions)]
    if any(lp == float('-inf') for lp in log_precisions):
        return 0.0
    geometric_mean = math.exp(sum(log_precisions))
    brevity_penalty = min(1.0, math.exp(1 - len(ref_tokens) / len(hyp_tokens))) if len(hyp_tokens) > 0 else 0.0
    return brevity_penalty * geometric_mean


def calculate_meteor_score(reference, hypothesis):
    ref_tokens = set(reference.lower().split())
    hyp_tokens = set(hypothesis.lower().split())
    if len(hyp_tokens) == 0:
        return 0.0
    matches = len(ref_tokens & hyp_tokens)
    precision = matches / len(hyp_tokens) if len(hyp_tokens) > 0 else 0.0
    recall = matches / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f_mean = (10 * precision * recall) / (9 * precision + recall)
    return f_mean


def calculate_rouge_l(reference, hypothesis):
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    if len(ref_tokens) == 0 or len(hyp_tokens) == 0:
        return 0.0
    lcs_length = 0
    m, n = len(ref_tokens), len(hyp_tokens)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if ref_tokens[i-1] == hyp_tokens[j-1]:
                dp[i][j] = dp[i-1][j-1] + 1
            else:
                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
    lcs_length = dp[m][n]
    precision = lcs_length / len(hyp_tokens) if len(hyp_tokens) > 0 else 0.0
    recall = lcs_length / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def calculate_exact_match(reference, hypothesis):
    return 1.0 if reference.strip() == hypothesis.strip() else 0.0


def calculate_chrf_score(reference, hypothesis, n=6):
    def get_char_ngrams(text, n):
        chars = list(text)
        ngrams = []
        for i in range(len(chars) - n + 1):
            ngrams.append(''.join(chars[i:i+n]))
        return ngrams
    ref_ngrams = Counter(get_char_ngrams(reference, n))
    hyp_ngrams = Counter(get_char_ngrams(hypothesis, n))
    if len(hyp_ngrams) == 0:
        return 0.0
    matches = sum((ref_ngrams & hyp_ngrams).values())
    precision = matches / sum(hyp_ngrams.values()) if sum(hyp_ngrams.values()) > 0 else 0.0
    recall = matches / sum(ref_ngrams.values()) if sum(ref_ngrams.values()) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def extract_syntax_features(code):
    features = {
        'methods': len(re.findall(r'\b(public|private|protected|static)\s+\w+\s+\w+\s*\(', code)),
        'classes': len(re.findall(r'\bclass\s+\w+', code)),
        'if_statements': len(re.findall(r'\bif\s*\(', code)),
        'for_loops': len(re.findall(r'\bfor\s*\(', code)),
        'while_loops': len(re.findall(r'\bwhile\s*\(', code)),
        'return_statements': len(re.findall(r'\breturn\b', code)),
        'variables': len(re.findall(r'\b\w+\s*=\s*', code)),
        'try_catch': len(re.findall(r'\btry\s*\{', code)),
    }
    return features


def calculate_syntax_similarity(ref_features, hyp_features):
    if not ref_features or not hyp_features:
        return 0.0
    total_diff = 0
    total_count = 0
    for key in ref_features:
        ref_val = ref_features[key]
        hyp_val = hyp_features.get(key, 0)
        max_val = max(ref_val, hyp_val)
        if max_val > 0:
            similarity = 1.0 - abs(ref_val - hyp_val) / max_val
            total_diff += similarity
            total_count += 1
    return total_diff / total_count if total_count > 0 else 0.0


def calculate_code_similarity_metrics(reference, hypothesis):
    metrics = {}
    metrics['bleu_1'] = calculate_bleu_score(reference, hypothesis, max_n=1)
    metrics['bleu_2'] = calculate_bleu_score(reference, hypothesis, max_n=2)
    metrics['bleu_4'] = calculate_bleu_score(reference, hypothesis, max_n=4)
    metrics['meteor'] = calculate_meteor_score(reference, hypothesis)
    metrics['rouge_l'] = calculate_rouge_l(reference, hypothesis)
    metrics['exact_match'] = calculate_exact_match(reference, hypothesis)
    metrics['chrf'] = calculate_chrf_score(reference, hypothesis)
    ref_features = extract_syntax_features(reference)
    hyp_features = extract_syntax_features(hypothesis)
    metrics['syntax_similarity'] = calculate_syntax_similarity(ref_features, hyp_features)
    return metrics


def check_compilation(code, language='csharp'):
    try:
        if language == 'csharp':
            with tempfile.NamedTemporaryFile(mode='w', suffix='.cs', delete=False) as f:
                wrapper = f"using System;\nusing System.Collections.Generic;\nusing System.Linq;\n\nnamespace TempCompile {{\n    public class TempClass {{\n        {code}\n    }}\n}}"
                f.write(wrapper)
                temp_file = f.name
            result = subprocess.run(
                ['csc', '/nologo', '/t:library', '/noconfig', temp_file],
                capture_output=True,
                timeout=5,
                text=True
            )
            success = result.returncode == 0
            try:
                os.unlink(temp_file)
                dll_file = temp_file.replace('.cs', '.dll')
                if os.path.exists(dll_file):
                    os.unlink(dll_file)
            except:
                pass
            return success
    except:
        pass
    return False


def collate_fn(batch):
    return {
        'input_ids': torch.stack([item['input_ids'] for item in batch]),
        'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
        'labels': torch.stack([item['labels'] for item in batch]),
        'java_code': [item['java_code'] for item in batch],
        'csharp_code': [item['csharp_code'] for item in batch]
    }


def train_model(model, train_data, num_epochs=5, learning_rate=2e-5, batch_size=2, warmup_steps=100):
    print(f"\n{'='*80}")
    print("TRAINING CONFIGURATION")
    print(f"{'='*80}")
    print(f"Number of Epochs: {num_epochs}")
    print(f"Learning Rate: {learning_rate}")
    print(f"Batch Size: {batch_size}")
    print(f"Warmup Steps: {warmup_steps}")
    print(f"Total Training Samples: {len(train_data)}")
    print(f"Total Batches per Epoch: {len(train_data) // batch_size}")
    print(f"{'='*80}\n")
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = AdamW(trainable_params, lr=learning_rate, weight_decay=0.01, eps=1e-8)
    def lr_lambda(current_step: int):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return 1.0
    scheduler = LambdaLR(optimizer, lr_lambda)
    dataset = CodeTranslationDataset(train_data)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=True
    )
    model.train()
    start_time = time.time()
    global_step = 0
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0.0
        valid_batches = 0
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 80)
        for batch_idx, batch in enumerate(dataloader):
            try:
                input_ids = batch['input_ids'].cuda()
                attention_mask = batch['attention_mask'].cuda()
                labels = batch['labels'].cuda()
                optimizer.zero_grad()
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs['loss']
                if loss is not None and not torch.isnan(loss) and not torch.isinf(loss):
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
                    optimizer.step()
                    scheduler.step()
                    total_loss += loss.item()
                    valid_batches += 1
                    global_step += 1
                    if (batch_idx + 1) % 100 == 0:
                        avg_loss = total_loss / valid_batches
                        print(f"  Batch {batch_idx + 1}/{len(dataloader)} - Avg Loss: {avg_loss:.4f}")
                if (batch_idx + 1) % 500 == 0:
                    torch.cuda.empty_cache()
            except Exception as e:
                print(f"  Warning: Batch {batch_idx + 1} failed - {str(e)}")
                continue
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / valid_batches if valid_batches > 0 else 0.0
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"  Average Loss: {avg_loss:.4f}")
        print(f"  Valid Batches: {valid_batches}/{len(dataloader)}")
        print(f"  Time: {epoch_time:.2f}s")
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
        torch.cuda.empty_cache()
    total_time = time.time() - start_time
    print(f"\n{'='*80}")
    print("TRAINING COMPLETE")
    print(f"{'='*80}")
    print(f"Total Time: {total_time:.2f}s ({total_time/60:.2f} minutes)")
    print(f"Average Time per Epoch: {total_time/num_epochs:.2f}s")
    print(f"{'='*80}\n")


def evaluate_model(model, tokenizer, test_data, batch_size=1, max_new_tokens=256):
    print(f"\n{'='*80}")
    print("EVALUATION PHASE")
    print(f"{'='*80}")
    print(f"Test Samples: {len(test_data)}")
    print(f"Max Generation Length: {max_new_tokens}")
    print(f"{'='*80}\n")
    model.eval()
    all_metrics = {
        'bleu_1': [], 'bleu_2': [], 'bleu_4': [],
        'meteor': [], 'rouge_l': [], 'exact_match': [],
        'chrf': [], 'syntax_similarity': []
    }
    references = []
    hypotheses = []
    compilation_success = 0
    start_time = time.time()
    total_tokens = 0
    print("Generating translations...")
    with torch.no_grad():
        for idx, item in enumerate(test_data):
            try:
                java_code = item['java_code']
                csharp_code = item['csharp_code']
                prompt = f"[INST] Translate this Java code to C#:\n{java_code}\n[/INST]"
                inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=384)
                inputs = {k: v.cuda() for k, v in inputs.items()}
                outputs = model.llm.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    num_beams=4,
                    early_stopping=True,
                    do_sample=False,
                    temperature=1.0,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )
                generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
                hypothesis = generated_text[len(prompt):].strip()
                references.append(csharp_code)
                hypotheses.append(hypothesis)
                metrics = calculate_code_similarity_metrics(csharp_code, hypothesis)
                for key in all_metrics:
                    all_metrics[key].append(metrics[key])
                if check_compilation(hypothesis):
                    compilation_success += 1
                total_tokens += len(outputs[0])
                if (idx + 1) % 50 == 0:
                    print(f"  Processed {idx + 1}/{len(test_data)} samples")
                if (idx + 1) % 100 == 0:
                    torch.cuda.empty_cache()
            except Exception as e:
                print(f"  Warning: Sample {idx + 1} failed - {str(e)}")
                for key in all_metrics:
                    all_metrics[key].append(0.0)
                references.append("")
                hypotheses.append("")
                continue
    inference_time = time.time() - start_time
    results = {
        'bleu_1': np.mean(all_metrics['bleu_1']),
        'bleu_1_std': np.std(all_metrics['bleu_1']),
        'bleu_2': np.mean(all_metrics['bleu_2']),
        'bleu_2_std': np.std(all_metrics['bleu_2']),
        'bleu_4': np.mean(all_metrics['bleu_4']),
        'bleu_4_std': np.std(all_metrics['bleu_4']),
        'meteor': np.mean(all_metrics['meteor']),
        'meteor_std': np.std(all_metrics['meteor']),
        'rouge_l': np.mean(all_metrics['rouge_l']),
        'rouge_l_std': np.std(all_metrics['rouge_l']),
        'exact_match': np.mean(all_metrics['exact_match']),
        'chrf': np.mean(all_metrics['chrf']),
        'chrf_std': np.std(all_metrics['chrf']),
        'syntax_similarity': np.mean(all_metrics['syntax_similarity']),
        'syntax_similarity_std': np.std(all_metrics['syntax_similarity']),
        'compilation_rate': (compilation_success / len(test_data)) * 100,
        'inference_time': inference_time,
        'throughput': len(test_data) / inference_time,
        'tokens_per_sec': total_tokens / inference_time,
        'total_tokens': total_tokens,
        'avg_tokens_per_sample': total_tokens / len(test_data)
    }
    return results


def print_evaluation_results(results):
    print(f"\n{'='*80}")
    print("EVALUATION RESULTS")
    print(f"{'='*80}\n")
    print(f"{'='*80}")
    print("N-GRAM BASED METRICS")
    print(f"{'='*80}")
    print(f"BLEU-1: {results['bleu_1']:.4f} (±{results['bleu_1_std']:.4f})")
    print(f"BLEU-2: {results['bleu_2']:.4f} (±{results['bleu_2_std']:.4f})")
    print(f"BLEU-4: {results['bleu_4']:.4f} (±{results['bleu_4_std']:.4f})")
    print(f"chrF: {results['chrf']:.4f} (±{results['chrf_std']:.4f})")
    print(f"\n{'='*80}")
    print("SEMANTIC SIMILARITY METRICS")
    print(f"{'='*80}")
    print(f"METEOR: {results['meteor']:.4f} (±{results['meteor_std']:.4f})")
    print(f"ROUGE-L: {results['rouge_l']:.4f} (±{results['rouge_l_std']:.4f})")
    print(f"\n{'='*80}")
    print("CODE-SPECIFIC METRICS")
    print(f"{'='*80}")
    print(f"Syntax Similarity: {results['syntax_similarity']:.4f} (±{results['syntax_similarity_std']:.4f})")
    print(f"Exact Match: {results['exact_match']:.4f}")
    print(f"Compilation Success Rate: {results['compilation_rate']:.2f}%")
    print(f"\n{'='*80}")
    print("PERFORMANCE METRICS")
    print(f"{'='*80}")
    print(f"Total Inference Time: {results['inference_time']:.2f}s ({results['inference_time']/60:.2f} min)")
    print(f"Throughput: {results['throughput']:.2f} samples/sec")
    print(f"Token Generation Rate: {results['tokens_per_sec']:.2f} tokens/sec")
    print(f"Total Tokens Generated: {results['total_tokens']}")
    print(f"Average Tokens per Sample: {results['avg_tokens_per_sample']:.2f}")
    print(f"{'='*80}\n")


def main():
    print(f"\n{'='*80}")
    print("JAVA TO C# CODE TRANSLATION")
    print("Using GateRA-Enhanced Mistral-7B")
    print(f"{'='*80}\n")
    print_system_info()
    torch.cuda.empty_cache()
    gc.collect()
    cache_dir = 'hf_cache'
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        "mistralai/Mistral-7B-Instruct-v0.2",
        cache_dir=cache_dir
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print("Tokenizer loaded\n")
    gatera_config = GateRAConfig(
        rank=16,
        alpha=16.0,
        dropout=0.0,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        entropy_reg_weight=0.01
    )
    print("Initializing model...")
    model = GateRAMistral(
        "mistralai/Mistral-7B-Instruct-v0.2",
        gatera_config,
        cache_dir=cache_dir
    )
    print_parameter_statistics(model)
    print(f"\n{'='*80}")
    print("LOADING TRAINING DATA")
    print(f"{'='*80}\n")
    train_csv = '/train.csv'
    train_data = load_and_process_data(train_csv, tokenizer, java_col='java', csharp_col='C#')
    if len(train_data) == 0:
        print("ERROR: No training data loaded. Exiting.")
        return
    train_model(model, train_data, num_epochs=5, learning_rate=2e-5, batch_size=2, warmup_steps=100)
    print(f"\n{'='*80}")
    print("LOADING TEST DATA")
    print(f"{'='*80}\n")
    test_csv = '/test.csv'
    test_data = load_and_process_data(test_csv, tokenizer, java_col='java', csharp_col='C#')
    if len(test_data) == 0:
        print("ERROR: No test data loaded. Exiting.")
        return
    results = evaluate_model(model, tokenizer, test_data, batch_size=1)
    print_evaluation_results(results)
    print(f"\n{'='*80}")
    print("EXECUTION COMPLETE")
    print(f"{'='*80}")
    print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*80}\n")


if __name__ == '__main__':
    main()

# LoRA

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import re
import subprocess
import tempfile
from datetime import datetime
import math


class LoRAConfig:
    def __init__(self, r: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.1):
        self.r = r
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout
        self.scaling = lora_alpha / r


class LoRALayer(nn.Module):
    def __init__(self, in_features: int, out_features: int, config: LoRAConfig):
        super().__init__()
        self.r = config.r
        self.lora_alpha = config.lora_alpha
        self.scaling = config.scaling
        self.lora_dropout = nn.Dropout(p=config.lora_dropout)
        self.lora_A = nn.Parameter(torch.zeros((in_features, config.r)))
        self.lora_B = nn.Parameter(torch.zeros((config.r, out_features)))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
    
    def forward(self, x, original_weight):
        original_output = F.linear(x, original_weight)
        lora_output = F.linear(F.linear(self.lora_dropout(x), self.lora_A.T), self.lora_B.T)
        return original_output + lora_output * self.scaling


class CodeTranslationDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


class LoRAMistral(nn.Module):
    def __init__(self, model_name: str, lora_config: LoRAConfig, cache_dir: str = None):
        super().__init__()
        torch.cuda.empty_cache()
        gc.collect()
        
        num_gpus = torch.cuda.device_count()
        print(f"\nInitializing model across {num_gpus} GPUs")
        for i in range(num_gpus):
            props = torch.cuda.get_device_properties(i)
            print(f"  GPU {i}: {props.name} - {props.total_memory / 1024**3:.2f} GB")
        
        self.primary_device = 'cuda:0'
        
        print(f"\nLoading base model: {model_name}")
        base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir,
            low_cpu_mem_usage=True,
            device_map='auto',
            trust_remote_code=True
        )
        
        self.llm = base_model
        self.lora_config = lora_config
        
        num_layers = len(base_model.model.layers)
        print(f"Model has {num_layers} transformer layers")
        
        for param in self.llm.parameters():
            param.requires_grad = False
        
        self.lora_layers = nn.ModuleDict()
        self._inject_lora_layers()
        
        print("Model initialization complete!\n")
    
    def _inject_lora_layers(self):
        num_layers = len(self.llm.model.layers)
        hidden_size = self.llm.config.hidden_size
        
        print(f"Injecting LoRA layers (r={self.lora_config.r}) into {num_layers} layers...")
        
        for layer_idx in range(num_layers):
            layer = self.llm.model.layers[layer_idx]
            device = next(layer.parameters()).device
            
            self_attn = layer.self_attn
            q_proj = self_attn.q_proj
            v_proj = self_attn.v_proj
            
            lora_q = LoRALayer(hidden_size, q_proj.out_features, self.lora_config).to(device).to(torch.bfloat16)
            lora_v = LoRALayer(hidden_size, v_proj.out_features, self.lora_config).to(device).to(torch.bfloat16)
            
            self.lora_layers[f'layer_{layer_idx}_q'] = lora_q
            self.lora_layers[f'layer_{layer_idx}_v'] = lora_v
            
            original_q_weight = q_proj.weight.detach()
            original_v_weight = v_proj.weight.detach()
            
            def make_lora_forward(lora_layer, orig_weight):
                def new_forward(x):
                    return lora_layer(x, orig_weight)
                return new_forward
            
            q_proj.forward = make_lora_forward(lora_q, original_q_weight)
            v_proj.forward = make_lora_forward(lora_v, original_v_weight)
        
        print(f"LoRA injection complete")
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.llm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        
        loss = None
        if labels is not None:
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            
            if torch.isnan(loss) or torch.isinf(loss):
                loss = None
        
        return {'logits': outputs.logits, 'loss': loss, 'hidden_states': outputs.hidden_states}


def print_system_info():
    print(f"\n{'='*80}")
    print("SYSTEM INFORMATION")
    print(f"{'='*80}")
    print(f"Date/Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"PyTorch Version: {torch.__version__}")
    print(f"CUDA Available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
    
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    print(f"Initial RAM Usage: {mem_info.rss / 1024**3:.2f} GB")
    
    virtual_mem = psutil.virtual_memory()
    print(f"Total System RAM: {virtual_mem.total / 1024**3:.2f} GB")
    print(f"Available RAM: {virtual_mem.available / 1024**3:.2f} GB")
    print(f"{'='*80}\n")


def print_parameter_statistics(model):
    total_params = sum(p.numel() for p in model.llm.parameters())
    lora_params = sum(p.numel() for p in model.lora_layers.parameters())
    trainable_params = lora_params
    
    print(f"\n{'='*80}")
    print("PARAMETER STATISTICS")
    print(f"{'='*80}")
    print(f"Total LLM Parameters: {total_params:,}")
    print(f"Trainable LoRA Parameters: {lora_params:,}")
    print(f"Trainable Percentage: {(trainable_params / total_params * 100):.4f}%")
    print(f"Parameter Efficiency: {total_params / trainable_params:.2f}x reduction")
    print(f"LoRA Rank (r): {model.lora_config.r}")
    print(f"LoRA Alpha: {model.lora_config.lora_alpha}")
    print(f"LoRA Scaling Factor: {model.lora_config.scaling}")
    print(f"{'='*80}\n")
    
    print(f"{'='*80}")
    print("GPU MEMORY ALLOCATION")
    print(f"{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        total = torch.cuda.get_device_properties(i).total_memory / 1024**3
        print(f"GPU {i}: {allocated:.2f} GB allocated / {reserved:.2f} GB reserved / {total:.2f} GB total")
    print(f"{'='*80}\n")


def load_and_process_data(csv_path, tokenizer, max_length=384, java_col='java', csharp_col='C#'):
    print(f"Loading dataset from: {csv_path}")
    df = pd.read_csv(csv_path)
    print(f"Dataset loaded: {len(df)} samples\n")
    
    data_list = []
    skipped = 0
    
    print("Processing dataset...")
    for idx, row in df.iterrows():
        try:
            java_code = str(row[java_col]).strip()
            csharp_code = str(row[csharp_col]).strip()
            
            if len(java_code) < 5 or len(csharp_code) < 5:
                skipped += 1
                continue
            
            prompt = f"[INST] Translate this Java code to C#:\n{java_code}\n[/INST]"
            full_text = f"{prompt} {csharp_code}"
            
            inputs = tokenizer(
                full_text,
                max_length=max_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )
            
            labels = inputs['input_ids'].clone()
            
            prompt_tokens = tokenizer(prompt, return_tensors='pt')['input_ids']
            prompt_length = prompt_tokens.shape[1]
            labels[0, :prompt_length] = -100
            
            data_list.append({
                'input_ids': inputs['input_ids'].squeeze(0),
                'attention_mask': inputs['attention_mask'].squeeze(0),
                'labels': labels.squeeze(0),
                'java_code': java_code,
                'csharp_code': csharp_code
            })
            
            if (idx + 1) % 1000 == 0:
                print(f"  Processed {idx + 1}/{len(df)} samples")
                
        except Exception as e:
            skipped += 1
            continue
    
    print(f"Processing complete: {len(data_list)} valid samples, {skipped} skipped\n")
    return data_list


def calculate_bleu_score(reference, hypothesis, max_n=4):
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    
    if len(hyp_tokens) == 0:
        return 0.0
    
    precisions = []
    for n in range(1, max_n + 1):
        ref_ngrams = Counter([tuple(ref_tokens[i:i+n]) for i in range(len(ref_tokens) - n + 1)])
        hyp_ngrams = Counter([tuple(hyp_tokens[i:i+n]) for i in range(len(hyp_tokens) - n + 1)])
        
        matches = sum((ref_ngrams & hyp_ngrams).values())
        total = sum(hyp_ngrams.values())
        
        precision = matches / total if total > 0 else 0.0
        precisions.append(precision)
    
    if all(p == 0 for p in precisions):
        return 0.0
    
    weights = [1.0 / max_n] * max_n
    log_precisions = [w * math.log(p) if p > 0 else float('-inf') for w, p in zip(weights, precisions)]
    
    if any(lp == float('-inf') for lp in log_precisions):
        return 0.0
    
    geometric_mean = math.exp(sum(log_precisions))
    
    brevity_penalty = min(1.0, math.exp(1 - len(ref_tokens) / len(hyp_tokens))) if len(hyp_tokens) > 0 else 0.0
    
    return brevity_penalty * geometric_mean


def calculate_meteor_score(reference, hypothesis):
    ref_tokens = set(reference.lower().split())
    hyp_tokens = set(hypothesis.lower().split())
    
    if len(hyp_tokens) == 0:
        return 0.0
    
    matches = len(ref_tokens & hyp_tokens)
    precision = matches / len(hyp_tokens) if len(hyp_tokens) > 0 else 0.0
    recall = matches / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
    
    if precision + recall == 0:
        return 0.0
    
    f_mean = (10 * precision * recall) / (9 * precision + recall)
    return f_mean


def calculate_rouge_l(reference, hypothesis):
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    
    if len(ref_tokens) == 0 or len(hyp_tokens) == 0:
        return 0.0
    
    lcs_length = 0
    m, n = len(ref_tokens), len(hyp_tokens)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if ref_tokens[i-1] == hyp_tokens[j-1]:
                dp[i][j] = dp[i-1][j-1] + 1
            else:
                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
    
    lcs_length = dp[m][n]
    
    precision = lcs_length / len(hyp_tokens) if len(hyp_tokens) > 0 else 0.0
    recall = lcs_length / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
    
    if precision + recall == 0:
        return 0.0
    
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def calculate_exact_match(reference, hypothesis):
    return 1.0 if reference.strip() == hypothesis.strip() else 0.0


def calculate_chrf_score(reference, hypothesis, n=6):
    def get_char_ngrams(text, n):
        chars = list(text)
        ngrams = []
        for i in range(len(chars) - n + 1):
            ngrams.append(''.join(chars[i:i+n]))
        return ngrams
    
    ref_ngrams = Counter(get_char_ngrams(reference, n))
    hyp_ngrams = Counter(get_char_ngrams(hypothesis, n))
    
    if len(hyp_ngrams) == 0:
        return 0.0
    
    matches = sum((ref_ngrams & hyp_ngrams).values())
    precision = matches / sum(hyp_ngrams.values()) if sum(hyp_ngrams.values()) > 0 else 0.0
    recall = matches / sum(ref_ngrams.values()) if sum(ref_ngrams.values()) > 0 else 0.0
    
    if precision + recall == 0:
        return 0.0
    
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def extract_syntax_features(code):
    features = {
        'methods': len(re.findall(r'\b(public|private|protected|static)\s+\w+\s+\w+\s*\(', code)),
        'classes': len(re.findall(r'\bclass\s+\w+', code)),
        'if_statements': len(re.findall(r'\bif\s*\(', code)),
        'for_loops': len(re.findall(r'\bfor\s*\(', code)),
        'while_loops': len(re.findall(r'\bwhile\s*\(', code)),
        'return_statements': len(re.findall(r'\breturn\b', code)),
        'variables': len(re.findall(r'\b\w+\s*=\s*', code)),
        'try_catch': len(re.findall(r'\btry\s*\{', code)),
    }
    return features


def calculate_syntax_similarity(ref_features, hyp_features):
    if not ref_features or not hyp_features:
        return 0.0
    
    total_diff = 0
    total_count = 0
    
    for key in ref_features:
        ref_val = ref_features[key]
        hyp_val = hyp_features.get(key, 0)
        
        max_val = max(ref_val, hyp_val)
        if max_val > 0:
            similarity = 1.0 - abs(ref_val - hyp_val) / max_val
            total_diff += similarity
            total_count += 1
    
    return total_diff / total_count if total_count > 0 else 0.0


def calculate_code_similarity_metrics(reference, hypothesis):
    metrics = {}
    
    metrics['bleu_1'] = calculate_bleu_score(reference, hypothesis, max_n=1)
    metrics['bleu_2'] = calculate_bleu_score(reference, hypothesis, max_n=2)
    metrics['bleu_4'] = calculate_bleu_score(reference, hypothesis, max_n=4)
    metrics['meteor'] = calculate_meteor_score(reference, hypothesis)
    metrics['rouge_l'] = calculate_rouge_l(reference, hypothesis)
    metrics['exact_match'] = calculate_exact_match(reference, hypothesis)
    metrics['chrf'] = calculate_chrf_score(reference, hypothesis)
    
    ref_features = extract_syntax_features(reference)
    hyp_features = extract_syntax_features(hypothesis)
    metrics['syntax_similarity'] = calculate_syntax_similarity(ref_features, hyp_features)
    
    return metrics


def check_compilation(code, language='csharp'):
    try:
        if language == 'csharp':
            with tempfile.NamedTemporaryFile(mode='w', suffix='.cs', delete=False) as f:
                wrapper = f"using System;\nusing System.Collections.Generic;\nusing System.Linq;\n\nnamespace TempCompile {{\n    public class TempClass {{\n        {code}\n    }}\n}}"
                f.write(wrapper)
                temp_file = f.name
            
            result = subprocess.run(
                ['csc', '/nologo', '/t:library', '/noconfig', temp_file],
                capture_output=True,
                timeout=5,
                text=True
            )
            
            success = result.returncode == 0
            
            try:
                os.unlink(temp_file)
                dll_file = temp_file.replace('.cs', '.dll')
                if os.path.exists(dll_file):
                    os.unlink(dll_file)
            except:
                pass
            
            return success
    except:
        pass
    
    return False


def collate_fn(batch):
    return {
        'input_ids': torch.stack([item['input_ids'] for item in batch]),
        'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
        'labels': torch.stack([item['labels'] for item in batch]),
        'java_code': [item['java_code'] for item in batch],
        'csharp_code': [item['csharp_code'] for item in batch]
    }


def train_model(model, train_data, num_epochs=5, learning_rate=2e-5, batch_size=2):
    print(f"\n{'='*80}")
    print("TRAINING CONFIGURATION")
    print(f"{'='*80}")
    print(f"Number of Epochs: {num_epochs}")
    print(f"Learning Rate: {learning_rate}")
    print(f"Batch Size: {batch_size}")
    print(f"Total Training Samples: {len(train_data)}")
    print(f"Total Batches per Epoch: {len(train_data) // batch_size}")
    print(f"{'='*80}\n")
    
    lora_params = [p for p in model.lora_layers.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(lora_params, lr=learning_rate, weight_decay=0.01, eps=1e-8)
    
    dataset = CodeTranslationDataset(train_data)
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=True
    )
    
    model.train()
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0.0
        valid_batches = 0
        
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 80)
        
        for batch_idx, batch in enumerate(dataloader):
            try:
                input_ids = batch['input_ids'].cuda()
                attention_mask = batch['attention_mask'].cuda()
                labels = batch['labels'].cuda()
                
                optimizer.zero_grad()
                
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs['loss']
                
                if loss is not None and not torch.isnan(loss) and not torch.isinf(loss):
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
                    optimizer.step()
                    
                    total_loss += loss.item()
                    valid_batches += 1
                    
                    if (batch_idx + 1) % 100 == 0:
                        avg_loss = total_loss / valid_batches
                        print(f"  Batch {batch_idx + 1}/{len(dataloader)} - Avg Loss: {avg_loss:.4f}")
                
                if (batch_idx + 1) % 500 == 0:
                    torch.cuda.empty_cache()
                    
            except Exception as e:
                print(f"  Warning: Batch {batch_idx + 1} failed - {str(e)}")
                continue
        
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / valid_batches if valid_batches > 0 else 0.0
        
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"  Average Loss: {avg_loss:.4f}")
        print(f"  Valid Batches: {valid_batches}/{len(dataloader)}")
        print(f"  Time: {epoch_time:.2f}s")
        
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
        
        torch.cuda.empty_cache()
    
    total_time = time.time() - start_time
    
    print(f"\n{'='*80}")
    print("TRAINING COMPLETE")
    print(f"{'='*80}")
    print(f"Total Time: {total_time:.2f}s ({total_time/60:.2f} minutes)")
    print(f"Average Time per Epoch: {total_time/num_epochs:.2f}s")
    print(f"{'='*80}\n")


def evaluate_model(model, tokenizer, test_data, batch_size=1, max_new_tokens=256):
    print(f"\n{'='*80}")
    print("EVALUATION PHASE")
    print(f"{'='*80}")
    print(f"Test Samples: {len(test_data)}")
    print(f"Max Generation Length: {max_new_tokens}")
    print(f"{'='*80}\n")
    
    model.eval()
    
    all_metrics = {
        'bleu_1': [], 'bleu_2': [], 'bleu_4': [],
        'meteor': [], 'rouge_l': [], 'exact_match': [],
        'chrf': [], 'syntax_similarity': []
    }
    
    references = []
    hypotheses = []
    compilation_success = 0
    
    start_time = time.time()
    total_tokens = 0
    
    print("Generating translations...")
    
    with torch.no_grad():
        for idx, item in enumerate(test_data):
            try:
                java_code = item['java_code']
                csharp_code = item['csharp_code']
                
                prompt = f"[INST] Translate this Java code to C#:\n{java_code}\n[/INST]"
                
                inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=384)
                inputs = {k: v.cuda() for k, v in inputs.items()}
                
                outputs = model.llm.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    num_beams=4,
                    early_stopping=True,
                    do_sample=False,
                    temperature=1.0,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )
                
                generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
                hypothesis = generated_text[len(prompt):].strip()
                
                references.append(csharp_code)
                hypotheses.append(hypothesis)
                
                metrics = calculate_code_similarity_metrics(csharp_code, hypothesis)
                for key in all_metrics:
                    all_metrics[key].append(metrics[key])
                
                if check_compilation(hypothesis):
                    compilation_success += 1
                
                total_tokens += len(outputs[0])
                
                if (idx + 1) % 50 == 0:
                    print(f"  Processed {idx + 1}/{len(test_data)} samples")
                
                if (idx + 1) % 100 == 0:
                    torch.cuda.empty_cache()
                    
            except Exception as e:
                print(f"  Warning: Sample {idx + 1} failed - {str(e)}")
                for key in all_metrics:
                    all_metrics[key].append(0.0)
                references.append("")
                hypotheses.append("")
                continue
    
    inference_time = time.time() - start_time
    
    results = {
        'bleu_1': np.mean(all_metrics['bleu_1']),
        'bleu_1_std': np.std(all_metrics['bleu_1']),
        'bleu_2': np.mean(all_metrics['bleu_2']),
        'bleu_2_std': np.std(all_metrics['bleu_2']),
        'bleu_4': np.mean(all_metrics['bleu_4']),
        'bleu_4_std': np.std(all_metrics['bleu_4']),
        'meteor': np.mean(all_metrics['meteor']),
        'meteor_std': np.std(all_metrics['meteor']),
        'rouge_l': np.mean(all_metrics['rouge_l']),
        'rouge_l_std': np.std(all_metrics['rouge_l']),
        'exact_match': np.mean(all_metrics['exact_match']),
        'chrf': np.mean(all_metrics['chrf']),
        'chrf_std': np.std(all_metrics['chrf']),
        'syntax_similarity': np.mean(all_metrics['syntax_similarity']),
        'syntax_similarity_std': np.std(all_metrics['syntax_similarity']),
        'compilation_rate': (compilation_success / len(test_data)) * 100,
        'inference_time': inference_time,
        'throughput': len(test_data) / inference_time,
        'tokens_per_sec': total_tokens / inference_time,
        'total_tokens': total_tokens,
        'avg_tokens_per_sample': total_tokens / len(test_data)
    }
    
    return results


def print_evaluation_results(results):
    print(f"\n{'='*80}")
    print("EVALUATION RESULTS")
    print(f"{'='*80}\n")
    
    print(f"{'='*80}")
    print("N-GRAM BASED METRICS")
    print(f"{'='*80}")
    print(f"BLEU-1: {results['bleu_1']:.4f} (±{results['bleu_1_std']:.4f})")
    print(f"BLEU-2: {results['bleu_2']:.4f} (±{results['bleu_2_std']:.4f})")
    print(f"BLEU-4: {results['bleu_4']:.4f} (±{results['bleu_4_std']:.4f})")
    print(f"chrF: {results['chrf']:.4f} (±{results['chrf_std']:.4f})")
    
    print(f"\n{'='*80}")
    print("SEMANTIC SIMILARITY METRICS")
    print(f"{'='*80}")
    print(f"METEOR: {results['meteor']:.4f} (±{results['meteor_std']:.4f})")
    print(f"ROUGE-L: {results['rouge_l']:.4f} (±{results['rouge_l_std']:.4f})")
    
    print(f"\n{'='*80}")
    print("CODE-SPECIFIC METRICS")
    print(f"{'='*80}")
    print(f"Syntax Similarity: {results['syntax_similarity']:.4f} (±{results['syntax_similarity_std']:.4f})")
    print(f"Exact Match: {results['exact_match']:.4f}")
    print(f"Compilation Success Rate: {results['compilation_rate']:.2f}%")
    
    print(f"\n{'='*80}")
    print("PERFORMANCE METRICS")
    print(f"{'='*80}")
    print(f"Total Inference Time: {results['inference_time']:.2f}s ({results['inference_time']/60:.2f} min)")
    print(f"Throughput: {results['throughput']:.2f} samples/sec")
    print(f"Token Generation Rate: {results['tokens_per_sec']:.2f} tokens/sec")
    print(f"Total Tokens Generated: {results['total_tokens']}")
    print(f"Average Tokens per Sample: {results['avg_tokens_per_sample']:.2f}")
    print(f"{'='*80}\n")


def main():
    print(f"\n{'='*80}")
    print("JAVA TO C# CODE TRANSLATION")
    print("Using LoRA-Enhanced Mistral-7B")
    print(f"{'='*80}\n")
    
    print_system_info()
    
    torch.cuda.empty_cache()
    gc.collect()
    
    cache_dir = '/hf_cache'
    
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        "mistralai/Mistral-7B-Instruct-v0.2",
        cache_dir=cache_dir
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print("Tokenizer loaded\n")
    
    lora_config = LoRAConfig(r=8, lora_alpha=16, lora_dropout=0.1)
    
    print("Initializing model...")
    model = LoRAMistral(
        "mistralai/Mistral-7B-Instruct-v0.2",
        lora_config,
        cache_dir=cache_dir
    )
    
    print_parameter_statistics(model)
    
    print(f"\n{'='*80}")
    print("LOADING TRAINING DATA")
    print(f"{'='*80}\n")
    
    train_csv = '/train.csv'
    train_data = load_and_process_data(train_csv, tokenizer, java_col='java', csharp_col='C#')
    
    if len(train_data) == 0:
        print("ERROR: No training data loaded. Exiting.")
        return
    
    train_model(model, train_data, num_epochs=5, learning_rate=2e-5, batch_size=2)
    
    print(f"\n{'='*80}")
    print("LOADING TEST DATA")
    print(f"{'='*80}\n")
    
    test_csv = '/test.csv'
    test_data = load_and_process_data(test_csv, tokenizer, java_col='java', csharp_col='C#')
    
    if len(test_data) == 0:
        print("ERROR: No test data loaded. Exiting.")
        return
    
    results = evaluate_model(model, tokenizer, test_data, batch_size=1)
    
    print_evaluation_results(results)
    
    print(f"\n{'='*80}")
    print("EXECUTION COMPLETE")
    print(f"{'='*80}")
    print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*80}\n")


if __name__ == '__main__':
    main()

# BitFit

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import re
import subprocess
import tempfile
from datetime import datetime
import math


class BitFitConfig:
    def __init__(self, bias_terms='all'):
        self.bias_terms = bias_terms


class CodeTranslationDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


class BitFitMistral(nn.Module):
    def __init__(self, model_name: str, bitfit_config: BitFitConfig, cache_dir: str = None):
        super().__init__()
        torch.cuda.empty_cache()
        gc.collect()
        num_gpus = torch.cuda.device_count()
        print(f"\nInitializing model across {num_gpus} GPUs")
        for i in range(num_gpus):
            props = torch.cuda.get_device_properties(i)
            print(f"  GPU {i}: {props.name} - {props.total_memory / 1024**3:.2f} GB")
        self.primary_device = 'cuda:0'
        print(f"\nLoading base model: {model_name}")
        base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir,
            low_cpu_mem_usage=True,
            device_map='auto',
            trust_remote_code=True
        )
        self.llm = base_model
        self.bitfit_config = bitfit_config
        num_layers = len(base_model.model.layers)
        print(f"Model has {num_layers} transformer layers")
        for param in self.llm.parameters():
            param.requires_grad = False
        self._add_and_enable_bias_training()
        print("Model initialization complete!\n")
    
    def _add_and_enable_bias_training(self):
        num_layers = len(self.llm.model.layers)
        print(f"Adding and enabling bias terms for {num_layers} layers...")
        trainable_biases = 0
        for layer_idx in range(num_layers):
            layer = self.llm.model.layers[layer_idx]
            if hasattr(layer.self_attn, 'q_proj'):
                if layer.self_attn.q_proj.bias is None:
                    layer.self_attn.q_proj.bias = nn.Parameter(torch.zeros(layer.self_attn.q_proj.out_features, dtype=torch.bfloat16, device=layer.self_attn.q_proj.weight.device))
                layer.self_attn.q_proj.bias.requires_grad = True
                trainable_biases += layer.self_attn.q_proj.bias.numel()
            if hasattr(layer.self_attn, 'k_proj'):
                if layer.self_attn.k_proj.bias is None:
                    layer.self_attn.k_proj.bias = nn.Parameter(torch.zeros(layer.self_attn.k_proj.out_features, dtype=torch.bfloat16, device=layer.self_attn.k_proj.weight.device))
                layer.self_attn.k_proj.bias.requires_grad = True
                trainable_biases += layer.self_attn.k_proj.bias.numel()
            if hasattr(layer.self_attn, 'v_proj'):
                if layer.self_attn.v_proj.bias is None:
                    layer.self_attn.v_proj.bias = nn.Parameter(torch.zeros(layer.self_attn.v_proj.out_features, dtype=torch.bfloat16, device=layer.self_attn.v_proj.weight.device))
                layer.self_attn.v_proj.bias.requires_grad = True
                trainable_biases += layer.self_attn.v_proj.bias.numel()
            if hasattr(layer.self_attn, 'o_proj'):
                if layer.self_attn.o_proj.bias is None:
                    layer.self_attn.o_proj.bias = nn.Parameter(torch.zeros(layer.self_attn.o_proj.out_features, dtype=torch.bfloat16, device=layer.self_attn.o_proj.weight.device))
                layer.self_attn.o_proj.bias.requires_grad = True
                trainable_biases += layer.self_attn.o_proj.bias.numel()
            if hasattr(layer, 'mlp'):
                if hasattr(layer.mlp, 'gate_proj'):
                    if layer.mlp.gate_proj.bias is None:
                        layer.mlp.gate_proj.bias = nn.Parameter(torch.zeros(layer.mlp.gate_proj.out_features, dtype=torch.bfloat16, device=layer.mlp.gate_proj.weight.device))
                    layer.mlp.gate_proj.bias.requires_grad = True
                    trainable_biases += layer.mlp.gate_proj.bias.numel()
                if hasattr(layer.mlp, 'up_proj'):
                    if layer.mlp.up_proj.bias is None:
                        layer.mlp.up_proj.bias = nn.Parameter(torch.zeros(layer.mlp.up_proj.out_features, dtype=torch.bfloat16, device=layer.mlp.up_proj.weight.device))
                    layer.mlp.up_proj.bias.requires_grad = True
                    trainable_biases += layer.mlp.up_proj.bias.numel()
                if hasattr(layer.mlp, 'down_proj'):
                    if layer.mlp.down_proj.bias is None:
                        layer.mlp.down_proj.bias = nn.Parameter(torch.zeros(layer.mlp.down_proj.out_features, dtype=torch.bfloat16, device=layer.mlp.down_proj.weight.device))
                    layer.mlp.down_proj.bias.requires_grad = True
                    trainable_biases += layer.mlp.down_proj.bias.numel()
        if hasattr(self.llm, 'lm_head'):
            if self.llm.lm_head.bias is None:
                self.llm.lm_head.bias = nn.Parameter(torch.zeros(self.llm.lm_head.out_features, dtype=torch.bfloat16, device=self.llm.lm_head.weight.device))
            self.llm.lm_head.bias.requires_grad = True
            trainable_biases += self.llm.lm_head.bias.numel()
        print(f"BitFit training enabled: {trainable_biases} bias parameters")
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.llm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        loss = None
        if labels is not None:
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            if torch.isnan(loss) or torch.isinf(loss):
                loss = None
        return {'logits': outputs.logits, 'loss': loss, 'hidden_states': outputs.hidden_states}


def print_system_info():
    print(f"\n{'='*80}")
    print("SYSTEM INFORMATION")
    print(f"{'='*80}")
    print(f"Date/Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"PyTorch Version: {torch.__version__}")
    print(f"CUDA Available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    print(f"Initial RAM Usage: {mem_info.rss / 1024**3:.2f} GB")
    virtual_mem = psutil.virtual_memory()
    print(f"Total System RAM: {virtual_mem.total / 1024**3:.2f} GB")
    print(f"Available RAM: {virtual_mem.available / 1024**3:.2f} GB")
    print(f"{'='*80}\n")


def print_parameter_statistics(model):
    total_params = sum(p.numel() for p in model.llm.parameters())
    trainable_params = sum(p.numel() for p in model.llm.parameters() if p.requires_grad)
    print(f"\n{'='*80}")
    print("PARAMETER STATISTICS")
    print(f"{'='*80}")
    print(f"Total LLM Parameters: {total_params:,}")
    print(f"Trainable Bias Parameters: {trainable_params:,}")
    print(f"Trainable Percentage: {(trainable_params / total_params * 100):.4f}%")
    if trainable_params > 0:
        print(f"Parameter Efficiency: {total_params / trainable_params:.2f}x reduction")
    print(f"BitFit Mode: {model.bitfit_config.bias_terms}")
    print(f"{'='*80}\n")
    print(f"{'='*80}")
    print("GPU MEMORY ALLOCATION")
    print(f"{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        total = torch.cuda.get_device_properties(i).total_memory / 1024**3
        print(f"GPU {i}: {allocated:.2f} GB allocated / {reserved:.2f} GB reserved / {total:.2f} GB total")
    print(f"{'='*80}\n")


def load_and_process_data(csv_path, tokenizer, max_length=384, java_col='java', csharp_col='C#'):
    print(f"Loading dataset from: {csv_path}")
    df = pd.read_csv(csv_path)
    print(f"Dataset loaded: {len(df)} samples\n")
    data_list = []
    skipped = 0
    print("Processing dataset...")
    for idx, row in df.iterrows():
        try:
            java_code = str(row[java_col]).strip()
            csharp_code = str(row[csharp_col]).strip()
            if len(java_code) < 5 or len(csharp_code) < 5:
                skipped += 1
                continue
            prompt = f"[INST] Translate this Java code to C#:\n{java_code}\n[/INST]"
            full_text = f"{prompt} {csharp_code}"
            inputs = tokenizer(
                full_text,
                max_length=max_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )
            labels = inputs['input_ids'].clone()
            prompt_tokens = tokenizer(prompt, return_tensors='pt')['input_ids']
            prompt_length = prompt_tokens.shape[1]
            labels[0, :prompt_length] = -100
            data_list.append({
                'input_ids': inputs['input_ids'].squeeze(0),
                'attention_mask': inputs['attention_mask'].squeeze(0),
                'labels': labels.squeeze(0),
                'java_code': java_code,
                'csharp_code': csharp_code
            })
            if (idx + 1) % 1000 == 0:
                print(f"  Processed {idx + 1}/{len(df)} samples")
        except Exception as e:
            skipped += 1
            continue
    print(f"Processing complete: {len(data_list)} valid samples, {skipped} skipped\n")
    return data_list


def calculate_bleu_score(reference, hypothesis, max_n=4):
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    if len(hyp_tokens) == 0:
        return 0.0
    precisions = []
    for n in range(1, max_n + 1):
        ref_ngrams = Counter([tuple(ref_tokens[i:i+n]) for i in range(len(ref_tokens) - n + 1)])
        hyp_ngrams = Counter([tuple(hyp_tokens[i:i+n]) for i in range(len(hyp_tokens) - n + 1)])
        matches = sum((ref_ngrams & hyp_ngrams).values())
        total = sum(hyp_ngrams.values())
        precision = matches / total if total > 0 else 0.0
        precisions.append(precision)
    if all(p == 0 for p in precisions):
        return 0.0
    weights = [1.0 / max_n] * max_n
    log_precisions = [w * math.log(p) if p > 0 else float('-inf') for w, p in zip(weights, precisions)]
    if any(lp == float('-inf') for lp in log_precisions):
        return 0.0
    geometric_mean = math.exp(sum(log_precisions))
    brevity_penalty = min(1.0, math.exp(1 - len(ref_tokens) / len(hyp_tokens))) if len(hyp_tokens) > 0 else 0.0
    return brevity_penalty * geometric_mean


def calculate_meteor_score(reference, hypothesis):
    ref_tokens = set(reference.lower().split())
    hyp_tokens = set(hypothesis.lower().split())
    if len(hyp_tokens) == 0:
        return 0.0
    matches = len(ref_tokens & hyp_tokens)
    precision = matches / len(hyp_tokens) if len(hyp_tokens) > 0 else 0.0
    recall = matches / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f_mean = (10 * precision * recall) / (9 * precision + recall)
    return f_mean


def calculate_rouge_l(reference, hypothesis):
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    if len(ref_tokens) == 0 or len(hyp_tokens) == 0:
        return 0.0
    lcs_length = 0
    m, n = len(ref_tokens), len(hyp_tokens)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if ref_tokens[i-1] == hyp_tokens[j-1]:
                dp[i][j] = dp[i-1][j-1] + 1
            else:
                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
    lcs_length = dp[m][n]
    precision = lcs_length / len(hyp_tokens) if len(hyp_tokens) > 0 else 0.0
    recall = lcs_length / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def calculate_exact_match(reference, hypothesis):
    return 1.0 if reference.strip() == hypothesis.strip() else 0.0


def calculate_chrf_score(reference, hypothesis, n=6):
    def get_char_ngrams(text, n):
        chars = list(text)
        ngrams = []
        for i in range(len(chars) - n + 1):
            ngrams.append(''.join(chars[i:i+n]))
        return ngrams
    ref_ngrams = Counter(get_char_ngrams(reference, n))
    hyp_ngrams = Counter(get_char_ngrams(hypothesis, n))
    if len(hyp_ngrams) == 0:
        return 0.0
    matches = sum((ref_ngrams & hyp_ngrams).values())
    precision = matches / sum(hyp_ngrams.values()) if sum(hyp_ngrams.values()) > 0 else 0.0
    recall = matches / sum(ref_ngrams.values()) if sum(ref_ngrams.values()) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def extract_syntax_features(code):
    features = {
        'methods': len(re.findall(r'\b(public|private|protected|static)\s+\w+\s+\w+\s*\(', code)),
        'classes': len(re.findall(r'\bclass\s+\w+', code)),
        'if_statements': len(re.findall(r'\bif\s*\(', code)),
        'for_loops': len(re.findall(r'\bfor\s*\(', code)),
        'while_loops': len(re.findall(r'\bwhile\s*\(', code)),
        'return_statements': len(re.findall(r'\breturn\b', code)),
        'variables': len(re.findall(r'\b\w+\s*=\s*', code)),
        'try_catch': len(re.findall(r'\btry\s*\{', code)),
    }
    return features


def calculate_syntax_similarity(ref_features, hyp_features):
    if not ref_features or not hyp_features:
        return 0.0
    total_diff = 0
    total_count = 0
    for key in ref_features:
        ref_val = ref_features[key]
        hyp_val = hyp_features.get(key, 0)
        max_val = max(ref_val, hyp_val)
        if max_val > 0:
            similarity = 1.0 - abs(ref_val - hyp_val) / max_val
            total_diff += similarity
            total_count += 1
    return total_diff / total_count if total_count > 0 else 0.0


def calculate_code_similarity_metrics(reference, hypothesis):
    metrics = {}
    metrics['bleu_1'] = calculate_bleu_score(reference, hypothesis, max_n=1)
    metrics['bleu_2'] = calculate_bleu_score(reference, hypothesis, max_n=2)
    metrics['bleu_4'] = calculate_bleu_score(reference, hypothesis, max_n=4)
    metrics['meteor'] = calculate_meteor_score(reference, hypothesis)
    metrics['rouge_l'] = calculate_rouge_l(reference, hypothesis)
    metrics['exact_match'] = calculate_exact_match(reference, hypothesis)
    metrics['chrf'] = calculate_chrf_score(reference, hypothesis)
    ref_features = extract_syntax_features(reference)
    hyp_features = extract_syntax_features(hypothesis)
    metrics['syntax_similarity'] = calculate_syntax_similarity(ref_features, hyp_features)
    return metrics


def check_compilation(code, language='csharp'):
    try:
        if language == 'csharp':
            with tempfile.NamedTemporaryFile(mode='w', suffix='.cs', delete=False) as f:
                wrapper = f"using System;\nusing System.Collections.Generic;\nusing System.Linq;\n\nnamespace TempCompile {{\n    public class TempClass {{\n        {code}\n    }}\n}}"
                f.write(wrapper)
                temp_file = f.name
            result = subprocess.run(
                ['csc', '/nologo', '/t:library', '/noconfig', temp_file],
                capture_output=True,
                timeout=5,
                text=True
            )
            success = result.returncode == 0
            try:
                os.unlink(temp_file)
                dll_file = temp_file.replace('.cs', '.dll')
                if os.path.exists(dll_file):
                    os.unlink(dll_file)
            except:
                pass
            return success
    except:
        pass
    return False


def collate_fn(batch):
    return {
        'input_ids': torch.stack([item['input_ids'] for item in batch]),
        'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
        'labels': torch.stack([item['labels'] for item in batch]),
        'java_code': [item['java_code'] for item in batch],
        'csharp_code': [item['csharp_code'] for item in batch]
    }


def train_model(model, train_data, num_epochs=5, learning_rate=2e-5, batch_size=2):
    print(f"\n{'='*80}")
    print("TRAINING CONFIGURATION")
    print(f"{'='*80}")
    print(f"Number of Epochs: {num_epochs}")
    print(f"Learning Rate: {learning_rate}")
    print(f"Batch Size: {batch_size}")
    print(f"Total Training Samples: {len(train_data)}")
    print(f"Total Batches per Epoch: {len(train_data) // batch_size}")
    print(f"{'='*80}\n")
    bias_params = [p for p in model.llm.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(bias_params, lr=learning_rate, weight_decay=0.01, eps=1e-8)
    dataset = CodeTranslationDataset(train_data)
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=True
    )
    model.train()
    start_time = time.time()
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0.0
        valid_batches = 0
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 80)
        for batch_idx, batch in enumerate(dataloader):
            try:
                input_ids = batch['input_ids'].cuda()
                attention_mask = batch['attention_mask'].cuda()
                labels = batch['labels'].cuda()
                optimizer.zero_grad()
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs['loss']
                if loss is not None and not torch.isnan(loss) and not torch.isinf(loss):
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(bias_params, max_norm=1.0)
                    optimizer.step()
                    total_loss += loss.item()
                    valid_batches += 1
                    if (batch_idx + 1) % 100 == 0:
                        avg_loss = total_loss / valid_batches
                        print(f"  Batch {batch_idx + 1}/{len(dataloader)} - Avg Loss: {avg_loss:.4f}")
                if (batch_idx + 1) % 500 == 0:
                    torch.cuda.empty_cache()
            except Exception as e:
                print(f"  Warning: Batch {batch_idx + 1} failed - {str(e)}")
                continue
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / valid_batches if valid_batches > 0 else 0.0
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"  Average Loss: {avg_loss:.4f}")
        print(f"  Valid Batches: {valid_batches}/{len(dataloader)}")
        print(f"  Time: {epoch_time:.2f}s")
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
        torch.cuda.empty_cache()
    total_time = time.time() - start_time
    print(f"\n{'='*80}")
    print("TRAINING COMPLETE")
    print(f"{'='*80}")
    print(f"Total Time: {total_time:.2f}s ({total_time/60:.2f} minutes)")
    print(f"Average Time per Epoch: {total_time/num_epochs:.2f}s")
    print(f"{'='*80}\n")


def evaluate_model(model, tokenizer, test_data, batch_size=1, max_new_tokens=256):
    print(f"\n{'='*80}")
    print("EVALUATION PHASE")
    print(f"{'='*80}")
    print(f"Test Samples: {len(test_data)}")
    print(f"Max Generation Length: {max_new_tokens}")
    print(f"{'='*80}\n")
    model.eval()
    all_metrics = {
        'bleu_1': [], 'bleu_2': [], 'bleu_4': [],
        'meteor': [], 'rouge_l': [], 'exact_match': [],
        'chrf': [], 'syntax_similarity': []
    }
    references = []
    hypotheses = []
    compilation_success = 0
    start_time = time.time()
    total_tokens = 0
    print("Generating translations...")
    with torch.no_grad():
        for idx, item in enumerate(test_data):
            try:
                java_code = item['java_code']
                csharp_code = item['csharp_code']
                prompt = f"[INST] Translate this Java code to C#:\n{java_code}\n[/INST]"
                inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=384)
                inputs = {k: v.cuda() for k, v in inputs.items()}
                outputs = model.llm.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    num_beams=4,
                    early_stopping=True,
                    do_sample=False,
                    temperature=1.0,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )
                generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
                hypothesis = generated_text[len(prompt):].strip()
                references.append(csharp_code)
                hypotheses.append(hypothesis)
                metrics = calculate_code_similarity_metrics(csharp_code, hypothesis)
                for key in all_metrics:
                    all_metrics[key].append(metrics[key])
                if check_compilation(hypothesis):
                    compilation_success += 1
                total_tokens += len(outputs[0])
                if (idx + 1) % 50 == 0:
                    print(f"  Processed {idx + 1}/{len(test_data)} samples")
                if (idx + 1) % 100 == 0:
                    torch.cuda.empty_cache()
            except Exception as e:
                print(f"  Warning: Sample {idx + 1} failed - {str(e)}")
                for key in all_metrics:
                    all_metrics[key].append(0.0)
                references.append("")
                hypotheses.append("")
                continue
    inference_time = time.time() - start_time
    results = {
        'bleu_1': np.mean(all_metrics['bleu_1']),
        'bleu_1_std': np.std(all_metrics['bleu_1']),
        'bleu_2': np.mean(all_metrics['bleu_2']),
        'bleu_2_std': np.std(all_metrics['bleu_2']),
        'bleu_4': np.mean(all_metrics['bleu_4']),
        'bleu_4_std': np.std(all_metrics['bleu_4']),
        'meteor': np.mean(all_metrics['meteor']),
        'meteor_std': np.std(all_metrics['meteor']),
        'rouge_l': np.mean(all_metrics['rouge_l']),
        'rouge_l_std': np.std(all_metrics['rouge_l']),
        'exact_match': np.mean(all_metrics['exact_match']),
        'chrf': np.mean(all_metrics['chrf']),
        'chrf_std': np.std(all_metrics['chrf']),
        'syntax_similarity': np.mean(all_metrics['syntax_similarity']),
        'syntax_similarity_std': np.std(all_metrics['syntax_similarity']),
        'compilation_rate': (compilation_success / len(test_data)) * 100,
        'inference_time': inference_time,
        'throughput': len(test_data) / inference_time,
        'tokens_per_sec': total_tokens / inference_time,
        'total_tokens': total_tokens,
        'avg_tokens_per_sample': total_tokens / len(test_data)
    }
    return results


def print_evaluation_results(results):
    print(f"\n{'='*80}")
    print("EVALUATION RESULTS")
    print(f"{'='*80}\n")
    print(f"{'='*80}")
    print("N-GRAM BASED METRICS")
    print(f"{'='*80}")
    print(f"BLEU-1: {results['bleu_1']:.4f} (±{results['bleu_1_std']:.4f})")
    print(f"BLEU-2: {results['bleu_2']:.4f} (±{results['bleu_2_std']:.4f})")
    print(f"BLEU-4: {results['bleu_4']:.4f} (±{results['bleu_4_std']:.4f})")
    print(f"chrF: {results['chrf']:.4f} (±{results['chrf_std']:.4f})")
    print(f"\n{'='*80}")
    print("SEMANTIC SIMILARITY METRICS")
    print(f"{'='*80}")
    print(f"METEOR: {results['meteor']:.4f} (±{results['meteor_std']:.4f})")
    print(f"ROUGE-L: {results['rouge_l']:.4f} (±{results['rouge_l_std']:.4f})")
    print(f"\n{'='*80}")
    print("CODE-SPECIFIC METRICS")
    print(f"{'='*80}")
    print(f"Syntax Similarity: {results['syntax_similarity']:.4f} (±{results['syntax_similarity_std']:.4f})")
    print(f"Exact Match: {results['exact_match']:.4f}")
    print(f"Compilation Success Rate: {results['compilation_rate']:.2f}%")
    print(f"\n{'='*80}")
    print("PERFORMANCE METRICS")
    print(f"{'='*80}")
    print(f"Total Inference Time: {results['inference_time']:.2f}s ({results['inference_time']/60:.2f} min)")
    print(f"Throughput: {results['throughput']:.2f} samples/sec")
    print(f"Token Generation Rate: {results['tokens_per_sec']:.2f} tokens/sec")
    print(f"Total Tokens Generated: {results['total_tokens']}")
    print(f"Average Tokens per Sample: {results['avg_tokens_per_sample']:.2f}")
    print(f"{'='*80}\n")


def main():
    print(f"\n{'='*80}")
    print("JAVA TO C# CODE TRANSLATION")
    print("Using BitFit-Enhanced Mistral-7B")
    print(f"{'='*80}\n")
    print_system_info()
    torch.cuda.empty_cache()
    gc.collect()
    cache_dir = '/hf_cache'
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        "mistralai/Mistral-7B-Instruct-v0.2",
        cache_dir=cache_dir
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print("Tokenizer loaded\n")
    bitfit_config = BitFitConfig(bias_terms='all')
    print("Initializing model...")
    model = BitFitMistral(
        "mistralai/Mistral-7B-Instruct-v0.2",
        bitfit_config,
        cache_dir=cache_dir
    )
    print_parameter_statistics(model)
    print(f"\n{'='*80}")
    print("LOADING TRAINING DATA")
    print(f"{'='*80}\n")
    train_csv = '/train.csv'
    train_data = load_and_process_data(train_csv, tokenizer, java_col='java', csharp_col='C#')
    if len(train_data) == 0:
        print("ERROR: No training data loaded. Exiting.")
        return
    train_model(model, train_data, num_epochs=5, learning_rate=2e-5, batch_size=2)
    print(f"\n{'='*80}")
    print("LOADING TEST DATA")
    print(f"{'='*80}\n")
    test_csv = '/test.csv'
    test_data = load_and_process_data(test_csv, tokenizer, java_col='java', csharp_col='C#')
    if len(test_data) == 0:
        print("ERROR: No test data loaded. Exiting.")
        return
    results = evaluate_model(model, tokenizer, test_data, batch_size=1)
    print_evaluation_results(results)
    print(f"\n{'='*80}")
    print("EXECUTION COMPLETE")
    print(f"{'='*80}")
    print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*80}\n")


if __name__ == '__main__':
    main()

# Prefix

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import re
import subprocess
import tempfile
from datetime import datetime
import math


class PrefixTuningConfig:
    def __init__(self, prefix_length=10, num_layers=32, hidden_size=4096, num_attention_heads=32):
        self.prefix_length = prefix_length
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads


class CodeTranslationDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


class PrefixEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.prefix_length = config.prefix_length
        self.num_layers = config.num_layers
        self.hidden_size = config.hidden_size
        
        self.prefix_tokens = nn.Parameter(torch.randn(self.prefix_length, self.hidden_size))
        self.prefix_mlp = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Tanh(),
            nn.Linear(self.hidden_size, self.hidden_size)
        )
    
    def forward(self, batch_size):
        prefix_embeds = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1, -1)
        prefix_embeds = self.prefix_mlp(prefix_embeds)
        return prefix_embeds.to(torch.bfloat16)

class PrefixTuningMistral(nn.Module):
    def __init__(self, model_name: str, prefix_config: PrefixTuningConfig, cache_dir: str = None):
        super().__init__()
        torch.cuda.empty_cache()
        gc.collect()
        num_gpus = torch.cuda.device_count()
        print(f"\nInitializing model across {num_gpus} GPUs")
        for i in range(num_gpus):
            props = torch.cuda.get_device_properties(i)
            print(f"  GPU {i}: {props.name} - {props.total_memory / 1024**3:.2f} GB")
        self.primary_device = 'cuda:0'
        print(f"\nLoading base model: {model_name}")
        base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir,
            low_cpu_mem_usage=True,
            device_map='auto',
            trust_remote_code=True
        )
        self.llm = base_model
        self.prefix_config = prefix_config
        num_layers = len(base_model.model.layers)
        print(f"Model has {num_layers} transformer layers")
        for param in self.llm.parameters():
            param.requires_grad = False
        self.prefix_encoder = PrefixEncoder(prefix_config).to(self.primary_device)
        self.prefix_encoder = self.prefix_encoder.to(torch.bfloat16)  # ADD THIS LINE
        print(f"Prefix encoder initialized with {self.prefix_config.prefix_length} prefix tokens")
        print("Model initialization complete!\n")
    
    def get_prompt_embedding(self, input_ids):
        inputs_embeds = self.llm.get_input_embeddings()(input_ids)
        return inputs_embeds
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        batch_size = input_ids.size(0)
        
        prefix_embeds = self.prefix_encoder(batch_size).to(input_ids.device)
        inputs_embeds = self.get_prompt_embedding(input_ids)
        inputs_embeds = torch.cat([prefix_embeds, inputs_embeds], dim=1)
        
        prefix_attention_mask = torch.ones(batch_size, self.prefix_config.prefix_length, device=input_ids.device, dtype=attention_mask.dtype if attention_mask is not None else torch.long)
        if attention_mask is not None:
            attention_mask = torch.cat([prefix_attention_mask, attention_mask], dim=1)
        else:
            attention_mask = torch.cat([prefix_attention_mask, torch.ones(batch_size, input_ids.size(1), device=input_ids.device)], dim=1)
        
        outputs = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        
        loss = None
        if labels is not None:
            logits = outputs.logits[:, self.prefix_config.prefix_length:, :]
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            if torch.isnan(loss) or torch.isinf(loss):
                loss = None
        
        return {'logits': outputs.logits, 'loss': loss, 'hidden_states': outputs.hidden_states}


def print_system_info():
    print(f"\n{'='*80}")
    print("SYSTEM INFORMATION")
    print(f"{'='*80}")
    print(f"Date/Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"PyTorch Version: {torch.__version__}")
    print(f"CUDA Available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    print(f"Initial RAM Usage: {mem_info.rss / 1024**3:.2f} GB")
    virtual_mem = psutil.virtual_memory()
    print(f"Total System RAM: {virtual_mem.total / 1024**3:.2f} GB")
    print(f"Available RAM: {virtual_mem.available / 1024**3:.2f} GB")
    print(f"{'='*80}\n")


def print_parameter_statistics(model):
    total_params = sum(p.numel() for p in model.llm.parameters())
    trainable_params = sum(p.numel() for p in model.prefix_encoder.parameters())
    print(f"\n{'='*80}")
    print("PARAMETER STATISTICS")
    print(f"{'='*80}")
    print(f"Total LLM Parameters: {total_params:,}")
    print(f"Trainable Prefix Parameters: {trainable_params:,}")
    print(f"Trainable Percentage: {(trainable_params / (total_params + trainable_params) * 100):.4f}%")
    if trainable_params > 0:
        print(f"Parameter Efficiency: {total_params / trainable_params:.2f}x reduction")
    print(f"Prefix Length: {model.prefix_config.prefix_length}")
    print(f"{'='*80}\n")
    print(f"{'='*80}")
    print("GPU MEMORY ALLOCATION")
    print(f"{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        total = torch.cuda.get_device_properties(i).total_memory / 1024**3
        print(f"GPU {i}: {allocated:.2f} GB allocated / {reserved:.2f} GB reserved / {total:.2f} GB total")
    print(f"{'='*80}\n")


def load_and_process_data(csv_path, tokenizer, max_length=384, java_col='java', csharp_col='C#'):
    print(f"Loading dataset from: {csv_path}")
    df = pd.read_csv(csv_path)
    print(f"Dataset loaded: {len(df)} samples\n")
    data_list = []
    skipped = 0
    print("Processing dataset...")
    for idx, row in df.iterrows():
        try:
            java_code = str(row[java_col]).strip()
            csharp_code = str(row[csharp_col]).strip()
            if len(java_code) < 5 or len(csharp_code) < 5:
                skipped += 1
                continue
            prompt = f"[INST] Translate this Java code to C#:\n{java_code}\n[/INST]"
            full_text = f"{prompt} {csharp_code}"
            inputs = tokenizer(
                full_text,
                max_length=max_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )
            labels = inputs['input_ids'].clone()
            prompt_tokens = tokenizer(prompt, return_tensors='pt')['input_ids']
            prompt_length = prompt_tokens.shape[1]
            labels[0, :prompt_length] = -100
            data_list.append({
                'input_ids': inputs['input_ids'].squeeze(0),
                'attention_mask': inputs['attention_mask'].squeeze(0),
                'labels': labels.squeeze(0),
                'java_code': java_code,
                'csharp_code': csharp_code
            })
            if (idx + 1) % 1000 == 0:
                print(f"  Processed {idx + 1}/{len(df)} samples")
        except Exception as e:
            skipped += 1
            continue
    print(f"Processing complete: {len(data_list)} valid samples, {skipped} skipped\n")
    return data_list


def calculate_bleu_score(reference, hypothesis, max_n=4):
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    if len(hyp_tokens) == 0:
        return 0.0
    precisions = []
    for n in range(1, max_n + 1):
        ref_ngrams = Counter([tuple(ref_tokens[i:i+n]) for i in range(len(ref_tokens) - n + 1)])
        hyp_ngrams = Counter([tuple(hyp_tokens[i:i+n]) for i in range(len(hyp_tokens) - n + 1)])
        matches = sum((ref_ngrams & hyp_ngrams).values())
        total = sum(hyp_ngrams.values())
        precision = matches / total if total > 0 else 0.0
        precisions.append(precision)
    if all(p == 0 for p in precisions):
        return 0.0
    weights = [1.0 / max_n] * max_n
    log_precisions = [w * math.log(p) if p > 0 else float('-inf') for w, p in zip(weights, precisions)]
    if any(lp == float('-inf') for lp in log_precisions):
        return 0.0
    geometric_mean = math.exp(sum(log_precisions))
    brevity_penalty = min(1.0, math.exp(1 - len(ref_tokens) / len(hyp_tokens))) if len(hyp_tokens) > 0 else 0.0
    return brevity_penalty * geometric_mean


def calculate_meteor_score(reference, hypothesis):
    ref_tokens = set(reference.lower().split())
    hyp_tokens = set(hypothesis.lower().split())
    if len(hyp_tokens) == 0:
        return 0.0
    matches = len(ref_tokens & hyp_tokens)
    precision = matches / len(hyp_tokens) if len(hyp_tokens) > 0 else 0.0
    recall = matches / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f_mean = (10 * precision * recall) / (9 * precision + recall)
    return f_mean


def calculate_rouge_l(reference, hypothesis):
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    if len(ref_tokens) == 0 or len(hyp_tokens) == 0:
        return 0.0
    lcs_length = 0
    m, n = len(ref_tokens), len(hyp_tokens)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if ref_tokens[i-1] == hyp_tokens[j-1]:
                dp[i][j] = dp[i-1][j-1] + 1
            else:
                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
    lcs_length = dp[m][n]
    precision = lcs_length / len(hyp_tokens) if len(hyp_tokens) > 0 else 0.0
    recall = lcs_length / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def calculate_exact_match(reference, hypothesis):
    return 1.0 if reference.strip() == hypothesis.strip() else 0.0


def calculate_chrf_score(reference, hypothesis, n=6):
    def get_char_ngrams(text, n):
        chars = list(text)
        ngrams = []
        for i in range(len(chars) - n + 1):
            ngrams.append(''.join(chars[i:i+n]))
        return ngrams
    ref_ngrams = Counter(get_char_ngrams(reference, n))
    hyp_ngrams = Counter(get_char_ngrams(hypothesis, n))
    if len(hyp_ngrams) == 0:
        return 0.0
    matches = sum((ref_ngrams & hyp_ngrams).values())
    precision = matches / sum(hyp_ngrams.values()) if sum(hyp_ngrams.values()) > 0 else 0.0
    recall = matches / sum(ref_ngrams.values()) if sum(ref_ngrams.values()) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def extract_syntax_features(code):
    features = {
        'methods': len(re.findall(r'\b(public|private|protected|static)\s+\w+\s+\w+\s*\(', code)),
        'classes': len(re.findall(r'\bclass\s+\w+', code)),
        'if_statements': len(re.findall(r'\bif\s*\(', code)),
        'for_loops': len(re.findall(r'\bfor\s*\(', code)),
        'while_loops': len(re.findall(r'\bwhile\s*\(', code)),
        'return_statements': len(re.findall(r'\breturn\b', code)),
        'variables': len(re.findall(r'\b\w+\s*=\s*', code)),
        'try_catch': len(re.findall(r'\btry\s*\{', code)),
    }
    return features


def calculate_syntax_similarity(ref_features, hyp_features):
    if not ref_features or not hyp_features:
        return 0.0
    total_diff = 0
    total_count = 0
    for key in ref_features:
        ref_val = ref_features[key]
        hyp_val = hyp_features.get(key, 0)
        max_val = max(ref_val, hyp_val)
        if max_val > 0:
            similarity = 1.0 - abs(ref_val - hyp_val) / max_val
            total_diff += similarity
            total_count += 1
    return total_diff / total_count if total_count > 0 else 0.0


def calculate_code_similarity_metrics(reference, hypothesis):
    metrics = {}
    metrics['bleu_1'] = calculate_bleu_score(reference, hypothesis, max_n=1)
    metrics['bleu_2'] = calculate_bleu_score(reference, hypothesis, max_n=2)
    metrics['bleu_4'] = calculate_bleu_score(reference, hypothesis, max_n=4)
    metrics['meteor'] = calculate_meteor_score(reference, hypothesis)
    metrics['rouge_l'] = calculate_rouge_l(reference, hypothesis)
    metrics['exact_match'] = calculate_exact_match(reference, hypothesis)
    metrics['chrf'] = calculate_chrf_score(reference, hypothesis)
    ref_features = extract_syntax_features(reference)
    hyp_features = extract_syntax_features(hypothesis)
    metrics['syntax_similarity'] = calculate_syntax_similarity(ref_features, hyp_features)
    return metrics


def check_compilation(code, language='csharp'):
    try:
        if language == 'csharp':
            with tempfile.NamedTemporaryFile(mode='w', suffix='.cs', delete=False) as f:
                wrapper = f"using System;\nusing System.Collections.Generic;\nusing System.Linq;\n\nnamespace TempCompile {{\n    public class TempClass {{\n        {code}\n    }}\n}}"
                f.write(wrapper)
                temp_file = f.name
            result = subprocess.run(
                ['csc', '/nologo', '/t:library', '/noconfig', temp_file],
                capture_output=True,
                timeout=5,
                text=True
            )
            success = result.returncode == 0
            try:
                os.unlink(temp_file)
                dll_file = temp_file.replace('.cs', '.dll')
                if os.path.exists(dll_file):
                    os.unlink(dll_file)
            except:
                pass
            return success
    except:
        pass
    return False


def collate_fn(batch):
    return {
        'input_ids': torch.stack([item['input_ids'] for item in batch]),
        'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
        'labels': torch.stack([item['labels'] for item in batch]),
        'java_code': [item['java_code'] for item in batch],
        'csharp_code': [item['csharp_code'] for item in batch]
    }


def train_model(model, train_data, num_epochs=5, learning_rate=2e-5, batch_size=2):
    print(f"\n{'='*80}")
    print("TRAINING CONFIGURATION")
    print(f"{'='*80}")
    print(f"Number of Epochs: {num_epochs}")
    print(f"Learning Rate: {learning_rate}")
    print(f"Batch Size: {batch_size}")
    print(f"Total Training Samples: {len(train_data)}")
    print(f"Total Batches per Epoch: {len(train_data) // batch_size}")
    print(f"{'='*80}\n")
    prefix_params = list(model.prefix_encoder.parameters())
    optimizer = torch.optim.AdamW(prefix_params, lr=learning_rate, weight_decay=0.01, eps=1e-8)
    dataset = CodeTranslationDataset(train_data)
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=True
    )
    model.train()
    start_time = time.time()
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0.0
        valid_batches = 0
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 80)
        for batch_idx, batch in enumerate(dataloader):
            try:
                input_ids = batch['input_ids'].cuda()
                attention_mask = batch['attention_mask'].cuda()
                labels = batch['labels'].cuda()
                optimizer.zero_grad()
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs['loss']
                if loss is not None and not torch.isnan(loss) and not torch.isinf(loss):
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(prefix_params, max_norm=1.0)
                    optimizer.step()
                    total_loss += loss.item()
                    valid_batches += 1
                    if (batch_idx + 1) % 100 == 0:
                        avg_loss = total_loss / valid_batches
                        print(f"  Batch {batch_idx + 1}/{len(dataloader)} - Avg Loss: {avg_loss:.4f}")
                if (batch_idx + 1) % 500 == 0:
                    torch.cuda.empty_cache()
            except Exception as e:
                print(f"  Warning: Batch {batch_idx + 1} failed - {str(e)}")
                continue
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / valid_batches if valid_batches > 0 else 0.0
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"  Average Loss: {avg_loss:.4f}")
        print(f"  Valid Batches: {valid_batches}/{len(dataloader)}")
        print(f"  Time: {epoch_time:.2f}s")
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
        torch.cuda.empty_cache()
    total_time = time.time() - start_time
    print(f"\n{'='*80}")
    print("TRAINING COMPLETE")
    print(f"{'='*80}")
    print(f"Total Time: {total_time:.2f}s ({total_time/60:.2f} minutes)")
    print(f"Average Time per Epoch: {total_time/num_epochs:.2f}s")
    print(f"{'='*80}\n")

def evaluate_model(model, tokenizer, test_data, batch_size=1, max_new_tokens=256):
    print(f"\n{'='*80}")
    print("EVALUATION PHASE")
    print(f"{'='*80}")
    print(f"Test Samples: {len(test_data)}")
    print(f"Max Generation Length: {max_new_tokens}")
    print(f"{'='*80}\n")
    model.eval()
    all_metrics = {
        'bleu_1': [], 'bleu_2': [], 'bleu_4': [],
        'meteor': [], 'rouge_l': [], 'exact_match': [],
        'chrf': [], 'syntax_similarity': []
    }
    references = []
    hypotheses = []
    compilation_success = 0
    start_time = time.time()
    total_tokens = 0
    print("Generating translations...")
    with torch.no_grad():
        for idx, item in enumerate(test_data):
            try:
                java_code = item['java_code']
                csharp_code = item['csharp_code']
                prompt = f"[INST] Translate this Java code to C#:\n{java_code}\n[/INST]"
                inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=384)
                input_ids = inputs['input_ids'].cuda()
                attention_mask = inputs['attention_mask'].cuda()
                
                batch_size_gen = input_ids.size(0)
                prefix_embeds = model.prefix_encoder(batch_size_gen).to(input_ids.device)
                inputs_embeds = model.get_prompt_embedding(input_ids)
                inputs_embeds = torch.cat([prefix_embeds, inputs_embeds], dim=1)
                
                prefix_attention_mask = torch.ones(batch_size_gen, model.prefix_config.prefix_length, device=input_ids.device, dtype=attention_mask.dtype)
                full_attention_mask = torch.cat([prefix_attention_mask, attention_mask], dim=1)
                
                outputs = model.llm.generate(
                    inputs_embeds=inputs_embeds,
                    attention_mask=full_attention_mask,
                    max_new_tokens=max_new_tokens,
                    num_beams=4,
                    early_stopping=True,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )
                
                generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
                hypothesis = generated_text[len(prompt):].strip()
                
                references.append(csharp_code)
                hypotheses.append(hypothesis)
                metrics = calculate_code_similarity_metrics(csharp_code, hypothesis)
                for key in all_metrics:
                    all_metrics[key].append(metrics[key])
                if check_compilation(hypothesis):
                    compilation_success += 1
                total_tokens += len(outputs[0])
                if (idx + 1) % 50 == 0:
                    print(f"  Processed {idx + 1}/{len(test_data)} samples")
                if (idx + 1) % 100 == 0:
                    torch.cuda.empty_cache()
            except Exception as e:
                print(f"  Warning: Sample {idx + 1} failed - {str(e)}")
                for key in all_metrics:
                    all_metrics[key].append(0.0)
                references.append("")
                hypotheses.append("")
                continue
    inference_time = time.time() - start_time
    results = {
        'bleu_1': np.mean(all_metrics['bleu_1']),
        'bleu_1_std': np.std(all_metrics['bleu_1']),
        'bleu_2': np.mean(all_metrics['bleu_2']),
        'bleu_2_std': np.std(all_metrics['bleu_2']),
        'bleu_4': np.mean(all_metrics['bleu_4']),
        'bleu_4_std': np.std(all_metrics['bleu_4']),
        'meteor': np.mean(all_metrics['meteor']),
        'meteor_std': np.std(all_metrics['meteor']),
        'rouge_l': np.mean(all_metrics['rouge_l']),
        'rouge_l_std': np.std(all_metrics['rouge_l']),
        'exact_match': np.mean(all_metrics['exact_match']),
        'chrf': np.mean(all_metrics['chrf']),
        'chrf_std': np.std(all_metrics['chrf']),
        'syntax_similarity': np.mean(all_metrics['syntax_similarity']),
        'syntax_similarity_std': np.std(all_metrics['syntax_similarity']),
        'compilation_rate': (compilation_success / len(test_data)) * 100,
        'inference_time': inference_time,
        'throughput': len(test_data) / inference_time,
        'tokens_per_sec': total_tokens / inference_time,
        'total_tokens': total_tokens,
        'avg_tokens_per_sample': total_tokens / len(test_data)
    }
    return results


def print_evaluation_results(results):
    print(f"\n{'='*80}")
    print("EVALUATION RESULTS")
    print(f"{'='*80}\n")
    print(f"{'='*80}")
    print("N-GRAM BASED METRICS")
    print(f"{'='*80}")
    print(f"BLEU-1: {results['bleu_1']:.4f} (±{results['bleu_1_std']:.4f})")
    print(f"BLEU-2: {results['bleu_2']:.4f} (±{results['bleu_2_std']:.4f})")
    print(f"BLEU-4: {results['bleu_4']:.4f} (±{results['bleu_4_std']:.4f})")
    print(f"chrF: {results['chrf']:.4f} (±{results['chrf_std']:.4f})")
    print(f"\n{'='*80}")
    print("SEMANTIC SIMILARITY METRICS")
    print(f"{'='*80}")
    print(f"METEOR: {results['meteor']:.4f} (±{results['meteor_std']:.4f})")
    print(f"ROUGE-L: {results['rouge_l']:.4f} (±{results['rouge_l_std']:.4f})")
    print(f"\n{'='*80}")
    print("CODE-SPECIFIC METRICS")
    print(f"{'='*80}")
    print(f"Syntax Similarity: {results['syntax_similarity']:.4f} (±{results['syntax_similarity_std']:.4f})")
    print(f"Exact Match: {results['exact_match']:.4f}")
    print(f"Compilation Success Rate: {results['compilation_rate']:.2f}%")
    print(f"\n{'='*80}")
    print("PERFORMANCE METRICS")
    print(f"{'='*80}")
    print(f"Total Inference Time: {results['inference_time']:.2f}s ({results['inference_time']/60:.2f} min)")
    print(f"Throughput: {results['throughput']:.2f} samples/sec")
    print(f"Token Generation Rate: {results['tokens_per_sec']:.2f} tokens/sec")
    print(f"Total Tokens Generated: {results['total_tokens']}")
    print(f"Average Tokens per Sample: {results['avg_tokens_per_sample']:.2f}")
    print(f"{'='*80}\n")


def main():
    print(f"\n{'='*80}")
    print("JAVA TO C# CODE TRANSLATION")
    print("Using Prefix Tuning-Enhanced Mistral-7B")
    print(f"{'='*80}\n")
    print_system_info()
    torch.cuda.empty_cache()
    gc.collect()
    cache_dir = '/hf_cache'
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        "mistralai/Mistral-7B-Instruct-v0.2",
        cache_dir=cache_dir
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print("Tokenizer loaded\n")
    prefix_config = PrefixTuningConfig(
        prefix_length=10,
        num_layers=32,
        hidden_size=4096,
        num_attention_heads=32
    )
    print("Initializing model...")
    model = PrefixTuningMistral(
        "mistralai/Mistral-7B-Instruct-v0.2",
        prefix_config,
        cache_dir=cache_dir
    )
    print_parameter_statistics(model)
    print(f"\n{'='*80}")
    print("LOADING TRAINING DATA")
    print(f"{'='*80}\n")
    train_csv = '/train.csv'
    train_data = load_and_process_data(train_csv, tokenizer, java_col='java', csharp_col='C#')
    if len(train_data) == 0:
        print("ERROR: No training data loaded. Exiting.")
        return
    train_model(model, train_data, num_epochs=5, learning_rate=2e-5, batch_size=6)
    print(f"\n{'='*80}")
    print("LOADING TEST DATA")
    print(f"{'='*80}\n")
    test_csv = '/test.csv'
    test_data = load_and_process_data(test_csv, tokenizer, java_col='java', csharp_col='C#')
    if len(test_data) == 0:
        print("ERROR: No test data loaded. Exiting.")
        return
    results = evaluate_model(model, tokenizer, test_data, batch_size=4)
    print_evaluation_results(results)
    print(f"\n{'='*80}")
    print("EXECUTION COMPLETE")
    print(f"{'='*80}")
    print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*80}\n")


if __name__ == '__main__':
    main()

# adpater

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import re
import subprocess
import tempfile
from datetime import datetime
import math


class AdapterConfig:
    def __init__(self, adapter_size=64, num_layers=32, hidden_size=4096):
        self.adapter_size = adapter_size
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.adapter_dropout = 0.1


class CodeTranslationDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


class AdapterLayer(nn.Module):
    def __init__(self, hidden_size, adapter_size, dropout=0.1):
        super().__init__()
        self.down_project = nn.Linear(hidden_size, adapter_size)
        self.activation = nn.ReLU()
        self.up_project = nn.Linear(adapter_size, hidden_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, hidden_states):
        residual = hidden_states
        hidden_states = self.down_project(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.up_project(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states + residual


class AdapterTuningMistral(nn.Module):
    def __init__(self, model_name: str, adapter_config: AdapterConfig, cache_dir: str = None):
        super().__init__()
        torch.cuda.empty_cache()
        gc.collect()
        num_gpus = torch.cuda.device_count()
        print(f"\nInitializing model across {num_gpus} GPUs")
        for i in range(num_gpus):
            props = torch.cuda.get_device_properties(i)
            print(f"  GPU {i}: {props.name} - {props.total_memory / 1024**3:.2f} GB")
        self.primary_device = 'cuda:0'
        print(f"\nLoading base model: {model_name}")
        base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir,
            low_cpu_mem_usage=True,
            device_map='auto',
            trust_remote_code=True
        )
        self.llm = base_model
        self.adapter_config = adapter_config
        num_layers = len(base_model.model.layers)
        print(f"Model has {num_layers} transformer layers")
        for param in self.llm.parameters():
            param.requires_grad = False
        self.adapters = nn.ModuleList()
        for i in range(num_layers):
            adapter = AdapterLayer(
                adapter_config.hidden_size,
                adapter_config.adapter_size,
                adapter_config.adapter_dropout
            )
            self.adapters.append(adapter)
        device_map = self.llm.hf_device_map if hasattr(self.llm, 'hf_device_map') else {}
        for i, adapter in enumerate(self.adapters):
            layer_key = f'model.layers.{i}'
            if layer_key in device_map:
                target_device = device_map[layer_key]
            else:
                target_device = self.primary_device
            adapter.to(target_device).to(torch.bfloat16)
        self._register_forward_hooks()
        print(f"Adapters initialized with size {self.adapter_config.adapter_size}")
        print("Model initialization complete!\n")
    
    def _register_forward_hooks(self):

        def create_hook(adapter_idx):
            def hook(module, input, output):
                hidden_states = output[0] if isinstance(output, tuple) else output
                adapter = self.adapters[adapter_idx]
                adapter_device = next(adapter.parameters()).device
                hidden_states = hidden_states.to(adapter_device)
                adapted = adapter(hidden_states)
                if isinstance(output, tuple):
                    return (adapted,) + output[1:]
                return adapted
            return hook
        for i, layer in enumerate(self.llm.model.layers):
            layer.register_forward_hook(create_hook(i))
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.llm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True
        )
        return outputs


def print_system_info():
    print(f"\n{'='*80}")
    print("SYSTEM INFORMATION")
    print(f"{'='*80}")
    print(f"Date/Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"PyTorch Version: {torch.__version__}")
    print(f"CUDA Available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    print(f"Initial RAM Usage: {mem_info.rss / 1024**3:.2f} GB")
    virtual_mem = psutil.virtual_memory()
    print(f"Total System RAM: {virtual_mem.total / 1024**3:.2f} GB")
    print(f"Available RAM: {virtual_mem.available / 1024**3:.2f} GB")
    print(f"{'='*80}\n")


def print_parameter_statistics(model):
    total_params = sum(p.numel() for p in model.llm.parameters())
    trainable_params = sum(p.numel() for p in model.adapters.parameters())
    print(f"\n{'='*80}")
    print("PARAMETER STATISTICS")
    print(f"{'='*80}")
    print(f"Total LLM Parameters: {total_params:,}")
    print(f"Trainable Adapter Parameters: {trainable_params:,}")
    print(f"Trainable Percentage: {(trainable_params / (total_params + trainable_params) * 100):.4f}%")
    if trainable_params > 0:
        print(f"Parameter Efficiency: {total_params / trainable_params:.2f}x reduction")
    print(f"Adapter Size: {model.adapter_config.adapter_size}")
    print(f"Number of Adapter Layers: {len(model.adapters)}")
    print(f"{'='*80}\n")
    print(f"{'='*80}")
    print("GPU MEMORY ALLOCATION")
    print(f"{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        total = torch.cuda.get_device_properties(i).total_memory / 1024**3
        print(f"GPU {i}: {allocated:.2f} GB allocated / {reserved:.2f} GB reserved / {total:.2f} GB total")
    print(f"{'='*80}\n")


def load_and_process_data(csv_path, tokenizer, max_length=384, java_col='java', csharp_col='C#'):
    print(f"Loading dataset from: {csv_path}")
    df = pd.read_csv(csv_path)
    print(f"Dataset loaded: {len(df)} samples\n")
    data_list = []
    skipped = 0
    print("Processing dataset...")
    for idx, row in df.iterrows():
        try:
            java_code = str(row[java_col]).strip()
            csharp_code = str(row[csharp_col]).strip()
            if len(java_code) < 5 or len(csharp_code) < 5:
                skipped += 1
                continue
            prompt = f"[INST] Translate this Java code to C#:\n{java_code}\n[/INST]"
            full_text = f"{prompt} {csharp_code}"
            inputs = tokenizer(
                full_text,
                max_length=max_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )
            labels = inputs['input_ids'].clone()
            prompt_tokens = tokenizer(prompt, return_tensors='pt')['input_ids']
            prompt_length = prompt_tokens.shape[1]
            labels[0, :prompt_length] = -100
            data_list.append({
                'input_ids': inputs['input_ids'].squeeze(0),
                'attention_mask': inputs['attention_mask'].squeeze(0),
                'labels': labels.squeeze(0),
                'java_code': java_code,
                'csharp_code': csharp_code
            })
            if (idx + 1) % 1000 == 0:
                print(f"  Processed {idx + 1}/{len(df)} samples")
        except Exception as e:
            skipped += 1
            continue
    print(f"Processing complete: {len(data_list)} valid samples, {skipped} skipped\n")
    return data_list


def calculate_bleu_score(reference, hypothesis, max_n=4):
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    if len(hyp_tokens) == 0:
        return 0.0
    precisions = []
    for n in range(1, max_n + 1):
        ref_ngrams = Counter([tuple(ref_tokens[i:i+n]) for i in range(len(ref_tokens) - n + 1)])
        hyp_ngrams = Counter([tuple(hyp_tokens[i:i+n]) for i in range(len(hyp_tokens) - n + 1)])
        matches = sum((ref_ngrams & hyp_ngrams).values())
        total = sum(hyp_ngrams.values())
        precision = matches / total if total > 0 else 0.0
        precisions.append(precision)
    if all(p == 0 for p in precisions):
        return 0.0
    weights = [1.0 / max_n] * max_n
    log_precisions = [w * math.log(p) if p > 0 else float('-inf') for w, p in zip(weights, precisions)]
    if any(lp == float('-inf') for lp in log_precisions):
        return 0.0
    geometric_mean = math.exp(sum(log_precisions))
    brevity_penalty = min(1.0, math.exp(1 - len(ref_tokens) / len(hyp_tokens))) if len(hyp_tokens) > 0 else 0.0
    return brevity_penalty * geometric_mean


def calculate_meteor_score(reference, hypothesis):
    ref_tokens = set(reference.lower().split())
    hyp_tokens = set(hypothesis.lower().split())
    if len(hyp_tokens) == 0:
        return 0.0
    matches = len(ref_tokens & hyp_tokens)
    precision = matches / len(hyp_tokens) if len(hyp_tokens) > 0 else 0.0
    recall = matches / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f_mean = (10 * precision * recall) / (9 * precision + recall)
    return f_mean


def calculate_rouge_l(reference, hypothesis):
    ref_tokens = reference.split()
    hyp_tokens = hypothesis.split()
    if len(ref_tokens) == 0 or len(hyp_tokens) == 0:
        return 0.0
    lcs_length = 0
    m, n = len(ref_tokens), len(hyp_tokens)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if ref_tokens[i-1] == hyp_tokens[j-1]:
                dp[i][j] = dp[i-1][j-1] + 1
            else:
                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
    lcs_length = dp[m][n]
    precision = lcs_length / len(hyp_tokens) if len(hyp_tokens) > 0 else 0.0
    recall = lcs_length / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def calculate_exact_match(reference, hypothesis):
    return 1.0 if reference.strip() == hypothesis.strip() else 0.0


def calculate_chrf_score(reference, hypothesis, n=6):
    def get_char_ngrams(text, n):
        chars = list(text)
        ngrams = []
        for i in range(len(chars) - n + 1):
            ngrams.append(''.join(chars[i:i+n]))
        return ngrams
    ref_ngrams = Counter(get_char_ngrams(reference, n))
    hyp_ngrams = Counter(get_char_ngrams(hypothesis, n))
    if len(hyp_ngrams) == 0:
        return 0.0
    matches = sum((ref_ngrams & hyp_ngrams).values())
    precision = matches / sum(hyp_ngrams.values()) if sum(hyp_ngrams.values()) > 0 else 0.0
    recall = matches / sum(ref_ngrams.values()) if sum(ref_ngrams.values()) > 0 else 0.0
    if precision + recall == 0:
        return 0.0
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def extract_syntax_features(code):
    features = {
        'methods': len(re.findall(r'\b(public|private|protected|static)\s+\w+\s+\w+\s*\(', code)),
        'classes': len(re.findall(r'\bclass\s+\w+', code)),
        'if_statements': len(re.findall(r'\bif\s*\(', code)),
        'for_loops': len(re.findall(r'\bfor\s*\(', code)),
        'while_loops': len(re.findall(r'\bwhile\s*\(', code)),
        'return_statements': len(re.findall(r'\breturn\b', code)),
        'variables': len(re.findall(r'\b\w+\s*=\s*', code)),
        'try_catch': len(re.findall(r'\btry\s*\{', code)),
    }
    return features


def calculate_syntax_similarity(ref_features, hyp_features):
    if not ref_features or not hyp_features:
        return 0.0
    total_diff = 0
    total_count = 0
    for key in ref_features:
        ref_val = ref_features[key]
        hyp_val = hyp_features.get(key, 0)
        max_val = max(ref_val, hyp_val)
        if max_val > 0:
            similarity = 1.0 - abs(ref_val - hyp_val) / max_val
            total_diff += similarity
            total_count += 1
    return total_diff / total_count if total_count > 0 else 0.0


def calculate_code_similarity_metrics(reference, hypothesis):
    metrics = {}
    metrics['bleu_1'] = calculate_bleu_score(reference, hypothesis, max_n=1)
    metrics['bleu_2'] = calculate_bleu_score(reference, hypothesis, max_n=2)
    metrics['bleu_4'] = calculate_bleu_score(reference, hypothesis, max_n=4)
    metrics['meteor'] = calculate_meteor_score(reference, hypothesis)
    metrics['rouge_l'] = calculate_rouge_l(reference, hypothesis)
    metrics['exact_match'] = calculate_exact_match(reference, hypothesis)
    metrics['chrf'] = calculate_chrf_score(reference, hypothesis)
    ref_features = extract_syntax_features(reference)
    hyp_features = extract_syntax_features(hypothesis)
    metrics['syntax_similarity'] = calculate_syntax_similarity(ref_features, hyp_features)
    return metrics


def check_compilation(code, language='csharp'):
    try:
        if language == 'csharp':
            with tempfile.NamedTemporaryFile(mode='w', suffix='.cs', delete=False) as f:
                wrapper = f"using System;\nusing System.Collections.Generic;\nusing System.Linq;\n\nnamespace TempCompile {{\n    public class TempClass {{\n        {code}\n    }}\n}}"
                f.write(wrapper)
                temp_file = f.name
            result = subprocess.run(
                ['csc', '/nologo', '/t:library', '/noconfig', temp_file],
                capture_output=True,
                timeout=5,
                text=True
            )
            success = result.returncode == 0
            try:
                os.unlink(temp_file)
                dll_file = temp_file.replace('.cs', '.dll')
                if os.path.exists(dll_file):
                    os.unlink(dll_file)
            except:
                pass
            return success
    except:
        pass
    return False


def collate_fn(batch):
    return {
        'input_ids': torch.stack([item['input_ids'] for item in batch]),
        'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
        'labels': torch.stack([item['labels'] for item in batch]),
        'java_code': [item['java_code'] for item in batch],
        'csharp_code': [item['csharp_code'] for item in batch]
    }


def train_model(model, train_data, num_epochs=5, learning_rate=2e-5, batch_size=2):
    print(f"\n{'='*80}")
    print("TRAINING CONFIGURATION")
    print(f"{'='*80}")
    print(f"Number of Epochs: {num_epochs}")
    print(f"Learning Rate: {learning_rate}")
    print(f"Batch Size: {batch_size}")
    print(f"Total Training Samples: {len(train_data)}")
    print(f"Total Batches per Epoch: {len(train_data) // batch_size}")
    print(f"{'='*80}\n")
    adapter_params = list(model.adapters.parameters())
    optimizer = torch.optim.AdamW(adapter_params, lr=learning_rate, weight_decay=0.01, eps=1e-8)
    dataset = CodeTranslationDataset(train_data)
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=True
    )
    model.train()
    start_time = time.time()
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0.0
        valid_batches = 0
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 80)
        for batch_idx, batch in enumerate(dataloader):
            try:
                input_ids = batch['input_ids'].cuda()
                attention_mask = batch['attention_mask'].cuda()
                labels = batch['labels'].cuda()
                optimizer.zero_grad()
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                if loss is not None and not torch.isnan(loss) and not torch.isinf(loss):
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(adapter_params, max_norm=1.0)
                    optimizer.step()
                    total_loss += loss.item()
                    valid_batches += 1
                    if (batch_idx + 1) % 100 == 0:
                        avg_loss = total_loss / valid_batches
                        print(f"  Batch {batch_idx + 1}/{len(dataloader)} - Avg Loss: {avg_loss:.4f}")
                if (batch_idx + 1) % 500 == 0:
                    torch.cuda.empty_cache()
            except Exception as e:
                print(f"  Warning: Batch {batch_idx + 1} failed - {str(e)}")
                continue
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / valid_batches if valid_batches > 0 else 0.0
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"  Average Loss: {avg_loss:.4f}")
        print(f"  Valid Batches: {valid_batches}/{len(dataloader)}")
        print(f"  Time: {epoch_time:.2f}s")
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
        torch.cuda.empty_cache()
    total_time = time.time() - start_time
    print(f"\n{'='*80}")
    print("TRAINING COMPLETE")
    print(f"{'='*80}")
    print(f"Total Time: {total_time:.2f}s ({total_time/60:.2f} minutes)")
    print(f"Average Time per Epoch: {total_time/num_epochs:.2f}s")
    print(f"{'='*80}\n")


def evaluate_model(model, tokenizer, test_data, batch_size=1, max_new_tokens=256):
    print(f"\n{'='*80}")
    print("EVALUATION PHASE")
    print(f"{'='*80}")
    print(f"Test Samples: {len(test_data)}")
    print(f"Max Generation Length: {max_new_tokens}")
    print(f"{'='*80}\n")
    model.eval()
    all_metrics = {
        'bleu_1': [], 'bleu_2': [], 'bleu_4': [],
        'meteor': [], 'rouge_l': [], 'exact_match': [],
        'chrf': [], 'syntax_similarity': []
    }
    references = []
    hypotheses = []
    compilation_success = 0
    start_time = time.time()
    total_tokens = 0
    print("Generating translations...")
    with torch.no_grad():
        for idx, item in enumerate(test_data):
            try:
                java_code = item['java_code']
                csharp_code = item['csharp_code']
                prompt = f"[INST] Translate this Java code to C#:\n{java_code}\n[/INST]"
                inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=384)
                input_ids = inputs['input_ids'].cuda()
                attention_mask = inputs['attention_mask'].cuda()
                outputs = model.llm.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_new_tokens=max_new_tokens,
                    num_beams=4,
                    early_stopping=True,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )
                generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
                hypothesis = generated_text[len(prompt):].strip()
                references.append(csharp_code)
                hypotheses.append(hypothesis)
                metrics = calculate_code_similarity_metrics(csharp_code, hypothesis)
                for key in all_metrics:
                    all_metrics[key].append(metrics[key])
                if check_compilation(hypothesis):
                    compilation_success += 1
                total_tokens += len(outputs[0])
                if (idx + 1) % 50 == 0:
                    print(f"  Processed {idx + 1}/{len(test_data)} samples")
                if (idx + 1) % 100 == 0:
                    torch.cuda.empty_cache()
            except Exception as e:
                print(f"  Warning: Sample {idx + 1} failed - {str(e)}")
                for key in all_metrics:
                    all_metrics[key].append(0.0)
                references.append("")
                hypotheses.append("")
                continue
    inference_time = time.time() - start_time
    results = {
        'bleu_1': np.mean(all_metrics['bleu_1']),
        'bleu_1_std': np.std(all_metrics['bleu_1']),
        'bleu_2': np.mean(all_metrics['bleu_2']),
        'bleu_2_std': np.std(all_metrics['bleu_2']),
        'bleu_4': np.mean(all_metrics['bleu_4']),
        'bleu_4_std': np.std(all_metrics['bleu_4']),
        'meteor': np.mean(all_metrics['meteor']),
        'meteor_std': np.std(all_metrics['meteor']),
        'rouge_l': np.mean(all_metrics['rouge_l']),
        'rouge_l_std': np.std(all_metrics['rouge_l']),
        'exact_match': np.mean(all_metrics['exact_match']),
        'chrf': np.mean(all_metrics['chrf']),
        'chrf_std': np.std(all_metrics['chrf']),
        'syntax_similarity': np.mean(all_metrics['syntax_similarity']),
        'syntax_similarity_std': np.std(all_metrics['syntax_similarity']),
        'compilation_rate': (compilation_success / len(test_data)) * 100,
        'inference_time': inference_time,
        'throughput': len(test_data) / inference_time,
        'tokens_per_sec': total_tokens / inference_time,
        'total_tokens': total_tokens,
        'avg_tokens_per_sample': total_tokens / len(test_data)
    }
    return results


def print_evaluation_results(results):
    print(f"\n{'='*80}")
    print("EVALUATION RESULTS")
    print(f"{'='*80}\n")
    print(f"{'='*80}")
    print("N-GRAM BASED METRICS")
    print(f"{'='*80}")
    print(f"BLEU-1: {results['bleu_1']:.4f} (±{results['bleu_1_std']:.4f})")
    print(f"BLEU-2: {results['bleu_2']:.4f} (±{results['bleu_2_std']:.4f})")
    print(f"BLEU-4: {results['bleu_4']:.4f} (±{results['bleu_4_std']:.4f})")
    print(f"chrF: {results['chrf']:.4f} (±{results['chrf_std']:.4f})")
    print(f"\n{'='*80}")
    print("SEMANTIC SIMILARITY METRICS")
    print(f"{'='*80}")
    print(f"METEOR: {results['meteor']:.4f} (±{results['meteor_std']:.4f})")
    print(f"ROUGE-L: {results['rouge_l']:.4f} (±{results['rouge_l_std']:.4f})")
    print(f"\n{'='*80}")
    print("CODE-SPECIFIC METRICS")
    print(f"{'='*80}")
    print(f"Syntax Similarity: {results['syntax_similarity']:.4f} (±{results['syntax_similarity_std']:.4f})")
    print(f"Exact Match: {results['exact_match']:.4f}")
    print(f"Compilation Success Rate: {results['compilation_rate']:.2f}%")
    print(f"\n{'='*80}")
    print("PERFORMANCE METRICS")
    print(f"{'='*80}")
    print(f"Total Inference Time: {results['inference_time']:.2f}s ({results['inference_time']/60:.2f} min)")
    print(f"Throughput: {results['throughput']:.2f} samples/sec")
    print(f"Token Generation Rate: {results['tokens_per_sec']:.2f} tokens/sec")
    print(f"Total Tokens Generated: {results['total_tokens']}")
    print(f"Average Tokens per Sample: {results['avg_tokens_per_sample']:.2f}")
    print(f"{'='*80}\n")


def main():
    print(f"\n{'='*80}")
    print("JAVA TO C# CODE TRANSLATION")
    print("Using Adapter Tuning-Enhanced Mistral-7B")
    print(f"{'='*80}\n")
    print_system_info()
    torch.cuda.empty_cache()
    gc.collect()
    cache_dir = '/hf_cache'
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        "mistralai/Mistral-7B-Instruct-v0.2",
        cache_dir=cache_dir
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print("Tokenizer loaded\n")
    adapter_config = AdapterConfig(
        adapter_size=64,
        num_layers=32,
        hidden_size=4096
    )
    print("Initializing model...")
    model = AdapterTuningMistral(
        "mistralai/Mistral-7B-Instruct-v0.2",
        adapter_config,
        cache_dir=cache_dir
    )
    print_parameter_statistics(model)
    print(f"\n{'='*80}")
    print("LOADING TRAINING DATA")
    print(f"{'='*80}\n")
    train_csv = '/train.csv'
    train_data = load_and_process_data(train_csv, tokenizer, java_col='java', csharp_col='C#')
    if len(train_data) == 0:
        print("ERROR: No training data loaded. Exiting.")
        return
    train_model(model, train_data, num_epochs=5, learning_rate=2e-5, batch_size=6)
    print(f"\n{'='*80}")
    print("LOADING TEST DATA")
    print(f"{'='*80}\n")
    test_csv = '/test.csv'
    test_data = load_and_process_data(test_csv, tokenizer, java_col='java', csharp_col='C#')
    if len(test_data) == 0:
        print("ERROR: No test data loaded. Exiting.")
        return
    results = evaluate_model(model, tokenizer, test_data, batch_size=4)
    print_evaluation_results(results)
    print(f"\n{'='*80}")
    print("EXECUTION COMPLETE")
    print(f"{'='*80}")
    print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*80}\n")
    
if __name__ == '__main__':
    main()