In [1]:
!pip install git+https://github.com/muhrin/mrs-tutorial.git

Collecting git+https://github.com/muhrin/mrs-tutorial.git
  Cloning https://github.com/muhrin/mrs-tutorial.git to /tmp/pip-req-build-guqagh47
  Running command git clone --filter=blob:none --quiet https://github.com/muhrin/mrs-tutorial.git /tmp/pip-req-build-guqagh47
  Resolved https://github.com/muhrin/mrs-tutorial.git to commit 16ef89ae60a20b997d674312027bc19ad08ef168
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting ase (from e3nn-invstutorial==0.1.0)
  Downloading ase-3.23.0-py3-none-any.whl.metadata (3.8 kB)
Collecting RISE (from e3nn-invstutorial==0.1.0)
  Downloading rise-5.7.1-py2.py3-none-any.whl.metadata (2.8 kB)
Collecting fqdn (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook->e3nn-invstutorial==0.1.0)
  Downloading fqdn-1.5.1-py3-none-any.whl.metadata (1.4 kB)
Collecting isoduration (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook->e3nn-invstutorial==0.1.0)
  Do

In [2]:
!pip install ase



In [3]:
import ase

In [4]:
import json
import random
from functools import partial

import ase
from ase import build, visualize, io
import e3nn.io
from e3nn import o3, io
from e3nn_invstutorial import radial_spherical_tensor, orthonormal_radial_basis
import ipywidgets
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas
from plotly import subplots
import plotly.graph_objects as go
import plotly.express as px
from scipy.spatial.transform import Rotation
import sympy
import torch

In [5]:
def view(atoms: ase.Atom, centre=True):
    viewer = visualize.view(atoms, viewer='x3d')
    # if not centre:
    # nglview.view.center(selection='0')
    return viewer

def show_array(positions, calc=None):
    if calc is None:
        calc = lambda x: x
    
    fig = px.imshow(calc(positions), color_continuous_scale='RdBu', zmin=-5, zmax=5)
    widget = go.FigureWidget(fig)

    @ipywidgets.interact(
        xrot=(0, 360, 1.),
        xtrans=(-10, 10, 1.),
        swap=(0, len(positions) - 1, 1)
    )
    def update(xrot=0, xtrans=0, swap=0):
        with widget.batch_update():
            pos = positions.copy()
            
            # Permute
            if swap:
                pos[[swap, 0]] = pos[[0, swap]]
                
            # Translate
            if xtrans:
                pos[:, 0] += xtrans
            
            # Rotate
            if xrot:
                rot = Rotation.from_euler('x', [xrot], degrees=True)
                pos = rot.apply(pos)
            
            print(pos)
            data = calc(pos).T
            widget.data[0].z = data

    return widget

In [6]:
def s2_grid():
    betas = torch.linspace(0, math.pi, 40)
    alphas = torch.linspace(0, 2 * math.pi, 80)
    beta, alpha = torch.meshgrid(betas, alphas)
    return o3.angles_to_xyz(alpha, beta)

def trace(r, f, c, radial_abs=True):
    if radial_abs:
        a = f.abs()
    else:
        a = 1
    return dict(
        x=a * r[..., 0] + c[0],
        y=a * r[..., 1] + c[1],
        z=a * r[..., 2] + c[2],
        surfacecolor=f
    )

def plot(data, radial_abs=True):
    r = s2_grid()
    n = data.shape[-1]
    traces = [
        trace(r, data[..., i], torch.tensor([2.0 * i - (n - 1.0), 0.0, 0.0]), radial_abs=radial_abs)
        for i in range(n)
    ]
    cmax = max(d['surfacecolor'].abs().max().item() for d in traces)
    traces = [go.Surface(**d, colorscale='RdBu', cmin=-cmax, cmax=cmax) for d in traces]
    fig = go.Figure(data=traces, layout=layout)
    fig.show()
    
def plot_sphere(r):
    fig = go.Figure(
        data=[
            go.Scatter3d(
                x=r[..., 0].flatten(),
                y=r[..., 1].flatten(),
                z=r[..., 2].flatten(),
                mode='markers',
                marker=dict(
                    size=1,
                ),
            )
        ],
        layout=dict(
            width=500,
            height=300,
            scene=dict(
                xaxis=dict(
                    **axis,
                    range=[-1, 1]
                ),
                yaxis=dict(
                    **axis,
                    range=[-1, 1]
                ),
                zaxis=dict(
                    **axis,
                    range=[-1, 1]
                ),
                aspectmode='manual',
                aspectratio=dict(x=3, y=3, z=3),
                camera=dict(
                    up=dict(x=0, y=0, z=1),
                    center=dict(x=0, y=0, z=0),
                    eye=dict(x=0, y=-5, z=5),
                    projection=dict(type='orthographic'),
                ),
            ),
            paper_bgcolor="rgba(0,0,0,0)",
            plot_bgcolor="rgba(0,0,0,0)",
            margin=dict(l=0, r=0, t=0, b=0)
        )
    )
    fig.show()

In [7]:
molecule = build.molecule('CH3COCH3')
view(molecule)

In [8]:
show_array(molecule.positions)

interactive(children=(FloatSlider(value=0.0, description='xrot', max=360.0, step=1.0), FloatSlider(value=0.0, …

FigureWidget({
    'data': [{'coloraxis': 'coloraxis',
              'hovertemplate': 'x: %{x}<br>y: %{y}<br>color: %{z}<extra></extra>',
              'name': '0',
              'type': 'heatmap',
              'uid': '86f6934a-7372-4ed5-8fd7-5ecaae29691b',
              'xaxis': 'x',
              'yaxis': 'y',
              'z': array([[ 0.      ,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
                           -0.881086,  0.881086,  0.881086, -0.881086],
                          [ 0.      ,  0.      ,  1.28549 , -1.28549 ,  2.134917, -2.134917,
                            1.331548,  1.331548, -1.331548, -1.331548],
                          [ 1.405591,  0.17906 , -0.616342, -0.616342,  0.066535,  0.066535,
                           -1.264013, -1.264013, -1.264013, -1.264013]])}],
    'layout': {'coloraxis': {'cmax': 5,
                             'cmin': -5,
                             'colorscale': [[0.0, 'rgb(103,0,31)'], [0.1,
                               