In [None]:
import os

import boto3
import pandas as pd
from datasets import load_dataset
from smart_open import open

In [None]:
def download_content(s3, blob_id, src_encoding):
    s3_url = f"s3://softwareheritage/content/{blob_id}"
    with open(s3_url, "rb", compression=".gz", transport_params={"client": s3}) as fin:
        content = fin.read().decode(src_encoding)
    return {"content": content}

In [None]:
session = boto3.Session(
    aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
    aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
    aws_session_token=os.environ["AWS_SESSION_TOKEN"],
)
s3 = session.client("s3")

ds = load_dataset("bigcode/the-stack-v2-dedup", "Go", split="train", streaming=True)

num_entries = 1000
results = []
for i, row in enumerate(ds):
    if i >= num_entries:
        break
    entry = download_content(s3, row["blob_id"], row["src_encoding"])
    results.append(entry)

# Convert results to DataFrame
df = pd.DataFrame(results)

In [None]:
# print(df["content"][1])
df

In [None]:
import json
import os

from transformers import AutoTokenizer

from src.go_ast_tokenizer.dataset_builder import GoStyleChecker

os.environ["TOKENIZERS_PARALLELISM"] = "false"

TOKENS_LENGTH_CUTOFF = 2000

MODEL_ID = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

checker = GoStyleChecker()

num_skipped = 0
labels = []

for _idx, snippet in enumerate(df["content"]):
    num_tokens = len(tokenizer.encode(snippet))
    if num_tokens > TOKENS_LENGTH_CUTOFF:
        num_skipped += 1
        labels.append([])
        continue

    try:
        warnings = checker.check_style(snippet)
    except Exception:
        warnings = []

    labels.append(warnings)

df["labels"] = labels

with open("output.jsonl", "w", encoding="utf-8") as outfile:
    for record in df.to_dict(orient="records"):
        json.dump(record, outfile)
        outfile.write("\n")

print(f"Number of skipped snippets: {num_skipped}")