In [96]:
import dostools
import importlib
import numpy as np
import pickle
import torch
import sys
import matplotlib.pyplot as plt
import copy
from tqdm import tqdm
import matplotlib
import time
import scipy 
import ase
import ase.io
torch.set_default_dtype(torch.float64) 
# %matplotlib notebook
# matplotlib.rcParams['figure.figsize'] = (10, 10)

In [123]:
import dostools.datasets.data as data
import dostools.utils.utils as utils

with torch.no_grad():
#     sigma = 0.3
#     structures = data.load_structures(":")
#     n_structures = len(structures) #total number of structures
#     for structure in structures:#implement periodicity
#         structure.wrap(eps = 1e-12) 
#     n_atoms = np.zeros(n_structures, dtype = int) #stores number of atoms in each structures
#     for i in range(n_structures):
#         n_atoms[i] = len(structures[i])
        
    xdos = torch.tensor(data.load_xdos())
    
    total_dos3 = torch.load("./total_ldos3.pt")
    total_dos1 = torch.load("./total_ldos1.pt")
    
    surface_dos3 = torch.load("./surface_ldos3.pt")
    surface_dos1 = torch.load("./surface_ldos1.pt")
    
    surface_aligned_dos3 = torch.load("./surface_aligned_dos3.pt")
    surface_aligned_dos1 = torch.load("./surface_aligned_dos1.pt")
    
    bulk_dos3 = torch.load("./bulk_ldos3.pt")
    bulk_dos1 = torch.load("./bulk_ldos1.pt")
    
    total_aligned_dos3 = torch.load("./total_aligned_dos3.pt")
    total_aligned_dos1 = torch.load("./total_aligned_dos1.pt")
    
    surface_soap = torch.load("./surface_soap.pt")
    bulk_soap = torch.load("./bulk_soap.pt")
    total_soap = torch.load("./total_soap.pt")
    
    surface_kernel_30 = torch.load("./surface_kernel_30.pt")
    surface_kMM_30 = torch.load("./surface_kMM_30.pt")
    
    bulk_kernel_200 = torch.load("./bulk_kernel_200.pt")
    bulk_kMM_200 = torch.load("./bulk_kMM_200.pt")
    
    bulk_kernel_100 = torch.load("./bulk_kernel_100.pt")
    bulk_kMM_100 = torch.load("./bulk_kMM_100.pt")
    
    total_kernel_100 = torch.load("./total_kernel_100.pt")
    total_kMM_100 = torch.load("./total_kMM_100.pt")
    
    total_kernel_150 = torch.load("./total_kernel_150.pt")
    total_kMM_150 = torch.load("./total_kMM_150.pt")
    
    

In [121]:
# with torch.no_grad():

#     cutoff_index = torch.tensor(487)
#     xdos = xdos[:cutoff_index]

#     total_dos3 = torch.load("./total_ldos3.pt")[:,:cutoff_index]
#     total_dos1 = torch.load("./total_ldos1.pt")[:,:cutoff_index]

#     surface_dos3 = torch.load("./surface_ldos3.pt")[:,:cutoff_index]
#     surface_dos1 = torch.load("./surface_ldos1.pt")[:,:cutoff_index]

#     surface_aligned_dos3 = torch.load("./surface_aligned_dos3.pt")[:,:cutoff_index]
#     surface_aligned_dos1 = torch.load("./surface_aligned_dos1.pt")[:,:cutoff_index]

#     bulk_dos3 = torch.load("./bulk_ldos3.pt")[:,:cutoff_index]
#     bulk_dos1 = torch.load("./bulk_ldos1.pt")[:,:cutoff_index]

#     total_aligned_dos3 = torch.load("./total_aligned_dos3.pt")[:,:cutoff_index]
#     total_aligned_dos1 = torch.load("./total_aligned_dos1.pt")[:,:cutoff_index]

In [122]:
def generate_train_test_split(n_samples):
    n_structures = n_samples
    np.random.seed(0)
    n_train = int(0.8 * n_structures)
    train_index = np.arange(n_structures)
    np.random.shuffle(train_index)
    test_index = train_index[n_train:]
    train_index = train_index[:n_train]
    
    return train_index, test_index

