In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


# Libraries

In [None]:
!pip install pymatgen pandas numpy mp_api torch_geometric matgl py3Dmol mlflow

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 169, in exc_logging_wrapper
    status = run_func(*args)
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/req_command.py", line 242, in wrapper
    return func(self, options, args)
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/commands/install.py", line 377, in run
    requirement_set = resolver.resolve(
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/resolution/resolvelib/resolver.py", line 92, in resolve
    result = self._result = resolver.resolve(
  File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/resolvelib/resolvers.py", line 546, in resolve
    state = resolution.resolve(requirements, max_rounds=max_rounds)
  File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/resolvelib/resolvers.py", line 443, in resolve
    newly_unsatisfied_names = {
  File "/usr/local/lib/python3.10/dist-packages/pip/_vendor

In [None]:
import pandas as pd
from mp_api.client import MPRester
from pymatgen.core import Structure
from pymatgen.analysis.graphs import StructureGraph
from pymatgen.analysis.local_env import CrystalNN
from emmet.core.summary import HasProps
import matgl
from matgl.ext.pymatgen import Structure2Graph, get_element_list
import torch
import torch_geometric
from torch_geometric.data import Dataset
import numpy as np
import os
import py3Dmol
from tqdm import tqdm

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Dataset Utilities

In [None]:
def getGraphFromStructure(struct):
  r = torch_geometric.utils.convert.from_dgl(struct[0])

  r.edge_index = torch.unique(r.edge_index.t(), dim=0)

  r.edge_index = r.edge_index.t()

  del r.pbc_offset

  cart_coords = torch.matmul(r.frac_coords, struct[1][0])

  r.frac_coords = cart_coords

  numAtoms = len(r.node_type)

  r.x = torch.zeros((numAtoms, 4))

  for i in range(numAtoms):
    r.x[i] = torch.tensor([r.node_type[i], r.frac_coords[i][0], r.frac_coords[i][1], r.frac_coords[i][2]])


  del r.node_type, r.frac_coords

  return r

def filter_elements(t):
    mask = t[0] < t[1]  # Create a mask where the opposite condition is true (we want to keep these)
    filtered_tensor = t[:, mask]  # Filter columns based on mask
    return filtered_tensor

def visualizeCrystal(r):
  cart_coords = r.x  # This is a tensor

  edge_index = r.edge_index.t()

  color_mapping = {
      # Alkali metals (Group 1, excluding Hydrogen)
      'Li': 'violet', 'Na': 'violet', 'K': 'violet', 'Rb': 'violet', 'Cs': 'violet', 'Fr': 'violet',
      # Alkaline earth metals (Group 2)
      'Be': 'indigo', 'Mg': 'indigo', 'Ca': 'indigo', 'Sr': 'indigo', 'Ba': 'indigo', 'Ra': 'indigo',
      # Transition metals (Groups 3-12)
      'Sc': 'blue', 'Ti': 'blue', 'V': 'blue', 'Cr': 'blue', 'Mn': 'blue', 'Fe': 'blue',
      'Co': 'blue', 'Ni': 'blue', 'Cu': 'blue', 'Zn': 'blue', 'Y': 'blue',
      'Zr': 'blue', 'Nb': 'blue', 'Mo': 'blue', 'Tc': 'blue', 'Ru': 'blue', 'Rh': 'blue',
      'Pd': 'blue', 'Ag': 'blue', 'Cd': 'blue', 'Hf': 'blue', 'Ta': 'blue', 'W': 'blue',
      'Re': 'blue', 'Os': 'blue', 'Ir': 'blue', 'Pt': 'blue', 'Au': 'blue', 'Hg': 'blue',
      'Rf': 'blue', 'Db': 'blue', 'Sg': 'blue', 'Bh': 'blue', 'Hs': 'blue', 'Mt': 'blue',
      # Post-transition metals
      'Al': 'green', 'Ga': 'green', 'In': 'green', 'Sn': 'green', 'Tl': 'green', 'Pb': 'green', 'Bi': 'green',
      # Metalloids
      'B': 'yellowgreen', 'Si': 'yellowgreen', 'Ge': 'yellowgreen', 'As': 'yellowgreen', 'Sb': 'yellowgreen', 'Te': 'yellowgreen', 'Po': 'yellowgreen',
      # Nonmetals
      'H': 'white', 'C': 'black', 'N': 'blue', 'O': 'red', 'P': 'orange', 'S': 'yellow', 'Se': 'yellow',
      # Halogens (Group 17)
      'F': 'cyan', 'Cl': 'cyan', 'Br': 'cyan', 'I': 'cyan', 'At': 'cyan',
      # Noble gases (Group 18)
      'He': 'magenta', 'Ne': 'magenta', 'Ar': 'magenta', 'Kr': 'magenta', 'Xe': 'magenta', 'Rn': 'magenta',
      # Lanthanides
      'La': 'lightblue', 'Ce': 'lightblue', 'Pr': 'lightblue', 'Nd': 'lightblue', 'Pm': 'lightblue',
      'Sm': 'lightblue', 'Eu': 'lightblue', 'Gd': 'lightblue', 'Tb': 'lightblue', 'Dy': 'lightblue',
      'Ho': 'lightblue', 'Er': 'lightblue', 'Tm': 'lightblue', 'Yb': 'lightblue', 'Lu': 'lightblue',
      # Actinides
      'Ac': 'lightgreen', 'Th': 'lightgreen', 'Pa': 'lightgreen', 'U': 'lightgreen', 'Np': 'lightgreen',
      'Pu': 'lightgreen', 'Am': 'lightgreen', 'Cm': 'lightgreen', 'Bk': 'lightgreen', 'Cf': 'lightgreen',
      'Es': 'lightgreen', 'Fm': 'lightgreen', 'Md': 'lightgreen', 'No': 'lightgreen', 'Lr': 'lightgreen',

  }

  # Start viewer
  view = py3Dmol.view(width=800, height=400)

  # Add atoms
  for atom_type, x, y, z in cart_coords:
    atom_color = color_mapping.get(elemList[int(atom_type.item())])  # Use default color if atom_type not in mapping
    view.addSphere({'center': {'x': x.item(), 'y': y.item(), 'z': z.item()}, 'radius': 0.5, 'color': atom_color})

  # Add bonds - assuming direct connections without considering pbc_offset
  for start, end in edge_index:
      start_coords = cart_coords[start][1:]
      end_coords = cart_coords[end][1:]
      view.addCylinder({'start': {'x': start_coords[0].item(), 'y': start_coords[1].item(), 'z': start_coords[2].item()},
                        'end': {'x': end_coords[0].item(), 'y': end_coords[1].item(), 'z': end_coords[2].item()},
                        'radius': 0.1, 'color': 'gray'})

  # Show the structure
  view.zoomTo()
  view.show()

# Dataset

In [None]:
API_KEY = '0zNsIrPXwPmHXOl8AyT2C96LX2yVIVzI'

In [None]:
mpr = MPRester(API_KEY)

results = mpr.materials.elasticity.search(fields=["material_id", "structure", 'bulk_modulus','young_modulus', 'shear_modulus', 'homogeneous_poisson'])
data = [result.dict() for result in results]  # Convert result objects to dictionaries

Retrieving ElasticityDoc documents:   0%|          | 0/12392 [00:00<?, ?it/s]

In [None]:
structureList = [Structure.from_dict(item['structure']) for item in data]

elemList = get_element_list(structureList)

In [None]:
from torch_geometric.utils import to_dense_adj

s = Structure2Graph(elemList, 4)

extracted_data = []

# Iterate over each dictionary in your list
for item in data:
    # Extract values, including from the subdictionary
    row = {
        'Structure': s.get_graph(Structure.from_dict(item['structure'])),
        'Shear Modulus': item['shear_modulus']['vrh'],
        'Bulk Modulus': item['bulk_modulus']['vrh']
    }
    # Append the extracted data to your list
    extracted_data.append(row)

df = pd.DataFrame(extracted_data)

trainingDf = pd.DataFrame(columns = ["Graph", "Structure", "Shear Modulus"])

trainingDf["Structure"] = df["Structure"]
trainingDf["Shear Modulus"] = df["Shear Modulus"]

filtered_df = trainingDf[trainingDf['Shear Modulus'] <= 700]
filtered_df.reset_index(drop=True, inplace=True)

filtered_df['Graph'] = filtered_df['Structure'].apply(getGraphFromStructure)

SM_mean = filtered_df["Shear Modulus"].mean()
SM_std = filtered_df["Shear Modulus"].std()

filtered_df["Norm_SM"]=(filtered_df["Shear Modulus"]-SM_mean)/SM_std

for i in range(len(filtered_df["Graph"])):
  filtered_df["Graph"][i].SM = filtered_df["Norm_SM"][i]
  filtered_df["Graph"][i].edge_index = filter_elements(filtered_df["Graph"][i].edge_index)

dataset = []

for graph in filtered_df["Graph"]:
  if(graph.x.size()[0]<=15 and graph.x.size()[0]>2 and graph.x.size()[0] == torch.squeeze(to_dense_adj(graph.edge_index)).size()[0]):
    dataset.append(graph)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_df['Graph'] = filtered_df['Structure'].apply(getGraphFromStructure)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_df["Norm_SM"]=(filtered_df["Shear Modulus"]-SM_mean)/SM_std


In [None]:
from torch_geometric.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data_list):
        super(CustomDataset, self).__init__()
        self.data_list = data_list

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data_list[idx]

In [None]:
MaterialsDataset = CustomDataset(dataset)

# Training Utilities

In [None]:
from torch_geometric.utils import to_dense_adj
import torch

device = DEVICE

def count_parameters(model):
    """
    Counts the number of parameters for a Pytorch model
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def kl_loss(mu=None, logstd=None):
    """
    Closed formula of the KL divergence for normal distributions
    """
    MAX_LOGSTD = 10
    logstd =  logstd.clamp(max=MAX_LOGSTD)
    kl_div = -0.5 * torch.mean(torch.sum(1 + 2 * logstd - mu**2 - logstd.exp()**2, dim=1))

    # Limit numeric errors
    kl_div = kl_div.clamp(max=1000)
    return kl_div

def slice_graph_targets(graph_id, edge_targets, node_targets, batch_index):
    """
    Slices out the upper triangular part of an adjacency matrix for
    a single graph from a large adjacency matrix for a full batch.
    For the node features the corresponding section in the batch is sliced out.
    --------
    graph_id: The ID of the graph (in the batch index) to slice
    edge_targets: A dense adjacency matrix for the whole batch
    node_targets: A tensor of node labels for the whole batch
    batch_index: The node to graph map for the batch
    """
    # Create mask for nodes of this graph id
    graph_mask = torch.eq(batch_index, graph_id)
    # Row slice and column slice batch targets to get graph edge targets
    graph_edge_targets = edge_targets[graph_mask][:, graph_mask]
    # Get triangular upper part of adjacency matrix for targets
    size = graph_edge_targets.shape[0]
    if size > 1:  # Ensure there are at least 2 nodes to form an upper triangle
        triu_indices = torch.triu_indices(size, size, offset=1)
        # Direct indexing using triu_indices
        graph_edge_targets = graph_edge_targets[triu_indices[0], triu_indices[1]]
    else:
        graph_edge_targets = torch.empty((0,), dtype=edge_targets.dtype, device=edge_targets.device)

    # Slice node targets
    graph_node_targets = node_targets[graph_mask]
    return graph_edge_targets, graph_node_targets

def slice_graph_predictions(triu_logits, node_logits, graph_triu_size, triu_start_point, graph_size, node_start_point):
    """
    Slices out the corresponding section from a list of batch triu values.
    Given a start point and the size of a graph's triu, simply slices
    the section from the batch list.
    -------
    triu_logits: A batch of triu predictions of different graphs
    node_logits: A batch of node predictions with fixed size MAX_GRAPH_SIZE
    graph_triu_size: Size of the triu of the graph to slice
    triu_start_point: Index of the first node of this graph in the triu batch
    graph_size: Max graph size
    node_start_point: Index of the first node of this graph in the nodes batch
    """
    # Slice edge logits
    graph_logits_triu = torch.squeeze(
                    triu_logits[triu_start_point:triu_start_point + graph_triu_size]
                    )
    # Slice node logits
    graph_node_logits = torch.squeeze(
                    node_logits[node_start_point:node_start_point + graph_size]
                    )
    return graph_logits_triu, graph_node_logits

def to_one_hot(x, options):
    """
    Converts a tensor of values to a one-hot vector
    based on the entries in options.
    """
    return torch.nn.functional.one_hot(x.long(), len(options))

def squared_difference(input, target):
    return (input - target) ** 2


def triu_to_dense(triu_values, num_nodes):
    """
    Converts a triangular upper part of a matrix as flat vector
    to a squared adjacency matrix with a specific size (num_nodes).
    """
    dense_adj = torch.zeros((num_nodes, num_nodes)).to(device).float()
    triu_indices = torch.triu_indices(num_nodes, num_nodes, offset=1)
    tril_indices = torch.tril_indices(num_nodes, num_nodes, offset=-1)
    dense_adj[triu_indices[0], triu_indices[1]] = triu_values
    dense_adj[tril_indices[0], tril_indices[1]] = triu_values
    return dense_adj

def calculate_node_edge_pair_loss(node_tar, edge_tar, node_pred, edge_pred):
    """
    Calculates a loss based on the sum of node-edge pairs.
    node_tar:  [nodes, supported atoms]
    node_pred: [max nodes, supported atoms + 1]
    edge_tar:  [triu values for target nodes, supported edges]
    edge_pred: [triu values for predicted nodes, supported edges]

    """
    # Recover full 3d adjacency matrix for edge predictions
    edge_pred_mat = triu_to_dense(edge_pred[:,1].float(), node_pred.shape[0]) # [num nodes, num nodes]

    # Recover full 3d adjacency matrix for edge targets
    edge_tar_mat = triu_to_dense(edge_tar[:,1].float(), node_tar.shape[0]) # [num nodes, num nodes]

    # --- The two output matrices tell us how many edges are connected with each of the atom types
    # Multiply each of the edge types with the atom types for the predictions
    node_edge_preds = torch.empty((MAX_MAT_SIZE, len(elemList)), dtype=torch.float, device=device)
    node_edge_preds = torch.matmul(edge_pred_mat, node_pred[:, :88])

    # Multiply each of the edge types with the atom types for the targets
    node_edge_tar = torch.empty((node_tar.shape[0], len(elemList)), dtype=torch.float, device=device)
    node_edge_tar = torch.matmul(edge_tar_mat, node_tar.float().squeeze())

    # Reduce to matrix with [num atom types, num edge types]
    node_edge_pred_matrix = torch.sum(node_edge_preds, dim=0)
    node_edge_tar_matrix = torch.sum(node_edge_tar, dim=0)

    node_edge_loss = torch.mean(sum(squared_difference(node_edge_pred_matrix, node_edge_tar_matrix.float())))

    return node_edge_loss


def approximate_recon_loss(node_targets, node_preds, triu_targets, triu_preds):
    atom_targets = node_targets[:,:1]
    coord_targets = node_targets[:,1:]

    # Convert targets to one hot
    onehot_node_targets = to_one_hot(atom_targets, elemList ) #+ ["None"]
    onehot_triu_targets = to_one_hot(triu_targets, ["None", "Edge"])

    # Reshape node predictions
    node_matrix_shape = (MAX_MAT_SIZE, (len(elemList) + 1 + 3))
    node_preds_matrix = node_preds.reshape(node_matrix_shape)

    # Reshape triu predictions
    edge_matrix_shape = (int((MAX_MAT_SIZE * (MAX_MAT_SIZE - 1))/2), 2)
    triu_preds_matrix = triu_preds.reshape(edge_matrix_shape)

    # Apply sum on labels per (node/edge) type and discard "none" types
    node_preds_reduced = torch.sum(node_preds_matrix[:, :88], 0)
    node_targets_reduced = torch.sum(onehot_node_targets, 0)
    triu_preds_reduced = torch.sum(triu_preds_matrix[:, 1:], 0)
    triu_targets_reduced = torch.sum(onehot_triu_targets[:, 1:], 0)

    # Calculate node-sum loss and edge-sum loss
    node_loss = torch.sum(squared_difference(node_preds_reduced, node_targets_reduced.float()))
    edge_loss = torch.sum(squared_difference(triu_preds_reduced, triu_targets_reduced.float()))

    # Calculate coordinate loss
    coord_target_matrix = torch.zeros(88, 3).to(device)

    for node in node_targets:
      coord_target_matrix[node[0].int()][0] += node[1]
      coord_target_matrix[node[0].int()][1] += node[2]
      coord_target_matrix[node[0].int()][2] += node[3]

    atom_preds = torch.argmax(node_preds_matrix[:, :-3], dim=1)
    atom_preds = atom_preds.unsqueeze(1)
    node_coords = node_preds_matrix[:, -3:]
    node_preds_coord_matrix = torch.cat((atom_preds, node_coords), dim=1)

    coord_pred_matrix = torch.zeros(88, 3).to(device)

    for node in node_preds_coord_matrix:
      if node[0].int() != 88:
        coord_pred_matrix[node[0].int()][0] += node[1]
        coord_pred_matrix[node[0].int()][1] += node[2]
        coord_pred_matrix[node[0].int()][2] += node[3]

    coord_loss = torch.sum(squared_difference(coord_pred_matrix.float(), coord_target_matrix.float()))


    # Calculate node-edge-sum loss
    # Forces the model to properly arrange the matrices
    node_edge_loss = calculate_node_edge_pair_loss(onehot_node_targets,
                                      onehot_triu_targets,
                                      node_preds_matrix[:, :89],
                                      triu_preds_matrix)

    approx_loss =   node_loss + coord_loss + edge_loss + node_edge_loss
    return approx_loss


def gvae_loss(triu_logits, node_logits, edge_index, node_types, \
              mu, logvar, batch_index, kl_beta):
    """
    Calculates the loss for the graph variational autoencoder,
    consiting of a node loss, an edge loss and the KL divergence.
    """
    # Convert target edge index to dense adjacency matrix
    batch_edge_targets = torch.squeeze(to_dense_adj(edge_index))

    # For this model we always have the same (fixed) output dimension
    graph_size = MAX_MAT_SIZE*(len(elemList) + 1+3)
    graph_triu_size = int((MAX_MAT_SIZE * (MAX_MAT_SIZE - 1)) / 2) * 2

    # Reconstruction loss per graph
    batch_recon_loss = []
    triu_indices_counter = 0
    graph_size_counter = 0

    # Loop over graphs in this batch
    for graph_id in torch.unique(batch_index):
            # Get upper triangular targets for this graph from the whole batch
            triu_targets, node_targets = slice_graph_targets(graph_id,
                                                            batch_edge_targets,
                                                            node_types,
                                                            batch_index)

            # Get upper triangular predictions for this graph from the whole batch
            triu_preds, node_preds = slice_graph_predictions(triu_logits,
                                                            node_logits,
                                                            graph_triu_size,
                                                            triu_indices_counter,
                                                            graph_size,
                                                            graph_size_counter)

            # Update counter to the index of the next (upper-triu) graph
            triu_indices_counter = triu_indices_counter + graph_triu_size
            graph_size_counter = graph_size_counter + graph_size

            # Calculate losses
            recon_loss = approximate_recon_loss(node_targets,
                                                node_preds,
                                                triu_targets,
                                                triu_preds)
            batch_recon_loss.append(recon_loss)

    # Take average of all losses
    num_graphs = torch.unique(batch_index).shape[0]
    batch_recon_loss = torch.true_divide(sum(batch_recon_loss),  num_graphs)

    # KL Divergence
    kl_divergence = kl_loss(mu, logvar)

    return batch_recon_loss + kl_beta * kl_divergence, kl_divergence

# Model 1

In [None]:
import torch
import torch.nn as nn
from torch.nn import Linear
from torch_geometric.nn.conv import TransformerConv
from torch_geometric.nn import Set2Set
from torch_geometric.nn import BatchNorm
from tqdm import tqdm

MAX_MAT_SIZE = 15
NUM_ATOMS = len(elemList)

class GVAE(nn.Module):
    def __init__(self, feature_size=4):
        super(GVAE, self).__init__()
        self.encoder_embedding_size = 64
        self.latent_embedding_size = 128
        self.num_atom_types = NUM_ATOMS
        self.max_num_atoms = MAX_MAT_SIZE
        self.decoder_hidden_neurons = 512
        self.device = DEVICE

        # Encoder layers
        self.conv1 = TransformerConv(feature_size,
                                    self.encoder_embedding_size,
                                    heads=4,
                                    concat=False,
                                    beta=True)
        self.bn1 = BatchNorm(self.encoder_embedding_size)
        self.conv2 = TransformerConv(self.encoder_embedding_size,
                                    self.encoder_embedding_size,
                                    heads=4,
                                    concat=False,
                                    beta=True)
        self.bn2 = BatchNorm(self.encoder_embedding_size)
        self.conv3 = TransformerConv(self.encoder_embedding_size,
                                    self.encoder_embedding_size,
                                    heads=4,
                                    concat=False,
                                    beta=True)
        self.bn3 = BatchNorm(self.encoder_embedding_size)
        self.conv4 = TransformerConv(self.encoder_embedding_size,
                                    self.encoder_embedding_size,
                                    heads=4,
                                    concat=False,
                                    beta=True)

        # Pooling layers
        self.pooling = Set2Set(self.encoder_embedding_size, processing_steps=4)

        self.intermediate_linear = Linear(self.encoder_embedding_size * 2, 127)

        # Latent transform layers
        self.mu_transform = Linear(self.latent_embedding_size,
                                            self.latent_embedding_size)
        self.logvar_transform = Linear(self.latent_embedding_size,
                                            self.latent_embedding_size)

        # Decoder layers
        # --- Shared layers
        self.linear_1 = Linear(self.latent_embedding_size, self.decoder_hidden_neurons)
        self.linear_2 = Linear(self.decoder_hidden_neurons, self.decoder_hidden_neurons)

        # --- Atom decoding (outputs a matrix: (max_num_atoms) * (# atom_types + "none"-type + x-coord + y-coord + z-coord))
        atom_output_dim = self.max_num_atoms*(self.num_atom_types + 1 + 3)
        self.atom_decode = Linear(self.decoder_hidden_neurons, atom_output_dim)

        # --- Edge decoding (outputs a triu tensor: (max_num_atoms*(max_num_atoms-1)/2))
        edge_output_dim = int(((self.max_num_atoms * (self.max_num_atoms - 1)) / 2) * 2)
        self.edge_decode = Linear(self.decoder_hidden_neurons, edge_output_dim)


    def encode(self, x, edge_index, shear_modulus, batch_index):
        # GNN layers
        x = self.conv1(x, edge_index).relu()
        x = self.bn1(x)
        x = self.conv2(x, edge_index).relu()
        x = self.bn2(x)
        x = self.conv3(x, edge_index).relu()
        x = self.bn3(x)
        x = self.conv4(x, edge_index).relu()

        # Pool to global representation
        x = self.pooling(x, batch_index)

        #Reduce size to add Shear Modulus
        x = self.intermediate_linear(x)

        shear_modulus = shear_modulus.unsqueeze(-1)

        # Concatenate normalized shear modulus to make it 128 elements
        x = torch.cat((x, shear_modulus), dim=1)

        # Latent transform layers
        mu = self.mu_transform(x)
        logvar = self.logvar_transform(x)
        return mu, logvar

    def decode_graph(self, graph_z):
        """
        Decodes a latent vector into a continuous graph representation
        consisting of node types and edge types.
        """
        # Pass through shared layers
        z = self.linear_1(graph_z).relu()
        z = self.linear_2(z).relu()
        # Decode atom types
        atom_logits = self.atom_decode(z)
        # Decode edge types
        edge_logits = self.edge_decode(z)

        return atom_logits, edge_logits


    def decode(self, z, batch_index):
        node_logits = []
        triu_logits = []
        # Iterate over molecules in batch
        for graph_id in torch.unique(batch_index):
            # Get latent vector for this graph
            graph_z = z[graph_id]

            # Recover graph from latent vector
            atom_logits, edge_logits = self.decode_graph(graph_z)

            # Store per graph results
            node_logits.append(atom_logits)
            triu_logits.append(edge_logits)

        # Concatenate all outputs of the batch
        node_logits = torch.cat(node_logits)
        triu_logits = torch.cat(triu_logits)
        return triu_logits, node_logits


    def reparameterize(self, mu, logvar):
        if self.training:
            # Get standard deviation
            std = torch.exp(logvar)
            # Returns random numbers from a normal distribution
            eps = torch.randn_like(std)
            # Return sampled values
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x, edge_index, shear_modulus, batch_index):
      # Encode the molecule
      mu, logvar = self.encode(x, edge_index, shear_modulus, batch_index)
      # Sample latent vector (per atom)
      z = self.reparameterize(mu, logvar)
      # Decode latent vector into original molecule
      triu_logits, node_logits = self.decode(z, batch_index)
      return triu_logits, node_logits, mu, logvar



    def sample_graphs(self, desired_shear_modulus, num=100):
      print("Sampling materials ... ")

      device = self.device

      desired_shear_modulus = (desired_shear_modulus - SM_mean) / SM_std
      desired_shear_modulus = torch.tensor([desired_shear_modulus], dtype=torch.float32, device=device)

      mats = []

      # Sample materials and check if they are valid
      for _ in tqdm(range(num)):
          # Sample latent space
          z = torch.randn(1, self.latent_embedding_size - 1, device=device)

          # Adjust dimensions of desired_shear_modulus
          desired_shear_modulus_unsqueezed = desired_shear_modulus.unsqueeze(1)

          # Concatenate normalized shear modulus to make it 128 elements
          z = torch.cat((z, desired_shear_modulus_unsqueezed), dim=1)

          # Get model output (this could also be batched)
          dummy_batch_index = torch.tensor([0], dtype=torch.int32, device=device)
          t, n = self.decode(z, dummy_batch_index)

          node_matrix_shape = (MAX_MAT_SIZE, (NUM_ATOMS + 1 + 3))
          node_preds_matrix = n.view(node_matrix_shape)
          node_preds = torch.argmax(node_preds_matrix[:, :-3], dim=1)
          node_coords = node_preds_matrix[:, -3:]

          node_preds_reshaped = node_preds.to(node_coords.dtype).unsqueeze(1)
          node_features = torch.cat((node_preds_reshaped, node_coords), dim=1)

          edge_matrix_shape = (int((MAX_MAT_SIZE * (MAX_MAT_SIZE - 1)) / 2), 2)
          triu_preds_matrix = t.view(edge_matrix_shape)
          triu_preds = torch.argmax(triu_preds_matrix, dim=1)

          edges = torch.tensor([[0, 0]], device=device)

          index = 0
          for i in range(15):
              for j in range(i+1, 15):
                  if triu_preds[index] == 1 and node_preds[i] != 88 and node_preds[j] != 88:
                      edge = torch.tensor([[i, j]], device=device)
                      edges = torch.cat((edges, edge), dim=0)
                  index += 1

          edges = edges[1:].t()

          index_to_remove = torch.where(node_preds == 88)[0]
          for ind in reversed(index_to_remove):
              mask = torch.arange(node_features.size(0), device=device) != ind

              # Apply the mask
              node_features = node_features[mask]
              edges = edges[(edges[:, 0] != ind) & (edges[:, 1] != ind)]

              # Decrement indices of nodes after the removed node
              edges[edges >= ind] -= 1

          gen_graph = torch_geometric.data.Data(x=node_features, edge_index=edges)
          mats.append(gen_graph)

      return mats

def save_checkpoint(checkpoint):
    torch.save(checkpoint, '/content/drive/My Drive/checkpoint.pth.tar')

def load_checkpoint(checkpoint):
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Model 2

In [None]:
import torch
import torch.nn as nn
from torch.nn import Linear
from torch_geometric.nn import TransformerConv, Set2Set, BatchNorm
from torch_geometric.data import DataLoader

MAX_MAT_SIZE = 15
NUM_ATOMS = len(elemList)  # Define elemList appropriately

class GraphGenerator(nn.Module):
    def __init__(self, feature_size=4):
        super(GraphGenerator, self).__init__()
        self.encoder_embedding_size = 64
        self.latent_embedding_size = 128
        self.num_atom_types = NUM_ATOMS
        self.max_num_atoms = MAX_MAT_SIZE
        self.decoder_hidden_neurons = 512

        # Latent space dimension includes shear modulus
        self.latent_dim = self.latent_embedding_size + 1

        # Generator layers
        self.linear_1 = Linear(self.latent_dim, self.decoder_hidden_neurons)
        self.linear_2 = Linear(self.decoder_hidden_neurons, self.decoder_hidden_neurons)

        # Atom decoding
        atom_output_dim = self.max_num_atoms * (self.num_atom_types + 1 + 3)
        self.atom_decode = Linear(self.decoder_hidden_neurons, atom_output_dim)

        # Edge decoding
        edge_output_dim = int(((self.max_num_atoms * (self.max_num_atoms - 1)) / 2) * 2)
        self.edge_decode = Linear(self.decoder_hidden_neurons, edge_output_dim)

    def forward(self, z):
        z = z.relu()
        z = self.linear_1(z).relu()
        z = self.linear_2(z).relu()

        atom_logits = self.atom_decode(z)
        edge_logits = self.edge_decode(z)

        return atom_logits, edge_logits


class GraphDiscriminator(nn.Module):
    def __init__(self, feature_size=4):
        super(GraphDiscriminator, self).__init__()
        self.encoder_embedding_size = 64

        # Encoder layers
        self.conv1 = TransformerConv(feature_size,
                                     self.encoder_embedding_size,
                                     heads=4, concat=False, beta=True)
        self.bn1 = BatchNorm(self.encoder_embedding_size)

        # Pooling
        self.pooling = Set2Set(self.encoder_embedding_size, processing_steps=4)

        # Classification layer
        self.classifier = Linear(self.encoder_embedding_size * 2, 1)

    def forward(self, x, edge_index, batch_index):
        x = self.conv1(x, edge_index).relu()
        x = self.bn1(x)
        x = self.pooling(x, batch_index)

        return torch.sigmoid(self.classifier(x))


def train_gan(generator, discriminator, data_loader, device=DEVICE):
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.001)
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.001)
    criterion = nn.BCELoss()

    for epoch in range(50):  # Example epoch count
        for data in data_loader:
            # Train Discriminator
            optimizer_D.zero_grad()
            real_data = data.to(device)
            real_output = discriminator(real_data.x, real_data.edge_index, real_data.batch)
            real_label = torch.ones(real_output.shape[0], 1, device=device)
            loss_D_real = criterion(real_output, real_label)

            z = torch.randn(real_data.num_graphs, generator.latent_dim, device=device)
            generated_atoms, generated_edges = generator(z)
            fake_output = discriminator(generated_atoms, generated_edges, real_data.batch)
            fake_label = torch.zeros(fake_output.shape[0], 1, device=device)
            loss_D_fake = criterion(fake_output, fake_label)

            loss_D = (loss_D_real + loss_D_fake) / 2
            loss_D.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            z = torch.randn(real_data.num_graphs, generator.latent_dim, device=device)
            generated_atoms, generated_edges = generator(z)
            fake_output = discriminator(generated_atoms, generated_edges, real_data.batch)
            loss_G = criterion(fake_output, real_label)
            loss_G.backward()
            optimizer_G.step()

            print(f"Epoch {epoch}, Loss D: {loss_D.item()}, Loss G: {loss_G.item()}")

train_dataset = MaterialsDataset
test_dataset = MaterialsDataset[:1000]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

generator = GraphGenerator().to(device)
discriminator = GraphDiscriminator().to(device)
train_gan(generator, discriminator, train_loader)


#Train

In [None]:
import mlflow.pytorch

mlflow.set_tracking_uri("/content/drive/My Drive/mlruns")
mlflow.set_experiment("GVAE Test 1")

2024/04/14 11:52:42 INFO mlflow.tracking.fluent: Experiment with name 'GVAE Test 1' does not exist. Creating a new experiment.


<Experiment: artifact_location='/content/drive/My Drive/mlruns/358276506967947079', creation_time=1713095562385, experiment_id='358276506967947079', last_update_time=1713095562385, lifecycle_stage='active', name='GVAE Test 1', tags={}>

In [None]:
import torch
from torch_geometric.data import DataLoader
from tqdm import tqdm
import numpy as np


# Load data
train_dataset = MaterialsDataset
test_dataset = MaterialsDataset[:1000]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

# Load model
model = GVAE()
model = model.to(device)
print("Model parameters: ", count_parameters(model))

# Define loss and optimizer
loss_fn = gvae_loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
kl_beta = 0.5

# Train function
def run_one_epoch(data_loader, type, epoch, kl_beta):
    # Store per batch loss and accuracy
    all_losses = []
    all_kldivs = []

    # Iterate over data loader
    for _, batch in enumerate(tqdm(data_loader)):
            # Use GPU
            batch.to(device)
            # Reset gradients
            optimizer.zero_grad()
            # Call model
            triu_logits, node_logits, mu, logvar = model(batch.x.float(),
                                                        batch.edge_index,
                                                        batch.SM.float(),
                                                        batch.batch)
            # Calculate loss and backpropagate
            loss, kl_div = loss_fn(triu_logits, node_logits,
                                   batch.edge_index,
                                   batch.x.float(), mu, logvar,
                                   batch.batch, kl_beta)
            if type == "Train":
                loss.backward()
                optimizer.step()
            # Store loss and metrics
            all_losses.append(loss.detach().cpu().numpy())
            #all_accs.append(acc)
            all_kldivs.append(kl_div.detach().cpu().numpy())

    print(f"{type} epoch {epoch} loss: ", np.array(all_losses).mean())
    mlflow.log_metric(key=f"{type} Epoch Loss", value=float(np.array(all_losses).mean()), step=epoch)
    mlflow.log_metric(key=f"{type} KL Divergence", value=float(np.array(all_kldivs).mean()), step=epoch)
    mlflow.pytorch.log_model(model, "model")

# Run training
with mlflow.start_run() as run:
    for epoch in range(101):
        checkpoint = {
          "epoch": epoch,
          "model_state_dict": model.state_dict(),
          "optimizer_state_dict": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)
        model.train()
        run_one_epoch(train_loader, type="Train", epoch=epoch, kl_beta=kl_beta)

    mlflow.pytorch.log_model(model, "model")

Model parameters:  1410997


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 0 loss:  2063.0513


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 1 loss:  1940.612


100%|██████████| 281/281 [02:50<00:00,  1.64it/s]


Train epoch 2 loss:  1895.513


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 3 loss:  1866.9177


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 4 loss:  1845.5089


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 5 loss:  1862.4941


100%|██████████| 281/281 [02:50<00:00,  1.64it/s]


Train epoch 6 loss:  1842.4934


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 7 loss:  1841.8601


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 8 loss:  1843.5116


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 9 loss:  1823.0919


100%|██████████| 281/281 [02:50<00:00,  1.64it/s]


Train epoch 10 loss:  1813.157


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 11 loss:  1807.0413


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 12 loss:  1806.432


100%|██████████| 281/281 [02:52<00:00,  1.63it/s]


Train epoch 13 loss:  1790.1262


100%|██████████| 281/281 [02:52<00:00,  1.63it/s]


Train epoch 14 loss:  1799.1813


100%|██████████| 281/281 [02:52<00:00,  1.63it/s]


Train epoch 15 loss:  1768.7185


100%|██████████| 281/281 [02:52<00:00,  1.63it/s]


Train epoch 16 loss:  1763.5181


100%|██████████| 281/281 [02:54<00:00,  1.61it/s]


Train epoch 17 loss:  1755.2106


100%|██████████| 281/281 [02:53<00:00,  1.62it/s]


Train epoch 18 loss:  1775.2706


100%|██████████| 281/281 [02:52<00:00,  1.62it/s]


Train epoch 19 loss:  1760.8936


100%|██████████| 281/281 [02:53<00:00,  1.62it/s]


Train epoch 20 loss:  1723.6572


100%|██████████| 281/281 [02:53<00:00,  1.62it/s]


Train epoch 21 loss:  1710.602


100%|██████████| 281/281 [02:53<00:00,  1.62it/s]


Train epoch 22 loss:  1746.1686


100%|██████████| 281/281 [02:53<00:00,  1.62it/s]


Train epoch 23 loss:  1737.8912


100%|██████████| 281/281 [02:53<00:00,  1.62it/s]


Train epoch 24 loss:  1739.263


100%|██████████| 281/281 [02:52<00:00,  1.63it/s]


Train epoch 25 loss:  1727.3728


100%|██████████| 281/281 [02:52<00:00,  1.63it/s]


Train epoch 26 loss:  1707.0173


100%|██████████| 281/281 [02:53<00:00,  1.62it/s]


Train epoch 27 loss:  1700.5098


100%|██████████| 281/281 [02:52<00:00,  1.63it/s]


Train epoch 28 loss:  1695.9379


100%|██████████| 281/281 [02:53<00:00,  1.62it/s]


Train epoch 29 loss:  1719.8492


100%|██████████| 281/281 [02:53<00:00,  1.62it/s]


Train epoch 30 loss:  1688.3771


100%|██████████| 281/281 [02:53<00:00,  1.62it/s]


Train epoch 31 loss:  1694.9526


100%|██████████| 281/281 [02:53<00:00,  1.62it/s]


Train epoch 32 loss:  1668.474


100%|██████████| 281/281 [02:52<00:00,  1.63it/s]


Train epoch 33 loss:  1676.007


100%|██████████| 281/281 [02:52<00:00,  1.62it/s]


Train epoch 34 loss:  1691.2124


100%|██████████| 281/281 [02:53<00:00,  1.62it/s]


Train epoch 35 loss:  1687.4193


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 36 loss:  1658.9464


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 37 loss:  1629.5095


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 38 loss:  1640.0118


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 39 loss:  1626.1003


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 40 loss:  1618.9067


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 41 loss:  1652.3176


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 42 loss:  1610.0421


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 43 loss:  1619.4457


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 44 loss:  1612.8055


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 45 loss:  1626.3525


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 46 loss:  1602.4462


100%|██████████| 281/281 [02:51<00:00,  1.63it/s]


Train epoch 47 loss:  1609.0509


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 48 loss:  1603.3956


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 49 loss:  1633.6957


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 50 loss:  1821.2126


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 51 loss:  1692.2834


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 52 loss:  1652.8826


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 53 loss:  1634.3031


100%|██████████| 281/281 [02:49<00:00,  1.65it/s]


Train epoch 54 loss:  1644.1969


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 55 loss:  1616.6044


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 56 loss:  1601.2933


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 57 loss:  1670.2982


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 58 loss:  1650.2189


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 59 loss:  1613.2683


100%|██████████| 281/281 [02:49<00:00,  1.65it/s]


Train epoch 60 loss:  1592.4819


100%|██████████| 281/281 [02:49<00:00,  1.66it/s]


Train epoch 61 loss:  1605.1704


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 62 loss:  1616.9994


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 63 loss:  1626.7356


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 64 loss:  1610.54


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 65 loss:  1590.0468


100%|██████████| 281/281 [02:49<00:00,  1.65it/s]


Train epoch 66 loss:  1640.3972


100%|██████████| 281/281 [02:49<00:00,  1.65it/s]


Train epoch 67 loss:  1615.5465


100%|██████████| 281/281 [02:49<00:00,  1.65it/s]


Train epoch 68 loss:  1598.0852


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 69 loss:  1593.8967


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 70 loss:  1602.9915


100%|██████████| 281/281 [02:51<00:00,  1.64it/s]


Train epoch 71 loss:  1573.0181


100%|██████████| 281/281 [02:49<00:00,  1.65it/s]


Train epoch 72 loss:  1576.3795


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 73 loss:  1565.5748


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 74 loss:  1548.3594


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 75 loss:  1547.2473


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 76 loss:  1552.0181


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 77 loss:  1535.3228


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 78 loss:  1565.3156


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 79 loss:  1542.9753


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 80 loss:  1554.9264


100%|██████████| 281/281 [02:49<00:00,  1.66it/s]


Train epoch 81 loss:  1550.0183


100%|██████████| 281/281 [02:49<00:00,  1.66it/s]


Train epoch 82 loss:  1564.2134


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 83 loss:  1562.236


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 84 loss:  1535.8615


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 85 loss:  1535.1948


100%|██████████| 281/281 [02:49<00:00,  1.66it/s]


Train epoch 86 loss:  1551.7281


100%|██████████| 281/281 [02:49<00:00,  1.66it/s]


Train epoch 87 loss:  1540.0792


100%|██████████| 281/281 [02:49<00:00,  1.66it/s]


Train epoch 88 loss:  1530.8796


100%|██████████| 281/281 [02:49<00:00,  1.65it/s]


Train epoch 89 loss:  1505.6655


100%|██████████| 281/281 [02:49<00:00,  1.66it/s]


Train epoch 90 loss:  1509.6753


100%|██████████| 281/281 [02:49<00:00,  1.66it/s]


Train epoch 91 loss:  1497.8151


100%|██████████| 281/281 [02:49<00:00,  1.65it/s]


Train epoch 92 loss:  1517.118


100%|██████████| 281/281 [02:49<00:00,  1.66it/s]


Train epoch 93 loss:  1507.3774


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 94 loss:  1494.2728


100%|██████████| 281/281 [02:49<00:00,  1.65it/s]


Train epoch 95 loss:  1492.814


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 96 loss:  1612.9478


100%|██████████| 281/281 [02:49<00:00,  1.65it/s]


Train epoch 97 loss:  1655.1714


100%|██████████| 281/281 [02:49<00:00,  1.66it/s]


Train epoch 98 loss:  1614.27


100%|██████████| 281/281 [02:49<00:00,  1.65it/s]


Train epoch 99 loss:  1587.1713


100%|██████████| 281/281 [02:50<00:00,  1.65it/s]


Train epoch 100 loss:  1563.2291




In [None]:
graphs = model.sample_graphs(10)

Sampling materials ... 


100%|██████████| 100/100 [00:01<00:00, 66.35it/s]


In [None]:
graphs[99].x

tensor([[ 0.0000e+00,  1.2604e-02, -1.3791e-01, -1.7877e-01],
        [ 2.0000e+00, -2.3322e-01, -3.6316e-01, -3.7851e-01],
        [ 6.0000e+00,  6.4429e-01,  2.5776e-01,  6.6989e-01],
        [ 6.0000e+00,  5.1944e-02, -7.1272e-02,  5.5665e-01],
        [ 6.0000e+00, -1.4892e-01, -2.0007e-01,  3.1569e-01],
        [ 3.0000e+00,  5.6126e-01,  4.4179e-01,  1.1305e+00],
        [ 1.2000e+01,  4.9334e-01,  2.0947e-01,  4.7138e-01],
        [ 6.0000e+00, -1.7799e-01, -2.3646e-01, -4.3765e-02],
        [ 2.0000e+00,  1.2516e+00,  7.1839e-01,  1.4773e+00],
        [ 3.0000e+00,  7.8170e-01,  3.7863e-01,  1.1208e+00],
        [ 1.4000e+01,  1.1005e-01, -5.5989e-04,  1.4810e-01],
        [ 4.1000e+01, -2.7557e-01, -2.7897e-02, -4.6296e-01],
        [ 2.0000e+00,  2.2705e-01, -4.0406e-02,  5.3485e-01],
        [ 6.0000e+00, -7.5113e-02, -4.1200e-01, -2.6448e-01],
        [ 1.4000e+01,  4.4899e-01,  8.3423e-02,  1.1045e+00]], device='cuda:0',
       grad_fn=<CatBackward0>)

In [None]:
graphs[0].x

tensor([[ 0.0000e+00, -2.8010e-02, -1.4131e-01, -1.9392e-01],
        [ 2.0000e+00, -2.9739e-01, -3.5196e-01, -4.0671e-01],
        [ 6.0000e+00,  7.0862e-01,  3.4209e-01,  7.5355e-01],
        [ 6.0000e+00,  1.2874e-03, -8.5675e-02,  5.0737e-01],
        [ 6.0000e+00,  1.2034e-02, -8.5902e-02,  5.0244e-01],
        [ 3.0000e+00,  4.7669e-01,  3.5995e-01,  1.0089e+00],
        [ 1.2000e+01,  4.7923e-01,  2.1651e-01,  4.5014e-01],
        [ 6.0000e+00,  5.4597e-02, -8.2491e-02,  2.0096e-01],
        [ 2.0000e+00,  1.1745e+00,  7.3806e-01,  1.4354e+00],
        [ 3.0000e+00,  6.6360e-01,  3.3894e-01,  1.0221e+00],
        [ 3.3000e+01,  3.5276e-02, -6.9160e-02,  9.8445e-02],
        [ 4.1000e+01, -1.7538e-01,  4.7847e-02, -3.6828e-01],
        [ 2.0000e+00,  9.2369e-02, -1.0684e-01,  4.2744e-01],
        [ 6.0000e+00,  3.6169e-01, -7.2626e-02,  2.1622e-01],
        [ 6.0000e+00,  3.7993e-01,  6.6950e-02,  1.0800e+00]], device='cuda:0',
       grad_fn=<CatBackward0>)

In [None]:
visualizeCrystal(graphs[99])