In [1]:
from geomloss import SamplesLoss

In [2]:
import numpy as np
import pandas as pd
import sys
import torch.nn as nn
import torch.nn.functional as F

import torch
import madgrad
import torch.autograd as auto
sys.path.append("./set_transformer")
from modules import ISAB, PMA, SAB
import matplotlib.pyplot as plt
from sklearn.preprocessing import RobustScaler


In [3]:
from sklearn.model_selection import train_test_split
def build_dataset():
    data_gpaw = pd.read_pickle("data/gpaw_forces_dataframe.pickle")
    data_vasp = pd.read_pickle("data/VASP_MoS2_defects.pickle")
    assert (data_vasp.structures.iloc[0].lattice._matrix == data_gpaw.structure.iloc[0].cell.array).all()
    
    positions = np.concatenate([i.positions[np.newaxis, :, :] for i in data_gpaw.structure] +
                               [i.cart_coords[np.newaxis, :, :] for i in data_vasp.structures],
                               axis=0).astype(np.float32)
    #positions -= positions.mean(axis=(0,1), keepdims=True)
    #positions /= positions.std(axis=(0,1), keepdims=True)
    
    energies = np.concatenate([data_gpaw.energy.values, data_vasp.energy.values]).astype(np.float32)
#     energies -= energies.mean()
    energies = energies.reshape(-1, 1)
    transformer = RobustScaler().fit(energies)

    energies = transformer.transform(energies)


    forces = np.concatenate([f[np.newaxis, :, :] for f in data_gpaw.forces] +
                            [np.zeros((len(data_vasp), positions.shape[1], positions.shape[2]), dtype=np.float32)],
                            axis=0).astype(np.float32)
    
    types = np.concatenate([i.get_atomic_numbers()[np.newaxis, :] for i in data_gpaw.structure] + 
                           [np.array(i.atomic_numbers)[np.newaxis, :] for i in data_vasp.structures],
                           axis=0).astype(np.int32)
    return list(map(np.array,
                    train_test_split(positions, types, energies, forces, test_size=0.25, random_state=1421))),\
           data_gpaw.structure.iloc[0].cell.array.astype(np.float32)

In [4]:
def fixup_initialization(args):
    temp_state_dic = {}
    en_layers = args.encoder_layers
    de_layers = args.decoder_layers

    if args.Tfixup:
        for name, param in self.named_parameters():
            if name in ["fc1.weight",
                        "fc2.weight",
                        "self_attn.out_proj.weight",
                        ]:
                temp_state_dic[name] = (0.67 * (en_layers) ** (- 1. / 4.)) * param
            elif name in ["self_attn.v_proj.weight",]:
                temp_state_dic[name] = (0.67 * (en_layers) ** (- 1. / 4.)) * (param * (2**0.5))

    for name in self.state_dict():
        if name not in temp_state_dic:
            temp_state_dic[name] = self.state_dict()[name]
    self.load_state_dict(temp_state_dic)
    

# temp_state_dict = embed_tokens.state_dict()
# temp_state_dict["weight"] = (9 * args.encoder_layers) ** (- 1 / 4) * temp_state_dict["weight"]
# embed_tokens.load_state_dict(temp_state_dict)

In [22]:
class SetTransformer(nn.Module):
    def __init__(
        self,
        dim_input=4,
        num_outputs=1,
        dim_output=1,
        num_inds=32,
        dim_hidden=128,
        num_heads=16,
        ln=False,
    ):
        super(SetTransformer, self).__init__()
        self.enc = nn.Sequential(
            SAB(dim_input, dim_hidden, num_heads, ln=ln),
            SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
            ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln),
#             ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln),
#             ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln),
#             ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln),
#             ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln),
        )
        self.dec = nn.Sequential(
            PMA(dim_hidden, num_heads, num_outputs, ln=ln),
            nn.Linear(dim_hidden, dim_hidden),
            nn.LeakyReLU(),
            nn.Linear(dim_hidden, dim_hidden),
            nn.LeakyReLU(),
            nn.Linear(dim_hidden, dim_output),
        )

    def forward(self, X):
        return self.dec(self.enc(X)).squeeze()

