### **Getting started with CrystaLLM**
To setup local enviornment, following the **'Getting Started'** of the README.md

We need a clean env with dependencies installed and the model in editable mode

In [None]:
!conda create --yes --quiet --name ft_crystallm_venv python=3.10

In [None]:
%pip install -r requirements.txt

In [None]:
%pip install torch==2.0.1

In [None]:
%pip install -e .

### **Formatting your training data**

To exemplify data processing steps for the package, we'll be using the **JARVIS-DFT** dataset (https://www.nature.com/articles/s41524-020-00440-1). It contains various 3D materials computed with OptB88vdW and TBmBJ methods

In [None]:
# pip install jarvis-tools to get the JARVIS-DFT dataset
%pip install jarvis-tools

We want to generate a folder that contains .cif files. The preferred formatting is of pymatgen's CifWriter - as this was the one used to generate the original model's training data. This can be done by generating a pymatgen structure from a material then converting it to a CIF file. Here lets make a small 1K dataset.

In [None]:
# Download the JARVIS-DFT dataset and get CIFs for a random 1K samples
!python finetune_example_scripts/generate_jarvis_data.py \
    --output_folder 'finetune_example/data_formatting/jarvis_data'

Now we preprocess the files as per the **'Using Your Own CIF Files'** in the README.md

In [7]:
# convert the data to a single tar.gz file
!python bin/prepare_custom.py \
    finetune_example/data_formatting/jarvis_data/ \
    finetune_example/data_formatting/jarvis_data.tar.gz

preparing CIF files...: 100%|███████████████| 1000/1000 [00:23<00:00, 41.68it/s]
prepared CIF files have been saved to finetune_example/data_formatting/jarvis_data.tar.gz


In [None]:
# convert the data to pickle format
!python bin/tar_to_pickle.py \
    finetune_example/data_formatting/jarvis_data.tar.gz \
    finetune_example/data_formatting/jarvis_data.pkl.gz

loading data from finetune_example/data_formatting/jarvis_data.tar.gz...
extracting files...: 100%|███████████████| 1000/1000 [00:00<00:00, 86537.59it/s]
saving data to finetune_example/data_formatting/jarvis_data.pkl.gz...
conversion complete!


In [9]:
# deduplicate the data
!python bin/deduplicate.py \
    finetune_example/data_formatting/jarvis_data.pkl.gz \
    --out finetune_example/data_formatting/jarvis_data_dedup.pkl.gz

loading data from finetune_example/data_formatting/jarvis_data.pkl.gz...
number of CIFs to deduplicate: 1,000
100%|███████████████████████████████████| 1000/1000 [00:00<00:00, 114956.53it/s]
number of entries to write: 996
saving data to finetune_example/data_formatting/jarvis_data_dedup.pkl.gz...


In [10]:
# preprocess the data, here we are reformatting the cif files as per the model requirements
!python bin/preprocess.py \
    finetune_example/data_formatting/jarvis_data_dedup.pkl.gz \
    --out finetune_example/data_formatting/jarvis_data_prep.pkl.gz \
    --workers 4

loading data from finetune_example/data_formatting/jarvis_data_dedup.pkl.gz...
100%|███████████████████████████████████████| 996/996 [00:00<00:00, 2260.71it/s]
number of CIFs: 996
saving data to finetune_example/data_formatting/jarvis_data_prep.pkl.gz...


In [12]:
# split the data into train, val and test sets
!python bin/split.py \
    finetune_example/data_formatting/jarvis_data_prep.pkl.gz \
    --train_out finetune_example/data_formatting/jarvis_data_train.pkl.gz \
    --val_out finetune_example/data_formatting/jarvis_data_val.pkl.gz \
    --test_out finetune_example/data_formatting/jarvis_data_test.pkl.gz \
    --test_size 0.05

loading data from finetune_example/data_formatting/jarvis_data_prep.pkl.gz...
splitting dataset...
number of CIFs in train set: 851
number of CIFs in validation set: 95
number of CIFs in test set: 50
writing train set...
writing validation set...
writing test set...


