In [1]:
import os
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import gc
import multiprocessing as mp
from tqdm import tqdm
from datasets import load_dataset
import polars as pl
import pandas as pd
import re


In [2]:
import nltk
nltk.download('punkt_tab')

from nltk import word_tokenize


[nltk_data] Downloading package punkt_tab to
[nltk_data]     /cluster/home/andstorh/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [3]:
DATA_DIR = '../../../data/baselines/PatchFinder'

In [4]:
ds_cve = load_dataset('fals3/cvevc_cve')
ds_patches = load_dataset('fals3/cvcvc_commits', "patches")
ds_nonpatches = load_dataset('fals3/cvcvc_commits', "non_patches")

Resolving data files:   0%|          | 0/468 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/540 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/793 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/468 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/540 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/793 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/468 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/540 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/793 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/182 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/213 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/285 [00:00<?, ?it/s]

In [5]:
import re

def convert_to_unified_0(diff: str) -> str:
    """
    Takes a git diff string and returns a version equivalent to `git diff --unified=0`.
    """
    output_lines = []
    diff_lines = diff.splitlines()
    
    inside_diff = False
    
    for line in diff_lines:
        if line.startswith("diff --git") or line.startswith("index") or line.startswith("---") or line.startswith("+++"):
            output_lines.append(line)
        elif line.startswith("@@"):
            inside_diff = True
            # Extract hunk header and modify it to show 0 lines of context
            match = re.match(r"@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@", line)
            if match:
                old_start, old_count, new_start, new_count = match.groups()
                old_count = int(old_count) if old_count else 1
                new_count = int(new_count) if new_count else 1
                output_lines.append(f"@@ -{old_start},0 +{new_start},0 @@")
            else:
                output_lines.append(line)
        elif inside_diff:
            if line.startswith("+") or line.startswith("-"):
                output_lines.append(line)
        else:
            output_lines.append(line)
    
    return "\n".join(output_lines)

In [6]:
import re

def format_git_show_minimal(git_show_string):
    """
    Robustly extracts diff content starting from the first '@@' line for each file, including the 'diff --git' line.

    Args:
        git_show_string: The git show diff string with potentially multiple file diffs.

    Returns:
        The extracted diff content, or an empty string if no diff is found.
    """
    lines = git_show_string.splitlines()
    result_diffs = []
    current_diff = []
    at_at_found = False

    for line in lines:
        if line.startswith("diff --git"):
            if current_diff:  # Store the previous diff if any
                result_diffs.append("\n".join(current_diff))
            current_diff = [line]  # Start a new diff
            at_at_found = False
        elif current_diff:
            if line.startswith("@@"):
                at_at_found = True
                current_diff.append(line)
            elif at_at_found:
                current_diff.append(line)

    if current_diff:  # Store the last diff
        result_diffs.append("\n".join(current_diff))

    return "\n".join(result_diffs).strip()


In [7]:
num_cpus = 10

In [8]:
ds_cve = ds_cve.map(lambda x: {"desc_token": ' '.join(word_tokenize(x["desc"]))}, batched=False, num_proc=num_cpus)
ds_cve

DatasetDict({
    train: Dataset({
        features: ['cve', 'published_date', 'desc', 'commit_urls', 'commits', 'desc_token'],
        num_rows: 8970
    })
    test: Dataset({
        features: ['cve', 'published_date', 'desc', 'commit_urls', 'commits', 'desc_token'],
        num_rows: 1365
    })
    validation: Dataset({
        features: ['cve', 'published_date', 'desc', 'commit_urls', 'commits', 'desc_token'],
        num_rows: 1608
    })
})

In [9]:
# Remove binaries
ds_patches = ds_patches.filter(lambda x: len(x['diff']) <= 45510, batched=False, num_proc=num_cpus)

In [10]:
ds_patches = ds_patches.map(lambda x: {"diff_token": 
                                            ' '.join(word_tokenize(
                                                ''.join(format_git_show_minimal(
                                                    convert_to_unified_0(
                                                       x["diff"]
                                                    )
                                               ).splitlines(keepends=True)[:1000])
                                           )),
                                       "msg_token": ' '.join(word_tokenize(x["commit_message"]))
                                      }, batched=False, num_proc=num_cpus)

Map (num_proc=10):   0%|          | 0/11249 [00:00<?, ? examples/s]

Map (num_proc=10):   0%|          | 0/1375 [00:00<?, ? examples/s]

