# CodeQwuen LLMs with LoRA 

In [None]:
import math
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, TaskType, get_peft_model
from datasets import Dataset as HFDataset
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, confusion_matrix,
    balanced_accuracy_score, precision_score, recall_score, jaccard_score,
    matthews_corrcoef, cohen_kappa_score, roc_auc_score, average_precision_score,
    log_loss, brier_score_loss
)
import pandas as pd
import numpy as np
import os
import logging
import time
from typing import List, Dict, Tuple
import psutil
import gc

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


class LoRATrainer:
    def __init__(self, train_file, test_file, func1_col, func2_col, label_col, cache_dir=None):
        self.train_file = train_file
        self.test_file = test_file
        self.func1_col = func1_col
        self.func2_col = func2_col
        self.label_col = label_col
        self.cache_dir = cache_dir
        
        torch.cuda.empty_cache()
        gc.collect()
        
        num_gpus = torch.cuda.device_count()
        logger.info(f"Number of available GPUs: {num_gpus}")
        for i in range(num_gpus):
            logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        
        logger.info("Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            "Qwen/CodeQwen1.5-7B-Chat", 
            trust_remote_code=True,
            cache_dir=cache_dir,
            local_files_only=True
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        logger.info("Loading base model with automatic device mapping...")
        self.base_model = AutoModelForCausalLM.from_pretrained(
            "Qwen/CodeQwen1.5-7B-Chat",
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
            cache_dir=cache_dir,
            local_files_only=True,
            low_cpu_mem_usage=True
        )
        
        self.print_model_distribution()
        
        logger.info("Loading datasets...")
        self.train_data = HFDataset.from_csv(train_file)
        self.test_data = HFDataset.from_csv(test_file)
        logger.info(f"Train samples: {len(self.train_data)}")
        logger.info(f"Test samples: {len(self.test_data)}")
        
        self.training_stats = {
            'epoch_times': [],
            'epoch_losses': [],
            'memory_usage': []
        }
    
    def print_model_distribution(self):
        logger.info("\n" + "="*80)
        logger.info("MODEL DEVICE DISTRIBUTION")
        logger.info("="*80)
        
        if hasattr(self.base_model, 'hf_device_map'):
            device_map = self.base_model.hf_device_map
            for module_name, device in device_map.items():
                logger.info(f"{module_name}: {device}")
        
        logger.info("\n" + "="*80)
        logger.info("GPU MEMORY USAGE")
        logger.info("="*80)
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            reserved = torch.cuda.memory_reserved(i) / 1024**3
            logger.info(f"GPU {i}: Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
        logger.info("="*80 + "\n")
        
    def preprocess(self, batch):
        c1 = batch[self.func1_col]
        c2 = batch[self.func2_col]
        y = batch[self.label_col]
        text = f"<func1>\n{c1}\n</func1>\n<func2>\n{c2}\n</func2>\nLabel: {y}"
        out = self.tokenizer(text, truncation=True, padding="max_length", max_length=512)
        out["labels"] = out["input_ids"]
        return out
    
    def count_parameters(self, model):
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        return total_params, trainable_params
    
    def calculate_metrics(self, y_true, y_pred, y_prob=None):
        cm = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = cm.ravel()
        
        acc = accuracy_score(y_true, y_pred)
        bal_acc = balanced_accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred, zero_division=0)
        rec = recall_score(y_true, y_pred, zero_division=0)
        
        p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary', zero_division=0)
        
        jacc = jaccard_score(y_true, y_pred, zero_division=0)
        mcc = matthews_corrcoef(y_true, y_pred)
        kappa = cohen_kappa_score(y_true, y_pred)
        
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
        fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
        
        metrics = {
            'cm': cm,
            'acc': acc,
            'bal_acc': bal_acc,
            'prec': prec,
            'rec': rec,
            'f1': f1,
            'jacc': jacc,
            'mcc': mcc,
            'kappa': kappa,
            'specificity': specificity,
            'npv': npv,
            'fpr': fpr,
            'fnr': fnr
        }
        
        if y_prob is not None:
            try:
                roc_auc = roc_auc_score(y_true, y_prob)
                pr_auc = average_precision_score(y_true, y_prob)
                ll = log_loss(y_true, y_prob)
                brier = brier_score_loss(y_true, y_prob)
                
                metrics.update({
                    'roc_auc': roc_auc,
                    'pr_auc': pr_auc,
                    'log_loss': ll,
                    'brier': brier
                })
            except:
                pass
        
        return metrics
    
    def calculate_recall_at_k(self, embeddings, labels, k_values=[1, 3, 5, 10]):
        results = {}
        
        similarity_matrix = torch.nn.functional.cosine_similarity(
            embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2
        )
        
        for k in k_values:
            correct = 0
            total = 0
            
            for i in range(len(labels)):
                query_label = labels[i]
                similarities = similarity_matrix[i]
                similarities[i] = -float('inf')
                
                top_k_indices = torch.topk(similarities, k).indices
                top_k_labels = [labels[idx] for idx in top_k_indices]
                
                if query_label in top_k_labels:
                    correct += 1
                total += 1
            
            results[f'recall@{k}'] = correct / total if total > 0 else 0.0
        
        reciprocal_ranks = []
        for i in range(len(labels)):
            query_label = labels[i]
            similarities = similarity_matrix[i]
            similarities[i] = -float('inf')
            
            sorted_indices = torch.argsort(similarities, descending=True)
            sorted_labels = [labels[idx] for idx in sorted_indices]
            
            try:
                rank = sorted_labels.index(query_label) + 1
                reciprocal_ranks.append(1.0 / rank)
            except ValueError:
                reciprocal_ranks.append(0.0)
        
        results['mrr'] = np.mean(reciprocal_ranks)
        
        return results
    
    def train(self):
        train_dataset = self.train_data.map(self.preprocess, batched=False)
        test_dataset = self.test_data.map(self.preprocess, batched=False)
        
        logger.info("\n" + "="*80)
        logger.info("APPLYING LORA CONFIGURATION")
        logger.info("="*80)
        
        peft_config = LoraConfig(
            r=8,
            lora_alpha=16,
            lora_dropout=0.0,
            target_modules=["q_proj", "v_proj"],
            task_type=TaskType.CAUSAL_LM
        )
        
        model = get_peft_model(self.base_model, peft_config)
        
        total_params, trainable_params = self.count_parameters(model)
        
        logger.info("\n" + "="*80)
        logger.info("PARAMETER STATISTICS")
        logger.info("="*80)
        logger.info(f"Total parameters: {total_params:,}")
        logger.info(f"Trainable parameters: {trainable_params:,}")
        logger.info(f"Trainable percentage: {100 * trainable_params / total_params:.4f}%")
        logger.info("="*80 + "\n")
        
        self.print_model_distribution()
        
        training_args = TrainingArguments(
            output_dir="codeqwen1.5-clone-lora",
            per_device_train_batch_size=4,
            per_device_eval_batch_size=4,
            gradient_accumulation_steps=4,
            num_train_epochs=5,
            learning_rate=2e-4,
            lr_scheduler_type="cosine",
            warmup_ratio=0.03,
            logging_steps=20,
            save_total_limit=2,
            eval_strategy="epoch",
            save_strategy="epoch",
            fp16=True,
            bf16=False,
            dataloader_num_workers=0,
            dataloader_pin_memory=True,
            gradient_checkpointing=False,
            optim="adamw_torch"
        )
        
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=test_dataset
        )
        
        logger.info("="*80)
        logger.info("STARTING TRAINING")
        logger.info("="*80)
        
        training_start = time.time()
        
        process = psutil.Process(os.getpid())
        mem_before = process.memory_info().rss / 1024 / 1024
        
        train_result = trainer.train()
        
        total_training_time = time.time() - training_start
        
        mem_after = process.memory_info().rss / 1024 / 1024
        mem_used = mem_after - mem_before
        
        logger.info(f"\n{'='*80}")
        logger.info("TRAINING COMPLETE")
        logger.info(f"{'='*80}")
        logger.info(f"Total Training Time: {total_training_time:.2f}s")
        logger.info(f"Memory Usage: {mem_used:.1f} MB")
        logger.info(f"{'='*80}\n")
        
        self.print_model_distribution()
        
        model.save_pretrained("codeqwen1.5-lora-clone")
        
        torch.cuda.empty_cache()
        gc.collect()
        
        return model, trainer, total_training_time
    
    def evaluate_comprehensive(self, model, test_dataset):
        model.eval()
        
        logger.info("="*80)
        logger.info("STARTING COMPREHENSIVE EVALUATION")
        logger.info("="*80)
        
        test_df = pd.read_csv(self.test_file)
        test_labels = test_df[self.label_col].tolist()
        
        all_preds = []
        all_probs = []
        all_embeddings = []
        
        inference_start = time.time()
        total_samples = len(test_dataset)
        
        logger.info(f"Processing {total_samples} test samples...")
        
        with torch.no_grad():
            for idx in range(len(test_dataset)):
                sample = test_dataset[idx]
                
                input_ids = torch.tensor([sample['input_ids']])
                attention_mask = torch.tensor([sample['attention_mask']])
                
                first_device = next(model.parameters()).device
                input_ids = input_ids.to(first_device)
                attention_mask = attention_mask.to(first_device)
                
                outputs = model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask, 
                    output_hidden_states=True
                )
                
                hidden_states = outputs.hidden_states[-1]
                pooled = hidden_states.mean(dim=1)
                
                logits = pooled.squeeze()
                all_embeddings.append(logits.cpu())
                
                pred = 1 if logits.mean().item() > 0 else 0
                prob = torch.sigmoid(logits.mean()).item()
                
                all_preds.append(pred)
                all_probs.append(prob)
                
                if (idx + 1) % 100 == 0:
                    logger.info(f"Processed {idx + 1}/{total_samples} samples")
                    torch.cuda.empty_cache()
        
        inference_time = time.time() - inference_start
        samples_per_sec = total_samples / inference_time
        
        logger.info(f"Inference completed in {inference_time:.2f}s")
        logger.info(f"Throughput: {samples_per_sec:.2f} samples/sec")
        
        metrics = self.calculate_metrics(test_labels, all_preds, all_probs)
        
        all_embeddings = torch.stack(all_embeddings)
        recall_metrics = self.calculate_recall_at_k(all_embeddings, test_labels)
        metrics.update(recall_metrics)
        
        metrics['inference_time'] = inference_time
        metrics['samples_per_sec'] = samples_per_sec
        
        torch.cuda.empty_cache()
        gc.collect()
        
        return metrics