In [13]:
# tokenize the data 
!python bin/tokenize_cifs.py \
    --train_fname finetune_example/data_formatting/jarvis_data_train.pkl.gz \
    --val_fname finetune_example/data_formatting/jarvis_data_train.pkl.gz \
    --out_dir finetune_example/finetuning/jarvis_train_val/ 

loading data from finetune_example/data_formatting/jarvis_data_train.pkl.gz...
loading data from finetune_example/data_formatting/jarvis_data_train.pkl.gz...
preparing files...: 100%|█████████████████| 851/851 [00:00<00:00, 133738.72it/s]
preparing files...: 100%|█████████████████| 851/851 [00:00<00:00, 143152.03it/s]
tokenizing...: 100%|████████████████████████| 851/851 [00:00<00:00, 3344.84it/s]
train min tokenized length: 194
train max tokenized length: 1,737
train mean tokenized length: 380.57 +/- 152.87
train total unk counts: 0
tokenizing...: 100%|████████████████████████| 851/851 [00:00<00:00, 1964.74it/s]
val min tokenized length: 194
val max tokenized length: 1,737
val mean tokenized length: 380.57 +/- 152.87
val total unk counts: 0
concatenating train tokens...: 100%|██████| 851/851 [00:00<00:00, 582656.33it/s]
concatenating val tokens...: 100%|████████| 851/851 [00:00<00:00, 677682.31it/s]
encoding...
train has 323,867 tokens
val has 323,867 tokens
vocab size: 371
exporting 

In [None]:
# identify the starting indices of the cifs in the tokenized training data, slightly improves model performance
!python bin/identify_starts.py \
    --dataset_fname finetune_example/finetuning/jarvis_train_val/jarvis_train_val.tar.gz \
    --out_fname finetune_example/finetuning/jarvis_train_val/starts.pkl

identifying starts...: 100%|████████| 323867/323867 [00:00<00:00, 505502.96it/s]
writing start indices...


### **Finetuning model on new data**

To finetune the model on the new data, we need to download a pretrained model. For example, the small or large model can be downloaded. Here we will use the small model, but the logic is the same for any other. This is done as per the **'Using a Pre-trained Model'** in the README.md

'The config folder in this project contains a number of model configuration .yaml files. A corresponding .tar.gz model file exists for each .yaml file in that directory that begins with crystallm_, which can be downloaded.' (from README.md)

In [None]:
# download the pretrained model (small version)
!python bin/download.py \
    crystallm_v1_small.tar.gz \
    --out finetune_example/finetuning/

# large model can be downloaded with: python bin/download.py crystallm_v1_large.tar.gz

In [25]:
# decompress the model
!tar xvf finetune_example/finetuning/crystallm_v1_small.tar.gz -C finetune_example/finetuning/

crystallm_v1_small/
crystallm_v1_small/ckpt.pt


Next we want to make a config file that will allow us to finetune the model. Most model parameters are adaptable similarly to any DL model. We then want to make sure we set the ***init_from: "resume"***, and ***out_dir: 'finetune_example/finetuning/crystallm_v1_small'***

_Notes_: 
- Because the model saves to the same directory as the one it loads the model from, if a finetuned model is saved on top of the original one, you must re-decompress the model as above to fine-tune from the base pre-trained model (here crystallm_v1_small/ckpt.pt)
- The **max_iters** argument is set to understand how many iterations the model should train on _in total_. So we look at the config file's max_iters variable for the model we're fine-tuning, then add however many iterations we want to the number stored

