In [1]:
import pickle

import numpy as np
import pandas as pd
import torch

import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
import torch.optim as optim

from sklearn.model_selection import train_test_split
from torch import nn
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from torch_scatter import scatter
from torch.utils.tensorboard import SummaryWriter

In [2]:
#делаем из данных матрицу векторов-атомов, список рёбер (edge_index) и матрицу векторов-рёбер
def simple_preprocessing(batch):
    atom_numbers = batch['atomic_numbers'].long()
    atom_embeds = F.one_hot(atom_numbers, num_classes=100)
    edge_index = torch.Tensor(batch['edge_index_new']).long()
    distances = torch.Tensor(batch['distances_new'])
    angles = torch.Tensor(batch['contact_solid_angles'])
    distances = torch.reshape(distances, (distances.shape[0], 1))
    angles = torch.reshape(angles, (angles.shape[0], 1))                         
    edges_embeds = torch.cat((distances, angles), 1)
    
    
    return Data(x=atom_embeds, edge_index=edge_index, edge_attr=edges_embeds)

In [3]:
#датасет, который умеет возвращать эелемент и собственную длину
class Dataset(Dataset):

    def __init__(self, data, features_fields, target_field, type_='train', preprocessing=simple_preprocessing):
        
        self.data = data[features_fields]
        self.length = len(data)
        self.target = torch.Tensor(data[target_field].values)
        self.type_ = type_
        self.preprocessing = preprocessing

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        
        system = self.preprocessing(self.data.iloc[index])
        
        if self.type_ == 'train':
            y = self.target[index]
            
            return system, y

$$
\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right)
$$

Гамма лежит в апдейт, квадратик в aggr, а фи в месседж; в этом примере гамма и фи -- умножение на матрицу после конкатенации, а квадратик -- суммирование

In [4]:
class GConv(MessagePassing):
    def __init__(self, dim_atom=100, dim_edge=2, out_channels=2):
        super(GConv, self).__init__(aggr='add')  # "Add" aggregation
        self.phi_output = 3
        self.lin_phi = torch.nn.Linear(dim_atom*2+dim_edge, self.phi_output, bias=False)
        self.lin_gamma = torch.nn.Linear(dim_atom + self.phi_output, out_channels, bias=False)

    def forward(self, batch):
        x = batch.x
        edge_index = batch.edge_index
        edge_attr = batch.edge_attr
        
        # x has shape [N -- количество атомов в системе(батче), in_channels -- размерность вектора-атома]
        # edge_index has shape [2, E] -- каждое ребро задаётся парой вершин

        # Start propagating messages. 
    
        return self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)  #не совсем понял что такое сайз

    def message(self, x, x_i, x_j, edge_attr):
        concatenated = torch.cat((x_i, x_j, edge_attr), 1)
        return self.lin_phi(concatenated)
        
    def update(self, aggr_out, x):
                
        concatenated = torch.cat((x, aggr_out), 1)

        return self.lin_gamma(concatenated)

In [5]:
#собственно нейросеть
class ConvNN(nn.Module):
    
    def __init__(self):
        
        super().__init__()          
        self.conv = GConv()
        self.lin = torch.nn.Linear(2, 1, bias=False)
        
    def forward(self, batch):
        convoluted_1 = self.conv(batch)
        scattered = scatter(convoluted_1, batch.batch, dim=0, reduce='sum')
        summed = scattered
        energy = self.lin(summed)
        
        return energy

In [6]:
#train -- ходим по батчам из итератора, обнуляем градиенты, предсказываем у, считаем лосс, считаем градиенты, делаем шаг оптимайзера, записываем лосс
def train(model, iterator, optimizer, criterion, print_every=50):
    
    epoch_loss = 0
    
    model.train()

    for i, (systems, ys) in enumerate(iterator):
        
        optimizer.zero_grad()
        predictions = model(systems).squeeze()
        loss = criterion(predictions.float(), ys.to(device).float())
        loss.backward()     
        
        optimizer.step()      
        
        epoch_loss += loss.item()  
        
        if not (i+1) % print_every:
            print(i)
            print(f'Loss: {epoch_loss/i}')
        
    return epoch_loss / len(iterator)