def print_results(metrics, dataset_name):
    print(f"\n{'='*80}")
    print(f"RESULTS: {dataset_name}")
    print(f"{'='*80}")
    print(f"Confusion Matrix:\n{metrics['cm']}")
    print(f"Accuracy: {metrics['acc']:.4f}")
    print(f"Balanced Accuracy: {metrics['bal_acc']:.4f}")
    print(f"Precision: {metrics['prec']:.4f}")
    print(f"Recall: {metrics['rec']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    print(f"Jaccard: {metrics['jacc']:.4f}")
    print(f"MCC: {metrics['mcc']:.4f}")
    print(f"Cohen's Kappa: {metrics['kappa']:.4f}")
    print(f"Specificity: {metrics['specificity']:.4f}")
    print(f"NPV: {metrics['npv']:.4f}")
    print(f"FPR: {metrics['fpr']:.4f}")
    print(f"FNR: {metrics['fnr']:.4f}")
    
    if 'roc_auc' in metrics:
        print(f"ROC AUC: {metrics['roc_auc']:.4f}")
    if 'pr_auc' in metrics:
        print(f"PR AUC: {metrics['pr_auc']:.4f}")
    if 'log_loss' in metrics:
        print(f"Log Loss: {metrics['log_loss']:.4f}")
    if 'brier' in metrics:
        print(f"Brier Score: {metrics['brier']:.4f}")
    
    if 'recall@1' in metrics:
        print(f"\nRetrieval Metrics:")
        print(f"Recall@1: {metrics['recall@1']:.4f}")
        print(f"Recall@3: {metrics['recall@3']:.4f}")
        print(f"Recall@5: {metrics['recall@5']:.4f}")
        print(f"Recall@10: {metrics['recall@10']:.4f}")
        print(f"MRR: {metrics['mrr']:.4f}")
    
    if 'inference_time' in metrics:
        print(f"\nPerformance:")
        print(f"Inference Time: {metrics['inference_time']:.2f}s")
        print(f"Samples/sec: {metrics['samples_per_sec']:.0f}")
    
    print(f"{'='*80}\n")


def run_lora_experiment(train_file, test_file, func1_col, func2_col, label_col, cache_dir=None):
    logger.info("="*80)
    logger.info("CODE CLONE DETECTION WITH LORA")
    logger.info("="*80)
    
    lora_trainer = LoRATrainer(
        train_file, 
        test_file, 
        func1_col, 
        func2_col, 
        label_col,
        cache_dir=cache_dir
    )
    
    model, trainer, training_time = lora_trainer.train()
    
    logger.info(f"\n{'='*80}")
    logger.info("FINAL TEST EVALUATION")
    logger.info(f"{'='*80}")
    
    test_dataset = lora_trainer.test_data.map(lora_trainer.preprocess, batched=False)
    metrics = lora_trainer.evaluate_comprehensive(model, test_dataset)
    
    print_results(metrics, "TEST SET")
    
    return lora_trainer, model, metrics


if __name__ == "__main__":
    train_file = 'train.csv'
    test_file = '/test.csv'
    func1_col = "func1"
    func2_col = "func2"
    label_col = "label"
    cache_dir = '/hf_cache'
    
    try:
        lora_trainer, model, results = run_lora_experiment(
            train_file, 
            test_file, 
            func1_col, 
            func2_col, 
            label_col,
            cache_dir=cache_dir
        )
        logger.info("LoRA experiment completed successfully!")
        
    except Exception as e:
        logger.error(f"Error: {e}")
        raise

# adpater tuning with CodeQueen

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
from sklearn.metrics import (
    accuracy_score, confusion_matrix, balanced_accuracy_score, 
    matthews_corrcoef, roc_auc_score, average_precision_score, f1_score
)
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader


class Activation_Function_Class(nn.Module):
    def __init__(self, hidden_act):
        super().__init__()
        if hidden_act.lower() == "relu":
            self.f = nn.functional.relu
        elif hidden_act.lower() == "tanh":
            self.f = torch.tanh
        elif hidden_act.lower() == "swish":
            def swish(x):
                return x * torch.sigmoid(x)
            self.f = swish
        elif hidden_act.lower() == "gelu":
            def gelu_new(x):
                return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
            self.f = gelu_new
        elif hidden_act.lower() == "leakyrelu":
            self.f = nn.functional.leaky_relu

    def forward(self, x):
        return self.f(x)


class Adapter(nn.Module):
    def __init__(
            self,
            input_size,
            down_sample=None,
            non_linearity="relu",
            init_bert_weights=True,
            add_layer_norm_before=True,
            residual_before_ln=True,
    ):
        super().__init__()
        self.input_size = input_size
        self.add_layer_norm_before = add_layer_norm_before
        self.residual_before_ln = residual_before_ln
        
        seq_list = []
        
        if self.add_layer_norm_before:
            self.adapter_norm_before = nn.LayerNorm(self.input_size)
            seq_list.append(self.adapter_norm_before)
        
        self.down_sample = down_sample
        if down_sample is None:
            self.down_sample = self.input_size // 2
        
        seq_list.append(nn.Linear(self.input_size, self.down_sample))
        
        self.non_linearity = Activation_Function_Class(non_linearity.lower())
        seq_list.append(self.non_linearity)
        
        self.adapter_down = nn.Sequential(*seq_list)
        self.adapter_up = nn.Linear(self.down_sample, self.input_size)
        
        if init_bert_weights:
            self.adapter_down.apply(self.init_bert_weights)
            self.adapter_up.apply(self.init_bert_weights)

    def forward(self, x):
        input_dtype = x.dtype
        down = self.adapter_down(x)
        up = self.adapter_up(down)
        output = up
        output = output + x
        return output.to(input_dtype)

    @staticmethod
    def init_bert_weights(module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


class CodeCloneDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


class AdapterTuningCodeQwen(nn.Module):
    def __init__(self, model_name: str, adapter_size: int = 512, num_classes: int = 2, cache_dir: str = None):
        super().__init__()
        
        torch.cuda.empty_cache()
        gc.collect()
        
        num_gpus = torch.cuda.device_count()
        print(f"Initializing model across {num_gpus} GPUs")
        
        self.primary_device = 'cuda:0'
        
        print("Loading base LLM model...")
        base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            cache_dir=cache_dir,
            low_cpu_mem_usage=True,
            device_map='auto'
        )
        
        self.llm = base_model
        
        total_layers = len(base_model.model.layers)
        print(f"Model has {total_layers} transformer layers")
        
        self.layer_devices = []
        for i in range(total_layers):
            layer_device = next(base_model.model.layers[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Layer {i} -> {layer_device}")
        
        embed_device = next(base_model.model.embed_tokens.parameters()).device
        norm_device = next(base_model.model.norm.parameters()).device
        lm_head_device = next(base_model.lm_head.parameters()).device
        
        print(f"Embeddings -> {embed_device}")
        print(f"Norm -> {norm_device}")
        print(f"LM Head -> {lm_head_device}")
        
        self.embed_device = str(embed_device)
        self.final_device = str(norm_device)
        
        self.hidden_size = base_model.config.hidden_size
        self.num_classes = num_classes
        
        self.adapters = nn.ModuleList([
            Adapter(
                input_size=self.hidden_size,
                down_sample=adapter_size,
                non_linearity="gelu",
                init_bert_weights=True,
                add_layer_norm_before=True
            ) for _ in range(total_layers)
        ])
        
        for i, adapter in enumerate(self.adapters):
            adapter.to(self.layer_devices[i]).half()
        
        self.classifier = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(self.hidden_size // 4, num_classes)
        ).to(self.primary_device)
        
        for module in self.classifier.modules():
            if isinstance(module, nn.Linear):
                module.weight.data = module.weight.data.half()
                if module.bias is not None:
                    module.bias.data = module.bias.data.half()
        
        for param in self.llm.parameters():
            param.requires_grad = False
        
        for adapter in self.adapters:
            for param in adapter.parameters():
                param.requires_grad = True
            
        for param in self.classifier.parameters():
            param.requires_grad = True
        
        print("Model initialization complete!")
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        batch_size = input_ids.shape[0]
        
        input_ids = input_ids.to(self.embed_device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.embed_device)
        
        with torch.no_grad():
            outputs = self.llm(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
                return_dict=True
            )
        
        hidden_states = outputs.hidden_states[-1]
        
        for i, adapter in enumerate(self.adapters):
            if i < len(self.adapters) - 1:
                continue
            else:
                adapter_device = self.layer_devices[i]
                hidden_states = hidden_states.to(adapter_device)
                hidden_states = adapter(hidden_states)
        
        if str(hidden_states.device) != self.primary_device:
            hidden_states = hidden_states.to(self.primary_device)
        
        if attention_mask is not None:
            attention_mask_expanded = attention_mask.to(self.primary_device).unsqueeze(-1).expand(hidden_states.size()).float()
            attention_mask_expanded = attention_mask_expanded.half()
            sum_embeddings = torch.sum(hidden_states * attention_mask_expanded, 1)
            sum_mask = torch.clamp(attention_mask_expanded.sum(1), min=1e-9)
            pooled_output = sum_embeddings / sum_mask
        else:
            pooled_output = hidden_states.mean(dim=1)
        
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            if labels.device != logits.device:
                labels = labels.to(logits.device)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        
        return {'logits': logits, 'loss': loss, 'hidden_states': hidden_states}


def print_parameter_statistics(model: AdapterTuningCodeQwen):
    total_params = sum(p.numel() for p in model.llm.parameters())
    adapter_params = sum(p.numel() for p in model.adapters.parameters())
    classifier_params = sum(p.numel() for p in model.classifier.parameters())
    trainable_params = adapter_params + classifier_params
    
    print(f"\n{'='*80}")
    print(f"PARAMETER STATISTICS")
    print(f"{'='*80}")
    print(f"Total Frozen LLM Parameters: {total_params:,}")
    print(f"Trainable Adapter Parameters: {adapter_params:,}")
    print(f"Trainable Classifier Parameters: {classifier_params:,}")
    print(f"Total Trainable Parameters: {trainable_params:,}")
    print(f"Trainable Ratio: {(trainable_params / total_params * 100):.4f}%")
    print(f"{'='*80}\n")
    
    print(f"{'='*80}")
    print(f"GPU MEMORY USAGE")
    print(f"{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        print(f"GPU {i}: Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
    print(f"{'='*80}\n")


def load_dataset(csv_path, label_col, func1_col, func2_col):
    try:
        df = pd.read_csv(csv_path)
        print(f"Loaded dataset: {len(df)} samples")
        print(f"Label distribution: {df[label_col].value_counts().to_dict()}")
        return df
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None


def create_training_data(csv_path, tokenizer, max_length=512, 
                        label_col='label', func1_col='func1', func2_col='func2'):
    df = load_dataset(csv_path, label_col, func1_col, func2_col)
    
    if df is None:
        return []
    
    train_data = []
    
    for _, row in df.iterrows():
        func1 = str(row[func1_col])
        func2 = str(row[func2_col])
        label = int(row[label_col])
        
        combined_text = f"Code1: {func1}\nCode2: {func2}\nAre these code clones?"
        
        encoding = tokenizer(
            combined_text,
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=max_length
        )
        
        train_data.append({
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        })
    
    return train_data


def get_memory_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024


def calculate_recall_at_k(embeddings, labels, k_values=[1, 5, 10]):
    similarities = torch.mm(embeddings, embeddings.t())
    
    results = {}
    for k in k_values:
        correct = 0
        total = 0
        
        for i in range(len(labels)):
            true_label = labels[i]
            sims = similarities[i].clone()
            sims[i] = -float('inf')
            
            top_k_indices = torch.topk(sims, min(k, len(labels)-1)).indices
            top_k_labels = labels[top_k_indices]
            
            if true_label in top_k_labels:
                correct += 1
            total += 1
        
        results[f'recall@{k}'] = correct / total if total > 0 else 0.0
    
    return results


def calculate_mrr(embeddings, labels):
    similarities = torch.mm(embeddings, embeddings.t())
    
    mrr_sum = 0.0
    count = 0
    
    for i in range(len(labels)):
        true_label = labels[i]
        sims = similarities[i].clone()
        sims[i] = -float('inf')
        
        sorted_indices = torch.argsort(sims, descending=True)
        sorted_labels = labels[sorted_indices]
        
        for rank, label in enumerate(sorted_labels, 1):
            if label == true_label:
                mrr_sum += 1.0 / rank
                break
        
        count += 1
    
    return mrr_sum / count if count > 0 else 0.0


def calculate_comprehensive_metrics(y_true, y_pred, y_proba=None, embeddings=None):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    
    if cm.shape != (2, 2):
        cm_2x2 = np.zeros((2, 2))
        for i in range(min(cm.shape[0], 2)):
            for j in range(min(cm.shape[1], 2)):
                cm_2x2[i, j] = cm[i, j]
        cm = cm_2x2
    
    tn, fp, fn, tp = cm.ravel()
    
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = f1_score(y_true, y_pred, zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    
    if y_proba is not None and len(np.unique(y_true)) > 1:
        try:
            roc_auc = roc_auc_score(y_true, y_proba)
            pr_auc = average_precision_score(y_true, y_proba)
        except:
            roc_auc = pr_auc = 0.0
    else:
        roc_auc = pr_auc = 0.0
    
    results = {
        'cm': cm,
        'acc': acc,
        'bal_acc': bal_acc,
        'prec': prec,
        'rec': rec,
        'specificity': specificity,
        'f1': f1,
        'mcc': mcc,
        'roc_auc': roc_auc,
        'pr_auc': pr_auc
    }
    
    if embeddings is not None:
        recall_metrics = calculate_recall_at_k(embeddings, torch.tensor(y_true))
        mrr = calculate_mrr(embeddings, torch.tensor(y_true))
        results.update(recall_metrics)
        results['mrr'] = mrr
    
    return results


def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}


def train_model(model, train_data, num_epochs=5, learning_rate=1e-4, batch_size=2):
    optimizer = torch.optim.AdamW([
        {'params': model.adapters.parameters()},
        {'params': model.classifier.parameters()}
    ], lr=learning_rate, weight_decay=0.01)
    
    dataset = CodeCloneDataset(train_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    model.train()
    
    epoch_stats = []
    start_time = time.time()
    initial_memory = get_memory_usage()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0
        correct = 0
        total = 0
        batch_count = 0
        total_tokens = 0
        
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            
            total_tokens += input_ids.numel()
            
            optimizer.zero_grad()
            
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            
            loss = outputs['loss']
            if loss is not None:
                loss.backward()
                
                adapter_params = []
                for adapter in model.adapters:
                    adapter_params.extend(list(adapter.parameters()))
                classifier_params = list(model.classifier.parameters())
                all_params = adapter_params + classifier_params
                
                torch.nn.utils.clip_grad_norm_(all_params, 1.0)
                optimizer.step()
                total_loss += loss.item()
                
                with torch.no_grad():
                    predictions = torch.argmax(outputs['logits'], dim=-1)
                    correct += (predictions == labels.to(predictions.device)).sum().item()
                    total += labels.size(0)
            
            batch_count += 1
            
            if batch_count % 10 == 0:
                torch.cuda.empty_cache()
        
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / batch_count if batch_count > 0 else 0
        accuracy = correct / total if total > 0 else 0
        tokens_per_sec = total_tokens / epoch_time
        current_memory = get_memory_usage()
        
        epoch_stats.append({
            'epoch': epoch + 1,
            'loss': avg_loss,
            'acc': accuracy,
            'time': epoch_time,
            'tokens_per_sec': tokens_per_sec,
            'memory_mb': current_memory
        })
        
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}, Acc: {accuracy:.4f}, "
              f"Time: {epoch_time:.2f}s, Tokens/sec: {tokens_per_sec:.2f}, Memory: {current_memory:.2f}MB")
        
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
    
    total_training_time = time.time() - start_time
    peak_memory = max([stat['memory_mb'] for stat in epoch_stats])
    memory_increase = peak_memory - initial_memory
    
    print(f"\n{'='*80}")
    print(f"TRAINING SUMMARY")
    print(f"{'='*80}")
    print(f"Total Training Time: {total_training_time:.2f}s")
    print(f"Peak Memory Usage: {peak_memory:.2f}MB")
    print(f"Memory Increase: {memory_increase:.2f}MB")
    print(f"Average Tokens/sec: {np.mean([s['tokens_per_sec'] for s in epoch_stats]):.2f}")
    print(f"{'='*80}\n")
    
    return epoch_stats


def evaluate_model(model, test_data, batch_size=2):
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    all_embeddings = []
    
    dataset = CodeCloneDataset(test_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    
    start_time = time.time()
    total_tokens = 0
    
    with torch.no_grad():
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            
            total_tokens += input_ids.numel()
            
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']
            hidden_states = outputs['hidden_states']
            
            predictions = torch.argmax(logits, dim=-1)
            probabilities = F.softmax(logits, dim=-1)
            
            pooled_embeddings = hidden_states.mean(dim=1)
            pooled_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities[:, 1].cpu().numpy())
            all_embeddings.append(pooled_embeddings.cpu())
            
            if len(all_predictions) % 20 == 0:
                torch.cuda.empty_cache()
    
    inference_time = time.time() - start_time
    tokens_per_sec = total_tokens / inference_time
    
    all_embeddings = torch.cat(all_embeddings, dim=0)
    
    metrics = calculate_comprehensive_metrics(
        all_labels, 
        all_predictions, 
        all_probabilities,
        all_embeddings
    )
    
    metrics['inference_time'] = inference_time
    metrics['tokens_per_sec'] = tokens_per_sec
    
    return metrics


def print_evaluation_results(metrics, dataset_name):
    print(f"\n{'='*80}")
    print(f"EVALUATION RESULTS - {dataset_name}")
    print(f"{'='*80}")
    print(f"Confusion Matrix:\n{metrics['cm']}")
    print(f"Accuracy: {metrics['acc']:.4f}")
    print(f"Balanced Accuracy: {metrics['bal_acc']:.4f}")
    print(f"Precision: {metrics['prec']:.4f}")
    print(f"Recall: {metrics['rec']:.4f}")
    print(f"Specificity: {metrics['specificity']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    print(f"MCC: {metrics['mcc']:.4f}")
    print(f"ROC AUC: {metrics['roc_auc']:.4f}")
    print(f"PR AUC: {metrics['pr_auc']:.4f}")
    
    if 'recall@1' in metrics:
        print(f"Recall@1: {metrics['recall@1']:.4f}")
        print(f"Recall@5: {metrics['recall@5']:.4f}")
        print(f"Recall@10: {metrics['recall@10']:.4f}")
        print(f"MRR: {metrics['mrr']:.4f}")
    
    print(f"Inference Time: {metrics['inference_time']:.2f}s")
    print(f"Tokens/sec: {metrics['tokens_per_sec']:.2f}")
    print(f"{'='*80}\n")


def main():
    torch.cuda.empty_cache()
    gc.collect()
    
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    
    primary_device = 'cuda:0'
    
    cache_dir = '/hf_cache'
    
    tokenizer = AutoTokenizer.from_pretrained("Qwen/CodeQwen1.5-7B-Chat", cache_dir=cache_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print("\nLoading model with adapter tuning and pipeline parallelism across GPUs...")
    model = AdapterTuningCodeQwen("Qwen/CodeQwen1.5-7B-Chat", adapter_size=512, num_classes=2, cache_dir=cache_dir)
    print_parameter_statistics(model)
    
    print(f"\n{'='*80}")
    print("TRAINING PHASE")
    print(f"{'='*80}\n")
    
    train_csv = 'train.csv'
    train_data = create_training_data(
        train_csv, 
        tokenizer,
        label_col='label', 
        func1_col='func1', 
        func2_col='func2'
    )
    print(f"Created {len(train_data)} training examples\n")
    
    if len(train_data) == 0:
        print("No training data available. Exiting.")
        return
    
    epoch_stats = train_model(model, train_data, num_epochs=5, learning_rate=1e-4, batch_size=2)
    
    print(f"\n{'='*80}")
    print("TESTING PHASE")
    print(f"{'='*80}\n")
    
    test_csv = 'test.csv'
    test_data = create_training_data(
        test_csv,
        tokenizer,
        label_col='label',
        func1_col='func1',
        func2_col='func2'
    )
    print(f"Created {len(test_data)} test examples\n")
    
    if len(test_data) > 0:
        test_metrics = evaluate_model(model, test_data, batch_size=2)
        print_evaluation_results(test_metrics, "TEST SET")
    
    print(f"\n{'='*80}")
    print("TRAINING COMPLETE")
    print(f"{'='*80}\n")


if __name__ == '__main__':
    main()

# bitFit codeQwueen

In [None]:
import math
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup
from torch.optim import AdamW
from datasets import Dataset as HFDataset
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, confusion_matrix,
    balanced_accuracy_score, precision_score, recall_score, jaccard_score,
    matthews_corrcoef, cohen_kappa_score, roc_auc_score, average_precision_score,
    log_loss, brier_score_loss
)
import pandas as pd
import numpy as np
import os
import logging
import time
from typing import List, Dict, Tuple
import psutil
import gc

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


class BitFitClassificationHead(nn.Module):
    def __init__(self, hidden_size, num_labels=1, dtype=torch.float16):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size, dtype=dtype)
        self.dropout = nn.Dropout(0.1)
        self.out_proj = nn.Linear(hidden_size, num_labels, dtype=dtype)
        
    def forward(self, hidden_states):
        x = self.dropout(hidden_states)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class BitFitCloneTrainer:
    BIAS_TERMS_MAPPING = {
        'query': 'q_proj.bias',
        'key': 'k_proj.bias',
        'value': 'v_proj.bias',
        'output': 'o_proj.bias',
        'mlp': 'mlp',
        'layer_norm': 'ln',
        'all': 'bias'
    }

    def __init__(self, train_file, test_file, func1_col, func2_col, label_col, cache_dir=None, bias_terms=['all']):
        self.train_file = train_file
        self.test_file = test_file
        self.func1_col = func1_col
        self.func2_col = func2_col
        self.label_col = label_col
        self.cache_dir = cache_dir
        self.bias_terms = bias_terms
        
        torch.cuda.empty_cache()
        gc.collect()
        
        num_gpus = torch.cuda.device_count()
        logger.info(f"Number of available GPUs: {num_gpus}")
        for i in range(num_gpus):
            logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        
        logger.info("Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            "Qwen/CodeQwen1.5-7B-Chat", 
            trust_remote_code=True,
            cache_dir=cache_dir,
            local_files_only=True
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        logger.info("Loading base model with automatic device mapping...")
        self.base_model = AutoModelForCausalLM.from_pretrained(
            "Qwen/CodeQwen1.5-7B-Chat",
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
            cache_dir=cache_dir,
            local_files_only=True,
            low_cpu_mem_usage=True
        )
        
        hidden_size = self.base_model.config.hidden_size
        logger.info(f"Model hidden size: {hidden_size}")
        
        self.classification_head = None
        
        self.print_model_distribution()
        
        logger.info("Loading datasets...")
        self.train_data = HFDataset.from_csv(train_file)
        self.test_data = HFDataset.from_csv(test_file)
        logger.info(f"Train samples: {len(self.train_data)}")
        logger.info(f"Test samples: {len(self.test_data)}")
        
        self.training_stats = {
            'epoch_times': [],
            'epoch_losses': [],
            'memory_usage': []
        }
        
        self.optimizer = None
        self.scheduler = None
    
    def print_model_distribution(self):
        logger.info("\n" + "="*80)
        logger.info("MODEL DEVICE DISTRIBUTION")
        logger.info("="*80)
        
        if hasattr(self.base_model, 'hf_device_map'):
            device_map = self.base_model.hf_device_map
            for module_name, device in device_map.items():
                logger.info(f"{module_name}: {device}")
        
        logger.info("\n" + "="*80)
        logger.info("GPU MEMORY USAGE")
        logger.info("="*80)
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            reserved = torch.cuda.memory_reserved(i) / 1024**3
            logger.info(f"GPU {i}: Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
        logger.info("="*80 + "\n")
    
    def preprocess(self, example):
        c1 = example[self.func1_col]
        c2 = example[self.func2_col]
        y = example[self.label_col]
        text = f"<func1>\n{c1}\n</func1>\n<func2>\n{c2}\n</func2>"
        out = self.tokenizer(text, truncation=True, padding="max_length", max_length=512)
        out["labels"] = y
        return out
    
    def _setup_bitfit_parameters(self):
        logger.info("Setting up BitFit parameters...")
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        trainable_bias_terms = self._get_trainable_bias_terms()
        
        trainable_params = 0
        total_params = 0
        bias_param_count = 0
        
        for name, param in self.base_model.named_parameters():
            total_params += param.numel()
            
            should_train = False
            for bias_term in trainable_bias_terms:
                if bias_term in name and 'bias' in name:
                    should_train = True
                    break
            
            if should_train:
                param.requires_grad = True
                trainable_params += param.numel()
                bias_param_count += param.numel()
                logger.info(f"✓ Trainable: {name} - Shape: {param.shape} - Params: {param.numel():,}")
        
        hidden_size = self.base_model.config.hidden_size
        first_device = next(self.base_model.parameters()).device
        model_dtype = next(self.base_model.parameters()).dtype
        
        self.classification_head = BitFitClassificationHead(hidden_size, num_labels=1, dtype=model_dtype)
        self.classification_head = self.classification_head.to(first_device)
        
        head_params = sum(p.numel() for p in self.classification_head.parameters())
        trainable_params += head_params
        total_params += head_params
        
        logger.info(f"\n✓ Classification Head added: {head_params:,} parameters")
        for name, param in self.classification_head.named_parameters():
            logger.info(f"  ✓ {name} - Shape: {param.shape} - Params: {param.numel():,}")
        
        logger.info(f"\n{'='*80}")
        logger.info("BITFIT PARAMETER STATISTICS")
        logger.info(f"{'='*80}")
        logger.info(f"Base Model Total Parameters: {total_params - head_params:,}")
        logger.info(f"Trainable Bias Parameters: {bias_param_count:,}")
        logger.info(f"Classification Head Parameters: {head_params:,}")
        logger.info(f"Total Trainable Parameters: {trainable_params:,}")
        logger.info(f"Total Parameters: {total_params:,}")
        logger.info(f"Trainable Percentage: {100 * trainable_params / total_params:.6f}%")
        logger.info(f"{'='*80}\n")
    
    def _get_trainable_bias_terms(self) -> List[str]:
        trainable_terms = []
        
        for term in self.bias_terms:
            if term in self.BIAS_TERMS_MAPPING:
                trainable_terms.append(self.BIAS_TERMS_MAPPING[term])
            else:
                logger.warning(f"Unknown bias term: {term}")
        
        return trainable_terms
    
    def calculate_metrics(self, y_true, y_pred, y_prob=None):
        cm = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = cm.ravel()
        
        acc = accuracy_score(y_true, y_pred)
        bal_acc = balanced_accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred, zero_division=0)
        rec = recall_score(y_true, y_pred, zero_division=0)
        
        p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary', zero_division=0)
        
        jacc = jaccard_score(y_true, y_pred, zero_division=0)
        mcc = matthews_corrcoef(y_true, y_pred)
        kappa = cohen_kappa_score(y_true, y_pred)
        
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
        fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
        
        metrics = {
            'cm': cm,
            'tn': int(tn),
            'fp': int(fp),
            'fn': int(fn),
            'tp': int(tp),
            'acc': acc,
            'bal_acc': bal_acc,
            'prec': prec,
            'rec': rec,
            'f1': f1,
            'jacc': jacc,
            'mcc': mcc,
            'kappa': kappa,
            'specificity': specificity,
            'npv': npv,
            'fpr': fpr,
            'fnr': fnr
        }
        
        if y_prob is not None:
            try:
                roc_auc = roc_auc_score(y_true, y_prob)
                pr_auc = average_precision_score(y_true, y_prob)
                ll = log_loss(y_true, y_prob)
                brier = brier_score_loss(y_true, y_prob)
                
                metrics.update({
                    'roc_auc': roc_auc,
                    'pr_auc': pr_auc,
                    'log_loss': ll,
                    'brier': brier
                })
            except Exception as e:
                logger.warning(f"Could not calculate probabilistic metrics: {e}")
        
        return metrics
    
    def calculate_recall_at_k(self, embeddings, labels, k_values=[1, 3, 5, 10]):
        results = {}
        
        similarity_matrix = torch.nn.functional.cosine_similarity(
            embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2
        )
        
        for k in k_values:
            correct = 0
            total = 0
            
            for i in range(len(labels)):
                query_label = labels[i]
                similarities = similarity_matrix[i]
                similarities[i] = -float('inf')
                
                if k > len(labels) - 1:
                    k_actual = len(labels) - 1
                else:
                    k_actual = k
                
                top_k_indices = torch.topk(similarities, k_actual).indices
                top_k_labels = [labels[idx] for idx in top_k_indices]
                
                if query_label in top_k_labels:
                    correct += 1
                total += 1
            
            results[f'recall@{k}'] = correct / total if total > 0 else 0.0
        
        reciprocal_ranks = []
        for i in range(len(labels)):
            query_label = labels[i]
            similarities = similarity_matrix[i]
            similarities[i] = -float('inf')
            
            sorted_indices = torch.argsort(similarities, descending=True)
            sorted_labels = [labels[idx] for idx in sorted_indices]
            
            try:
                rank = sorted_labels.index(query_label) + 1
                reciprocal_ranks.append(1.0 / rank)
            except ValueError:
                reciprocal_ranks.append(0.0)
        
        results['mrr'] = np.mean(reciprocal_ranks)
        
        return results
    
    def collate_fn(self, batch):
        input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch])
        attention_mask = torch.stack([torch.tensor(item['attention_mask']) for item in batch])
        labels = torch.tensor([item['labels'] for item in batch])
        return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}
    
    def setup_optimizer(self, train_dataloader, num_epochs):
        trainable_params = [p for p in self.base_model.parameters() if p.requires_grad]
        trainable_params += list(self.classification_head.parameters())
        
        self.optimizer = AdamW(
            trainable_params,
            lr=2e-4,
            weight_decay=0.01
        )
        
        total_steps = len(train_dataloader) * num_epochs
        warmup_steps = int(0.03 * total_steps)
        
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )
        
        logger.info(f"Optimizer setup complete. Total training steps: {total_steps}")
        logger.info(f"Warmup steps: {warmup_steps}")
    
    def train_epoch(self, train_dataloader, epoch, num_epochs):
        self.base_model.eval()
        self.classification_head.train()
        
        total_loss = 0
        total_samples = 0
        
        for batch_idx, batch in enumerate(train_dataloader):
            first_device = next(self.base_model.parameters()).device
            
            input_ids = batch['input_ids'].to(first_device)
            attention_mask = batch['attention_mask'].to(first_device)
            labels = batch['labels'].to(first_device)
            
            with torch.no_grad():
                outputs = self.base_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True
                )
            
            hidden_states = outputs.hidden_states[-1]
            pooled = hidden_states.mean(dim=1)
            
            logits = self.classification_head(pooled).squeeze(-1)
            
            loss_fn = nn.BCEWithLogitsLoss()
            loss = loss_fn(logits, labels.float())
            
            self.optimizer.zero_grad()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(
                list(self.base_model.parameters()) + list(self.classification_head.parameters()), 
                max_norm=1.0
            )
            
            self.optimizer.step()
            if self.scheduler:
                self.scheduler.step()
            
            total_loss += loss.item()
            total_samples += input_ids.size(0)
            
            if batch_idx % 20 == 0:
                current_lr = self.optimizer.param_groups[0]['lr']
                logger.info(
                    f"Epoch {epoch+1}/{num_epochs} - "
                    f"Batch {batch_idx}/{len(train_dataloader)} - "
                    f"Loss: {loss.item():.4f} - "
                    f"LR: {current_lr:.2e}"
                )
        
        avg_loss = total_loss / len(train_dataloader)
        return {'loss': avg_loss}
    
    def train(self):
        train_dataset = self.train_data.map(self.preprocess, batched=False)
        test_dataset = self.test_data.map(self.preprocess, batched=False)
        
        logger.info("\n" + "="*80)
        logger.info("APPLYING BITFIT CONFIGURATION")
        logger.info("="*80)
        
        self._setup_bitfit_parameters()
        
        self.print_model_distribution()
        
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=4,
            shuffle=True,
            collate_fn=self.collate_fn,
            num_workers=0,
            pin_memory=True
        )
        
        num_epochs = 5
        
        self.setup_optimizer(train_dataloader, num_epochs)
        
        logger.info("="*80)
        logger.info("STARTING TRAINING")
        logger.info(f"Total Epochs: {num_epochs}")
        logger.info("="*80)
        
        training_start = time.time()
        
        process = psutil.Process(os.getpid())
        mem_before = process.memory_info().rss / 1024 / 1024
        
        for epoch in range(num_epochs):
            epoch_start = time.time()
            
            logger.info(f"\n{'='*60}")
            logger.info(f"Epoch {epoch + 1}/{num_epochs}")
            logger.info(f"{'='*60}")
            
            train_metrics = self.train_epoch(train_dataloader, epoch, num_epochs)
            
            epoch_time = time.time() - epoch_start
            self.training_stats['epoch_times'].append(epoch_time)
            self.training_stats['epoch_losses'].append(train_metrics['loss'])
            
            logger.info(f"Epoch {epoch+1} completed in {epoch_time:.2f}s")
            logger.info(f"Training Loss: {train_metrics['loss']:.4f}")
            
            torch.cuda.empty_cache()
            gc.collect()
        
        total_training_time = time.time() - training_start
        
        mem_after = process.memory_info().rss / 1024 / 1024
        mem_used = mem_after - mem_before
        
        logger.info(f"\n{'='*80}")
        logger.info("TRAINING COMPLETE")
        logger.info(f"{'='*80}")
        logger.info(f"Total Training Time: {total_training_time:.2f}s")
        logger.info(f"Average Epoch Time: {np.mean(self.training_stats['epoch_times']):.2f}s")
        logger.info(f"Final Training Loss: {self.training_stats['epoch_losses'][-1]:.4f}")
        logger.info(f"Memory Usage: {mem_used:.1f} MB")
        logger.info(f"{'='*80}\n")
        
        self.print_model_distribution()
        
        os.makedirs("codeqwen1.5-bitfit-clone", exist_ok=True)
        self.base_model.save_pretrained("codeqwen1.5-bitfit-clone")
        torch.save(self.classification_head.state_dict(), "codeqwen1.5-bitfit-clone/classification_head.pt")
        logger.info("Models saved successfully to codeqwen1.5-bitfit-clone/")
        
        torch.cuda.empty_cache()
        gc.collect()
        
        return self.base_model, total_training_time
    
    def evaluate_comprehensive(self, model, test_dataset):
        model.eval()
        self.classification_head.eval()
        
        logger.info("="*80)
        logger.info("STARTING COMPREHENSIVE EVALUATION")
        logger.info("="*80)
        
        test_df = pd.read_csv(self.test_file)
        test_labels = test_df[self.label_col].tolist()
        
        all_preds = []
        all_probs = []
        all_embeddings = []
        
        inference_start = time.time()
        total_samples = len(test_dataset)
        
        logger.info(f"Processing {total_samples} test samples...")
        
        with torch.no_grad():
            for idx in range(len(test_dataset)):
                sample = test_dataset[idx]
                
                input_ids = torch.tensor([sample['input_ids']])
                attention_mask = torch.tensor([sample['attention_mask']])
                
                first_device = next(model.parameters()).device
                input_ids = input_ids.to(first_device)
                attention_mask = attention_mask.to(first_device)
                
                outputs = model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask, 
                    output_hidden_states=True
                )
                
                hidden_states = outputs.hidden_states[-1]
                pooled = hidden_states.mean(dim=1)
                
                all_embeddings.append(pooled.squeeze().cpu())
                
                logits = self.classification_head(pooled).squeeze(-1)
                
                prob = torch.sigmoid(logits).item()
                pred = 1 if prob > 0.5 else 0
                
                all_preds.append(pred)
                all_probs.append(prob)
                
                if (idx + 1) % 100 == 0:
                    logger.info(f"Processed {idx + 1}/{total_samples} samples")
                    torch.cuda.empty_cache()
        
        inference_time = time.time() - inference_start
        samples_per_sec = total_samples / inference_time
        
        logger.info(f"Inference completed in {inference_time:.2f}s")
        logger.info(f"Throughput: {samples_per_sec:.2f} samples/sec")
        
        metrics = self.calculate_metrics(test_labels, all_preds, all_probs)
        
        all_embeddings = torch.stack(all_embeddings)
        recall_metrics = self.calculate_recall_at_k(all_embeddings, test_labels)
        metrics.update(recall_metrics)
        
        metrics['inference_time'] = inference_time
        metrics['samples_per_sec'] = samples_per_sec
        
        torch.cuda.empty_cache()
        gc.collect()
        
        return metrics


