# 💡 Colab Demo: Predicting Bulk Modulus with an AI Agent (early version of [DREAMS](https://arxiv.org/abs/2507.14267))

> ⚠️ **Note:** Running Quantum ESPRESSO is not feasible in Google Colab.  
> For demonstration purposes, we use the **EMT calculator** from ASE as a lightweight substitute.
> You can later replace it with Quantum ESPRESSO or a machine-learned interatomic potential (MLIP) like MatterSim if running locally or on HPC.

This notebook shows how an agent:
- Builds atomic structures with varying Cu/Au concentrations
- Automatically generates DFT-like input settings
- Predicts bulk modulus using the equation of state (EOS) fit

To explore our up-to-date multi-agent framework DREAMS, please check out our repo: https://github.com/BattModels/material_agent

In [23]:
! pip install -qU "langchain[anthropic]"
! pip install langchain_community
! pip install ase
! pip install pymatgen



In [24]:

from getpass import getpass
from typing import List
from ase import Atoms, Atom
from langchain.agents import tool
from ase.build import bulk
from ase.data import reference_states, atomic_numbers
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.agents.format_scratchpad.openai_tools import (
    format_to_openai_tool_messages,
)
from typing import Callable, List

from langchain.memory import ConversationBufferMemory
from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage,
)

from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain_core.pydantic_v1 import BaseModel
from langchain.agents import AgentExecutor
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.ext.matproj import MPRester
import os, sys, glob

from ase.calculators.emt import EMT
from ase.eos import calculate_eos
from ase.units import kJ
from ase.lattice.cubic import FaceCenteredCubic
from ase import Atoms
import io
from ase.io import write, read
import numpy as np
import ast

# from ase.calculators.espresso import Espresso, EspressoProfile
from getpass import getpass

In [25]:
# Anthropic API
API_KEY = getpass("Enter your API key: ")

Enter your API key: ··········


In [26]:
# Materials Project API
mp_key = getpass("Enter your Materials Project API key: ")

Enter your Materials Project API key: ··········


In [27]:
## Env variables
os.environ['ANTHROPIC_API_KEY'] = API_KEY #getpass("Enter your API key")
os.environ["LANGSIM_PROVIDER"] = "anthropic"
os.environ["LANGSIM_API_KEY"] = os.environ['ANTHROPIC_API_KEY']
os.environ["LANGSIM_MODEL"] = "claude-3-5-sonnet-20240620"

In [28]:
def print_nested_dict(myDict):
    for i in myDict:
        print(i)
        for j in myDict[i]:
            print(f"\t{j}: {myDict[i][j]}")

In [29]:
llm = ChatAnthropic(model=os.environ["LANGSIM_MODEL"], api_key=os.environ['ANTHROPIC_API_KEY'])

In [30]:
## Functions
class AtomsDict(BaseModel):
    numbers: List[int]
    positions: List[List[float]]
    cell: List[List[float]]
    pbc: List[bool]

@tool
def get_crystal_structure(chemical_formula: str) -> str:
    """Returns both the crystal structure (in ase Atoms) and the information (in dictionary) of a chemcial compounds"""
    with MPRester(mp_key) as mpr:
        docs = mpr.materials.summary.search(formula=chemical_formula, fields=["material_id", "formula_pretty", "structure", "is_magnetic", "ordering", "energy_above_hull", "energy_per_atom"]
        )

    if len(docs) == 0:
        return "No crystal structure known."
    else:
        docs = sorted(docs, key=lambda x: x.energy_per_atom)
        struct = docs[0].structure
        atoms = AseAtomsAdaptor.get_atoms(struct)
        return AtomsDict(**{k: v.tolist() for k, v in atoms.todict().items()}), docs[0].dict()


@tool
def load_csv(file_path: str, csv_args: dict) -> str:
    """Load a csv file and return the data as a pandas dataframe"""
    from langchain_community.document_loaders.csv_loader import CSVLoader
    loader = CSVLoader(file_path=file_path, csv_args = csv_args)
    data = loader.load()
    return data

