<a href="https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/atomgpt_example_huggingface.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AtomGPT Structure Generation/Inference example: https://pubs.acs.org/doi/10.1021/acs.jpclett.4c01126
## Author: kamal.choudhary@nist.gov

In [1]:
!pip install -q condacolab
import condacolab
condacolab.install()

⏬ Downloading https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Mambaforge-23.11.0-0-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:17
🔁 Restarting kernel...


# Installation

In [1]:
%%time
import os
os.chdir('/content')
!rm -rf Software
os.makedirs('/content/Software')
os.chdir('/content/Software')
if not os.path.exists('atomgpt'):
  !rm -rf atomgpt
  !git clone https://github.com/usnistgov/atomgpt.git
  os.chdir('atomgpt')
  !pip install -qqq -r dev-requirements.txt
  !pip install -q -e .


Cloning into 'atomgpt'...
remote: Enumerating objects: 566, done.[K
remote: Counting objects: 100% (67/67), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 566 (delta 19), reused 10 (delta 1), pack-reused 499[K
Receiving objects: 100% (566/566), 66.30 MiB | 36.56 MiB/s, done.
Resolving deltas: 100% (254/254), done.
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.1/77.1 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.2/162.2 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.0/51.0 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.4/109.4 kB[0m [31m9.0 MB/s[0m eta [36m0:00

To learn how to train the model on a custom dataset, see https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/atomgpt_example.ipynb

An example prompt

In [2]:
prompt_example = "The chemical formula is MgB2 The  Tc_supercon is 36.483. Generate atomic structure description with lattice lengths, angles, coordinates and atom types."

Load model

In [3]:
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from jarvis.core.atoms import Atoms
from jarvis.core.lattice import Lattice
from tqdm import tqdm
from jarvis.io.vasp.inputs import Poscar
config = PeftConfig.from_pretrained("knc6/atomgpt_mistral_tc_supercon")
base_model = AutoModelForCausalLM.from_pretrained("unsloth/mistral-7b-bnb-4bit")
model = PeftModel.from_pretrained(base_model, "knc6/atomgpt_mistral_tc_supercon")
tokenizer = AutoTokenizer.from_pretrained("unsloth/mistral-7b-bnb-4bit")
alpaca_prompt = """Below is a description of a superconductor material..

### Instruction:
{}

### Input:
{}

### Output:
{}"""


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


adapter_config.json:   0%|          | 0.00/732 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


model.safetensors:   0%|          | 0.00/4.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/155 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/168M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

A few helper function to parse model output

In [4]:
def text2atoms(response):
    tmp_atoms_array = response.split("\n")
    lat_lengths = np.array(tmp_atoms_array[1].split(), dtype="float")
    lat_angles = np.array(tmp_atoms_array[2].split(), dtype="float")
    lat = Lattice.from_parameters(
        lat_lengths[0],
        lat_lengths[1],
        lat_lengths[2],
        lat_angles[0],
        lat_angles[1],
        lat_angles[2],
    )
    elements = []
    coords = []
    for ii, i in enumerate(tmp_atoms_array):
        if ii > 2 and ii < len(tmp_atoms_array):
            tmp = i.split()
            elements.append(tmp[0])
            coords.append([float(tmp[1]), float(tmp[2]), float(tmp[3])])
    atoms = Atoms(
        coords=coords,
        elements=elements,
        lattice_mat=lat.lattice(),
        cartesian=False,
    )
    return atoms

def gen_atoms(prompt="", max_new_tokens=512, model="", tokenizer=""):
    inputs = tokenizer(
        [
            alpaca_prompt.format(
                "Below is a description of a superconductor material.",  # instruction
                prompt,  # input
                "",  # output - leave this blank for generation!
            )
        ],
        return_tensors="pt",
    ).to("cuda")
    outputs = model.generate(
        **inputs, max_new_tokens=max_new_tokens, use_cache=True
    )
    response = tokenizer.batch_decode(outputs)[0].split("# Output:")[1].strip('</s>')
    # print('response',response)
    atoms = text2atoms(response)
    return atoms

def general_relaxer(atoms="", calculator="", fmax=0.05, steps=150):
    ase_atoms = atoms.ase_converter()
    ase_atoms.calc = calculator
    ase_atoms = ExpCellFilter(ase_atoms)

    dyn = FIRE(ase_atoms)
    dyn.run(fmax=fmax, steps=steps)
    return ase_to_atoms(ase_atoms.atoms)

In [5]:
atoms = gen_atoms(prompt=prompt_example, model=model, tokenizer=tokenizer)
print(atoms)


System
1.0
3.07 0.0 0.0
0.0 3.07 0.0
0.0 0.0 3.51
Mg B 
1 2 
direct
0.0 0.0 0.0 Mg
0.667 0.333 0.5 B
0.333 0.667 0.5 B



In [8]:
prompt_example = "The chemical formula is MgCB2 The  Tc_supercon is 36.483. Generate atomic structure description with lattice lengths, angles, coordinates and atom types."
atoms = gen_atoms(prompt=prompt_example, model=model, tokenizer=tokenizer)
print(atoms)


System
1.0
3.1 0.0 0.0
0.0 3.1 0.0
0.0 0.0 3.51
Mg C B 
1 1 2 
direct
0.0 0.0 0.0 Mg
0.5 0.5 0.5 C
0.5 0.0 0.25 B
0.0 0.5 0.75 B



In [None]:
from alignn.ff.ff import AlignnAtomwiseCalculator, default_path
from tqdm import tqdm
from ase.constraints import ExpCellFilter
from sklearn.metrics import mean_absolute_error
import time
from jarvis.core.atoms import ase_to_atoms
from ase.optimize.fire import FIRE

model_path = default_path()
calc =  AlignnAtomwiseCalculator(path=model_path, stress_wt=0.3)

dir_path /usr/local/lib/python3.10/site-packages/alignn/ff/alignnff_wt10
model_path /usr/local/lib/python3.10/site-packages/alignn/ff/alignnff_wt10


AtomGPT based generated structures might be high energy. To opimize the structure, ALIGNN-FF can be useful. See an example in the notebook [here](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/ALIGNN_Structure_Relaxation_Phonons_Interface.ipynb).

In [None]:
!conda env export

name: base
channels:
  - conda-forge
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_gnu
  - archspec=0.2.2=pyhd8ed1ab_0
  - boltons=23.1.1=pyhd8ed1ab_0
  - brotli-python=1.1.0=py310hc6cd4ac_1
  - bzip2=1.0.8=hd590300_5
  - c-ares=1.24.0=hd590300_0
  - ca-certificates=2023.11.17=hbcca054_0
  - cffi=1.16.0=py310h2fee648_0
  - charset-normalizer=3.3.2=pyhd8ed1ab_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - conda=23.11.0=py310hff52083_1
  - conda-libmamba-solver=23.12.0=pyhd8ed1ab_0
  - conda-package-handling=2.2.0=pyh38be061_0
  - conda-package-streaming=0.9.0=pyhd8ed1ab_0
  - distro=1.8.0=pyhd8ed1ab_0
  - fmt=10.1.1=h00ab1b0_1
  - icu=73.2=h59595ed_0
  - jsonpatch=1.33=pyhd8ed1ab_0
  - jsonpointer=2.4=py310hff52083_3
  - keyutils=1.6.1=h166bdaf_0
  - krb5=1.21.2=h659d440_0
  - ld_impl_linux-64=2.40=h41732ed_0
  - libarchive=3.7.2=h2aa1ff5_1
  - libcurl=8.5.0=hca28451_0
  - libedit=3.1.20191231=he28a2e2_2
  - libev=4.33=hd590300_2
  - libffi=3.4.2=h7f98852_5
  - libgcc-n

In [None]:
!pip freeze

accelerate==0.31.0
aiohttp==3.9.5
aiosignal==1.3.1
alignn==2024.4.20
annotated-types==0.7.0
archspec @ file:///home/conda/feedstock_root/build_artifacts/archspec_1699370045702/work
ase==3.23.0
async-timeout==4.0.3
-e git+https://github.com/usnistgov/atomgpt.git@90303df9b53bb77de8fad2e41ecfccf2d35fabf8#egg=atomgpt
attrs==23.2.0
autopep8==2.3.1
bitsandbytes==0.43.1
black==24.4.2
boltons @ file:///home/conda/feedstock_root/build_artifacts/boltons_1703154663129/work
Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1695989787169/work
certifi==2024.6.2
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1696001684923/work
chardet==3.0.4
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1698833585322/work
click==8.1.7
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
conda @ file:///home/conda/feedstock_root/build_artifacts/conda_1701731572133/work
conda-libmamba-solver @ file:///ho

In [None]:
!conda env export

name: base
channels:
  - conda-forge
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_gnu
  - archspec=0.2.2=pyhd8ed1ab_0
  - boltons=23.1.1=pyhd8ed1ab_0
  - brotli-python=1.1.0=py310hc6cd4ac_1
  - bzip2=1.0.8=hd590300_5
  - c-ares=1.24.0=hd590300_0
  - ca-certificates=2023.11.17=hbcca054_0
  - cffi=1.16.0=py310h2fee648_0
  - charset-normalizer=3.3.2=pyhd8ed1ab_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - conda=23.11.0=py310hff52083_1
  - conda-libmamba-solver=23.12.0=pyhd8ed1ab_0
  - conda-package-handling=2.2.0=pyh38be061_0
  - conda-package-streaming=0.9.0=pyhd8ed1ab_0
  - distro=1.8.0=pyhd8ed1ab_0
  - fmt=10.1.1=h00ab1b0_1
  - icu=73.2=h59595ed_0
  - jsonpatch=1.33=pyhd8ed1ab_0
  - jsonpointer=2.4=py310hff52083_3
  - keyutils=1.6.1=h166bdaf_0
  - krb5=1.21.2=h659d440_0
  - ld_impl_linux-64=2.40=h41732ed_0
  - libarchive=3.7.2=h2aa1ff5_1
  - libcurl=8.5.0=hca28451_0
  - libedit=3.1.20191231=he28a2e2_2
  - libev=4.33=hd590300_2
  - libffi=3.4.2=h7f98852_5
  - libgcc-n

In [None]:
# env="""name:base
# channels:
#   - xformers
#   - pytorch
#   - nvidia
#   - conda-forge
#   - defaults
# dependencies:
#   - _libgcc_mutex=0.1=conda_forge
#   - _openmp_mutex=4.5=2_gnu
#   - blas=1.0=mkl
#   - bzip2=1.0.8=h7f98852_4
#   - ca-certificates=2024.2.2=hbcca054_0
#   - cairo=1.18.0=h3faef2a_0
#   - cffi=1.16.0=py39h7a31438_0
#   - cuda-cudart=12.1.105=0
#   - cuda-cupti=12.1.105=0
#   - cuda-libraries=12.1.0=0
#   - cuda-nvrtc=12.1.105=0
#   - cuda-nvtx=12.1.105=0
#   - cuda-opencl=12.4.99=0
#   - cuda-runtime=12.1.0=0
#   - cudatoolkit=11.7.0=hd8887f6_10
#   - expat=2.5.0=hcb278e6_1
#   - filelock=3.15.4=pyhd8ed1ab_0
#   - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
#   - font-ttf-inconsolata=3.000=h77eed37_0
#   - font-ttf-source-code-pro=2.038=h77eed37_0
#   - font-ttf-ubuntu=0.83=hab24e00_0
#   - fontconfig=2.14.2=h14ed4e7_0
#   - fonts-conda-ecosystem=1=0
#   - fonts-conda-forge=1=0
#   - freetype=2.12.1=h267a509_2
#   - gettext=0.21.1=h27087fc_0
#   - gmp=6.3.0=h59595ed_1
#   - gmpy2=2.1.2=py39h376b7d2_1
#   - icu=73.2=h59595ed_0
#   - intel-openmp=2022.1.0=h9e868ea_3769
#   - jinja2=3.1.4=pyhd8ed1ab_0
#   - ld_impl_linux-64=2.40=h41732ed_0
#   - libblas=3.9.0=16_linux64_mkl
#   - libcblas=3.9.0=16_linux64_mkl
#   - libcublas=12.1.0.26=0
#   - libcufft=11.0.2.4=0
#   - libcufile=1.9.0.20=0
#   - libcurand=10.3.5.119=0
#   - libcusolver=11.4.4.55=0
#   - libcusparse=12.0.2.55=0
#   - libexpat=2.5.0=hcb278e6_1
#   - libffi=3.4.2=h7f98852_5
#   - libgcc-ng=13.2.0=h807b86a_2
#   - libgfortran-ng=13.2.0=h69a702a_5
#   - libgfortran5=13.2.0=ha4646dd_5
#   - libglib=2.78.0=hebfc3b9_0
#   - libgomp=13.2.0=h807b86a_2
#   - libiconv=1.17=h166bdaf_0
#   - liblapack=3.9.0=16_linux64_mkl
#   - libnpp=12.0.2.50=0
#   - libnsl=2.0.0=h7f98852_0
#   - libnvjitlink=12.1.105=0
#   - libnvjpeg=12.1.1.14=0
#   - libopenblas=0.3.26=pthreads_h413a1c8_0
#   - libpng=1.6.39=h753d276_0
#   - libprotobuf=3.21.12=hfc55251_2
#   - libsqlite=3.43.0=h2797004_0
#   - libstdcxx-ng=13.2.0=h7e041cc_2
#   - libuuid=2.38.1=h0b41bf4_0
#   - libxcb=1.15=h0b41bf4_0
#   - libxml2=2.11.5=h232c23b_1
#   - libzlib=1.2.13=hd590300_5
#   - llvm-openmp=15.0.7=h0cdce71_0
#   - markupsafe=2.1.5=py39hd1e30aa_0
#   - mkl=2022.1.0=hc2b9512_224
#   - mpc=1.3.1=hfe3b2da_0
#   - mpfr=4.2.1=h9458935_0
#   - mpmath=1.3.0=pyhd8ed1ab_0
#   - ncurses=6.4=hcb278e6_0
#   - networkx=3.2.1=pyhd8ed1ab_0
#   - ninja=1.11.1=h924138e_0
#   - openbabel=3.1.1=py39h421517d_8
#   - openssl=3.2.1=hd590300_1
#   - pcre2=10.40=hc3806b6_0
#   - pip=23.2.1=pyhd8ed1ab_0
#   - pixman=0.42.2=h59595ed_0
#   - pthread-stubs=0.4=h36c2ea0_1001
#   - pycparser=2.22=pyhd8ed1ab_0
#   - python=3.9.18=h0755675_0_cpython
#   - python_abi=3.9=4_cp39
#   - pytorch=2.2.2=py3.9_cuda12.1_cudnn8.9.2_0
#   - pytorch-cuda=12.1=ha16c6d3_5
#   - pytorch-mutex=1.0=cuda
#   - pyyaml=6.0.1=py39hd1e30aa_1
#   - readline=8.2=h8228510_1
#   - setuptools=68.2.2=pyhd8ed1ab_0
#   - sleef=3.5.1=h9b69904_2
#   - sympy=1.12=pypyh9d50eac_103
#   - tk=8.6.13=h2797004_0
#   - torchtriton=2.2.0=py39
#   - typing_extensions=4.10.0=pyha770c72_0
#   - wheel=0.43.0=pyhd8ed1ab_1
#   - xformers=0.0.25.post1=py39_cu12.1.0_pyt2.2.2
#   - xorg-kbproto=1.0.7=h7f98852_1002
#   - xorg-libice=1.1.1=hd590300_0
#   - xorg-libsm=1.2.4=h7391055_0
#   - xorg-libx11=1.8.7=h8ee46fc_0
#   - xorg-libxau=1.0.11=hd590300_0
#   - xorg-libxdmcp=1.1.3=h7f98852_0
#   - xorg-libxext=1.3.4=h0b41bf4_2
#   - xorg-libxrender=0.9.11=hd590300_0
#   - xorg-renderproto=0.11.1=h7f98852_1002
#   - xorg-xextproto=7.3.0=h0b41bf4_1003
#   - xorg-xproto=7.0.31=h7f98852_1007
#   - xz=5.2.6=h166bdaf_0
#   - yaml=0.2.5=h7f98852_2
#   - zlib=1.2.13=hd590300_5
#   - pip:
#       - accelerate==0.31.0
#       - aiohttp==3.9.5
#       - aiosignal==1.3.1
#       - alignn==2024.4.20
#       - annotated-types==0.7.0
#       - ase==3.23.0
#       - async-timeout==4.0.3
#       - attrs==23.2.0
#       - autopep8==2.3.1
#       - bitsandbytes==0.43.1
#       - black==24.4.2
#       - certifi==2024.6.2
#       - chardet==3.0.4
#       - charset-normalizer==3.3.2
#       - click==8.1.7
#       - contourpy==1.2.1
#       - cycler==0.12.1
#       - datasets==2.20.0
#       - dgl==1.1.1
#       - dill==0.3.8
#       - docstring-parser==0.16
#       - eval-type-backport==0.2.0
#       - flake8==7.1.0
#       - fonttools==4.53.0
#       - frozenlist==1.4.1
#       - fsspec==2024.5.0
#       - huggingface-hub==0.23.4
#       - idna==3.7
#       - importlib-resources==6.4.0
#       - jarvis-tools==2024.4.30
#       - joblib==1.4.2
#       - kiwisolver==1.4.5
#       - lmdb==1.4.1
#       - markdown-it-py==3.0.0
#       - matplotlib==3.9.0
#       - mccabe==0.7.0
#       - mdurl==0.1.2
#       - multidict==4.7.6
#       - multiprocess==0.70.16
#       - mypy-extensions==1.0.0
#       - numpy==1.26.4
#       - packaging==24.1
#       - pandas==2.2.2
#       - pathspec==0.12.1
#       - peft==0.11.1
#       - pillow==10.3.0
#       - platformdirs==4.2.2
#       - psutil==6.0.0
#       - pyarrow==16.1.0
#       - pyarrow-hotfix==0.6
#       - pycodestyle==2.12.0
#       - pydantic==2.7.4
#       - pydantic-core==2.18.4
#       - pydantic-settings==2.3.3
#       - pydocstyle==6.3.0
#       - pyflakes==3.2.0
#       - pygments==2.18.0
#       - pyparsing==2.4.7
#       - python-dateutil==2.9.0.post0
#       - python-dotenv==1.0.1
#       - pytz==2024.1
#       - regex==2024.5.15
#       - requests==2.32.3
#       - rich==13.7.1
#       - safetensors==0.4.3
#       - scikit-learn==1.5.0
#       - scipy==1.13.1
#       - sentencepiece==0.2.0
#       - shtab==1.7.1
#       - six==1.16.0
#       - snowballstemmer==2.2.0
#       - spglib==2.4.0
#       - threadpoolctl==3.5.0
#       - tokenizers==0.19.1
#       - tomli==2.0.1
#       - toolz==0.12.1
#       - torchdata==0.7.1
#       - tqdm==4.66.4
#       - transformers==4.41.2
#       - trl==0.8.6
#       - tyro==0.8.4
#       - tzdata==2024.1
#       - urllib3==2.2.2
#       - xmltodict==0.13.0
#       - xxhash==3.4.1
#       - yarl==1.9.4
#       - zipp==3.19.2
# """
# with open(f'/content/conda.yaml', 'w') as f:
#     f.write(env)
# # !conda env update --name base -f conda.yaml