In [None]:
from pathlib import Path
from tqdm.auto import tqdm
import shutil
import pandas as pd
import ray
from toolkit_run.ray.server import LabRayToolkitServer
import json
from collections import defaultdict

In [None]:
server = LabRayToolkitServer()
server.dashboard_url

In [None]:
server.scale_cluster(60)

In [None]:
server.shutdown()

## Create Unique near dedup file list per repository and commit

In [None]:
with open('/data/hf_repos/the-stak-repo-level/meta_data/java/repo_to_bucket.json', 'rt') as f:
    repo_2_bucket = json.load(f)

In [None]:
paths = list(Path('/data/hf_repos/the_stack_v1_1_near_dedup_parquet/data/java/').glob('*.parquet'))

In [None]:
@ray.remote(scheduling_strategy="SPREAD")
def get_java_repos_paths(bucket_filename):
    df = pd.read_parquet(bucket_filename)
    return df[['max_stars_repo_name', 'max_stars_repo_head_hexsha', 'max_stars_repo_path']]

In [None]:
res = []
for path in paths:
    res += [get_java_repos_paths.remote(path)]

In [None]:
res = ray.get(res)

In [None]:
res = pd.concat(res)

In [None]:
res

In [None]:
groupes = res.groupby(['max_stars_repo_name', 'max_stars_repo_head_hexsha'])

In [None]:
near_dedup_unique_files_by_repo_commit = defaultdict(lambda : defaultdict(list))
for key, gr in tqdm(groupes):
    near_dedup_unique_files_by_repo_commit[key[0]][key[1]] = list(gr['max_stars_repo_path'])

In [None]:
bucket_to_repo = defaultdict(list)
for k, v in repo_2_bucket.items():
    bucket_to_repo[v].append(k)

In [None]:
for bucket, repos in tqdm(bucket_to_repo.items()):
    data = defaultdict(lambda : defaultdict(list))
    for repo in repos:
        for commit, files in near_dedup_unique_files_by_repo_commit[repo].items():
            data[repo][commit] = files
    path = Path(f'/data/hf_repos/the-stak-repo-level/meta_data/java/{bucket}')
    path.mkdir(parents=True, exist_ok=True)
    with open(path / 'near_dedup_unique_files_by_repo_commit.json', 'wt') as f:
        json.dump(data, f)

## Select subset of 1K repos

In [None]:
res1 = res.groupby(['max_stars_repo_name', 'max_stars_repo_head_hexsha']).count()

In [None]:
res1  = res1[res1['max_stars_repo_path'] >= 20].reset_index()

In [None]:
res1

In [None]:
repos_hashes = list(res1[['max_stars_repo_name', 'max_stars_repo_head_hexsha']].drop_duplicates().values)

In [None]:
import random
random.seed(42)
repos_commit_hashes_1K = random.sample(repos_hashes, k=1000)

In [None]:
for i in range(len(repos_commit_hashes_1K)):
    repos_commit_hashes_1K[i] = list(repos_commit_hashes_1K[i])
with open('/data/hf_repos/the-stak-repo-level/meta_data/java/1K_20plus_repos_commit_hashes.json', 'wt') as f:
    json.dump(repos_commit_hashes_1K, f)

In [None]:
near_dedup_unique_files_by_repo_commit_1K = defaultdict(lambda : defaultdict(list))
for el in repos_commit_hashes_1K:
    near_dedup_unique_files_by_repo_commit_1K[el[0]][el[1]] = near_dedup_unique_files_by_repo_commit[el[0]][el[1]]
with open('/data/hf_repos/the-stak-repo-level/meta_data/java/1K_20plus_near_dedup_unique_files_by_repo_commit.json', 'wt') as f:
        json.dump(near_dedup_unique_files_by_repo_commit_1K, f)

In [None]:
repo_2_bucket_1K = dict()
for el in repos_commit_hashes_1K:
    repo_2_bucket_1K[el[0]] = repo_2_bucket[el[0]]
with open('/data/hf_repos/the-stak-repo-level/meta_data/java/1K_20plus_repo_2_bucket.json', 'wt') as f:
        json.dump(repo_2_bucket_1K, f)

## Build file list for selected repos

In [None]:
the_stack_meta_path = Path('/data/hf_repos/the-stack-metadata')
the_stack_path = Path('/data/hf_repos/the-stack-v1.1')

In [None]:
with open('/data/hf_repos/the-stak-repo-level/meta_data/java/1K_20plus_repo_2_bucket.json', 'rt') as f:
    repo_2_bucket_1K = json.load(f)