def print_results(metrics, dataset_name):
    print(f"\n{'='*80}")
    print(f"RESULTS: {dataset_name}")
    print(f"{'='*80}")
    print(f"\nConfusion Matrix:")
    print(f"{metrics['cm']}")
    print(f"\nTrue Negatives (TN): {metrics['tn']}")
    print(f"False Positives (FP): {metrics['fp']}")
    print(f"False Negatives (FN): {metrics['fn']}")
    print(f"True Positives (TP): {metrics['tp']}")
    print(f"\n{'='*80}")
    print(f"CLASSIFICATION METRICS")
    print(f"{'='*80}")
    print(f"Accuracy: {metrics['acc']:.4f}")
    print(f"Balanced Accuracy: {metrics['bal_acc']:.4f}")
    print(f"Precision: {metrics['prec']:.4f}")
    print(f"Recall (Sensitivity): {metrics['rec']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    print(f"Specificity: {metrics['specificity']:.4f}")
    print(f"Negative Predictive Value (NPV): {metrics['npv']:.4f}")
    print(f"False Positive Rate (FPR): {metrics['fpr']:.4f}")
    print(f"False Negative Rate (FNR): {metrics['fnr']:.4f}")
    print(f"\n{'='*80}")
    print(f"ADVANCED METRICS")
    print(f"{'='*80}")
    print(f"Jaccard Score: {metrics['jacc']:.4f}")
    print(f"Matthews Correlation Coefficient (MCC): {metrics['mcc']:.4f}")
    print(f"Cohen's Kappa: {metrics['kappa']:.4f}")
    
    if 'roc_auc' in metrics:
        print(f"\n{'='*80}")
        print(f"PROBABILISTIC METRICS")
        print(f"{'='*80}")
        print(f"ROC AUC Score: {metrics['roc_auc']:.4f}")
    if 'pr_auc' in metrics:
        print(f"Precision-Recall AUC: {metrics['pr_auc']:.4f}")
    if 'log_loss' in metrics:
        print(f"Log Loss: {metrics['log_loss']:.4f}")
    if 'brier' in metrics:
        print(f"Brier Score: {metrics['brier']:.4f}")
    
    if 'recall@1' in metrics:
        print(f"\n{'='*80}")
        print(f"RETRIEVAL METRICS")
        print(f"{'='*80}")
        print(f"Recall@1: {metrics['recall@1']:.4f}")
        print(f"Recall@3: {metrics['recall@3']:.4f}")
        print(f"Recall@5: {metrics['recall@5']:.4f}")
        print(f"Recall@10: {metrics['recall@10']:.4f}")
        print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")
    
    if 'inference_time' in metrics:
        print(f"\n{'='*80}")
        print(f"PERFORMANCE METRICS")
        print(f"{'='*80}")
        print(f"Inference Time: {metrics['inference_time']:.2f}s")
        print(f"Throughput: {metrics['samples_per_sec']:.2f} samples/sec")
    
    print(f"{'='*80}\n")


def run_bitfit_experiment(train_file, test_file, func1_col, func2_col, label_col, cache_dir=None, bias_terms=['all']):
    logger.info("="*80)
    logger.info("CODE CLONE DETECTION WITH BITFIT")
    logger.info("="*80)
    
    bitfit_trainer = BitFitCloneTrainer(
        train_file, 
        test_file, 
        func1_col, 
        func2_col, 
        label_col,
        cache_dir=cache_dir,
        bias_terms=bias_terms
    )
    
    model, training_time = bitfit_trainer.train()
    
    logger.info(f"\n{'='*80}")
    logger.info("FINAL TEST EVALUATION")
    logger.info(f"{'='*80}")
    
    test_dataset = bitfit_trainer.test_data.map(bitfit_trainer.preprocess, batched=False)
    metrics = bitfit_trainer.evaluate_comprehensive(model, test_dataset)
    
    print_results(metrics, "TEST SET")
    
    return bitfit_trainer, model, metrics


if __name__ == "__main__":
    train_file = '/train.csv'
    test_file = '/test.csv'
    func1_col = "func1"
    func2_col = "func2"
    label_col = "label"
    cache_dir = '/hf_cache'
    
    try:
        bitfit_trainer, model, results = run_bitfit_experiment(
            train_file, 
            test_file, 
            func1_col, 
            func2_col, 
            label_col,
            cache_dir=cache_dir,
            bias_terms=['all']
        )
        logger.info("BitFit experiment completed successfully!")
        
    except Exception as e:
        logger.error(f"Error: {e}")
        import traceback
        traceback.print_exc()
        raise

# GateRA codeQween

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
from sklearn.metrics import (
    accuracy_score, confusion_matrix, balanced_accuracy_score, 
    matthews_corrcoef, roc_auc_score, average_precision_score, f1_score
)
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader
from typing import Optional, List, Dict
from dataclasses import dataclass


@dataclass
class GateRAConfig:
    rank: int = 16
    alpha: float = 16.0
    dropout: float = 0.0
    target_modules: List[str] = None
    entropy_reg_weight: float = 0.01
    
    def __post_init__(self):
        if self.target_modules is None:
            self.target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", 
                                   "gate_proj", "up_proj", "down_proj"]


