In [192]:
import pandas as pd
import numpy as np
import os
from openai import OpenAI
import openai
import ast
import glob
import json

from tqdm import tqdm

In [None]:
client = OpenAI(
    # This is the default and can be omitted
    api_key="your_openai_api",
)

In [194]:
def get_gpt_api(prompt, model="gpt-4"):
    try:
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            model=model,
        )
        return chat_completion.choices[0].message.content
    except Exception as e:
        return str(e)

In [195]:
prompt_lookup = {
                "formation_energy_per_atom": "The formation energy per atom is",
                "band_gap": "The band gap is",
                "pretty_formula": "The chemical formula is",
                "e_above_hull": "The energy above the convex hull is",
                "elements": "The elements are",
                "spacegroup.number": "The spacegroup number is",
            }

def get_template(row):
    property_prompt = ""
    for key, value in prompt_lookup.items():
        if key == "elements":
            property_prompt += f"{value} {', '.join(ast.literal_eval(row[key]))}. "
        elif key in ["formation_energy_per_atom", "band_gap", "e_above_hull"]:
            property_prompt += f"{value} {round(float(row[key]), 4)}. "
        else:
            property_prompt += f"{value} {row[key]}. "
    
    desc_prompt = row["description"]

    return desc_prompt, property_prompt


In [196]:
# Read embeddings.csv
embeddings_path = '/data/rech/dingqian/intel/crystal-llm-retrieval/embeddings.csv'
embeddings_df = pd.read_csv(embeddings_path, index_col=0)

# Read train_desc.csv
data_path = "../data/basic/train_desc.csv"
train_df = pd.read_csv(data_path, index_col=1)

In [197]:
train_df.head()

Unnamed: 0_level_0,Unnamed: 0,formation_energy_per_atom,band_gap,pretty_formula,e_above_hull,elements,cif,spacegroup.number,description
material_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
mp-1221227,37228,-1.63746,0.2133,Na3MnCoNiO6,0.043001,"['Co', 'Mn', 'Na', 'Ni', 'O']",# generated using pymatgen\ndata_Na3MnCoNiO6\n...,8,Na3MnCoNiO6 is Caswellsilverite-derived struct...
mp-974729,19480,-0.314759,0.0,Nd(Al2Cu)4,0.0,"['Al', 'Cu', 'Nd']",# generated using pymatgen\ndata_Nd(Al2Cu)4\n_...,139,Al8Cu4Nd crystallizes in the monoclinic C2/m s...
mp-1185360,29624,-0.193761,0.0,LiMnIr2,0.018075,"['Ir', 'Li', 'Mn']",# generated using pymatgen\ndata_LiMnIr2\n_sym...,225,LiMnIr2 is Heusler structured and crystallizes...
mp-1188861,38633,-0.584694,3.8556,LiCSN,0.048847,"['C', 'Li', 'N', 'S']",# generated using pymatgen\ndata_LiCSN\n_symme...,62,
mp-677272,10889,-2.474759,0.4707,La2EuS4,0.0,"['Eu', 'La', 'S']",# generated using pymatgen\ndata_La2EuS4\n_sym...,122,EuLa2S4 crystallizes in the tetragonal I-42d s...


In [198]:
embeddings_df.head()

