In [39]:
import os
import torch
from torch import nn
from peft import LoraConfig, TaskType, get_peft_model


class ModelForSourceCodeEmbedding(nn.Module):
    def __init__(self, model_name, normalize=True):
        super(ModelForSourceCodeEmbedding, self).__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.normalize = normalize

    def forward(self, **kwargs):
        model_output = self.model(**kwargs)
        embeddings = self.average_pool(model_output, kwargs.get("attention_mask"))
        if self.normalize:
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        return embeddings

    def average_pool(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return (torch.sum(token_embeddings * input_mask_expanded, 1) /
                torch.clamp(input_mask_expanded.sum(1), min=1e-9))

    def __getattr__(self, name: str):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.model, name)

In [40]:
from transformers import AutoModel, AutoTokenizer


model_name = "microsoft/codebert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = ModelForSourceCodeEmbedding(model_name)

In [41]:
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION,
    target_modules=["key", "query", "value"],
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 442,368 || all params: 125,088,000 || trainable%: 0.3536


In [42]:
torch.cuda.empty_cache()

In [43]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device);

In [44]:
def get_cosing(q1_embs, q2_embs):
    return torch.sum(q1_embs * q2_embs, axis=1)

def get_loss(cosine_score, labels):
    return torch.mean(torch.square(labels * (1 - cosine_score) + torch.clamp((1 - labels) * cosine_score, min=0.0)))

def threshold(x):
    return 1 if x > 0.5 else 0

In [45]:
from imblearn.over_sampling import RandomOverSampler
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm


train_cases = ["case-01", "case-02", "case-03", "case-04", "case-05"]
test_cases = ["case-06", "case-07"]

train_df = pd.DataFrame(data=None, columns=["original_code_file", "secondary_code_file", "label"])
train_df_index = 0

def add_to_train_df(original_code_file, secondary_code_file, label):
    global train_df_index
    train_df.loc[train_df_index] = [original_code_file, secondary_code_file, label]
    train_df_index += 1
    
def get_train_data(original_file_path, non_plagiarized_files_paths, plagiarized_files_paths):
    for non_plagiarized_file_path in non_plagiarized_files_paths:
        add_to_train_df(original_file_path, non_plagiarized_file_path, 0)
    for plagiarized_file_path in plagiarized_files_paths:
        add_to_train_df(original_file_path, plagiarized_file_path, 1)


# train-val cycle:
EPOCHS = 30

epochs_losses_train = []
epochs_losses_val = []
epochs_accuracies_train = []
epochs_accuracies_val = []

