In [1]:
import os

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
import torch.optim as optim

from collections import Counter, OrderedDict, defaultdict
from datetime import datetime
from torch import nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing, DataParallel
from torch_scatter import scatter
from torch.utils.tensorboard import SummaryWriter

from DataClasses import lmdb_dataset, Dataset, choose_dataloader
from ModelFunctions import my_reshape, multigpu_available, predict, train, evaluate, inference

import tqdm

In [2]:
import sys
sys.path.append(os.path.expanduser('../ocpmodels/models'))
sys.path.append(os.path.expanduser('../../ocp_airi'))

In [3]:
#вызывается каждый раз, когда датасет отдаёт элемент (систему)
#делаем из данных матрицу векторов-атомов, список рёбер (edge_index) и матрицу векторов-рёбер; надо писать свою функцию для каждой сети
def preprocessing(system):
    keys = ['pos', 'atomic_numbers', 'cell', 'natoms', 'sid', 'y_relaxed']
    features_dict = {}
    for key in keys:
        features_dict[key] = system[key]
    return Data(**features_dict)

In [4]:
#вызывается каждый раз, когда датасет отдаёт элемент (систему)
#делаем из данных матрицу векторов-атомов, список рёбер (edge_index) и матрицу векторов-рёбер; надо писать свою функцию для каждой сети
def preprocessing(system):
    return system

In [5]:
m_table = pd.read_csv('https://gist.githubusercontent.com/GoodmanSciences/c2dd862cd38f21b0ad36b8f96b4bf1ee/raw/1d92663004489a5b6926e944c1b3d9ec5c40900e/Periodic%2520Table%2520of%2520Elements.csv')

In [6]:
# subsample = 'id'
# subsample =  'val_' + subsample
subsample = 'train'
model_type = 'id'   #id, electronegativity, dimenetpp

In [7]:
folder = '../checkpoints_for_analysis/'

In [8]:
#config
batch_size = 64
num_workers = 0

features_cols = ['feature_1']

target_col = 'y_relaxed'
lr = 0.001
epochs = 30

In [9]:
# #чтобы тензор по умолчанию заводился на куде
# if torch.cuda.is_available():
#     torch.set_default_tensor_type('torch.cuda.FloatTensor')
#     print('cuda')

In [10]:
#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
print(device)

cuda


In [11]:
#инициализируем валидационный датасет и валидационный итератор
# val_dataset_file_path = os.path.expanduser(f"../../ocp_datasets/data/is2re/all/{subsample}/data.lmdb")
val_dataset_file_path = os.path.expanduser(f"../../ocp_datasets/data/is2re/all/{subsample}/data.lmdb")

valid_set = Dataset(val_dataset_file_path, features_cols, target_col, preprocessing=preprocessing)
valid_generator = choose_dataloader(valid_set, batch_size=batch_size, num_workers=num_workers)

In [12]:
try:
    lmdb_dataset(train_dataset_file_path).describe()
except:
    pass

In [13]:
#model
if model_type == 'id':
    from spinconv import spinconv

    model = spinconv(None, None, 1, otf_graph=True, regress_forces=False)
    checkpoint = '2021-09-25-11-23-25_epoch_15_.pickle'
    model_chp = torch.load(folder + checkpoint,  map_location=device)
    state_dict = model_chp.module.state_dict()

In [14]:
if model_type == 'electronegativity':
    from spinconv_with_embeds_single import spinconv
    
    model = spinconv(None, None, 1, otf_graph=True, regress_forces=False, custom_embedding_value=[33])
    checkpoint = '2021-10-05-14-31-19_epoch_13_.pickle'
    model_chp = torch.load(folder + checkpoint,  map_location=device)
    state_dict = model_chp.module.state_dict()

In [15]:
if model_type == 'dimenetpp':
    from dimenet_plus_plus import DimeNetPlusPlusWrap
    
    kwargs = {
          'hidden_channels' : 256,
          'out_emb_channels': 192,
          'num_blocks': 3,
          'cutoff': 6.0,
          'num_radial': 6,
          'num_spherical': 7,
          'num_before_skip': 1,
          'num_after_skip': 2,
          'num_output_layers': 3,
          'regress_forces': False,
          'use_pbc': True
        }
    model = DimeNetPlusPlusWrap(None, None, 1, **kwargs, otf_graph=True)
    checkpoint = 'dimenetpp_all.pt'
    model_chp = torch.load(folder + checkpoint,  map_location=device)
    state_dict = model_chp["state_dict"]
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7+7:] # remove `module.module.`
        new_state_dict[name] = v
    state_dict = new_state_dict

In [16]:
if multigpu_available():
    model = DataParallel(model)

# #optimizer and loss
# optimizer = optim.AdamW(model.parameters(), lr=lr)
# criterion = nn.L1Loss()

#переносим на куду если она есть
model = model.to(device)
# criterion = criterion.to(device)

# model.load_state_dict(new_state_dict)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [17]:
iterator = valid_generator

In [None]:
result = torch.tensor([]).to(device)

multigpu_mode = multigpu_available()

model.eval()

with torch.no_grad():
    for batch in tqdm.tqdm(iterator):
        predictions, sids = predict(model, batch, multigpu_mode, device, inference=True)
        sids = my_reshape(sids)
        predictions = my_reshape(predictions)
        y_true = my_reshape(batch['y_relaxed'])
        mini_submit = torch.cat((sids, predictions, y_true), dim=1)
        result = torch.cat((result, mini_submit))        

  2%|██▉                                                                                                                                                                               | 117/7193 [00:42<42:16,  2.79it/s]

In [None]:
df = pd.DataFrame(result, columns=['sid', 'pred', 'target'])

In [None]:
for column in ['sid', 'pred', 'target']:
    df[column] = df[column].apply(lambda x: x.item())
    
df['sid'] = df['sid'].apply(int)
df['index'] = df.index

In [None]:
dataset = lmdb_dataset(val_dataset_file_path)

In [None]:
dataset.describe()

In [None]:
def make_composition(composition_dict):
    composition = []
    for key in composition_dict:
        element = m_table.iloc[key-1]['Symbol']
        composition.append(element)
        composition.append(str(composition_dict[key]))
    return composition

In [None]:
def get_composition(row):
    i = int(row['index'])
    from_lmdb = dataset[i]
    assert(int(from_lmdb['sid']) == int(row['sid']))
    atomic_numbers = from_lmdb['atomic_numbers']
    tags = from_lmdb['tags']
    cat = defaultdict(int)
    ads = defaultdict(int)
    for pair in zip(atomic_numbers.long().tolist(), tags.long().tolist()):
        number = pair[0]
        tag = pair[1]
        if tag in [0, 1]:
            cat[number] += 1
        else:
            ads[number] += 1
    
    cat_composition = '_'.join(make_composition(cat))
    ads_composition = '_'.join(make_composition(ads))
    return cat_composition + '__' + ads_composition

In [None]:
def get_fixed(row):
    i = int(row['index'])
    from_lmdb = dataset[i]
    assert(int(from_lmdb['sid']) == int(row['sid']))
    tags = from_lmdb['tags']
    counter = Counter(tags.long().tolist())
    return counter[0]/len(tags)

In [None]:
%%time
df['composition'] = df.apply(get_composition, axis=1)

In [None]:
%%time
df['fixed'] = df.apply(get_fixed, axis=1)

In [None]:
np.mean(abs(df['pred']-df['target']))

In [None]:
df

In [None]:
df.to_csv(f'../results_csv/{model_type}_{subsample}.csv')