In [None]:
import pandas as pd
import os
import datasets

import pyarrow as pa
import pyarrow.parquet as pq

from collections import defaultdict

from tqdm import tqdm

In [None]:
pmid_to_cc = pd.read_csv('../assets/pmid/PMC-ids.csv')
pmid_to_cc2 = pd.read_csv('../assets/pmid/oa_non_comm_use_pdf.csv')
pmid_to_cc3 = pd.read_csv('../assets/pmid/oa_file_list.csv')
pmid_to_cc4 = pd.read_csv('../assets/pmid/oa_comm_use_file_list.csv')

In [None]:
pmid_to_cc.columns

In [None]:
pmid_to_cc2.columns

In [None]:
pmid_to_cc3.columns

In [None]:
pmid_to_cc4.columns

In [None]:
# Combine all mappings into a single DataFrame
pmid_to_cc_combined = pd.concat([pmid_to_cc2, pmid_to_cc3, pmid_to_cc4], ignore_index=True)
pmid_to_cc_combined.drop_duplicates(subset=['Accession ID'], inplace=True)

In [None]:
pmid_to_cc_combined.dropna(subset=['PMID'], inplace=True)

In [None]:
pmid_to_cc_combined.groupby('License').size() / len(pmid_to_cc_combined)

In [None]:
pmid_to_cc_combined.PMID = pmid_to_cc_combined.PMID.astype(str).str.strip()

In [None]:
# hf dataset: UMCU/PubmedAbstracts_Dutch_v1
# fiedlds: pmid, year, txt_line, text

""" 
I want to load the hf dataset streaming and add the cc license info to each record based on the pmid.
If there is no cc license info for a pmid, I want to add 'Copyright reserved' as license.
"""

def add_cc_license(example):
    pmid = example['pmid']
    cc_info = pmid_to_cc_combined[pmid_to_cc_combined['PMID'] == str(pmid).strip()]
    if not cc_info.empty:
        example['License'] = cc_info['License'].values[0]
    else:
        example['License'] = 'Copyright reserved'
    return example

dataset = datasets.load_dataset('UMCU/PubmedAbstracts_Dutch_v1', split='train', streaming=True)
dataset_with_license = dataset.map(add_cc_license)
for record in dataset_with_license.take(5):
    print(record)

In [None]:
# I want to streaming save dataset_with_license to parquet
output_dir = '../assets/pmid/pmid_with_cc_license_parquet'
if not os.path.exists(output_dir):
    os.makedirs(output_dir) 

# Batch size
batch_size = 256
writer = None
license_counts = defaultdict(int)

for batch in tqdm(dataset_with_license.iter(batch_size=batch_size), desc="Writing to Parquet"):
    # Convert batch to Arrow Table
    table = pa.Table.from_pydict(batch)
    
    batch_license_counts = table.to_pandas()['License'].value_counts().to_dict()
    for license, count in batch_license_counts.items():
        license_counts[license] += count

    if writer is None:
        writer = pq.ParquetWriter(os.path.join(output_dir, "dataset_with_license.parquet"), table.schema)

    writer.write_table(table)

writer.close()


In [16]:
license_counts

defaultdict(int, {'Copyright reserved': 7168})