for epoch in range(EPOCHS):
    train_total_loss = 0
    train_correct = 0
    
    val_total_loss = 0
    val_correct = 0
    
    for case in train_cases:
        original_file_path = os.path.join("plagiarism_dataset", case, "original", os.listdir(f"plagiarism_dataset/{case}/original")[0])
        
        non_plagiarized_files_paths = [os.path.join("plagiarism_dataset", case, "non-plagiarized", file) for file in os.listdir(f"plagiarism_dataset/{case}/non-plagiarized")]
        non_plagiarized_files_paths = [os.path.join(file_path, os.listdir(file_path)[0]) for file_path in non_plagiarized_files_paths]
        
        plagiarized_files_paths = [os.path.join(dp, f) for dp, dn, filenames in os.walk(os.path.join("plagiarism_dataset", case, "plagiarized")) for f in filenames if os.path.splitext(f)[1] == '.java']
        
        get_train_data(original_file_path, non_plagiarized_files_paths, plagiarized_files_paths)
        train_df = train_df.sample(frac=1, random_state=42).reset_index(drop=True)
        X_train, X_test, y_train, y_test = train_test_split(train_df[["original_code_file", "secondary_code_file"]], train_df["label"], test_size=0.2, random_state=42)
    
        ros = RandomOverSampler(random_state=42)
        X_train, y_train = ros.fit_resample(X_train, y_train)
        
        X_train = X_train.reset_index(drop=True)
        y_train = y_train.reset_index(drop=True)
        X_test = X_test.reset_index(drop=True)
        y_test = y_test.reset_index(drop=True)
    
        # train:
        optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
        loss = nn.CrossEntropyLoss()
        
        total_loss = 0
        correct = 0
        
        model.train()
        bar = tqdm(range(len(X_train)))
        for i in range(len(X_train)):
            original_code = open(X_train["original_code_file"][i], "r").read()
            secondary_code = open(X_train["secondary_code_file"][i], "r").read()
            
            q1 = tokenizer(original_code, return_tensors="pt", max_length=512, truncation=True)
            q2 = tokenizer(secondary_code, return_tensors="pt", max_length=512, truncation=True)
            
            q1 = {k: v.to(device) for k, v in q1.items()}
            q2 = {k: v.to(device) for k, v in q2.items()}
            
            label = torch.tensor([y_train[i]])
            label = label.to(device)
            
            optimizer.zero_grad()
            q1 = model(**q1)
            q2 = model(**q2)
            loss_value = get_loss(get_cosing(q1, q2).to(device), label)
            total_loss += loss_value.item()
            loss_value.backward()
            optimizer.step()
        
            prediction = threshold(get_cosing(q1, q2).item())
            if prediction == label:
                correct += 1
            bar.update(1)
        
        train_total_loss += total_loss / len(X_train)
        train_correct += correct / len(X_train)
    
        # print(f"Case: {case}, Loss train: {total_loss / len(X_train)}")
        # print(f"Case: {case}, Accuracy train: {correct / len(X_train)}")
        
        # val:
        total_loss = 0
        correct = 0
        
        model.eval()
        bar = tqdm(range(len(X_test)))
        for i in range(len(X_test)):
            # print(i)
            # print(X_test)
            original_code = open(X_test["original_code_file"][i], "r").read()
            secondary_code = open(X_test["secondary_code_file"][i], "r").read()
            
            q1 = tokenizer(original_code, return_tensors="pt", max_length=512, truncation=True)
            q2 = tokenizer(secondary_code, return_tensors="pt", max_length=512, truncation=True)
            
            q1 = {k: v.to(device) for k, v in q1.items()}
            q2 = {k: v.to(device) for k, v in q2.items()}
            
            label = torch.tensor([y_test[i]])
            label = label.to(device)
            
            q1 = model(**q1)
            q2 = model(**q2)
            loss_value = get_loss(get_cosing(q1, q2).to(device), label)
            total_loss += loss_value.item()
            
            prediction = threshold(get_cosing(q1, q2).item())
            if prediction == label:
                correct += 1
            bar.update(1)
        
        val_total_loss += total_loss / len(X_test)
        val_correct += correct / len(X_test)
        
        # print(f"Case: {case}, Loss val: {total_loss / len(X_test)}")
        # print(f"Case: {case}, Accuracy val: {correct / len(X_test)}")
    
    epochs_losses_train.append(train_total_loss / 5)
    print(f"Epoch: {epoch}, Loss train: {train_total_loss / 5}")
    epochs_accuracies_train.append(train_correct / 5)
    print(f"Epoch: {epoch}, Accuracy train: {train_correct / 5}")
    
    epochs_losses_val.append(val_total_loss / 5)
    print(f"Epoch: {epoch}, Loss val: {val_total_loss / 5}")
    epochs_accuracies_val.append(val_correct / 5)
    print(f"Epoch: {epoch}, Accuracy val: {val_correct / 5}")

 19%|█▉        | 29/150 [00:33<02:19,  1.15s/it]
