In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
#import torchviz
import time
import copy

import matplotlib.pyplot as plt

import ase.io

from utils.dataset import AtomisticDataset, create_dataloader
from utils.soap import PowerSpectrum
from utils.alchemical import AlchemicalCombine
from utils.linear import LinearModel

torch.set_default_dtype(torch.float64)

In [3]:
torch.cuda.is_available()

False

In [4]:
n_test = 100
n_train = 300

frames = ase.io.read("data/elpasolites_10590.xyz", f":")
energies = torch.tensor(np.loadtxt("data/elpasolites_10590_evpa.dat"))

# frames = ase.io.read("../equistore-examples/data/molecule_conformers_dftb.xyz", ":")
# energies = torch.tensor([frame.info["energy"] for frame in frames])

train_frames = frames[:n_train]
test_frames = frames[-n_test:]

train_energies = energies[:n_train]
test_energies = energies[-n_test:]

print(f"using {n_train} training frames")

using 300 training frames


In [5]:
global_species = set()
for frame in frames:
    global_species.update(frame.numbers)

global_species = list(map(lambda u: int(u), global_species))

# HYPERS_FROM_PAPER = {
#     "interaction_cutoff": 5.0,
#     "max_angular": 9,
#     "max_radial": 12,
#     "gaussian_sigma_constant": 0.3,
#     "gaussian_sigma_type": "Constant",
#     "cutoff_smooth_width": 0.5,
#     "radial_basis": "GTO",
#     "compute_gradients": False,
#     "expansion_by_species_method": "user defined",
#     "global_species": global_species,
# }

HYPERS_SMALL = {
    "cutoff": 5.0,
    "max_angular": 3,
    "max_radial": 4,
    "atomic_gaussian_width": 0.3,
    "cutoff_function": {"ShiftedCosine": {"width": 0.5}},
    "radial_basis": {"Gto": {}},
    "gradients": False,
}

# Optimization loop

In [6]:
# device = "cuda"
device = "cpu"

train_dataset = AtomisticDataset(train_frames, HYPERS_SMALL, train_energies)
test_dataset = AtomisticDataset(test_frames, HYPERS_SMALL, test_energies)

In [7]:
train_dataloader = create_dataloader(
    train_dataset,
    batch_size=512,
    shuffle=True,
    device=device,
)

train_dataloader_no_batch = create_dataloader(
    train_dataset,
    batch_size=len(train_dataset),
    shuffle=False,
    device=device,
)

test_dataloader = create_dataloader(
    test_dataset,
    batch_size=100,
    shuffle=False,
    device=device,
)

In [8]:
def loss_optimizer(predicted, actual, regularizer, weights):
    loss = torch.linalg.norm(predicted.flatten() - actual.flatten()) ** 2
    # regularize the loss, full dataset std
    loss += regularizer / torch.std(train_energies.flatten()) * torch.linalg.norm(weights) ** 2

    return loss

def loss_mae(predicted, actual):
    return torch.mean(torch.abs(predicted.flatten() - actual.flatten()))

In [9]:
class MixedSpeciesLinearModel(torch.nn.Module):
    def __init__(self, 
        species, 
        n_pseudo_species, 
        regularizer,
        optimizable_weights,
        random_initial_weights,
    ):
        super().__init__()

        self.alchemical = AlchemicalCombine(species, n_pseudo_species)
        self.power_spectrum = PowerSpectrum()
        self.model = LinearModel(
            normalize=True, 
            regularizer=regularizer,
            optimizable_weights=optimizable_weights,
            random_initial_weights=random_initial_weights,
        )

        self.optimizable_weights = optimizable_weights
        self.random_initial_weights = random_initial_weights

    def forward(self, spherical_expansion):
        combined = self.alchemical(spherical_expansion)
        power_spectrum = self.power_spectrum(combined)
        energies, _ = self.model(power_spectrum, with_forces=False)
        return energies

    def initialize_parameters(self, spherical_expansion, energies):
        combined = self.alchemical(spherical_expansion)
        power_spectrum = self.power_spectrum(combined)
        self.model.initialize_parameters(power_spectrum, energies)
        

