In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import torchviz
import time
import copy

import matplotlib.pyplot as plt

import ase.io

from utils.soap import compute_spherical_expansion_librascal, PowerSpectrum
from utils.gap import train_sparse_gap_model, train_per_species_sparse_gap_model, train_full_gap_model
from utils.alchemical import AlchemicalCombine

torch.set_default_dtype(torch.float64)

In [None]:
frames = ase.io.read("data/elpasolites_10590.xyz", ":300")
energies = torch.tensor(np.loadtxt("data/elpasolites_10590_evpa.dat")[:300])

n_train = int(0.8 * len(frames))

train_frames = frames[:n_train]
test_frames = frames[n_train:]

train_energies = energies[:n_train]
test_energies = energies[n_train:]

print(f"using {n_train} training frames")

In [None]:
global_species = set()
for frame in frames:
    global_species.update(frame.numbers)

global_species = list(map(lambda u: int(u), global_species))

HYPERS = {
    "interaction_cutoff": 5.0,
    "max_angular": 4,
    "max_radial": 4,
    "gaussian_sigma_constant": 0.3,
    "gaussian_sigma_type": "Constant",
    "cutoff_smooth_width": 0.5,
    "radial_basis": "GTO",
    "compute_gradients": False,
    "expansion_by_species_method": "user defined",
    "global_species": global_species,
}

HYPERS_MJW = {
    "interaction_cutoff": 5.0,
    "max_angular": 9,
    "max_radial": 12,
    "gaussian_sigma_constant": 0.3,
    "gaussian_sigma_type": "Constant",
    "cutoff_smooth_width": 0.5,
    "radial_basis": "GTO",
    "compute_gradients": False,
    "expansion_by_species_method": "user defined",
    "global_species": global_species,
    # ?? central atom weight ≠ 1 ??
}

In [None]:
# train_spherical_expansions, train_slices = compute_spherical_expansion_librascal(train_frames, HYPERS)
# test_spherical_expansions, test_slices = compute_spherical_expansion_librascal(test_frames, HYPERS)

# train_species = torch.hstack([torch.tensor(frame.numbers) for frame in train_frames])
# test_species = torch.hstack([torch.tensor(frame.numbers) for frame in test_frames])

# Validate utilities

In [None]:
# import utils.gap

# def structure_sum(kernel):
#     return utils.gap.common.SumStructureKernel.apply(kernel, test_slices, train_slices)

# rand_kernel = torch.rand((len(test_slices), len(train_slices)), requires_grad=True)
# torch.autograd.gradcheck(structure_sum, rand_kernel, fast_mode=True)

# Baseline: GAP model without species combination

In [None]:
def evaluate_and_plot_model(model, name, file):
    predicted_energies_training_set = model(
        train_spherical_expansions, train_species, train_slices
    )

    predicted_energies_test_set = model(
        test_spherical_expansions, test_species, test_slices
    )

    loss_fn = torch.nn.MSELoss()
    train_loss = loss_fn(predicted_energies_training_set.squeeze(), train_energies)
    test_loss = loss_fn(predicted_energies_test_set.squeeze(), test_energies)

    train_loss *= 100 / train_energies.std()
    test_loss *= 100 / test_energies.std()

    fig, ax = plt.subplots(1, 2, figsize=(12, 5))

    ax[0].scatter(train_energies, predicted_energies_training_set.detach().numpy())
    x = np.linspace(train_energies.min(), train_energies.max(), 20)
    ax[0].plot(x, x, color='r')

    ax[0].set_title(f'Training set — loss = {train_loss:.3} %RMSE')
    ax[0].set_xlabel('DFT')
    ax[0].set_ylabel('Predicted')


    ax[1].scatter(test_energies, predicted_energies_test_set.detach().numpy())
    x = np.linspace(test_energies.min(), test_energies.max(), 20)
    ax[1].plot(x, x, color='r')

    ax[1].set_title(f'Test set — loss = {test_loss:.3} %RMSE')
    ax[1].set_xlabel('DFT')
    ax[1].set_ylabel('Predicted')

    fig.suptitle(name)
    fig.savefig(file, bbox_inches="tight")

# Full GAP

