### Requirements

In [1]:
import os
import sys
sys.path.append("../")

import random
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
tqdm.pandas()

import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import joblib
from IPython.display import clear_output
from pymatgen.core import Structure

from torch.utils.data import Dataset, random_split, DataLoader
from transformers import get_cosine_schedule_with_warmup
from diffusers import DDIMScheduler

from src.model.models import CrystalUNetModel
from src.generation.generation import generate_diffusion
from src.inference.inference_data_generation import generate_inference_dataset
from src.utils import seed_everything

In [2]:
from dataclasses import dataclass


@dataclass
class InferenceConfig:
    # Data
    max_nsites = 64
    max_elems = 4
    min_elems = 2

    # Model
    model_channels: int = 128
    num_res_blocks: int = 7
    attention_resolutions=(1, 2, 4, 8)
    
    # Noise Scheduler
    num_train_timesteps = 1_000
    num_inference_steps = 100
    beta_start = 0.0001
    beta_end = 0.02
    beta_schedule = "squaredcos_cap_v2" 
    clip_sample = False

    # Training
    batch_size = 256
    num_workers = 1

    # Accelerator
    mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision

    device = "cuda"
    random_state = 42 


config = InferenceConfig()
seed_everything(config.random_state)

In [3]:
PATH = "../FTCP_data/"

In [4]:
spgs = [  6,   8,  10,  12,  25,  35,  44,  47,  65,  71,  99, 119, 123, 129, 139, 160, 166, 216, 225]
formula = "Ta1W1B6"
step = -0.01
start = -0.3993 - 1
n = 20

df, inferece_dataset = generate_inference_dataset(
        formula,
        spgs,
        step,
        start,
        n,
        return_df=True,
        data_path="../src/data/"
)

df, inferece_dataset, len(inferece_dataset)

(    pretty_formula  spacegroup_relax  enthalpy_formation_atom
 0          Ta1W1B6                 6                  -1.3993
 1          Ta1W1B6                 6                  -1.4093
 2          Ta1W1B6                 6                  -1.4193
 3          Ta1W1B6                 6                  -1.4293
 4          Ta1W1B6                 6                  -1.4393
 ..             ...               ...                      ...
 394        Ta1W1B6               225                  -1.5593
 395        Ta1W1B6               225                  -1.5693
 396        Ta1W1B6               225                  -1.5793
 397        Ta1W1B6               225                  -1.5893
 398        Ta1W1B6               225                  -1.5993
 
 [399 rows x 3 columns],
 <src.inference.inference_data_generation.InferenceCrystalDataset at 0x7f441d115350>,
 399)

In [5]:
dataloader = DataLoader(inferece_dataset, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=False)
len(dataloader)

2

### Model training

In [6]:
model = CrystalUNetModel(
    in_channels=3, # should be equal to num_features (input features) 
    dims=1, #this states, that we are using 1D U-Net
    condition_dims=1 + 256 + 256, # num_condition_features 256 - is size of spacegroups condition
    model_channels=config.model_channels, # inner model features
    out_channels=3, # should be equal to num_features (input features) 
    num_res_blocks=config.num_res_blocks, # idk
    attention_resolutions=config.attention_resolutions
)
model.to(config.device)

