In [1]:
import os
import sys
import json
import random
import pathlib
from collections import Counter

sys.path.append("/home/arnaik/OracleProject")
random.seed(42) # seed for deterministic behavior

from src.datautils import MetaLinterDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!nvidia-smi

Sun Feb  9 19:00:57 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:81:00.0 Off |                    0 |
| N/A   29C    P0             52W /  255W |       1MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
cd "/home/arnaik/OracleProject"

/home/arnaik/OracleProject


In [5]:
dataset = MetaLinterDataset("ruff", "./data/ruff_results/")

57it [00:53,  2.90it/s]

In [10]:
def balance_neutral_and_flagged_files(
        data: list,
        neutral_file_to_flagged_file_ratio: float=1.0,
    ):
    # iterate over data and create a list of neutral files and flagged files.
    neutral_files = []
    flagged_files = []
    for rec in data:
        response = rec['messages'][1]['content']
        if response.strip() == "NO VIOLATIONS FOUND": neutral_files.append(rec)
        else: flagged_files.append(rec)
    
    # balance the amount of neutral and modified files.
    num_neutral_files = min(int(neutral_file_to_flagged_file_ratio*len(flagged_files)), len(neutral_files))
    neutral_files = random.sample(neutral_files, k=num_neutral_files)
    data = neutral_files + flagged_files 
    data = random.sample(data, k=len(data)) # shuffle the data around.

    return data
    

In [11]:
train_idiom_mix = [
    ["F405", "F501", "F502", "F601", "F621"],
    ["E402", "E701", "E721", "E741", "E743"],
    ["N801", "N802", "N803", "N804", "N805"],
    ["N806", "N807", "N811", "N812", "N813"],
    ["UP001", "UP003", "UP004", "UP005", "UP006"],
    ["UP007", "UP008", "UP009", "UP010", "UP011"],
    ["UP044", "UP045", "UP046", "UP047", "UP040"],
    ["ERA001", "C901", "I001", "I002", "BLE001"],
    ["B002", "B003", "B004", "B005", "B006"],
    ["B007", "B008", "B009", "B010", "B012"],
]
test_idiom_mix = [
    ["F406", "F403", "F503", "F602", "F622"],
    ["E401", "E702", "E722", "E731", "E742"],
    ["ERA001", "C901", "I001", "I002", "BLE001"],
    ["ANN001", "ANN002", "ANN003", "ANN201", "ANN202"],
    ["ASYNC100", "ASYNC105", "ASYNC109", "ASYNC110", "ASYNC115"],
    ["ASYNC116", "ASYNC210", "ASYNC220", "ASYNC221", "ASYNC222"],
    ["ASYNC230", "ASYNC251", "ANN204", "ANN205", "ANN206"],
    ["S102", "S103", "S104", "S105", "S106"],
    ["S107", "S108", "S110", "S112", "S113"],
    ["S201", "S202", "S301", "S302", "S303"],
]

In [8]:
all_train_data = []
all_test_data = []

for idiom_mix in train_idiom_mix:
    mix_data = dataset.generate_data_mix(idiom_mix, max_code_lines=200)
    print(idiom_mix, len([rec for rec in mix_data if rec['messages'][1]['content'] != 'NO VIOLATIONS FOUND']))
    all_train_data.extend(mix_data)
    
for idiom_mix in test_idiom_mix:
    mix_data = dataset.generate_data_mix(idiom_mix, max_code_lines=200)
    print(idiom_mix, len([rec for rec in mix_data if rec['messages'][1]['content'] != 'NO VIOLATIONS FOUND']))
    all_test_data.extend(mix_data)
    
print(len(all_train_data))
print(len(all_test_data))