class GatingModule(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.gate_linear = nn.Linear(input_dim, 1, bias=True)
        nn.init.zeros_(self.gate_linear.weight)
        nn.init.zeros_(self.gate_linear.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_logits = self.gate_linear(x)
        gate_values = torch.sigmoid(gate_logits)
        return gate_values


class GateRALayer(nn.Module):
    def __init__(
        self,
        base_layer: nn.Module,
        rank: int,
        alpha: float,
        dropout: float,
        input_dim: int,
        output_dim: int
    ):
        super().__init__()
        self.base_layer = base_layer
        self.rank = rank
        self.scaling = alpha / rank
        self.dropout = dropout
        
        self.lora_A = nn.Parameter(torch.zeros(input_dim, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, output_dim))
        
        self.gating_module = GatingModule(input_dim)
        
        self.dropout_layer = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
        
        self.reset_parameters()
        
        self.last_gate_values = None
    
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        base_output = self.base_layer(x)
        
        if x.dim() == 3:
            batch_size, seq_len, hidden_dim = x.shape
            x_flat = x.view(-1, hidden_dim)
        else:
            x_flat = x
            batch_size, seq_len = None, None
        
        gate_values = self.gating_module(x_flat)
        self.last_gate_values = gate_values.detach()
        
        lora_output = x_flat @ self.lora_A @ self.lora_B
        lora_output = self.dropout_layer(lora_output)
        
        modulated_output = gate_values * lora_output * self.scaling
        
        if batch_size is not None and seq_len is not None:
            modulated_output = modulated_output.view(batch_size, seq_len, -1)
        
        final_output = base_output + modulated_output
        
        return final_output


class CodeCloneDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


class GateRACodeQwen(nn.Module):
    def __init__(self, model_name: str, gatera_config: GateRAConfig, num_classes: int = 2, cache_dir: str = None):
        super().__init__()
        
        torch.cuda.empty_cache()
        gc.collect()
        
        num_gpus = torch.cuda.device_count()
        print(f"Initializing model across {num_gpus} GPUs")
        
        self.primary_device = 'cuda:0'
        
        print("Loading base LLM model...")
        base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            cache_dir=cache_dir,
            low_cpu_mem_usage=True,
            device_map='auto'
        )
        
        self.llm = base_model
        self.config = gatera_config
        self.gatera_layers = nn.ModuleList()
        self.gatera_layer_names = []
        
        total_layers = len(base_model.model.layers)
        print(f"Model has {total_layers} transformer layers")
        
        self.layer_devices = []
        for i in range(total_layers):
            layer_device = next(base_model.model.layers[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Layer {i} -> {layer_device}")
        
        embed_device = next(base_model.model.embed_tokens.parameters()).device
        norm_device = next(base_model.model.norm.parameters()).device
        lm_head_device = next(base_model.lm_head.parameters()).device
        
        print(f"Embeddings -> {embed_device}")
        print(f"Norm -> {norm_device}")
        print(f"LM Head -> {lm_head_device}")
        
        self.embed_device = str(embed_device)
        self.final_device = str(norm_device)
        
        self.hidden_size = base_model.config.hidden_size
        self.num_classes = num_classes
        
        for param in self.llm.parameters():
            param.requires_grad = False
        
        print("\nInjecting GateRA layers...")
        self._inject_gatera_layers()
        
        self.classifier = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(self.hidden_size // 4, num_classes)
        ).to(self.primary_device)
        
        for module in self.classifier.modules():
            if isinstance(module, nn.Linear):
                module.weight.data = module.weight.data.half()
                if module.bias is not None:
                    module.bias.data = module.bias.data.half()
        
        for param in self.classifier.parameters():
            param.requires_grad = True
        
        print("Model initialization complete!")
    
    def _inject_gatera_layers(self):
        layer_count = 0
        for name, module in self.llm.named_modules():
            if isinstance(module, nn.Linear):
                layer_name = name.split('.')[-1]
                
                should_inject = any(target in layer_name for target in self.config.target_modules)
                
                if should_inject:
                    input_dim = module.in_features
                    output_dim = module.out_features
                    
                    layer_device = next(module.parameters()).device
                    
                    gatera_layer = GateRALayer(
                        base_layer=module,
                        rank=self.config.rank,
                        alpha=self.config.alpha,
                        dropout=self.config.dropout,
                        input_dim=input_dim,
                        output_dim=output_dim
                    ).to(layer_device).half()
                    
                    self.gatera_layers.append(gatera_layer)
                    self.gatera_layer_names.append(name)
                    
                    parent_name = '.'.join(name.split('.')[:-1])
                    parent = self.llm
                    if parent_name:
                        for part in parent_name.split('.'):
                            parent = getattr(parent, part)
                    setattr(parent, layer_name, gatera_layer)
                    
                    layer_count += 1
                    if layer_count % 10 == 0:
                        print(f"  Injected {layer_count} GateRA layers...")
        
        print(f"Total GateRA layers injected: {layer_count}")
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        batch_size = input_ids.shape[0]
        
        input_ids = input_ids.to(self.embed_device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.embed_device)
        
        outputs = self.llm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        
        hidden_states = outputs.hidden_states[-1]
        
        if str(hidden_states.device) != self.primary_device:
            hidden_states = hidden_states.to(self.primary_device)
        
        if attention_mask is not None:
            attention_mask_expanded = attention_mask.to(self.primary_device).unsqueeze(-1).expand(hidden_states.size()).float()
            attention_mask_expanded = attention_mask_expanded.half()
            sum_embeddings = torch.sum(hidden_states * attention_mask_expanded, 1)
            sum_mask = torch.clamp(attention_mask_expanded.sum(1), min=1e-9)
            pooled_output = sum_embeddings / sum_mask
        else:
            pooled_output = hidden_states.mean(dim=1)
        
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            if labels.device != logits.device:
                labels = labels.to(logits.device)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        
        return {'logits': logits, 'loss': loss, 'hidden_states': hidden_states}
    
    def get_gate_entropy_loss(self):
        total_entropy_loss = 0.0
        gate_count = 0
        
        for layer in self.gatera_layers:
            if hasattr(layer, 'last_gate_values') and layer.last_gate_values is not None:
                gate_values = layer.last_gate_values
                eps = 1e-8
                gate_values = torch.clamp(gate_values, eps, 1.0 - eps)
                entropy = -gate_values * torch.log(gate_values) - (1 - gate_values) * torch.log(1 - gate_values)
                # Move entropy to primary device before adding
                total_entropy_loss += entropy.mean().to(self.primary_device)
                gate_count += 1
        
        if gate_count > 0:
            return total_entropy_loss / gate_count
        else:
            return torch.tensor(0.0, device=self.primary_device)


def print_parameter_statistics(model: GateRACodeQwen):
    total_params = sum(p.numel() for p in model.llm.parameters())
    
    gatera_params = 0
    for layer in model.gatera_layers:
        gatera_params += sum(p.numel() for p in layer.parameters() if p.requires_grad)
    
    classifier_params = sum(p.numel() for p in model.classifier.parameters())
    trainable_params = gatera_params + classifier_params
    
    print(f"\n{'='*80}")
    print(f"PARAMETER STATISTICS")
    print(f"{'='*80}")
    print(f"Total Frozen LLM Parameters: {total_params:,}")
    print(f"Trainable GateRA Parameters: {gatera_params:,}")
    print(f"Trainable Classifier Parameters: {classifier_params:,}")
    print(f"Total Trainable Parameters: {trainable_params:,}")
    print(f"Trainable Ratio: {(trainable_params / total_params * 100):.4f}%")
    print(f"{'='*80}\n")
    
    print(f"{'='*80}")
    print(f"GPU MEMORY USAGE")
    print(f"{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        print(f"GPU {i}: Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
    print(f"{'='*80}\n")

def load_dataset(csv_path, label_col, func1_col, func2_col):
    try:
        df = pd.read_csv(csv_path)
        print(f"Loaded dataset: {len(df)} samples")
        print(f"Label distribution: {df[label_col].value_counts().to_dict()}")
        return df
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None


def create_training_data(csv_path, tokenizer, max_length=512, 
                        label_col='label', func1_col='func1', func2_col='func2'):
    df = load_dataset(csv_path, label_col, func1_col, func2_col)
    
    if df is None:
        return []
    
    train_data = []
    
    for _, row in df.iterrows():
        func1 = str(row[func1_col])
        func2 = str(row[func2_col])
        label = int(row[label_col])
        
        combined_text = f"Code1: {func1}\nCode2: {func2}\nAre these code clones?"
        
        encoding = tokenizer(
            combined_text,
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=max_length
        )
        
        train_data.append({
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        })
    
    return train_data


def get_memory_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024


def calculate_recall_at_k(embeddings, labels, k_values=[1, 5, 10]):
    similarities = torch.mm(embeddings, embeddings.t())
    
    results = {}
    for k in k_values:
        correct = 0
        total = 0
        
        for i in range(len(labels)):
            true_label = labels[i]
            sims = similarities[i].clone()
            sims[i] = -float('inf')
            
            top_k_indices = torch.topk(sims, min(k, len(labels)-1)).indices
            top_k_labels = labels[top_k_indices]
            
            if true_label in top_k_labels:
                correct += 1
            total += 1
        
        results[f'recall@{k}'] = correct / total if total > 0 else 0.0
    
    return results


def calculate_mrr(embeddings, labels):
    similarities = torch.mm(embeddings, embeddings.t())
    
    mrr_sum = 0.0
    count = 0
    
    for i in range(len(labels)):
        true_label = labels[i]
        sims = similarities[i].clone()
        sims[i] = -float('inf')
        
        sorted_indices = torch.argsort(sims, descending=True)
        sorted_labels = labels[sorted_indices]
        
        for rank, label in enumerate(sorted_labels, 1):
            if label == true_label:
                mrr_sum += 1.0 / rank
                break
        
        count += 1
    
    return mrr_sum / count if count > 0 else 0.0


def calculate_comprehensive_metrics(y_true, y_pred, y_proba=None, embeddings=None):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    
    if cm.shape != (2, 2):
        cm_2x2 = np.zeros((2, 2))
        for i in range(min(cm.shape[0], 2)):
            for j in range(min(cm.shape[1], 2)):
                cm_2x2[i, j] = cm[i, j]
        cm = cm_2x2
    
    tn, fp, fn, tp = cm.ravel()
    
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = f1_score(y_true, y_pred, zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    
    if y_proba is not None and len(np.unique(y_true)) > 1:
        try:
            roc_auc = roc_auc_score(y_true, y_proba)
            pr_auc = average_precision_score(y_true, y_proba)
        except:
            roc_auc = pr_auc = 0.0
    else:
        roc_auc = pr_auc = 0.0
    
    results = {
        'cm': cm,
        'acc': acc,
        'bal_acc': bal_acc,
        'prec': prec,
        'rec': rec,
        'specificity': specificity,
        'f1': f1,
        'mcc': mcc,
        'roc_auc': roc_auc,
        'pr_auc': pr_auc
    }
    
    if embeddings is not None:
        recall_metrics = calculate_recall_at_k(embeddings, torch.tensor(y_true))
        mrr = calculate_mrr(embeddings, torch.tensor(y_true))
        results.update(recall_metrics)
        results['mrr'] = mrr
    
    return results


def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}


def train_model(model, train_data, num_epochs=5, learning_rate=1e-4, batch_size=2, entropy_reg_weight=0.01):
    # Collect trainable parameters
    trainable_params = []
    for layer in model.gatera_layers:
        trainable_params.extend([p for p in layer.parameters() if p.requires_grad])
    trainable_params.extend([p for p in model.classifier.parameters() if p.requires_grad])
    
    optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=0.01)
    
    dataset = CodeCloneDataset(train_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    model.train()
    
    epoch_stats = []
    start_time = time.time()
    initial_memory = get_memory_usage()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0
        total_task_loss = 0
        total_entropy_loss = 0
        correct = 0
        total = 0
        batch_count = 0
        total_tokens = 0
        
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            
            total_tokens += input_ids.numel()
            
            optimizer.zero_grad()
            
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            
            task_loss = outputs['loss']
            
            # Get entropy regularization loss
            entropy_loss = model.get_gate_entropy_loss()
            
            # Combined loss
            loss = task_loss + entropy_reg_weight * entropy_loss
            
            if loss is not None:
                loss.backward()
                
                torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
                optimizer.step()
                
                total_loss += loss.item()
                total_task_loss += task_loss.item()
                total_entropy_loss += entropy_loss.item()
                
                with torch.no_grad():
                    predictions = torch.argmax(outputs['logits'], dim=-1)
                    correct += (predictions == labels.to(predictions.device)).sum().item()
                    total += labels.size(0)
            
            batch_count += 1
            
            if batch_count % 10 == 0:
                torch.cuda.empty_cache()
        
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / batch_count if batch_count > 0 else 0
        avg_task_loss = total_task_loss / batch_count if batch_count > 0 else 0
        avg_entropy_loss = total_entropy_loss / batch_count if batch_count > 0 else 0
        accuracy = correct / total if total > 0 else 0
        tokens_per_sec = total_tokens / epoch_time
        current_memory = get_memory_usage()
        
        epoch_stats.append({
            'epoch': epoch + 1,
            'loss': avg_loss,
            'task_loss': avg_task_loss,
            'entropy_loss': avg_entropy_loss,
            'acc': accuracy,
            'time': epoch_time,
            'tokens_per_sec': tokens_per_sec,
            'memory_mb': current_memory
        })
        
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f} (Task: {avg_task_loss:.4f}, Entropy: {avg_entropy_loss:.4f}), "
              f"Acc: {accuracy:.4f}, Time: {epoch_time:.2f}s, Tokens/sec: {tokens_per_sec:.2f}, Memory: {current_memory:.2f}MB")
        
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
    
    total_training_time = time.time() - start_time
    peak_memory = max([stat['memory_mb'] for stat in epoch_stats])
    memory_increase = peak_memory - initial_memory
    
    print(f"\n{'='*80}")
    print(f"TRAINING SUMMARY")
    print(f"{'='*80}")
    print(f"Total Training Time: {total_training_time:.2f}s")
    print(f"Peak Memory Usage: {peak_memory:.2f}MB")
    print(f"Memory Increase: {memory_increase:.2f}MB")
    print(f"Average Tokens/sec: {np.mean([s['tokens_per_sec'] for s in epoch_stats]):.2f}")
    print(f"{'='*80}\n")
    
    return epoch_stats


def evaluate_model(model, test_data, batch_size=2):
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    all_embeddings = []
    
    dataset = CodeCloneDataset(test_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    
    start_time = time.time()
    total_tokens = 0
    
    with torch.no_grad():
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            
            total_tokens += input_ids.numel()
            
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']
            hidden_states = outputs['hidden_states']
            
            predictions = torch.argmax(logits, dim=-1)
            probabilities = F.softmax(logits, dim=-1)
            
            pooled_embeddings = hidden_states.mean(dim=1)
            pooled_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities[:, 1].cpu().numpy())
            all_embeddings.append(pooled_embeddings.cpu())
            
            if len(all_predictions) % 20 == 0:
                torch.cuda.empty_cache()
    
    inference_time = time.time() - start_time
    tokens_per_sec = total_tokens / inference_time
    
    all_embeddings = torch.cat(all_embeddings, dim=0)
    
    metrics = calculate_comprehensive_metrics(
        all_labels, 
        all_predictions, 
        all_probabilities,
        all_embeddings
    )
    
    metrics['inference_time'] = inference_time
    metrics['tokens_per_sec'] = tokens_per_sec
    
    return metrics


def print_evaluation_results(metrics, dataset_name):
    print(f"\n{'='*80}")
    print(f"EVALUATION RESULTS - {dataset_name}")
    print(f"{'='*80}")
    print(f"Confusion Matrix:\n{metrics['cm']}")
    print(f"Accuracy: {metrics['acc']:.4f}")
    print(f"Balanced Accuracy: {metrics['bal_acc']:.4f}")
    print(f"Precision: {metrics['prec']:.4f}")
    print(f"Recall: {metrics['rec']:.4f}")
    print(f"Specificity: {metrics['specificity']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    print(f"MCC: {metrics['mcc']:.4f}")
    print(f"ROC AUC: {metrics['roc_auc']:.4f}")
    print(f"PR AUC: {metrics['pr_auc']:.4f}")
    
    if 'recall@1' in metrics:
        print(f"Recall@1: {metrics['recall@1']:.4f}")
        print(f"Recall@5: {metrics['recall@5']:.4f}")
        print(f"Recall@10: {metrics['recall@10']:.4f}")
        print(f"MRR: {metrics['mrr']:.4f}")
    
    print(f"Inference Time: {metrics['inference_time']:.2f}s")
    print(f"Tokens/sec: {metrics['tokens_per_sec']:.2f}")
    print(f"{'='*80}\n")


def main():
    torch.cuda.empty_cache()
    gc.collect()
    
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    
    primary_device = 'cuda:0'
    
    cache_dir = '/hf_cache'
    
    tokenizer = AutoTokenizer.from_pretrained("Qwen/CodeQwen1.5-7B-Chat", cache_dir=cache_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Create GateRA config
    gatera_config = GateRAConfig(
        rank=16,
        alpha=16.0,
        dropout=0.0,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        entropy_reg_weight=0.01
    )
    
    print("\nLoading model with GateRA PEFT and pipeline parallelism across GPUs...")
    model = GateRACodeQwen("Qwen/CodeQwen1.5-7B-Chat", gatera_config=gatera_config, num_classes=2, cache_dir=cache_dir)
    print_parameter_statistics(model)
    
    print(f"\n{'='*80}")
    print("TRAINING PHASE")
    print(f"{'='*80}\n")
    
    train_csv = '/train.csv'
    train_data = create_training_data(
        train_csv, 
        tokenizer,
        label_col='label', 
        func1_col='func1', 
        func2_col='func2'
    )
    print(f"Created {len(train_data)} training examples\n")
    
    if len(train_data) == 0:
        print("No training data available. Exiting.")
        return
    
    epoch_stats = train_model(
        model, 
        train_data, 
        num_epochs=5, 
        learning_rate=1e-4, 
        batch_size=2,
        entropy_reg_weight=gatera_config.entropy_reg_weight
    )
    
    print(f"\n{'='*80}")
    print("TESTING PHASE")
    print(f"{'='*80}\n")
    
    test_csv = '/test.csv'
    test_data = create_training_data(
        test_csv,
        tokenizer,
        label_col='label',
        func1_col='func1',
        func2_col='func2'
    )
    print(f"Created {len(test_data)} test examples\n")
    
    if len(test_data) > 0:
        test_metrics = evaluate_model(model, test_data, batch_size=2)
        print_evaluation_results(test_metrics, "TEST SET")
    
    print(f"\n{'='*80}")
    print("TRAINING COMPLETE")
    print(f"{'='*80}\n")


if __name__ == '__main__':
    main()

# codeqwuen prefix tuning

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
from sklearn.metrics import (
    accuracy_score, confusion_matrix, balanced_accuracy_score, 
    matthews_corrcoef, roc_auc_score, average_precision_score, 
    f1_score, jaccard_score, cohen_kappa_score, log_loss, brier_score_loss
)
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader


class PrefixTuningConfig:
    def __init__(
        self,
        prefix_length: int = 20,
        num_layers: int = 32,
        hidden_size: int = 4096,
        num_heads: int = 32,
        head_dim: int = 128,
        reparam_dim: int = 512,
        dropout: float = 0.1
    ):
        self.prefix_length = prefix_length
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.reparam_dim = reparam_dim
        self.dropout = dropout


class PrefixEncoder(nn.Module):
    def __init__(self, config: PrefixTuningConfig):
        super().__init__()
        self.prefix_length = config.prefix_length
        self.num_layers = config.num_layers
        self.num_heads = config.num_heads
        self.head_dim = config.head_dim
        self.hidden_size = config.hidden_size
        
        self.prefix_tokens = nn.Parameter(
            torch.randn(config.num_layers, config.prefix_length, config.reparam_dim)
        )
        
        self.reparam_mlp = nn.Sequential(
            nn.Linear(config.reparam_dim, config.hidden_size),
            nn.Tanh(),
            nn.Linear(config.hidden_size, 2 * config.num_heads * config.head_dim),
            nn.Dropout(config.dropout)
        )
        
    def forward(self, batch_size: int):
        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1, -1, -1)
        
        prefix_kvs = []
        for layer_idx in range(self.num_layers):
            layer_prefix = prefix_tokens[:, layer_idx, :, :]
            prefix_hidden = self.reparam_mlp(layer_prefix)
            
            prefix_hidden = prefix_hidden.view(
                batch_size, self.prefix_length, 2, self.num_heads, self.head_dim
            )
            
            key = prefix_hidden[:, :, 0, :, :].transpose(1, 2)
            value = prefix_hidden[:, :, 1, :, :].transpose(1, 2)
            
            prefix_kvs.append((key.contiguous(), value.contiguous()))
        
        return prefix_kvs


class CodeCloneDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


class PrefixTuningCodeQwen(nn.Module):
    def __init__(self, model_name: str, prefix_config: PrefixTuningConfig, num_classes: int = 2, cache_dir: str = None):
        super().__init__()
        
        torch.cuda.empty_cache()
        gc.collect()
        
        num_gpus = torch.cuda.device_count()
        print(f"Initializing model across {num_gpus} GPUs")
        
        self.primary_device = 'cuda:0'
        
        print("Loading base LLM model...")
        base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            cache_dir=cache_dir,
            low_cpu_mem_usage=True,
            device_map='auto'
        )
        
        self.llm = base_model
        self.prefix_config = prefix_config
        
        total_layers = len(base_model.model.layers)
        print(f"Model has {total_layers} transformer layers")
        
        self.layer_devices = []
        for i in range(total_layers):
            layer_device = next(base_model.model.layers[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Layer {i} -> {layer_device}")
        
        embed_device = next(base_model.model.embed_tokens.parameters()).device
        norm_device = next(base_model.model.norm.parameters()).device
        
        print(f"Embeddings -> {embed_device}")
        print(f"Norm -> {norm_device}")
        
        self.embed_device = str(embed_device)
        self.final_device = str(norm_device)
        
        self.prefix_encoder = PrefixEncoder(prefix_config).to(self.primary_device).half()
        self.num_classes = num_classes
        
        self.classifier = nn.Sequential(
            nn.Linear(prefix_config.hidden_size, prefix_config.hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(prefix_config.hidden_size // 4, num_classes)
        ).to(self.primary_device).half()
        
        for param in self.llm.parameters():
            param.requires_grad = False
        
        for param in self.prefix_encoder.parameters():
            param.requires_grad = True
            
        for param in self.classifier.parameters():
            param.requires_grad = True
        
        self._register_prefix_hooks()
        
        print("Model initialization complete!")
    
    def _register_prefix_hooks(self):
        self.prefix_kvs = None
        
        def create_hook(layer_idx):
            def hook(module, args, kwargs, output):
                if self.prefix_kvs is not None and layer_idx < len(self.prefix_kvs):
                    prefix_key, prefix_value = self.prefix_kvs[layer_idx]
                    
                    device = output.past_key_value[0].device if hasattr(output, 'past_key_value') and output.past_key_value is not None else self.layer_devices[layer_idx]
                    
                    prefix_key = prefix_key.to(device)
                    prefix_value = prefix_value.to(device)
                    
                    if hasattr(output, 'past_key_value') and output.past_key_value is not None:
                        orig_key, orig_value = output.past_key_value
                        
                        new_key = torch.cat([prefix_key, orig_key], dim=2)
                        new_value = torch.cat([prefix_value, orig_value], dim=2)
                        
                        output.past_key_value = (new_key, new_value)
                
                return output
            return hook
        
        for layer_idx, layer in enumerate(self.llm.model.layers):
            layer.register_forward_hook(create_hook(layer_idx), with_kwargs=True)
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        batch_size = input_ids.shape[0]
        
        input_ids = input_ids.to(self.embed_device)
        
        self.prefix_kvs = self.prefix_encoder(batch_size)
        
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.embed_device)
            prefix_attention = torch.ones(
                batch_size, self.prefix_config.prefix_length, 
                dtype=attention_mask.dtype, device=attention_mask.device
            )
            attention_mask = torch.cat([prefix_attention, attention_mask], dim=1)
        
        with torch.no_grad():
            outputs = self.llm(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
                return_dict=True,
                use_cache=True
            )
        
        hidden_states = outputs.hidden_states[-1]
        
        if str(hidden_states.device) != self.primary_device:
            hidden_states = hidden_states.to(self.primary_device)
        
        pooled_output = hidden_states.mean(dim=1)
        
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            if labels.device != logits.device:
                labels = labels.to(logits.device)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        
        self.prefix_kvs = None
        
        return {'logits': logits, 'loss': loss, 'hidden_states': hidden_states}


def print_parameter_statistics(model: PrefixTuningCodeQwen):
    total_params = sum(p.numel() for p in model.llm.parameters())
    prefix_params = sum(p.numel() for p in model.prefix_encoder.parameters())
    classifier_params = sum(p.numel() for p in model.classifier.parameters())
    trainable_params = prefix_params + classifier_params
    
    print(f"\n{'='*80}")
    print(f"PARAMETER STATISTICS")
    print(f"{'='*80}")
    print(f"Total Frozen LLM Parameters: {total_params:,}")
    print(f"Trainable Prefix Parameters: {prefix_params:,}")
    print(f"Trainable Classifier Parameters: {classifier_params:,}")
    print(f"Total Trainable Parameters: {trainable_params:,}")
    print(f"Trainable Percentage: {(trainable_params / total_params * 100):.4f}%")
    print(f"{'='*80}\n")
    
    print(f"{'='*80}")
    print(f"GPU MEMORY USAGE")
    print(f"{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        print(f"GPU {i}: Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
    print(f"{'='*80}\n")


def load_dataset(csv_path, label_col, func1_col, func2_col):
    try:
        df = pd.read_csv(csv_path)
        print(f"Loaded dataset: {len(df)} samples")
        print(f"Label distribution: {df[label_col].value_counts().to_dict()}")
        return df
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None


def create_training_data(csv_path, tokenizer, max_length=512, 
                        label_col='label', func1_col='func1', func2_col='func2'):
    df = load_dataset(csv_path, label_col, func1_col, func2_col)
    
    if df is None:
        return []
    
    train_data = []
    
    for _, row in df.iterrows():
        func1 = str(row[func1_col])
        func2 = str(row[func2_col])
        label = int(row[label_col])
        
        combined_text = f"Code1: {func1}\nCode2: {func2}\nAre these code clones?"
        
        encoding = tokenizer(
            combined_text,
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=max_length
        )
        
        train_data.append({
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        })
    
    return train_data


def get_memory_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024


def calculate_recall_at_k(embeddings, labels, k_values=[1, 3, 5, 10]):
    similarities = torch.mm(embeddings, embeddings.t())
    
    results = {}
    for k in k_values:
        correct = 0
        total = 0
        
        for i in range(len(labels)):
            true_label = labels[i]
            sims = similarities[i].clone()
            sims[i] = -float('inf')
            
            top_k_indices = torch.topk(sims, min(k, len(labels)-1)).indices
            top_k_labels = labels[top_k_indices]
            
            if true_label in top_k_labels:
                correct += 1
            total += 1
        
        results[f'recall@{k}'] = correct / total if total > 0 else 0.0
    
    return results


def calculate_mrr(embeddings, labels):
    similarities = torch.mm(embeddings, embeddings.t())
    
    mrr_sum = 0.0
    count = 0
    
    for i in range(len(labels)):
        true_label = labels[i]
        sims = similarities[i].clone()
        sims[i] = -float('inf')
        
        sorted_indices = torch.argsort(sims, descending=True)
        sorted_labels = labels[sorted_indices]
        
        for rank, label in enumerate(sorted_labels, 1):
            if label == true_label:
                mrr_sum += 1.0 / rank
                break
        
        count += 1
    
    return mrr_sum / count if count > 0 else 0.0


def calculate_comprehensive_metrics(y_true, y_pred, y_proba=None, embeddings=None):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    
    if cm.shape != (2, 2):
        cm_2x2 = np.zeros((2, 2))
        for i in range(min(cm.shape[0], 2)):
            for j in range(min(cm.shape[1], 2)):
                cm_2x2[i, j] = cm[i, j]
        cm = cm_2x2
    
    tn, fp, fn, tp = cm.ravel()
    
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = f1_score(y_true, y_pred, zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
    fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
    fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
    jacc = jaccard_score(y_true, y_pred, zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred)
    
    if y_proba is not None and len(np.unique(y_true)) > 1:
        try:
            roc_auc = roc_auc_score(y_true, y_proba)
            pr_auc = average_precision_score(y_true, y_proba)
            logloss = log_loss(y_true, np.column_stack([1-y_proba, y_proba]))
            brier = brier_score_loss(y_true, y_proba)
        except:
            roc_auc = pr_auc = logloss = brier = 0.0
    else:
        roc_auc = pr_auc = logloss = brier = 0.0
    
    results = {
        'cm': cm,
        'tn': int(tn),
        'fp': int(fp),
        'fn': int(fn),
        'tp': int(tp),
        'acc': acc,
        'bal_acc': bal_acc,
        'prec': prec,
        'rec': rec,
        'specificity': specificity,
        'npv': npv,
        'fpr': fpr,
        'fnr': fnr,
        'f1': f1,
        'mcc': mcc,
        'jacc': jacc,
        'kappa': kappa,
        'roc_auc': roc_auc,
        'pr_auc': pr_auc,
        'log_loss': logloss,
        'brier': brier
    }
    
    if embeddings is not None:
        recall_metrics = calculate_recall_at_k(embeddings, torch.tensor(y_true))
        mrr = calculate_mrr(embeddings, torch.tensor(y_true))
        results.update(recall_metrics)
        results['mrr'] = mrr
    
    return results


def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}


def train_model(model, train_data, num_epochs=5, learning_rate=1e-4, batch_size=2):
    optimizer = torch.optim.AdamW([
        {'params': model.prefix_encoder.parameters()},
        {'params': model.classifier.parameters()}
    ], lr=learning_rate, weight_decay=0.01)
    
    dataset = CodeCloneDataset(train_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    model.train()
    
    epoch_stats = []
    start_time = time.time()
    initial_memory = get_memory_usage()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0
        correct = 0
        total = 0
        batch_count = 0
        
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            
            optimizer.zero_grad()
            
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            
            loss = outputs['loss']
            if loss is not None:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    list(model.prefix_encoder.parameters()) + list(model.classifier.parameters()), 1.0
                )
                optimizer.step()
                total_loss += loss.item()
                
                with torch.no_grad():
                    predictions = torch.argmax(outputs['logits'], dim=-1)
                    correct += (predictions == labels.to(predictions.device)).sum().item()
                    total += labels.size(0)
            
            batch_count += 1
            
            if batch_count % 10 == 0:
                torch.cuda.empty_cache()
        
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / batch_count if batch_count > 0 else 0
        accuracy = correct / total if total > 0 else 0
        current_memory = get_memory_usage()
        
        epoch_stats.append({
            'epoch': epoch + 1,
            'loss': avg_loss,
            'acc': accuracy,
            'time': epoch_time,
            'memory_mb': current_memory
        })
        
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}, Acc: {accuracy:.4f}, Time: {epoch_time:.2f}s, Memory: {current_memory:.2f}MB")
        
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
    
    total_training_time = time.time() - start_time
    peak_memory = max([stat['memory_mb'] for stat in epoch_stats])
    memory_increase = peak_memory - initial_memory
    
    print(f"\n{'='*80}")
    print(f"TRAINING SUMMARY")
    print(f"{'='*80}")
    print(f"Total Training Time: {total_training_time:.2f}s")
    print(f"Peak Memory Usage: {peak_memory:.2f}MB")
    print(f"Memory Increase: {memory_increase:.2f}MB")
    print(f"{'='*80}\n")
    
    return epoch_stats


def evaluate_model(model, test_data, batch_size=2):
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    all_embeddings = []
    
    dataset = CodeCloneDataset(test_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    
    start_time = time.time()
    
    with torch.no_grad():
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']
            hidden_states = outputs['hidden_states']
            
            predictions = torch.argmax(logits, dim=-1)
            probabilities = F.softmax(logits, dim=-1)
            
            pooled_embeddings = hidden_states.mean(dim=1)
            pooled_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities[:, 1].cpu().numpy())
            all_embeddings.append(pooled_embeddings.cpu())
            
            if len(all_predictions) % 20 == 0:
                torch.cuda.empty_cache()
    
    inference_time = time.time() - start_time
    samples_per_sec = len(test_data) / inference_time
    
    all_embeddings = torch.cat(all_embeddings, dim=0)
    
    metrics = calculate_comprehensive_metrics(
        all_labels, 
        all_predictions, 
        all_probabilities,
        all_embeddings
    )
    
    metrics['inference_time'] = inference_time
    metrics['samples_per_sec'] = samples_per_sec
    
    return metrics


def print_results(metrics, dataset_name):
    print(f"\n{'='*80}")
    print(f"RESULTS: {dataset_name}")
    print(f"{'='*80}")
    print(f"\nConfusion Matrix:")
    print(f"{metrics['cm']}")
    print(f"\nTrue Negatives (TN): {metrics['tn']}")
    print(f"False Positives (FP): {metrics['fp']}")
    print(f"False Negatives (FN): {metrics['fn']}")
    print(f"True Positives (TP): {metrics['tp']}")
    print(f"\n{'='*80}")
    print(f"CLASSIFICATION METRICS")
    print(f"{'='*80}")
    print(f"Accuracy: {metrics['acc']:.4f}")
    print(f"Balanced Accuracy: {metrics['bal_acc']:.4f}")
    print(f"Precision: {metrics['prec']:.4f}")
    print(f"Recall (Sensitivity): {metrics['rec']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    print(f"Specificity: {metrics['specificity']:.4f}")
    print(f"Negative Predictive Value (NPV): {metrics['npv']:.4f}")
    print(f"False Positive Rate (FPR): {metrics['fpr']:.4f}")
    print(f"False Negative Rate (FNR): {metrics['fnr']:.4f}")
    print(f"\n{'='*80}")
    print(f"ADVANCED METRICS")
    print(f"{'='*80}")
    print(f"Jaccard Score: {metrics['jacc']:.4f}")
    print(f"Matthews Correlation Coefficient (MCC): {metrics['mcc']:.4f}")
    print(f"Cohen's Kappa: {metrics['kappa']:.4f}")
    
    if 'roc_auc' in metrics:
        print(f"\n{'='*80}")
        print(f"PROBABILISTIC METRICS")
        print(f"{'='*80}")
        print(f"ROC AUC Score: {metrics['roc_auc']:.4f}")
    if 'pr_auc' in metrics:
        print(f"Precision-Recall AUC: {metrics['pr_auc']:.4f}")
    if 'log_loss' in metrics:
        print(f"Log Loss: {metrics['log_loss']:.4f}")
    if 'brier' in metrics:
        print(f"Brier Score: {metrics['brier']:.4f}")
    
    if 'recall@1' in metrics:
        print(f"\n{'='*80}")
        print(f"RETRIEVAL METRICS")
        print(f"{'='*80}")
        print(f"Recall@1: {metrics['recall@1']:.4f}")
        print(f"Recall@3: {metrics['recall@3']:.4f}")
        print(f"Recall@5: {metrics['recall@5']:.4f}")
        print(f"Recall@10: {metrics['recall@10']:.4f}")
        print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")
    
    if 'inference_time' in metrics:
        print(f"\n{'='*80}")
        print(f"PERFORMANCE METRICS")
        print(f"{'='*80}")
        print(f"Inference Time: {metrics['inference_time']:.2f}s")
        print(f"Throughput: {metrics['samples_per_sec']:.2f} samples/sec")
    
    print(f"{'='*80}\n")


def main():
    torch.cuda.empty_cache()
    gc.collect()
    
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    
    cache_dir = '/hf_cache'
    
    tokenizer = AutoTokenizer.from_pretrained("Qwen/CodeQwen1.5-7B-Chat", cache_dir=cache_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    prefix_config = PrefixTuningConfig(
        prefix_length=20,
        num_layers=32,
        hidden_size=4096,
        num_heads=32,
        head_dim=128,
        reparam_dim=512
    )
    
    print("\nLoading model with pipeline parallelism across GPUs...")
    model = PrefixTuningCodeQwen("Qwen/CodeQwen1.5-7B-Chat", prefix_config, num_classes=2, cache_dir=cache_dir)
    print_parameter_statistics(model)
    
    print(f"\n{'='*80}")
    print("TRAINING PHASE")
    print(f"{'='*80}\n")
    
    train_csv = '/train.csv'
    train_data = create_training_data(
        train_csv, 
        tokenizer,
        label_col='label', 
        func1_col='func1', 
        func2_col='func2'
    )
    print(f"Created {len(train_data)} training examples\n")
    
    if len(train_data) == 0:
        print("No training data available. Exiting.")
        return
    
    epoch_stats = train_model(model, train_data, num_epochs=5, learning_rate=1e-4, batch_size=2)
    
    print(f"\n{'='*80}")
    print("TESTING PHASE")
    print(f"{'='*80}\n")
    
    test_csv = '/test.csv'
    test_data = create_training_data(
        test_csv,
        tokenizer,
        label_col='label',
        func1_col='func1',
        func2_col='func2'
    )
    print(f"Created {len(test_data)} test examples\n")
    
    if len(test_data) > 0:
        test_metrics = evaluate_model(model, test_data, batch_size=2)
        print_results(test_metrics, "TEST SET")
    
    print(f"\n{'='*80}")
    print("TRAINING COMPLETE")
    print(f"{'='*80}\n")


if __name__ == '__main__':
    main()

# Ts-pEFT


In [None]:
import math
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup
from torch.optim import AdamW
from datasets import Dataset as HFDataset
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, confusion_matrix,
    balanced_accuracy_score, precision_score, recall_score, jaccard_score,
    matthews_corrcoef, cohen_kappa_score, roc_auc_score, average_precision_score,
    log_loss, brier_score_loss
)
import pandas as pd
import numpy as np
import os
import logging
import time
from typing import List, Dict, Tuple
import psutil
import gc

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

class LoRALayer(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        rank: int = 4,
        alpha: float = 1.0,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        result = self.dropout(x) @ self.lora_A.T @ self.lora_B.T
        return result * self.scaling


class TSPEFTLayer(nn.Module):
    def __init__(
        self,
        base_layer: nn.Linear,
        rank: int = 4,
        alpha: float = 1.0,
        dropout: float = 0.0,
        s: float = 4e-5,
        lambda_reg: float = 1e-5,
        beta1: float = 0.9,
        beta2: float = 0.98,
        eps: float = 1e-8,
    ):
        super().__init__()
        self.base_layer = base_layer
        self.base_layer.requires_grad_(False)
        
        base_dtype = base_layer.weight.dtype
        
        self.lora = LoRALayer(
            base_layer.in_features,
            base_layer.out_features,
            rank=rank,
            alpha=alpha,
            dropout=dropout,
        )
        
        with torch.no_grad():
            self.lora.lora_A.data = self.lora.lora_A.data.to(base_dtype)
            self.lora.lora_B.data = self.lora.lora_B.data.to(base_dtype)
        
        self.s = s
        self.lambda_reg = lambda_reg
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        
        self.register_buffer('tau', torch.tensor(0.0, dtype=base_dtype))
        self.register_buffer('m', torch.tensor(0.0, dtype=base_dtype))
        self.register_buffer('v', torch.tensor(0.0, dtype=base_dtype))
        self.register_buffer('step', torch.tensor(0))
        
    def compute_relative_magnitude(
        self, 
        base_output: torch.Tensor, 
        lora_output: torch.Tensor
    ) -> torch.Tensor:
        base_norm = torch.norm(base_output, p=2, dim=-1, keepdim=True)
        lora_norm = torch.norm(lora_output, p=2, dim=-1, keepdim=True)
        r_i = lora_norm / (base_norm + self.eps)
        return r_i.squeeze(-1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.to(dtype=self.base_layer.weight.dtype)
        
        base_output = self.base_layer(x)
        lora_output = self.lora(x)
        
        if not self.training:
            r_i = self.compute_relative_magnitude(base_output, lora_output)
            gate = (r_i >= self.tau).float().unsqueeze(-1)
            return base_output + gate * lora_output
        
        r_i = self.compute_relative_magnitude(base_output, lora_output)
        gate = (r_i >= self.tau).float()
        
        gated_output = base_output + gate.unsqueeze(-1) * lora_output
        
        self._cache_for_backward = {
            'r_i': r_i,
            'gate': gate,
            'lora_output': lora_output,
            'base_output': base_output,
        }
        
        return gated_output
        
    def compute_threshold_gradient(self, grad_output: torch.Tensor) -> float:
        if not hasattr(self, '_cache_for_backward'):
            return 0.0
            
        cache = self._cache_for_backward
        r_i = cache['r_i']
        gate = cache['gate']
        lora_output = cache['lora_output']
        
        mu_i = (grad_output * lora_output).sum(dim=-1)
        
        consistency_mask = ((mu_i >= 0).float() == gate).float()
        sparsity_mask = gate
        
        grad_loss = -self.s * (consistency_mask * mu_i).sum()
        grad_sparsity = -self.s * (sparsity_mask * self.lambda_reg).sum()
        
        g_k = grad_loss + grad_sparsity
        
        return g_k.item()
    
    def update_threshold(self, grad_output: torch.Tensor, lr: float = 1.0):
        if not self.training:
            return
            
        g_k = self.compute_threshold_gradient(grad_output)
        
        self.step += 1
        
        self.m = self.beta1 * self.m + (1 - self.beta1) * g_k
        self.v = self.beta2 * self.v + (1 - self.beta2) * (g_k ** 2)
        
        m_hat = self.m / (1 - self.beta1 ** self.step.item())
        v_hat = self.v / (1 - self.beta2 ** self.step.item())
        
        tau_update = lr * self.s * m_hat / (torch.sqrt(v_hat) + self.eps)
        self.tau = torch.clamp(self.tau + tau_update, min=0.0)
        
        if hasattr(self, '_cache_for_backward'):
            delattr(self, '_cache_for_backward')


class ClassificationHead(nn.Module):
    def __init__(self, hidden_size, num_labels=1, dtype=torch.float16):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size, dtype=dtype)
        self.dropout = nn.Dropout(0.1)
        self.out_proj = nn.Linear(hidden_size, num_labels, dtype=dtype)
        
    def forward(self, hidden_states):
        x = self.dropout(hidden_states)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class TSPEFTCloneTrainer:
    def __init__(self, train_file, test_file, func1_col, func2_col, label_col, cache_dir=None, 
                 rank=32, alpha=0.5, dropout=0.05, s=4e-5, lambda_reg=4.5e-5):
        self.train_file = train_file
        self.test_file = test_file
        self.func1_col = func1_col
        self.func2_col = func2_col
        self.label_col = label_col
        self.cache_dir = cache_dir
        self.rank = rank
        self.alpha = alpha
        self.dropout = dropout
        self.s = s
        self.lambda_reg = lambda_reg
        
        torch.cuda.empty_cache()
        gc.collect()
        
        num_gpus = torch.cuda.device_count()
        logger.info(f"Number of available GPUs: {num_gpus}")
        for i in range(num_gpus):
            logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        
        logger.info("Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            "Qwen/CodeQwen1.5-7B-Chat", 
            trust_remote_code=True,
            cache_dir=cache_dir,
            local_files_only=True
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        logger.info("Loading base model with automatic device mapping...")
        self.base_model = AutoModelForCausalLM.from_pretrained(
            "Qwen/CodeQwen1.5-7B-Chat",
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
            cache_dir=cache_dir,
            local_files_only=True,
            low_cpu_mem_usage=True
        )
        
        hidden_size = self.base_model.config.hidden_size
        logger.info(f"Model hidden size: {hidden_size}")
        
        self.classification_head = None
        self.ts_peft_layers = {}
        
        self.print_model_distribution()
        
        logger.info("Loading datasets...")
        self.train_data = HFDataset.from_csv(train_file)
        self.test_data = HFDataset.from_csv(test_file)
        logger.info(f"Train samples: {len(self.train_data)}")
        logger.info(f"Test samples: {len(self.test_data)}")
        
        self.training_stats = {
            'epoch_times': [],
            'epoch_losses': [],
            'memory_usage': []
        }
        
        self.optimizer = None
        self.scheduler = None
    
    def print_model_distribution(self):
        logger.info("\n" + "="*80)
        logger.info("MODEL DEVICE DISTRIBUTION")
        logger.info("="*80)
        
        if hasattr(self.base_model, 'hf_device_map'):
            device_map = self.base_model.hf_device_map
            for module_name, device in device_map.items():
                logger.info(f"{module_name}: {device}")
        
        logger.info("\n" + "="*80)
        logger.info("GPU MEMORY USAGE")
        logger.info("="*80)
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            reserved = torch.cuda.memory_reserved(i) / 1024**3
            logger.info(f"GPU {i}: Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
        logger.info("="*80 + "\n")
    
    def preprocess(self, example):
        c1 = example[self.func1_col]
        c2 = example[self.func2_col]
        y = example[self.label_col]
        text = f"<func1>\n{c1}\n</func1>\n<func2>\n{c2}\n</func2>"
        out = self.tokenizer(text, truncation=True, padding="max_length", max_length=512)
        out["labels"] = y
        return out
    
    def _setup_tspeft_parameters(self):
        logger.info("Setting up TS-PEFT parameters...")
        
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'up_proj', 'down_proj', 'gate_proj']
        
        trainable_params = 0
        total_params = 0
        lora_param_count = 0
        
        for name, module in self.base_model.named_modules():
            total_params += sum(p.numel() for p in module.parameters(recurse=False))
            
            if isinstance(module, nn.Linear):
                should_replace = False
                for target in target_modules:
                    if target in name:
                        should_replace = True
                        break
                
                if should_replace:
                    parent_name = '.'.join(name.split('.')[:-1])
                    child_name = name.split('.')[-1]
                    
                    parent = self.base_model
                    if parent_name:
                        for part in parent_name.split('.'):
                            parent = getattr(parent, part)
                    
                    module_device = next(module.parameters()).device
                    module_dtype = next(module.parameters()).dtype
                    
                    ts_layer = TSPEFTLayer(
                        module,
                        rank=self.rank,
                        alpha=self.alpha,
                        dropout=self.dropout,
                        s=self.s,
                        lambda_reg=self.lambda_reg,
                    )
                    
                    ts_layer = ts_layer.to(device=module_device, dtype=module_dtype)
                    
                    setattr(parent, child_name, ts_layer)
                    layer_key = name.replace('.', '_')
                    self.ts_peft_layers[layer_key] = ts_layer
                    
                    lora_params = sum(p.numel() for p in ts_layer.lora.parameters())
                    trainable_params += lora_params
                    lora_param_count += lora_params
                    
                    logger.info(f"✓ TS-PEFT Layer: {name} on {module_device} - LoRA Params: {lora_params:,}")
        
        hidden_size = self.base_model.config.hidden_size
        first_device = next(self.base_model.parameters()).device
        model_dtype = next(self.base_model.parameters()).dtype
        
        self.classification_head = ClassificationHead(hidden_size, num_labels=1, dtype=model_dtype)
        self.classification_head = self.classification_head.to(first_device)
        
        head_params = sum(p.numel() for p in self.classification_head.parameters())
        trainable_params += head_params
        total_params += head_params
        
        logger.info(f"\n✓ Classification Head added: {head_params:,} parameters")
        for name, param in self.classification_head.named_parameters():
            logger.info(f"  ✓ {name} - Shape: {param.shape} - Params: {param.numel():,}")
        
        logger.info(f"\n{'='*80}")
        logger.info("TS-PEFT PARAMETER STATISTICS")
        logger.info(f"{'='*80}")
        logger.info(f"Base Model Total Parameters: {total_params - head_params:,}")
        logger.info(f"Trainable LoRA Parameters: {lora_param_count:,}")
        logger.info(f"Classification Head Parameters: {head_params:,}")
        logger.info(f"Total Trainable Parameters: {trainable_params:,}")
        logger.info(f"Total Parameters: {total_params:,}")
        logger.info(f"Trainable Percentage: {100 * trainable_params / total_params:.6f}%")
        logger.info(f"Total TS-PEFT Layers: {len(self.ts_peft_layers)}")
        logger.info(f"{'='*80}\n")
    
    def calculate_metrics(self, y_true, y_pred, y_prob=None):
        cm = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = cm.ravel()
        
        acc = accuracy_score(y_true, y_pred)
        bal_acc = balanced_accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred, zero_division=0)
        rec = recall_score(y_true, y_pred, zero_division=0)
        
        p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary', zero_division=0)
        
        jacc = jaccard_score(y_true, y_pred, zero_division=0)
        mcc = matthews_corrcoef(y_true, y_pred)
        kappa = cohen_kappa_score(y_true, y_pred)
        
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
        fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
        
        metrics = {
            'cm': cm,
            'tn': int(tn),
            'fp': int(fp),
            'fn': int(fn),
            'tp': int(tp),
            'acc': acc,
            'bal_acc': bal_acc,
            'prec': prec,
            'rec': rec,
            'f1': f1,
            'jacc': jacc,
            'mcc': mcc,
            'kappa': kappa,
            'specificity': specificity,
            'npv': npv,
            'fpr': fpr,
            'fnr': fnr
        }
        
        if y_prob is not None:
            try:
                roc_auc = roc_auc_score(y_true, y_prob)
                pr_auc = average_precision_score(y_true, y_prob)
                ll = log_loss(y_true, y_prob)
                brier = brier_score_loss(y_true, y_prob)
                
                metrics.update({
                    'roc_auc': roc_auc,
                    'pr_auc': pr_auc,
                    'log_loss': ll,
                    'brier': brier
                })
            except Exception as e:
                logger.warning(f"Could not calculate probabilistic metrics: {e}")
        
        return metrics
    
    def calculate_recall_at_k(self, embeddings, labels, k_values=[1, 3, 5, 10]):
        results = {}
        
        similarity_matrix = torch.nn.functional.cosine_similarity(
            embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2
        )
        
        for k in k_values:
            correct = 0
            total = 0
            
            for i in range(len(labels)):
                query_label = labels[i]
                similarities = similarity_matrix[i]
                similarities[i] = -float('inf')
                
                if k > len(labels) - 1:
                    k_actual = len(labels) - 1
                else:
                    k_actual = k
                
                top_k_indices = torch.topk(similarities, k_actual).indices
                top_k_labels = [labels[idx] for idx in top_k_indices]
                
                if query_label in top_k_labels:
                    correct += 1
                total += 1
            
            results[f'recall@{k}'] = correct / total if total > 0 else 0.0
        
        reciprocal_ranks = []
        for i in range(len(labels)):
            query_label = labels[i]
            similarities = similarity_matrix[i]
            similarities[i] = -float('inf')
            
            sorted_indices = torch.argsort(similarities, descending=True)
            sorted_labels = [labels[idx] for idx in sorted_indices]
            
            try:
                rank = sorted_labels.index(query_label) + 1
                reciprocal_ranks.append(1.0 / rank)
            except ValueError:
                reciprocal_ranks.append(0.0)
        
        results['mrr'] = np.mean(reciprocal_ranks)
        
        return results
    
    def collate_fn(self, batch):
        input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch])
        attention_mask = torch.stack([torch.tensor(item['attention_mask']) for item in batch])
        labels = torch.tensor([item['labels'] for item in batch])
        return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}
    
    def setup_optimizer(self, train_dataloader, num_epochs):
        trainable_params = []
        for layer in self.ts_peft_layers.values():
            trainable_params.extend(layer.lora.parameters())
        trainable_params += list(self.classification_head.parameters())
        
        self.optimizer = AdamW(
            trainable_params,
            lr=2e-4,
            weight_decay=0.01
        )
        
        self.optimizer.zero_grad()
        
        total_steps = len(train_dataloader) * num_epochs
        warmup_steps = int(0.03 * total_steps)
        
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )
        
        logger.info(f"Optimizer setup complete. Total training steps: {total_steps}")
        logger.info(f"Warmup steps: {warmup_steps}")
    
    def update_thresholds(self, lr: float = 1.0):
        for layer in self.ts_peft_layers.values():
            if hasattr(layer, '_cache_for_backward') and layer.training:
                grad_output = torch.ones_like(
                    layer._cache_for_backward['base_output']
                )
                layer.update_threshold(grad_output, lr)
    
    def get_sparsity_stats(self):
        stats = {}
        for name, layer in self.ts_peft_layers.items():
            if hasattr(layer, '_cache_for_backward'):
                gate = layer._cache_for_backward['gate']
                sparsity = (1 - gate.float().mean()).item() * 100
                stats[name] = sparsity
        return stats
    
    def train_epoch(self, train_dataloader, epoch, num_epochs):
            self.base_model.train()
            self.classification_head.train()
            
            total_loss = 0
            total_samples = 0
            accumulation_steps = 4
            
            for batch_idx, batch in enumerate(train_dataloader):
                first_device = next(self.base_model.parameters()).device
                
                input_ids = batch['input_ids'].to(first_device)
                attention_mask = batch['attention_mask'].to(first_device)
                labels = batch['labels'].to(first_device)
                
                with torch.cuda.amp.autocast():
                    outputs = self.base_model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        output_hidden_states=True
                    )
                    
                    hidden_states = outputs.hidden_states[-1]
                    pooled = hidden_states.mean(dim=1)
                    
                    logits = self.classification_head(pooled).squeeze(-1)
                    
                    loss_fn = nn.BCEWithLogitsLoss()
                    loss = loss_fn(logits, labels.float())
                    loss = loss / accumulation_steps
                
                loss.backward()
                
                if (batch_idx + 1) % accumulation_steps == 0:
                    self.update_thresholds(lr=1.0)
                    
                    torch.nn.utils.clip_grad_norm_(
                        [p for layer in self.ts_peft_layers.values() for p in layer.lora.parameters()] + 
                        list(self.classification_head.parameters()), 
                        max_norm=1.0
                    )
                    
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    if self.scheduler:
                        self.scheduler.step()
                
                total_loss += loss.item() * accumulation_steps
                total_samples += input_ids.size(0)
                
                if batch_idx % 20 == 0:
                    current_lr = self.optimizer.param_groups[0]['lr']
                    sparsity_stats = self.get_sparsity_stats()
                    avg_sparsity = sum(sparsity_stats.values()) / len(sparsity_stats) if sparsity_stats else 0.0
                    logger.info(
                        f"Epoch {epoch+1}/{num_epochs} - "
                        f"Batch {batch_idx}/{len(train_dataloader)} - "
                        f"Loss: {(loss.item() * accumulation_steps):.4f} - "
                        f"LR: {current_lr:.2e} - "
                        f"Sparsity: {avg_sparsity:.2f}%"
                    )
                    
                    torch.cuda.empty_cache()
                    gc.collect()
            
            avg_loss = total_loss / len(train_dataloader)
            return {'loss': avg_loss}
    
    def train(self):
        train_dataset = self.train_data.map(self.preprocess, batched=False)
        test_dataset = self.test_data.map(self.preprocess, batched=False)
        
        logger.info("\n" + "="*80)
        logger.info("APPLYING TS-PEFT CONFIGURATION")
        logger.info("="*80)
        
        self._setup_tspeft_parameters()
        
        self.print_model_distribution()
        
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=4,
            shuffle=True,
            collate_fn=self.collate_fn,
            num_workers=0,
            pin_memory=True
        )
        
        num_epochs = 5
        
        self.setup_optimizer(train_dataloader, num_epochs)
        
        logger.info("="*80)
        logger.info("STARTING TRAINING")
        logger.info(f"Total Epochs: {num_epochs}")
        logger.info("="*80)
        
        training_start = time.time()
        
        process = psutil.Process(os.getpid())
        mem_before = process.memory_info().rss / 1024 / 1024
        
        for epoch in range(num_epochs):
            epoch_start = time.time()
            
            logger.info(f"\n{'='*60}")
            logger.info(f"Epoch {epoch + 1}/{num_epochs}")
            logger.info(f"{'='*60}")
            
            train_metrics = self.train_epoch(train_dataloader, epoch, num_epochs)
            
            epoch_time = time.time() - epoch_start
            self.training_stats['epoch_times'].append(epoch_time)
            self.training_stats['epoch_losses'].append(train_metrics['loss'])
            
            logger.info(f"Epoch {epoch+1} completed in {epoch_time:.2f}s")
            logger.info(f"Training Loss: {train_metrics['loss']:.4f}")
            
            torch.cuda.empty_cache()
            gc.collect()
        
        total_training_time = time.time() - training_start
        
        mem_after = process.memory_info().rss / 1024 / 1024
        mem_used = mem_after - mem_before
        
        logger.info(f"\n{'='*80}")
        logger.info("TRAINING COMPLETE")
        logger.info(f"{'='*80}")
        logger.info(f"Total Training Time: {total_training_time:.2f}s")
        logger.info(f"Average Epoch Time: {np.mean(self.training_stats['epoch_times']):.2f}s")
        logger.info(f"Final Training Loss: {self.training_stats['epoch_losses'][-1]:.4f}")
        logger.info(f"Memory Usage: {mem_used:.1f} MB")
        logger.info(f"{'='*80}\n")
        
        self.print_model_distribution()
        
        os.makedirs("codeqwen1.5-tspeft-clone", exist_ok=True)
        self.base_model.save_pretrained("codeqwen1.5-tspeft-clone")
        torch.save(self.classification_head.state_dict(), "codeqwen1.5-tspeft-clone/classification_head.pt")
        logger.info("Models saved successfully to codeqwen1.5-tspeft-clone/")
        
        torch.cuda.empty_cache()
        gc.collect()
        
        return self.base_model, total_training_time
    
    def evaluate_comprehensive(self, model, test_dataset):
        model.eval()
        self.classification_head.eval()
        
        logger.info("="*80)
        logger.info("STARTING COMPREHENSIVE EVALUATION")
        logger.info("="*80)
        
        test_df = pd.read_csv(self.test_file)
        test_labels = test_df[self.label_col].tolist()
        
        all_preds = []
        all_probs = []
        all_embeddings = []
        
        inference_start = time.time()
        total_samples = len(test_dataset)
        
        logger.info(f"Processing {total_samples} test samples...")
        
        with torch.no_grad():
            for idx in range(len(test_dataset)):
                sample = test_dataset[idx]
                
                input_ids = torch.tensor([sample['input_ids']])
                attention_mask = torch.tensor([sample['attention_mask']])
                
                first_device = next(model.parameters()).device
                input_ids = input_ids.to(first_device)
                attention_mask = attention_mask.to(first_device)
                
                outputs = model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask, 
                    output_hidden_states=True
                )
                
                hidden_states = outputs.hidden_states[-1]
                pooled = hidden_states.mean(dim=1)
                
                all_embeddings.append(pooled.squeeze().cpu())
                
                logits = self.classification_head(pooled).squeeze(-1)
                
                prob = torch.sigmoid(logits).item()
                pred = 1 if prob > 0.5 else 0
                
                all_preds.append(pred)
                all_probs.append(prob)
                
                if (idx + 1) % 100 == 0:
                    logger.info(f"Processed {idx + 1}/{total_samples} samples")
                    torch.cuda.empty_cache()
        
        inference_time = time.time() - inference_start
        samples_per_sec = total_samples / inference_time
        
        logger.info(f"Inference completed in {inference_time:.2f}s")
        logger.info(f"Throughput: {samples_per_sec:.2f} samples/sec")
        
        metrics = self.calculate_metrics(test_labels, all_preds, all_probs)
        
        all_embeddings = torch.stack(all_embeddings)
        recall_metrics = self.calculate_recall_at_k(all_embeddings, test_labels)
        metrics.update(recall_metrics)
        
        metrics['inference_time'] = inference_time
        metrics['samples_per_sec'] = samples_per_sec
        
        torch.cuda.empty_cache()
        gc.collect()
        
        return metrics


