In [3]:
import os

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
import torch.optim as optim


from datetime import datetime
from torch import nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter
from torch.utils.tensorboard import SummaryWriter

from DataClasses import lmdb_dataset, Dataset
from ModelFunctions import train, evaluate, inference

In [None]:
#вызывается каждый раз, когда датасет отдаёт элемент (систему)
#делаем из данных матрицу векторов-атомов, список рёбер (edge_index) и матрицу векторов-рёбер; надо писать свою функцию для каждой сети
def preprocessing(system):
    
    atom_embeds = torch.cat(atom_features, 1)
    
    edge_index = system['edge_index_new'].long()

    edges_embeds = torch.cat(edge_features, 1)
    
    return Data(x=atom_embeds.to(device), edge_index=edge_index.to(device), edge_attr=edges_embeds.to(device))

$$
\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right)
$$

Гамма лежит в апдейт, квадратик в aggr, а фи в месседж; в этом примере квадратик -- суммирование

In [6]:
class DistanceBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, max_num_elements, scalar_max):
        super.__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.max_num_elements = max_num_elements

        self.fc1 = nn.Linear(self.in_channels, self.out_channels)

    def forward(self, lmdb_element):
        x = lmdb_dataset["distances_new"]
        x = self.fc1(x)
        return x



In [7]:
class ProjectLatLongSphere(torch.nn.Module):
    def __init__(self, sphere_size_lat, sphere_size_long):
        super(ProjectLatLongSphere, self).__init__()
        self.sphere_size_lat = sphere_size_lat
        self.sphere_size_long = sphere_size_long

    def forward(self, x, length, index, delta, source_edge_index):
        device = x.device
        hidden_channels = len(x[0])

        x_proj = torch.zeros(
            length * self.sphere_size_lat * self.sphere_size_long,
            hidden_channels,
            device=device,
        )
        splat_values = x[source_edge_index]

        # Perform bilinear splatting
        x_proj.index_add_(0, index[0], splat_values * (delta[0].view(-1, 1)))
        x_proj.index_add_(0, index[1], splat_values * (delta[1].view(-1, 1)))
        x_proj.index_add_(0, index[2], splat_values * (delta[2].view(-1, 1)))
        x_proj.index_add_(0, index[3], splat_values * (delta[3].view(-1, 1)))

        x_proj = x_proj.view(
            length,
            self.sphere_size_lat * self.sphere_size_long,
            hidden_channels,
        )
        x_proj = torch.transpose(x_proj, 1, 2).contiguous()
        x_proj = x_proj.view(
            length,
            hidden_channels,
            self.sphere_size_lat,
            self.sphere_size_long,
        )

        return x_proj

In [8]:
class SpinConvBlock(torch.nn.Module):
    def __init__(
        self,
        in_hidden_channels,
        mid_hidden_channels,
        sphere_size_lat,
        sphere_size_long,
        act,
        lmax
    ):
        super(self).__init__()
        self.in_hidden_channels = in_hidden_channels
        self.mid_hidden_channels = mid_hidden_channels
        self.sphere_size_lat = sphere_size_lat
        self.sphere_size_long = sphere_size_long
        # self.sphere_message = sphere_message
        self.act = act
        self.lmax = lmax
        self.num_groups = self.in_hidden_channels // 8

        self.ProjectLatLongSphere = ProjectLatLongSphere(
            sphere_size_lat, sphere_size_long
        )
        # assert self.sphere_message in [
        #     "fullconv",
        #     "rotspharmwd",
        # ]

        # if self.sphere_message == "fullconv":
        padding = self.sphere_size_long // 2
        self.conv1 = nn.Conv1d(
            self.in_hidden_channels * self.sphere_size_lat,
            self.mid_hidden_channels,
            self.sphere_size_long,
            groups=self.in_hidden_channels // 8,
            padding=padding,
            padding_mode="circular",
        )
        self.pool = nn.AvgPool1d(sphere_size_long)

        self.GroupNorm = nn.GroupNorm(
            self.num_groups, self.mid_hidden_channels
        )

    def forward(self, x, out_size, proj_index, proj_delta, proj_src_index):
        x = self.ProjectLatLongSphere(
            x, out_size, proj_index, proj_delta, proj_src_index
        )

        if self.sphere_message in ["fullconv"]:
            x = x.view(
                -1,
                self.in_hidden_channels * self.sphere_size_lat,
                self.sphere_size_long,
            )
            x = self.conv1(x)
            x = self.act(x)
            # Pool in the longitudal direction
            x = self.pool(x[:, :, 0 : self.sphere_size_long])
            x = x.view(out_size, -1)

        x = self.GroupNorm(x)

        return x