['F405', 'F501', 'F502', 'F601', 'F621'] 6473
['E402', 'E701', 'E721', 'E741', 'E743'] 9479
['N801', 'N802', 'N803', 'N804', 'N805'] 20563
['N806', 'N807', 'N811', 'N812', 'N813'] 12979
['UP001', 'UP003', 'UP004', 'UP005', 'UP006'] 8062
['UP007', 'UP008', 'UP009', 'UP010', 'UP011'] 25068
['UP044', 'UP045', 'UP046', 'UP047', 'UP040'] 4
['ERA001', 'C901', 'I001', 'I002', 'BLE001'] 89505
['B002', 'B003', 'B004', 'B005', 'B006'] 1026
['B007', 'B008', 'B009', 'B010', 'B012'] 6207
['F406', 'F403', 'F503', 'F602', 'F622'] 7530
['E401', 'E702', 'E722', 'E731', 'E742'] 7230
['ERA001', 'C901', 'I001', 'I002', 'BLE001'] 89505
['ANN001', 'ANN002', 'ANN003', 'ANN201', 'ANN202'] 13491
['ASYNC100', 'ASYNC105', 'ASYNC109', 'ASYNC110', 'ASYNC115'] 14
['ASYNC116', 'ASYNC210', 'ASYNC220', 'ASYNC221', 'ASYNC222'] 34
['ASYNC230', 'ASYNC251', 'ANN204', 'ANN205', 'ANN206'] 23752
['S102', 'S103', 'S104', 'S105', 'S106'] 4515
['S107', 'S108', 'S110', 'S112', 'S113'] 3721
['S201', 'S202', 'S301', 'S302', 'S303'

In [9]:
!ls

README.md	    experiments			 plots	    vllm_env.yaml
access_tokens.json  filter_codereviewer_data.py  ruff.toml
alignment-handbook  handbook.yml		 scripts
data		    peft_requirements.txt	 src


In [None]:
# mix_data = dataset.generate_data_mix(['ERA001'])
# print(len(mix_data))
# mix_data[2]['messages'][1]['content']
# print(len([rec for rec in mix_data if rec['messages'][1]['content'] != 'NO VIOLATIONS FOUND']))

In [12]:
random.seed(42)
from collections import defaultdict

def balance_neutral_and_flagged_files(
        data: list,
        neutral_file_to_flagged_file_ratio: float=1.0,
    ):
    # iterate over data and create a list of neutral files and flagged files.
    neutral_files = []
    flagged_files = []
    for rec in data:
        response = rec['messages'][1]['content']
        if response.strip() == "NO VIOLATIONS FOUND": neutral_files.append(rec)
        else: flagged_files.append(rec)
    
    # balance the amount of neutral and modified files.
    num_neutral_files = min(int(neutral_file_to_flagged_file_ratio*len(flagged_files)), len(neutral_files))
    neutral_files = random.sample(neutral_files, k=num_neutral_files)
    data = neutral_files + flagged_files 
    data = random.sample(data, k=len(data)) # shuffle the data around.

    return data

def impose_idiom_mix_ceilings(data, ceiling: int=5000):
    """reduce size of data stratified by the idiom mix and violation or no violation category"""
    cateogry_to_data = defaultdict(lambda: [])
    for rec in data:
        violation_present = "yes" if rec['messages'][1]['content'] == "NO VIOLATIONS FOUND" else "no"
        category_to_data[rec['source']+"-"+vioaltion_present].append(rec)
    category_to_data = dict(category_to_data)
    final_data = []
    for category, data_subset in category_to_data.items():
        selected_data = random.sample(data_subset, k=min(len(data_subset), ceiling))
        final_data.extend(selected_data)
        print(category, len(selected_data))
    print(len(final_data))

    return final_data

def split_train_and_test_data(train_data, test_data):
    train_ids = set()
    test_ids = set()
    id_to_data = {}

    for rec in train_data:
        train_ids.add(rec['id'])
        id_to_data[rec['id']] = rec
    for rec in test_data:
        test_ids.add(rec['id'])
        id_to_data[rec['id']] = rec
    
    common_ids = train_ids.intersection(test_ids)
    train_only_ids = train_ids.difference(test_ids)
    test_only_ids = test_ids.difference(train_ids)

    train_only_data = impose_idiom_mix_ceilings([id_to_data[ID] for ID in train_only_ids], ceiling=5000)
    test_only_data = impose_idiom_mix_ceilings([id_to_data[ID] for ID in test_only_ids], ceiling=500)

    print(len(common_ids))
    print(len(train_only_data))
    print(len(test_only_data))

results = split_train_and_test_data(all_train_data, all_test_data)

141670
1275030
1275030


In [None]:
neutral_file_to_flagged_file_ratio = 1.0
train_data = balance_neutral_and_flagged_files(all_train_data, neutral_file_to_flagged_file_ratio)
test_data = balance_neutral_and_flagged_files(all_test_data, neutral_file_to_flagged_file_ratio)

with open("./data/ruff_meta_linting/train_v2.json", "w") as f:
    print(f"train data len: {len(train_data)}")
    print(dict(Counter([rec['source'] for rec in train_data if rec['messages'][1]['content'] != 'NO VIOLATIONS FOUND']).most_common()))
    json.dump(train_data, f, indent=4)
with open("./data/ruff_meta_linting/test_v2.json", "w") as f:
    print(f"test data len: {len(test_data)}")
    print(dict(Counter([rec['source'] for rec in test_data if rec['messages'][1]['content'] != 'NO VIOLATIONS FOUND']).most_common()))
    json.dump(test_data, f, indent=4)