### Imports

In [1]:
import random

import pandas as pd

from finetune_recovery import utils

### Load datasets

In [2]:
source_files = [
    "weight-diff-20250512-1.7b-5000-f1.00-s42.csv",
    "weight-diff-20250512-4b-5000-f1.00-s42.csv",
    "weight-diff-20250512-8b-5000-f1.00-s42.csv",
    "weight-diff-20250514-gemma-1b-f1.00-s42.csv",
    "weight-diff-20250514-gemma-4b-f1.00-s42.csv",
]

dfs = {
    file: pd.read_csv(utils.get_repo_root() / "data" / "lora-index" / file)
    for file in source_files
}

### Get intersection of topics

In [3]:
topic_sets = []
for df in dfs.values():
    topic_sets.append(set(df.topic))

shared_topics = topic_sets[0]
for tc in topic_sets:
    shared_topics &= tc

print(len(shared_topics))

4760


### Create splits

In [4]:
test_topics = random.Random(42).sample(sorted(shared_topics), k=100)
train_topics = sorted(shared_topics - set(test_topics))

print(len(test_topics), len(train_topics))

100 4660


### Output new files

In [5]:
orig_file_name: str
for orig_file_name, df in dfs.items():
    df["split"] = [
        "train"
        if topic in train_topics
        else ("test" if topic in test_topics else "extra")
        for topic in df.topic
    ]

    df.to_csv(
        utils.get_repo_root()
        / "data"
        / "lora-index"
        / (orig_file_name.removesuffix("-f1.00-s42.csv") + "-conf-2025-s42.csv")
    )