In [None]:
class EmbeddingBlock(torch.nn.Module): #TODO перепроверить все названия переменных
    def __init__(self, embedding_dim, embedded_hidden, element_number, fc1_input,  fc1_to_fc2_hidden, fc2_embedded_hidden, fc3_out, activation):
        super().__init__()
        self.max_number_of_elements = element_number
        
        self.embeding = nn.Embedding(num_embeddings=self.max_number_of_elements, embedding_dim=embedding_dim)

        nn.init.uniform_(self.source_embed.weight.data, -0.0001, 0.0001)
        nn.init.uniform_(self.target_embed.weight.data, -0.0001, 0.0001)

        self.embedding_dim = embedding_dim
        self.embedding_hidden = embedded_hidden
        self.fc_embedding = nn.Linear(2 * self.embedding_dim, embedded_hidden)
        
        self.fc2_embedded_hidden = fc2_embedded_hidden

        self.fc1 = nn.Linear(fc1_input, fc1_to_fc2_hidden)
        self.fc2 = nn.Linear(fc1_to_fc2_hidden, fc2_embedded_hidden*self.embedding_dim)

        self.fc3 = nn.Linear(fc2_embedded_hidden, fc3_out)

        self.softmax = nn.Softmax(dim=1)

        self.activation = activation

    # def buildEmbed(x, connectivity):
    #     for i in 

    def forward(self, batch):
        
        x = torch.tensor(batch['x'], dtype=torch.long)
        edge_index = torch.tensor(batch['edge_index'], dtype=torch.long)
        edge_attr = torch.tensor(batch['edge_attr'], dtype=torch.long)

        source_embedding = self.embeding(x[edge_index[0]])
        target_embedding = self.embeding(x[edge_index[1]])
        embedding = torch.cat([source_embedding, target_embedding], dim=1)
        embedding = self.fc_embedding(embedding)
        embedding = self.softmax(embedding)


        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        x = self.activation(x)
        
        #Тута будет ошибка
        x = (
            x.view(-1, self.embedding_hidden, self.fc2_embedded_hidden)
        ) * (embedding.view(-1, self.embedding_hidden, 1))

        x = torch.sum(x, dim=1)
        x = self.fc3(x)
        return x

In [None]:
class MessageBlock(torch.nn.Module):
    def __init__(self, in_hidden_channels, out_hidden_channels, mid_hidden_channels,
     embedding_size, sphere_size_lat, sphere_size_long, max_num_elements, act, lmax):
        super().__init__()
        self.in_hidden_channels = in_hidden_channels
        self.out_hidden_channels = out_hidden_channels
        self.mid_hidden_channels = mid_hidden_channels
        self.act = act
        self.lmax = lmax
        self.embedding_size = embedding_size
        self.sphere_size_lat = sphere_size_lat
        self.sphere_size_long = sphere_size_long
        self.max_num_elements = max_num_elements
        self.num_embedding_basis = 8

        self.spinconvblock = SpinConvBlock(
            self.in_hidden_channels, self.mid_hidden_channels,
            self.sphere_size_lat, sphere_size_long, self.act,
            self.lmax
        )

        self.embeddingblock1 = EmbeddingBlock(
            fc1_input = self.mid_hidden_channels,
            fc1_to_fc2_hidden = self.mid_hidden_channels,
            fc2_embedded_hidden = self.mid_hidden_channels,
            fc3_out = self.mid_hidden_channels,
            embedding_dim = self.embedding_size, 
            embedded_hidden = self.num_embedding_basis, 
            activation = self.act
        )

        self.embeddingblock2 = EmbeddingBlock(
            fc1_input = self.mid_hidden_channels,
            fc1_to_fc2_hidden = self.mid_hidden_channels,
            fc2_embedded_hidden = self.mid_hidden_channels,
            fc3_out = self.out_hidden_channels
            embedding_dim = self.embedding_size, 
            embedded_hidden = self.num_embedding_basis, 
            activation = self.act
        )

        self.distfc1 = nn.Linear(self.mid_hidden_channels, self.mid_hidden_channels)
        self.distfc2 = nn.Linear(self.mid_hidden_channels, self.mid_hidden_channels)

    def forward(self, x, x_dist, source_element, target_element, proj_index, proj_delta, proj_src_index):
                out_size = len(x)

        x = self.spinconvblock(
            x, out_size, proj_index, proj_delta, proj_src_index
        )

        x = self.embeddingblock1(x, source_element, target_element)

        x_dist = self.distfc1(x_dist)
        x_dist = self.act(x_dist)
        x_dist = self.distfc2(x_dist)
        x = x + x_dist

        x = self.act(x)
        x = self.embeddingblock2(x, source_element, target_element)

        return x