In [None]:
class FullGap(torch.nn.Module):
    def __init__(self, zeta, lambdas):
        super().__init__()
        self.power_spectrum = PowerSpectrum()

        self.zeta = zeta
        self.lambdas = lambdas

        self.model = None

    def fit(self, spherical_expansion, all_species, structures_slices, energies):
        power_spectrum = self.power_spectrum(spherical_expansion)

        self.model = train_full_gap_model(
            power_spectrum,
            all_species,
            structures_slices,
            energies, 
            zeta=self.zeta, 
            lambdas=self.lambdas
        )

    def forward(self, spherical_expansion, all_species, structures_slices):
        ps = self.power_spectrum(spherical_expansion)
        return self.model(ps, all_species, structures_slices)

In [None]:
# full_gap = FullGap(zeta=2, lambdas=[1e-6, 1e-6])
# full_gap.fit(train_spherical_expansions, [], train_slices, train_energies)

# evaluate_and_plot_model(
#     full_gap, 
#     f"Full GAP model",
#     "full-gap-model.pdf",
# )

# Sparse GAP

In [None]:
class SparseGap(torch.nn.Module):
    def __init__(self, n_support, zeta, lambdas):
        super().__init__()
        self.power_spectrum = PowerSpectrum()

        self.n_support = n_support
        self.zeta = zeta
        self.lambdas = lambdas

        self.model = None

    def fit(self, spherical_expansion, all_species, structures_slices, energies):
        power_spectrum = self.power_spectrum(spherical_expansion)

        self.model = train_sparse_gap_model(
            power_spectrum, 
            all_species,
            structures_slices,
            energies, 
            self.n_support, 
            zeta=self.zeta, 
            lambdas=self.lambdas
        )

    def forward(self, spherical_expansion, all_species, structures_slices):
        ps = self.power_spectrum(spherical_expansion)
        return self.model(ps, all_species, structures_slices)

In [None]:
# n_support = 100

# sparse_gap = SparseGap(n_support=n_support, zeta=2, lambdas=[1e-6, 1e-6])
# sparse_gap.fit(train_spherical_expansions, [], train_slices, train_energies)

# evaluate_and_plot_model(
#     sparse_gap, 
#     f"Sparse GAP model — {n_support} GAP support point",
#     "sparse-gap-model.pdf",
# )

### Sparse GAP, one model per central atom species

In [None]:
class PerSpeciesSparseGap(torch.nn.Module):
    def __init__(self, n_support, zeta, lambdas):
        super().__init__()
        self.power_spectrum = PowerSpectrum()

        self.n_support = n_support
        self.zeta = zeta
        self.lambdas = lambdas

        self.model = None

    def fit(self, spherical_expansion, all_species, structures_slices, energies):
        power_spectrum = self.power_spectrum(spherical_expansion)

        self.model = train_per_species_sparse_gap_model(
            power_spectrum, 
            all_species,
            structures_slices,
            energies, 
            self.n_support, 
            zeta=self.zeta, 
            lambdas=self.lambdas
        )

    def forward(self, spherical_expansion, all_species, structures_slices):
        ps = self.power_spectrum(spherical_expansion)
        return self.model(ps, all_species, structures_slices)

In [None]:
# n_support = {
#     species: 5 for species in global_species
# }

# per_species_sparse_model = PerSpeciesSparseGap(n_support, zeta=2, lambdas=[1e-6, 1e-6])
# per_species_sparse_model.fit(train_spherical_expansions, train_species, train_slices, train_energies)

# evaluate_and_plot_model(
#     per_species_sparse_model, 
#     f"Basic sparse model — {sum(n_support.values())} GAP support point",
#     "basic-sparse-model.pdf",
# )

In [None]:
# # plot computational graph on a smaller dataset
# small_hypers = copy.deepcopy(HYPERS)
# small_hypers["max_angular"] = 1
# small_hypers["max_radial"] = 1
# small_hypers["global_species"] = [6, 1]

# small_train_frames = methane_frames[2:]
# small_test_frames = methane_frames[:2]

# small_train_energies = torch.tensor([f.info["energy"] for f in small_train_frames])

# small_train_spherical_expansions, small_train_slices = compute_spherical_expansion_librascal(
#     small_train_frames, small_hypers
# )
# small_train_species = torch.hstack([torch.tensor(frame.numbers) for frame in small_train_frames])

# small_test_spherical_expansions, small_test_slices = compute_spherical_expansion_librascal(
#     small_test_frames, small_hypers
# )
# small_test_species = torch.hstack([torch.tensor(frame.numbers) for frame in small_test_frames])

# small_n_support = {1: 10, 6: 10}

# small_model = BasicSparseGap(small_n_support, zeta=2, lambdas=[1e-6, 1e-6])
# small_model.fit(
#     small_train_spherical_expansions, 
#     small_train_species, 
#     small_train_slices, 
#     small_train_energies
# )