@tool
def get_kpoints(atom_dict: AtomsDict, k_point_distance: str) -> str:
    """Returns the kpoints of a given ase atoms object and user specified k_point_distance (k_point_distance could be fine, normal, coarse or very fine, default is normal)"""
    import numpy as np
    """Returns the kpoints for a crystal structure"""
    atoms = Atoms(**atom_dict.dict())
    cell = atoms.cell
    if 'fine' in k_point_distance and 'very' not in k_point_distance:
        kspacing = 0.2
    elif 'very' in k_point_distance and 'fine' in k_point_distance:
        kspacing = 0.1
    elif 'coarse' in k_point_distance:
        return [1,1,1]
    else:
        kspacing = 0.3
    kpoints = [
        (np.ceil(2 * np.pi / np.linalg.norm(ii) / kspacing).astype(int)) for ii in cell
    ]
    return kpoints

@tool
def get_optimized_lattice(atom_dict: AtomsDict) -> str:
    "Returns the optimized lattice of a given ase atoms object using Equation of State"
    from ase.eos import EquationOfState, calculate_eos
    from ase.optimize import BFGS, LBFGS
    import numpy as np
    from ase.calculators.emt import EMT
    atoms = Atoms(**atom_dict.dict())
    atoms.calc = EMT()
    # relax the atoms
    dyn = BFGS(atoms)
    dyn.run(fmax=0.01)
    # EOS calculation
    scales = np.linspace(0.95, 1.05, 5)
    v0 = atoms.get_volume()
    eos = calculate_eos(atoms, trajectory=str(atoms.symbols)+'.traj',npoints = 5, eps = 0.04)
    v, e, B = eos.fit()
    scale = (v0/v)**(1/3)
    atoms.set_cell(atoms.get_cell()*scale, scale_atoms=True)
    return AtomsDict(**{k: v.tolist() for k, v in atoms.todict().items()})

@tool
def append_hubbard_u_card(input_file: str, u_values: dict) -> str:
    "Returns the input file with the appended Hubbard U card that is consistent for QE7.2 input format. The u_values is a dictionary with key formatted as element-orbital and the U value as the value"
    input_file += '\n'
    input_file += 'HUBBARD ortho-atomic\n'
    for key in u_values:
        input_file += f'U {key} {u_values[key]:.2f}\n'
    input_file += '\n'
    return input_file

class DFTAgent:
    def __init__(self, model: ChatAnthropic) -> None:
        self.model = model

        tools = [get_kpoints, load_csv, get_crystal_structure]

        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    # "You are very powerful assistant, but don't know current events.",  # This initial query fails when car is provided as input rather than gold
                    "You are very powerful compututation material scientist that produces high-quality quantum espresso input files for density functional theory calculations, but don't know current events. \
                    For each query vailidate that it contains chemical elements from the periodic table and otherwise cancel.\
                    Always generate conventional cell with ibrav=0 and do not use celldm and angstrom at the same time.\
                    Please include CONTROL, SYSTEM, ELECTRONS, ATOMIC_SPECIES, K_POINTS, ATOMIC_POSITIONS, and CELL. \
                    Use the right smearing based on the material.\
                    If not specified, use normal for k point spacing.\
                    If the system involves hubbard U correction, specify starting magnetization in SYSTEM card and hubbard U parameters in HUBBARD card, and use the pre-defined hubbard correction tool.\
                    The conv_thr should scale with number of atoms in the system.\
                    Do not use the get crystal structure tool unless you are explicitly told to do so.\
                    Do not ask for the user's permission to proceed.\
                    Please make sure that the input is the most optimal. \
                    The input dictionary should be readable by ase.Espresso.\
                    ",
                ),
                ("user", "{input}"),
                MessagesPlaceholder(variable_name="agent_scratchpad"),
            ]
        )

        llm_with_tools = llm.bind_tools(tools)

        agent = (
            {
                "input": lambda x: x["input"],
                "agent_scratchpad": lambda x: format_to_openai_tool_messages(
                    x["intermediate_steps"]
                ),
            }
            | prompt
            | llm_with_tools
            | OpenAIToolsAgentOutputParser()
        )

        self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, max_iterations=1000)

    # def query(self, atoms):
    #     question = f"""
    #     Please generate Quantum Espresso input settings in a dictionary for the given ase atoms dictionary: {atoms.todict()}.
    #     """

    #     lst = list(self.agent_executor.stream({'input': question}))

    #     a = lst[-1]['output'][0]['text']
    #     a = a.split('```')[1]
    #     a = a.split('=')[1]
    #     a = a.replace('\n','')
    #     a = a.replace(' ','')

    #     b = ast.literal_eval(a)
    #     return b

    # Dummy query function for demo with EMT()
    def query(self, atoms):
      question = f"""
        Please generate Quantum Espresso input settings in a dictionary for the given ase atoms dictionary: {atoms.todict()}.
      """

      lst = list(self.agent_executor.stream({'input': question}))

      # Just extract and print the response — no parsing
      a = lst[-1]['output'][0]['text']
      print("🔍 LLM-generated input (not parsed):")
      print(a)

      # Return a dummy or empty dict just to keep downstream code running
      return {}


    def display_result(out: str) -> str:
        string = out[-1]['output'][0]['text']
        string = string.replace('\\n', '\n')
        print(string.split('```')[1])