In [None]:
@registry.register_model("spinconv")
class spinconv(BaseModel):
    def __init__(
        self,
        num_atoms,  # not used
        bond_feat_dim,  # not used
        num_targets,
        use_pbc=True,
        regress_forces=True,
        otf_graph=False,
        hidden_channels=32,
        mid_hidden_channels=200,
        num_interactions=1,
        num_basis_functions=200,
        basis_width_scalar=1.0,
        max_num_neighbors=20,
        sphere_size_lat=15,
        sphere_size_long=9,
        cutoff=10.0,
        distance_block_scalar_max=2.0,
        max_num_elements=90,
        embedding_size=32,
        show_timing_info=False,
        sphere_message="fullconv",  # message block sphere representation
        output_message="fullconv",  # output block sphere representation
        lmax=False,
        force_estimator="random",
        model_ref_number=0,
        readout="add",
        num_rand_rotations=5,
        scale_distances=True,
    ):
        super(spinconv, self).__init__()

        self.num_targets = num_targets
        self.num_random_rotations = num_rand_rotations
        self.regress_forces = regress_forces
        self.use_pbc = use_pbc
        self.cutoff = cutoff
        self.otf_graph = otf_graph
        self.show_timing_info = show_timing_info
        self.max_num_elements = max_num_elements
        self.mid_hidden_channels = mid_hidden_channels
        self.sphere_size_lat = sphere_size_lat
        self.sphere_size_long = sphere_size_long
        self.num_atoms = 0
        self.hidden_channels = hidden_channels
        self.embedding_size = embedding_size
        self.max_num_neighbors = max_num_neighbors
        self.sphere_message = sphere_message
        self.output_message = output_message
        self.force_estimator = force_estimator
        self.num_basis_functions = num_basis_functions
        self.distance_block_scalar_max = distance_block_scalar_max
        self.grad_forces = False
        self.num_embedding_basis = 8
        self.lmax = lmax
        self.scale_distances = scale_distances
        self.basis_width_scalar = basis_width_scalar

        if self.sphere_message in ["spharm", "rotspharmroll", "rotspharmwd"]:
            assert self.lmax, "lmax must be defined for spherical harmonics"
        if self.output_message in ["spharm", "rotspharmroll", "rotspharmwd"]:
            assert self.lmax, "lmax must be defined for spherical harmonics"

        # variables used for display purposes
        self.counter = 0
        self.start_time = time.time()
        self.total_time = 0
        self.model_ref_number = model_ref_number

        if self.force_estimator == "grad":
            self.grad_forces = True

        # self.act = ShiftedSoftplus()
        self.act = Swish()

        self.distance_expansion_forces = GaussianSmearing(
            0.0,
            cutoff,
            num_basis_functions,
            basis_width_scalar,
        )

        # Weights for message initialization
        self.embeddingblock2 = EmbeddingBlock(
            self.mid_hidden_channels,
            self.hidden_channels,
            self.mid_hidden_channels,
            self.embedding_size,
            self.num_embedding_basis,
            self.max_num_elements,
            self.act,
        )
        self.distfc1 = nn.Linear(
            self.mid_hidden_channels, self.mid_hidden_channels
        )
        self.distfc2 = nn.Linear(
            self.mid_hidden_channels, self.mid_hidden_channels
        )

        self.dist_block = DistanceBlock(
            self.num_basis_functions,
            self.mid_hidden_channels,
            self.max_num_elements,
            self.distance_block_scalar_max,
            self.distance_expansion_forces,
            self.scale_distances,
        )

        self.message_blocks = ModuleList()
        for _ in range(num_interactions):
            block = MessageBlock(
                hidden_channels,
                hidden_channels,
                mid_hidden_channels,
                embedding_size,
                self.sphere_size_lat,
                self.sphere_size_long,
                self.max_num_elements,
                self.sphere_message,
                self.act,
                self.lmax,
            )
            self.message_blocks.append(block)

        self.energyembeddingblock = EmbeddingBlock(
            hidden_channels,
            1,
            mid_hidden_channels,
            embedding_size,
            8,
            self.max_num_elements,
            self.act,
        )

        if force_estimator == "random":
            self.force_output_block = ForceOutputBlock(
                hidden_channels,
                2,
                mid_hidden_channels,
                embedding_size,
                self.sphere_size_lat,
                self.sphere_size_long,
                self.max_num_elements,
                self.output_message,
                self.act,
                self.lmax,
            )

    @conditional_grad(torch.enable_grad())
    def forward(self, data):
        self.device = data.pos.device
        self.num_atoms = len(data.batch)
        self.batch_size = len(data.natoms)
        outputs = self._forward_helper(
            data, edge_index, edge_distance, edge_distance_vec
        )
        if self.show_timing_info is True:
            torch.cuda.synchronize()
            print(
                "Memory: {}\t{}\t{}".format(
                    len(edge_index[0]),
                    torch.cuda.memory_allocated()
                    / (1000 * len(edge_index[0])),
                    torch.cuda.max_memory_allocated() / 1000000,
                )
            )

        return outputs

    # restructure forward helper for conditional grad
    def _forward_helper(
        self, data, edge_index, edge_distance, edge_distance_vec
    ):
        ###############################################################
        # Initialize messages
        ###############################################################
        #Заменить и сорс и таргет эмбеддинги на наши
        source_element = data.atomic_numbers[edge_index[0, :]].long()
        target_element = data.atomic_numbers[edge_index[1, :]].long()

        x_dist = self.dist_block(edge_distance, source_element, target_element)

        x = x_dist
        x = self.distfc1(x)
        x = self.act(x)
        x = self.distfc2(x)
        x = self.act(x)
        x = self.embeddingblock2(x, source_element, target_element)

        ###############################################################
        # Update messages using block interactions
        ###############################################################

        edge_rot_mat = self._init_edge_rot_mat(
            data, edge_index, edge_distance_vec
        )
        (
            proj_edges_index,
            proj_edges_delta,
            proj_edges_src_index,
        ) = self._project2D_edges_init(
            edge_rot_mat, edge_index, edge_distance_vec
        )

        for block_index, interaction in enumerate(self.message_blocks):
            x_out = interaction(
                x,
                x_dist,
                source_element,
                target_element,
                proj_edges_index,
                proj_edges_delta,
                proj_edges_src_index,
            )

            if block_index > 0:
                x = x + x_out
            else:
                x = x_out

        ###############################################################
        # Decoder
        # Compute the forces and energies from the messages
        ###############################################################
        assert self.force_estimator in ["random", "grad"]

        energy = scatter(x, edge_index[1], dim=0, dim_size=data.num_nodes) / (
            self.max_num_neighbors / 2.0 + 1.0
        )
        atomic_numbers = data.atomic_numbers.long()
        energy = self.energyembeddingblock(
            energy, atomic_numbers, atomic_numbers
        )
        energy = scatter(energy, data.batch, dim=0)

        if not self.regress_forces:
            return energy
        else:
            return energy, forces

    def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec):
        device = data.pos.device
        num_atoms = len(data.batch)

        edge_vec_0 = edge_distance_vec
        edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0 ** 2, dim=1))

        if torch.min(edge_vec_0_distance) < 0.0001:
            print(
                "Error edge_vec_0_distance: {}".format(
                    torch.min(edge_vec_0_distance)
                )
            )
            (minval, minidx) = torch.min(edge_vec_0_distance, 0)
            print(
                "Error edge_vec_0_distance: {} {} {} {} {}".format(
                    minidx,
                    edge_index[0, minidx],
                    edge_index[1, minidx],
                    data.pos[edge_index[0, minidx]],
                    data.pos[edge_index[1, minidx]],
                )
            )

        avg_vector = torch.zeros(num_atoms, 3, device=device)
        weight = 0.5 * (
            torch.cos(edge_vec_0_distance * PI / self.cutoff) + 1.0
        )
        avg_vector.index_add_(
            0, edge_index[1, :], edge_vec_0 * weight.view(-1, 1).expand(-1, 3)
        )

        edge_vec_2 = avg_vector[edge_index[1, :]] + 0.0001
        edge_vec_2_distance = torch.sqrt(torch.sum(edge_vec_2 ** 2, dim=1))

        if torch.min(edge_vec_2_distance) < 0.000001:
            print(
                "Error edge_vec_2_distance: {}".format(
                    torch.min(edge_vec_2_distance)
                )
            )

        norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1))
        norm_0_2 = edge_vec_2 / (edge_vec_2_distance.view(-1, 1))
        norm_z = torch.cross(norm_x, norm_0_2, dim=1)
        norm_z = norm_z / (
            torch.sqrt(torch.sum(norm_z ** 2, dim=1, keepdim=True)) + 0.0000001
        )
        norm_y = torch.cross(norm_x, norm_z, dim=1)
        norm_y = norm_y / (
            torch.sqrt(torch.sum(norm_y ** 2, dim=1, keepdim=True)) + 0.0000001
        )

        norm_x = norm_x.view(-1, 3, 1)
        norm_y = norm_y.view(-1, 3, 1)
        norm_z = norm_z.view(-1, 3, 1)

        edge_rot_mat_inv = torch.cat([norm_x, norm_y, norm_z], dim=2)
        edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2)

        return edge_rot_mat

    def _project2D_edges_init(self, rot_mat, edge_index, edge_distance_vec):
        torch.set_printoptions(sci_mode=False)
        length = len(edge_distance_vec)
        device = edge_distance_vec.device

        # Assuming the edges are consecutive based on the target index
        target_node_index, neigh_count = torch.unique_consecutive(
            edge_index[1], return_counts=True
        )
        max_neighbors = torch.max(neigh_count)
        target_neigh_count = torch.zeros(self.num_atoms, device=device).long()
        target_neigh_count.index_copy_(
            0, target_node_index.long(), neigh_count
        )

        index_offset = (
            torch.cumsum(target_neigh_count, dim=0) - target_neigh_count
        )
        neigh_index = torch.arange(length, device=device)
        neigh_index = neigh_index - index_offset[edge_index[1]]

        edge_map_index = edge_index[1] * max_neighbors + neigh_index
        target_lookup = (
            torch.zeros(self.num_atoms * max_neighbors, device=device) - 1
        ).long()
        target_lookup.index_copy_(
            0,
            edge_map_index.long(),
            torch.arange(length, device=device).long(),
        )
        target_lookup = target_lookup.view(self.num_atoms, max_neighbors)

        # target_lookup - For each target node, a list of edge indices
        # target_neigh_count - number of neighbors for each target node
        source_edge = target_lookup[edge_index[0]]
        target_edge = (
            torch.arange(length, device=device)
            .long()
            .view(-1, 1)
            .repeat(1, max_neighbors)
        )

        source_edge = source_edge.view(-1)
        target_edge = target_edge.view(-1)

        mask_unused = source_edge.ge(0)
        source_edge = torch.masked_select(source_edge, mask_unused)
        target_edge = torch.masked_select(target_edge, mask_unused)

        return self._project2D_init(
            source_edge, target_edge, rot_mat, edge_distance_vec
        )

    def _project2D_init(
        self, source_edge, target_edge, rot_mat, edge_distance_vec
        ):
        edge_distance_norm = F.normalize(edge_distance_vec)
        source_edge_offset = edge_distance_norm[source_edge]

        source_edge_offset_rot = torch.bmm(
            rot_mat[target_edge], source_edge_offset.view(-1, 3, 1)
        )

        source_edge_X = torch.atan2(
            source_edge_offset_rot[:, 1], source_edge_offset_rot[:, 2]
        ).view(-1)

        # source_edge_X ranges from -pi to pi
        source_edge_X = (source_edge_X + math.pi) / (2.0 * math.pi)

        # source_edge_Y ranges from -1 to 1
        source_edge_Y = source_edge_offset_rot[:, 0].view(-1)
        source_edge_Y = torch.clamp(source_edge_Y, min=-1.0, max=1.0)
        source_edge_Y = (source_edge_Y.asin() + (math.pi / 2.0)) / (
            math.pi
        )  # bin by angle
        # source_edge_Y = (source_edge_Y + 1.0) / 2.0 # bin by sin
        source_edge_Y = 0.99 * (source_edge_Y) + 0.005

        source_edge_X = source_edge_X * self.sphere_size_long
        source_edge_Y = source_edge_Y * (
            self.sphere_size_lat - 1.0
        )  # not circular so pad by one

        source_edge_X_0 = torch.floor(source_edge_X).long()
        source_edge_X_del = source_edge_X - source_edge_X_0
        source_edge_X_0 = source_edge_X_0 % self.sphere_size_long
        source_edge_X_1 = (source_edge_X_0 + 1) % self.sphere_size_long

        source_edge_Y_0 = torch.floor(source_edge_Y).long()
        source_edge_Y_del = source_edge_Y - source_edge_Y_0
        source_edge_Y_0 = source_edge_Y_0 % self.sphere_size_lat
        source_edge_Y_1 = (source_edge_Y_0 + 1) % self.sphere_size_lat

        # Compute the values needed to bilinearly splat the values onto the spheres
        index_0_0 = (
            target_edge * self.sphere_size_lat * self.sphere_size_long
            + source_edge_Y_0 * self.sphere_size_long
            + source_edge_X_0
        )
        index_0_1 = (
            target_edge * self.sphere_size_lat * self.sphere_size_long
            + source_edge_Y_0 * self.sphere_size_long
            + source_edge_X_1
        )
        index_1_0 = (
            target_edge * self.sphere_size_lat * self.sphere_size_long
            + source_edge_Y_1 * self.sphere_size_long
            + source_edge_X_0
        )
        index_1_1 = (
            target_edge * self.sphere_size_lat * self.sphere_size_long
            + source_edge_Y_1 * self.sphere_size_long
            + source_edge_X_1
        )

        delta_0_0 = (1.0 - source_edge_X_del) * (1.0 - source_edge_Y_del)
        delta_0_1 = (source_edge_X_del) * (1.0 - source_edge_Y_del)
        delta_1_0 = (1.0 - source_edge_X_del) * (source_edge_Y_del)
        delta_1_1 = (source_edge_X_del) * (source_edge_Y_del)

        index_0_0 = index_0_0.view(1, -1)
        index_0_1 = index_0_1.view(1, -1)
        index_1_0 = index_1_0.view(1, -1)
        index_1_1 = index_1_1.view(1, -1)

        # NaNs otherwise
        if self.grad_forces:
            with torch.no_grad():
                delta_0_0 = delta_0_0.view(1, -1)
                delta_0_1 = delta_0_1.view(1, -1)
                delta_1_0 = delta_1_0.view(1, -1)
                delta_1_1 = delta_1_1.view(1, -1)
        else:
            delta_0_0 = delta_0_0.view(1, -1)
            delta_0_1 = delta_0_1.view(1, -1)
            delta_1_0 = delta_1_0.view(1, -1)
            delta_1_1 = delta_1_1.view(1, -1)

        return (
            torch.cat([index_0_0, index_0_1, index_1_0, index_1_1]),
            torch.cat([delta_0_0, delta_0_1, delta_1_0, delta_1_1]),
            source_edge,
        )


