In [1]:
import os

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
import torch.optim as optim

from collections import Counter
from datetime import datetime
from torch import nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing, DataParallel
from torch_scatter import scatter
from torch.utils.tensorboard import SummaryWriter

from DataClasses import lmdb_dataset, Dataset, choose_dataloader
from ModelFunctions import my_reshape, multigpu_available, predict, train, evaluate, inference

import tqdm

In [2]:
import sys
sys.path.append(os.path.expanduser('../ocpmodels/models'))
sys.path.append(os.path.expanduser('../../ocp_airi'))

from spinconv import spinconv

In [3]:
#вызывается каждый раз, когда датасет отдаёт элемент (систему)
#делаем из данных матрицу векторов-атомов, список рёбер (edge_index) и матрицу векторов-рёбер; надо писать свою функцию для каждой сети
def preprocessing(system):
    keys = ['pos', 'atomic_numbers', 'cell', 'natoms', 'sid', 'y_relaxed']
    features_dict = {}
    for key in keys:
        features_dict[key] = system[key]
    return Data(**features_dict)

In [4]:
m_table = pd.read_csv('https://gist.githubusercontent.com/GoodmanSciences/c2dd862cd38f21b0ad36b8f96b4bf1ee/raw/1d92663004489a5b6926e944c1b3d9ec5c40900e/Periodic%2520Table%2520of%2520Elements.csv')

In [5]:
#config
subsample = 'ood_ads'
batch_size = 63
num_workers = 0

features_cols = ['feature_1']

target_col = 'y_relaxed'
lr = 0.001
epochs = 30

In [6]:
# #чтобы тензор по умолчанию заводился на куде
# if torch.cuda.is_available():
#     torch.set_default_tensor_type('torch.cuda.FloatTensor')
#     print('cuda')

In [7]:
#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
print(device)

cuda


In [8]:
#инициализируем валидационный датасет и валидационный итератор
val_dataset_file_path = os.path.expanduser(f"../../ocp_datasets/data/is2re/all/val_{subsample}/data.lmdb")

valid_set = Dataset(val_dataset_file_path, features_cols, target_col, preprocessing=preprocessing)
valid_generator = choose_dataloader(valid_set, batch_size=batch_size, num_workers=num_workers)

In [9]:
try:
    lmdb_dataset(train_dataset_file_path).describe()
except:
    pass

In [10]:
#model
model = spinconv(None, None, 1, otf_graph=True, regress_forces=False)
checkpoint = '2021-09-25-11-23-25_epoch_15_.pickle'
folder = '../checkpoints_for_analysis/'
model_chp = torch.load(folder + checkpoint,  map_location=device)
state_dict = model_chp.module.state_dict()
# state_dict = model_chp.state_dict()
# new_state_dict = OrderedDict()
# for k, v in state_dict.items():
#     name = k[7:] # remove `module.`
#     new_state_dict[name] = v
if multigpu_available():
    model = DataParallel(model)

#optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion = nn.L1Loss()

#переносим на куду если она есть
model = model.to(device)
criterion = criterion.to(device)

# model.load_state_dict(new_state_dict)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [11]:
iterator = valid_generator

In [12]:
result = torch.tensor([]).to(device)

multigpu_mode = multigpu_available()

model.eval()

with torch.no_grad():
    for batch in tqdm.tqdm(iterator):
        predictions, sids = predict(model, batch, multigpu_mode, device, inference=True)
        sids = my_reshape(sids)
        predictions = my_reshape(predictions)
        y_true = my_reshape(batch['y_relaxed'])
        mini_submit = torch.cat((sids, predictions, y_true), dim=1)
        result = torch.cat((result, mini_submit))        

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 397/397 [02:15<00:00,  2.92it/s]


In [13]:
df = pd.DataFrame(result, columns=['sid', 'pred', 'target'])

In [14]:
for column in ['sid', 'pred', 'target']:
    df[column] = df[column].apply(lambda x: x.item())
    
df['sid'] = df['sid'].apply(int)
df['index'] = df.index

In [15]:
dataset = lmdb_dataset(val_dataset_file_path)

In [16]:
def get_composition(row):
    i = int(row['index'])
    from_lmdb = dataset[i]
    assert(int(from_lmdb['sid']) == int(row['sid']))
    atomic_numbers = from_lmdb['atomic_numbers']
    counter = Counter(atomic_numbers.long().tolist())
    composition = []
    for key in counter:
        element = m_table.iloc[key-1]['Symbol']
        composition.append(element)
        composition.append(str(counter[key]))
    return '_'.join(composition)

In [17]:
%%time
df['composition'] = df.apply(get_composition, axis=1)

CPU times: user 26.3 s, sys: 225 ms, total: 26.5 s
Wall time: 26.5 s


In [18]:
df

Unnamed: 0,sid,pred,target,index,composition
0,1930510,-1.911062,-2.826376,0,Ti_16_Hg_16_C_1_H_1_O_1
1,1729668,0.758232,1.307682,1,K_64_Zn_32_S_64_C_1_H_1_O_1
2,710795,-1.281191,-1.009515,2,Rh_72_W_24_C_1_H_1
3,1904318,-0.128700,-1.053289,3,Cs_32_N_1_H_2
4,455640,3.736803,3.714117,4,As_28_Rh_48_N_1_O_2_H_1
...,...,...,...,...,...
24956,1670661,-1.600841,-1.797060,24956,Ga_16_Mn_16_Zr_16_N_1_H_2
24957,714182,0.417107,-1.844538,24957,S_24_Si_12_C_1_H_1_O_1
24958,1728610,-1.767314,-0.299679,24958,N_48_Sr_48_C_1_H_1
24959,614554,-5.785323,-4.340534,24959,Ta_96_S_16_N_2_C_2_H_8


In [19]:
df.to_csv(f'id_val_{subsample}.csv')