In [31]:
def parse_qe_input_string(input_string):
    sections = ['control', 'system', 'electrons', 'ions', 'cell']
    input_data = {section: {} for section in sections}
    input_data['atomic_species'] = {}
    input_data['hubbard'] = {}

    lines = input_string.strip().split('\n')
    current_section = None
    atomic_species_section = False
    hubbard_section = False

    for line in lines:
        line = line.strip()

        if line.startswith('&') and line[1:].lower() in sections:
            current_section = line[1:].lower()
            atomic_species_section = False
            hubbard_section = False
            continue
        elif line == '/':
            current_section = None
            continue
        elif line.lower() == 'atomic_species':
            atomic_species_section = True
            hubbard_section = False
            continue
        elif line.lower() == 'hubbard (ortho-atomic)':
            hubbard_section = True
            atomic_species_section = False
            continue

        if current_section:
            if '=' in line:
                key, value = line.split('=', 1)
                key = key.strip()
                value = value.strip().strip(",")

                # Convert to appropriate type
                if value.lower() in ['.true.', '.false.']:
                    value = value.lower() == '.true.'
                elif value.isdigit():
                    value = int(value)
                else:
                    try:
                        value = float(value)
                    except ValueError:
                        pass

                input_data[current_section][key] = value
        elif atomic_species_section:
            parts = line.split()
            if len(parts) == 3:
                input_data['atomic_species'][parts[0]] = {
                    'mass': float(parts[1]),
                    'pseudopotential': parts[2]
                }
        elif hubbard_section:
            parts = line.split()
            if len(parts) == 3:
                input_data['hubbard'][parts[1]] = float(parts[2])

    return input_data

In [32]:
## Functions
class AtomsDict(BaseModel):
    numbers: List[int]
    positions: List[List[float]]
    cell: List[List[float]]
    pbc: List[bool]


class simpleDFTAgent:
    def __init__(self, model: ChatAnthropic) -> None:
        self.model = model

        # self.tools = [self.get_crystal_structure, self.get_kpoints]

        self.system_message = "You are an expert computational materials scientist that produces high-quality quantum espresso (latest version) input files for density functional theory calculations, but don't know current events. \
                    By default, use PBE QE ussp.F.UPF (https://www.physics.rutgers.edu/gbrv/) and scf. \
                    Generate cell parameters based on information given by the user (usually in extxyz format).\
                    Also generate appropriate kpts based on cell parameters.\
                    if system is magnetic then generate spin-polarized input and add separate hubbard correction card with best correction for that composition from literature. \
                    If the user specifies high accuracy, use SCAN exchange correlation functional in input_dft, else which is best suitable. \
                    "

    def display_result(self, out: str) -> str:
        string = out[-1]['output'][0]['text']
        string = string.replace('\\n', '\n')
        print(string.split('```')[1])

    def query(self, input: str):
        message = self.model.invoke(
            [
                self.system_message,
                HumanMessage(content=input),
            ]
        )
        return message.content
        # lst = list(agent_executor.stream({"input": input}))
        # self.display_result(lst)