In [None]:

#model
model = ConvNN(dim_atom=training_set[0][0]['x'].shape[1], dim_edge=training_set[0][0]['edge_attr'].shape[1])

#optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion = nn.L1Loss()

#переносим на куду если она есть
model = model.to('cpu')
criterion = criterion.to('cpu')

logfile_str = {
    "lr": "0.001",
    "smearing" : "smearing"
}

#граф модели
trace_system = dict(list(next(iter(training_generator))[0]))
writer.add_graph(model, trace_system)
writer.add_text(timestamp, str(logfile_str))

In [None]:
class GConv(MessagePassing):
    def __init__(self):
        super(GConv, self).__init__(aggr='add')  # "Add" aggregation

    def forward(self, batch):
        x = batch['x']
        edge_index = batch['edge_index']
        edge_attr = batch['edge_attr']
        
        # x has shape [N -- количество атомов в системе(батче), in_channels -- размерность вектора-атома]
        # edge_index has shape [2, E] -- каждое ребро задаётся парой вершин

        # Start propagating messages. 
    
        return self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)

    def message(self, x, x_i, x_j, edge_attr):
        # your function
        pass
        
    def update(self, aggr_out, x, edge_attr, edge_index):        
        #your function
        pass

In [None]:
class ConvNN(nn.Module):
    
    def __init__(self):
        
        super().__init__()          
        self.conv = GConv()
        
    def forward(self, batch):
        #your function
        pass

