In [1]:
import os

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
import torch.optim as optim


from datetime import datetime
from torch import nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter
from torch.utils.tensorboard import SummaryWriter

from DataClasses import lmdb_dataset, Dataset
from ModelFunctions import train, evaluate, inference

In [2]:
def my_reshape(tensor):
    return torch.reshape(tensor, (tensor.shape[0], 1))

In [3]:
def preprocessing(system):
    #spherical_radii = torch.Tensor(system['spherical_domain_radii'])
    #spherical_radii = my_reshape(spherical_radii)
    
    tags = system['tags'].long().to(device)
    tags = F.one_hot(tags, num_classes=3)
    
    atom_numbers = system['atomic_numbers'].long().to(device)
    atom_numbers = F.one_hot(atom_numbers, num_classes=100)
    
    voronoi_volumes = system['voronoi_volumes'].float().to(device)
    voronoi_volumes = my_reshape(voronoi_volumes)
    
    atom_features = (tags, atom_numbers, voronoi_volumes)#, spherical_radii)
    atom_embeds = torch.cat(atom_features, 1)
    
    edge_index = system['edge_index_new'].long().to(device)
    
    distances = system['distances_new'].float().to(device)
    distances = my_reshape(distances)
    
    angles = system['contact_solid_angles'].float().to(device)
    angles = my_reshape(angles)
    
    edges_embeds = torch.cat((distances, angles), 1)
    
    
    return Data(x=atom_embeds.to(device), edge_index=edge_index.to(device), edge_attr=edges_embeds.to(device))

In [4]:
class GaussianSmearing(nn.Module):
    def __init__(self, start=0.0, stop=8.0, num_gaussians=150):
        super(GaussianSmearing, self).__init__()
        offset = torch.linspace(start, stop, num_gaussians)
        self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
        self.register_buffer('offset', offset)

    def forward(self, dist):
        dist = dist.view(-1, 1) - self.offset.view(1, -1)
        return torch.exp(self.coeff * torch.pow(dist, 2))

In [5]:
class ShiftedSoftplus(nn.Module):
    def __init__(self):
        super(ShiftedSoftplus, self).__init__()
        self.shift = torch.log(torch.tensor(2.0)).item()

    def forward(self, x):
        return F.softplus(x) - self.shift

In [6]:
class CFconv(MessagePassing):
    def __init__(self, dim_hidden, dim_edge):   #dim_edge можно и не передавать
        super(CFconv, self).__init__(aggr='add')
        self.rbf = GaussianSmearing(num_gaussians=smearing['rbf'])
        self.sa_bins = GaussianSmearing(start=0.0, stop=50.0, num_gaussians=smearing['sa_bins']) #кладём телесные углы в бины
        self.blocks = nn.Sequential(nn.Linear(smearing['rbf']+smearing['sa_bins'], dim_hidden, bias=True),
                                   ShiftedSoftplus(),
                                   nn.Linear(dim_hidden, dim_hidden, bias=True),
                                   ShiftedSoftplus())
        self.lin_phi = torch.nn.Linear(dim_hidden, dim_hidden, bias=False)
        
        self.reset_parameters()
        
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.blocks[0].weight)
        self.blocks[0].bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.blocks[2].weight)
        self.blocks[0].bias.data.fill_(0)

    def forward(self, batch):
        x = batch['x']
        edge_index = batch['edge_index']
        rbf_dist = self.rbf(batch['edge_attr'][:, 0])
        bins_angles = self.sa_bins(batch['edge_attr'][:, 1])
        edge_attr = torch.cat((rbf_dist, bins_angles), 1)
        edge_attr = self.blocks(edge_attr)
        
    
        return self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)

    def message(self, x, x_i, x_j, edge_attr):
        new_edges = self.lin_phi(edge_attr)
        hd_product = x_j * new_edges
        return hd_product
        
    def update(self, aggr_out):

        return aggr_out

In [7]:
class Interaction(nn.Module):
    
    def __init__(self, dim_hidden, dim_edge):
        super().__init__()
        self.atom_wise_64_1 = nn.Linear(dim_hidden, dim_hidden, bias=True)
        self.cfconv = CFconv(dim_hidden, dim_edge)
        self.atom_wise_64_2 = nn.Linear(dim_hidden, dim_hidden, bias=True)
        self.shifted_softplus = ShiftedSoftplus()
        self.atom_wise_64_3 = nn.Linear(dim_hidden, dim_hidden, bias=True)
        
        self.reset_parameters()
        
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.atom_wise_64_1.weight)
        self.atom_wise_64_1.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.atom_wise_64_2.weight)
        self.atom_wise_64_2.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.atom_wise_64_3.weight)
        self.atom_wise_64_3.bias.data.fill_(0)
        
    def forward(self, batch):
        x_input = batch['x'].clone().detach()
        batch['x'] = self.atom_wise_64_1(batch['x'])
        conved = self.cfconv(batch)
        conved = self.atom_wise_64_2(conved)
        ssp = self.shifted_softplus(conved)
        v = self.atom_wise_64_3(ssp)
        
        return x_input + v

