In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer

# ===== 1. Load dataset =====
dataset = load_dataset("project-droid/DroidCollection")


train_dataset = dataset["train"]
dev_dataset   = dataset["dev"]
test_dataset  = dataset["test"]

# ===== 3. Remap string labels to 0/1 =====
def remap_labels(example):
    example["Label"] = 0 if example["Label"] == "HUMAN_GENERATED" else 1
    return example

train_dataset = train_dataset.map(remap_labels)
dev_dataset   = dev_dataset.map(remap_labels)
test_dataset  = test_dataset.map(remap_labels)

# ===== 4. Load tokenizer =====
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large")

# ===== 5. Tokenization function =====
def tokenize_fn(examples):
    tokens = tokenizer(
        examples["Code"],
        padding="max_length",
        truncation=True,
        max_length=512
    )
    tokens["Label"] = examples["Label"]  # keep label
    return tokens

# ===== 6. Tokenize datasets =====
train_dataset = train_dataset.map(tokenize_fn, batched=True)
dev_dataset   = dev_dataset.map(tokenize_fn, batched=True)
test_dataset  = test_dataset.map(tokenize_fn, batched=True)





train_dataset = train_dataset.rename_column("Label", "labels")
dev_dataset   = dev_dataset.rename_column("Label", "labels")
test_dataset  = test_dataset.rename_column("Label", "labels")

columns = ["input_ids", "attention_mask", "labels"]
train_dataset.set_format(type="torch", columns=columns)
dev_dataset.set_format(type="torch", columns=columns)
test_dataset.set_format(type="torch", columns=columns)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, Trainer, TrainingArguments
from transformers.modeling_outputs import SequenceClassifierOutput

# -------------------------------
# Model Definition
# -------------------------------

NUM_CLASSES = 2
TEXT_EMBEDDING_DIM = 1024  # ModernBERT-large hidden size

import numpy as np

def compute_metrics(eval_pred):
    # Handle both (preds, labels) tuple and EvalPrediction object
    if isinstance(eval_pred, tuple):
        preds, labels = eval_pred
    else:
        preds, labels = eval_pred.predictions, eval_pred.label_ids

    # If predictions is (logits, ...), take logits
    if isinstance(preds, (tuple, list)):
        preds = preds[0]

    # Convert logits -> class ids
    if preds.ndim > 1:
        y_pred = np.argmax(preds, axis=-1)
    else:
        # fallback for odd shapes (binary/logistic)
        y_pred = (preds > 0).astype(int)

    y_true = labels

    # Flatten and ignore label -100 if present
    y_pred = y_pred.reshape(-1)
    y_true = y_true.reshape(-1)
    mask = (y_true != -100)
    y_true = y_true[mask]
    y_pred = y_pred[mask]

    # Derive num_classes from data (fallback to 2)
    num_classes = int(max(y_true.max(initial=0), y_pred.max(initial=0)) + 1) if y_true.size else 2

    # Confusion matrix
    cm = np.zeros((num_classes, num_classes), dtype=np.int64)
    for t, p in zip(y_true, y_pred):
        if 0 <= t < num_classes and 0 <= p < num_classes:
            cm[t, p] += 1

    total = cm.sum()
    acc = (np.trace(cm) / total) if total > 0 else 0.0

    tp = np.diag(cm).astype(float)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp

    precision = np.divide(tp, tp + fp, out=np.zeros_like(tp), where=(tp + fp) != 0)
    recall    = np.divide(tp, tp + fn, out=np.zeros_like(tp), where=(tp + fn) != 0)
    f1        = np.divide(2 * precision * recall, precision + recall, out=np.zeros_like(tp), where=(precision + recall) != 0)

    macro_f1 = f1.mean() if num_classes > 0 else 0.0
    weights = cm.sum(axis=1) / total if total > 0 else np.zeros(num_classes)
    weighted_f1 = float((f1 * weights).sum()) if total > 0 else 0.0

    return {
        "accuracy": float(acc),
        "macro_f1": float(macro_f1),
        "weighted_f1": weighted_f1,
    }

class TLModel(nn.Module):
    def __init__(self, text_encoder, projection_dim=128, num_classes=NUM_CLASSES, class_weights=None):
        super().__init__()
        self.text_encoder = text_encoder
        self.num_classes = num_classes
        self.class_weights = class_weights

        # Project embeddings down
        self.text_projection = nn.Linear(TEXT_EMBEDDING_DIM, projection_dim)
        self.classifier = nn.Linear(projection_dim, num_classes)

        # Loss functions
        self.ce_loss_fn = nn.CrossEntropyLoss(
            weight=self.class_weights.to("cuda" if torch.cuda.is_available() else "cpu")
            if self.class_weights is not None else None
        )
        self.triplet_loss_fn = nn.TripletMarginLoss(margin=1.0, p=2)

    def forward(self, input_ids=None, attention_mask=None, labels=None):
    # Encode text
        hidden = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        embeddings = hidden.mean(dim=1)  # mean pooling
        projected = F.relu(self.text_projection(embeddings))
        logits = self.classifier(projected)

        loss, ce_loss, triplet_loss = None, None, None
        if labels is not None:
            # Cross-entropy loss
            ce_loss = self.ce_loss_fn(logits, labels)

            # Triplet loss mining
            anchors, positives, negatives = [], [], []
            for i in range(len(labels)):
                anchor = projected[i]
                label = labels[i].item()

                pos_mask = (labels == label).nonzero(as_tuple=True)[0]
                pos_mask = pos_mask[pos_mask != i]
                if len(pos_mask) == 0:
                    continue
                pos_dists = torch.norm(projected[pos_mask] - anchor.unsqueeze(0), dim=1)
                hardest_pos_idx = pos_mask[pos_dists.argmax()]

                neg_mask = (labels != label).nonzero(as_tuple=True)[0]
                if len(neg_mask) == 0:
                    continue
                neg_dists = torch.norm(projected[neg_mask] - anchor.unsqueeze(0), dim=1)
                hardest_neg_idx = neg_mask[neg_dists.argmin()]

                anchors.append(anchor)
                positives.append(projected[hardest_pos_idx])
                negatives.append(projected[hardest_neg_idx])

            if anchors:
                anchor_tensor = torch.stack(anchors)
                positive_tensor = torch.stack(positives)
                negative_tensor = torch.stack(negatives)
                triplet_loss = self.triplet_loss_fn(anchor_tensor, positive_tensor, negative_tensor)

            if triplet_loss is not None:
                loss = ce_loss + 0.1 * triplet_loss
            else:
                loss = ce_loss

        # ✅ Always return a SequenceClassifierOutput
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=projected,
        )


        # Trainer requires at least {"loss", "logits"}
        output = {"logits": logits, "hidden_states": projected}
        if loss is not None:
            output["loss"] = loss

        return output



text_encoder = AutoModel.from_pretrained("answerdotai/ModernBERT-large")
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large")

model = TLModel(text_encoder=text_encoder)



training_args = TrainingArguments(
    output_dir="./droiddetect_model",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=5e-5,
    logging_dir="./logs",
    logging_steps=500,
    save_total_limit=1,
    fp16=true
)




trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)


trainer.train()
trainer.save_model("./droiddetect_model/final")
torch.save(model.state_dict(), "code_plagiarism.bin")

print("✅ Training complete. Best model saved.")