In [None]:
# Example config
'''
out_dir: 'finetune_example/finetuning/crystallm_v1_small'
eval_interval: 100  # how often to evaluate against the validation set
eval_iters_train: 80
eval_iters_val: 80
log_interval: 50  # how often to print to the console (1 = every iteration)
init_from: 'resume' # for fine-tuning

# logging
always_save_checkpoint: True  # if set to False, will only save .ckpt if the model improves
validate: True  # whether to validate with a validation set

# data and batching
dataset: 'finetune_example/finetuning/jarvis_train_val'
batch_size: 8  # reduce if running out of memory
gradient_accumulation_steps: 4  # batch_size * gradient_accumulation_steps = effective batch size

# preserve the pre-trained model's parameters
block_size: 1024
n_layer: 8
n_head: 8
n_embd: 512
dropout: 0.1

# editable training parameters (these may need adjustment to specific dataset)
learning_rate: 1e-4  # start from lower as the model is already pre-trained
decay_lr: True
max_iters: 101000  # number of iterations to train for (BASE MODEL IS AT 100K, so add how many more you want to finetune for)
lr_decay_iters: 101000  # set to max_iters
min_lr: 1e-5 # minimum learning rate (learning_rate/10 usually)
warmup_iters: 100  # number of iterations to warm up for
beta2: 0.99  # adam parameters
'''

In [26]:
!python bin/train.py --config=finetune_example_scripts/jarvis_finetune.yaml

Using configuration:
out_dir: finetune_example/finetuning/crystallm_v1_small
eval_interval: 100
log_interval: 50
eval_iters_train: 80
eval_iters_val: 80
eval_only: false
always_save_checkpoint: true
init_from: resume
dataset: finetune_example/finetuning/jarvis_train_val
gradient_accumulation_steps: 4
batch_size: 8
block_size: 1024
n_layer: 8
n_head: 8
n_embd: 512
dropout: 0.1
bias: false
learning_rate: 0.0001
max_iters: 101000
weight_decay: 0.1
beta1: 0.9
beta2: 0.99
grad_clip: 1.0
decay_lr: true
warmup_iters: 100
lr_decay_iters: 101000
min_lr: 1.0e-05
device: cuda
dtype: bfloat16
compile: true
underrep_p: 0.0
validate: true

