In [1]:
import dostools
import importlib
import numpy as np
import pickle
import torch
import sys
import matplotlib.pyplot as plt
import copy
from tqdm import tqdm
import matplotlib
import time
import scipy 
import copy
import ase
import ase.io
torch.set_default_dtype(torch.float64) 
# %matplotlib notebook
# matplotlib.rcParams['figure.figsize'] = (10, 10)

In [2]:
import dostools.datasets.data as data
import dostools.utils.utils as utils

with torch.no_grad():
#     sigma = 0.3
#     structures = data.load_structures(":")
#     n_structures = len(structures) #total number of structures
#     for structure in structures:#implement periodicity
#         structure.wrap(eps = 1e-12) 
#     n_atoms = np.zeros(n_structures, dtype = int) #stores number of atoms in each structures
#     for i in range(n_structures):
#         n_atoms[i] = len(structures[i])
        
    xdos = torch.tensor(data.load_xdos())
    
    total_dos3 = torch.load("./total_ldos3.pt")
    total_dos1 = torch.load("./total_ldos1.pt")
    
    surface_dos3 = torch.load("./surface_ldos3.pt")
    surface_dos1 = torch.load("./surface_ldos1.pt")
    
    surface_aligned_dos3 = torch.load("./surface_aligned_dos3.pt")
    surface_aligned_dos1 = torch.load("./surface_aligned_dos1.pt")
    
    bulk_dos3 = torch.load("./bulk_ldos3.pt")
    bulk_dos1 = torch.load("./bulk_ldos1.pt")
    
    total_aligned_dos3 = torch.load("./total_aligned_dos3.pt")
    total_aligned_dos1 = torch.load("./total_aligned_dos1.pt")
    
    surface_soap = torch.load("./surface_soap.pt")
    bulk_soap = torch.load("./bulk_soap.pt")
    total_soap = torch.load("./total_soap.pt")
    
    surface_kernel_30 = torch.load("./surface_kernel_30.pt")
    surface_kMM_30 = torch.load("./surface_kMM_30.pt")
    
    bulk_kernel_200 = torch.load("./bulk_kernel_200.pt")
    bulk_kMM_200 = torch.load("./bulk_kMM_200.pt")
    
    bulk_kernel_100 = torch.load("./bulk_kernel_100.pt")
    bulk_kMM_100 = torch.load("./bulk_kMM_100.pt")
    
    total_kernel_100 = torch.load("./total_kernel_100.pt")
    total_kMM_100 = torch.load("./total_kMM_100.pt")
    
    total_kernel_150 = torch.load("./total_kernel_150.pt")
    total_kMM_150 = torch.load("./total_kMM_150.pt")
    
    

In [3]:
def generate_train_test_split(n_samples):
    n_structures = n_samples
    np.random.seed(0)
    n_train = int(0.8 * n_structures)
    train_index = np.arange(n_structures)
    np.random.shuffle(train_index)
    test_index = train_index[n_train:]
    train_index = train_index[:n_train]
    
    return train_index, test_index

def generate_biased_train_test_split(n_samples):
    #Assumes 100 amorphous structures at the end
    n_structures = n_samples
    amorph_train = np.arange(n_samples-100, n_samples,1)
    np.random.seed(0)
    np.random.shuffle(amorph_train)
    
    amorph_test = amorph_train[:80]
    amorph_train = amorph_train[80:]

    n_structures = n_samples - 100
    np.random.seed(0)
    n_train = int(0.8 * n_samples)-20
    remaining_train_index = np.arange(n_structures)
    np.random.shuffle(remaining_train_index)

    remaining_test_index = remaining_train_index[n_train:]
    remaining_train_index = remaining_train_index[:n_train]

    biased_train_index = np.concatenate([remaining_train_index, amorph_train])
    biased_test_index = np.concatenate([remaining_test_index, amorph_test])
    
    return biased_train_index, biased_test_index

def generate_surface_holdout_split(n_samples):
    #Assumes that we are using the 110 surfaces for test which are located at 673 + 31st-57th index
    #26 structures
    
    n_test = int(0.2 * n_samples) - 26
    n_train = n_samples - n_test
    
    remaining_indexes = np.concatenate([np.arange(673+31), np.arange(673+57,n_samples,1)])
    indexes_110 = np.arange(673+31, 673+57,1)
    np.random.seed(0)
    
    np.random.shuffle(remaining_indexes)
    
    remaining_test_index = remaining_indexes[n_train:]
    remaining_train_index = remaining_indexes[:n_train]
    
    total_train_index = remaining_train_index
    total_test_index = np.concatenate([remaining_test_index, indexes_110])
    
    return total_train_index, total_test_index
    
def surface_holdout(n_samples):
    test_index = np.arange(31,57,1)
    train_index = np.concatenate([np.arange(31), np.arange(57, n_samples)])
    
    return train_index, test_index

n_surfaces = 154
n_bulkstructures = 773
n_total_structures = 773 + 154