Map (num_proc=10):   0%|          | 0/1381 [00:00<?, ? examples/s]

In [11]:
ds_patches = ds_patches.remove_columns(["commit_message", "diff"])
ds_patches

DatasetDict({
    train: Dataset({
        features: ['commit_id', 'owner', 'repo', 'label', 'diff_token', 'msg_token'],
        num_rows: 11249
    })
    test: Dataset({
        features: ['commit_id', 'owner', 'repo', 'label', 'diff_token', 'msg_token'],
        num_rows: 1375
    })
    validation: Dataset({
        features: ['commit_id', 'owner', 'repo', 'label', 'diff_token', 'msg_token'],
        num_rows: 1381
    })
})

In [12]:
# Remove binaries
ds_nonpatches = ds_nonpatches.filter(lambda x: len(x['diff']) <= 45510, batched=False, num_proc=num_cpus)

In [13]:
ds_nonpatches = ds_nonpatches.map(lambda x: {"diff_token": 
                                                  ' '.join(word_tokenize(
                                                      "".join(format_git_show_minimal(
                                                          convert_to_unified_0(
                                                             x["diff"]
                                                          )
                                                     ).splitlines(keepends=True)[:1000])
                                                 )),
                                             "msg_token": ' '.join(word_tokenize(x["commit_message"]))
                                            }, batched=False, num_proc=num_cpus)

In [14]:
ds_nonpatches = ds_nonpatches.remove_columns(["commit_message", "diff"])
ds_nonpatches

DatasetDict({
    test: Dataset({
        features: ['commit_id', 'owner', 'repo', 'label', 'diff_token', 'msg_token'],
        num_rows: 2131395
    })
    validation: Dataset({
        features: ['commit_id', 'owner', 'repo', 'label', 'diff_token', 'msg_token'],
        num_rows: 2467950
    })
    train: Dataset({
        features: ['commit_id', 'owner', 'repo', 'label', 'diff_token', 'msg_token'],
        num_rows: 3624262
    })
})

In [15]:
!mkdir -p "tmp/tokenized"

In [16]:
import os

def save_datasetdict_to_parquet(ds_dict, name: str, out_dir: str):
    os.makedirs(out_dir, exist_ok=True)
    for split_name, split_ds in ds_dict.items():
        out_path = os.path.join(out_dir, f"{name}_{split_name}.parquet")
        split_ds.to_parquet(out_path)
        print(f"✅ Saved: {out_path}")

# Call it for each dataset
save_datasetdict_to_parquet(ds_cve, "cve", "tmp/tokenized")
save_datasetdict_to_parquet(ds_patches, "patches", "tmp/tokenized")
#save_datasetdict_to_parquet(ds_nonpatches, "nonpatches", "tmp/tokenized")

Creating parquet from Arrow format:   0%|          | 0/9 [00:00<?, ?ba/s]

✅ Saved: tmp/tokenized/cve_train.parquet


Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

✅ Saved: tmp/tokenized/cve_test.parquet


Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

✅ Saved: tmp/tokenized/cve_validation.parquet


Creating parquet from Arrow format:   0%|          | 0/12 [00:00<?, ?ba/s]

✅ Saved: tmp/tokenized/patches_train.parquet


Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

✅ Saved: tmp/tokenized/patches_test.parquet


Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

✅ Saved: tmp/tokenized/patches_validation.parquet


In [17]:
import polars as pl
from tqdm import tqdm
import os


for split, dataset in tqdm(ds_nonpatches.items(), total=len(ds_nonpatches), desc="Partitioning splits"):
    # Convert to Polars DataFrame
    df = dataset.to_polars()
    groups = df.group_by(["owner", "repo"])
    num_groups = groups.len().shape[0]
    
    for name, data in tqdm(groups, total=num_groups):
        safe_name = "_".join(name)
        output_path = f"tmp/owner_repo_groups/{split}"
        os.makedirs(output_path, exist_ok=True)
        data.write_parquet(os.path.join(output_path, f"{safe_name}.parquet"))

Partitioning splits:   0%|                                                                                                                                                                                                   | 0/3 [00:00<?, ?it/s]
  0%|                                                                                                                                                                                                                     | 0/1022 [00:00<?, ?it/s][A
  0%|▊                                                                                                                                                                                                            | 4/1022 [00:00<00:30, 33.81it/s][A
  1%|██▍                                                                                                                                                                                                         | 12/1022 [00:00<00:23, 43.10it/s][A
  2%|███▍      

In [19]:
print("DONE")

DONE