Creating finetune_example/finetuning/crystallm_v1_small...
Reading start indices from finetune_example/finetuning/jarvis_train_val/starts.pkl...
Found vocab_size = 371 (inside finetune_example/finetuning/jarvis_train_val/meta.pkl)
Resuming training from finetune_example/finetuning/crystallm_v1_small...
number of parameters: 25.36M
Compiling the model (takes a ~min

In [27]:
# rename the finetuned model to avoid confusion
!mv finetune_example/finetuning/crystallm_v1_small finetune_example/finetuning/crystallm_ft_jarvis

### **Evaluating model performance**

Now that our model has been finetuned, lets evaluate our model. The set can be any CIF dataset, as long as it's fully processed/ as per the first section of the notebook:
- In our first example case we will first use the _'finetune_example/data_formatting/jarvis_data_test.pkl.gz'_ generated during preprocessing steps. We will only prompt with the reduced formula, but optionally the space group can be provided.
- In the second example we will generate new structures without specifiying the formula ('Ab Initio')


In [28]:
!python bin/make_prompts.py \
    finetune_example/data_formatting/jarvis_data_test.pkl.gz \
    -o finetune_example/evaluation/prompts_jarvis_test.tar.gz

# optionally add
# --with-spacegroup (to include spacegroup in the prompts)

loading data from finetune_example/data_formatting/jarvis_data_test.pkl.gz...
preparing prompts...: 100%|███████████████████| 50/50 [00:00<00:00, 8741.05it/s]


In [None]:
# lets check the prompts for the test set
!tar -xvf finetune_example/evaluation/prompts_jarvis_test.tar.gz -C finetune_example/evaluation/prompts_jarvis_test

In [32]:
# lets read one of the prompts
with open('finetune_example/evaluation/prompts_jarvis_test/JVASP-4720.txt') as f:
    print(f.read())

data_Pr1Al2Ni3



When generating the structures, we can specify some generation parameters - a few are discussed here: 
- **Top-k:** specifies how many of the most probable tokens to consider when generating next token. Higher top-k will give more creative but less coherent outputs, and the inverse for lower top-k.
- **Temperature:** also controls creativity of model. Low T (under 1) gives more probability mass to most likely tokens during generation, whereas High T (above 1) makes model predictions more diverse by ditributing probability distribution more evenly over possible next tokens.
- **Num-gens:** We can give the model a few attempts to match the 'true' CIF file or specify how many new structures to come up with
- **Prompts:** If we want to generate random materials, we can remove the **'--prompts'** argument so the model comes up with a composition itself.

In [33]:
# lets evaluate the model on test set
!python bin/generate_cifs.py \
    --model finetune_example/finetuning/crystallm_ft_jarvis \
    --prompts finetune_example/evaluation/prompts_jarvis_test.tar.gz \
    --out finetune_example/evaluation/gen_jarvis_test.tar.gz \
    --device cuda \
    --num-gens 20 \
    --top-k 10 \
    --temperature 0.8

extracting prompts...: 100%|█████████████████| 50/50 [00:00<00:00, 60999.19it/s]
generating CIFs from prompts...:   0%|                   | 0/50 [00:00<?, ?it/s]initializing model from finetune_example/finetuning/crystallm_ft_jarvis on cuda:0...
initializing model from finetune_example/finetuning/crystallm_ft_jarvis on cuda:1...
number of parameters: 25.36M
number of parameters: 25.36M
generating CIFs from prompts...: 100%|██████████| 50/50 [04:11<00:00,  5.03s/it]
writing CIF files to finetune_example/evaluation/gen_jarvis_test.tar.gz...: 100%


In [34]:
# lets evaluate the model without specifiying any formulae
!python bin/generate_cifs.py \
    --model finetune_example/finetuning/crystallm_ft_jarvis \
    --out finetune_example/evaluation/gen_jarvis_ab_initio.tar.gz \
    --device cuda \
    --num-gens 20 \
    --top-k 10 \
    --temperature 0.8

generating CIFs from prompts...:   0%|                   | 0/20 [00:00<?, ?it/s]initializing model from finetune_example/finetuning/crystallm_ft_jarvis on cuda:1...
initializing model from finetune_example/finetuning/crystallm_ft_jarvis on cuda:0...
number of parameters: 25.36M
number of parameters: 25.36M
generating CIFs from prompts...: 100%|██████████| 20/20 [00:09<00:00,  2.11it/s]
writing CIF files to finetune_example/evaluation/gen_jarvis_ab_initio.tar.gz...:


In [None]:
# lets check the generated cifs for test set
!tar -xvf finetune_example/evaluation/gen_jarvis_test.tar.gz -C finetune_example/evaluation/gen_jarvis_test

In [39]:
# lets look at one of the generated cifs for test set
with open('finetune_example/evaluation/gen_jarvis_test/JVASP-4720__1.cif') as f:
    print(f.read())

data_Pr1Al2Ni3
loop_
_atom_type_symbol
_atom_type_electronegativity
_atom_type_radius
_atom_type_ionic_radius
Pr 1.1300 1.8500 1.0600
Al 1.6100 1.2500 0.6750
Ni 1.9100 1.3500 0.7400
_symmetry_space_group_name_H-M P6/mmm
_cell_length_a 5.3231
_cell_length_b 5.3231
_cell_length_c 3.8552
_cell_angle_alpha 90.0000
_cell_angle_beta 90.0000
_cell_angle_gamma 120.0000
_symmetry_Int_Tables_number 191
_chemical_formula_structural PrAl2Ni3
_chemical_formula_sum 'Pr1 Al2 Ni3'
_cell_volume 94.5961
_cell_formula_units_Z 1
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Pr Pr0 1 0.0000 0.0000 0.0000 1.0
Al Al1 2 0.3333 0.6667 0.0000 1.0
Ni Ni2 3 0.0000 0.5000 0.5000 1.0




In [None]:
# lets check the generated cifs for ab initio
!tar -xvf finetune_example/evaluation/gen_jarvis_ab_initio.tar.gz -C finetune_example/evaluation/gen_jarvis_ab_initio

In [41]:
# lets look at one of the generated cifs for test set
with open('finetune_example/evaluation/gen_jarvis_ab_initio/1__1.cif') as f:
    print(f.read())

data_Li12Mn6V18O48
loop_
_atom_type_symbol
_atom_type_electronegativity
_atom_type_radius
_atom_type_ionic_radius
Li 0.9800 1.4500 0.9000
Mn 1.5500 1.4000 0.6483
V 1.6300 1.3500 0.7775
O 3.4400 0.6000 1.2600
_symmetry_space_group_name_H-M R-3m
_cell_length_a 5.7953
_cell_length_b 5.7953
_cell_length_c 28.5364
_cell_angle_alpha 90.0000
_cell_angle_beta 90.0000
_cell_angle_gamma 120.0000
_symmetry_Int_Tables_number 166
_chemical_formula_structural Li2MnV3O8
_chemical_formula_sum 'Li12 Mn6 V18 O48'
_cell_volume 829.9055
_cell_formula_units_Z 6
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Li Li0 6 0.0000 0.0000 0.1849 1.0
Li Li1 6 0.0000 0.0000 0.3110 1.0
Mn Mn2 3 -0.0000 -0.0000 0.5000 1.0
Mn Mn3 3 0.0000 0.0000 0.0000 1.0
V V4 18 0.0109 0.5054 0.2493 1.0
O O5 18 0.0244 0.5122 0.1252 1.0
O O6 18 0.0343 0.51

Now lets evaluate the generated CIFs, the results include:
- The fraction of generated CIF files where the printed space group is consistent with the generated structure
- The fraction of generated CIF files that are consistent in terms of atom site multiplicity
- The average bond length reasonableness score, and the fraction of generated CIF files that have reasonable bond lengths
- The fraction of generated CIF files that are valid
- The longest valid generated tokenized length
- The average valid generated tokenized length

A few sensibility parameters can be specified during evaluation (check bin/evaluate_cifs.py args) including:
- The smallest or largest cell length allowable
- The smallest or largest cell angle allowable

In [3]:
# Lets evaluate the generated cifs for the test set
!python bin/evaluate_cifs.py \
    finetune_example/evaluation/gen_jarvis_test.tar.gz \
    -o finetune_example/results/jarvis_test_results.csv

extracting generated CIFs...: 100%|█████| 1000/1000 [00:00<00:00, 144158.93it/s]
100%|███████████████████████████████████████| 1000/1000 [04:06<00:00,  4.05it/s]
space group consistent: 970/1000 (0.970)
atom site multiplicity consistent: 966/1000 (0.966)
avg. bond length reasonableness score: 0.9660 ± 0.1270
bond lengths reasonable: 842/1000 (0.842)
num valid: 837/1000 (0.84)
longest valid generated length: 577
avg. valid generated length: 333.119 ± 55.803


In [None]:
# Lets evaluate the generated cifs ab initio
!python bin/evaluate_cifs.py \
    finetune_example/evaluation/gen_jarvis_ab_initio.tar.gz \
    -o finetune_example/results/jarvis_ab_initio_results.csv \

extracting generated CIFs...: 100%|██████████| 20/20 [00:00<00:00, 37449.14it/s]
  0%|                                                    | 0/20 [00:00<?, ?it/s]ERROR: 'NoneType' object is not subscriptable
 25%|███████████                                 | 5/20 [00:02<00:06,  2.15it/s]ERROR: 
 40%|█████████████████▌                          | 8/20 [00:02<00:02,  4.61it/s]ERROR: 'NoneType' object is not subscriptable
100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.99it/s]
space group consistent: 17/20 (0.850)
atom site multiplicity consistent: 19/20 (0.950)
avg. bond length reasonableness score: 0.9676 ± 0.0909
bond lengths reasonable: 15/20 (0.750)
num valid: 15/20 (0.75)
longest valid generated length: 874
avg. valid generated length: 441.200 ± 162.956