def generate_biased_train_test_split(n_samples):
    #Assumes 100 amorphous structures at the end
    n_structures = n_samples
    amorph_train = np.arange(n_samples-100, n_samples,1)
    np.random.seed(0)
    np.random.shuffle(amorph_train)
    
    amorph_test = amorph_train[:80]
    amorph_train = amorph_train[80:]

    n_structures = n_samples - 100
    np.random.seed(0)
    n_train = int(0.8 * n_samples)-20
    remaining_train_index = np.arange(n_structures)
    np.random.shuffle(remaining_train_index)

    remaining_test_index = remaining_train_index[n_train:]
    remaining_train_index = remaining_train_index[:n_train]

    biased_train_index = np.concatenate([remaining_train_index, amorph_train])
    biased_test_index = np.concatenate([remaining_test_index, amorph_test])
    
    return biased_train_index, biased_test_index

def generate_surface_holdout_split(n_samples):
    #Assumes that we are using the 110 surfaces for test which are located at 673 + 31st-57th index
    #26 structures
    
    n_test = int(0.2 * n_samples) - 26
    n_train = n_samples - n_test
    
    remaining_indexes = np.concatenate([np.arange(673+31), np.arange(673+57,n_samples,1)])
    indexes_110 = np.arange(673+31, 673+57,1)
    np.random.seed(0)
    
    np.random.shuffle(remaining_indexes)
    
    remaining_test_index = remaining_indexes[n_train:]
    remaining_train_index = remaining_indexes[:n_train]
    
    total_train_index = remaining_train_index
    total_test_index = np.concatenate([remaining_test_index, indexes_110])
    
    return total_train_index, total_test_index
    
def surface_holdout(n_samples):
    test_index = np.arange(31,57,1)
    train_index = np.concatenate([np.arange(31), np.arange(57, n_samples)])
    
    return train_index, test_index

n_surfaces = 154
n_bulkstructures = 773
n_total_structures = 773 + 154


surface_train_index, surface_test_index = generate_train_test_split(n_surfaces)
bulk_train_index, bulk_test_index = generate_train_test_split(n_bulkstructures)
total_train_index, total_test_index = generate_train_test_split(n_total_structures)
surface_holdout_train_index, surface_holdout_test_index = surface_holdout(n_surfaces)
bulk_biased_train_index, bulk_biased_test_index = generate_biased_train_test_split(n_bulkstructures)
total_biased_train_index, total_biased_test_index = generate_biased_train_test_split(n_total_structures)
holdout_train_index, holdout_test_index = generate_surface_holdout_split(n_total_structures)

In [227]:
# cat = torch.rand(1000).reshape(20, 5, 10)
# dog = torch.rand(1000).reshape(20, 5, 10)
mouse = torch.sum((cat - dog) **2, axis = 2)
mouse =  (mouse.reshape(20,5))

a, b = torch.min(mouse, dim = 0)

In [238]:
def t_get_each_mse(predictions, true, xdos = None):
    #takes a 3d array for predictions and true
    if xdos is not None:
        mse = torch.trapezoid((predictions - true)**2, xdos, axis = 2)
    else:
        mse = torch.sum((predictions - true)**2 , axis = 2)
        
    return mse

def t_get_rmse(a, b, xdos=None, perc=False): #account for the fact that DOS is continuous but we are training them pointwise
    """ computes  Root Mean Squared Error (RMSE) of array properties (DOS/aofd).
         a=pred, b=target, xdos, perc: if False return RMSE else return %RMSE"""
    #MIGHT NOT WORK FOR PC
    if xdos is not None:
        if len(a.size()) > 1:
            rmse = torch.sqrt((torch.trapezoid((a - b)**2, xdos, axis=1)).mean())
        else:
            rmse = torch.sqrt((torch.trapezoid((a - b)**2, xdos, axis=0)).mean())
        if not perc:
            return rmse
        else:
            mean = b.mean(axis = 0)
            std = torch.sqrt((torch.trapezoid((b - mean)**2, xdos, axis=1)).mean())
            return (100 * rmse / std)
    else:
        if len(a.size()) > 1:
            rmse = torch.sqrt(((a - b)**2).mean(dim =0))
        else:
            rmse = torch.sqrt(((a - b)**2).mean())
        if not perc:
            return torch.mean(rmse, 0)
        else:
            return torch.mean(100 * (rmse / b.std(dim = 0,unbiased=True)), 0)
        
