In [None]:
#!/usr/bin/env python
# coding: utf-8

import pandas as pd
import time
import re
import os
from datasets import load_dataset
from transformers import pipeline

http_proxy_url = "http://proxy.alcf.anl.gov:3128"
https_proxy_url = "http://proxy.alcf.anl.gov:3128"
os.environ['http_proxy'] = http_proxy_url
os.environ['https_proxy'] = https_proxy_url
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3"

dataset = load_dataset("Flamenco43/cleaned_html_nano_papers_31k")
sample = dataset['train']

# pipelines for QA
pipe = pipeline("question-answering", model="Flamenco43/Synthetic_NanoQA")

pd.set_option('display.max_rows', None)

def print_span(text, start_match, end_match):
    start_match_escaped = re.escape(start_match)
    end_match_escaped = re.escape(end_match)
    pattern = f"{start_match_escaped}(.*?){end_match_escaped}"
    matches = re.findall(pattern, text, re.DOTALL)
    for match in matches:
        return f"{start_match}{match}{end_match}"

complete_start_time = time.time()

# Inference begins
batch_size = 32
num_batches = len(dataset['train']) // batch_size + (1 if len(dataset['train']) % batch_size != 0 else 0)

for batch_num in range(num_batches):
    start_idx = batch_num * batch_size
    end_idx = min(start_idx + batch_size, len(dataset['train']))
    batch_samples = sample[start_idx:end_idx]
    
    for n in range(len(batch_samples)):
        chunks = [batch_samples['paragraphs'][n][i:i+512] for i in range(0, len(batch_samples['paragraphs'][n]), 512)]
        
        start_time = time.time()
        squad_q1_outputs = []
        squad_q2_outputs = []
        squad_q3_outputs = []
        list_of_regex_lists = []
        chunks_df = pd.DataFrame(chunks, columns=['chunks'])
        
        q1 = "what are the nanoparticles?"
        for i in range(len(chunks)):
            q1answ = pipe(q1, chunks[i])
            q2 = "what are the properties of " + q1answ['answer'] + "?"
            q2answ = pipe(q2, chunks[i])
            q3 = "what are the nanoparticles with properties " + q2answ['answer'] + "?"
            q3answ = pipe(q3, chunks[i])
            squad_q1_outputs.append(q1answ['answer'])
            squad_q2_outputs.append(q2answ['answer'])
            squad_q3_outputs.append(q3answ['answer'])
                    
            pattern = re.compile(r'\d+(?:\.\d+)?(?:%|wt %|nm|s(?!t)| nm| c|c)', re.IGNORECASE)
            matches = re.findall(pattern, chunks_df['chunks'][i])

            regex_list = []
            if matches:
                for match in matches:
                    output = print_span(chunks[0], q3answ['answer'], match)
                    regex_list.append(output)
                
                for match in matches:
                    output = print_span(chunks[0], match, q3answ['answer'])
                    regex_list.append(output)
            list_of_regex_lists.append(regex_list)
            
        runtime = (time.time() - start_time)

        list1 = squad_q1_outputs
        list2 = squad_q2_outputs
        list3 = squad_q3_outputs
        list4 = list_of_regex_lists
        list5 = runtime
        list6 = batch_samples['dois'][n]
        list7 = batch_samples['titles'][n]
        
        data = {'np_check1': list1, 'np_check2': list3, 'properties': list2, "regex_outputs": list4, "runtime": list5, "dois": list6, "title": list7}    
        df = pd.DataFrame(data)
        filename = 'QA_output/row' + str(start_idx + n) + '.json'
        df.to_json(filename, orient='records')

complete_runtime = (time.time() - complete_start_time)
print("complete runtime", complete_runtime)