For the wanted application, it may be useful to save all valid CIFs to a new folder, so that they can be processed into standard format CIFs after. This can be done using the **'finetune_example_scripts/evaluate_cifs_custom.py'**. The custom scipt preserves the same logic as above, but it also saves the valid cifs to a new folder if the **'--save_valid_dir'** is specified.

In [13]:
# Lets evaluate the generated cifs ab initio
!python finetune_example_scripts/evaluate_cifs_custom.py \
    finetune_example/evaluation/gen_jarvis_ab_initio.tar.gz \
    -o finetune_example/results/jarvis_ab_initio_results.csv \
    --save_valid_dir finetune_example/results/jarvis_ab_initio_valid

extracting generated CIFs...: 100%|██████████| 20/20 [00:00<00:00, 65484.84it/s]
100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.96it/s]
space group consistent: 17/20 (0.850)
atom site multiplicity consistent: 19/20 (0.950)
avg. bond length reasonableness score: 0.9676 ± 0.0909
bond lengths reasonable: 15/20 (0.750)
num valid: 15/20 (0.75)
longest valid generated length: 874
avg. valid generated length: 441.200 ± 162.956


If we then want to convert the generated valid CIFs to useable and standard format CIF files, we can process the valid structures using the bin/postprocess.py. Lets perform this on the valid ab initio generations.

