# GateRA

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import numpy as np
from sklearn.metrics import (
    accuracy_score, confusion_matrix, balanced_accuracy_score, 
    matthews_corrcoef, roc_auc_score, average_precision_score, 
    f1_score, jaccard_score, cohen_kappa_score, log_loss, brier_score_loss
)
import time
import psutil
import os
import gc
import math
from torch.utils.data import Dataset, DataLoader
from typing import Optional, List, Dict
from dataclasses import dataclass


@dataclass
class GateRAConfig:
    rank: int = 16
    alpha: float = 16.0
    dropout: float = 0.0
    target_modules: List[str] = None
    entropy_reg_weight: float = 0.01
    
    def __post_init__(self):
        if self.target_modules is None:
            self.target_modules = ["q", "k", "v", "o", "wi", "wo"]


class GatingModule(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.gate_linear = nn.Linear(input_dim, 1, bias=True)
        nn.init.zeros_(self.gate_linear.weight)
        nn.init.zeros_(self.gate_linear.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_logits = self.gate_linear(x)
        gate_values = torch.sigmoid(gate_logits)
        return gate_values


class GateRALayer(nn.Module):
    def __init__(
        self,
        base_layer: nn.Module,
        rank: int,
        alpha: float,
        dropout: float,
        input_dim: int,
        output_dim: int
    ):
        super().__init__()
        self.base_layer = base_layer
        self.rank = rank
        self.scaling = alpha / rank
        self.dropout = dropout
        
        self.lora_A = nn.Parameter(torch.zeros(input_dim, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, output_dim))
        
        self.gating_module = GatingModule(input_dim)
        
        self.dropout_layer = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
        
        self.reset_parameters()
    
    @property
    def weight(self):
        return self.base_layer.weight
    
    @property
    def bias(self):
        return self.base_layer.bias if hasattr(self.base_layer, 'bias') else None
    
    @property
    def in_features(self):
        return self.base_layer.in_features
    
    @property
    def out_features(self):
        return self.base_layer.out_features
    
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        base_output = self.base_layer(x)
        
        original_shape = x.shape
        if x.dim() == 3:
            batch_size, seq_len, hidden_dim = x.shape
            x_flat = x.view(-1, hidden_dim)
        else:
            x_flat = x
            batch_size, seq_len = None, None
        
        with torch.cuda.amp.autocast(enabled=False):
            x_flat_fp32 = x_flat.float()
            gate_values = self.gating_module(x_flat_fp32)
            
            lora_output = torch.matmul(x_flat_fp32, self.lora_A.float())
            lora_output = torch.matmul(lora_output, self.lora_B.float())
            lora_output = self.dropout_layer(lora_output)
            
            modulated_output = (gate_values * lora_output * self.scaling).to(base_output.dtype)
        
        if batch_size is not None and seq_len is not None:
            modulated_output = modulated_output.view(batch_size, seq_len, -1)
        
        return base_output + modulated_output
    

class CodeCloneDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]




class GateRACodeT5(nn.Module):
    def __init__(self, model_name: str, gatera_config: GateRAConfig, num_classes: int = 2, cache_dir: str = None):
        super().__init__()
        torch.cuda.empty_cache()
        gc.collect()
        num_gpus = torch.cuda.device_count()
        print(f"Initializing model across {num_gpus} GPUs")
        self.primary_device = 'cuda:0'
        print("Loading base LLM model...")
        base_model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name, 
            torch_dtype=torch.float32, 
            cache_dir=cache_dir, 
            low_cpu_mem_usage=True, 
            device_map='auto'
        )
        self.llm = base_model
        if hasattr(self.llm, 'gradient_checkpointing_enable'):
            self.llm.gradient_checkpointing_enable()
        self.gatera_config = gatera_config

        encoder_layers = len(base_model.encoder.block)
        decoder_layers = len(base_model.decoder.block)
        total_layers = encoder_layers + decoder_layers
        print(f"Model has {encoder_layers} encoder layers and {decoder_layers} decoder layers (total: {total_layers})")
        self.layer_devices = []
        for i in range(encoder_layers):
            layer_device = next(base_model.encoder.block[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Encoder Layer {i} -> {layer_device}")
        for i in range(decoder_layers):
            layer_device = next(base_model.decoder.block[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Decoder Layer {i} -> {layer_device}")
        embed_device = next(base_model.encoder.embed_tokens.parameters()).device
        final_device = next(base_model.decoder.final_layer_norm.parameters()).device
        print(f"Embeddings -> {embed_device}")
        print(f"Final Layer -> {final_device}")
        self.embed_device = str(embed_device)
        self.final_device = str(final_device)
        for param in self.llm.parameters():
            param.requires_grad = False
        self.gatera_layers = nn.ModuleDict()
        self._inject_gatera_layers()
        self.num_classes = num_classes
        hidden_size = base_model.config.d_model
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 4, num_classes)
        ).to(self.primary_device).float()
        for param in self.classifier.parameters():
            param.requires_grad = True
        print("Model initialization complete!")
    
    def _inject_gatera_layers(self):
        trainable_params = 0
        layer_counter = 0
        for name, module in self.llm.named_modules():
            if isinstance(module, nn.Linear):
                parent_name = '.'.join(name.split('.')[:-1])
                layer_name = name.split('.')[-1]
                
                should_inject = any(target in layer_name for target in self.gatera_config.target_modules)
                
                if should_inject:
                    input_dim = module.in_features
                    output_dim = module.out_features
                    
                    gatera_layer = GateRALayer(
                        base_layer=module,
                        rank=self.gatera_config.rank,
                        alpha=self.gatera_config.alpha,
                        dropout=self.gatera_config.dropout,
                        input_dim=input_dim,
                        output_dim=output_dim
                    )
                    
                    layer_device = next(module.parameters()).device
                    gatera_layer = gatera_layer.to(layer_device)
                    
                    safe_key = f"gatera_layer_{layer_counter}"
                    self.gatera_layers[safe_key] = gatera_layer
                    layer_counter += 1
                    
                    parent = self.llm
                    if parent_name:
                        for part in parent_name.split('.'):
                            parent = getattr(parent, part)
                    setattr(parent, layer_name, gatera_layer)
                    
                    trainable_params += gatera_layer.lora_A.numel() + gatera_layer.lora_B.numel()
                    trainable_params += sum(p.numel() for p in gatera_layer.gating_module.parameters())
        
        print(f"GateRA enabled with {trainable_params:,} trainable parameters across {len(self.gatera_layers)} layers")
        
    def forward(self, input_ids, attention_mask=None, labels=None):
        input_ids = input_ids.to(self.embed_device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.embed_device)
        outputs = self.llm(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            decoder_input_ids=input_ids, 
            output_hidden_states=True, 
            return_dict=True
        )
        hidden_states = outputs.decoder_hidden_states[-1]
        if str(hidden_states.device) != self.primary_device:
            hidden_states = hidden_states.to(self.primary_device)
        pooled_output = hidden_states.mean(dim=1)
        logits = self.classifier(pooled_output)
        loss = None
        entropy_loss = None
        if labels is not None:
            if labels.device != logits.device:
                labels = labels.to(logits.device)
            loss_fct = nn.CrossEntropyLoss()
            task_loss = loss_fct(logits, labels)
            
            total_entropy_loss = 0.0
            gate_count = 0
            batch_size = input_ids.shape[0]
            
            if len(self.gatera_layers) > 0:
                sample_layer = next(iter(self.gatera_layers.values()))
                if hasattr(sample_layer, 'gating_module') and hasattr(sample_layer, 'lora_A'):
                    try:
                        dummy_input = torch.randn(
                            min(batch_size, 2),
                            sample_layer.lora_A.shape[0],
                            device=sample_layer.lora_A.device,
                            dtype=torch.float32
                        )
                        gate_vals = sample_layer.gating_module(dummy_input)
                        eps = 1e-8
                        gate_vals = torch.clamp(gate_vals, eps, 1.0 - eps)
                        entropy = -gate_vals * torch.log(gate_vals) - (1 - gate_vals) * torch.log(1 - gate_vals)
                        total_entropy_loss = entropy.mean()
                        gate_count = 1
                        del dummy_input, gate_vals, entropy
                    except:
                        pass
            
            if gate_count > 0:
                entropy_loss = total_entropy_loss
            else:
                entropy_loss = torch.tensor(0.0, device=task_loss.device)
            
            loss = task_loss + self.gatera_config.entropy_reg_weight * entropy_loss
        
        return {'logits': logits, 'loss': loss, 'entropy_loss': entropy_loss, 'hidden_states': hidden_states}


def print_parameter_statistics(model: GateRACodeT5):
    total_params = sum(p.numel() for p in model.llm.parameters())
    gatera_params = sum(p.numel() for p in model.gatera_layers.parameters() if p.requires_grad)
    classifier_params = sum(p.numel() for p in model.classifier.parameters())
    trainable_params = gatera_params + classifier_params
    print(f"\n{'='*80}\nPARAMETER STATISTICS\n{'='*80}")
    print(f"Total Frozen LLM Parameters: {total_params:,}")
    print(f"Trainable GateRA Parameters: {gatera_params:,}")
    print(f"Trainable Classifier Parameters: {classifier_params:,}")
    print(f"Total Trainable Parameters: {trainable_params:,}")
    print(f"Trainable Percentage: {(trainable_params / total_params * 100):.4f}%")
    print(f"GateRA Rank: {model.gatera_config.rank}")
    print(f"GateRA Alpha: {model.gatera_config.alpha}")
    print(f"Entropy Regularization Weight: {model.gatera_config.entropy_reg_weight}")
    print(f"{'='*80}\n")
    print(f"{'='*80}\nGPU MEMORY USAGE\n{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        print(f"GPU {i}: Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
    print(f"{'='*80}\n")


def load_dataset(csv_path, label_col, func1_col, func2_col):
    try:
        df = pd.read_csv(csv_path)
        print(f"Loaded dataset: {len(df)} samples")
        print(f"Label distribution: {df[label_col].value_counts().to_dict()}")
        return df
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None


def create_training_data(csv_path, tokenizer, max_length=512, label_col='label', func1_col='func1', func2_col='func2'):
    df = load_dataset(csv_path, label_col, func1_col, func2_col)
    if df is None:
        return []
    train_data = []
    for _, row in df.iterrows():
        func1 = str(row[func1_col])
        func2 = str(row[func2_col])
        label = int(row[label_col])
        combined_text = f"Code1: {func1}\nCode2: {func2}\nAre these code clones?"
        encoding = tokenizer(
            combined_text, 
            return_tensors='pt', 
            padding='max_length', 
            truncation=True, 
            max_length=max_length
        )
        train_data.append({
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        })
    return train_data


def get_memory_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024


def calculate_recall_at_k(embeddings, labels, k_values=[1, 3, 5, 10]):
    similarities = torch.mm(embeddings, embeddings.t())
    results = {}
    for k in k_values:
        correct = 0
        total = 0
        for i in range(len(labels)):
            true_label = labels[i]
            sims = similarities[i].clone()
            sims[i] = -float('inf')
            top_k_indices = torch.topk(sims, min(k, len(labels)-1)).indices
            top_k_labels = labels[top_k_indices]
            if true_label in top_k_labels:
                correct += 1
            total += 1
        results[f'recall@{k}'] = correct / total if total > 0 else 0.0
    return results


def calculate_mrr(embeddings, labels):
    similarities = torch.mm(embeddings, embeddings.t())
    mrr_sum = 0.0
    count = 0
    for i in range(len(labels)):
        true_label = labels[i]
        sims = similarities[i].clone()
        sims[i] = -float('inf')
        sorted_indices = torch.argsort(sims, descending=True)
        sorted_labels = labels[sorted_indices]
        for rank, label in enumerate(sorted_labels, 1):
            if label == true_label:
                mrr_sum += 1.0 / rank
                break
        count += 1
    return mrr_sum / count if count > 0 else 0.0


def calculate_comprehensive_metrics(y_true, y_pred, y_proba=None, embeddings=None):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    if cm.shape != (2, 2):
        cm_2x2 = np.zeros((2, 2))
        for i in range(min(cm.shape[0], 2)):
            for j in range(min(cm.shape[1], 2)):
                cm_2x2[i, j] = cm[i, j]
        cm = cm_2x2
    tn, fp, fn, tp = cm.ravel()
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = f1_score(y_true, y_pred, zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
    fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
    fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
    jacc = jaccard_score(y_true, y_pred, zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred)
    if y_proba is not None and len(np.unique(y_true)) > 1:
        try:
            roc_auc = roc_auc_score(y_true, y_proba)
            pr_auc = average_precision_score(y_true, y_proba)
            logloss = log_loss(y_true, np.column_stack([1-y_proba, y_proba]))
            brier = brier_score_loss(y_true, y_proba)
        except:
            roc_auc = pr_auc = logloss = brier = 0.0
    else:
        roc_auc = pr_auc = logloss = brier = 0.0
    results = {
        'cm': cm, 'tn': int(tn), 'fp': int(fp), 'fn': int(fn), 'tp': int(tp),
        'acc': acc, 'bal_acc': bal_acc, 'prec': prec, 'rec': rec,
        'specificity': specificity, 'npv': npv, 'fpr': fpr, 'fnr': fnr,
        'f1': f1, 'mcc': mcc, 'jacc': jacc, 'kappa': kappa,
        'roc_auc': roc_auc, 'pr_auc': pr_auc, 'log_loss': logloss, 'brier': brier
    }
    if embeddings is not None:
        recall_metrics = calculate_recall_at_k(embeddings, torch.tensor(y_true))
        mrr = calculate_mrr(embeddings, torch.tensor(y_true))
        results.update(recall_metrics)
        results['mrr'] = mrr
    return results


def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}


def train_model(model, train_data, num_epochs=3, learning_rate=1e-3, batch_size=2):
    gatera_params = [p for p in model.gatera_layers.parameters() if p.requires_grad]
    optimizer = AdamW(
        [{'params': gatera_params}, {'params': model.classifier.parameters()}], 
        lr=learning_rate, 
        weight_decay=0.01
    )
    dataset = CodeCloneDataset(train_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    model.train()
    epoch_stats = []
    start_time = time.time()
    initial_memory = get_memory_usage()
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0
        total_task_loss = 0
        total_entropy_loss = 0
        correct = 0
        total = 0
        batch_count = 0
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs['loss']
            if loss is not None:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(gatera_params + list(model.classifier.parameters()), 1.0)
                optimizer.step()
                total_loss += loss.item()
                if outputs.get('entropy_loss') is not None:
                    total_entropy_loss += outputs['entropy_loss'].item()
                with torch.no_grad():
                    predictions = torch.argmax(outputs['logits'], dim=-1)
                    correct += (predictions == labels.to(predictions.device)).sum().item()
                    total += labels.size(0)
                
                del outputs, loss, predictions
                torch.cuda.empty_cache()
            
            batch_count += 1
            if batch_count % 5 == 0:
                torch.cuda.empty_cache()
                gc.collect()
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / batch_count if batch_count > 0 else 0
        avg_entropy = total_entropy_loss / batch_count if batch_count > 0 else 0
        accuracy = correct / total if total > 0 else 0
        current_memory = get_memory_usage()
        epoch_stats.append({
            'epoch': epoch + 1,
            'loss': avg_loss,
            'entropy_loss': avg_entropy,
            'acc': accuracy,
            'time': epoch_time,
            'memory_mb': current_memory
        })
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}, Entropy: {avg_entropy:.4f}, Acc: {accuracy:.4f}, Time: {epoch_time:.2f}s, Memory: {current_memory:.2f}MB")
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
    total_training_time = time.time() - start_time
    peak_memory = max([stat['memory_mb'] for stat in epoch_stats])
    memory_increase = peak_memory - initial_memory
    print(f"\n{'='*80}\nTRAINING SUMMARY\n{'='*80}")
    print(f"Total Training Time: {total_training_time:.2f}s")
    print(f"Peak Memory Usage: {peak_memory:.2f}MB")
    print(f"Memory Increase: {memory_increase:.2f}MB\n{'='*80}\n")
    return epoch_stats


def evaluate_model(model, test_data, batch_size=2):
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    all_embeddings = []
    dataset = CodeCloneDataset(test_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    start_time = time.time()
    with torch.no_grad():
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']
            hidden_states = outputs['hidden_states']
            predictions = torch.argmax(logits, dim=-1)
            probabilities = F.softmax(logits, dim=-1)
            pooled_embeddings = hidden_states.mean(dim=1)
            pooled_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities[:, 1].cpu().numpy())
            all_embeddings.append(pooled_embeddings.cpu())
            if len(all_predictions) % 20 == 0:
                torch.cuda.empty_cache()
    inference_time = time.time() - start_time
    samples_per_sec = len(test_data) / inference_time
    all_embeddings = torch.cat(all_embeddings, dim=0)
    metrics = calculate_comprehensive_metrics(all_labels, all_predictions, all_probabilities, all_embeddings)
    metrics['inference_time'] = inference_time
    metrics['samples_per_sec'] = samples_per_sec
    return metrics


def print_results(metrics, dataset_name):
    print(f"\n{'='*80}\nRESULTS: {dataset_name}\n{'='*80}\n\nConfusion Matrix:\n{metrics['cm']}\n\nTrue Negatives (TN): {metrics['tn']}\nFalse Positives (FP): {metrics['fp']}\nFalse Negatives (FN): {metrics['fn']}\nTrue Positives (TP): {metrics['tp']}\n\n{'='*80}\nCLASSIFICATION METRICS\n{'='*80}")
    print(f"Accuracy: {metrics['acc']:.4f}\nBalanced Accuracy: {metrics['bal_acc']:.4f}\nPrecision: {metrics['prec']:.4f}\nRecall (Sensitivity): {metrics['rec']:.4f}\nF1 Score: {metrics['f1']:.4f}\nSpecificity: {metrics['specificity']:.4f}\nNegative Predictive Value (NPV): {metrics['npv']:.4f}\nFalse Positive Rate (FPR): {metrics['fpr']:.4f}\nFalse Negative Rate (FNR): {metrics['fnr']:.4f}")
    print(f"\n{'='*80}\nADVANCED METRICS\n{'='*80}\nJaccard Score: {metrics['jacc']:.4f}\nMatthews Correlation Coefficient (MCC): {metrics['mcc']:.4f}\nCohen's Kappa: {metrics['kappa']:.4f}")
    if 'roc_auc' in metrics:
        print(f"\n{'='*80}\nPROBABILISTIC METRICS\n{'='*80}\nROC AUC Score: {metrics['roc_auc']:.4f}")
    if 'pr_auc' in metrics:
        print(f"Precision-Recall AUC: {metrics['pr_auc']:.4f}")
    if 'log_loss' in metrics:
        print(f"Log Loss: {metrics['log_loss']:.4f}")
    if 'brier' in metrics:
        print(f"Brier Score: {metrics['brier']:.4f}")
    if 'recall@1' in metrics:
        print(f"\n{'='*80}\nRETRIEVAL METRICS\n{'='*80}\nRecall@1: {metrics['recall@1']:.4f}\nRecall@3: {metrics['recall@3']:.4f}\nRecall@5: {metrics['recall@5']:.4f}\nRecall@10: {metrics['recall@10']:.4f}\nMean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")
    if 'inference_time' in metrics:
        print(f"\n{'='*80}\nPERFORMANCE METRICS\n{'='*80}\nInference Time: {metrics['inference_time']:.2f}s\nThroughput: {metrics['samples_per_sec']:.2f} samples/sec")
    print(f"{'='*80}\n")


def main():
    torch.cuda.empty_cache()
    gc.collect()
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    cache_dir = '/hf_cache'
    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-770m", cache_dir=cache_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    gatera_config = GateRAConfig(rank=16, alpha=16.0, dropout=0.0, entropy_reg_weight=0.01)
    print("\nLoading model with GateRA (Gated LoRA) and pipeline parallelism across GPUs...")
    model = GateRACodeT5("Salesforce/codet5p-770m", gatera_config, num_classes=2, cache_dir=cache_dir)
    print_parameter_statistics(model)
    print(f"\n{'='*80}\nTRAINING PHASE\n{'='*80}\n")
    train_csv = '/train.csv'
    train_data = create_training_data(train_csv, tokenizer, label_col='label', func1_col='func1', func2_col='func2')
    print(f"Created {len(train_data)} training examples\n")
    if len(train_data) == 0:
        print("No training data available. Exiting.")
        return
    epoch_stats = train_model(model, train_data, num_epochs=5, learning_rate=1e-3, batch_size=4)
    print(f"\n{'='*80}\nTESTING PHASE\n{'='*80}\n")
    test_csv = '/test.csv'
    test_data = create_training_data(test_csv, tokenizer, label_col='label', func1_col='func1', func2_col='func2')
    print(f"Created {len(test_data)} test examples\n")
    if len(test_data) > 0:
        test_metrics = evaluate_model(model, test_data, batch_size=4)
        print_results(test_metrics, "TEST SET")
    print(f"\n{'='*80}\nTRAINING COMPLETE\n{'='*80}\n")


if __name__ == '__main__':
    main()

# Prefix tuning codet5+

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import numpy as np
from sklearn.metrics import (
    accuracy_score, confusion_matrix, balanced_accuracy_score, 
    matthews_corrcoef, roc_auc_score, average_precision_score, 
    f1_score, jaccard_score, cohen_kappa_score, log_loss, brier_score_loss
)
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader


class PrefixTuningConfig:
    def __init__(
        self,
        prefix_length: int = 20,
        num_layers: int = 24,
        hidden_size: int = 1024,
        num_heads: int = 16,
        head_dim: int = 64,
        reparam_dim: int = 512,
        dropout: float = 0.1
    ):
        self.prefix_length = prefix_length
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.reparam_dim = reparam_dim
        self.dropout = dropout


class PrefixEncoder(nn.Module):
    def __init__(self, config: PrefixTuningConfig):
        super().__init__()
        self.prefix_length = config.prefix_length
        self.num_layers = config.num_layers
        self.num_heads = config.num_heads
        self.head_dim = config.head_dim
        self.hidden_size = config.hidden_size
        
        self.prefix_tokens = nn.Parameter(
            torch.randn(config.num_layers, config.prefix_length, config.reparam_dim)
        )
        
        self.reparam_mlp = nn.Sequential(
            nn.Linear(config.reparam_dim, config.hidden_size),
            nn.Tanh(),
            nn.Linear(config.hidden_size, 2 * config.num_heads * config.head_dim),
            nn.Dropout(config.dropout)
        )
        
    def forward(self, batch_size: int):
        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1, -1, -1)
        
        prefix_kvs = []
        for layer_idx in range(self.num_layers):
            layer_prefix = prefix_tokens[:, layer_idx, :, :]
            prefix_hidden = self.reparam_mlp(layer_prefix)
            
            prefix_hidden = prefix_hidden.view(
                batch_size, self.prefix_length, 2, self.num_heads, self.head_dim
            )
            
            key = prefix_hidden[:, :, 0, :, :].transpose(1, 2)
            value = prefix_hidden[:, :, 1, :, :].transpose(1, 2)
            
            prefix_kvs.append((key.contiguous(), value.contiguous()))
        
        return prefix_kvs


class CodeCloneDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


class PrefixTuningCodeT5(nn.Module):
    def __init__(self, model_name: str, prefix_config: PrefixTuningConfig, num_classes: int = 2, cache_dir: str = None):
        super().__init__()
        
        torch.cuda.empty_cache()
        gc.collect()
        
        num_gpus = torch.cuda.device_count()
        print(f"Initializing model across {num_gpus} GPUs")
        
        self.primary_device = 'cuda:0'
        
        print("Loading base LLM model...")
        base_model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name,
            torch_dtype=torch.float32,
            cache_dir=cache_dir,
            low_cpu_mem_usage=True,
            device_map='auto'
        )
        
        self.llm = base_model
        self.prefix_config = prefix_config
        
        encoder_layers = len(base_model.encoder.block)
        decoder_layers = len(base_model.decoder.block)
        total_layers = encoder_layers + decoder_layers
        print(f"Model has {encoder_layers} encoder layers and {decoder_layers} decoder layers (total: {total_layers})")
        
        self.layer_devices = []
        for i in range(encoder_layers):
            layer_device = next(base_model.encoder.block[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Encoder Layer {i} -> {layer_device}")
        
        for i in range(decoder_layers):
            layer_device = next(base_model.decoder.block[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Decoder Layer {i} -> {layer_device}")
        
        embed_device = next(base_model.encoder.embed_tokens.parameters()).device
        final_device = next(base_model.decoder.final_layer_norm.parameters()).device
        
        print(f"Embeddings -> {embed_device}")
        print(f"Final Layer -> {final_device}")
        
        self.embed_device = str(embed_device)
        self.final_device = str(final_device)
        
        self.prefix_encoder = PrefixEncoder(prefix_config).to(self.primary_device).float()
        self.num_classes = num_classes
        
        self.classifier = nn.Sequential(
            nn.Linear(prefix_config.hidden_size, prefix_config.hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(prefix_config.hidden_size // 4, num_classes)
        ).to(self.primary_device).float()
        
        for param in self.llm.parameters():
            param.requires_grad = False
        
        for param in self.prefix_encoder.parameters():
            param.requires_grad = True
            
        for param in self.classifier.parameters():
            param.requires_grad = True
        
        self._register_prefix_hooks()
        
        print("Model initialization complete!")
    
    def _register_prefix_hooks(self):
        self.prefix_kvs = None
        
        def create_hook(layer_idx):
            def hook(module, args, kwargs, output):
                if self.prefix_kvs is not None and layer_idx < len(self.prefix_kvs):
                    prefix_key, prefix_value = self.prefix_kvs[layer_idx]
                    
                    device = output[1][0].device if isinstance(output, tuple) and len(output) > 1 and output[1] is not None else self.layer_devices[layer_idx]
                    
                    prefix_key = prefix_key.to(device)
                    prefix_value = prefix_value.to(device)
                    
                    if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
                        orig_key_value = output[1]
                        if isinstance(orig_key_value, tuple) and len(orig_key_value) >= 2:
                            orig_key = orig_key_value[0]
                            orig_value = orig_key_value[1]
                            
                            new_key = torch.cat([prefix_key, orig_key], dim=2)
                            new_value = torch.cat([prefix_value, orig_value], dim=2)
                            
                            new_key_value = (new_key, new_value) + orig_key_value[2:] if len(orig_key_value) > 2 else (new_key, new_value)
                            output = (output[0], new_key_value) + output[2:] if len(output) > 2 else (output[0], new_key_value)
                
                return output
            return hook
        
        encoder_layers = len(self.llm.encoder.block)
        for layer_idx, layer in enumerate(self.llm.encoder.block):
            layer.register_forward_hook(create_hook(layer_idx), with_kwargs=True)
        
        for layer_idx, layer in enumerate(self.llm.decoder.block):
            layer.register_forward_hook(create_hook(encoder_layers + layer_idx), with_kwargs=True)
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        batch_size = input_ids.shape[0]
        
        input_ids = input_ids.to(self.embed_device)
        
        self.prefix_kvs = self.prefix_encoder(batch_size)
        
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.embed_device)
            prefix_attention = torch.ones(
                batch_size, self.prefix_config.prefix_length, 
                dtype=attention_mask.dtype, device=attention_mask.device
            )
            attention_mask = torch.cat([prefix_attention, attention_mask], dim=1)
        
        with torch.no_grad():
            outputs = self.llm(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=input_ids,
                output_hidden_states=True,
                return_dict=True
            )
        
        hidden_states = outputs.decoder_hidden_states[-1]
        
        if str(hidden_states.device) != self.primary_device:
            hidden_states = hidden_states.to(self.primary_device)
        
        pooled_output = hidden_states.mean(dim=1)
        
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            if labels.device != logits.device:
                labels = labels.to(logits.device)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        
        self.prefix_kvs = None
        
        return {'logits': logits, 'loss': loss, 'hidden_states': hidden_states}


def print_parameter_statistics(model: PrefixTuningCodeT5):
    total_params = sum(p.numel() for p in model.llm.parameters())
    prefix_params = sum(p.numel() for p in model.prefix_encoder.parameters())
    classifier_params = sum(p.numel() for p in model.classifier.parameters())
    trainable_params = prefix_params + classifier_params
    
    print(f"\n{'='*80}")
    print(f"PARAMETER STATISTICS")
    print(f"{'='*80}")
    print(f"Total Frozen LLM Parameters: {total_params:,}")
    print(f"Trainable Prefix Parameters: {prefix_params:,}")
    print(f"Trainable Classifier Parameters: {classifier_params:,}")
    print(f"Total Trainable Parameters: {trainable_params:,}")
    print(f"Trainable Percentage: {(trainable_params / total_params * 100):.4f}%")
    print(f"{'='*80}\n")
    
    print(f"{'='*80}")
    print(f"GPU MEMORY USAGE")
    print(f"{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        print(f"GPU {i}: Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
    print(f"{'='*80}\n")


def load_dataset(csv_path, label_col, func1_col, func2_col):
    try:
        df = pd.read_csv(csv_path)
        print(f"Loaded dataset: {len(df)} samples")
        print(f"Label distribution: {df[label_col].value_counts().to_dict()}")
        return df
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None


def create_training_data(csv_path, tokenizer, max_length=512, 
                        label_col='label', func1_col='func1', func2_col='func2'):
    df = load_dataset(csv_path, label_col, func1_col, func2_col)
    
    if df is None:
        return []
    
    train_data = []
    
    for _, row in df.iterrows():
        func1 = str(row[func1_col])
        func2 = str(row[func2_col])
        label = int(row[label_col])
        
        combined_text = f"Code1: {func1}\nCode2: {func2}\nAre these code clones?"
        
        encoding = tokenizer(
            combined_text,
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=max_length
        )
        
        train_data.append({
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        })
    
    return train_data


def get_memory_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024


def calculate_recall_at_k(embeddings, labels, k_values=[1, 3, 5, 10]):
    similarities = torch.mm(embeddings, embeddings.t())
    
    results = {}
    for k in k_values:
        correct = 0
        total = 0
        
        for i in range(len(labels)):
            true_label = labels[i]
            sims = similarities[i].clone()
            sims[i] = -float('inf')
            
            top_k_indices = torch.topk(sims, min(k, len(labels)-1)).indices
            top_k_labels = labels[top_k_indices]
            
            if true_label in top_k_labels:
                correct += 1
            total += 1
        
        results[f'recall@{k}'] = correct / total if total > 0 else 0.0
    
    return results


def calculate_mrr(embeddings, labels):
    similarities = torch.mm(embeddings, embeddings.t())
    
    mrr_sum = 0.0
    count = 0
    
    for i in range(len(labels)):
        true_label = labels[i]
        sims = similarities[i].clone()
        sims[i] = -float('inf')
        
        sorted_indices = torch.argsort(sims, descending=True)
        sorted_labels = labels[sorted_indices]
        
        for rank, label in enumerate(sorted_labels, 1):
            if label == true_label:
                mrr_sum += 1.0 / rank
                break
        
        count += 1
    
    return mrr_sum / count if count > 0 else 0.0


def calculate_comprehensive_metrics(y_true, y_pred, y_proba=None, embeddings=None):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    
    if cm.shape != (2, 2):
        cm_2x2 = np.zeros((2, 2))
        for i in range(min(cm.shape[0], 2)):
            for j in range(min(cm.shape[1], 2)):
                cm_2x2[i, j] = cm[i, j]
        cm = cm_2x2
    
    tn, fp, fn, tp = cm.ravel()
    
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = f1_score(y_true, y_pred, zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
    fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
    fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
    jacc = jaccard_score(y_true, y_pred, zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred)
    
    if y_proba is not None and len(np.unique(y_true)) > 1:
        try:
            roc_auc = roc_auc_score(y_true, y_proba)
            pr_auc = average_precision_score(y_true, y_proba)
            logloss = log_loss(y_true, np.column_stack([1-y_proba, y_proba]))
            brier = brier_score_loss(y_true, y_proba)
        except:
            roc_auc = pr_auc = logloss = brier = 0.0
    else:
        roc_auc = pr_auc = logloss = brier = 0.0
    
    results = {
        'cm': cm,
        'tn': int(tn),
        'fp': int(fp),
        'fn': int(fn),
        'tp': int(tp),
        'acc': acc,
        'bal_acc': bal_acc,
        'prec': prec,
        'rec': rec,
        'specificity': specificity,
        'npv': npv,
        'fpr': fpr,
        'fnr': fnr,
        'f1': f1,
        'mcc': mcc,
        'jacc': jacc,
        'kappa': kappa,
        'roc_auc': roc_auc,
        'pr_auc': pr_auc,
        'log_loss': logloss,
        'brier': brier
    }
    
    if embeddings is not None:
        recall_metrics = calculate_recall_at_k(embeddings, torch.tensor(y_true))
        mrr = calculate_mrr(embeddings, torch.tensor(y_true))
        results.update(recall_metrics)
        results['mrr'] = mrr
    
    return results


def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}


def train_model(model, train_data, num_epochs=5, learning_rate=1e-4, batch_size=2):
    optimizer = torch.optim.AdamW([
        {'params': model.prefix_encoder.parameters()},
        {'params': model.classifier.parameters()}
    ], lr=learning_rate, weight_decay=0.01)
    
    dataset = CodeCloneDataset(train_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    model.train()
    
    epoch_stats = []
    start_time = time.time()
    initial_memory = get_memory_usage()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0
        correct = 0
        total = 0
        batch_count = 0
        
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            
            optimizer.zero_grad()
            
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            
            loss = outputs['loss']
            if loss is not None:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    list(model.prefix_encoder.parameters()) + list(model.classifier.parameters()), 1.0
                )
                optimizer.step()
                total_loss += loss.item()
                
                with torch.no_grad():
                    predictions = torch.argmax(outputs['logits'], dim=-1)
                    correct += (predictions == labels.to(predictions.device)).sum().item()
                    total += labels.size(0)
            
            batch_count += 1
            
            if batch_count % 10 == 0:
                torch.cuda.empty_cache()
        
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / batch_count if batch_count > 0 else 0
        accuracy = correct / total if total > 0 else 0
        current_memory = get_memory_usage()
        
        epoch_stats.append({
            'epoch': epoch + 1,
            'loss': avg_loss,
            'acc': accuracy,
            'time': epoch_time,
            'memory_mb': current_memory
        })
        
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}, Acc: {accuracy:.4f}, Time: {epoch_time:.2f}s, Memory: {current_memory:.2f}MB")
        
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
    
    total_training_time = time.time() - start_time
    peak_memory = max([stat['memory_mb'] for stat in epoch_stats])
    memory_increase = peak_memory - initial_memory
    
    print(f"\n{'='*80}")
    print(f"TRAINING SUMMARY")
    print(f"{'='*80}")
    print(f"Total Training Time: {total_training_time:.2f}s")
    print(f"Peak Memory Usage: {peak_memory:.2f}MB")
    print(f"Memory Increase: {memory_increase:.2f}MB")
    print(f"{'='*80}\n")
    
    return epoch_stats


def evaluate_model(model, test_data, batch_size=2):
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    all_embeddings = []
    
    dataset = CodeCloneDataset(test_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    
    start_time = time.time()
    
    with torch.no_grad():
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']
            hidden_states = outputs['hidden_states']
            
            predictions = torch.argmax(logits, dim=-1)
            probabilities = F.softmax(logits, dim=-1)
            
            pooled_embeddings = hidden_states.mean(dim=1)
            pooled_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities[:, 1].cpu().numpy())
            all_embeddings.append(pooled_embeddings.cpu())
            
            if len(all_predictions) % 20 == 0:
                torch.cuda.empty_cache()
    
    inference_time = time.time() - start_time
    samples_per_sec = len(test_data) / inference_time
    
    all_embeddings = torch.cat(all_embeddings, dim=0)
    
    metrics = calculate_comprehensive_metrics(
        all_labels, 
        all_predictions, 
        all_probabilities,
        all_embeddings
    )
    
    metrics['inference_time'] = inference_time
    metrics['samples_per_sec'] = samples_per_sec
    
    return metrics


def print_results(metrics, dataset_name):
    print(f"\n{'='*80}")
    print(f"RESULTS: {dataset_name}")
    print(f"{'='*80}")
    print(f"\nConfusion Matrix:")
    print(f"{metrics['cm']}")
    print(f"\nTrue Negatives (TN): {metrics['tn']}")
    print(f"False Positives (FP): {metrics['fp']}")
    print(f"False Negatives (FN): {metrics['fn']}")
    print(f"True Positives (TP): {metrics['tp']}")
    print(f"\n{'='*80}")
    print(f"CLASSIFICATION METRICS")
    print(f"{'='*80}")
    print(f"Accuracy: {metrics['acc']:.4f}")
    print(f"Balanced Accuracy: {metrics['bal_acc']:.4f}")
    print(f"Precision: {metrics['prec']:.4f}")
    print(f"Recall (Sensitivity): {metrics['rec']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    print(f"Specificity: {metrics['specificity']:.4f}")
    print(f"Negative Predictive Value (NPV): {metrics['npv']:.4f}")
    print(f"False Positive Rate (FPR): {metrics['fpr']:.4f}")
    print(f"False Negative Rate (FNR): {metrics['fnr']:.4f}")
    print(f"\n{'='*80}")
    print(f"ADVANCED METRICS")
    print(f"{'='*80}")
    print(f"Jaccard Score: {metrics['jacc']:.4f}")
    print(f"Matthews Correlation Coefficient (MCC): {metrics['mcc']:.4f}")
    print(f"Cohen's Kappa: {metrics['kappa']:.4f}")
    
    if 'roc_auc' in metrics:
        print(f"\n{'='*80}")
        print(f"PROBABILISTIC METRICS")
        print(f"{'='*80}")
        print(f"ROC AUC Score: {metrics['roc_auc']:.4f}")
    if 'pr_auc' in metrics:
        print(f"Precision-Recall AUC: {metrics['pr_auc']:.4f}")
    if 'log_loss' in metrics:
        print(f"Log Loss: {metrics['log_loss']:.4f}")
    if 'brier' in metrics:
        print(f"Brier Score: {metrics['brier']:.4f}")
    
    if 'recall@1' in metrics:
        print(f"\n{'='*80}")
        print(f"RETRIEVAL METRICS")
        print(f"{'='*80}")
        print(f"Recall@1: {metrics['recall@1']:.4f}")
        print(f"Recall@3: {metrics['recall@3']:.4f}")
        print(f"Recall@5: {metrics['recall@5']:.4f}")
        print(f"Recall@10: {metrics['recall@10']:.4f}")
        print(f"Mean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")
    
    if 'inference_time' in metrics:
        print(f"\n{'='*80}")
        print(f"PERFORMANCE METRICS")
        print(f"{'='*80}")
        print(f"Inference Time: {metrics['inference_time']:.2f}s")
        print(f"Throughput: {metrics['samples_per_sec']:.2f} samples/sec")
    
    print(f"{'='*80}\n")


def main():
    torch.cuda.empty_cache()
    gc.collect()
    
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    
    cache_dir = '/hf_cache'
    
    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-770m", cache_dir=cache_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    prefix_config = PrefixTuningConfig(
        prefix_length=20,
        num_layers=24,
        hidden_size=1024,
        num_heads=16,
        head_dim=64,
        reparam_dim=512
    )
    
    print("\nLoading model with pipeline parallelism across GPUs...")
    model = PrefixTuningCodeT5("Salesforce/codet5p-770m", prefix_config, num_classes=2, cache_dir=cache_dir)
    print_parameter_statistics(model)
    
    print(f"\n{'='*80}")
    print("TRAINING PHASE")
    print(f"{'='*80}\n")
    
    train_csv = '/train.csv'
    train_data = create_training_data(
        train_csv, 
        tokenizer,
        label_col='label', 
        func1_col='func1', 
        func2_col='func2'
    )
    print(f"Created {len(train_data)} training examples\n")
    
    if len(train_data) == 0:
        print("No training data available. Exiting.")
        return
    
    epoch_stats = train_model(model, train_data, num_epochs=5, learning_rate=1e-4, batch_size=2)
    
    print(f"\n{'='*80}")
    print("TESTING PHASE")
    print(f"{'='*80}\n")
    
    test_csv = '/test.csv'
    test_data = create_training_data(
        test_csv,
        tokenizer,
        label_col='label',
        func1_col='func1',
        func2_col='func2'
    )
    print(f"Created {len(test_data)} test examples\n")
    
    if len(test_data) > 0:
        test_metrics = evaluate_model(model, test_data, batch_size=2)
        print_results(test_metrics, "TEST SET")
    
    print(f"\n{'='*80}")
    print("TRAINING COMPLETE")
    print(f"{'='*80}\n")


if __name__ == '__main__':
    main()

# adpater clone det

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import numpy as np
from sklearn.metrics import (
    accuracy_score, confusion_matrix, balanced_accuracy_score, 
    matthews_corrcoef, roc_auc_score, average_precision_score, 
    f1_score, jaccard_score, cohen_kappa_score, log_loss, brier_score_loss
)
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader


class AdapterConfig:
    def __init__(self, hidden_size: int = 1024, reduction_factor: int = 16, non_linearity: str = "gelu", dropout: float = 0.1):
        self.hidden_size = hidden_size
        self.adapter_size = hidden_size // reduction_factor
        self.reduction_factor = reduction_factor
        self.non_linearity = non_linearity
        self.dropout = dropout


class AdapterModule(nn.Module):
    def __init__(self, config: AdapterConfig):
        super().__init__()
        self.down_project = nn.Linear(config.hidden_size, config.adapter_size)
        self.up_project = nn.Linear(config.adapter_size, config.hidden_size)
        if config.non_linearity == "gelu":
            self.activation = nn.GELU()
        elif config.non_linearity == "relu":
            self.activation = nn.ReLU()
        elif config.non_linearity == "swish":
            self.activation = nn.SiLU()
        else:
            self.activation = nn.GELU()
        self.dropout = nn.Dropout(config.dropout)
        nn.init.normal_(self.down_project.weight, std=1e-3)
        nn.init.zeros_(self.down_project.bias)
        nn.init.normal_(self.up_project.weight, std=1e-3)
        nn.init.zeros_(self.up_project.bias)
    
    def forward(self, hidden_states):
        residual = hidden_states
        hidden_states = self.down_project(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.up_project(hidden_states)
        return hidden_states + residual


class CodeCloneDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]


class AdapterCodeT5(nn.Module):
    def __init__(self, model_name: str, adapter_config: AdapterConfig, num_classes: int = 2, cache_dir: str = None):
        super().__init__()
        torch.cuda.empty_cache()
        gc.collect()
        num_gpus = torch.cuda.device_count()
        print(f"Initializing model across {num_gpus} GPUs")
        self.primary_device = 'cuda:0'
        print("Loading base LLM model...")
        base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float32, cache_dir=cache_dir, low_cpu_mem_usage=True, device_map='auto')
        self.llm = base_model
        self.adapter_config = adapter_config
        encoder_layers = len(base_model.encoder.block)
        decoder_layers = len(base_model.decoder.block)
        total_layers = encoder_layers + decoder_layers
        print(f"Model has {encoder_layers} encoder layers and {decoder_layers} decoder layers (total: {total_layers})")
        self.layer_devices = []
        for i in range(encoder_layers):
            layer_device = next(base_model.encoder.block[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Encoder Layer {i} -> {layer_device}")
        for i in range(decoder_layers):
            layer_device = next(base_model.decoder.block[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Decoder Layer {i} -> {layer_device}")
        embed_device = next(base_model.encoder.embed_tokens.parameters()).device
        final_device = next(base_model.decoder.final_layer_norm.parameters()).device
        print(f"Embeddings -> {embed_device}")
        print(f"Final Layer -> {final_device}")
        self.embed_device = str(embed_device)
        self.final_device = str(final_device)
        for param in self.llm.parameters():
            param.requires_grad = False
        self.encoder_adapters_attn = nn.ModuleList()
        self.encoder_adapters_ffn = nn.ModuleList()
        self.decoder_adapters_attn = nn.ModuleList()
        self.decoder_adapters_ffn = nn.ModuleList()
        self._inject_adapter_modules()
        self.num_classes = num_classes
        hidden_size = base_model.config.d_model
        self.classifier = nn.Sequential(nn.Linear(hidden_size, hidden_size // 4), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size // 4, num_classes)).to(self.primary_device).float()
        for param in self.classifier.parameters():
            param.requires_grad = True
        print("Model initialization complete!")
    
    def _inject_adapter_modules(self):
        encoder_layers = len(self.llm.encoder.block)
        decoder_layers = len(self.llm.decoder.block)
        for layer_idx in range(encoder_layers):
            layer = self.llm.encoder.block[layer_idx]
            device = next(layer.parameters()).device
            adapter_after_attn = AdapterModule(self.adapter_config).to(device)
            self.encoder_adapters_attn.append(adapter_after_attn)
            adapter_after_ffn = AdapterModule(self.adapter_config).to(device)
            self.encoder_adapters_ffn.append(adapter_after_ffn)
        for layer_idx in range(decoder_layers):
            layer = self.llm.decoder.block[layer_idx]
            device = next(layer.parameters()).device
            adapter_after_attn = AdapterModule(self.adapter_config).to(device)
            self.decoder_adapters_attn.append(adapter_after_attn)
            adapter_after_ffn = AdapterModule(self.adapter_config).to(device)
            self.decoder_adapters_ffn.append(adapter_after_ffn)
        self._register_forward_hooks()
    
    def _register_forward_hooks(self):
        encoder_layers = len(self.llm.encoder.block)
        decoder_layers = len(self.llm.decoder.block)
        for layer_idx in range(encoder_layers):
            layer = self.llm.encoder.block[layer_idx]
            attn_adapter = self.encoder_adapters_attn[layer_idx]
            ffn_adapter = self.encoder_adapters_ffn[layer_idx]
            original_forward = layer.forward
            def make_forward(orig_fwd, attn_adp, ffn_adp):
                def new_forward(*args, **kwargs):
                    outputs = orig_fwd(*args, **kwargs)
                    if isinstance(outputs, tuple):
                        hidden_states = outputs[0]
                        hidden_states = attn_adp(hidden_states)
                        hidden_states = ffn_adp(hidden_states)
                        outputs = (hidden_states,) + outputs[1:]
                    return outputs
                return new_forward
            layer.forward = make_forward(original_forward, attn_adapter, ffn_adapter)
        for layer_idx in range(decoder_layers):
            layer = self.llm.decoder.block[layer_idx]
            attn_adapter = self.decoder_adapters_attn[layer_idx]
            ffn_adapter = self.decoder_adapters_ffn[layer_idx]
            original_forward = layer.forward
            def make_forward(orig_fwd, attn_adp, ffn_adp):
                def new_forward(*args, **kwargs):
                    outputs = orig_fwd(*args, **kwargs)
                    if isinstance(outputs, tuple):
                        hidden_states = outputs[0]
                        hidden_states = attn_adp(hidden_states)
                        hidden_states = ffn_adp(hidden_states)
                        outputs = (hidden_states,) + outputs[1:]
                    return outputs
                return new_forward
            layer.forward = make_forward(original_forward, attn_adapter, ffn_adapter)
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        input_ids = input_ids.to(self.embed_device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.embed_device)
        outputs = self.llm(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=input_ids, output_hidden_states=True, return_dict=True)
        hidden_states = outputs.decoder_hidden_states[-1]
        if str(hidden_states.device) != self.primary_device:
            hidden_states = hidden_states.to(self.primary_device)
        pooled_output = hidden_states.mean(dim=1)
        logits = self.classifier(pooled_output)
        loss = None
        if labels is not None:
            if labels.device != logits.device:
                labels = labels.to(logits.device)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        return {'logits': logits, 'loss': loss, 'hidden_states': hidden_states}


def print_parameter_statistics(model: AdapterCodeT5):
    total_params = sum(p.numel() for p in model.llm.parameters())
    adapter_params = sum(p.numel() for p in model.encoder_adapters_attn.parameters()) + sum(p.numel() for p in model.encoder_adapters_ffn.parameters()) + sum(p.numel() for p in model.decoder_adapters_attn.parameters()) + sum(p.numel() for p in model.decoder_adapters_ffn.parameters())
    classifier_params = sum(p.numel() for p in model.classifier.parameters())
    trainable_params = adapter_params + classifier_params
    print(f"\n{'='*80}\nPARAMETER STATISTICS\n{'='*80}")
    print(f"Total Frozen LLM Parameters: {total_params:,}")
    print(f"Trainable Adapter Parameters: {adapter_params:,}")
    print(f"Trainable Classifier Parameters: {classifier_params:,}")
    print(f"Total Trainable Parameters: {trainable_params:,}")
    print(f"Trainable Percentage: {(trainable_params / total_params * 100):.4f}%\n{'='*80}\n")
    print(f"{'='*80}\nGPU MEMORY USAGE\n{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        print(f"GPU {i}: Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
    print(f"{'='*80}\n")


def load_dataset(csv_path, label_col, func1_col, func2_col):
    try:
        df = pd.read_csv(csv_path)
        print(f"Loaded dataset: {len(df)} samples")
        print(f"Label distribution: {df[label_col].value_counts().to_dict()}")
        return df
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None


def create_training_data(csv_path, tokenizer, max_length=512, label_col='label', func1_col='func1', func2_col='func2'):
    df = load_dataset(csv_path, label_col, func1_col, func2_col)
    if df is None:
        return []
    train_data = []
    for _, row in df.iterrows():
        func1 = str(row[func1_col])
        func2 = str(row[func2_col])
        label = int(row[label_col])
        combined_text = f"Code1: {func1}\nCode2: {func2}\nAre these code clones?"
        encoding = tokenizer(combined_text, return_tensors='pt', padding='max_length', truncation=True, max_length=max_length)
        train_data.append({'input_ids': encoding['input_ids'].squeeze(0), 'attention_mask': encoding['attention_mask'].squeeze(0), 'labels': torch.tensor(label, dtype=torch.long)})
    return train_data


def get_memory_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024


def calculate_recall_at_k(embeddings, labels, k_values=[1, 3, 5, 10]):
    similarities = torch.mm(embeddings, embeddings.t())
    results = {}
    for k in k_values:
        correct = 0
        total = 0
        for i in range(len(labels)):
            true_label = labels[i]
            sims = similarities[i].clone()
            sims[i] = -float('inf')
            top_k_indices = torch.topk(sims, min(k, len(labels)-1)).indices
            top_k_labels = labels[top_k_indices]
            if true_label in top_k_labels:
                correct += 1
            total += 1
        results[f'recall@{k}'] = correct / total if total > 0 else 0.0
    return results


def calculate_mrr(embeddings, labels):
    similarities = torch.mm(embeddings, embeddings.t())
    mrr_sum = 0.0
    count = 0
    for i in range(len(labels)):
        true_label = labels[i]
        sims = similarities[i].clone()
        sims[i] = -float('inf')
        sorted_indices = torch.argsort(sims, descending=True)
        sorted_labels = labels[sorted_indices]
        for rank, label in enumerate(sorted_labels, 1):
            if label == true_label:
                mrr_sum += 1.0 / rank
                break
        count += 1
    return mrr_sum / count if count > 0 else 0.0


def calculate_comprehensive_metrics(y_true, y_pred, y_proba=None, embeddings=None):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    if cm.shape != (2, 2):
        cm_2x2 = np.zeros((2, 2))
        for i in range(min(cm.shape[0], 2)):
            for j in range(min(cm.shape[1], 2)):
                cm_2x2[i, j] = cm[i, j]
        cm = cm_2x2
    tn, fp, fn, tp = cm.ravel()
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = f1_score(y_true, y_pred, zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
    fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
    fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
    jacc = jaccard_score(y_true, y_pred, zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred)
    if y_proba is not None and len(np.unique(y_true)) > 1:
        try:
            roc_auc = roc_auc_score(y_true, y_proba)
            pr_auc = average_precision_score(y_true, y_proba)
            logloss = log_loss(y_true, np.column_stack([1-y_proba, y_proba]))
            brier = brier_score_loss(y_true, y_proba)
        except:
            roc_auc = pr_auc = logloss = brier = 0.0
    else:
        roc_auc = pr_auc = logloss = brier = 0.0
    results = {'cm': cm, 'tn': int(tn), 'fp': int(fp), 'fn': int(fn), 'tp': int(tp), 'acc': acc, 'bal_acc': bal_acc, 'prec': prec, 'rec': rec, 'specificity': specificity, 'npv': npv, 'fpr': fpr, 'fnr': fnr, 'f1': f1, 'mcc': mcc, 'jacc': jacc, 'kappa': kappa, 'roc_auc': roc_auc, 'pr_auc': pr_auc, 'log_loss': logloss, 'brier': brier}
    if embeddings is not None:
        recall_metrics = calculate_recall_at_k(embeddings, torch.tensor(y_true))
        mrr = calculate_mrr(embeddings, torch.tensor(y_true))
        results.update(recall_metrics)
        results['mrr'] = mrr
    return results


def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}


def train_model(model, train_data, num_epochs=5, learning_rate=1e-4, batch_size=2):
    adapter_params = list(model.encoder_adapters_attn.parameters()) + list(model.encoder_adapters_ffn.parameters()) + list(model.decoder_adapters_attn.parameters()) + list(model.decoder_adapters_ffn.parameters())
    optimizer = torch.optim.AdamW([{'params': adapter_params}, {'params': model.classifier.parameters()}], lr=learning_rate, weight_decay=0.01)
    dataset = CodeCloneDataset(train_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    model.train()
    epoch_stats = []
    start_time = time.time()
    initial_memory = get_memory_usage()
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0
        correct = 0
        total = 0
        batch_count = 0
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs['loss']
            if loss is not None:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(adapter_params + list(model.classifier.parameters()), 1.0)
                optimizer.step()
                total_loss += loss.item()
                with torch.no_grad():
                    predictions = torch.argmax(outputs['logits'], dim=-1)
                    correct += (predictions == labels.to(predictions.device)).sum().item()
                    total += labels.size(0)
            batch_count += 1
            if batch_count % 10 == 0:
                torch.cuda.empty_cache()
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / batch_count if batch_count > 0 else 0
        accuracy = correct / total if total > 0 else 0
        current_memory = get_memory_usage()
        epoch_stats.append({'epoch': epoch + 1, 'loss': avg_loss, 'acc': accuracy, 'time': epoch_time, 'memory_mb': current_memory})
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}, Acc: {accuracy:.4f}, Time: {epoch_time:.2f}s, Memory: {current_memory:.2f}MB")
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
    total_training_time = time.time() - start_time
    peak_memory = max([stat['memory_mb'] for stat in epoch_stats])
    memory_increase = peak_memory - initial_memory
    print(f"\n{'='*80}\nTRAINING SUMMARY\n{'='*80}")
    print(f"Total Training Time: {total_training_time:.2f}s")
    print(f"Peak Memory Usage: {peak_memory:.2f}MB")
    print(f"Memory Increase: {memory_increase:.2f}MB\n{'='*80}\n")
    return epoch_stats


def evaluate_model(model, test_data, batch_size=2):
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    all_embeddings = []
    dataset = CodeCloneDataset(test_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    start_time = time.time()
    with torch.no_grad():
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']
            hidden_states = outputs['hidden_states']
            predictions = torch.argmax(logits, dim=-1)
            probabilities = F.softmax(logits, dim=-1)
            pooled_embeddings = hidden_states.mean(dim=1)
            pooled_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities[:, 1].cpu().numpy())
            all_embeddings.append(pooled_embeddings.cpu())
            if len(all_predictions) % 20 == 0:
                torch.cuda.empty_cache()
    inference_time = time.time() - start_time
    samples_per_sec = len(test_data) / inference_time
    all_embeddings = torch.cat(all_embeddings, dim=0)
    metrics = calculate_comprehensive_metrics(all_labels, all_predictions, all_probabilities, all_embeddings)
    metrics['inference_time'] = inference_time
    metrics['samples_per_sec'] = samples_per_sec
    return metrics


def print_results(metrics, dataset_name):
    print(f"\n{'='*80}\nRESULTS: {dataset_name}\n{'='*80}\n\nConfusion Matrix:\n{metrics['cm']}\n\nTrue Negatives (TN): {metrics['tn']}\nFalse Positives (FP): {metrics['fp']}\nFalse Negatives (FN): {metrics['fn']}\nTrue Positives (TP): {metrics['tp']}\n\n{'='*80}\nCLASSIFICATION METRICS\n{'='*80}")
    print(f"Accuracy: {metrics['acc']:.4f}\nBalanced Accuracy: {metrics['bal_acc']:.4f}\nPrecision: {metrics['prec']:.4f}\nRecall (Sensitivity): {metrics['rec']:.4f}\nF1 Score: {metrics['f1']:.4f}\nSpecificity: {metrics['specificity']:.4f}\nNegative Predictive Value (NPV): {metrics['npv']:.4f}\nFalse Positive Rate (FPR): {metrics['fpr']:.4f}\nFalse Negative Rate (FNR): {metrics['fnr']:.4f}")
    print(f"\n{'='*80}\nADVANCED METRICS\n{'='*80}\nJaccard Score: {metrics['jacc']:.4f}\nMatthews Correlation Coefficient (MCC): {metrics['mcc']:.4f}\nCohen's Kappa: {metrics['kappa']:.4f}")
    if 'roc_auc' in metrics:
        print(f"\n{'='*80}\nPROBABILISTIC METRICS\n{'='*80}\nROC AUC Score: {metrics['roc_auc']:.4f}")
    if 'pr_auc' in metrics:
        print(f"Precision-Recall AUC: {metrics['pr_auc']:.4f}")
    if 'log_loss' in metrics:
        print(f"Log Loss: {metrics['log_loss']:.4f}")
    if 'brier' in metrics:
        print(f"Brier Score: {metrics['brier']:.4f}")
    if 'recall@1' in metrics:
        print(f"\n{'='*80}\nRETRIEVAL METRICS\n{'='*80}\nRecall@1: {metrics['recall@1']:.4f}\nRecall@3: {metrics['recall@3']:.4f}\nRecall@5: {metrics['recall@5']:.4f}\nRecall@10: {metrics['recall@10']:.4f}\nMean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")
    if 'inference_time' in metrics:
        print(f"\n{'='*80}\nPERFORMANCE METRICS\n{'='*80}\nInference Time: {metrics['inference_time']:.2f}s\nThroughput: {metrics['samples_per_sec']:.2f} samples/sec")
    print(f"{'='*80}\n")


def main():
    torch.cuda.empty_cache()
    gc.collect()
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    cache_dir = '/hf_cache'
    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-770m", cache_dir=cache_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    adapter_config = AdapterConfig(hidden_size=1024, reduction_factor=16, non_linearity="gelu", dropout=0.1)
    print("\nLoading model with Adapter Tuning and pipeline parallelism across GPUs...")
    model = AdapterCodeT5("Salesforce/codet5p-770m", adapter_config, num_classes=2, cache_dir=cache_dir)
    print_parameter_statistics(model)
    print(f"\n{'='*80}\nTRAINING PHASE\n{'='*80}\n")
    train_csv = 'train.csv'
    train_data = create_training_data(train_csv, tokenizer, label_col='label', func1_col='func1', func2_col='func2')
    print(f"Created {len(train_data)} training examples\n")
    if len(train_data) == 0:
        print("No training data available. Exiting.")
        return
    epoch_stats = train_model(model, train_data, num_epochs=5, learning_rate=1e-4, batch_size=2)
    print(f"\n{'='*80}\nTESTING PHASE\n{'='*80}\n")
    test_csv = '/test.csv'
    test_data = create_training_data(test_csv, tokenizer, label_col='label', func1_col='func1', func2_col='func2')
    print(f"Created {len(test_data)} test examples\n")
    if len(test_data) > 0:
        test_metrics = evaluate_model(model, test_data, batch_size=2)
        print_results(test_metrics, "TEST SET")
    print(f"\n{'='*80}\nTRAINING COMPLETE\n{'='*80}\n")


if __name__ == '__main__':
    main()

# LoRA

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import numpy as np
from sklearn.metrics import (
    accuracy_score, confusion_matrix, balanced_accuracy_score, 
    matthews_corrcoef, roc_auc_score, average_precision_score, 
    f1_score, jaccard_score, cohen_kappa_score, log_loss, brier_score_loss
)
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader


class LoRAConfig:
    def __init__(self, r: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.1):
        self.r = r
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout
        self.scaling = lora_alpha / r


class LoRALayer(nn.Module):
    def __init__(self, in_features: int, out_features: int, config: LoRAConfig):
        super().__init__()
        self.r = config.r
        self.lora_alpha = config.lora_alpha
        self.scaling = config.scaling
        self.lora_dropout = nn.Dropout(p=config.lora_dropout)
        self.lora_A = nn.Parameter(torch.zeros((in_features, config.r)))
        self.lora_B = nn.Parameter(torch.zeros((config.r, out_features)))
        nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5))
        nn.init.zeros_(self.lora_B)
    
    def forward(self, x, original_weight):
        result = F.linear(x, original_weight)
        lora_result = F.linear(F.linear(self.lora_dropout(x), self.lora_A.T), self.lora_B.T)
        return result + lora_result * self.scaling


class CodeCloneDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]


class LoRACodeT5(nn.Module):
    def __init__(self, model_name: str, lora_config: LoRAConfig, num_classes: int = 2, cache_dir: str = None):
        super().__init__()
        torch.cuda.empty_cache()
        gc.collect()
        num_gpus = torch.cuda.device_count()
        print(f"Initializing model across {num_gpus} GPUs")
        self.primary_device = 'cuda:0'
        print("Loading base LLM model...")
        base_model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name, 
            torch_dtype=torch.float32, 
            cache_dir=cache_dir, 
            low_cpu_mem_usage=True, 
            device_map='auto'
        )
        self.llm = base_model
        self.lora_config = lora_config
        encoder_layers = len(base_model.encoder.block)
        decoder_layers = len(base_model.decoder.block)
        total_layers = encoder_layers + decoder_layers
        print(f"Model has {encoder_layers} encoder layers and {decoder_layers} decoder layers (total: {total_layers})")
        self.layer_devices = []
        for i in range(encoder_layers):
            layer_device = next(base_model.encoder.block[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Encoder Layer {i} -> {layer_device}")
        for i in range(decoder_layers):
            layer_device = next(base_model.decoder.block[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Decoder Layer {i} -> {layer_device}")
        embed_device = next(base_model.encoder.embed_tokens.parameters()).device
        final_device = next(base_model.decoder.final_layer_norm.parameters()).device
        print(f"Embeddings -> {embed_device}")
        print(f"Final Layer -> {final_device}")
        self.embed_device = str(embed_device)
        self.final_device = str(final_device)
        for param in self.llm.parameters():
            param.requires_grad = False
        self.lora_layers = nn.ModuleDict()
        self._inject_lora_layers()
        self.num_classes = num_classes
        hidden_size = base_model.config.d_model
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 4, num_classes)
        ).to(self.primary_device).float()
        for param in self.classifier.parameters():
            param.requires_grad = True
        print("Model initialization complete!")
    
    def _inject_lora_layers(self):
        encoder_layers = len(self.llm.encoder.block)
        decoder_layers = len(self.llm.decoder.block)
        d_model = self.llm.config.d_model
        for layer_idx in range(encoder_layers):
            layer = self.llm.encoder.block[layer_idx]
            device = next(layer.parameters()).device
            self_attn = layer.layer[0].SelfAttention
            q_proj = self_attn.q
            k_proj = self_attn.k
            v_proj = self_attn.v
            o_proj = self_attn.o
            self.lora_layers[f'encoder_{layer_idx}_q'] = LoRALayer(d_model, d_model, self.lora_config).to(device)
            self.lora_layers[f'encoder_{layer_idx}_v'] = LoRALayer(d_model, d_model, self.lora_config).to(device)
            original_q_forward = q_proj.forward
            original_v_forward = v_proj.forward
            def make_lora_forward(orig_forward, lora_layer, orig_weight):
                def new_forward(x):
                    return lora_layer(x, orig_weight)
                return new_forward
            q_proj.forward = make_lora_forward(original_q_forward, self.lora_layers[f'encoder_{layer_idx}_q'], q_proj.weight)
            v_proj.forward = make_lora_forward(original_v_forward, self.lora_layers[f'encoder_{layer_idx}_v'], v_proj.weight)
        for layer_idx in range(decoder_layers):
            layer = self.llm.decoder.block[layer_idx]
            device = next(layer.parameters()).device
            self_attn = layer.layer[0].SelfAttention
            q_proj = self_attn.q
            k_proj = self_attn.k
            v_proj = self_attn.v
            o_proj = self_attn.o
            self.lora_layers[f'decoder_{layer_idx}_q'] = LoRALayer(d_model, d_model, self.lora_config).to(device)
            self.lora_layers[f'decoder_{layer_idx}_v'] = LoRALayer(d_model, d_model, self.lora_config).to(device)
            original_q_forward = q_proj.forward
            original_v_forward = v_proj.forward
            def make_lora_forward(orig_forward, lora_layer, orig_weight):
                def new_forward(x):
                    return lora_layer(x, orig_weight)
                return new_forward
            q_proj.forward = make_lora_forward(original_q_forward, self.lora_layers[f'decoder_{layer_idx}_q'], q_proj.weight)
            v_proj.forward = make_lora_forward(original_v_forward, self.lora_layers[f'decoder_{layer_idx}_v'], v_proj.weight)
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        input_ids = input_ids.to(self.embed_device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.embed_device)
        outputs = self.llm(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            decoder_input_ids=input_ids, 
            output_hidden_states=True, 
            return_dict=True
        )
        hidden_states = outputs.decoder_hidden_states[-1]
        if str(hidden_states.device) != self.primary_device:
            hidden_states = hidden_states.to(self.primary_device)
        pooled_output = hidden_states.mean(dim=1)
        logits = self.classifier(pooled_output)
        loss = None
        if labels is not None:
            if labels.device != logits.device:
                labels = labels.to(logits.device)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        return {'logits': logits, 'loss': loss, 'hidden_states': hidden_states}


def print_parameter_statistics(model: LoRACodeT5):
    total_params = sum(p.numel() for p in model.llm.parameters())
    lora_params = sum(p.numel() for p in model.lora_layers.parameters())
    classifier_params = sum(p.numel() for p in model.classifier.parameters())
    trainable_params = lora_params + classifier_params
    print(f"\n{'='*80}\nPARAMETER STATISTICS\n{'='*80}")
    print(f"Total Frozen LLM Parameters: {total_params:,}")
    print(f"Trainable LoRA Parameters: {lora_params:,}")
    print(f"Trainable Classifier Parameters: {classifier_params:,}")
    print(f"Total Trainable Parameters: {trainable_params:,}")
    print(f"Trainable Percentage: {(trainable_params / total_params * 100):.4f}%")
    print(f"LoRA Rank: {model.lora_config.r}")
    print(f"LoRA Alpha: {model.lora_config.lora_alpha}")
    print(f"LoRA Scaling: {model.lora_config.scaling}")
    print(f"{'='*80}\n")
    print(f"{'='*80}\nGPU MEMORY USAGE\n{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        print(f"GPU {i}: Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
    print(f"{'='*80}\n")


def load_dataset(csv_path, label_col, func1_col, func2_col):
    try:
        df = pd.read_csv(csv_path)
        print(f"Loaded dataset: {len(df)} samples")
        print(f"Label distribution: {df[label_col].value_counts().to_dict()}")
        return df
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None


def create_training_data(csv_path, tokenizer, max_length=512, label_col='label', func1_col='func1', func2_col='func2'):
    df = load_dataset(csv_path, label_col, func1_col, func2_col)
    if df is None:
        return []
    train_data = []
    for _, row in df.iterrows():
        func1 = str(row[func1_col])
        func2 = str(row[func2_col])
        label = int(row[label_col])
        combined_text = f"Code1: {func1}\nCode2: {func2}\nAre these code clones?"
        encoding = tokenizer(
            combined_text, 
            return_tensors='pt', 
            padding='max_length', 
            truncation=True, 
            max_length=max_length
        )
        train_data.append({
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        })
    return train_data


def get_memory_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024


def calculate_recall_at_k(embeddings, labels, k_values=[1, 3, 5, 10]):
    similarities = torch.mm(embeddings, embeddings.t())
    results = {}
    for k in k_values:
        correct = 0
        total = 0
        for i in range(len(labels)):
            true_label = labels[i]
            sims = similarities[i].clone()
            sims[i] = -float('inf')
            top_k_indices = torch.topk(sims, min(k, len(labels)-1)).indices
            top_k_labels = labels[top_k_indices]
            if true_label in top_k_labels:
                correct += 1
            total += 1
        results[f'recall@{k}'] = correct / total if total > 0 else 0.0
    return results


def calculate_mrr(embeddings, labels):
    similarities = torch.mm(embeddings, embeddings.t())
    mrr_sum = 0.0
    count = 0
    for i in range(len(labels)):
        true_label = labels[i]
        sims = similarities[i].clone()
        sims[i] = -float('inf')
        sorted_indices = torch.argsort(sims, descending=True)
        sorted_labels = labels[sorted_indices]
        for rank, label in enumerate(sorted_labels, 1):
            if label == true_label:
                mrr_sum += 1.0 / rank
                break
        count += 1
    return mrr_sum / count if count > 0 else 0.0


def calculate_comprehensive_metrics(y_true, y_pred, y_proba=None, embeddings=None):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    if cm.shape != (2, 2):
        cm_2x2 = np.zeros((2, 2))
        for i in range(min(cm.shape[0], 2)):
            for j in range(min(cm.shape[1], 2)):
                cm_2x2[i, j] = cm[i, j]
        cm = cm_2x2
    tn, fp, fn, tp = cm.ravel()
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = f1_score(y_true, y_pred, zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
    fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
    fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
    jacc = jaccard_score(y_true, y_pred, zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred)
    if y_proba is not None and len(np.unique(y_true)) > 1:
        try:
            roc_auc = roc_auc_score(y_true, y_proba)
            pr_auc = average_precision_score(y_true, y_proba)
            logloss = log_loss(y_true, np.column_stack([1-y_proba, y_proba]))
            brier = brier_score_loss(y_true, y_proba)
        except:
            roc_auc = pr_auc = logloss = brier = 0.0
    else:
        roc_auc = pr_auc = logloss = brier = 0.0
    results = {
        'cm': cm, 'tn': int(tn), 'fp': int(fp), 'fn': int(fn), 'tp': int(tp),
        'acc': acc, 'bal_acc': bal_acc, 'prec': prec, 'rec': rec,
        'specificity': specificity, 'npv': npv, 'fpr': fpr, 'fnr': fnr,
        'f1': f1, 'mcc': mcc, 'jacc': jacc, 'kappa': kappa,
        'roc_auc': roc_auc, 'pr_auc': pr_auc, 'log_loss': logloss, 'brier': brier
    }
    if embeddings is not None:
        recall_metrics = calculate_recall_at_k(embeddings, torch.tensor(y_true))
        mrr = calculate_mrr(embeddings, torch.tensor(y_true))
        results.update(recall_metrics)
        results['mrr'] = mrr
    return results


def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}


def train_model(model, train_data, num_epochs=15, learning_rate=1e-4, batch_size=2):
    lora_params = list(model.lora_layers.parameters())
    optimizer = torch.optim.AdamW(
        [{'params': lora_params}, {'params': model.classifier.parameters()}], 
        lr=learning_rate, 
        weight_decay=0.01
    )
    dataset = CodeCloneDataset(train_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    model.train()
    epoch_stats = []
    start_time = time.time()
    initial_memory = get_memory_usage()
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0
        correct = 0
        total = 0
        batch_count = 0
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs['loss']
            if loss is not None:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(lora_params + list(model.classifier.parameters()), 1.0)
                optimizer.step()
                total_loss += loss.item()
                with torch.no_grad():
                    predictions = torch.argmax(outputs['logits'], dim=-1)
                    correct += (predictions == labels.to(predictions.device)).sum().item()
                    total += labels.size(0)
            batch_count += 1
            if batch_count % 10 == 0:
                torch.cuda.empty_cache()
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / batch_count if batch_count > 0 else 0
        accuracy = correct / total if total > 0 else 0
        current_memory = get_memory_usage()
        epoch_stats.append({
            'epoch': epoch + 1,
            'loss': avg_loss,
            'acc': accuracy,
            'time': epoch_time,
            'memory_mb': current_memory
        })
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}, Acc: {accuracy:.4f}, Time: {epoch_time:.2f}s, Memory: {current_memory:.2f}MB")
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
    total_training_time = time.time() - start_time
    peak_memory = max([stat['memory_mb'] for stat in epoch_stats])
    memory_increase = peak_memory - initial_memory
    print(f"\n{'='*80}\nTRAINING SUMMARY\n{'='*80}")
    print(f"Total Training Time: {total_training_time:.2f}s")
    print(f"Peak Memory Usage: {peak_memory:.2f}MB")
    print(f"Memory Increase: {memory_increase:.2f}MB\n{'='*80}\n")
    return epoch_stats


def evaluate_model(model, test_data, batch_size=2):
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    all_embeddings = []
    dataset = CodeCloneDataset(test_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    start_time = time.time()
    with torch.no_grad():
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']
            hidden_states = outputs['hidden_states']
            predictions = torch.argmax(logits, dim=-1)
            probabilities = F.softmax(logits, dim=-1)
            pooled_embeddings = hidden_states.mean(dim=1)
            pooled_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities[:, 1].cpu().numpy())
            all_embeddings.append(pooled_embeddings.cpu())
            if len(all_predictions) % 20 == 0:
                torch.cuda.empty_cache()
    inference_time = time.time() - start_time
    samples_per_sec = len(test_data) / inference_time
    all_embeddings = torch.cat(all_embeddings, dim=0)
    metrics = calculate_comprehensive_metrics(all_labels, all_predictions, all_probabilities, all_embeddings)
    metrics['inference_time'] = inference_time
    metrics['samples_per_sec'] = samples_per_sec
    return metrics


def print_results(metrics, dataset_name):
    print(f"\n{'='*80}\nRESULTS: {dataset_name}\n{'='*80}\n\nConfusion Matrix:\n{metrics['cm']}\n\nTrue Negatives (TN): {metrics['tn']}\nFalse Positives (FP): {metrics['fp']}\nFalse Negatives (FN): {metrics['fn']}\nTrue Positives (TP): {metrics['tp']}\n\n{'='*80}\nCLASSIFICATION METRICS\n{'='*80}")
    print(f"Accuracy: {metrics['acc']:.4f}\nBalanced Accuracy: {metrics['bal_acc']:.4f}\nPrecision: {metrics['prec']:.4f}\nRecall (Sensitivity): {metrics['rec']:.4f}\nF1 Score: {metrics['f1']:.4f}\nSpecificity: {metrics['specificity']:.4f}\nNegative Predictive Value (NPV): {metrics['npv']:.4f}\nFalse Positive Rate (FPR): {metrics['fpr']:.4f}\nFalse Negative Rate (FNR): {metrics['fnr']:.4f}")
    print(f"\n{'='*80}\nADVANCED METRICS\n{'='*80}\nJaccard Score: {metrics['jacc']:.4f}\nMatthews Correlation Coefficient (MCC): {metrics['mcc']:.4f}\nCohen's Kappa: {metrics['kappa']:.4f}")
    if 'roc_auc' in metrics:
        print(f"\n{'='*80}\nPROBABILISTIC METRICS\n{'='*80}\nROC AUC Score: {metrics['roc_auc']:.4f}")
    if 'pr_auc' in metrics:
        print(f"Precision-Recall AUC: {metrics['pr_auc']:.4f}")
    if 'log_loss' in metrics:
        print(f"Log Loss: {metrics['log_loss']:.4f}")
    if 'brier' in metrics:
        print(f"Brier Score: {metrics['brier']:.4f}")
    if 'recall@1' in metrics:
        print(f"\n{'='*80}\nRETRIEVAL METRICS\n{'='*80}\nRecall@1: {metrics['recall@1']:.4f}\nRecall@3: {metrics['recall@3']:.4f}\nRecall@5: {metrics['recall@5']:.4f}\nRecall@10: {metrics['recall@10']:.4f}\nMean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")
    if 'inference_time' in metrics:
        print(f"\n{'='*80}\nPERFORMANCE METRICS\n{'='*80}\nInference Time: {metrics['inference_time']:.2f}s\nThroughput: {metrics['samples_per_sec']:.2f} samples/sec")
    print(f"{'='*80}\n")


def main():
    torch.cuda.empty_cache()
    gc.collect()
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    cache_dir = '/hf_cache'
    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-770m", cache_dir=cache_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    lora_config = LoRAConfig(r=8, lora_alpha=16, lora_dropout=0.1)
    print("\nLoading model with LoRA (rank=8) and pipeline parallelism across GPUs...")
    model = LoRACodeT5("Salesforce/codet5p-770m", lora_config, num_classes=2, cache_dir=cache_dir)
    print_parameter_statistics(model)
    print(f"\n{'='*80}\nTRAINING PHASE\n{'='*80}\n")
    train_csv = '/train.csv'
    train_data = create_training_data(train_csv, tokenizer, label_col='label', func1_col='func1', func2_col='func2')
    print(f"Created {len(train_data)} training examples\n")
    if len(train_data) == 0:
        print("No training data available. Exiting.")
        return
    epoch_stats = train_model(model, train_data, num_epochs=5, learning_rate=1e-4, batch_size=2)
    print(f"\n{'='*80}\nTESTING PHASE\n{'='*80}\n")
    test_csv = '/test.csv'
    test_data = create_training_data(test_csv, tokenizer, label_col='label', func1_col='func1', func2_col='func2')
    print(f"Created {len(test_data)} test examples\n")
    if len(test_data) > 0:
        test_metrics = evaluate_model(model, test_data, batch_size=2)
        print_results(test_metrics, "TEST SET")
    print(f"\n{'='*80}\nTRAINING COMPLETE\n{'='*80}\n")


if __name__ == '__main__':
    main()

# bitfit

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import numpy as np
from sklearn.metrics import (
    accuracy_score, confusion_matrix, balanced_accuracy_score, 
    matthews_corrcoef, roc_auc_score, average_precision_score, 
    f1_score, jaccard_score, cohen_kappa_score, log_loss, brier_score_loss
)
import time
import psutil
import os
import gc
from torch.utils.data import Dataset, DataLoader


class BitFitConfig:
    def __init__(self, bias_type: str = 'all'):
        self.bias_type = bias_type


class CodeCloneDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]


class BitFitCodeT5(nn.Module):
    def __init__(self, model_name: str, bitfit_config: BitFitConfig, num_classes: int = 2, cache_dir: str = None):
        super().__init__()
        torch.cuda.empty_cache()
        gc.collect()
        num_gpus = torch.cuda.device_count()
        print(f"Initializing model across {num_gpus} GPUs")
        self.primary_device = 'cuda:0'
        print("Loading base LLM model...")
        base_model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name, 
            torch_dtype=torch.float32, 
            cache_dir=cache_dir, 
            low_cpu_mem_usage=True, 
            device_map='auto'
        )
        self.llm = base_model
        self.bitfit_config = bitfit_config
        encoder_layers = len(base_model.encoder.block)
        decoder_layers = len(base_model.decoder.block)
        total_layers = encoder_layers + decoder_layers
        print(f"Model has {encoder_layers} encoder layers and {decoder_layers} decoder layers (total: {total_layers})")
        self.layer_devices = []
        for i in range(encoder_layers):
            layer_device = next(base_model.encoder.block[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Encoder Layer {i} -> {layer_device}")
        for i in range(decoder_layers):
            layer_device = next(base_model.decoder.block[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Decoder Layer {i} -> {layer_device}")
        embed_device = next(base_model.encoder.embed_tokens.parameters()).device
        final_device = next(base_model.decoder.final_layer_norm.parameters()).device
        print(f"Embeddings -> {embed_device}")
        print(f"Final Layer -> {final_device}")
        self.embed_device = str(embed_device)
        self.final_device = str(final_device)
        for param in self.llm.parameters():
            param.requires_grad = False
        self._enable_bias_training()
        self.num_classes = num_classes
        hidden_size = base_model.config.d_model
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 4, num_classes)
        ).to(self.primary_device).float()
        for param in self.classifier.parameters():
            param.requires_grad = True
        print("Model initialization complete!")
    
    def _enable_bias_training(self):
        trainable_bias_count = 0
        for name, param in self.llm.named_parameters():
            if 'bias' in name.lower():
                if self.bitfit_config.bias_type == 'all':
                    param.requires_grad = True
                    trainable_bias_count += param.numel()
                elif self.bitfit_config.bias_type == 'query_mlp':
                    if any(x in name.lower() for x in ['q.bias', 'query.bias', 'dense.bias']):
                        param.requires_grad = True
                        trainable_bias_count += param.numel()
        print(f"BitFit enabled with {trainable_bias_count:,} trainable bias parameters")
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        input_ids = input_ids.to(self.embed_device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.embed_device)
        outputs = self.llm(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            decoder_input_ids=input_ids, 
            output_hidden_states=True, 
            return_dict=True
        )
        hidden_states = outputs.decoder_hidden_states[-1]
        if str(hidden_states.device) != self.primary_device:
            hidden_states = hidden_states.to(self.primary_device)
        pooled_output = hidden_states.mean(dim=1)
        logits = self.classifier(pooled_output)
        loss = None
        if labels is not None:
            if labels.device != logits.device:
                labels = labels.to(logits.device)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        return {'logits': logits, 'loss': loss, 'hidden_states': hidden_states}


def print_parameter_statistics(model: BitFitCodeT5):
    total_params = sum(p.numel() for p in model.llm.parameters())
    bias_params = sum(p.numel() for p in model.llm.parameters() if p.requires_grad)
    classifier_params = sum(p.numel() for p in model.classifier.parameters())
    trainable_params = bias_params + classifier_params
    print(f"\n{'='*80}\nPARAMETER STATISTICS\n{'='*80}")
    print(f"Total Frozen LLM Parameters: {total_params:,}")
    print(f"Trainable Bias Parameters: {bias_params:,}")
    print(f"Trainable Classifier Parameters: {classifier_params:,}")
    print(f"Total Trainable Parameters: {trainable_params:,}")
    print(f"Trainable Percentage: {(trainable_params / total_params * 100):.4f}%")
    print(f"BitFit Bias Type: {model.bitfit_config.bias_type}")
    print(f"{'='*80}\n")
    print(f"{'='*80}\nGPU MEMORY USAGE\n{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        print(f"GPU {i}: Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
    print(f"{'='*80}\n")


def load_dataset(csv_path, label_col, func1_col, func2_col):
    try:
        df = pd.read_csv(csv_path)
        print(f"Loaded dataset: {len(df)} samples")
        print(f"Label distribution: {df[label_col].value_counts().to_dict()}")
        return df
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None


def create_training_data(csv_path, tokenizer, max_length=512, label_col='label', func1_col='func1', func2_col='func2'):
    df = load_dataset(csv_path, label_col, func1_col, func2_col)
    if df is None:
        return []
    train_data = []
    for _, row in df.iterrows():
        func1 = str(row[func1_col])
        func2 = str(row[func2_col])
        label = int(row[label_col])
        combined_text = f"Code1: {func1}\nCode2: {func2}\nAre these code clones?"
        encoding = tokenizer(
            combined_text, 
            return_tensors='pt', 
            padding='max_length', 
            truncation=True, 
            max_length=max_length
        )
        train_data.append({
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        })
    return train_data


def get_memory_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024


def calculate_recall_at_k(embeddings, labels, k_values=[1, 3, 5, 10]):
    similarities = torch.mm(embeddings, embeddings.t())
    results = {}
    for k in k_values:
        correct = 0
        total = 0
        for i in range(len(labels)):
            true_label = labels[i]
            sims = similarities[i].clone()
            sims[i] = -float('inf')
            top_k_indices = torch.topk(sims, min(k, len(labels)-1)).indices
            top_k_labels = labels[top_k_indices]
            if true_label in top_k_labels:
                correct += 1
            total += 1
        results[f'recall@{k}'] = correct / total if total > 0 else 0.0
    return results


def calculate_mrr(embeddings, labels):
    similarities = torch.mm(embeddings, embeddings.t())
    mrr_sum = 0.0
    count = 0
    for i in range(len(labels)):
        true_label = labels[i]
        sims = similarities[i].clone()
        sims[i] = -float('inf')
        sorted_indices = torch.argsort(sims, descending=True)
        sorted_labels = labels[sorted_indices]
        for rank, label in enumerate(sorted_labels, 1):
            if label == true_label:
                mrr_sum += 1.0 / rank
                break
        count += 1
    return mrr_sum / count if count > 0 else 0.0


def calculate_comprehensive_metrics(y_true, y_pred, y_proba=None, embeddings=None):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    if cm.shape != (2, 2):
        cm_2x2 = np.zeros((2, 2))
        for i in range(min(cm.shape[0], 2)):
            for j in range(min(cm.shape[1], 2)):
                cm_2x2[i, j] = cm[i, j]
        cm = cm_2x2
    tn, fp, fn, tp = cm.ravel()
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = f1_score(y_true, y_pred, zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
    fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
    fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
    jacc = jaccard_score(y_true, y_pred, zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred)
    if y_proba is not None and len(np.unique(y_true)) > 1:
        try:
            roc_auc = roc_auc_score(y_true, y_proba)
            pr_auc = average_precision_score(y_true, y_proba)
            logloss = log_loss(y_true, np.column_stack([1-y_proba, y_proba]))
            brier = brier_score_loss(y_true, y_proba)
        except:
            roc_auc = pr_auc = logloss = brier = 0.0
    else:
        roc_auc = pr_auc = logloss = brier = 0.0
    results = {
        'cm': cm, 'tn': int(tn), 'fp': int(fp), 'fn': int(fn), 'tp': int(tp),
        'acc': acc, 'bal_acc': bal_acc, 'prec': prec, 'rec': rec,
        'specificity': specificity, 'npv': npv, 'fpr': fpr, 'fnr': fnr,
        'f1': f1, 'mcc': mcc, 'jacc': jacc, 'kappa': kappa,
        'roc_auc': roc_auc, 'pr_auc': pr_auc, 'log_loss': logloss, 'brier': brier
    }
    if embeddings is not None:
        recall_metrics = calculate_recall_at_k(embeddings, torch.tensor(y_true))
        mrr = calculate_mrr(embeddings, torch.tensor(y_true))
        results.update(recall_metrics)
        results['mrr'] = mrr
    return results


def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}


def train_model(model, train_data, num_epochs=5, learning_rate=1e-3, batch_size=2):
    bias_params = [p for p in model.llm.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(
        [{'params': bias_params}, {'params': model.classifier.parameters()}], 
        lr=learning_rate, 
        weight_decay=0.01
    )
    dataset = CodeCloneDataset(train_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    model.train()
    epoch_stats = []
    start_time = time.time()
    initial_memory = get_memory_usage()
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0
        correct = 0
        total = 0
        batch_count = 0
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs['loss']
            if loss is not None:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(bias_params + list(model.classifier.parameters()), 1.0)
                optimizer.step()
                total_loss += loss.item()
                with torch.no_grad():
                    predictions = torch.argmax(outputs['logits'], dim=-1)
                    correct += (predictions == labels.to(predictions.device)).sum().item()
                    total += labels.size(0)
            batch_count += 1
            if batch_count % 10 == 0:
                torch.cuda.empty_cache()
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / batch_count if batch_count > 0 else 0
        accuracy = correct / total if total > 0 else 0
        current_memory = get_memory_usage()
        epoch_stats.append({
            'epoch': epoch + 1,
            'loss': avg_loss,
            'acc': accuracy,
            'time': epoch_time,
            'memory_mb': current_memory
        })
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}, Acc: {accuracy:.4f}, Time: {epoch_time:.2f}s, Memory: {current_memory:.2f}MB")
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
    total_training_time = time.time() - start_time
    peak_memory = max([stat['memory_mb'] for stat in epoch_stats])
    memory_increase = peak_memory - initial_memory
    print(f"\n{'='*80}\nTRAINING SUMMARY\n{'='*80}")
    print(f"Total Training Time: {total_training_time:.2f}s")
    print(f"Peak Memory Usage: {peak_memory:.2f}MB")
    print(f"Memory Increase: {memory_increase:.2f}MB\n{'='*80}\n")
    return epoch_stats


def evaluate_model(model, test_data, batch_size=2):
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    all_embeddings = []
    dataset = CodeCloneDataset(test_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    start_time = time.time()
    with torch.no_grad():
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']
            hidden_states = outputs['hidden_states']
            predictions = torch.argmax(logits, dim=-1)
            probabilities = F.softmax(logits, dim=-1)
            pooled_embeddings = hidden_states.mean(dim=1)
            pooled_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities[:, 1].cpu().numpy())
            all_embeddings.append(pooled_embeddings.cpu())
            if len(all_predictions) % 20 == 0:
                torch.cuda.empty_cache()
    inference_time = time.time() - start_time
    samples_per_sec = len(test_data) / inference_time
    all_embeddings = torch.cat(all_embeddings, dim=0)
    metrics = calculate_comprehensive_metrics(all_labels, all_predictions, all_probabilities, all_embeddings)
    metrics['inference_time'] = inference_time
    metrics['samples_per_sec'] = samples_per_sec
    return metrics


def print_results(metrics, dataset_name):
    print(f"\n{'='*80}\nRESULTS: {dataset_name}\n{'='*80}\n\nConfusion Matrix:\n{metrics['cm']}\n\nTrue Negatives (TN): {metrics['tn']}\nFalse Positives (FP): {metrics['fp']}\nFalse Negatives (FN): {metrics['fn']}\nTrue Positives (TP): {metrics['tp']}\n\n{'='*80}\nCLASSIFICATION METRICS\n{'='*80}")
    print(f"Accuracy: {metrics['acc']:.4f}\nBalanced Accuracy: {metrics['bal_acc']:.4f}\nPrecision: {metrics['prec']:.4f}\nRecall (Sensitivity): {metrics['rec']:.4f}\nF1 Score: {metrics['f1']:.4f}\nSpecificity: {metrics['specificity']:.4f}\nNegative Predictive Value (NPV): {metrics['npv']:.4f}\nFalse Positive Rate (FPR): {metrics['fpr']:.4f}\nFalse Negative Rate (FNR): {metrics['fnr']:.4f}")
    print(f"\n{'='*80}\nADVANCED METRICS\n{'='*80}\nJaccard Score: {metrics['jacc']:.4f}\nMatthews Correlation Coefficient (MCC): {metrics['mcc']:.4f}\nCohen's Kappa: {metrics['kappa']:.4f}")
    if 'roc_auc' in metrics:
        print(f"\n{'='*80}\nPROBABILISTIC METRICS\n{'='*80}\nROC AUC Score: {metrics['roc_auc']:.4f}")
    if 'pr_auc' in metrics:
        print(f"Precision-Recall AUC: {metrics['pr_auc']:.4f}")
    if 'log_loss' in metrics:
        print(f"Log Loss: {metrics['log_loss']:.4f}")
    if 'brier' in metrics:
        print(f"Brier Score: {metrics['brier']:.4f}")
    if 'recall@1' in metrics:
        print(f"\n{'='*80}\nRETRIEVAL METRICS\n{'='*80}\nRecall@1: {metrics['recall@1']:.4f}\nRecall@3: {metrics['recall@3']:.4f}\nRecall@5: {metrics['recall@5']:.4f}\nRecall@10: {metrics['recall@10']:.4f}\nMean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")
    if 'inference_time' in metrics:
        print(f"\n{'='*80}\nPERFORMANCE METRICS\n{'='*80}\nInference Time: {metrics['inference_time']:.2f}s\nThroughput: {metrics['samples_per_sec']:.2f} samples/sec")
    print(f"{'='*80}\n")


def main():
    torch.cuda.empty_cache()
    gc.collect()
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    cache_dir = '/hf_cache'
    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-770m", cache_dir=cache_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    bitfit_config = BitFitConfig(bias_type='all')
    print("\nLoading model with BitFit (bias-only fine-tuning) and pipeline parallelism across GPUs...")
    model = BitFitCodeT5("Salesforce/codet5p-770m", bitfit_config, num_classes=2, cache_dir=cache_dir)
    print_parameter_statistics(model)
    print(f"\n{'='*80}\nTRAINING PHASE\n{'='*80}\n")
    train_csv = '/train.csv'
    train_data = create_training_data(train_csv, tokenizer, label_col='label', func1_col='func1', func2_col='func2')
    print(f"Created {len(train_data)} training examples\n")
    if len(train_data) == 0:
        print("No training data available. Exiting.")
        return
    epoch_stats = train_model(model, train_data, num_epochs=5, learning_rate=1e-3, batch_size=8)
    print(f"\n{'='*80}\nTESTING PHASE\n{'='*80}\n")
    test_csv = '/test.csv'
    test_data = create_training_data(test_csv, tokenizer, label_col='label', func1_col='func1', func2_col='func2')
    print(f"Created {len(test_data)} test examples\n")
    if len(test_data) > 0:
        test_metrics = evaluate_model(model, test_data, batch_size=8)
        print_results(test_metrics, "TEST SET")
    print(f"\n{'='*80}\nTRAINING COMPLETE\n{'='*80}\n")


if __name__ == '__main__':
    main()

# TS PEft

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import numpy as np
from sklearn.metrics import (
    accuracy_score, confusion_matrix, balanced_accuracy_score, 
    matthews_corrcoef, roc_auc_score, average_precision_score, 
    f1_score, jaccard_score, cohen_kappa_score, log_loss, brier_score_loss
)
import time
import psutil
import os
import gc
import math
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict
from dataclasses import dataclass


@dataclass
class TSPEFTConfig:
    rank: int = 16
    alpha: float = 0.5
    dropout: float = 0.05
    s: float = 4e-5
    lambda_reg: float = 4.5e-5
    target_modules: List[str] = None
    beta1: float = 0.9
    beta2: float = 0.98
    eps: float = 1e-8
    
    def __post_init__(self):
        if self.target_modules is None:
            self.target_modules = ["q", "k", "v", "o", "wi", "wo"]


class LoRALayer(nn.Module):
    def __init__(self, in_features: int, out_features: int, rank: int = 4, alpha: float = 1.0, dropout: float = 0.0):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        result = self.dropout(x) @ self.lora_A.T @ self.lora_B.T
        return result * self.scaling


class TSPEFTLayer(nn.Module):
    def __init__(self, base_layer: nn.Linear, rank: int = 4, alpha: float = 1.0, dropout: float = 0.0, 
                 s: float = 4e-5, lambda_reg: float = 1e-5, beta1: float = 0.9, beta2: float = 0.98, eps: float = 1e-8):
        super().__init__()
        self.base_layer = base_layer
        self.base_layer.requires_grad_(False)
        self.lora = LoRALayer(base_layer.in_features, base_layer.out_features, rank=rank, alpha=alpha, dropout=dropout)
        self.s = s
        self.lambda_reg = lambda_reg
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.register_buffer('tau', torch.tensor(0.0))
        self.register_buffer('m', torch.tensor(0.0))
        self.register_buffer('v', torch.tensor(0.0))
        self.register_buffer('step', torch.tensor(0))
        
    @property
    def weight(self):
        return self.base_layer.weight
    
    @property
    def bias(self):
        return self.base_layer.bias if hasattr(self.base_layer, 'bias') else None
    
    @property
    def in_features(self):
        return self.base_layer.in_features
    
    @property
    def out_features(self):
        return self.base_layer.out_features
        
    def compute_relative_magnitude(self, base_output: torch.Tensor, lora_output: torch.Tensor) -> torch.Tensor:
        base_norm = torch.norm(base_output, p=2, dim=-1, keepdim=True)
        lora_norm = torch.norm(lora_output, p=2, dim=-1, keepdim=True)
        r_i = lora_norm / (base_norm + self.eps)
        return r_i.squeeze(-1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        base_output = self.base_layer(x)
        lora_output = self.lora(x)
        if not self.training:
            r_i = self.compute_relative_magnitude(base_output, lora_output)
            gate = (r_i >= self.tau).float().unsqueeze(-1)
            return base_output + gate * lora_output
        r_i = self.compute_relative_magnitude(base_output, lora_output)
        gate = (r_i >= self.tau).float()
        gated_output = base_output + gate.unsqueeze(-1) * lora_output
        self._cache_for_backward = {'r_i': r_i, 'gate': gate, 'lora_output': lora_output, 'base_output': base_output}
        return gated_output
    
    def compute_threshold_gradient(self, grad_output: torch.Tensor) -> float:
        if not hasattr(self, '_cache_for_backward'):
            return 0.0
        cache = self._cache_for_backward
        r_i = cache['r_i']
        gate = cache['gate']
        lora_output = cache['lora_output']
        mu_i = (grad_output * lora_output).sum(dim=-1)
        consistency_mask = ((mu_i >= 0).float() == gate).float()
        sparsity_mask = gate
        grad_loss = -self.s * (consistency_mask * mu_i).sum()
        grad_sparsity = -self.s * (sparsity_mask * self.lambda_reg).sum()
        g_k = grad_loss + grad_sparsity
        return g_k.item()
    
    def update_threshold(self, grad_output: torch.Tensor, lr: float = 1.0):
        if not self.training:
            return
        g_k = self.compute_threshold_gradient(grad_output)
        self.step += 1
        self.m = self.beta1 * self.m + (1 - self.beta1) * g_k
        self.v = self.beta2 * self.v + (1 - self.beta2) * (g_k ** 2)
        m_hat = self.m / (1 - self.beta1 ** self.step.item())
        v_hat = self.v / (1 - self.beta2 ** self.step.item())
        tau_update = lr * self.s * m_hat / (torch.sqrt(v_hat) + self.eps)
        self.tau = torch.clamp(self.tau + tau_update, min=0.0)
        if hasattr(self, '_cache_for_backward'):
            delattr(self, '_cache_for_backward')


class CodeCloneDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]


class TSPEFTCodeT5(nn.Module):
    def __init__(self, model_name: str, tspeft_config: TSPEFTConfig, num_classes: int = 2, cache_dir: str = None):
        super().__init__()
        torch.cuda.empty_cache()
        gc.collect()
        num_gpus = torch.cuda.device_count()
        print(f"Initializing model across {num_gpus} GPUs")
        self.primary_device = 'cuda:0'
        print("Loading base LLM model...")
        base_model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name, torch_dtype=torch.float32, cache_dir=cache_dir, 
            low_cpu_mem_usage=True, device_map='auto'
        )
        self.llm = base_model
        if hasattr(self.llm, 'gradient_checkpointing_enable'):
            self.llm.gradient_checkpointing_enable()
        self.tspeft_config = tspeft_config
        encoder_layers = len(base_model.encoder.block)
        decoder_layers = len(base_model.decoder.block)
        total_layers = encoder_layers + decoder_layers
        print(f"Model has {encoder_layers} encoder layers and {decoder_layers} decoder layers (total: {total_layers})")
        self.layer_devices = []
        for i in range(encoder_layers):
            layer_device = next(base_model.encoder.block[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Encoder Layer {i} -> {layer_device}")
        for i in range(decoder_layers):
            layer_device = next(base_model.decoder.block[i].parameters()).device
            self.layer_devices.append(str(layer_device))
            if i % 5 == 0:
                print(f"  Decoder Layer {i} -> {layer_device}")
        embed_device = next(base_model.encoder.embed_tokens.parameters()).device
        final_device = next(base_model.decoder.final_layer_norm.parameters()).device
        print(f"Embeddings -> {embed_device}")
        print(f"Final Layer -> {final_device}")
        self.embed_device = str(embed_device)
        self.final_device = str(final_device)
        for param in self.llm.parameters():
            param.requires_grad = False
        self.tspeft_layers = nn.ModuleDict()
        self._inject_tspeft_layers()
        self.num_classes = num_classes
        hidden_size = base_model.config.d_model
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 4, num_classes)
        ).to(self.primary_device).float()
        for param in self.classifier.parameters():
            param.requires_grad = True
        print("Model initialization complete!")
    
    def _inject_tspeft_layers(self):
        trainable_params = 0
        layer_counter = 0
        for name, module in self.llm.named_modules():
            if isinstance(module, nn.Linear):
                parent_name = '.'.join(name.split('.')[:-1])
                layer_name = name.split('.')[-1]
                should_inject = any(target in layer_name for target in self.tspeft_config.target_modules)
                if should_inject:
                    tspeft_layer = TSPEFTLayer(
                        base_layer=module,
                        rank=self.tspeft_config.rank,
                        alpha=self.tspeft_config.alpha,
                        dropout=self.tspeft_config.dropout,
                        s=self.tspeft_config.s,
                        lambda_reg=self.tspeft_config.lambda_reg,
                        beta1=self.tspeft_config.beta1,
                        beta2=self.tspeft_config.beta2,
                        eps=self.tspeft_config.eps
                    )
                    layer_device = next(module.parameters()).device
                    tspeft_layer = tspeft_layer.to(layer_device)
                    safe_key = f"tspeft_layer_{layer_counter}"
                    self.tspeft_layers[safe_key] = tspeft_layer
                    layer_counter += 1
                    parent = self.llm
                    if parent_name:
                        for part in parent_name.split('.'):
                            parent = getattr(parent, part)
                    setattr(parent, layer_name, tspeft_layer)
                    trainable_params += tspeft_layer.lora.lora_A.numel() + tspeft_layer.lora.lora_B.numel()
        print(f"TS-PEFT enabled with {trainable_params:,} trainable parameters across {len(self.tspeft_layers)} layers")
        
    def forward(self, input_ids, attention_mask=None, labels=None):
        input_ids = input_ids.to(self.embed_device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.embed_device)
        outputs = self.llm(
            input_ids=input_ids, attention_mask=attention_mask, 
            decoder_input_ids=input_ids, output_hidden_states=True, return_dict=True
        )
        hidden_states = outputs.decoder_hidden_states[-1]
        if str(hidden_states.device) != self.primary_device:
            hidden_states = hidden_states.to(self.primary_device)
        pooled_output = hidden_states.mean(dim=1)
        logits = self.classifier(pooled_output)
        loss = None
        if labels is not None:
            if labels.device != logits.device:
                labels = labels.to(logits.device)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        return {'logits': logits, 'loss': loss, 'hidden_states': hidden_states}
    
    def update_thresholds(self, lr: float = 1.0):
        for layer in self.tspeft_layers.values():
            if hasattr(layer, '_cache_for_backward') and layer.training:
                grad_output = torch.ones_like(layer._cache_for_backward['base_output'])
                layer.update_threshold(grad_output, lr)
    
    def get_sparsity_stats(self) -> Dict[str, float]:
        stats = {}
        for name, layer in self.tspeft_layers.items():
            if hasattr(layer, '_cache_for_backward'):
                gate = layer._cache_for_backward['gate']
                sparsity = (1 - gate.float().mean()).item() * 100
                stats[name] = sparsity
        return stats


def print_parameter_statistics(model: TSPEFTCodeT5):
    total_params = sum(p.numel() for p in model.llm.parameters())
    tspeft_params = sum(p.numel() for p in model.tspeft_layers.parameters() if p.requires_grad)
    classifier_params = sum(p.numel() for p in model.classifier.parameters())
    trainable_params = tspeft_params + classifier_params
    print(f"\n{'='*80}\nPARAMETER STATISTICS\n{'='*80}")
    print(f"Total Frozen LLM Parameters: {total_params:,}")
    print(f"Trainable TS-PEFT Parameters: {tspeft_params:,}")
    print(f"Trainable Classifier Parameters: {classifier_params:,}")
    print(f"Total Trainable Parameters: {trainable_params:,}")
    print(f"Trainable Percentage: {(trainable_params / total_params * 100):.4f}%")
    print(f"TS-PEFT Rank: {model.tspeft_config.rank}")
    print(f"TS-PEFT Alpha: {model.tspeft_config.alpha}")
    print(f"TS-PEFT s: {model.tspeft_config.s}")
    print(f"TS-PEFT lambda_reg: {model.tspeft_config.lambda_reg}")
    print(f"{'='*80}\n")
    print(f"{'='*80}\nGPU MEMORY USAGE\n{'='*80}")
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1024**3
        reserved = torch.cuda.memory_reserved(i) / 1024**3
        print(f"GPU {i}: Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
    print(f"{'='*80}\n")


def load_dataset(csv_path, label_col, func1_col, func2_col):
    try:
        df = pd.read_csv(csv_path)
        print(f"Loaded dataset: {len(df)} samples")
        print(f"Label distribution: {df[label_col].value_counts().to_dict()}")
        return df
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None


def create_training_data(csv_path, tokenizer, max_length=512, label_col='label', func1_col='func1', func2_col='func2'):
    df = load_dataset(csv_path, label_col, func1_col, func2_col)
    if df is None:
        return []
    train_data = []
    for _, row in df.iterrows():
        func1 = str(row[func1_col])
        func2 = str(row[func2_col])
        label = int(row[label_col])
        combined_text = f"Code1: {func1}\nCode2: {func2}\nAre these code clones?"
        encoding = tokenizer(
            combined_text, return_tensors='pt', padding='max_length', 
            truncation=True, max_length=max_length
        )
        train_data.append({
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        })
    return train_data


def get_memory_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024


def calculate_recall_at_k(embeddings, labels, k_values=[1, 3, 5, 10]):
    similarities = torch.mm(embeddings, embeddings.t())
    results = {}
    for k in k_values:
        correct = 0
        total = 0
        for i in range(len(labels)):
            true_label = labels[i]
            sims = similarities[i].clone()
            sims[i] = -float('inf')
            top_k_indices = torch.topk(sims, min(k, len(labels)-1)).indices
            top_k_labels = labels[top_k_indices]
            if true_label in top_k_labels:
                correct += 1
            total += 1
        results[f'recall@{k}'] = correct / total if total > 0 else 0.0
    return results


def calculate_mrr(embeddings, labels):
    similarities = torch.mm(embeddings, embeddings.t())
    mrr_sum = 0.0
    count = 0
    for i in range(len(labels)):
        true_label = labels[i]
        sims = similarities[i].clone()
        sims[i] = -float('inf')
        sorted_indices = torch.argsort(sims, descending=True)
        sorted_labels = labels[sorted_indices]
        for rank, label in enumerate(sorted_labels, 1):
            if label == true_label:
                mrr_sum += 1.0 / rank
                break
        count += 1
    return mrr_sum / count if count > 0 else 0.0


def calculate_comprehensive_metrics(y_true, y_pred, y_proba=None, embeddings=None):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    if cm.shape != (2, 2):
        cm_2x2 = np.zeros((2, 2))
        for i in range(min(cm.shape[0], 2)):
            for j in range(min(cm.shape[1], 2)):
                cm_2x2[i, j] = cm[i, j]
        cm = cm_2x2
    tn, fp, fn, tp = cm.ravel()
    acc = accuracy_score(y_true, y_pred)
    bal_acc = balanced_accuracy_score(y_true, y_pred)
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = f1_score(y_true, y_pred, zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
    fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
    fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
    jacc = jaccard_score(y_true, y_pred, zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred)
    if y_proba is not None and len(np.unique(y_true)) > 1:
        try:
            roc_auc = roc_auc_score(y_true, y_proba)
            pr_auc = average_precision_score(y_true, y_proba)
            logloss = log_loss(y_true, np.column_stack([1-y_proba, y_proba]))
            brier = brier_score_loss(y_true, y_proba)
        except:
            roc_auc = pr_auc = logloss = brier = 0.0
    else:
        roc_auc = pr_auc = logloss = brier = 0.0
    results = {
        'cm': cm, 'tn': int(tn), 'fp': int(fp), 'fn': int(fn), 'tp': int(tp),
        'acc': acc, 'bal_acc': bal_acc, 'prec': prec, 'rec': rec,
        'specificity': specificity, 'npv': npv, 'fpr': fpr, 'fnr': fnr,
        'f1': f1, 'mcc': mcc, 'jacc': jacc, 'kappa': kappa,
        'roc_auc': roc_auc, 'pr_auc': pr_auc, 'log_loss': logloss, 'brier': brier
    }
    if embeddings is not None:
        recall_metrics = calculate_recall_at_k(embeddings, torch.tensor(y_true))
        mrr = calculate_mrr(embeddings, torch.tensor(y_true))
        results.update(recall_metrics)
        results['mrr'] = mrr
    return results


def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}


def train_model(model, train_data, num_epochs=5, learning_rate=1e-3, batch_size=2):
    tspeft_params = [p for p in model.tspeft_layers.parameters() if p.requires_grad]
    optimizer = AdamW(
        [{'params': tspeft_params}, {'params': model.classifier.parameters()}], 
        lr=learning_rate, weight_decay=0.01
    )
    dataset = CodeCloneDataset(train_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    model.train()
    epoch_stats = []
    start_time = time.time()
    initial_memory = get_memory_usage()
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0
        correct = 0
        total = 0
        batch_count = 0
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs['loss']
            if loss is not None:
                loss.backward()
                model.update_thresholds(lr=1.0)
                torch.nn.utils.clip_grad_norm_(tspeft_params + list(model.classifier.parameters()), 1.0)
                optimizer.step()
                total_loss += loss.item()
                with torch.no_grad():
                    predictions = torch.argmax(outputs['logits'], dim=-1)
                    correct += (predictions == labels.to(predictions.device)).sum().item()
                    total += labels.size(0)
                del outputs, loss, predictions
                torch.cuda.empty_cache()
            batch_count += 1
            if batch_count % 5 == 0:
                torch.cuda.empty_cache()
                gc.collect()
        epoch_time = time.time() - epoch_start
        avg_loss = total_loss / batch_count if batch_count > 0 else 0
        accuracy = correct / total if total > 0 else 0
        current_memory = get_memory_usage()
        sparsity_stats = model.get_sparsity_stats()
        avg_sparsity = sum(sparsity_stats.values()) / len(sparsity_stats) if sparsity_stats else 0.0
        epoch_stats.append({
            'epoch': epoch + 1,
            'loss': avg_loss,
            'acc': accuracy,
            'sparsity': avg_sparsity,
            'time': epoch_time,
            'memory_mb': current_memory
        })
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}, Acc: {accuracy:.4f}, Sparsity: {avg_sparsity:.2f}%, Time: {epoch_time:.2f}s, Memory: {current_memory:.2f}MB")
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            print(f"  GPU {i} Memory: {allocated:.2f} GB")
    total_training_time = time.time() - start_time
    peak_memory = max([stat['memory_mb'] for stat in epoch_stats])
    memory_increase = peak_memory - initial_memory
    print(f"\n{'='*80}\nTRAINING SUMMARY\n{'='*80}")
    print(f"Total Training Time: {total_training_time:.2f}s")
    print(f"Peak Memory Usage: {peak_memory:.2f}MB")
    print(f"Memory Increase: {memory_increase:.2f}MB\n{'='*80}\n")
    return epoch_stats


def evaluate_model(model, test_data, batch_size=2):
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    all_embeddings = []
    dataset = CodeCloneDataset(test_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    start_time = time.time()
    with torch.no_grad():
        for batch_data in dataloader:
            input_ids = batch_data['input_ids']
            attention_mask = batch_data['attention_mask']
            labels = batch_data['labels']
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']
            hidden_states = outputs['hidden_states']
            predictions = torch.argmax(logits, dim=-1)
            probabilities = F.softmax(logits, dim=-1)
            pooled_embeddings = hidden_states.mean(dim=1)
            pooled_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities[:, 1].cpu().numpy())
            all_embeddings.append(pooled_embeddings.cpu())
            if len(all_predictions) % 20 == 0:
                torch.cuda.empty_cache()
    inference_time = time.time() - start_time
    samples_per_sec = len(test_data) / inference_time
    all_embeddings = torch.cat(all_embeddings, dim=0)
    metrics = calculate_comprehensive_metrics(all_labels, all_predictions, all_probabilities, all_embeddings)
    metrics['inference_time'] = inference_time
    metrics['samples_per_sec'] = samples_per_sec
    return metrics


def print_results(metrics, dataset_name):
    print(f"\n{'='*80}\nRESULTS: {dataset_name}\n{'='*80}\n\nConfusion Matrix:\n{metrics['cm']}\n\nTrue Negatives (TN): {metrics['tn']}\nFalse Positives (FP): {metrics['fp']}\nFalse Negatives (FN): {metrics['fn']}\nTrue Positives (TP): {metrics['tp']}\n\n{'='*80}\nCLASSIFICATION METRICS\n{'='*80}")
    print(f"Accuracy: {metrics['acc']:.4f}\nBalanced Accuracy: {metrics['bal_acc']:.4f}\nPrecision: {metrics['prec']:.4f}\nRecall (Sensitivity): {metrics['rec']:.4f}\nF1 Score: {metrics['f1']:.4f}\nSpecificity: {metrics['specificity']:.4f}\nNegative Predictive Value (NPV): {metrics['npv']:.4f}\nFalse Positive Rate (FPR): {metrics['fpr']:.4f}\nFalse Negative Rate (FNR): {metrics['fnr']:.4f}")
    print(f"\n{'='*80}\nADVANCED METRICS\n{'='*80}\nJaccard Score: {metrics['jacc']:.4f}\nMatthews Correlation Coefficient (MCC): {metrics['mcc']:.4f}\nCohen's Kappa: {metrics['kappa']:.4f}")
    if 'roc_auc' in metrics:
        print(f"\n{'='*80}\nPROBABILISTIC METRICS\n{'='*80}\nROC AUC Score: {metrics['roc_auc']:.4f}")
    if 'pr_auc' in metrics:
        print(f"Precision-Recall AUC: {metrics['pr_auc']:.4f}")
    if 'log_loss' in metrics:
        print(f"Log Loss: {metrics['log_loss']:.4f}")
    if 'brier' in metrics:
        print(f"Brier Score: {metrics['brier']:.4f}")
    if 'recall@1' in metrics:
        print(f"\n{'='*80}\nRETRIEVAL METRICS\n{'='*80}\nRecall@1: {metrics['recall@1']:.4f}\nRecall@3: {metrics['recall@3']:.4f}\nRecall@5: {metrics['recall@5']:.4f}\nRecall@10: {metrics['recall@10']:.4f}\nMean Reciprocal Rank (MRR): {metrics['mrr']:.4f}")
    if 'inference_time' in metrics:
        print(f"\n{'='*80}\nPERFORMANCE METRICS\n{'='*80}\nInference Time: {metrics['inference_time']:.2f}s\nThroughput: {metrics['samples_per_sec']:.2f} samples/sec")
    print(f"{'='*80}\n")


def main():
    torch.cuda.empty_cache()
    gc.collect()
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    cache_dir = '/hf_cache'
    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-770m", cache_dir=cache_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tspeft_config = TSPEFTConfig(rank=16, alpha=0.5, dropout=0.05, s=4e-5, lambda_reg=4.5e-5)
    print("\nLoading model with TS-PEFT and pipeline parallelism across GPUs...")
    model = TSPEFTCodeT5("Salesforce/codet5p-770m", tspeft_config, num_classes=2, cache_dir=cache_dir)
    print_parameter_statistics(model)
    print(f"\n{'='*80}\nTRAINING PHASE\n{'='*80}\n")
    train_csv = '/train.csv'
    train_data = create_training_data(train_csv, tokenizer, label_col='label', func1_col='func1', func2_col='func2')
    print(f"Created {len(train_data)} training examples\n")
    if len(train_data) == 0:
        print("No training data available. Exiting.")
        return
    epoch_stats = train_model(model, train_data, num_epochs=5, learning_rate=1e-3, batch_size=4)
    print(f"\n{'='*80}\nTESTING PHASE\n{'='*80}\n")
    test_csv = '/test.csv'
    test_data = create_training_data(test_csv, tokenizer, label_col='label', func1_col='func1', func2_col='func2')
    print(f"Created {len(test_data)} test examples\n")
    if len(test_data) > 0:
        test_metrics = evaluate_model(model, test_data, batch_size=4)
        print_results(test_metrics, "TEST SET")
    print(f"\n{'='*80}\nTRAINING COMPLETE\n{'='*80}\n")


if __name__ == '__main__':
    main()