def t_get_mse(a, b, xdos = None, perc = False):
    if xdos is not None:
        if len(a.size()) > 1:
            mse = (torch.trapezoid((a - b)**2, xdos, axis=1)).mean()
        else:
            mse = (torch.trapezoid((a - b)**2, xdos, axis=0)).mean()
        if not perc:
            return mse
        else:
            mean = b.mean(axis = 0)
            std = torch.trapezoid((b - mean)**2, xdos, axis=1).mean()
            return (100 * mse / std)
    else:
        if len(a.size()) > 1:
            mse = ((a - b)**2).mean(dim = 1)
        else:
            mse = ((a - b)**2).mean()
        if len(mse.shape) > 1:
            raise ValueError('Loss became 2D')
        if not perc:
            return torch.mean(mse, 0)
        else:
            return torch.mean(100 * (mse / b.std(dim=0, unbiased = True)),0)

In [277]:
alignment = torch.rand(927)
trueshift = torch.round(alignment/(xdos[1] - xdos[0])).int()
test_ldos0 = shifted_ldos_discrete(total_dos1, torch.ones(931))
shift_range = torch.arange(2 * full_range + 1) - full_range
l, s = t_get_BF_shift_index_mse(test_ldos0[:,:cutoff_index], total_dos1, shift_range, cutoff_index)

In [264]:
def t_get_BF_shift_index_mse(prediction, true, shift_range, cutoff_index):
    #shifts target instead of prediction
    if len(prediction.shape) > 1:
        shifted_true = shifted_ldos_discrete(true.repeat(shift_range.shape[0], 1, 1), shift_range)[:,:,:cutoff_index]
        full_loss = t_get_each_mse(prediction.repeat(shift_range.shape[0], 1, 1), shifted_true)
        full_loss = full_loss.reshape(shift_range.shape[0], -1)
        min_loss, index = torch.min(full_loss, dim = 0)
        
    else:
        shifted_true = shifted_ldos_discrete(true.repeat(shift_range.shape[0], 1), shift_range)[:,:,:cutoff_index]
        full_loss = t_get_each_mse(prediction.repeat(shift_range.shape[0], 1), shifted_true)
        min_loss, index = torch.min(full_loss, dim = 0)
        min_loss = full_loss[index]
    return min_loss, index

In [202]:
def shifted_ldos_discrete(ldos, shift): 
    shifted_ldos = torch.zeros_like(ldos)
    if len(ldos.shape) == 3:
        for i in range(len(ldos)):
            if shift[i] > 0:
                shifted_ldos[i] = torch.nn.functional.pad(ldos[i, :, :-1*shift[i]], (shift[i], 0))
            elif shift[i] < 0:
                shifted_ldos[i] = torch.nn.functional.pad(ldos[i, :, (-1*shift[i]):], (0, (-1*shift[i])))
            else:
                shifted_ldos[i] = ldos[i]
    elif len(ldos.shape) == 2:
        xdos_shift = torch.round(shift).int()
        for i in range(len(ldos)):
            if xdos_shift[i] > 0:
                shifted_ldos[i] = torch.nn.functional.pad(ldos[i,:-1*xdos_shift[i]], (xdos_shift[i],0))
            elif xdos_shift[i] < 0:
                shifted_ldos[i] = torch.nn.functional.pad(ldos[i,(-1*xdos_shift[i]):], (0,(-1*xdos_shift[i])))
            else:
                shifted_ldos[i] = ldos[i]
    else:        
        xdos_shift = int(torch.round(shift))
        if xdos_shift > 0:
            shifted_ldos = torch.nn.functional.pad(ldos[:-1*xdos_shift], (xdos_shift,0))
        elif xdos_shift < 0:
            shifted_ldos = torch.nn.functional.pad(ldos[(-1*xdos_shift):], (0,(-1*xdos_shift)))
        else:
            shifted_ldos = ldos
    return shifted_ldos

