In [None]:
! wget https://github.com/pfebrer/sisl/releases/download/tutorials-tag/sisl_playground_files.zip

In [None]:
! unzip sisl_playground_files.zip

In [None]:
! pip install sisl[viz]

In [None]:
from pathlib import Path

# Sisl imports
import sisl
import sisl.viz
from sisl.viz import merge_plots

# To quickly plot the hamiltonian matrix
import plotly.express as px

# For some reason sisl logger interferes with Colab logger.
# sisl.nodes.Node.context["log_level"] = "CRITICAL"

path_uc = Path("graphene_uc")
fdf_uc = sisl.get_sile(path_uc / "RUN.fdf")
geometry = fdf_uc.read_geometry()
true_h = fdf_uc.read_hamiltonian().tocsr()

In [None]:
import numpy as np
import pandas as pd
import torch

# To load plotly templates for sisl visualization
import sisl.viz

from e3nn import o3

from graph2mat import (
    BasisConfiguration,
    PointBasis,
    BasisTableWithEdges,
    MatrixDataProcessor,
)
from graph2mat.bindings.torch import TorchBasisMatrixDataset, TorchBasisMatrixData

from graph2mat.bindings.e3nn import E3nnGraph2Mat

from graph2mat.tools.viz import plot_basis_matrix

In [None]:
# The basis
point_1 = PointBasis("A", R=2, basis="0e", basis_convention="spherical")
point_2 = PointBasis("B", R=5, basis="2x0e + 1o", basis_convention="spherical")

basis = [point_1, point_2]

# The basis table.
table = BasisTableWithEdges(basis)

# The data processor.
processor = MatrixDataProcessor(
    basis_table=table, symmetric_matrix=True, sub_point_matrix=False
)

positions = np.array([[0, 0, 0], [6.0, 0, 0], [12, 0, 0]])

# The shape of the node features.
node_feats_irreps = o3.Irreps("0e + 1o")


# The fake environment representation function that we will use
# to compute node features.
def get_environment_representation(data, irreps):
    """Function that mocks a true calculation of an environment representation.

    Computes a random array and then ensures that the numbers obey our particular
    system's symmetries.
    """
    import torch

    torch.manual_seed(0)

    node_features = irreps.randn(data.num_nodes, -1)
    # The point in the middle sees the same in -X and +X directions
    # therefore its representation must be 0.
    # In principle the +/- YZ are also equivalent, but let's say that there
    # is something breaking the symmetry to make the numbers more interesting.
    # Note that the spherical harmonics convention is YZX.
    node_features[1, 3] = 0
    # We make both A points have equivalent features except in the X direction,
    # where the features are opposite
    node_features[2::3, :3] = node_features[0::3, :3]
    node_features[2::3, 3] = -node_features[0::3, 3]
    return node_features


# The matrix readout function
model = E3nnGraph2Mat(
    unique_basis=basis,
    irreps=dict(node_feats_irreps=node_feats_irreps),
    symmetric=True,
)

In [None]:
positions = np.array([[0, 0, 0], [6.0, 0, 0], [12, 0, 0]])

config1 = BasisConfiguration(
    point_types=["A", "B", "A"],
    positions=positions,
    basis=basis,
    cell=np.eye(3) * 100,
    pbc=(False, False, False),
    matrix=true_h,
)

# ! Try this instead. Gives: AttributeError: type object 'BasisConfiguration' has no attribute 'from_matrix'
# config1 = BasisConfiguration.from_matrix(
#         matrix = true_h,
#         geometry = geometry,
#         labels = True,
#     )

config2 = BasisConfiguration(
    point_types=["B", "A", "B"],
    positions=positions,
    basis=basis,
    cell=np.eye(3) * 100,
    pbc=(False, False, False),
    matrix=true_h,
)

configs = [config1, config2]

# Create the dataset
dataset = TorchBasisMatrixDataset(configs, data_processor=processor)