# Raman spectrum model

This repository contains Python code for performing inference of Raman spectra using an AI model. The model was fine-tuned from the Position-based Equivariant Graph Neural Network (POS-EGNN), a foundation model for chemistry and materials science. The only data required to execute Raman spectrum inference is crystallographic information of the target material.

The POS-EGNN model served as the pre-trained foundation for this work. It was originally trained on 150,000 materials from the Materials Project Trajectory (MPtrj) dataset to predict energies, forces, and stress. A fine-tuning task was then performed on POS-EGNN to develop the final model for Raman spectrum prediction. This Raman model was trained on approximately 5,400 different materials, combining experimental data from the Raman Open Database (ROD) and density functional theory (DFT) Raman spectra from the Computational Raman Database (CRD).

To perform inference, the input data must include the crystallographic information of the material: (i) atomic positions, (ii) unit cell representation (lattice), and (iii) atomic numbers. The model’s output is a unit vector of 4,000 elements, where each position corresponds to a Raman spectrum intensity value. The model predicts a spectrum with a frequency resolution of 1 cm⁻¹, ranging from 0 to 4000 cm⁻¹.

Code (GitHub): https://github.com/IBM/materials/tree/main/models/pos_egnn

Model (HuggingFace): https://huggingface.co/ibm-research/materials.pos-egnn/blob/main/pos-egnn_ft-raman.v3.ckpt

## Getting Started 
Make sure to have Python 3.12 installed.

Create a project folder. 

Copy the folder Morningstar, requirements.txt and inference.ipynb, available on Github, to the project folder. 

Then, follow these steps below to replicate our environment and install the necessary libraries:
##### •	python3.12 -m venv env
##### •	source env/bin/activate
##### •	pip install -r requirements.txt

## Example
Please execute the inference.ipynb step-by-step to perform the Raman spectrum inference with the model. Feel free to adapt it according to your needs.


In [None]:
print("** Loading libraries **")
import wget
from huggingface_hub import hf_hub_download
from pymatgen.core import Structure
import numpy as np
import matplotlib.pyplot as plt
from model.model import Model
import torch
from torch_geometric.data import Data
import warnings
warnings.filterwarnings("ignore")

In [None]:
# Download a CIF file for the inference

# Example:
# Material name: Montroydite (Mercury Oxide)
# Chemical Formula: HgO
# Database: Crystallography Open Database (COD)
# Link: https://www.crystallography.net/cod/9012530.html
# Reference: COD ID = 9012530
# Raman spectrum for reference: https://solsa.crystallography.net/rod/1000171.html

cif_url= "https://www.crystallography.net/cod/9012530.cif?CODSESSION=tnu9n62prput15usci5brvkf7t"
cif_file = 'file.cif'
wget.download(cif_url, cif_file)
print('cif downloaded')

In [None]:
# Extract material structure using Pymatgen library
# z: Array with atomic numbers of the material
# pos: Array with atomic positions of the material
# box: Array with lattice information of the material

structure = Structure.from_file(cif_file)
box = structure.lattice.matrix
pos = structure.cart_coords
z = structure.atomic_numbers

In [None]:
# Input data representation for the inference, in Torch Geometric format

data = Data()
data.pos = torch.tensor(pos, dtype=torch.float)
data.z = torch.tensor(z, dtype=torch.long)
data.box = torch.tensor(box, dtype=torch.float)
data.batch = torch.zeros(len(z), dtype=torch.long)
data.num_graphs=1

In [None]:
# Download the model
# AI model for Raman spectrum inference using crystallograph information

model_file = hf_hub_download(repo_id="ibm-research/materials.pos-egnn", filename="pos-egnn_ft-raman.v3.ckpt")
print('model downloaded')

In [None]:
# Load model file
model = Model.load_from_checkpoint(model_file, strict=False)

In [None]:
# Perform model inference
model.decoders["Spectra"].set_context_state("RAMAN_")
out = model(data)
result = model.decoders["Spectra"](out)
predicted_spectrum = result['spectrum'].detach().numpy()[0]

In [None]:
# Plot raman spectrum
# Spectrum resolution: 1 (cm-1)
# Spectrum range: 0 - 4000 (cm-1)
x = np.arange(4000)
formula = structure.reduced_formula

figure, axis = plt.subplots(figsize=(7, 5))
axis.set_title("Predicted Raman Spectrum", fontsize=18)
axis.set_xlabel("Frequency ($cm^{-1}$)",fontsize=18)
axis.set_ylabel("Intensity (Arbitrary Unit)",fontsize=18)
axis.tick_params(axis='x', labelrotation=90)
plt.yticks(fontsize=16, fontstyle='italic')
plt.xticks(fontsize=16, fontstyle='italic')
axis.plot(x, predicted_spectrum, color='blue', label=formula)
axis.legend(fontsize=16,loc='upper right')