def print_results(metrics, dataset_name):
    print(f"\n{'='*80}")
    print(f"RESULTS: {dataset_name}")
    print(f"{'='*80}")
    print(f"\nConfusion Matrix:")
    print(f"{metrics['cm']}")
    print(f"\nTrue Negatives (TN): {metrics['tn']}")
    print(f"False Positives (FP): {metrics['fp']}")
    print(f"False Negatives (FN): {metrics['fn']}")
    print(f"True Positives (TP): {metrics['tp']}")
    print(f"\n{'='*80}")
    print(f"CLASSIFICATION METRICS")
    print(f"{'='*80}")
    print(f"Accuracy: {metrics['acc']:.4f}")
    print(f"Balanced Accuracy: {metrics['bal_acc']:.4f}")
    print(f"Precision: {metrics['prec']:.4f}")
    print(f"Recall (Sensitivity): {metrics['rec']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    print(f"Specificity: {metrics['specificity']:.4f}")
    print(f"Negative Predictive Value (NPV): {metrics['npv']:.4f}")
    print(f"False Positive Rate (FPR): {metrics['fpr']:.4f}")
    print(f"False Negative Rate (FNR): {metrics['fnr']:.4f}")
    print(f"\n{'='*80}")
    print(f"ADVANCED METRICS")
    print(f"{'='*80}")
    print(f"Jaccard Score: {metrics['jacc']:.4f}")
    print(f"Matthews Correlation Coefficient (MCC): {metrics['mcc']:.4f}")
    print(f"Cohen's Kappa: {metrics['kappa']:.4f}")
    
    if 'roc_auc' in metrics:
        print(f"\n{'='*80}")
        print(f"PROBABILISTIC METRICS")
        print(f"{'='*80}")
        print(f"ROC AUC Score: {metrics['roc_auc']:.4f}")
    if 'pr_auc' in metrics:
        print(f"Precision-Recall AUC: {metrics['pr_auc']:.4f}")
    if 'log_loss' in metrics:
        print(f"Log Loss: {metrics['log_loss']:.4f}")
    if 'brier' in metrics:
        print(f"Brier Score: {metrics['brier']:.4f}")
    
    if 'recall@1' in metrics:
        print(f"\n{'='*80}")
        print(f"RETRIEVAL METRICS")
        print(f"{'='*80}")
        print(f"Recall@1: {metrics['recall@1']:.4f}")
        print(f"Recall@3: {metrics['recall@3']:.4f}")
        print(f"Recall@5: {metrics['recall@5']:.4f}")
        print(f"Recall@10: {metrics['recall@10']:.4f}")
        print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")
    
    if 'inference_time' in metrics:
        print(f"\n{'='*80}")
        print(f"PERFORMANCE METRICS")
        print(f"{'='*80}")
        print(f"Inference Time: {metrics['inference_time']:.2f}s")
        print(f"Throughput: {metrics['samples_per_sec']:.2f} samples/sec")
    
    print(f"{'='*80}\n")


def run_tspeft_experiment(train_file, test_file, func1_col, func2_col, label_col, cache_dir=None, 
                          rank=32, alpha=0.5, dropout=0.05, s=4e-5, lambda_reg=4.5e-5):
    logger.info("="*80)
    logger.info("CODE CLONE DETECTION WITH TS-PEFT")
    logger.info("="*80)
    
    tspeft_trainer = TSPEFTCloneTrainer(
        train_file, 
        test_file, 
        func1_col, 
        func2_col, 
        label_col,
        cache_dir=cache_dir,
        rank=rank,
        alpha=alpha,
        dropout=dropout,
        s=s,
        lambda_reg=lambda_reg
    )
    
    model, training_time = tspeft_trainer.train()
    
    logger.info(f"\n{'='*80}")
    logger.info("FINAL TEST EVALUATION")
    logger.info(f"{'='*80}")
    
    test_dataset = tspeft_trainer.test_data.map(tspeft_trainer.preprocess, batched=False)
    metrics = tspeft_trainer.evaluate_comprehensive(model, test_dataset)
    
    print_results(metrics, "TEST SET")
    
    return tspeft_trainer, model, metrics


if __name__ == "__main__":
    train_file = '/train.csv'
    test_file = '/test.csv'
    func1_col = "func1"
    func2_col = "func2"
    label_col = "label"
    cache_dir = '/hf_cache'
    
    try:
        tspeft_trainer, model, results = run_tspeft_experiment(
            train_file, 
            test_file, 
            func1_col, 
            func2_col, 
            label_col,
            cache_dir=cache_dir,
            rank=32,
            alpha=0.5,
            dropout=0.05,
            s=4e-5,
            lambda_reg=4.5e-5
        )
        logger.info("TS-PEFT experiment completed successfully!")
        
    except Exception as e:
        logger.error(f"Error: {e}")
        import traceback
        traceback.print_exc()
        raise