In [7]:
def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    
    model.train(False)
    model.eval()  
    
    with torch.no_grad():
        for systems, ys in iterator:   

            predictions = model(systems).squeeze()
            loss = criterion(predictions.float(), ys.to(device).float())        

            epoch_loss += loss.item()  
            
    return epoch_loss / len(iterator)

In [8]:
def inferens(model, iterator):
    y = torch.tensor([])

    model.train(False)
    model.eval()  
    
    with torch.no_grad():
        for systems, ys in iterator:   
          predictions = model(systemhs).squeeze()
          y = torch.cat((y, predictions))
      
    return y

## DATA

In [9]:
%%time
with open('/Users/humonen/Downloads/structures_train.pkl','rb') as f:
    data_ori = pickle.load(f)

CPU times: user 4.55 s, sys: 1.23 s, total: 5.77 s
Wall time: 5.83 s


In [10]:
%%time
#сливаем новые фичи и фичи из Data
for system in data_ori:
    for key in system['data']:
        system[key[0]] = key[1]
    del system['data']

CPU times: user 136 ms, sys: 2.95 ms, total: 139 ms
Wall time: 139 ms


In [11]:
%%time
df = pd.DataFrame(data_ori)

CPU times: user 58 ms, sys: 2.69 ms, total: 60.7 ms
Wall time: 59.2 ms


In [12]:
data_ori=[]

In [13]:
df