In [10]:
N_PSEUDO_SPECIES = 4
REGULARIZER = 1e-2

mixed_species_model = MixedSpeciesLinearModel(
    global_species, 
    n_pseudo_species=N_PSEUDO_SPECIES, 
    regularizer=[REGULARIZER],
    optimizable_weights=True,
    random_initial_weights=True,
)

In [11]:
mixed_species_model.to(device=device)

# initialize the model
with torch.no_grad():
    for spherical_expansions, energies in train_dataloader_no_batch:
        # we want to intially train the model on all frames, to ensure the
        # support points come from the full dataset.
        mixed_species_model.initialize_parameters(spherical_expansions, energies)

if mixed_species_model.optimizable_weights:
    torch_loss_regularizer = REGULARIZER
else:
    torch_loss_regularizer = 0
    # we can not use batches if we are training with linear algebra, we need to
    # have all training frames available
    assert train_dataloader.batch_size >= len(train_frames)

In [12]:
lr = 0.1
# optimizer = torch.optim.AdamW(
#     mixed_species_model.parameters(), 
#     lr=lr, weight_decay=0.0
# )

optimizer = torch.optim.LBFGS(
    mixed_species_model.parameters(), 
    lr=lr,
)

all_losses = []


filename = f"{mixed_species_model.__class__.__name__}-{N_PSEUDO_SPECIES}-mixed-{n_train}-train"
if mixed_species_model.optimizable_weights:
    filename += "-opt-weights"

if mixed_species_model.random_initial_weights:
    filename += "-random-weights"

output = open(f"{filename}.dat", "w")
output.write("# epoch  train_loss  test_mae\n")
n_epochs_total = 0

In [21]:
for epoch in range(5):
    epoch_start = time.time()

    # if UPDATE_SUPPORT_POINTS:
    #     # to update the support points, we need to have all training data at
    #     # once in memory
    #     for spherical_expansions, species, slices, _ in train_dataloader_no_batch:
    #         assert len(slices) == len(train_frames)
    #         # use `select_again=True` to re-select the same number of support
    #         # points. this might make convergence slower, but maybe able to
    #         # reach a lower final loss?
    #         #
    #         # with `select_again=False`, the environments selected in the first
    #         # fit above are used as support points
    #         mixed_species_model.update_support_points(
    #             spherical_expansions, species, slices, select_again=False
    #         )

    for spherical_expansions, energies in train_dataloader:
        def single_step():
            optimizer.zero_grad()
           
            if not mixed_species_model.optimizable_weights:
                mixed_species_model.initialize_parameters(spherical_expansions, energies)
                
            predicted = mixed_species_model(spherical_expansions)

            loss = loss_optimizer(
                predicted, 
                energies, 
                torch_loss_regularizer, 
                mixed_species_model.model.weights
            )
            loss.backward(retain_graph=True)

            return loss
            
        loss = optimizer.step(single_step)
        all_losses.append(loss.item())

    epoch_time = time.time() - epoch_start
    if epoch % 1 == 0:
        with torch.no_grad():
            predicted = []
            reference = []
            for spherical_expansions, energies in test_dataloader:
                reference.append(energies)
                predicted.append(mixed_species_model(spherical_expansions))

            reference = torch.vstack(reference)
            predicted = torch.vstack(predicted)
            mae = loss_mae(predicted.cpu(), reference)

            output.write(f"{n_epochs_total} {loss} {mae}\n")
            output.flush()

        print(f"epoch {n_epochs_total} took {epoch_time:.4}s, optimizer loss={loss.item():.4}, test mae={mae:.4}")
        
    n_epochs_total += 1

epoch 0 took 15.44s, optimizer loss=713.1, test mae=0.7252


KeyboardInterrupt: 