In [1]:
import json
import pickle
import time
from tqdm import tqdm

valfile = "jsonl/downstream_val.jsonl"
def load_jsonl(path):
    with open(path, 'r') as f:
        a = f.readlines()
        g = [json.loads(i) for i in a]
    return g
val_ds = load_jsonl(valfile)
doping_test = load_jsonl('jsonl/doping_test.jsonl')
mof1_test = load_jsonl('jsonl/mof1_test.jsonl')
mof2_test = load_jsonl('jsonl/mof2_test.jsonl')
with open("jsonl/discomat_dense_test.jsonl", 'r') as f:
    discomat_test = [json.loads(line) for line in f.readlines()]


In [None]:
import google.generativeai as genai
key = "API_KEY_HERE"
genai.configure(api_key=key)

def get_prompt(input_dict):
    system_prompt = input_dict['system']
    prompt = input_dict['question']
    return system_prompt, prompt

model = genai.GenerativeModel("gemini-1.5-flash-8b") #Using the flat-8b model as that is of comparable size. 


def get_output(input_dict, additional = ""):
    sys, prompt = get_prompt(input_dict)
    sys = sys + additional 
    response = model.generate_content(
        "system\n" + sys + "\nquestion\n" + prompt,
        generation_config=genai.types.GenerationConfig(
            # Only one candidate for now.
            candidate_count=1,
            max_output_tokens=1000,
            temperature=0.01,
        ),
    )
    return response.text

In [5]:
get_output(val_ds[2])

'2223 : dopant\n'

## Downstream run

In [None]:
import pickle
import time
from tqdm import tqdm

downstream_outputs = []
exceptions = []
outname = "gemini_1.5_flash_8b"
start_time = time.time()

for idx in tqdm(range(len(val_ds))):
    try:
        out = get_output(val_ds[idx], additional = ". if JSON format is asked then do not output anything aside from the json string")
        downstream_outputs.append((idx, out))
    except Exception as e:
        print("failed at ", idx, " because\n", e)
        exceptions.append((idx, e))
    if(idx % 1000 == 0):
        with open(outname, "wb") as f:
            pickle.dump(downstream_outputs, f) #saving every 1000 iterations incase something goes wrong and generation has to stop in the middle.

with open(outname + "_downstream.pkl", "wb") as f:
    pickle.dump(downstream_outputs, f)

 19%|██████████████▌                                                             | 2355/12342 [45:50<3:24:28,  1.23s/it]

## SIE Doping mof1 mof2 run

In [None]:
doping_outputs = []
for idx in tqdm(range(len(doping_test))):
    try:
        out = get_output(doping_test[idx], additional = "do not output any additional text")
        doping_outputs.append((idx, out))
    except Exception as e:
        print("failed at ", idx, " because\n", e)
        exceptions.append((idx, e))
        
with open(outname + "_doping.pkl", "wb") as f:
    pickle.dump(doping_outputs, f)

In [None]:
mof1_outputs = []
for idx in tqdm(range(len(mof1_test))):
    try:
        out = get_output(mof1_test[idx],additional = "do not output any additional text")
        mof1_outputs.append((idx, out))
    except Exception as e:
        print("failed at ", idx, " because\n", e)
        exceptions.append((idx, e))
        
with open(outname + "_mof1.pkl", "wb") as f:
    pickle.dump(mof1_outputs, f)

In [None]:
# with open("../llamat2_chat_train_2epochs_full_ex_mof1_test.pkl", "rb") as f:
#     check= pickle.load(f)

In [None]:
mof2_outputs = []
for idx in tqdm(range(len(mof2_test))):
    try:
        out = get_output(mof2_test[idx],additional = "do not output any additional text")
        mof2_outputs.append((idx, out))
    except Exception as e:
        print("failed at ", idx, " because\n", e)
        exceptions.append((idx, e))
        
with open(outname + "_mof2.pkl", "wb") as f:
    pickle.dump(mof2_outputs, f)

## Discomat run

In [None]:
discomat_outputs = []
for idx in tqdm(range(len(discomat_test))):
    try:
        out = get_output(discomat_test[idx],additional = "do not output any additional text")
        discomat_outputs.append((idx, out))
    except Exception as e:
        print("failed at ", idx, " because\n", e)
        exceptions.append((idx, e))
        
with open(outname + "_discomat.pkl", "wb") as f:
    pickle.dump(discomat_outputs, f)