In [None]:
with open('/data/hf_repos/the-stak-repo-level/meta_data/java/1K_20plus_repos_commit_hashes.json', 'rt') as f:
    repos_commit_hashes_1K = json.load(f)

In [None]:
repos_commit_hashes_dst_buckets_1K = [(el[0], el[1], repo_2_bucket_1K[el[0]]) for el in repos_commit_hashes_1K]

In [None]:
df_repos_commit_hashes_dst_buckets_1K = pd.DataFrame(
    data=repos_commit_hashes_dst_buckets_1K, columns=['name', 'head_hexsha', 'dst_bucket']
)

In [None]:
df_repos_commit_hashes_dst_buckets_1K

In [None]:
files_info = []
for fn in tqdm(list((the_stack_meta_path/'data').glob('*/ri.parquet'))):
    df = pd.read_parquet(fn)
    res = df_repos_commit_hashes_dst_buckets_1K.merge(
        right=df, on = ['name', 'head_hexsha'], how='left'
    )
    res = res[res['id'].notna()][['name', 'head_hexsha', 'dst_bucket', 'id']]
    if len(res) == 0:
        continue
    files_info.append((fn.parent.name, res))


In [None]:
@ray.remote(scheduling_strategy="SPREAD")
def get_files(src_root, info):
    src_root = Path(src_root)
    bucket = info[0]
    res = info[1]
    df_fi = pd.read_parquet(src_root / bucket / 'fi.parquet')
    df_fi = df_fi[(df_fi['size'] > 0) & (df_fi['is_deleted'] == False) & (df_fi['lang_ex'] == 'Java')]
    
    df_fi = df_fi.merge(res, left_on='ri_id', right_on='id', how='left').dropna()[['hexsha', 'path', 'name', 'head_hexsha', 'dst_bucket']]
    return df_fi

In [None]:
res =[]
for info in files_info:
    res.append(get_files.remote(the_stack_meta_path/'data', info))

In [None]:
while True:
    a, b = ray.wait(res, num_returns=len(res), fetch_local=False, timeout=10)
    print(len(b), 'of', len(res))
    if len(b) == 0:
        break

In [None]:
res = ray.get(res)

In [None]:
res = pd.concat(res)

In [None]:
res.to_parquet('/data/hf_repos/the-stak-repo-level/meta_data/java/1K_20plus_file_list.parquet')

In [None]:
res

## Recreate those 1K repos from the file list

In [None]:
@ray.remote(scheduling_strategy="SPREAD")
def recreate_from_the_stack_bucket(bucket_fn, dst_root, file_list_fn):
    dst_root = Path(dst_root)
    file_list = pd.read_parquet(file_list_fn)
    data = pd.read_parquet(bucket_fn)
    hashes = set(file_list['hexsha'])
    data = data[data['hexsha'].isin(hashes)]
    for i, row_data in data.iterrows():
        for j, row_file_list in file_list[file_list['hexsha'] == row_data['hexsha']].iterrows():
            dst_path = (
                dst_root / str(int(row_file_list['dst_bucket'])) / row_file_list['name'] /
                row_file_list['head_hexsha'] /  row_file_list['path']
            )
            dst_path.parent.mkdir(parents=True, exist_ok=True)
            dst_path.write_text(row_data['content'])

In [None]:
paths = list((the_stack_path / 'data' / 'java').glob('*.parquet'))

In [None]:
res = []
for path in paths:
    res.append(recreate_from_the_stack_bucket.remote(
        path,
        '/data/hf_repos/the-stak-repo-level/data/java',
        '/data/hf_repos/the-stak-repo-level/meta_data/java/1K_20plus_file_list.parquet'
    ))

In [None]:
while True:
    a, b = ray.wait(res, num_returns=len(res), fetch_local=False, timeout=10)
    print(len(b), 'of', len(res))
    if len(b) == 0:
        break

In [None]:
res = ray.get(res)

## Test results

In [None]:
file_list = pd.read_parquet('/data/hf_repos/the-stak-repo-level/meta_data/java/1K_file_list.parquet')


In [None]:
missing = []
dst_root = Path('/data/hf_repos/the-stak-repo-level/data/java')
for i, row_file_list in tqdm(file_list.iterrows(), total=len(file_list)):
    dst_path = (
        dst_root / str(int(row_file_list['dst_bucket'])) / row_file_list['name'] /
        row_file_list['head_hexsha'] /  row_file_list['path']
    )
    if not dst_path.is_file():
        missing.append(dst_path)

In [None]:
len(missing)