In [4]:
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import nest_asyncio
import uvicorn
import requests

nest_asyncio.apply()

ModuleNotFoundError: No module named 'fastapi'

In [None]:
import os
import sys
import json
import torch
import argparse
import numpy as np
import warnings
warnings.filterwarnings(action='ignore', message='Too many lattice symmetries was found')

from pymatgen.ext.matproj import MPRester

# IOs
from pymatgen.core.structure import Structure
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.analysis.phase_diagram import PhaseDiagram, GrandPotentialPhaseDiagram
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.core.periodic_table import Element, Species
from pymatgen.core.composition import Composition
from pymatgen.analysis.structure_analyzer import oxide_type
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry

# House code
sys.path.append('./script')
from relaxer import TrajectoryObserver, M3gnetRelaxer, ChgnetRelaxer, MaceRelaxer


mpr = MPRester(api_key="dPcAQJZ6y1NZidGuvTerIPPFXHtsOb3E")

In [1]:
def has_common_element(list1, list2):
    
    return not set(list1).isdisjoint(list2)


def read_json(fjson):
    
    with open(fjson) as f:
        return json.load(f)


def write_json(d, fjson):
    
    with open(fjson, 'w') as f:
        json.dump(d, f)

    return


def get_ehull(mpr, vasp_entry, mp_entries, compatibility):
    
    mp_entries = compatibility.process_entries(mp_entries)

    # phase diagram
    PD = PhaseDiagram(mp_entries + [vasp_entry])
    decomp_info = PD.get_decomp_and_e_above_hull(vasp_entry, allow_negative=True)
    e_above_hull = decomp_info[1]

    return e_above_hull


def get_test_structure(mpid='mp-510462'):
    
    mp_data = mpr.get_entry_by_material_id(material_id=[test_id])
    test_structure = mp_data[1].structure
    
    return test_structure


def get_chemsys_mpdata(structure):
    
    U_els = {'Co': 3.32, 'Cr': 3.7, 'Fe': 5.3, 'Mn': 3.9,
             'Mo': 4.38, 'Ni': 6.2, 'V': 4.2, 'W': 6.2}
    
    species = []
    potcar_spec = []

    for i in set(structure.species):
        species.append(i.name)

    chemical_space = '-'.join(species)
    mp_data = mpr.get_entries_in_chemsys(chemical_space)
    
    hubbards = {}
    if has_common_element(list(U_els.keys()), species):
        for specie in species:
            if specie in U_els.keys():
                hubbards[specie] = U_els[specie]
            else:
                hubbards[specie] = 0

    for d in mp_data:
        for j in d.parameters['potcar_spec']:
            if not j in potcar_spec:
                potcar_spec.append(j)
        if len(potcar_spec) == len(species):
            break

    return mp_data, potcar_spec, hubbards

def get_relaxation_result(structure, relaxer):
        
    mp_data, potcar_spec, hubbards = get_chemsys_mpdata(structure)
    
    atoms = structure.to_ase_atoms()
    # relaxer = MaceRelaxer()
    
    result = relaxer.relax(atoms, fmax=0.5)
    result['parameters'] = mp_data[0].parameters
    result['parameters']['potcar_spec'] = potcar_spec
    result['parameters']['hubbards'] = hubbards
    
    lattice = result['final_structure'].lattice.matrix
    species = [specie for specie in result['final_structure'].species]
    coords = [site.frac_coords for site in result['final_structure'].sites]
    structure = Structure(lattice=lattice, species=species, coords=coords)

    gga_entry = ComputedStructureEntry(structure=structure,
                                       energy=result['final_energy'],
                                       entry_id='llm_generation',
                                       composition=structure.composition.remove_charges(),
                                       parameters=result['parameters']
                                       )

    compat = MaterialsProject2020Compatibility()
    gga_entry = compat.process_entry(gga_entry)
    ehull = get_ehull(mpr, gga_entry, mp_data, compat)
    
    return result, ehull



In [2]:
# the request sent from the user is a string of the structure in POSCAR format, so we need to save it as file and then read it
def data_preprocessing(request):
    with open("POSCAR", "w") as f:
        f.write(request)
    
    structure = Structure.from_file("POSCAR")
    return structure

In [3]:
def inference(structure):
    
    result, ehull = get_relaxation_result(structure)
    
    total = {}
    trajectory = []
    aaa = AseAtomsAdaptor()
    
    for i in result['trajectory'].atoms_trajectory:
        relaxed_structure = aaa.get_structure(i)
        trajectory.append(relaxed_structure.as_dict())
        
    total['trajectory'] = trajectory
    total['ehull'] = ehull * 1000
 
    return total

In [5]:
app = FastAPI()

# load the model
relaxer = MaceRelaxer()

class PoscarRequest(BaseModel):
    poscar: str

def data_preprocessing(request):
    with open("POSCAR", "w") as f:
        f.write(request)
    
    structure = Structure.from_file("POSCAR")
    return structure

def inference(structure):
    result, ehull = get_relaxation_result(structure, relaxer)
    
    total = {}
    trajectory = []
    aaa = AseAtomsAdaptor()
    
    for i in result['trajectory'].atoms_trajectory:
        relaxed_structure = aaa.get_structure(i)
        trajectory.append(relaxed_structure.as_dict())
        
    total['trajectory'] = trajectory
    total['ehull'] = ehull * 1000
 
    return total

@app.get("/ping")
async def ping():
    return JSONResponse(content={"message": "pong"})

@app.post("/predict")
async def predict(data: PoscarRequest):
    try:
        # Preprocess the data
        structure = data_preprocessing(data.poscar)
        
        # Run inference
        result = inference(structure)
        
        # Clean up the temporary file
        if os.path.exists("POSCAR"):
            os.remove("POSCAR")
        
        return JSONResponse(content=result)
    except Exception as e:
        # Clean up the temporary file in case of error
        if os.path.exists("POSCAR"):
            os.remove("POSCAR")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8015)

NameError: name 'FastAPI' is not defined