# CrystaLLM Demo

This notebook demonstrates how to use the trained CrystaLLM model to generate crystal structures (CIF files) from chemical formulas.

### Setup

In [None]:
import os
import sys
import torch
from pymatgen.core import Structure

# Add repository root to path
sys.path.append(os.path.abspath('.'))

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

### Configuration

In [None]:
model_dir = "out_crystallm_v1_from_scratch"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = "float16"

print(f"Using model from: {model_dir}")
print(f"Device: {device}")

### Generation Function

In [None]:
import subprocess

def generate_cif(formula, num_samples=1, max_new_tokens=2000):
    print(f"Generating {num_samples} structure(s) for {formula}...")
    
    prompt = f"data_{formula}\n"
    
    cmd = [
        sys.executable, "bin/sample.py",
        f"out_dir={model_dir}",
        f"start={prompt}",
        f"num_samples={num_samples}",
        f"max_new_tokens={max_new_tokens}",
        f"device={device}",
        f"dtype={dtype}",
        "target=console"
    ]
    
    result = subprocess.run(cmd, capture_output=True, text=True)
    
    if result.returncode != 0:
        print("Error:", result.stderr)
        return []
        
    output = result.stdout
    cifs = []
    
    # Extract CIF parts from output
    parts = output.split('---------------')
    for part in parts:
        if "data_" in part:
            start = part.find("data_")
            cif_content = part[start:].strip()
            cifs.append(cif_content)
            
    return cifs

### Generate & Visualize

In [None]:
formula = "NaCl"  # Change this to any formula like SiO2, MgO, Fe2O3

generated_cifs = generate_cif(formula)

if generated_cifs:
    cif_text = generated_cifs[0]
    print("Generated CIF:")
    print(cif_text[:300] + "...\n")
    
    # Verify with pymatgen
    try:
        s = Structure.from_str(cif_text, fmt="cif")
        print("Valid Structure!")
        print(f"Formula: {s.composition.reduced_formula}")
        print(f"Spacegroup: {s.get_space_group_info()}")
        print(f"Density: {s.density:.2f} g/cm3")
    except Exception as e:
        print(f"Invalid CIF: {e}")
else:
    print("No structure generated.")