In [None]:
#config
batch_size = 50
num_workers = 0

features_cols = ['atomic_numbers', 'edge_index_new', 'distances_new', 
                 'contact_solid_angles', 'tags', 'voronoi_volumes', 'spherical_domain_radii']

target_col = 'y_relaxed'
lr = 0.001
epochs = 20

In [None]:
#чтобы тензор по умолчанию заводился на куде
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    print('cuda')

In [None]:
#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
print(device)

In [None]:
#инициализируем тренировочный датасети и тренировочный итератор
train_dataset_file_path= os.path.expanduser("../../ocp_datasets/data/is2re/10k/train/data_mod2.lmdb")

training_set = Dataset(train_dataset_file_path, features_cols, target_col, preprocessing=preprocessing)
training_generator = DataLoader(training_set, batch_size=batch_size, num_workers=num_workers)

In [None]:
#инициализируем валидационный датасет и валидационный итератор
val_dataset_file_path = os.path.expanduser("../../ocp_datasets/data/is2re/all/val_ood_both/data_mod2.lmdb")

valid_set = Dataset(val_dataset_file_path, features_cols, target_col, preprocessing=preprocessing)
valid_generator = DataLoader(valid_set, batch_size=batch_size, num_workers=num_workers)

In [None]:
try:
    lmdb_dataset(train_dataset_file_path).describe()