surface_train_index, surface_test_index = generate_train_test_split(n_surfaces)
bulk_train_index, bulk_test_index = generate_train_test_split(n_bulkstructures)
total_train_index, total_test_index = generate_train_test_split(n_total_structures)
surface_holdout_train_index, surface_holdout_test_index = surface_holdout(n_surfaces)
bulk_biased_train_index, bulk_biased_test_index = generate_biased_train_test_split(n_bulkstructures)
total_biased_train_index, total_biased_test_index = generate_biased_train_test_split(n_total_structures)
holdout_train_index, holdout_test_index = generate_surface_holdout_split(n_total_structures)

In [46]:
from scipy.signal import convolve, correlate, correlation_lags
def find_optimal_discrete_shift(prediction, true):
    if true.shape == prediction.shape and len(prediction.shape) == 2:
        shift = []
        for i in range(true.shape[0]):
            corr = correlate(true[i], prediction[i], mode='full')
            shift_i = np.argmax(corr) - len(true[i]) + 1   
            shift.append(shift_i)
        
        
    elif true.shape == prediction.shape and len(prediction.shape) == 1:
        corr = correlate(true, prediction, mode='full')
        shift = np.argmax(corr) - len(true) + 1   
    else:
        print ("input shapes are not the same")
        raise Exception
    return shift


In [67]:
#Generate shifted data
def shifted_ldos_discrete(ldos, xdos, shift): 
    shifted_ldos = torch.zeros_like(ldos)
    if len(ldos.shape) > 1:
        xdos_shift = torch.round(shift).int()
        print (xdos_shift)
        for i in range(len(ldos)):
            if xdos_shift[i] > 0:
                shifted_ldos[i] = torch.nn.functional.pad(ldos[i,:-1*xdos_shift[i]], (xdos_shift[i],0))
            elif xdos_shift[i] < 0:
                shifted_ldos[i] = torch.nn.functional.pad(ldos[i,(-1*xdos_shift[i]):], (0,(-1*xdos_shift[i])))
            else:
                shifted_ldos[i] = ldos[i]
    else:        
        xdos_shift = int(torch.round(shift))
        if xdos_shift > 0:
            shifted_ldos = torch.nn.functional.pad(ldos[:-1*xdos_shift], (xdos_shift,0))
        elif xdos_shift < 0:
            shifted_ldos = torch.nn.functional.pad(ldos[(-1*xdos_shift):], (0,(-1*xdos_shift)))
        else:
            shifted_ldos = ldos
    return shifted_ldos



In [48]:
total_dos1.shape

torch.Size([927, 778])

In [49]:
# alignment = torch.rand(927)
# trueshift = torch.round(alignment/(xdos[1] - xdos[0])).int()
test_ldos0 = shifted_ldos(total_dos1, xdos, torch.ones(931))
optshift = find_optimal_discrete_shift(np.array(test_ldos0),np.array(total_dos1))

In [68]:
from dostools.loss import loss

def normal_reg_train_L(feat, target, train_index, test_index, regularization, n_epochs, lr):
    
    patience = 20
    index = train_index
    t_index = test_index
    features = torch.hstack([feat, torch.ones(feat.shape[0]).view(-1,1)])
    Features = features[index]
    t_Features = features[t_index]
    n_col = Features.shape[1]
    Target = target[index]
    t_Target = target[t_index]
    reg = regularization * torch.eye(n_col)
    reg[-1, -1] = 0
    reg_features = torch.vstack([Features, reg])
    reg_target = torch.vstack([Target, torch.zeros(n_col,Target.shape[1])])
    


    weights = torch.nn.Parameter(torch.rand(Features.shape[1], Target.shape[1])- 0.5)
    opt = torch.optim.LBFGS([weights], lr = lr, line_search_fn = "strong_wolfe", tolerance_grad = 1e-20, tolerance_change = 1-20, history_size = 200)
    pbar = tqdm(range(n_epochs))
    current_rmse = torch.tensor(100)
    pred_loss = torch.tensor(100)
    prev_loss = torch.tensor(100)
    best_mse = torch.tensor(100)
    trigger = 0
    for epoch in pbar:
        pbar.set_description(f"Epoch: {epoch}")
        pbar.set_postfix(pred_loss = pred_loss.item(), lowest_mse = best_mse.item(), trigger = trigger)
        def closure():
            opt.zero_grad()
            pred_i = reg_features @ weights
            opt_shift = find_optimal_discrete_shift(np.array(pred_i[:len(index)]),np.array(reg_target[:len(index)]))
            pred_i[:len(index)] = shifted_ldos(pred_i[:len(index)], xdos, opt_shift)
            loss_i = loss.t_get_mse(pred_i, reg_target)
            loss_i.backward()
            return loss_i
        opt.step(closure)

        with torch.no_grad():
            preds = Features @ weights
            epoch_rmse = loss.t_get_rmse(preds, Target, xdos, perc = True)
            epoch_mse = loss.t_get_mse(preds, Target, xdos)


            pred_loss = epoch_rmse

            if epoch_mse < best_mse:
                best_mse = epoch_mse
                best_state = weights.clone()

            if epoch_mse < prev_loss * ( 1 + 1e-3):
                trigger =0
            else:
                trigger +=1 
                if trigger >= patience:
                    weights = best_state
                    opt = torch.optim.Adam([weights], lr = opt.param_groups[0]['lr'], weight_decay = 0)

            epoch_mse = prev_loss


    

    final_preds = Features @ best_state 
    final_t_preds = t_Features @ best_state

    loss_dos = loss.t_get_rmse(final_preds, Target, xdos, perc = True)
    test_loss_dos = loss.t_get_rmse(final_t_preds, t_Target, xdos, perc = True)
    return best_state, loss_dos, test_loss_dos