CrystalUNetModel(
  (model): UNetModel(
    (time_embed): Sequential(
      (0): Linear(in_features=128, out_features=512, bias=True)
      (1): SiLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
    )
    (label_emb): Linear(in_features=513, out_features=512, bias=True)
    (input_blocks): ModuleList(
      (0): TimestepEmbedSequential(
        (0): Conv1d(3, 128, kernel_size=(3,), stride=(1,), padding=(1,))
      )
      (1-7): 7 x TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
          )
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=512, out_features=128, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
            (1): SiLU()
            (2): D

In [89]:
model.load_state_dict(torch.load("<PATH TO YOU WEIGHTS>"))

<All keys matched successfully>

In [7]:
ddim_scheduler = DDIMScheduler(
    num_train_timesteps=config.num_train_timesteps,
    beta_start=config.beta_start,
    beta_end=config.beta_end,
    beta_schedule=config.beta_schedule,
    clip_sample=config.clip_sample
)
ddim_scheduler.set_timesteps(
    num_inference_steps=config.num_inference_steps
)

ddim_scheduler

DDIMScheduler {
  "_class_name": "DDIMScheduler",
  "_diffusers_version": "0.23.1",
  "beta_end": 0.02,
  "beta_schedule": "squaredcos_cap_v2",
  "beta_start": 0.0001,
  "clip_sample": false,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "rescale_betas_zero_snr": false,
  "sample_max_value": 1.0,
  "set_alpha_to_one": true,
  "steps_offset": 0,
  "thresholding": false,
  "timestep_spacing": "leading",
  "trained_betas": null
}

In [8]:
from accelerate import Accelerator

accelerator = Accelerator(
    mixed_precision=config.mixed_precision,
)

dataloader, model = accelerator.prepare(
    dataloader, model
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [9]:
lattice_size = 3

atoms_generated = []

lattice_generated = []

n_sites_dataset = []
conditions = []

elements_dataset = []

x1_energies = []


model.eval()
for batch in tqdm(dataloader):
    # get needed features
    element_matrix = batch["element_matrix"]
    elemental_property_matrix = batch["elemental_property_matrix"]
    spg = batch["spg"]

    x1_energy = batch["energy"]
    condition = batch["energy"]
    n_sites = batch["n_sites"]
    (
        element_matrix,
        elemental_property_matrix,
        condition,
        spg,
    ) = (
        element_matrix.to(config.device),
        elemental_property_matrix.to(config.device),
        condition.to(config.device),
        spg.to(config.device),
    )
    
    x_0_coords = torch.rand((element_matrix.shape[0], 64, 3)).to(config.device)

    with torch.no_grad():
        output = generate_diffusion(
            model=model, 
            x_0=x_0_coords,
            elements=torch.cat([element_matrix, elemental_property_matrix], dim=-1), 
            y=condition, 
            spg=spg,
            noise_scheduler=ddim_scheduler
        )
        output = output.cpu()
        coords_pred, lattice_pred = output[:, :-4], output[:, -3:]

    atoms_generated.append(coords_pred.cpu())
    lattice_generated.append(lattice_pred.cpu())
    n_sites_dataset.append(n_sites.cpu())
    conditions.append(condition.cpu())
    elements_dataset.append(element_matrix)
    x1_energies.append(x1_energy)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:19<00:00,  9.72s/it]


In [10]:
atoms_generated = torch.vstack(atoms_generated).detach().cpu().numpy()
lattice_generated = torch.vstack(lattice_generated).detach().cpu().numpy()
n_sites_dataset = torch.cat(n_sites_dataset).detach().cpu().numpy()
conditions_dataset = torch.cat(conditions).detach().cpu().numpy()

elements_dataset = torch.vstack(elements_dataset).detach().cpu().numpy()
x1_energies = torch.cat(x1_energies).detach().cpu().numpy()

In [11]:
import joblib
from joblib import Parallel, delayed
from tqdm.auto import tqdm
from pymatgen.core import Structure

elm_str = np.array(joblib.load("../src/data/element.pkl"))


def form_up_structure(one_hot_vectors, coordinates_input, lattice):
    pred_elm = np.argmax(one_hot_vectors, axis=1)
    pred_elm = elm_str[pred_elm]
    struct = Structure(lattice=lattice, species=pred_elm, coords=coordinates_input)
    return struct


indexes_to_make = np.arange(0, len(atoms_generated))
n_jobs = -1

pred_structures = Parallel(n_jobs=n_jobs)(
    delayed(form_up_structure)(
        elements_dataset[i, : n_sites_dataset[i]],
        atoms_generated[i][: n_sites_dataset[i]],
        lattice_generated[i],
    )
    for i in tqdm(indexes_to_make)
)

  0%|          | 0/399 [00:00<?, ?it/s]

In [12]:
cifs = [structure.to(fmt='cif')  for structure in pred_structures]
print(cifs[-1])

# generated using pymatgen
data_TaB6W
_symmetry_space_group_name_H-M   'P 1'
_cell_length_a   84.44050025
_cell_length_b   90.96421129
_cell_length_c   112.80858442
_cell_angle_alpha   19.73437124
_cell_angle_beta   25.54271838
_cell_angle_gamma   40.66772462
_symmetry_Int_Tables_number   1
_chemical_formula_structural   TaB6W
_chemical_formula_sum   'Ta1 B6 W1'
_cell_volume   98635.82983580
_cell_formula_units_Z   1
loop_
 _symmetry_equiv_pos_site_id
 _symmetry_equiv_pos_as_xyz
  1  'x, y, z'
loop_
 _atom_site_type_symbol
 _atom_site_label
 _atom_site_symmetry_multiplicity
 _atom_site_fract_x
 _atom_site_fract_y
 _atom_site_fract_z
 _atom_site_occupancy
  Ta  Ta0  1  28.53800774  68.27568817  24.53317833  1
  W  W1  1  28.09791565  49.11555481  27.25648117  1
  B  B2  1  68.66297150  7.88978481  60.80498123  1
  B  B3  1  43.98616028  26.24205208  42.04769516  1
  B  B4  1  49.02616882  53.97908020  12.62097168  1
  B  B5  1  5.96825743  13.09373665  65.89971161  1
  B  B6  1  24.7681

In [13]:
df["cif"] = cifs
df

Unnamed: 0,pretty_formula,spacegroup_relax,enthalpy_formation_atom,cif
0,Ta1W1B6,6,-1.3993,# generated using pymatgen\ndata_TaB6W\n_symme...
1,Ta1W1B6,6,-1.4093,# generated using pymatgen\ndata_TaB6W\n_symme...
2,Ta1W1B6,6,-1.4193,# generated using pymatgen\ndata_TaB6W\n_symme...
3,Ta1W1B6,6,-1.4293,# generated using pymatgen\ndata_TaB6W\n_symme...
4,Ta1W1B6,6,-1.4393,# generated using pymatgen\ndata_TaB6W\n_symme...
...,...,...,...,...
394,Ta1W1B6,225,-1.5593,# generated using pymatgen\ndata_TaB6W\n_symme...
395,Ta1W1B6,225,-1.5693,# generated using pymatgen\ndata_TaB6W\n_symme...
396,Ta1W1B6,225,-1.5793,# generated using pymatgen\ndata_TaB6W\n_symme...
397,Ta1W1B6,225,-1.5893,# generated using pymatgen\ndata_TaB6W\n_symme...


In [98]:
df.to_csv(f"<FILENAME>.csv", index=False)