In [14]:
import pandas as pd 
import numpy as np 
import torch 
import ast
import matplotlib.pyplot as plt

train_dict = torch.load("data/train_pv_xrd.pt")
val_dict = torch.load("data/val_pv_xrd.pt")
test_dict = torch.load("data/test_pv_xrd.pt")

train_df = pd.read_csv("data/train.csv")
val_df = pd.read_csv("data/val.csv")
test_df = pd.read_csv("data/test.csv")

train_df['atomic_numbers'] = train_df['atomic_numbers'].apply(ast.literal_eval)
val_df['atomic_numbers'] = val_df['atomic_numbers'].apply(ast.literal_eval)
test_df['atomic_numbers'] = test_df['atomic_numbers'].apply(ast.literal_eval)


In [15]:
MAX_ATOM = 80
def filter_df(df):
    """remove any compounds with elements with atomic numbers equal to or greater than max atom """
        
    atomic_num_list = [np.array(sublist) for sublist in list(df['atomic_numbers'])]
    indices_to_exclude = [i for i, val in enumerate(atomic_num_list) if np.any(val > MAX_ATOM-1)]
    df = df.drop(indices_to_exclude)
    return df 

def convert_to_tensor(pseudo_voight_dict, df): 
    list_of_pseudo_voights = []
    for key in df['material_id']:
        #key += "_0"
        value = pseudo_voight_dict[key]
        list_of_pseudo_voights.append(value)

    tensor_of_pseudo_voights = torch.tensor(torch.stack(list_of_pseudo_voights))
    return(tensor_of_pseudo_voights)

def create_padded_seqs(filtered_df): 
    atomic_num_list = [np.array(sublist) for sublist in list(filtered_df['atomic_numbers'])]
    all_atom_types = [np.concatenate([vec, np.zeros(25 - len(vec))]) for vec in atomic_num_list]
    all_atom_types = torch.tensor(np.stack(all_atom_types)).long()
    training_data_onehot = torch.nn.functional.one_hot(all_atom_types, num_classes=MAX_ATOM).float()

    return training_data_onehot

In [16]:
train_df = filter_df(train_df)
val_df = filter_df(val_df)
test_df = filter_df(test_df)

data_dict = {
    'train': (train_dict, train_df),
    'val': (val_dict, val_df), 
    'test': (test_dict, test_df)
}

training_sgs = torch.tensor(train_df['spacegroup.number'].values, dtype=torch.long)
val_sgs = torch.tensor(val_df['spacegroup.number'].values, dtype=torch.long)
test_sgs = torch.tensor(test_df['spacegroup.number'].values, dtype=torch.long)

training_pvs = convert_to_tensor(train_dict, train_df)
val_pvs = convert_to_tensor(val_dict, val_df)
test_pvs = convert_to_tensor(test_dict, test_df)

training_comps = create_padded_seqs(train_df)
val_comps = create_padded_seqs(val_df)
test_comps = create_padded_seqs(test_df)

  tensor_of_pseudo_voights = torch.tensor(torch.stack(list_of_pseudo_voights))


In [17]:
torch.save(training_comps, "data/train_compositionseq.pt")
torch.save(val_comps, "data/val_compositionseq.pt")
torch.save(test_comps, "data/test_compositionseq.pt")
torch.save(training_pvs, "data/train_pvs.pt")
torch.save(val_pvs, "data/val_pvs.pt")
torch.save(test_pvs, "data/test_pvs.pt")
torch.save(training_sgs, "data/train_sgs.pt")
torch.save(val_sgs, "data/val_sgs.pt")
torch.save(test_sgs, "data/test_sgs.pt")