## Imports

In [1]:
import dostools
import importlib
import numpy as np
import pickle
import torch
import sys
import matplotlib.pyplot as plt
import copy
from tqdm import tqdm
torch.set_default_dtype(torch.float64) 
sys.modules['dostools.src'] = dostools

In [2]:
import dostools.datasets.data as data
import dostools.utils.utils as utils

n_structures = 1039
np.random.seed(0)
n_train = int(0.8 * n_structures)
train_index = np.arange(n_structures)
np.random.shuffle(train_index)
test_index = train_index[n_train:]
train_index = train_index[:n_train]

with torch.no_grad():
    structures = data.load_structures(":")
    n_structures = len(structures) #total number of structures
    for structure in structures:#implement periodicity
        structure.wrap(eps = 1e-12) 
    n_atoms = np.zeros(n_structures, dtype = int) #stores number of atoms in each structures
    for i in range(n_structures):
        n_atoms[i] = len(structures[i])

    #eigenergies, emin, emax = dostools.src.datasets.data.load_eigenenergies(unpack = True, n_structures = len(structures))
    xdos = torch.tensor(data.load_xdos())
    ldos = torch.tensor(data.load_ldos())
    ldos *= 2

    print ("ldos shape is {}".format(ldos.shape))
    mean_dos_per_atom = ldos[train_index].mean(axis = 0) #only calculated for train set to prevent data leakage
    print ("mean dos shape is {}".format(mean_dos_per_atom.shape))
    
    
    y_pw = ldos - mean_dos_per_atom
    y_lcdf = torch.cumsum(y_pw, dim = 1)
    _, pc_vectors = utils.build_pc(ldos[train_index], mean_dos_per_atom[None,:], n_pc = 10)
    y_pc = utils.build_coeffs(ldos - mean_dos_per_atom[None,:], pc_vectors)
    Silicon = data.load_features()
    kMM = data.load_kMM()

ldos shape is torch.Size([1039, 778])
mean dos shape is torch.Size([778])
Variance covered with 10 PCs is = 0.9871211778950163


## Evaluator

In [3]:
import dostools.evaluation.evaluation as evaluation
importlib.reload(evaluation)
import dostools.models.training as training
importlib.reload(training)

targets = {
    'pw' : ldos,
    'lcdf' : y_lcdf,
    'pc' : y_pc
}
evaluator = evaluation.Evaluator(targets, xdos, mean_dos_per_atom)

## Dataset and DataLoader

In [11]:
import dostools.datasets.dataset as data
from torch.utils.data import TensorDataset, DataLoader
import dostools.consistency.consistency as consistency

device = 'cpu'
kwargs = {"pin_memory":True} if device == "cuda:0" else {}
#Dataset
y_shifted = y_pw
#y_shifted[:100] = consistency.shifted_ldos(y_shifted[:100], xdos, torch.zeros(100)-10)
train_data_soap = TensorDataset(Silicon.Features["structure_avedescriptors"][train_index].double(), y_shifted[train_index].double())
train_data_kernel = TensorDataset(Silicon.Features["structure_avekerneldescriptors"][train_index].double(), y_shifted[train_index].double())

test_data_soap = TensorDataset(Silicon.Features["structure_avedescriptors"][test_index].double(), y_shifted[test_index].double())
test_data_kernel = TensorDataset(Silicon.Features["structure_avekerneldescriptors"][test_index].double(), y_shifted[test_index].double())

#Dataloader

train_dataloader_soap = DataLoader(train_data_soap, batch_size = n_train, shuffle = True, **kwargs)
train_dataloader_kernel = DataLoader(train_data_kernel, batch_size = n_train, shuffle = True, **kwargs)



## Loss Function

In [6]:
import dostools.src.consistency.consistency as consistency
import dostools.src.loss.loss as loss
importlib.reload(loss)
importlib.reload(consistency)

