# Predicting Properties from Composition


SMACT includes a property prediction module that uses pretrained [ROOST](https://doi.org/10.1038/s41467-020-19964-7) (Representation Learning from Stoichiometry) models to predict material properties from chemical composition alone — no crystal structure required.

This notebook demonstrates how to:
1. Predict band gaps with a single function call
2. Use the `RoostPropertyPredictor` class for more control
3. Get uncertainty estimates on predictions
4. Query the model registry
5. Combine property prediction with SMACT screening

## Prerequisites

The property prediction module requires `torch` and `aviary-models`, which can be installed via:

```bash
pip install smact[property_prediction]
```

Or if installing from source:

```bash
uv sync --extra property_prediction
```

## 1. Quick start — `predict_band_gap`

The simplest way to predict band gaps is with the convenience function. It accepts a single composition string or a list of strings.

In [None]:
from smact.property_prediction import predict_band_gap

# Predict band gap for a single material
bg = predict_band_gap("GaN")
print(f"GaN predicted band gap: {bg[0]:.2f} eV")

In [None]:
# Predict band gaps for multiple materials at once
compositions = ["NaCl", "TiO2", "GaN", "Si", "ZnO", "CdTe", "GaAs", "SrTiO3"]
predictions = predict_band_gap(compositions)

for comp, bg in zip(compositions, predictions):
    print(f"  {comp:10s} → {bg:.2f} eV")

## 2. Using `RoostPropertyPredictor`

For more control, use the `RoostPropertyPredictor` class directly. This allows you to:
- Reuse a loaded model across multiple prediction calls
- Access model metadata
- Request uncertainty estimates

In [None]:
from smact.property_prediction import RoostPropertyPredictor

# Create a predictor (loads the model once)
predictor = RoostPropertyPredictor(property_name="band_gap")

# Make predictions — the model stays loaded between calls
result_1 = predictor.predict(["NaCl", "KCl", "LiF"])
result_2 = predictor.predict(["Fe2O3", "Al2O3"])

print("Alkali halides:", result_1)
print("Oxides:", result_2)

## 3. Uncertainty quantification

The ROOST model is trained with a heteroscedastic (robust) loss, which means each prediction comes with an **aleatoric uncertainty** estimate — the model's confidence that varies per sample.

Pass `return_uncertainty=True` to get a `PredictionResult` object.

In [None]:
# Get predictions with uncertainty
result = predictor.predict(
    ["NaCl", "TiO2", "GaN", "Si", "ZnO", "CdTe"],
    return_uncertainty=True,
)

print(f"Result type: {type(result).__name__}")
print(f"Unit: {result.unit}")
print(f"Number of predictions: {len(result)}")
print()

for comp, pred, unc in zip(
    result.compositions, result.predictions, result.uncertainties
):
    print(f"  {comp:8s} → {pred:.2f} ± {unc:.2f} eV")

In [None]:
# PredictionResult can be converted to a dictionary for serialisation
import pprint

pprint.pprint(result.to_dict())

## 4. Querying the model registry

The registry tracks which properties have pretrained models, what fidelity levels are available, and which models are installed.

In [None]:
from smact.property_prediction import (
    get_available_models,
    get_supported_properties,
)
from smact.property_prediction.registry import (
    get_property_description,
    get_property_unit,
)

# What properties can we predict?
properties = get_supported_properties()
print("Supported properties:")
for prop in properties:
    unit = get_property_unit(prop)
    desc = get_property_description(prop)
    print(f"  {prop}: {desc} [{unit}]")

print()

# What models are available?
models = get_available_models()
print("Available models:")
for model in models:
    print(f"  {model}")

## 5. Combining with SMACT screening

A powerful workflow is to first use SMACT's chemical filters to generate candidate compositions, then predict their properties. Here we screen ternary oxides and predict their band gaps.

In [None]:
import smact
from smact.screening import smact_filter

# Screen for valid Zn-Sn-O compositions
elements = [smact.Element(sym) for sym in ["Zn", "Sn", "O"]]
valid = smact_filter(elements, threshold=4)

# Build composition strings from the filter output
compositions = []
for els, charges, stoichs in valid:
    formula = "".join(
        f"{el}{s}" if s > 1 else el for el, s in zip(els, stoichs)
    )
    compositions.append(formula)

# Remove duplicates and predict band gaps
compositions = sorted(set(compositions))
print(f"Found {len(compositions)} unique compositions from SMACT filter")
print()

# Predict band gaps for all candidates
results = predictor.predict(compositions, return_uncertainty=True)

# Show results sorted by predicted band gap
import numpy as np

order = np.argsort(results.predictions)
print(f"{'Composition':>15s}  {'Band gap (eV)':>14s}  {'Uncertainty':>12s}")
print("-" * 46)
for i in order:
    print(
        f"{compositions[i]:>15s}  {results.predictions[i]:>10.2f} eV  "
        f"± {results.uncertainties[i]:.2f} eV"
    )

## Notes

- **Model**: The default model is ROOST trained on ~103K Materials Project DFT band gaps (PBE functional). It achieves an MAE of 0.28 eV and R² of 0.93 on the test set.
- **Input**: Compositions should be standard chemical formulas (e.g., `"NaCl"`, `"TiO2"`, `"Ba2YCu3O7"`). The model is composition-only — polymorphs with the same formula will give the same prediction.
- **Uncertainty**: Aleatoric uncertainties come from the heteroscedastic loss. Larger uncertainties indicate compositions where the model is less confident (e.g., outside the training distribution).
- **Device**: Pass `device="cuda"` to `RoostPropertyPredictor` or `predict_band_gap` for GPU acceleration on large batches.