In [3]:
import os
import sys
sys.path.append("../")

import random
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
tqdm.pandas()

import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import joblib
from IPython.display import clear_output
from pymatgen.core import Structure

from torch.utils.data import Dataset, random_split
from transformers import get_cosine_schedule_with_warmup
from diffusers import DDPMScheduler

from src.model.models import CrystalUNetModel
from src.generation.diffusion_generation_loops import train
from src.py_utils.crystal_dataset import CrystalDataset
from src.losses import diffusion_generation_loss, l1_loss
from src.py_utils.comparator import PymatgenComparator
from src.py_utils.sampler import get_dataloaders_pairs, filter_polymorphs, get_balanced_dataloaders_pairs, get_balanced_dataloaders_non_pairs
from src.py_utils.stratified_splitter import train_test_split_with_chemical_balance
from src.utils import seed_everything

In [4]:
from dataclasses import dataclass


@dataclass
class TrainingConfig:
    # Data
    max_nsites = 64
    max_elems = 4
    min_elems = 2

    # Model
    model_channels: int = 128
    num_res_blocks: int = 7
    attention_resolutions=(1, 2, 4, 8)

    # Loss
    coords_loss_coef = 0.5
    lattice_loss_coef = 0.5
    
    # Noise Scheduler
    num_train_timesteps = 1_000
    num_inference_steps = 100
    beta_start = 0.0001
    beta_end = 0.02
    beta_schedule = "squaredcos_cap_v2" 

    # Training
    batch_size = 256
    epochs = 500
    learning_rate = 1e-4
    lr_warmup_steps = 500
    num_workers = 1

    # Accelerator
    gradient_accumulation_steps = 1
    mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision

    device = "cuda"
    random_state = 42 


config = TrainingConfig()
seed_everything(config.random_state)

### Data initialization

In [5]:
PATH = "/home/lazarev/MaterialsDesign/FTCP_data/"
tag = "aflow_database_nsites_4_60"
dataset_path = PATH + f"datasets/{tag}/dataframe.csv"

In [6]:
dataset_df = pd.read_csv(dataset_path)
dataset_df.shape

(3043398, 19)

In [7]:
train_formulas, test_formulas = train_test_split_with_chemical_balance(
    dataset_df, 
    test_size=0.05,
    verbose=True
)

len(train_formulas), len(test_formulas)

Train/test structures df size ratio : 15.10869
Elements absolute difference: 0.12579


(446726, 41345)

In [8]:
test_dataset, train_dataset, test_dataloader, train_dataloader = get_balanced_dataloaders_non_pairs(
    dataset_df,
    train_formulas=train_formulas, 
    test_formulas=test_formulas,
    avg_structures_per_group=1,
    sampling_strategy='train_good_nsites_balaned',
    top_k_good=2,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    min_polymorphs=0.5
)

shape before processing: (3041403, 19)
shape after processing: (3041309, 19)


Converting lattice: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2852477/2852477 [05:56<00:00, 8009.88it/s]
Converting lattice: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41344/41344 [00:05<00:00, 7854.07it/s]


### Model training

In [9]:
model = CrystalUNetModel(
    in_channels=3, # should be equal to num_features (input features) (atomic coordinares)
    dims=1, #this states, that we are using 1D U-Net
    condition_dims=1 + 256 + 256, # num_condition_features 256 - is size of elements condition
    model_channels=config.model_channels, # inner model features
    out_channels=3, # should be equal to num_features (input features) (atomic coordinares)
    num_res_blocks=config.num_res_blocks,
    attention_resolutions=config.attention_resolutions
)

model.to(config.device)

optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

total_steps = int(len(train_dataloader) * config.epochs)
scheduler = get_cosine_schedule_with_warmup(optimizer, 
                                    num_warmup_steps = config.lr_warmup_steps, # Default value in run_glue.py
                                    num_training_steps = total_steps)

In [10]:
from accelerate import Accelerator

accelerator = Accelerator(
    mixed_precision=config.mixed_precision,
    gradient_accumulation_steps=config.gradient_accumulation_steps, 
)

train_dataloader, test_dataloader, model, optimizer, scheduler = accelerator.prepare(
    train_dataloader, test_dataloader, model, optimizer, scheduler
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [11]:
ddpm_scheduler = DDPMScheduler(
    num_train_timesteps=config.num_train_timesteps,
    beta_start=config.beta_start,
    beta_end=config.beta_end,
    beta_schedule=config.beta_schedule,
    clip_sample=False,

)
ddpm_scheduler.set_timesteps(
    num_inference_steps=config.num_inference_steps
)

ddpm_scheduler

DDPMScheduler {
  "_class_name": "DDPMScheduler",
  "_diffusers_version": "0.23.1",
  "beta_end": 0.02,
  "beta_schedule": "squaredcos_cap_v2",
  "beta_start": 0.0001,
  "clip_sample": false,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "sample_max_value": 1.0,
  "steps_offset": 0,
  "thresholding": false,
  "timestep_spacing": "leading",
  "trained_betas": null,
  "variance_type": "fixed_small"
}

In [None]:
train(
    model=model,
    optimizer=optimizer,
    noise_scheduler=ddpm_scheduler,
    loss_function=diffusion_generation_loss,
    metric_function=l1_loss,
    comparator=PymatgenComparator(elm_str_path='../src/data/element.pkl'),
    coords_loss_coef=config.coords_loss_coef,
    lattice_loss_coef=config.lattice_loss_coef,
    epochs=config.epochs,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataset,
    scheduler=scheduler,
    accelerator=accelerator,
    lattice_size=3,
    device=config.device,
    eval_every_n=5,
)

  0%|                                                                                                                                                                                 | 0/500 [00:00<?, ?it/s]
  0%|                                                                                                                                                                               | 0/11143 [00:00<?, ?it/s][A
  0%|                                                                                                                                                                    | 1/11143 [00:25<77:36:50, 25.08s/it][A
  0%|                                                                                                                                                                    | 2/11143 [00:25<32:28:35, 10.49s/it][A
  0%|                                                                                                                                                              