Unnamed: 0,id,covalent_radii,dipole_polarizability,electron_affinity,electronegativity,voloroi_volumes,voronoi_surface_areas,spherical_domain_radii,distances_new,contact_solid_angles,...,edge_index,fixed,force,natoms,pos,pos_relaxed,sid,tags,y_init,y_relaxed
0,0,"[1.21, 1.21, 1.21, 1.21, 1.21, 1.21, 1.21, 1.2...","[57.8, 57.8, 57.8, 57.8, 57.8, 57.8, 57.8, 57....","[0.43283, 0.43283, 0.43283, 0.43283, 0.43283, ...","[1.61, 1.61, 1.61, 1.61, 1.61, 1.61, 1.61, 1.6...","[89.3147203199472, 29.342061161368797, 15.7675...","[222.09641997716597, 75.76059098993872, 24.516...","[2.77296770148975, 1.913377530475639, 1.555579...","[4.704193115234375, 4.704193115234375, 2.84391...","[0.45406896877695196, 0.45406896877695196, 9.4...",...,"[[tensor(69), tensor(83), tensor(75), tensor(3...","[tensor(1.), tensor(0.), tensor(1.), tensor(0....","[[tensor(0.0767), tensor(0.0324), tensor(0.578...",86,"[[tensor(7.0256), tensor(0.), tensor(12.7346)]...","[[tensor(7.0256), tensor(0.), tensor(12.7346)]...",2472718,"[tensor(0), tensor(1), tensor(0), tensor(1), t...",6.282501,-0.025550
1,1,"[1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.7...","[112.0, 112.0, 112.0, 112.0, 112.0, 112.0, 112...","[0.426, 0.426, 0.426, 0.426, 0.426, 0.426, 0.4...","[1.33, 1.33, 1.33, 1.33, 1.33, 1.33, 1.33, 1.3...","[13.858382002979626, 13.858382236327781, 95.59...","[35.868790697767324, 35.88634058092103, 377.46...","[1.4900744671796733, 1.4900744755429902, 2.836...","[3.6807498931884766, 3.6807498931884766, 3.680...","[0.006754217151977511, 0.006754217151977511, 0...",...,"[[tensor(55), tensor(63), tensor(52), tensor(4...","[tensor(1.), tensor(1.), tensor(1.), tensor(1....","[[tensor(-0.4060), tensor(-0.1663), tensor(-0....",85,"[[tensor(5.3127), tensor(11.3843), tensor(12.7...","[[tensor(5.3127), tensor(11.3843), tensor(12.7...",1747243,"[tensor(0), tensor(0), tensor(0), tensor(0), t...",5.972082,-1.837069
2,2,"[1.45, 1.45, 1.45, 1.45, 1.45, 1.45, 1.45, 1.4...","[55.0, 55.0, 55.0, 55.0, 55.0, 55.0, 55.0, 55....","[1.3019999999999998, 1.3019999999999998, 1.301...","[1.93, 1.93, 1.93, 1.93, 1.93, 1.93, 1.93, 1.9...","[169.90046242317467, 16.909261919823788, 16.90...","[306.2641627020174, 34.77532226895127, 34.6920...","[3.435858874632096, 1.592252773326388, 1.59225...","[3.995115280151367, 3.995115280151367, 4.51798...","[7.348130132770407, 7.348130132770407, 3.46495...",...,"[[tensor(33), tensor(18), tensor(22), tensor(3...","[tensor(1.), tensor(1.), tensor(1.), tensor(1....","[[tensor(0.0028), tensor(0.0790), tensor(0.005...",66,"[[tensor(1.9976), tensor(0.), tensor(12.5561)]...","[[tensor(1.9976), tensor(0.), tensor(12.5561)]...",1824821,"[tensor(0), tensor(0), tensor(0), tensor(0), t...",0.686031,-0.152283
3,3,"[1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.7...","[103.0, 103.0, 103.0, 103.0, 103.0, 103.0, 103...","[0.013999999999999999, 0.013999999999999999, 0...","[1.3, 1.3, 1.3, 1.3, 1.3, 1.3, 1.3, 1.3, 1.3, ...","[35.108055970210486, 22.470958553752013, 25.92...","[73.35361598401914, 33.916565450356856, 44.848...","[2.0312937125432513, 1.7505626701983967, 1.835...","[3.111886501312256, 3.111886501312256, 5.22778...","[15.71836643737836, 15.71836643737836, 0.27114...",...,"[[tensor(9), tensor(53), tensor(45), tensor(36...","[tensor(1.), tensor(1.), tensor(0.), tensor(1....","[[tensor(-0.2585), tensor(-0.0456), tensor(0.0...",62,"[[tensor(9.3824), tensor(12.5032), tensor(14.5...","[[tensor(9.3824), tensor(12.5032), tensor(14.5...",1324413,"[tensor(0), tensor(0), tensor(1), tensor(0), t...",2.222443,-6.227071
4,4,"[1.47, 1.47, 1.47, 1.47, 1.47, 1.47, 1.47, 1.4...","[79.0, 79.0, 79.0, 79.0, 79.0, 79.0, 79.0, 79....","[0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.5...","[2.1, 2.1, 2.1, 2.1, 2.1, 2.1, 2.1, 2.1, 2.1, ...","[38.83699650086656, 14.355720997013236, 14.355...","[165.27906821064923, 34.537843861740484, 34.53...","[2.1008046766541075, 1.5076902742588287, 1.507...","[2.7617592811584473, 2.7617592811584473, 2.761...","[11.216767936903567, 11.216767936903567, 11.21...",...,"[[tensor(18), tensor(16), tensor(20), tensor(2...","[tensor(1.), tensor(1.), tensor(1.), tensor(1....","[[tensor(0.0334), tensor(0.5330), tensor(1.227...",67,"[[tensor(5.5235), tensor(3.1890), tensor(15.79...","[[tensor(5.5235), tensor(3.1890), tensor(15.79...",1219873,"[tensor(0), tensor(0), tensor(0), tensor(0), t...",1.410713,-2.453840
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,"[1.6, 1.6, 1.6, 1.6, 1.6, 1.6, 1.6, 1.6, 1.6, ...","[100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100...","[0.079, 0.079, 0.079, 0.079, 0.079, 0.079, 0.0...","[1.54, 1.54, 1.54, 1.54, 1.54, 1.54, 1.54, 1.5...","[14.936614852990052, 80.28670000683655, 14.123...","[21.662270008214907, 404.12627076188403, 30.05...","[1.5277578519180595, 2.6761985606158483, 1.499...","[2.687147378921509, 2.687147378921509, 3.10285...","[9.841924288121811, 9.841924288121811, 3.54409...",...,"[[tensor(32), tensor(37), tensor(33), tensor(2...","[tensor(1.), tensor(1.), tensor(0.), tensor(1....","[[tensor(0.0137), tensor(0.0397), tensor(-1.05...",54,"[[tensor(1.5514), tensor(2.1940), tensor(14.26...","[[tensor(1.5514), tensor(2.1940), tensor(14.26...",1278372,"[tensor(0), tensor(0), tensor(1), tensor(0), t...",1.932556,-3.993815
9996,9996,"[1.6, 1.6, 1.6, 1.6, 1.6, 1.6, 1.6, 1.6, 1.6, ...","[100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100...","[0.079, 0.079, 0.079, 0.079, 0.079, 0.079, 0.0...","[1.54, 1.54, 1.54, 1.54, 1.54, 1.54, 1.54, 1.5...","[18.577081819341252, 126.7628524297316, 119.67...","[42.58788956062807, 205.99971783560937, 201.42...","[1.642969980072988, 3.116265593029686, 3.05709...","[2.8137028217315674, 2.8137028217315674, 4.001...","[11.146685090517916, 11.146685090517916, 0.054...",...,"[[tensor(25), tensor(1), tensor(29), tensor(5)...","[tensor(1.), tensor(1.), tensor(1.), tensor(1....","[[tensor(0.0020), tensor(-0.4878), tensor(-0.3...",62,"[[tensor(0.), tensor(4.0017), tensor(12.5153)]...","[[tensor(0.), tensor(4.0017), tensor(12.5153)]...",992060,"[tensor(0), tensor(0), tensor(0), tensor(0), t...",1.960140,-3.572109
9997,9997,"[1.39, 1.39, 1.39, 1.39, 1.39, 1.39, 1.39, 1.3...","[53.0, 53.0, 53.0, 53.0, 53.0, 53.0, 53.0, 53....","[1.112067, 1.112067, 1.112067, 1.112067, 1.112...","[1.96, 1.96, 1.96, 1.96, 1.96, 1.96, 1.96, 1.9...","[25.954856203575492, 56.779363411229504, 25.95...","[45.61463641664654, 157.54634500490403, 45.672...","[1.8367216249953886, 2.3843351611718333, 1.836...","[4.141856670379639, 4.141856670379639, 2.77277...","[2.673288609842866, 2.673288609842866, 13.9762...",...,"[[tensor(65), tensor(73), tensor(38), tensor(4...","[tensor(1.), tensor(1.), tensor(1.), tensor(0....","[[tensor(-0.2149), tensor(0.3061), tensor(-0.3...",79,"[[tensor(7.4179), tensor(7.6097), tensor(20.61...","[[tensor(7.4179), tensor(7.6097), tensor(20.61...",1314291,"[tensor(0), tensor(0), tensor(0), tensor(1), t...",1.902864,-1.113694
9998,9998,"[1.22, 1.22, 1.22, 1.22, 1.22, 1.22, 1.22, 1.2...","[50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50....","[0.43, 0.43, 0.43, 0.43, 0.43, 0.43, 0.43, 0.4...","[1.81, 1.81, 1.81, 1.81, 1.81, 1.81, 1.81, 1.8...","[16.90193321730596, 16.209973370978155, 26.806...","[33.43475488612941, 24.71501572541707, 106.786...","[1.5920227054652445, 1.5699936554395928, 1.856...","[2.9440391063690186, 2.9440391063690186, 3.947...","[9.090653719734787, 9.090653719734787, 0.08500...",...,"[[tensor(76), tensor(72), tensor(78), tensor(7...","[tensor(1.), tensor(1.), tensor(1.), tensor(1....","[[tensor(-0.1841), tensor(-0.0496), tensor(-0....",97,"[[tensor(0.2012), tensor(1.3441), tensor(14.43...","[[tensor(0.2012), tensor(1.3441), tensor(14.43...",1972741,"[tensor(0), tensor(0), tensor(0), tensor(0), t...",7.348472,-3.752950