In [33]:
@tool
def get_bulk_modulus(concentration: float) -> str:
    """Returns the bulk modulus of chemcial symbol for a given crystal structure in GPa"""

    atoms = FaceCenteredCubic("Cu", latticeconstant=3.58)
    atoms *= (1,1,2)
    # Calculate the number of Cu atoms to replace
    num_atoms_to_replace = int(concentration * len(atoms))
    # Randomly select indices to replace
    indices_to_replace = np.random.choice(len(atoms), num_atoms_to_replace, replace=False)
    atoms.numbers[indices_to_replace] = 79
    scaleFactor = concentration * (6.5 - 3.58) / 3.58 + 1
    atoms.set_cell(atoms.cell * scaleFactor, scale_atoms=True)

    myDFTAgent = DFTAgent(model=llm)

    # input_data = myDFTAgent.query(atoms)
    # input_data["control"]["pseudo_dir"] = '/Users/Michael_wang/Documents/venkat/SSSP_1.3.0_PBE_precision'
    # print_nested_dict(input_data)
    # return 0
    # print_nested_dict(input_data)

    # del input_data["atomic_positions"]


    # pseudopotentials = {"Au": "Au_ONCV_PBE-1.0.oncvpsp.upf", "Cu": "Cu.paw.z_11.ld1.psl.v1.0.0-low.upf"}
    # pseudopotentials = {"Cu": "Cu.paw.z_11.ld1.psl.v1.0.0-low.upf", "Au": "Au_ONCV_PBE-1.0.oncvpsp.upf"}
    # profile = EspressoProfile(command='mpiexec -n 8 pw.x', pseudo_dir='/Users/Michael_wang/Documents/venkat/SSSP_1.3.0_PBE_precision')

    atoms.calc = EMT()
    # atoms.calc = Espresso(
       # profile=profile,
       # pseudopotentials=pseudopotentials,
       # input_data=input_data
    # )
    # atoms.calc = EMT()
    eos = calculate_eos(atoms, trajectory='dummy.traj')
    v, e, B = eos.fit()
    return B / kJ * 1.0e24

def display_result(out: str) -> str:
    string = out[-1]['output'][0]['text']
    string = string.replace('\\n', '\n')
    print(string.split('```')[1])

In [34]:
# get_bulk_modulus(0.25)


In [35]:
tools = [get_bulk_modulus]

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are very powerful assistant that performs bulk modulus calculations on atomistic level, but don't know current events. \
            For each query vailidate that the chemical elements only contains Copper and Gold and otherwise cancel. \
            Get the structure from supplied function. Use Atomic positions in Angstroms. \
            If the composition is not pure gold or pure copper, use the supplied function to generate mixed metal structure.\
            Calculate bulk modulus of both single metal and mixed metal from the supplied function.\
            You should try identifying if either Cu or Au meets the desired bulk modulus, if not, \
            try changing the concentration of Cu and Au until reaches 10 trials or meets the user input bulk modulus requirement.\
            From each calculation, validate that the desired bulk modulus is strictly following user input bulk modulus, otherwise cancel.\
            Also, is user specified a acceptable error range, for each calculation if the resulting bulk modulus is within that range, stop immediately.\
            ",
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ]
)

In [36]:
llm_with_tools = llm.bind_tools(tools)

##

agent = (
    {
        "input": lambda x: x["input"],
        "agent_scratchpad": lambda x: format_to_openai_tool_messages(
            x["intermediate_steps"]
        ),
    }
    | prompt
    | llm_with_tools
    | OpenAIToolsAgentOutputParser()
)

In [37]:
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)


lst = list(agent_executor.stream({"input": "Can you find which concentration of Cu and Au will give a bulk modulus around 170 GPa with an error of plus or minus 5?"}))



[1m> Entering new None chain...[0m
[32;1m[1;3m
Invoking: `get_bulk_modulus` with `{'concentration': 0}`
responded: [{'text': "Certainly! I'd be happy to help you find the concentration of Cu and Au that will give a bulk modulus around 170 GPa with an error range of ±5 GPa. To do this, we'll use the `get_bulk_modulus` function and adjust the concentration of Cu and Au until we find a result within the desired range.\n\nLet's start by checking the bulk modulus of pure Cu and pure Au, and then we'll adjust the concentrations accordingly.", 'type': 'text', 'index': 0}, {'id': 'toolu_016jLWvUtMSQ7d4awVMwNcBo', 'input': {}, 'name': 'get_bulk_modulus', 'type': 'tool_use', 'index': 1, 'partial_json': '{"concentration": 0}'}]

[0m[36;1m[1;3m134.45229027802856[0m[32;1m[1;3m
Invoking: `get_bulk_modulus` with `{'concentration': 1}`
responded: [{'text': "This result (134.45 GPa) corresponds to pure gold (Au), as the concentration parameter is 0 (meaning 0% Cu, 100% Au).\n\nNow, let's ch