def t_get_BF_shift_index_mse(prediction, true, shift_range, xdos = None, perc = False):
    if xdos is not None:
        mse = torch.zeros(true.shape[0])
        index = torch.zeros(true.shape[0])
        for i, pred in enumerate((prediction)):
            shifted_preds = consistency.shifted_ldos(pred.repeat(shift_range.shape[0],1), xdos, shift_range)
            mse[i], index[i] = torch.min(loss.t_get_each_mse(shifted_preds, true[i].repeat(shift_range.shape[0],1)),0)
        mse = torch.mean(mse, 0)
        
        return mse,index        
    else:
        raise ValueError("xdos not defined")


## Evaluation Function

In [7]:
def t_get_opt_BF_shift_rmse(prediction, true, opt_shift, xdos = None, perc = False):
    if xdos is not None:
        rmse = torch.zeros(true.shape[0])
        index = torch.zeros(true.shape[0])
        shifted_preds = consistency.shifted_ldos(prediction, xdos, opt_shift)
        rmse = torch.sqrt(torch.trapezoid((shifted_preds - true)**2, xdos, axis = 1)).mean()

        if perc:
            mean = true.mean(axis = 0)
            std = torch.sqrt(torch.trapezoid((true - mean)**2, xdos, axis = 1)).mean()
            loss = (100 * rmse/std)
        return loss        
    else:
        raise ValueError("xdos not defined")

## Alignment Model

In [16]:
import dostools.src.models.models as models
import dostools.src.models.training as training
import dostools.src.models.architectures as architecture
import dostools.src.loss.loss as loss
import torch.nn as nn

importlib.reload(models)
importlib.reload(training)
importlib.reload(architecture)
importlib.reload(loss)