In [14]:
df.columns

Index(['id', 'covalent_radii', 'dipole_polarizability', 'electron_affinity',
       'electronegativity', 'voloroi_volumes', 'voronoi_surface_areas',
       'spherical_domain_radii', 'distances_new', 'contact_solid_angles',
       'edge_index_new', 'atomic_numbers', 'cell', 'cell_offsets', 'distances',
       'edge_index', 'fixed', 'force', 'natoms', 'pos', 'pos_relaxed', 'sid',
       'tags', 'y_init', 'y_relaxed'],
      dtype='object')

In [15]:
#делим на обучующую и валидационную выборки
df_train, df_val = train_test_split(df, test_size=0.15)
df = []

In [16]:
#сбрасываем индексы
df_train = df_train.reset_index()
df_val = df_val.reset_index()

In [17]:
batch_size = 64
num_workers = 0

In [18]:
# features_cols = ['voloroi_volumes', 'voronoi_surface_areas', 'electronegativity', 
#                  'dipole_polarizability', 'edge_index_new', 'distances_new', 'contact_solid_angles']

features_cols = ['atomic_numbers', 'edge_index_new', 'distances_new', 'contact_solid_angles']
target_col = 'y_relaxed'

In [19]:
#инициализируем тренировочный датасети и тренировочный итератор
training_set = Dataset(df_train, features_cols, target_col)
training_generator = DataLoader(training_set, batch_size=batch_size, num_workers=num_workers)