In [14]:
!python bin/postprocess.py \
    finetune_example/results/jarvis_ab_initio_valid \
    finetune_example/results/jarvis_ab_initio_postprocessed

processed: Li10Mn2Co4O16.cif
processed: Yb4B16Rh4.cif
processed: Li4V3Cr2O10.cif
processed: Na2Nd6Ti4Sb4O28.cif
processed: Ba4Ga4Se8.cif
processed: Li4Ti4Cr4O16.cif
processed: Li8Fe4P8O32.cif
processed: Mg12Ti2Sb2.cif
processed: K2Ag1Hg1.cif
processed: Mn1Cu3S4.cif
processed: Mn1Sn3.cif
processed: Be1Cu1As1.cif
processed: Ca4Mn4Zn4.cif
processed: Rb2Sc2Te2O2.cif
processed: K8Sb4Au4Cl24.cif


In [15]:
# lets check the postprocessed cifs for ab initio generation
with open('finetune_example/results/jarvis_ab_initio_postprocessed/Ba4Ga4Se8.cif') as f:
    print(f.read())

data_Ba4Ga4Se8
_symmetry_space_group_name_H-M Pnma
_cell_length_a 7.1261
_cell_length_b 4.7277
_cell_length_c 13.2755
_cell_angle_alpha 90.0000
_cell_angle_beta 90.0000
_cell_angle_gamma 90.0000
_symmetry_Int_Tables_number 62
_chemical_formula_structural BaGaSe2
_chemical_formula_sum 'Ba4 Ga4 Se8'
_cell_volume 449.9548
_cell_formula_units_Z 4
loop_
 _symmetry_equiv_pos_site_id
 _symmetry_equiv_pos_as_xyz
  1  '-x, y+1/2, -z'
  2  '-x, -y, -z'
  3  '-x+1/2, -y, z+1/2'
  4  'x+1/2, -y+1/2, -z+1/2'
  5  'x, y, z'
  6  'x-1/2, y, -z-1/2'
  7  '-x-1/2, y-1/2, z-1/2'
  8  'x, -y-1/2, z'
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ba Ba0 4 0.2116 0.2500 0.1344 1.0
Ga Ga1 4 0.2376 0.2500 0.8763 1.0
Se Se2 4 0.0347 0.2500 0.7377 1.0
Se Se3 4 0.2489 0.7500 0.9374 1.0




These materials can then be used down the line for other tasks. A few functionalities were not covered in this notebook such as Monte Carlo Tree Search Decoding, single formula prompt generation, extractiong learned embeddings... These can all be found in the **README.md**.

For any additional queries, feel free to e-mail cyprien.bone.24@ucl.ac.uk or k.t.butler@ucl.ac.uk