except:
    pass

In [None]:
#model
model = ConvNN(dim_atom=training_set[0][0]['x'].shape[1], dim_edge=training_set[0][0]['edge_attr'].shape[1])

#optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion = nn.L1Loss()

#переносим на куду если она есть
model = model.to(device)
criterion = criterion.to(device)

In [None]:
timestamp = str(datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))

print(timestamp)

In [None]:
#tensorboard writer, при первом запуске надо руками сделать папку для логов

# server
#log_folder_path = "../../ocp_results/logs/tensorboard/out_base_model"

# colab
# log_folder_path = "/content/drive/MyDrive/ocp_results/logs/tensorboard/out_base_model"

# user_specific 
log_file_path = "../logs/tensorboard_airi"

writer = SummaryWriter(log_file_path + '/' + timestamp)

In [None]:
%%time
logfile_str = {
    "train_dataset_file_path": train_dataset_file_path,
    "val_dataset_file_path": val_dataset_file_path,
    "features_cols": features_cols,
    "target_col": target_col,
    "batch_size": batch_size,
    "num_workers": num_workers,
    "epochs": epochs,
    "lr": lr
}

#граф модели
trace_system = dict(list(next(iter(training_generator))[0]))
writer.add_graph(model, trace_system)
writer.add_text(timestamp, str(logfile_str))

## Training

In [None]:
%%time
loss = []
loss_eval = []

print(timestamp)
print(f'Start training model {str(model)}')
for i in range(epochs):
    loss.append(train(model, training_generator, optimizer, criterion, epoch=i, writer=writer, device=device))
    loss_eval.append(evaluate(model, valid_generator, criterion, epoch=i, writer=writer, device=device))