In [20]:
#инициализируем валидационный датасет и валидационный итератор
valid_set = Dataset(df_val, features_cols, target_col)
valid_generator = DataLoader(valid_set, batch_size=batch_size, num_workers=num_workers)

In [21]:
df_train = []
df_val = []

## MODEL

In [22]:
#tensorboard writer
writer = SummaryWriter('/Users/humonen/Documents/our_base_model/tensorboard_logs')

In [23]:
#хочется уметь рисовать граф модели
if False:
    trace_system = []
    writer.add_graph(CGConv, trace_system)

In [24]:
writer.close()

In [25]:
#чтобы тензор по умолчанию заводился на куде
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    print('cuda')

In [26]:
#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
print(device)

cpu


In [27]:
#model
model = ConvNN()

#optimizer and loss
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.L1Loss()

#переносим на куду если она есть
model = model.to(device)
criterion = criterion.to(device)

## Training

In [28]:
%%time
loss = []
loss_eval = []
epochs = 20
print(f'Start training model {str(model)}')
for i in range(epochs):
    print(i)
    loss.append(train(model, training_generator, optimizer, criterion))
    loss_eval.append(evaluate(model, valid_generator, criterion))

Start training model ConvNN(
  (conv): GConv(
    (lin_phi): Linear(in_features=202, out_features=3, bias=False)
    (lin_gamma): Linear(in_features=103, out_features=2, bias=False)
  )
  (lin): Linear(in_features=2, out_features=1, bias=False)
)
0
49
Loss: 4.720358162510152
99
Loss: 3.495723149993203
1
49
Loss: 1.785052550082304
99
Loss: 1.7416471254946007
2
49
Loss: 1.635594020084459
99
Loss: 1.6261963326521593
3
49
Loss: 1.603816175947384
99
Loss: 1.599812544957556
4
49
Loss: 1.594362769808088
99
Loss: 1.588231946482803
5
49
Loss: 1.590753202535668
99
Loss: 1.5826570481965037
6
49
Loss: 1.583629540034703
99
Loss: 1.5754588324614245
7
49
Loss: 1.5793631028155892
99
Loss: 1.5715223322011003
8
49
Loss: 1.5708349237636643
99
Loss: 1.5635347149588845
9
49
Loss: 1.5653838868043861
99
Loss: 1.5601755886366873
10
49
Loss: 1.5593294878395236
99
Loss: 1.5534798605273468
11
49
Loss: 1.553244218534353
99
Loss: 1.5479283128121886
12
49
Loss: 1.5475720386115872
99
Loss: 1.5406298637390137
13
49
L

In [29]:
loss_eval

[1.782741591334343,
 1.5935758848985035,
 1.5522674173116684,
 1.5353226512670517,
 1.5284293740987778,
 1.52466714878877,
 1.5180447647968929,
 1.5070415536562602,
 1.5027457525332768,
 1.498314768075943,
 1.4961438278357189,
 1.491486723224322,
 1.4905264029900234,
 1.4870769033829372,
 1.486108124256134,
 1.4845530142386754,
 1.4766766677300136,
 1.4762283116579056,
 1.4700354288021724,
 1.4672463486591976]