#### Silence Warnings (Majorly HF internal warnings)

In [1]:
%env PYTHONWARNINGS=ignore

import warnings
warnings.filterwarnings("ignore")



In [2]:
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from nemo_curator import DomainClassifier, QualityClassifier
from nemo_curator.datasets import DocumentDataset

In [3]:
cluster = LocalCUDACluster(rmm_async=True, rmm_pool_size="1GB")
client = Client(cluster)

# Data File Paths 

In [12]:
input_file_path="/home/nfs/syurick/LLM_domain_classifier_inference/4360_results_jsonl_dir/"
output_file_path = "/raid/vjawa/output_file.parquet"

# Create a Classifier

In [5]:
classifier_type="DomainClassifier"

In [6]:
%%time

input_dataset = DocumentDataset.read_json(
    input_file_path, backend="cudf", add_filename=True
)

if classifier_type == "DomainClassifier":
    domain_labels = [
    "Adult",
    "Arts_and_Entertainment",
    "Autos_and_Vehicles",
    "Beauty_and_Fitness",
    "Books_and_Literature",
    "Business_and_Industrial",
    "Computers_and_Electronics",
    "Finance",
    "Food_and_Drink",
    "Games",
    "Health",
    "Hobbies_and_Leisure",
    "Home_and_Garden",
    "Internet_and_Telecom",
    "Jobs_and_Education",
    "Law_and_Government",
    "News",
    "Online_Communities",
    "People_and_Society",
    "Pets_and_Animals",
    "Real_Estate",
    "Science",
    "Sensitive_Subjects",
    "Shopping",
    "Sports",
    "Travel_and_Transportation",
    ]
    model_file_name = "/home/nfs/syurick/LLM_domain_classifier_inference/" + \
                      "GoogleDebertaAgree_v3b_bce_maxlen512_bs64_noRef_best.pth"
    classifier = DomainClassifier(
        model_file_name=model_file_name,
        labels=domain_labels,
        batch_size=1024,
    )
elif classifier_type == "QualityClassifier":
    quality_labels = ["High", "Medium", "Low"]
    model_file_name = "/home/nfs/syurick/LLM_quality_classifier_inference/" + \
                      "quality_rnd3_2014val1070_10ep_2xhigh_1024_fold4_last-001.pth"


    classifier = QualityClassifier(
        model_file_name=model_file_name,
        labels=quality_labels,
        batch_size=1024,
    )
else:
    raise ValueError("Invalid classifier type")

Reading 16 files
CPU times: user 6.68 s, sys: 5.1 s, total: 11.8 s
Wall time: 8.08 s


# Run the actuall Classifier

In [13]:
%%time 

result_dataset = classifier(dataset=input_dataset)
result_dataset.df = result_dataset.df.rename(columns={"labels": f"{classifier_type}_prediction"})
result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)

Starting domain classifier inference


GPU: 0, Part: 5: 100%|██████████| 938/938 [00:05<00:00, 164.89it/s]]
GPU: 0, Part: 6: 100%|██████████| 938/938 [00:05<00:00, 161.22it/s]
GPU: 0, Part: 9: 100%|██████████| 937/937 [00:05<00:00, 157.10it/s]
GPU: 0, Part: 3: 100%|██████████| 938/938 [00:06<00:00, 153.23it/s]
GPU: 0, Part: 8: 100%|██████████| 937/937 [00:05<00:00, 156.58it/s]
GPU: 0, Part: 4: 100%|██████████| 938/938 [00:06<00:00, 151.96it/s]
GPU: 0, Part: 0: 100%|██████████| 938/938 [00:05<00:00, 163.98it/s]
GPU: 0, Part: 1: 100%|██████████| 938/938 [00:05<00:00, 162.96it/s]
GPU: 0, Part: 7: 100%|██████████| 937/937 [00:06<00:00, 154.01it/s]
GPU: 0, Part: 2: 100%|██████████| 938/938 [00:05<00:00, 162.85it/s]
GPU: 0, Part: 12: 100%|██████████| 937/937 [00:05<00:00, 160.30it/s]
GPU: 0, Part: 15: 100%|██████████| 937/937 [00:05<00:00, 161.63it/s]
GPU: 0, Part: 13: 100%|██████████| 937/937 [00:05<00:00, 157.16it/s]
GPU: 0, Part: 14: 100%|██████████| 937/937 [00:05<00:00, 157.38it/s]
GPU: 0, Part: 11: 100%|██████████| 937/937 

Writing to disk complete for 16 partitions
CPU times: user 1.38 s, sys: 1.34 s, total: 2.72 s
Wall time: 8.72 s


GPU: 0, Part: 10: 100%|██████████| 937/937 [00:06<00:00, 155.39it/s]


#### Verify The file was written correctly

In [14]:
output_dataset = DocumentDataset.read_json(output_file_path, backend="cudf", add_filename=True)
output_dataset.df.head(2)

Reading 16 files


Unnamed: 0,DomainClassifier_prediction,adlr_id,filename,id,pred,source_id,split_id,text,url
0,Online_Communities,cc-2022-40-0431053204,00.jsonl,a8083fe4-525d-4888-8513-b91f43bd8ee1,Online_Communities,crawl-data-CC-MAIN-2022-40-segments-1664030336...,lambada-0003225258-0000,Having been a community leader—and member—for ...,https://lisalarter.com/7-tips-for-building-ste...
1,Finance,cc-2022-40-0510168267,00.jsonl,559febdc-cb7f-4217-897a-c8dac325123b,Finance,crawl-data-CC-MAIN-2022-40-segments-1664030337...,lambada-0003918122-0000,Zelle is a way of sending money to almost anyo...,https://oregonmassageandwellnessclinic.com/app...


##### cleanup the output file

In [15]:
!rm -rf $output_file_path