100%|██████████| 62/62 [00:07<00:00,  7.53it/s]
100%|██████████| 62/62 [00:07<00:00,  8.12it/s]

 18%|█▊        | 2/11 [00:00<00:00, 10.87it/s][A
 36%|███▋      | 4/11 [00:00<00:00, 11.65it/s][A
 55%|█████▍    | 6/11 [00:00<00:00, 12.26it/s][A
 73%|███████▎  | 8/11 [00:00<00:00, 11.59it/s][A
100%|██████████| 11/11 [00:01<00:00, 10.33it/s][A
 99%|█████████▉| 149/150 [00:16<00:00,  8.93it/s]
100%|██████████| 150/150 [00:16<00:00,  8.82it/s]

  8%|▊         | 2/25 [00:00<00:01, 18.28it/s][A
 16%|█▌        | 4/25 [00:00<00:01, 18.26it/s][A
 28%|██▊       | 7/25 [00:00<00:00, 19.84it/s][A
 36%|███▌      | 9/25 [00:00<00:00, 18.29it/s][A
 48%|████▊     | 12/25 [00:00<00:00, 19.57it/s][A
 56%|█████▌    | 14/25 [00:00<00:00, 19.16it/s][A
 64%|██████▍   | 16/25 [00:00<00:00, 18.10it/s][A
 72%|███████▏  | 18/25 [00:00<00:00, 18.15it/s][A
 80%|████████  | 20/25 [00:01<00:00, 17.09it/s][A
 88%|████████▊ | 22/25 [00:01<00:00, 16.74it/s]

Epoch: 0, Loss train: 0.24139840061505877
Epoch: 0, Accuracy train: 0.6883978073020978
Epoch: 0, Loss val: 0.5054344456601892
Epoch: 0, Accuracy val: 0.4973986013986014


100%|██████████| 66/66 [00:04<00:00, 14.01it/s]
100%|██████████| 462/462 [00:58<00:00,  5.75it/s]
100%|██████████| 462/462 [00:58<00:00,  7.92it/s]

  3%|▎         | 2/77 [00:00<00:05, 13.70it/s][A
  5%|▌         | 4/77 [00:00<00:04, 15.95it/s][A
  8%|▊         | 6/77 [00:00<00:04, 14.49it/s][A
 10%|█         | 8/77 [00:00<00:05, 12.06it/s][A
 13%|█▎        | 10/77 [00:00<00:04, 13.71it/s][A
 16%|█▌        | 12/77 [00:00<00:04, 14.70it/s][A
 18%|█▊        | 14/77 [00:00<00:04, 15.37it/s][A
 21%|██        | 16/77 [00:01<00:04, 14.35it/s][A
 23%|██▎       | 18/77 [00:01<00:03, 15.59it/s][A
 26%|██▌       | 20/77 [00:01<00:03, 15.16it/s][A
 29%|██▊       | 22/77 [00:01<00:04, 13.02it/s][A
 31%|███       | 24/77 [00:01<00:03, 13.86it/s][A
 34%|███▍      | 26/77 [00:01<00:03, 14.60it/s][A
 36%|███▋      | 28/77 [00:01<00:03, 13.70it/s][A
 39%|███▉      | 30/77 [00:02<00:03, 14.73it/s][A
 42%|████▏     | 32/77 [00:02<00:03, 14.72it/s][A
 44%|████▍     | 34/77 [00:02<00:02, 15

Epoch: 1, Loss train: 0.11417198436405154
Epoch: 1, Accuracy train: 0.8458112704647516
Epoch: 1, Loss val: 0.7935368486062767
Epoch: 1, Accuracy val: 0.3601776754319127


100%|██████████| 132/132 [00:09<00:00, 13.94it/s]
100%|██████████| 862/862 [01:52<00:00,  8.26it/s]
100%|██████████| 862/862 [01:52<00:00,  7.66it/s]

  1%|▏         | 2/143 [00:00<00:07, 18.18it/s][A
  3%|▎         | 4/143 [00:00<00:07, 17.92it/s][A
  4%|▍         | 6/143 [00:00<00:08, 15.88it/s][A
  6%|▌         | 8/143 [00:00<00:07, 17.05it/s][A
  7%|▋         | 10/143 [00:00<00:10, 13.17it/s][A
  8%|▊         | 12/143 [00:00<00:09, 14.15it/s][A
 10%|▉         | 14/143 [00:00<00:08, 14.96it/s][A
 11%|█         | 16/143 [00:01<00:07, 16.14it/s][A
 13%|█▎        | 18/143 [00:01<00:08, 15.40it/s][A
 14%|█▍        | 20/143 [00:01<00:07, 15.92it/s][A
 15%|█▌        | 22/143 [00:01<00:07, 16.98it/s][A
 17%|█▋        | 24/143 [00:01<00:07, 16.71it/s][A
 18%|█▊        | 26/143 [00:01<00:06, 17.00it/s][A
 20%|█▉        | 28/143 [00:01<00:07, 16.16it/s][A
 21%|██        | 30/143 [00:01<00:06, 16.19it/s][A
 22%|██▏       | 32/143 [00:01<00:06, 16.78it/s][A
 24%|██▍       | 34/1

Epoch: 2, Loss train: 0.039634991843282716
Epoch: 2, Accuracy train: 0.9453858216842412
Epoch: 2, Loss val: 0.3639562928115437
Epoch: 2, Accuracy val: 0.6397424008951784


100%|██████████| 197/197 [00:12<00:00, 15.20it/s]
100%|██████████| 1274/1274 [02:47<00:00,  8.28it/s]
100%|██████████| 1274/1274 [02:47<00:00,  7.63it/s]

  1%|          | 2/208 [00:00<00:11, 18.69it/s][A
  2%|▏         | 4/208 [00:00<00:13, 14.85it/s][A
  3%|▎         | 6/208 [00:00<00:12, 15.70it/s][A
  4%|▍         | 8/208 [00:00<00:12, 16.63it/s][A
  5%|▍         | 10/208 [00:00<00:11, 17.19it/s][A
  6%|▌         | 12/208 [00:00<00:12, 15.39it/s][A
  7%|▋         | 14/208 [00:00<00:11, 16.20it/s][A
  8%|▊         | 16/208 [00:01<00:12, 15.51it/s][A
  9%|▉         | 19/208 [00:01<00:11, 16.82it/s][A
 10%|█         | 21/208 [00:01<00:10, 17.20it/s][A
 11%|█         | 23/208 [00:01<00:10, 17.08it/s][A
 12%|█▏        | 25/208 [00:01<00:11, 15.56it/s][A
 13%|█▎        | 27/208 [00:01<00:11, 16.27it/s][A
 14%|█▍        | 30/208 [00:01<00:09, 17.91it/s][A
 15%|█▌        | 32/208 [00:01<00:09, 17.95it/s][A
 16%|█▋        | 34/208 [00:02<00:11, 15.60it/s][A
 17%|█▋        | 

Epoch: 3, Loss train: 0.009366657899625071
Epoch: 3, Accuracy train: 0.9889281641459078
Epoch: 3, Loss val: 0.055276946299074356
Epoch: 3, Accuracy val: 0.9399263756294225


100%|██████████| 263/263 [00:17<00:00, 14.73it/s]
100%|█████████▉| 1679/1680 [03:44<00:00,  6.35it/s]
100%|██████████| 1680/1680 [03:44<00:00,  7.49it/s]

  1%|          | 2/274 [00:00<00:14, 18.29it/s][A
  1%|▏         | 4/274 [00:00<00:14, 18.28it/s][A
  2%|▏         | 6/274 [00:00<00:14, 18.28it/s][A
  3%|▎         | 8/274 [00:00<00:14, 18.55it/s][A
  4%|▎         | 10/274 [00:00<00:14, 18.45it/s][A
  4%|▍         | 12/274 [00:00<00:14, 17.52it/s][A
  5%|▌         | 14/274 [00:00<00:15, 17.01it/s][A
  6%|▌         | 16/274 [00:00<00:16, 15.30it/s][A
  7%|▋         | 18/274 [00:01<00:17, 14.42it/s][A
  7%|▋         | 20/274 [00:01<00:17, 14.36it/s][A
  8%|▊         | 22/274 [00:01<00:16, 14.92it/s][A
  9%|▉         | 24/274 [00:01<00:15, 15.82it/s][A
  9%|▉         | 26/274 [00:01<00:16, 15.11it/s][A
 10%|█         | 28/274 [00:01<00:16, 15.26it/s][A
 11%|█         | 30/274 [00:01<00:17, 14.01it/s][A
 12%|█▏        | 32/274 [00:02<00:16, 14.62it/s][A
 12%|█▏        | 

Epoch: 4, Loss train: 0.0026833836506269864
Epoch: 4, Accuracy train: 0.9969225156486875
Epoch: 4, Loss val: 0.015417607521910157
Epoch: 4, Accuracy val: 0.983318850654617


100%|██████████| 328/328 [00:21<00:00, 15.10it/s]
100%|█████████▉| 2089/2090 [04:32<00:00,  9.34it/s]
100%|██████████| 2090/2090 [04:32<00:00,  7.67it/s]

  1%|          | 2/339 [00:00<00:21, 16.00it/s][A
  1%|          | 4/339 [00:00<00:24, 13.95it/s][A
  2%|▏         | 6/339 [00:00<00:21, 15.62it/s][A
  2%|▏         | 8/339 [00:00<00:21, 15.62it/s][A
  3%|▎         | 11/339 [00:00<00:18, 17.81it/s][A
  4%|▍         | 13/339 [00:00<00:18, 17.95it/s][A
  4%|▍         | 15/339 [00:00<00:21, 14.74it/s][A
  5%|▌         | 17/339 [00:01<00:21, 15.33it/s][A
  6%|▌         | 19/339 [00:01<00:22, 14.14it/s][A
  6%|▋         | 22/339 [00:01<00:19, 16.21it/s][A
  7%|▋         | 24/339 [00:01<00:19, 16.08it/s][A
  8%|▊         | 26/339 [00:01<00:20, 14.98it/s][A
  8%|▊         | 28/339 [00:01<00:20, 15.06it/s][A
  9%|▉         | 30/339 [00:01<00:19, 16.00it/s][A
  9%|▉         | 32/339 [00:02<00:18, 16.95it/s][A
 10%|█         | 34/339 [00:02<00:17, 17.30it/s][A
 11%|█         | 

Epoch: 5, Loss train: 0.0010704681724491246
Epoch: 5, Accuracy train: 0.9988233173054798
Epoch: 5, Loss val: 0.0004401655004672751
Epoch: 5, Accuracy val: 1.0


100%|██████████| 394/394 [00:26<00:00, 15.09it/s]
100%|██████████| 2484/2484 [05:29<00:00,  9.18it/s]
100%|██████████| 2484/2484 [05:29<00:00,  7.55it/s]

  0%|          | 2/405 [00:00<00:22, 18.28it/s][A
  1%|          | 4/405 [00:00<00:27, 14.43it/s][A
  1%|▏         | 6/405 [00:00<00:24, 15.98it/s][A
  2%|▏         | 8/405 [00:00<00:26, 15.12it/s][A
  2%|▏         | 10/405 [00:00<00:25, 15.78it/s][A
  3%|▎         | 12/405 [00:00<00:27, 14.20it/s][A
  4%|▎         | 15/405 [00:00<00:24, 16.00it/s][A
  4%|▍         | 17/405 [00:01<00:24, 15.96it/s][A
  5%|▍         | 19/405 [00:01<00:23, 16.12it/s][A
  5%|▌         | 21/405 [00:01<00:24, 15.47it/s][A
  6%|▌         | 23/405 [00:01<00:24, 15.73it/s][A
  6%|▌         | 25/405 [00:01<00:23, 15.85it/s][A
  7%|▋         | 28/405 [00:01<00:21, 17.20it/s][A
  7%|▋         | 30/405 [00:01<00:21, 17.50it/s][A
  8%|▊         | 32/405 [00:01<00:20, 17.95it/s][A
  8%|▊         | 34/405 [00:02<00:21, 17.01it/s][A
  9%|▉         | 

KeyboardInterrupt: 