# torchviz_params = {}
# for l, sph in small_test_spherical_expansions.items():
#     sph.requires_grad_(True)
#     torchviz_params[f"sph l={l}"] = sph

# for s, w in small_model.model.weights.items():
#     w.requires_grad_(True)
#     torchviz_params[f"weight species={s}"] = w

# result = small_model(small_test_spherical_expansions, small_test_species, small_test_slices)

# torchviz.make_dot(result, params=torchviz_params)

# Combining species

In [None]:
class MixedSpeciesFullGap(torch.nn.Module):
    def __init__(self, species, n_pseudo_species, zeta, lambdas, optimizable_weights):
        super().__init__()
        self.power_spectrum = PowerSpectrum()
        self.alchemical = AlchemicalCombine(species, n_pseudo_species)

        self.zeta = zeta
        self.lambdas = lambdas

        self.optimizable_weights = optimizable_weights
        self.model = None

    def fit(self, spherical_expansion, all_species, structures_slices, energies):
        combined = self.alchemical(spherical_expansion)
        power_spectrum = self.power_spectrum(combined)
        
        self.model = train_full_gap_model(
            power_spectrum, 
            all_species,
            structures_slices,
            energies, 
            zeta=self.zeta, 
            lambdas=self.lambdas,
            optimizable_weights=self.optimizable_weights,
        )

    def forward(self, spherical_expansion, all_species, structures_slices):
        combined = self.alchemical(spherical_expansion)
        ps = self.power_spectrum(combined)
        return self.model(ps, all_species, structures_slices)


class MixedSpeciesSparseGap(torch.nn.Module):
    def __init__(self, species, n_pseudo_species, n_support, zeta, lambdas, optimizable_weights):
        super().__init__()
        self.power_spectrum = PowerSpectrum()
        self.alchemical = AlchemicalCombine(species, n_pseudo_species)

        self.zeta = zeta
        self.lambdas = lambdas

        self.n_support = n_support

        self.optimizable_weights = optimizable_weights
        self.model = None

    def fit(self, spherical_expansion, all_species, structures_slices, energies):
        combined = self.alchemical(spherical_expansion)
        power_spectrum = self.power_spectrum(combined)
        
        self.model = train_sparse_gap_model(
            power_spectrum, 
            all_species,
            structures_slices,
            energies, 
            self.n_support, 
            zeta=self.zeta, 
            lambdas=self.lambdas,
            optimizable_weights=self.optimizable_weights,
        )

    def forward(self, spherical_expansion, all_species, structures_slices):
        combined = self.alchemical(spherical_expansion)
        ps = self.power_spectrum(combined)
        return self.model(ps, all_species, structures_slices)

In [None]:
class AtomisticDataset(torch.utils.data.Dataset):
    def __init__(self, frames, hypers, energies):
        self.spherical_expansions = []
        for frame in frames:
            se, slices = compute_spherical_expansion_librascal([frame], hypers)
            self.spherical_expansions.append(se)
        
        self.species = [torch.tensor(frame.numbers) for frame in frames]
        self.energies = energies

    def __len__(self):
        return len(self.spherical_expansions)

    def __getitem__(self, idx):
        return self.spherical_expansions[idx], self.species[idx], self.energies[idx]

In [None]:
def collate_data_cpu(data):
    spherical_expansion = {
        lambda_: torch.vstack([d[0][lambda_] for d in data])
        for lambda_ in data[0][0].keys()
    }

    species = torch.hstack([d[1] for d in data])
    energies = torch.vstack([d[2] for d in data])

    slices = []
    start = 0
    for d in data:
        stop = start + d[1].shape[0]
        slices.append(slice(start, stop))
        start = stop

    return spherical_expansion, species, slices, energies

def collate_data_gpu(data):
    spherical_expansion, species, slices, energies = collate_data_cpu(data)

    spherical_expansion = {
        lambda_: se.to(device='cuda') for lambda_, se in spherical_expansion.items()
    }

    return spherical_expansion, species.to(device='cuda'), slices, energies.to(device='cuda')

# Optimization loop using GPU

In [None]:
train_dataset = AtomisticDataset(train_frames, HYPERS_MJW, train_energies)
test_dataset = AtomisticDataset(test_frames, HYPERS_MJW, test_energies)

In [None]:
N_PSEUDO_SPECIES = 4