In [84]:
# alignment = torch.rand(927)
# trueshift = torch.round(alignment/(xdos[1] - xdos[0])).int()
# test_ldos0 = shifted_ldos(total_dos1, xdos, torch.ones(931))
# optshift = find_optimal_discrete_shift(np.array(test_ldos0),np.array(total_dos1))

In [271]:
def normal_reg_train_L(cutoff_index, shift_range, feat, target, train_index, test_index, regularization, n_epochs, lr):
    
    patience = 20
    index = train_index
    t_index = test_index
    features = torch.hstack([feat, torch.ones(feat.shape[0]).view(-1,1)])
    Features = features[index]
    t_Features = features[t_index]
    n_col = Features.shape[1]
    Target = target[index]
    t_Target = target[t_index]
    reg = regularization * torch.eye(n_col)
    reg[-1, -1] = 0
    reg_features = torch.vstack([Features, reg])
    reg_target = torch.vstack([Target, torch.zeros(n_col,Target.shape[1])])
    
    alignment = torch.zeros(len(index))

    weights = torch.nn.Parameter(torch.rand(Features.shape[1], cutoff_index)- 0.5)
    
    opt = torch.optim.LBFGS([weights], lr = lr, line_search_fn = "strong_wolfe", tolerance_grad = 1e-20, tolerance_change = 1-20, history_size = 200)
    pbar = tqdm(range(n_epochs))
    current_rmse = torch.tensor(100)
    pred_loss = torch.tensor(100)
    prev_loss = torch.tensor(100)
    best_mse = torch.tensor(100)
    trigger = 0
    for epoch in pbar:
        pbar.set_description(f"Epoch: {epoch}")
        pbar.set_postfix(pred_loss = pred_loss.item(), lowest_mse = best_mse.item(), trigger = trigger)
        def closure():
            print ("--------HEY------")
            opt.zero_grad()
            pred_i = reg_features @ weights
            loss_i, jitter = t_get_BF_shift_index_mse(pred_i[:len(index)], reg_target[:len(index)], shift_range, cutoff_index)
            print (loss_i) #FIND OUT WHY THE LOSS_i is not the same as the bottom one
            print (torch.mean(loss_i))
            #change this code such that we dont use the loss but we get the jitter and shift it again
            jitter -= shift_range[0]
            reg_target_i = reg_target.clone()
            reg_target_i[:len(index)] = shifted_ldos_discrete(reg_target_i[:len(index)], jitter)
            loss_i = t_get_mse(pred_i, reg_target_i[:, :cutoff_index])
            print (loss_i)
            
            loss_i.backward()
            return loss_i
        opt.step(closure)

        with torch.no_grad():
            preds = Features @ weights
            target = Target.clone()
#             target = shifted_ldos_discrete(target, -1 * alignment)
            epoch_mse, jitter = t_get_BF_shift_index_mse(preds, target, shift_range, cutoff_index)
            jitter -= shift_range[0]
            alignment = -1 * jitter
            
            epoch_mse = torch.mean(epoch_mse)
            target = Target.clone()
            target = shifted_ldos_discrete(target, -1 * jitter)
            
            epoch_rmse = t_get_rmse(preds, target[:,:cutoff_index], xdos[:cutoff_index], perc = True)
            
            
            pred_loss = torch.mean(epoch_rmse)

            if epoch_mse < best_mse:
                best_mse = epoch_mse
                best_state = weights.clone()

            if epoch_mse < prev_loss * ( 1 + 1e-3):
                trigger =0
            else:
                trigger +=1 
                if trigger >= patience:
                    weights = best_state
                    opt = torch.optim.Adam([weights], lr = opt.param_groups[0]['lr'], weight_decay = 0)

            epoch_mse = prev_loss


    
    with torch.no_grad():
        print ("NOT FIXED YET")
        final_preds = Features @ best_state 
        final_t_preds = t_Features @ best_state

        shifted_true = find_optimal_discrete_shift(np.array(final_preds),np.array(Target))
        final_preds = shifted_ldos_discrete(final_preds, xdos, torch.tensor(opt_shift_train))
        opt_shift_test = find_optimal_discrete_shift(np.array(final_t_preds),np.array(t_Target))
        final_t_preds = shifted_ldos_discrete(final_t_preds, xdos, torch.tensor(opt_shift_test))

        loss_dos = loss.t_get_rmse(final_preds, Target[:,:cutoff_index], xdos[:cutoff_index], perc = True)
        test_loss_dos = loss.t_get_rmse(final_t_preds, t_Target[:,:cutoff_index], xdos[:cutoff_index], perc = True)
        return best_state, loss_dos, test_loss_dos
        