class AlignmentLinearModel(nn.Module):
    def __init__(self, inputSize, outputSize, train_size, xdos, reg, opt, device):
        super(AlignmentLinearModel, self).__init__()
        self.linear = nn.Linear(inputSize, outputSize, bias = False)
        self.xdos = xdos
        self.opt = opt
        self.device = device
        self.reg = torch.tensor(reg, requires_grad = False).to(self.device)
        self.alignment = torch.zeros(train_size, device = self.device)
        self.to(self.device)

    def forward(self, x):
        """
        Performs the transformations to the features based on the model
        
        Args:
            x (tensor): input features
        
        Returns:
            tensor: output
        """
        out = self.linear(x)
        return out

    def fit(self, traindata_loader, valdata_loader, loss, lr ,n_epochs):
        """
        Fits the model based on the training data, early stopping is based on performance on training data (or validation data)
        Returns the loss history 
        
        Args:
            traindata_loader (DataLoader): Train dataloader
            valdata_loader (DataLoader): Validation dataloader
            loss (function): Loss function
            lr (float): Learning rate
            n_epochs (int): Max number of epochs
        
        Returns:
            list: Loss history of the training process
        """
        if self.opt == "Adam":
            opt = torch.optim.Adam(self.parameters(), lr = lr, weight_decay = self.reg.item())
            if valdata_loader is not None:
                threshold = 1000
                scheduler_threshold = 100
            else:
                threshold = 1000
                scheduler_threshold = 1000
            tol = 1e-4
        if self.opt == "LBFGS":
            opt = torch.optim.LBFGS(self.parameters(), lr = lr)
            if valdata_loader is not None:
                threshold = 2000
                scheduler_threshold = 2000
            else:
                threshold = 30
                scheduler_threshold = 5
            tol = 1e-2
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor = 0.1, patience = scheduler_threshold)#0.5)
        best_state = copy.deepcopy(self.state_dict())
        lowest_loss = torch.tensor(9999)
        pred_loss = torch.tensor(0)
        trigger = 0
        loss_history =[]
        pbar = tqdm(range(n_epochs))
        
        for epoch in pbar:
            pbar.set_description(f"Epoch: {epoch}")
            if valdata_loader is not None:
                pbar.set_postfix(val_loss = lowest_loss.item(), trigger = trigger, train_loss = pred_loss.item())
            else:
                pbar.set_postfix(pred_loss = pred_loss.item(), lowest_loss = lowest_loss.item(), trigger = trigger)

            for x_data, y_data in traindata_loader:
                opt.zero_grad()
                x_data, y_data = x_data.to(self.device), y_data.to(self.device)
                if self.opt == "LBFGS":
                    def closure(predictions = False):
                        """
                        Function is necessary for LBFGS, returns the total loss of the model
                        
                        Args:
                            predictions (bool, optional): Returns prediction loss if true, returns total loss if False
                        
                        Returns:
                            tensor: Loss
                        """
                        opt.zero_grad()
                        _pred = self.forward(x_data)
                        _pred_loss, self.alignment = t_get_BF_shift_index_mse(_pred, y_data, shift_range, self.xdos)#, self.xdos), perc = True)       
                        _pred_loss *= 1e7
                        self.alignment = (self.alignment - 20) * (self.xdos[1] - self.xdos[0])
                        _pred_loss = torch.nan_to_num(_pred_loss, nan=lowest_loss.item(), posinf = lowest_loss.item(), neginf = lowest_loss.item())                 
                        _reg_loss = torch.sum(torch.pow(self.linear.weight,2))
                        _reg_loss *= self.reg.item()
                        _new_loss = _pred_loss + _reg_loss
                        _new_loss.backward()
                        # global z 
                        # z = (torch.sum(abs(self.linear.weight.grad)))
                        if predictions:
                            return _pred_loss
                        return _new_loss
                    opt.step(closure)
                    #print (z)
                    with torch.no_grad():
                        pred = self.forward(x_data)
                        pred_loss = t_get_opt_BF_shift_rmse(pred, y_data, self.alignment, self.xdos, perc = True)
                        reg_loss = torch.sum(torch.pow(self.linear.weight,2))
                        reg_loss *= self.reg.item()
                        new_loss = pred_loss + reg_loss
                    if pred_loss >100000 or (pred_loss.isnan().any()) :
                        print ("Optimizer shows weird behaviour, reinitializing at previous best_State")
                        self.load_state_dict(best_state)
                        opt = torch.optim.LBFGS(self.parameters(), lr = lr)
                    if epoch %10 == 1:
                        loss_history.append(lowest_loss.item())
                        scheduler.step(new_loss)
                elif self.opt == "Adam":
                    pred = self.forward(x_data)
                    pred_loss = loss(pred, y_data)#, self.xdos, perc = True)
                    new_loss = pred_loss
                    pred_loss.backward()
                    opt.step()
                    if pred_loss >100000 or (pred_loss.isnan().any()) :
                        print ("Optimizer shows weird behaviour, reinitializing at previous best_State")
                        self.load_state_dict(best_state)
                        opt = torch.optim.Adam(self.parameters(), lr = lr, weight_decay = self.reg.item())
                    if epoch %1000 == 1:
                        loss_history.append(lowest_loss.item())

            with torch.no_grad():
                if valdata_loader is not None:
                    new_loss = torch.zeros(1, requires_grad = False).to(self.device)
                    for x_val, y_val in valdata_loader:
                        x_val, y_val = x_val.to(self.device), y_val.to(self.device)
                        val_pred = self.forward(x_val)
                        new_loss += loss(val_pred, y_val, self.xdos, perc = False)

                if lowest_loss - new_loss > tol: #threshold to stop training
                    best_state = copy.deepcopy(self.state_dict())
                    lowest_loss = new_loss
                    trigger = 0

                else:
                    trigger +=1
                    
                    if trigger > threshold:
                        self.load_state_dict(best_state)
                        print ("Implemented early stopping with lowest_loss: {}".format(lowest_loss))
                        return loss_history
        return loss_history

In [None]:
xdos_step = xdos[1] - xdos[0]
shift_range = torch.tensor([x*xdos_step for x in range(-20,20)])
M_soap = AlignmentLinearModel(448, 778, n_train, xdos, 1e-11, "LBFGS", "cpu")
loss_history = M_soap.fit(train_dataloader_soap,None, None, 1, 20000)

Epoch: 201:   1%|█                                                                                                         | 201/20000 [41:19:34<5112:13:37, 929.54s/it, lowest_loss=19, pred_loss=18.9, trigger=1]

In [None]:
M_soap.alignment/xdos_step