# Full kernel, optimize everything with gradients
mixed_species_model = MixedSpeciesFullGap(
    global_species, 
    n_pseudo_species=N_PSEUDO_SPECIES, 
    zeta=1, 
    lambdas=[1e-3],
    optimizable_weights=True,
)

# # Full kernel, optimize species with gradients, weights with linear algebra
# mixed_species_model = MixedSpeciesFullGap(
#     global_species, 
#     n_pseudo_species=N_PSEUDO_SPECIES, 
#     zeta=1, 
#     lambdas=[1e-1, 1e-6],
#     optimizable_weights=False,
# )

# # Sparse kernel, optimize species with gradients, weights with linear algebra
# mixed_species_model = MixedSpeciesSparseGap(
#     global_species, 
#     n_support=100,
#     n_pseudo_species=N_PSEUDO_SPECIES, 
#     zeta=1, 
#     lambdas=[1e-1, 1e-6],
#     optimizable_weights=False,
# )

# # Sparse kernel, optimize everything with gradients
# # TODO: this fails since sparse points are not re-selected at each step
# mixed_species_model = MixedSpeciesSparseGap(
#     global_species, 
#     n_support=100,
#     n_pseudo_species=N_PSEUDO_SPECIES, 
#     zeta=1, 
#     lambdas=[1e-3, 1e-6],
#     optimizable_weights=True,
# )

In [None]:
device = "cuda"
# device = "cpu"

mixed_species_model.to(device=device)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=512, 
    shuffle=True,
    collate_fn= collate_data_gpu if device == "cuda" else collate_data_cpu
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset, 
    batch_size=1, 
    shuffle=False,
    collate_fn= collate_data_gpu if device == "cuda" else collate_data_cpu
)


In [None]:
 def loss_optimizer(predicted, actual, regularizer, weights):
    loss = torch.linalg.norm(predicted - actual) ** 2
    # regularize the loss, full dataset std
    loss += regularizer / torch.std(train_energies) * torch.linalg.norm(weights) ** 2

    # TODO alternative: batch std
    # loss += regularizer / torch.std(actual) * torch.linalg.norm(weights) ** 2

    return loss

def loss_mae(predicted, actual):
    return torch.mean(torch.abs(predicted - actual))

In [None]:
# loss_fn = torch.nn.MSELoss()

regularizer = 1e-3

optimizer = torch.optim.AdamW(mixed_species_model.parameters(), lr=0.1)
# optimizer = torch.optim.LBFGS(mixed_species_model.parameters(), lr=1)

all_losses = []

# initialize the weights to an OK value using fit
if mixed_species_model.optimizable_weights:
    for spherical_expansions, species, slices, energies in train_dataloader:
        mixed_species_model.fit(spherical_expansions, species, slices, energies)
        break # only use the first batch

if not mixed_species_model.optimizable_weights:
    regularizer = 0

for epoch in range(500):
    epoch_start = time.time()
    for spherical_expansions, species, slices, energies in train_dataloader:
        def single_step():
            optimizer.zero_grad()
           
            if not mixed_species_model.optimizable_weights:
                mixed_species_model.fit(spherical_expansions, species, slices, energies)
                
            predicted = mixed_species_model(spherical_expansions, species, slices)    

            loss = loss_optimizer(predicted, energies, regularizer, mixed_species_model.model.weights)
            loss.backward()

            return loss

        loss = optimizer.step(single_step)
        all_losses.append(loss.item())

    epoch_time = time.time() - epoch_start
    if epoch % 10 == 0:
        predicted = []
        for spherical_expansions, species, slices, energies in test_dataloader:
            predicted.append(mixed_species_model(spherical_expansions, species, slices))

        predicted = torch.vstack(predicted)
        mae = loss_mae(predicted.cpu(), test_dataset.energies)

        print(f"epoch {epoch} took {epoch_time:.4}s, optimizer loss={loss.item():.4}, test mae={mae:.4}")

In [None]:
plt.semilogy(all_losses)

In [None]:
species_combining_matrix = mixed_species_model.alchemical.combining_matrix.detach().cpu().numpy()

plt.scatter(species_combining_matrix[:, 0], species_combining_matrix[:, 1])

In [None]:
predicted = []
for spherical_expansions, species, slices, energies in test_dataloader:
    predicted.append(mixed_species_model(spherical_expansions, species, slices))

predicted = torch.vstack(predicted)

In [None]:
plt.scatter(test_dataset.energies, predicted.cpu().detach())

In [None]:
torch.mean(torch.abs(test_dataset.energies - predicted.cpu().squeeze()))