def normal_reg_train_Ad(cutoff_index, shift_range, feat, target, train_index, test_index, regularization, n_epochs, batch_size, lr):
    patience = 20
    index = train_index
    t_index = test_index

    features = torch.hstack([feat, torch.ones(feat.shape[0]).view(-1,1)])

    Sampler = torch.utils.data.RandomSampler(index, replacement = False)
    Batcher = torch.utils.data.BatchSampler(Sampler, batch_size, False)

    Features = features[index]
    t_Features = features[t_index]
    n_col = Features.shape[1]


    Target = target[index]
    t_Target = target[t_index]


    # reg_features = torch.vstack([Features, reg])
    # reg_target = torch.vstack([Target, torch.zeros(n_col,Target.shape[1])])


    reg = regularization * torch.eye(n_col)
    reg[-1, -1] = 0
    
#     alignment = torch.zeros(len(index))

    weights = torch.nn.Parameter((torch.rand(Features.shape[1], Target.shape[1])- 0.5))
    opt = torch.optim.Adam([weights], lr = lr, weight_decay = 0)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor = 0.1, patience = 500, threshold = 1e-7, min_lr = 1e-8)

    pbar = tqdm(range(n_epochs))

    current_rmse = torch.tensor(100)
    pred_loss = torch.tensor(100)
    prev_loss = torch.tensor(100)
    best_mse = torch.tensor(100)
    trigger = 0
    for epoch in pbar:
        pbar.set_description(f"Epoch: {epoch}")
        pbar.set_postfix(pred_loss = pred_loss.item(), lowest_mse = best_mse.item(), trigger = trigger)
        for i_batch in Batcher:
            def closure():
                opt.zero_grad()
                reg_features_i = torch.vstack([Features[i_batch], reg])
                pred_i = reg_features_i @ weights
#                 shifted_target_i = shifted_ldos_discrete(Target[i_batch].clone(), xdos, -1 * alignment)
                loss_i, jitter = t_get_BF_shift_index_mse(pred_i[:len(index)], reg_target_n[:len(index)], shift_range, cutoff_index)
                target_i = torch.vstack([shifted_target_i, torch.zeros(n_col, Target.shape[1])])
                loss_i = loss.t_get_mse(pred_i, target_i)
                loss_i.backward()
                return loss_i
            opt.step(closure)

        with torch.no_grad():
            preds = Features @ weights
            opt_shift = find_optimal_discrete_shift(np.array(preds),np.array(Target))
            preds = shifted_ldos_discrete(preds, xdos, torch.tensor(opt_shift))
            epoch_rmse = loss.t_get_rmse(preds, Target, xdos, perc = True)
            epoch_mse = loss.t_get_mse(preds, Target, xdos)


            pred_loss = torch.mean(epoch_rmse)

            if pred_loss < best_mse:
                best_mse = epoch_mse
                best_state = weights.clone()

            if epoch_mse < prev_loss * ( 1 + 1e-3):
                trigger =0
            else:
                trigger +=1 
                if trigger >= patience:
                    weights = best_state
                    opt = torch.optim.Adam([weights], lr = opt.param_groups[0]['lr'], weight_decay = 0)

            prev_loss = pred_loss

            scheduler.step(epoch_mse)

            if Batcher.batch_size > 1024:
                break

            if opt.param_groups[0]['lr'] < 1e-4:
                Batcher.batch_size *= 2
                opt.param_groups[0]['lr'] = lr
                print ("The batch_size is now: ", Batcher.batch_size)

    
    with torch.no_grad():
        final_preds = Features @ best_state 
        final_t_preds = t_Features @ best_state

        opt_shift_train = find_optimal_discrete_shift(np.array(final_preds),np.array(Target))
        final_preds = shifted_ldos_discrete(final_preds, xdos, torch.tensor(opt_shift_train))
        opt_shift_test = find_optimal_discrete_shift(np.array(final_t_preds),np.array(t_Target))
        final_t_preds = shifted_ldos_discrete(final_t_preds, xdos, torch.tensor(opt_shift_test))

        loss_dos = loss.t_get_rmse(final_preds, Target, xdos, perc = True)
        test_loss_dos = loss.t_get_rmse(final_t_preds, t_Target, xdos, perc = True)
        return best_state, loss_dos, test_loss_dos
        