def normal_reg_train_Ad(feat, target, train_index, test_index, regularization, n_epochs, batch_size, lr):
    index = train_index
    t_index = test_index

    features = torch.hstack([feat, torch.ones(feat.shape[0]).view(-1,1)])

    Sampler = torch.utils.data.RandomSampler(index, replacement = False)
    Batcher = torch.utils.data.BatchSampler(Sampler, batch_size, False)

    Features = features[index]
    t_Features = features[t_index]
    n_col = Features.shape[1]


    Target = target[index]
    t_Target = target[t_index]


    # reg_features = torch.vstack([Features, reg])
    # reg_target = torch.vstack([Target, torch.zeros(n_col,Target.shape[1])])


    reg = regularization * torch.eye(n_col)
    reg[-1, -1] = 0


    weights = torch.nn.Parameter((torch.rand(Features.shape[1], Target.shape[1])- 0.5))
    opt = torch.optim.Adam([weights], lr = lr, weight_decay = 0)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor = 0.1, patience = 500, threshold = 1e-7, min_lr = 1e-8)

    pbar = tqdm(range(n_epochs))

    current_rmse = torch.tensor(100)
    pred_loss = torch.tensor(100)
    prev_loss = torch.tensor(100)
    best_mse = torch.tensor(100)
    trigger = 0
    for epoch in pbar:
        pbar.set_description(f"Epoch: {epoch}")
        pbar.set_postfix(pred_loss = pred_loss.item(), lowest_mse = best_mse.item(), trigger = trigger)
        for i_batch in Batcher:
            def closure():
                opt.zero_grad()
                reg_features_i = torch.vstack([Features[i_batch], reg])
                target_i = torch.vstack([Target[i_batch], torch.zeros(n_col, Target.shape[1])])
                pred_i = reg_features_i @ weights
                opt_shift = find_optimal_discrete_shift(np.array(pred_i[:len(i_batch)].detach()),np.array(target_i[:len(i_batch)].detach()))
                print (opt_shift)
                pred_i[:len(i_batch)] = shifted_ldos(pred_i[:len(i_batch)], xdos, torch.tensor(opt_shift))
                loss_i = loss.t_get_mse(pred_i, target_i)
                loss_i.backward()
                return loss_i
            opt.step(closure)

        with torch.no_grad():
            preds = Features @ weights
            epoch_rmse = loss.t_get_rmse(preds, Target, xdos, perc = True)
            epoch_mse = loss.t_get_mse(preds, Target, xdos)


            pred_loss = epoch_rmse

            if epoch_mse < best_mse:
                best_mse = epoch_mse
                best_state = weights.clone()

            if epoch_mse < prev_loss * ( 1 + 1e-3):
                trigger =0
            else:
                trigger +=1 
                if trigger >= patience:
                    weights = best_state
                    opt = torch.optim.Adam([weights], lr = opt.param_groups[0]['lr'], weight_decay = 0)

            epoch_mse = prev_loss

            scheduler.step(epoch_mse)

            if Batcher.batch_size > 1024:
                break

            if opt.param_groups[0]['lr'] < 1e-4:
                Batcher.batch_size *= 2
                opt.param_groups[0]['lr'] = lr
                print ("The batch_size is now: ", Batcher.batch_size)

    

    final_preds = Features @ best_state 
    final_t_preds = t_Features @ best_state

    loss_dos = loss.t_get_rmse(final_preds, Target, xdos, perc = True)
    test_loss_dos = loss.t_get_rmse(final_t_preds, t_Target, xdos, perc = True)
    return best_state, loss_dos, test_loss_dos
        

In [69]:
weights, loss_dos, test_loss_dos = normal_reg_train_Ad(surface_soap, surface_dos3, surface_train_index, surface_test_index, 1e-2, 10000, 16, 1e-3)
print ("Adam Unbiased")
print ("The train error is {:.4} for SOAP".format(loss_dos))
print ("The test error is {:.4} for SOAP".format(test_loss_dos))

Epoch: 0:   0%|                                                                                                                                | 0/10000 [00:00<?, ?it/s, lowest_mse=100, pred_loss=100, trigger=0]

[143, 141, 152, 141, 145, 156, 154, 144, 143, 143, 156, 142, 158, 157, 143, 142]





RuntimeError: The expanded size of the tensor (778) must match the existing size (2856) at non-singleton dimension 0.  Target sizes: [778].  Tensor sizes: [2856]