In [8]:
class ConvNN(nn.Module):
    
    def __init__(self, dim_atom=103, dim_edge=2, dim_hidden=64):
        
        super().__init__()
        self.embedding = nn.Linear(dim_atom, dim_hidden)
        self.interaction1 = Interaction(dim_hidden, dim_edge)
        self.interaction2 = Interaction(dim_hidden, dim_edge)
        self.interaction3 = Interaction(dim_hidden, dim_edge)
        self.shifted_softplus = ShiftedSoftplus()
        self.atom_wise_32 = nn.Linear(dim_hidden, 32, bias=True)
        self.atom_wise_1 = nn.Linear(32, 1, bias=True)
        
        self.reset_parameters()
        
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.embedding.weight)
        torch.nn.init.xavier_uniform_(self.atom_wise_32.weight)
        self.atom_wise_32.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.atom_wise_1.weight)
        self.atom_wise_1.bias.data.fill_(0)
        
    def forward(self, batch):
        batch['x'] = self.embedding(batch['x'])
        batch['x'] = self.interaction1(batch)
        batch['x'] = self.interaction2(batch)
        batch['x'] = self.interaction3(batch)
        x_32 = self.atom_wise_32(batch['x'])
        x_32 = self.shifted_softplus(x_32)
        energies = self.atom_wise_1(x_32)
        energy = scatter(energies, batch['batch'], dim=0, reduce='sum')
        
        return energy

In [9]:
#config
batch_size = 50
num_workers = 0

features_cols = ['atomic_numbers', 'edge_index_new', 'distances_new', 
                 'contact_solid_angles', 'tags', 'voronoi_volumes', 'spherical_domain_radii'] #бесполезный массив

target_col = 'y_relaxed'
lr = 0.05
epochs = 10
smearing = {'rbf' : 150, 'sa_bins' : 150}

In [10]:
#чтобы тензор по умолчанию заводился на куде
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    print('cuda')

In [11]:
#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
print(device)

cpu


In [12]:
#инициализируем тренировочный датасети и тренировочный итератор
train_dataset_file_path= os.path.expanduser("../../ocp_datasets/data/is2re/10k/train/data_mod.lmdb")

training_set = Dataset(train_dataset_file_path, features_cols, target_col, preprocessing=preprocessing)
training_generator = DataLoader(training_set, batch_size=batch_size, num_workers=num_workers)

In [13]:
#инициализируем валидационный датасет и валидационный итератор
val_dataset_file_path = os.path.expanduser("../../ocp_datasets/data/is2re/all/val_ood_both/data_mod.lmdb")

valid_set = Dataset(val_dataset_file_path, features_cols, target_col, preprocessing=preprocessing)
valid_generator = DataLoader(valid_set, batch_size=batch_size, num_workers=num_workers)

In [14]:
try:
    lmdb_dataset(train_dataset_file_path).describe()
except:
    pass

item: 0


In [15]:
training_set[0][0]['edge_attr'].shape

torch.Size([1214, 2])

In [16]:
#model
model = ConvNN(dim_atom=training_set[0][0]['x'].shape[1], dim_edge=training_set[0][0]['edge_attr'].shape[1])

#optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion = nn.L1Loss()

#переносим на куду если она есть
model = model.to(device)
criterion = criterion.to(device)

In [17]:
timestamp = str(datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))

print(timestamp)

2021-09-15-13-51-48


In [18]:
#tensorboard writer, при первом запуске надо руками сделать папку для логов

# server
#log_folder_path = "../../ocp_results/logs/tensorboard/out_base_model"

# colab
# log_folder_path = "/content/drive/MyDrive/ocp_results/logs/tensorboard/out_base_model"

# user_specific 
log_file_path = "../logs/tensorboard_airi"

writer = SummaryWriter(log_file_path + '/' + timestamp)

In [19]:
%%time
logfile_str = {
    "train_dataset_file_path": train_dataset_file_path,
    "val_dataset_file_path": val_dataset_file_path,
    "features_cols": features_cols,
    "target_col": target_col,
    "batch_size": batch_size,
    "num_workers": num_workers,
    "epochs": epochs,
    "lr": lr,
    "smearing" : smearing
}

#граф модели
trace_system = dict(list(next(iter(training_generator))[0]))
writer.add_graph(model, trace_system)
writer.add_text(timestamp, str(logfile_str))

CPU times: user 5.47 s, sys: 1.19 s, total: 6.66 s
Wall time: 4.03 s


## Training

In [20]:
%%time
loss = []
loss_eval = []

print(timestamp)
print(f'Start training model {str(model)}')
for i in range(epochs):
    loss.append(train(model, training_generator, optimizer, criterion, epoch=i, writer=writer, device=device))
    loss_eval.append(evaluate(model, valid_generator, criterion, epoch=i, writer=writer, device=device))

2021-09-15-13-51-48
Start training model ConvNN(
  (embedding): Linear(in_features=104, out_features=64, bias=True)
  (interaction1): Interaction(
    (atom_wise_64_1): Linear(in_features=64, out_features=64, bias=True)
    (cfconv): CFconv(
      (rbf): GaussianSmearing()
      (sa_bins): GaussianSmearing()
      (blocks): Sequential(
        (0): Linear(in_features=300, out_features=64, bias=True)
        (1): ShiftedSoftplus()
        (2): Linear(in_features=64, out_features=64, bias=True)
        (3): ShiftedSoftplus()
      )
      (lin_phi): Linear(in_features=64, out_features=64, bias=False)
    )
    (atom_wise_64_2): Linear(in_features=64, out_features=64, bias=True)
    (shifted_softplus): ShiftedSoftplus()
    (atom_wise_64_3): Linear(in_features=64, out_features=64, bias=True)
  )
  (interaction2): Interaction(
    (atom_wise_64_1): Linear(in_features=64, out_features=64, bias=True)
    (cfconv): CFconv(
      (rbf): GaussianSmearing()
      (sa_bins): GaussianSmearing()
  

KeyboardInterrupt: 