In [251]:
len(xdos)

778

In [258]:
cutoff_index = torch.tensor(487)
full_range = 200
shift_range = torch.arange(2 * full_range + 1) - full_range

In [217]:
weights, loss_dos, test_loss_dos = normal_reg_train_Ad(total_soap, total_aligned_dos3,
                                                       total_train_index, total_test_index,
                                                       1e-2, 10000, 16, 1e-3)
print ("Adam Unbiased")
print ("The train error is {:.4} for SOAP".format(loss_dos))
print ("The test error is {:.4} for SOAP".format(test_loss_dos))

Epoch: 0:   0%|                                                                                                                                | 0/10000 [00:00<?, ?it/s, lowest_mse=100, pred_loss=100, trigger=0]


TypeError: shifted_ldos_discrete() takes 2 positional arguments but 3 were given

In [272]:
U_L_weights3 , loss_dos, test_loss_dos = normal_reg_train_L(cutoff_index, shift_range, surface_soap, surface_dos3, surface_train_index, surface_test_index,
                                                            1e-2, 6, 1)
print ("LBFGS Unbiased")
print ("The train error is {:.4} for SOAP".format(loss_dos))
print ("The test error is {:.4} for SOAP".format(test_loss_dos))

Epoch: 0:   0%|                                                                                                                                    | 0/6 [00:00<?, ?it/s, lowest_mse=100, pred_loss=100, trigger=0]

--------HEY------
torch.Size([123])
tensor(125.4775, grad_fn=<MeanBackward0>)
tensor(0.0510, grad_fn=<MeanBackward1>)
--------HEY------
torch.Size([123])
tensor(125.1251, grad_fn=<MeanBackward0>)
tensor(0.0509, grad_fn=<MeanBackward1>)
--------HEY------
torch.Size([123])
tensor(121.9773, grad_fn=<MeanBackward0>)
tensor(0.0495, grad_fn=<MeanBackward1>)
--------HEY------
torch.Size([123])
tensor(92.8213, grad_fn=<MeanBackward0>)
tensor(0.0369, grad_fn=<MeanBackward1>)
--------HEY------
torch.Size([123])
tensor(6.2189, grad_fn=<MeanBackward0>)
tensor(8.8324e-05, grad_fn=<MeanBackward1>)
--------HEY------
torch.Size([123])
tensor(6.2186, grad_fn=<MeanBackward0>)
tensor(8.8190e-05, grad_fn=<MeanBackward1>)
--------HEY------
torch.Size([123])
tensor(6.2158, grad_fn=<MeanBackward0>)
tensor(8.6987e-05, grad_fn=<MeanBackward1>)
--------HEY------
torch.Size([123])
tensor(6.1891, grad_fn=<MeanBackward0>)
tensor(7.5529e-05, grad_fn=<MeanBackward1>)
--------HEY------
torch.Size([123])
tensor(6.0513

Epoch: 0:   0%|                                                                                                                                    | 0/6 [00:03<?, ?it/s, lowest_mse=100, pred_loss=100, trigger=0]

torch.Size([123])
tensor(6.0437, grad_fn=<MeanBackward0>)
tensor(1.1336e-05, grad_fn=<MeanBackward1>)
--------HEY------





KeyboardInterrupt: 

In [154]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f74e219bb20>