Unnamed: 0_level_0,description,query,desc_embedding,query_embedding,second_most_similar_material_id,most_similar_material_id
material_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
mp-1221227,Na 3 Mn Co Ni O 6 is Ca swellsilverite-derived...,Be low is a description of a bulk material. Th...,[ 0.04090341 -0.01577177 0.06744143 -0.067098...,[ 0.02537145 -0.01533251 0.05748841 -0.070425...,mp-1222510,mp-1221227
mp-974729,Al 8 Cu 4 Nd crystallizes in the monoclinic C ...,Be low is a description of a bulk material. Th...,[ 3.77487834e-03 1.55593464e-02 -1.49153247e-...,[ 0.00705677 0.00390458 -0.00846322 -0.004971...,mp-1092251,mp-974729
mp-1185360,Li Mn Ir 2 is He usler structured and crystall...,Be low is a description of a bulk material. Th...,[-3.41129303e-02 4.88284323e-03 4.57347259e-...,[-1.50608942e-02 3.76177672e-03 3.40886451e-...,mp-864950,mp-865088
mp-677272,Eu La 2 S 4 crystallizes in the tetragonal I -...,Be low is a description of a bulk material. Th...,[-0.01615838 0.01042143 -0.04585846 -0.037169...,[-1.11538060e-02 1.12731121e-02 -3.59769352e-...,mp-1222958,mp-677272
mp-1104517,Yb 3 Pt 2 Ga 9 crystallizes in the orthorhombi...,Be low is a description of a bulk material. Th...,[ 0.03487201 -0.01587198 0.02627973 -0.031895...,[ 2.03714818e-02 -1.58217978e-02 1.39079597e-...,mp-1215613,mp-1104517


In [199]:
def get_prompt(formula_A, formula_B, desc_prompt_A, property_prompt_A, desc_prompt_B, property_prompt_B):

    prompt = f"I have two materials {formula_A} and {formula_B}.\n\n"
    prompt += f"The description of material {formula_A}: " + desc_prompt_A + '\n\n'
    # prompt += f"The property of material {formula_A}: " + property_prompt_A + '\n\n'
    prompt += f"The description of material {formula_B}: " + desc_prompt_B + '\n\n'
    # prompt += f"The property of material {formula_B}: " + property_prompt_B + '\n\n'
    prompt += f"Based on the descriptions and properties of the two materials above, "
    prompt += f"can you summarize how can we transit material {formula_A} to material {formula_B} in one paragraph? \n\n"
    prompt += f"Please do not include any hint of formula {formula_B} in your answer as we have no prior knowledge of formula {formula_B}. "

    prompt += f"\n\nThe answer should begin with: To transit from {formula_A} to a new material,"

    return prompt

In [200]:
def get_answer(row):
    prompt = get_prompt(row)
    return prompt, get_gpt_api(prompt)

In [201]:
batch_list = []

for index, row in tqdm(embeddings_df.iterrows()):
    material_id_A = row.name
    material_id_B = row["most_similar_material_id"] if row["most_similar_material_id"] != material_id_A else row["second_most_similar_material_id"]
    
    material_A = train_df.loc[material_id_A]
    material_B = train_df.loc[material_id_B]

    desc_prompt_A, property_prompt_A = get_template(material_A)
    desc_prompt_B, property_prompt_B = get_template(material_B)

    formula_A = material_A["pretty_formula"]
    formula_B = material_B["pretty_formula"]

    prompt = get_prompt(formula_A, formula_B, desc_prompt_A, property_prompt_A, desc_prompt_B, property_prompt_B)

    batch_dict = {
        "custom_id": material_id_A,
        "method": "POST",
        "url": "/v1/chat/completions",
        "body": {
            "messages": [
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            "model": "gpt-4-turbo",
            "max_tokens": 2048
        },
    }

    batch_list.append(batch_dict)

342it [00:00, 3419.40it/s]

22592it [00:04, 4535.46it/s]


In [202]:
answer = get_gpt_api(prompt, model="gpt-4-turbo")
answer

'To transit from SrSc2O4 to a new material, a significant restructuring of the crystal lattice and substitutions at the cation and anion sites would be necessary. The transition begins by replacing Sr and Sc with Li and Ni, respectively. The orthorhombic crystal structure should be maintained, albeit in a different space group to accommodate the altered coordination environments and chemical properties of Li and Ni compared to Sr and Sc. Li typically coordinates with oxygen in a different manner than Sr, often favoring a more compact octahedral coordination, as does Ni when substituting for Sc. Additionally, the oxygen sublattice would have to be reorganized to fit the new requirements for bond angles and lengths presented by the presence of Li and Ni. This transition would involve careful consideration of ionic sizes and charges to balance the structural stability and stoichiometry of the new compound within the framework of the orthorhombic system. This would not be a direct substitu

In [203]:
# convert to jsonl file

# save the batch_list to the jsonl file per 40000 rows
K = 40000
for i in range(0, len(batch_list), K):
    jsonl_file = f"batch_{i}.jsonl"
    with open(jsonl_file, 'w') as f:
        for batch in batch_list[i:i+K]:
            f.write(json.dumps(batch) + '\n')



In [204]:
batch_input_file = client.files.create(
  file=open("batch_0.jsonl", "rb"),
  purpose="batch"
)

batch_input_file_id = batch_input_file.id

response = client.batches.create(
    input_file_id=batch_input_file_id,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
      "description": "Full"
    }
)

status = response.status
status

'validating'

In [211]:
response

Batch(id='batch_lqhq9Cyye4F6IygKCKQD69Qd', completion_window='24h', created_at=1718954988, endpoint='/v1/chat/completions', input_file_id='file-RJQ6q4M8IxJHyfMOVXDpfnuC', object='batch', status='validating', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1719041388, failed_at=None, finalizing_at=None, in_progress_at=None, metadata={'description': 'Full'}, output_file_id=None, request_counts=BatchRequestCounts(completed=0, failed=0, total=0))

In [226]:
retrieve_response = client.batches.retrieve(response.id)
retrieve_response

Batch(id='batch_lqhq9Cyye4F6IygKCKQD69Qd', completion_window='24h', created_at=1718954988, endpoint='/v1/chat/completions', input_file_id='file-RJQ6q4M8IxJHyfMOVXDpfnuC', object='batch', status='completed', cancelled_at=None, cancelling_at=None, completed_at=1718958940, error_file_id='file-IvUc0y37T5xGim2yef4bD5CE', errors=None, expired_at=None, expires_at=1719041388, failed_at=None, finalizing_at=1718957934, in_progress_at=1718954994, metadata={'description': 'Full'}, output_file_id='file-2xT1rRwUmIM2dHjvBS5uk5GP', request_counts=BatchRequestCounts(completed=5560, failed=17032, total=22592))

In [239]:
content = client.files.retrieve_content(retrieve_response.output_file_id)
error_content = client.files.retrieve_content(retrieve_response.error_file_id)

  content = client.files.retrieve_content(retrieve_response.output_file_id)
  error_content = client.files.retrieve_content(retrieve_response.error_file_id)


In [244]:
# save the content to the json file
with open("batch_0_output.json", "w") as f:
    f.write(content)

In [245]:
# save the error content to the json file
with open("batch_0_error.json", "w") as f:
    f.write(error_content)

In [246]:
# find the custom_id in the error content and reconstrut the batch_list

# read lines from the error content

with open("batch_0_error.json", "r") as f:
    lines = f.readlines()

error_custom_ids = []

for line in lines:
    error_custom_ids.append(json.loads(line)["custom_id"])

In [249]:
batch_list_error = []

for batch_dict in batch_list:
    if batch_dict["custom_id"] in error_custom_ids:
        batch_list_error.append(batch_dict)

In [251]:
# save the batch_list to the jsonl file per 40000 rows
K = 40000
for i in range(0, len(batch_list_error), K):
    jsonl_file = f"batch_{i}_error.jsonl"
    with open(jsonl_file, 'w') as f:
        for batch in batch_list_error[i:i+K]:
            f.write(json.dumps(batch) + '\n')

In [252]:
batch_input_file = client.files.create(
  file=open("batch_0_error.jsonl", "rb"),
  purpose="batch"
)

batch_input_file_id = batch_input_file.id

response = client.batches.create(
    input_file_id=batch_input_file_id,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
      "description": "Full - error"
    }
)

status = response.status
status

'validating'

In [253]:
response

Batch(id='batch_g4wdMMTw1ATmO94kmtuy3sHR', completion_window='24h', created_at=1718960665, endpoint='/v1/chat/completions', input_file_id='file-Wpl3KEB34dBUPM7OtLNZc6fK', object='batch', status='validating', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1719047065, failed_at=None, finalizing_at=None, in_progress_at=None, metadata={'description': 'Full - error'}, output_file_id=None, request_counts=BatchRequestCounts(completed=0, failed=0, total=0))

In [275]:
# retrieve
retrieve_response = client.batches.retrieve(response.id)
retrieve_response.status

'completed'

In [276]:
content = client.files.retrieve_content(retrieve_response.output_file_id)

  content = client.files.retrieve_content(retrieve_response.output_file_id)


In [277]:
# save content 
with open("batch_0_error_output.json", "w") as f:
    f.write(content)

In [278]:
# merge the output files into one

output_files = glob.glob("batch_*_output.json")
output_filename = "output.json"
with open(output_filename, "w") as outfile:
    for f in output_files:
        with open(f, "r") as infile:
            outfile.write(infile.read())

['batch_0_error_output.json', 'batch_0_output.json']