In [19]:
(positions, test_positions, \
types, test_types, \
energies, test_energies, \
forces, test_forces), lattice = build_dataset()
species = (types==types[0,0]).astype(np.float32)
energy_mean = np.mean(energies)
print('positions.shape = {}'.format(positions.shape))
print('<E> = {}'.format(energy_mean))

positions.shape = (616, 105, 3)
<E> = -0.7444171905517578


In [20]:
inputs_t = torch.tensor(np.concatenate([positions, np.expand_dims(species, -1)], axis=2), requires_grad=True).cuda()
energies_t = torch.tensor(energies).cuda()
forces_t = torch.tensor(forces).cuda()

In [23]:
model = SetTransformer(dim_hidden=256, num_heads=4, num_inds=16).cuda()
model.load_state_dict(torch.load('data/0.263_ev.pth'), strict=False)

_IncompatibleKeys(missing_keys=['enc.2.I', 'enc.2.mab0.fc_q.weight', 'enc.2.mab0.fc_q.bias', 'enc.2.mab0.fc_k.weight', 'enc.2.mab0.fc_k.bias', 'enc.2.mab0.fc_v.weight', 'enc.2.mab0.fc_v.bias', 'enc.2.mab0.fc_o.weight', 'enc.2.mab0.fc_o.bias', 'enc.2.mab1.fc_q.weight', 'enc.2.mab1.fc_q.bias', 'enc.2.mab1.fc_k.weight', 'enc.2.mab1.fc_k.bias', 'enc.2.mab1.fc_v.weight', 'enc.2.mab1.fc_v.bias', 'enc.2.mab1.fc_o.weight', 'enc.2.mab1.fc_o.bias'], unexpected_keys=['enc.3.I', 'enc.3.mab0.fc_q.weight', 'enc.3.mab0.fc_q.bias', 'enc.3.mab0.fc_k.weight', 'enc.3.mab0.fc_k.bias', 'enc.3.mab0.fc_v.weight', 'enc.3.mab0.fc_v.bias', 'enc.3.mab0.fc_o.weight', 'enc.3.mab0.fc_o.bias', 'enc.3.mab1.fc_q.weight', 'enc.3.mab1.fc_q.bias', 'enc.3.mab1.fc_k.weight', 'enc.3.mab1.fc_k.bias', 'enc.3.mab1.fc_v.weight', 'enc.3.mab1.fc_v.bias', 'enc.3.mab1.fc_o.weight', 'enc.3.mab1.fc_o.bias', 'enc.4.I', 'enc.4.mab0.fc_q.weight', 'enc.4.mab0.fc_q.bias', 'enc.4.mab0.fc_k.weight', 'enc.4.mab0.fc_k.bias', 'enc.4.mab0.fc_v.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1000, gamma=0.9)


force_coefficient = 0.1
losses_en = []
losses_f = []


model.train()
for iteration in range(10000):
    preds = model(inputs_t)
    energies_t = energies_t.reshape(-1)
    mse_en = F.mse_loss(preds, energies_t)
#     l1 = l1(preds, energies_t)
    pred_forces = - auto.grad(preds.sum(), inputs_t, retain_graph=True,
                        create_graph=True)[0]
    mse_f = F.mse_loss(pred_forces[..., :3], forces_t)

    loss = mse_en + force_coefficient * mse_f
    
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    losses_en.append(mse_en.item())
    losses_f.append(mse_f.item())

    
    print(f"{iteration}  LR: {scheduler.get_last_lr()[0]:.6f}  loss: {loss.cpu().detach().item()} Energy RMSE: {np.sqrt(losses_en[-1]):.3f} eV \
        Force RMSE: {np.sqrt(losses_f[-1]):.3f} eV", end='\r')
    

4437  LR: 0.000007  loss: 0.18036767840385437 Energy RMSE: 0.407 eV         Force RMSE: 0.381 eV

In [None]:
# # Save model
# torch.save(model.state_dict(), 'data/0.263_ev.pth')
# # Just incase...
# torch.save(optimizer.state_dict(), 'data/optim_0.263_ev.pth')

In [None]:
plt.plot(losses)

In [None]:
preds_np = preds.cpu().detach().numpy()

In [None]:
fig, ax = plt.subplots()
ax.scatter(energies, preds_np)
ax.set_ylabel("Normed predicted total energy, eV")
ax.set_xlabel("Normed DFT total energy, eV");

In [None]:
plt.plot(energies, '.')
plt.plot(preds_np, 'r.')

In [None]:
# testing
(positions, test_positions, \
types, test_types, \
energies, test_energies, \
forces, test_forces), lattice = build_dataset()
species = (test_types==test_types[0,0]).astype(np.float32)
energy_mean_test = np.mean(test_energies)
print('positions.shape = {}'.format(test_positions.shape))
print('<E> = {}'.format(energy_mean_test))


inputs_test = torch.tensor(np.concatenate([test_positions, np.expand_dims(species, -1)], axis=2), requires_grad=True).cuda()
energies_test = torch.tensor(test_energies)
preds_test = model(inputs_test)

print(f'Test rmse: {torch.nn.MSELoss()(energies_test.reshape(-1), preds_test.cpu()).sqrt().item()} ev')

In [None]:
preds_test_ = preds_test.cpu().detach().numpy()
fig, ax = plt.subplots()
ax.scatter(energies_test, preds_test_)
ax.set_ylabel("Normed predicted total energy, eV")
ax.set_xlabel("Normed DFT total energy, eV");


In [None]:
plt.plot(energies_test, '.')
plt.plot(preds_test_, 'r.')

In [None]:
test_positions_ = torch.tensor(test_positions, requires_grad=True).cuda()

In [None]:
force = - auto.grad(preds_test.sum(), test_positions_)[0]

In [None]:
np.sqrt(((test_forces - force_)**2).mean())

In [None]:
fig, ax = plt.subplots()
ax.plot(test_forces[..., 0], force_[..., 0], 'b.', alpha=0.5)
ax.plot(test_forces[..., 0], force_[..., 0], 'g.', alpha=0.5)
ax.plot(test_forces[..., 0], force_[..., 0], 'r.', alpha=0.5)

ax.set_ylabel("Force")
ax.set_xlabel("Predicted force");


In [None]:
#Hyperparameter search...

In [None]:
import optuna
import logging

In [None]:
inputs_t = torch.Tensor(np.concatenate([positions, np.expand_dims(species, -1)], axis=2)).cuda()
energies_t = torch.Tensor(energies).cuda()

def objective(trial):
    lr = trial.suggest_uniform('lr', 1e-6, 1e-3)
    
    model = SetTransformer(dim_hidden=256, num_heads=16).cuda()
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-7)
    optimizer = madgrad.MADGRAD(model.parameters(), lr=lr)
#     scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-6,
#                                                 anneal_strategy='linear', div_factor=100,
#                                                 steps_per_epoch=1,
#                                                 epochs=5000)

    losses = []
    mse = torch.nn.MSELoss()#SamplesLoss("laplacian", blur=0.1)
    l1 = torch.nn.L1Loss()
    model.train()
    for iteration in range(10):
        preds = model(inputs_t)
    #     loss = criterion(preds.unsqueeze(0).view(1, -1, 1), energies_t.unsqueeze(0))
        global energies_t
        energies_t = energies_t.reshape(-1)
        l2 = mse(preds, energies_t)
        loss = l2 + l1(preds, energies_t)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    #     scheduler.step()
        losses.append(l2.item())
        print(f"{iteration}  Train RMSE {np.sqrt(losses[-1]):.3f} eV", end='\r')
    return np.sqrt(losses[-1])
    

In [None]:
pruner: optuna.pruners.BasePruner = (optuna.pruners.MedianPruner())
    
logging.basicConfig(filename='search.log', level=logging.INFO)

def print_best_callback(study, trial):
    with open('search.log', 'a+') as file:
        file.writelines(f"Best value: {study.best_value}, Best params: {study.best_trial.params}") 
    print(f"Best value: {study.best_value}, Best params: {study.best_trial.params}")


study = optuna.create_study(direction="minimize", pruner=pruner)
study.optimize(objective, n_trials=300, n_jobs=1, callbacks=[print_best_callback])

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
trial = study.best_trial

print("  Value: {}".format(trial.value))

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

In [None]:
torch.tensor(float('inf'))

In [None]:
study.optimize