In [16]:
import os
import sys
sys.path.append("../")

import random
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
tqdm.pandas()

import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import joblib
from IPython.display import clear_output
from pymatgen.core import Structure

from torch.utils.data import Dataset, random_split
from transformers import get_cosine_schedule_with_warmup
from diffusers import DDPMScheduler

from src.model.models import CrystalUNetModelX0Condition
from src.modification.diffusion_modification_loops import train
from src.py_utils.crystal_dataset import CrystalDataset
from src.losses import diffusion_modification_loss, l1_loss
from src.py_utils.comparator import PymatgenComparator
from src.py_utils.sampler import get_dataloaders_pairs, filter_polymorphs, get_balanced_dataloaders_pairs, get_balanced_dataloaders_non_pairs, filter_polymorphs
from src.py_utils.stratified_splitter import train_test_split_with_chemical_balance
from src.utils import seed_everything

In [2]:
from dataclasses import dataclass


@dataclass
class TrainingConfig:
    # Data
    max_nsites = 64
    max_elems = 4
    min_elems = 2

    # Model
    model_channels: int = 128
    num_res_blocks: int = 7
    attention_resolutions=(1, 2, 4, 8)

    # Loss
    coords_loss_coef = 0.5
    lattice_loss_coef = 0.5
    
    # Noise Scheduler
    num_train_timesteps = 1_000
    num_inference_steps = 100
    beta_start = 0.0001
    beta_end = 0.02
    beta_schedule = "squaredcos_cap_v2" 

    # Training
    batch_size = 256
    epochs = 500
    learning_rate = 1e-4
    lr_warmup_steps = 500
    num_workers = 4

    # Accelerator
    gradient_accumulation_steps = 1
    mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision

    device = "cuda"
    random_state = 42 


config = TrainingConfig()
seed_everything(config.random_state)

### Data initialization

In [3]:
PATH = "/home/lazarev/MaterialsDesign/FTCP_data/"
tag = "aflow_database_nsites_4_60"
dataset_path = PATH + f"datasets/{tag}/dataframe.csv"

dataset_df = pd.read_csv(dataset_path)
dataset_df.shape

(3043398, 19)

In [4]:
# num_polymorphs = dataset_df.groupby("pretty_formula").count()["auid"]
# needed_formulas = num_polymorphs[num_polymorphs > 1].index
# dataset_df = dataset_df[dataset_df["pretty_formula"].isin(needed_formulas)]
# dataset_df.shape

In [5]:
dataset_df = filter_polymorphs(
    dataset_df,
    min_polymorphs=2,
    min_energy=-5,
    max_energy=5,
)

In [6]:
min_energy_deltas = dataset_df.groupby("pretty_formula")[
    "enthalpy_formation_cell"
].apply(lambda group: np.diff(np.sort(group)).max())

energy_noise = 0.01

In [7]:
low_energy_diff_groups = min_energy_deltas[min_energy_deltas < energy_noise].index
dataset_df = dataset_df[
    ~dataset_df["pretty_formula"].isin(low_energy_diff_groups)
].reset_index(drop=True)

In [8]:
train_formulas, test_formulas = train_test_split_with_chemical_balance(
    dataset_df, test_size=0.2, verbose=True
)

Train/test structures df size ratio : 2.95568
Elements absolute difference: 0.05082


In [9]:
test_dataset, train_dataset, test_dataloader, train_dataloader = get_balanced_dataloaders_pairs(
    dataset_df,
    train_formulas=train_formulas,
    test_formulas=test_formulas,
    num_workers=4,
    avg_pairs_per_group=(1, 1),
    sampling_strategy="train_good_nsites_balaned",
    top_k_good=2,
    apply_energy_noising=False,
)

shape before processing: (2877196, 19)
shape after processing: (2244705, 19)


Converting lattice: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1628389/1628389 [03:21<00:00, 8074.86it/s]
Converting lattice: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 616316/616316 [01:15<00:00, 8154.67it/s]


### Model training

In [19]:
model = CrystalUNetModelX0Condition(
    in_channels=3, # should be equal to num_features (input features) (atomic coordinares)
    dims=1, #this states, that we are using 1D U-Net
    condition_dims=1 + 256 + 256 + 256, # num_condition_features 256 - is size of elements condition
    model_channels=config.model_channels, # inner model features
    out_channels=3, # should be equal to num_features (input features) (atomic coordinares)
    num_res_blocks=config.num_res_blocks,
    attention_resolutions=config.attention_resolutions
)

model.to(config.device)

optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

total_steps = int(len(train_dataloader) * config.epochs)
scheduler = get_cosine_schedule_with_warmup(optimizer, 
                                    num_warmup_steps = config.lr_warmup_steps, # Default value in run_glue.py
                                    num_training_steps = total_steps)

In [11]:
from accelerate import Accelerator

accelerator = Accelerator(
    mixed_precision=config.mixed_precision,
    gradient_accumulation_steps=config.gradient_accumulation_steps, 
)

train_dataloader, test_dataloader, model, optimizer, scheduler = accelerator.prepare(
    train_dataloader, test_dataloader, model, optimizer, scheduler
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [12]:
config.num_train_timesteps

1000

In [13]:
ddpm_scheduler = DDPMScheduler(
    num_train_timesteps=config.num_train_timesteps,
    beta_start=config.beta_start,
    beta_end=config.beta_end,
    beta_schedule=config.beta_schedule,
    clip_sample=False,

)
ddpm_scheduler.set_timesteps(
    num_inference_steps=config.num_inference_steps
)

ddpm_scheduler

DDPMScheduler {
  "_class_name": "DDPMScheduler",
  "_diffusers_version": "0.23.1",
  "beta_end": 0.02,
  "beta_schedule": "squaredcos_cap_v2",
  "beta_start": 0.0001,
  "clip_sample": false,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "sample_max_value": 1.0,
  "steps_offset": 0,
  "thresholding": false,
  "timestep_spacing": "leading",
  "trained_betas": null,
  "variance_type": "fixed_small"
}

In [None]:
train(
    model=model,
    optimizer=optimizer,
    noise_scheduler=ddpm_scheduler,
    loss_function=diffusion_modification_loss,
    metric_function=l1_loss,
    comparator=PymatgenComparator(elm_str_path='../src/data/element.pkl'),
    coords_loss_coef=config.coords_loss_coef,
    lattice_loss_coef=config.lattice_loss_coef,
    epochs=config.epochs,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    scheduler=scheduler,
    accelerator=accelerator,
    lattice_size=3,
    device=config.device,
    eval_every_n=5,
)