In [1]:
import time

# Download atom3dutils.py in same directory!
import atom3dutils
from atom3dutils import get_datasets, get_metrics

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchmetrics
import torch_scatter
import torch_geometric.data.batch as tg_batch
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.transforms import RadiusGraph, Compose, BaseTransform, Distance, Cartesian, RandomRotate
from torch_geometric.loader import DataLoader
import torch_geometric as tg

from e3nn import o3
from e3nn.o3 import Irreps
from e3nn.nn import Gate

import lightning.pytorch as lp
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
lp.seed_everything(42, workers=True)

Global seed set to 42


42

# PSA
This is copido from [here](https://colab.research.google.com/drive/1tTG0rJjLhHHnO7eXtMXmh0mRrGY95r0J#scrollTo=oKvMIMMr3gWg)

In [2]:
def balanced_irreps(hidden_features, lmax):
    """Divide subspaces equally over the feature budget"""
    N = int(hidden_features / (lmax + 1))

    irreps = []
    for l, irrep in enumerate(Irreps.spherical_harmonics(lmax)):
        n = int(N / (2 * l + 1))

        irreps.append(str(n) + "x" + str(irrep[1]))

    irreps = "+".join(irreps)

    irreps = Irreps(irreps)

    # Don't short sell yourself, add some more trivial irreps to fill the gap
    gap = hidden_features - irreps.dim
    if gap > 0:
        irreps = Irreps("{}x0e".format(gap)) + irreps
        irreps = irreps.simplify()

    return irreps

def compute_gate_irreps(irreps_out):
    """Compute irreps_scalars, irreps"""
    irreps_scalars = Irreps([(mul, ir) for mul, ir in irreps_out if ir.l == 0])
    irreps_gated = Irreps([(mul, ir) for mul, ir in irreps_out if ir.l > 0])
    irreps_gates = Irreps([(mul, "0e") for mul, _ in irreps_gated]).simplify()

    return irreps_scalars, irreps_gated, irreps_gates

class Convolution(nn.Module):
    """ SE(3) equivariant convolution, parameterised by a radial network """
    def __init__(self, irreps_in1, irreps_in2, irreps_out):
        super().__init__()
        self.irreps_in1 = irreps_in1
        self.irreps_in2 = irreps_in2
        self.irreps_out = irreps_out
        self.tp =  o3.FullyConnectedTensorProduct(
            irreps_in1,
            irreps_in2,
            irreps_out,
            irrep_normalization="component",
            path_normalization="element",
            internal_weights=False,
            shared_weights=False
        )

        self.radial_net = RadialNet(self.tp.weight_numel)

    def forward(self, x, rel_pos_sh, distance):
        """
        Features of shape [E, irreps_in1.dim]
        rel_pos_sh of shape [E, irreps_in2.dim]
        distance of shape [E, 1]
        """
        weights = self.radial_net(distance)
        return self.tp(x, rel_pos_sh, weights)

class RadialNet(nn.Module):
    def __init__(self, num_weights):
        super().__init__()

        num_basis = 10
        basis = tg.nn.models.dimenet.BesselBasisLayer(num_basis, cutoff=4)

        self.net = nn.Sequential(basis,
                                nn.Linear(num_basis, 16),
                                nn.SiLU(),
                                nn.Linear(16, num_weights))
    def forward(self, dist):
        return self.net(dist.squeeze(-1))


class ConvLayerSE3(tg.nn.MessagePassing):
    def __init__(self, irreps_in1, irreps_in2, irreps_out, activation=True):
        super().__init__(aggr="add")

        self.irreps_in1 = irreps_in1
        self.irreps_in2 = irreps_in2
        self.irreps_out = irreps_out

        irreps_scalars, irreps_gated, irreps_gates = compute_gate_irreps(irreps_out)
        self.conv = Convolution(irreps_in1, irreps_in2, irreps_gates + irreps_out)

        if activation:
            self.gate = Gate(irreps_scalars, [nn.SiLU()], irreps_gates, [nn.Sigmoid()], irreps_gated)
        else:
            self.gate = nn.Identity()

    def forward(self, edge_index, x, rel_pos_sh, dist):
        x = self.propagate(edge_index, x=x, rel_pos_sh=rel_pos_sh, dist=dist)
        x = self.gate(x)
        return x

    def message(self, x_i, x_j, rel_pos_sh, dist):
        print(type(x_i))
        return self.conv(x_j, rel_pos_sh, dist)

class ConvModel(nn.Module):
    def __init__(self, irreps_in, irreps_hidden, irreps_edge, irreps_out, depth, max_z:int=atom3dutils._NUM_ATOM_TYPES):
        super().__init__()

        self.irreps_in = irreps_in
        self.irreps_hidden = irreps_hidden
        self.irreps_edge = irreps_edge
        self.irreps_out = irreps_out

        self.embedder = nn.Embedding(max_z, irreps_in.dim)

        self.layers = nn.ModuleList()
        self.layers.append(ConvLayerSE3(irreps_in, irreps_edge, irreps_hidden))
        for i in range(depth-2):
            self.layers.append(ConvLayerSE3(irreps_hidden, irreps_edge, irreps_hidden))
        self.layers.append(ConvLayerSE3(irreps_hidden, irreps_edge, irreps_out, activation=False))

    def forward(self, graph):
        edge_index = graph.edge_index
        z = graph.z
        pos = graph.pos
        batch = graph.batch

        # Prepare quantities for convolutional layers
        # Index of source and target node
        src, tgt = edge_index[0], edge_index[1]
        # Vector pointing from the source node to the target node
        rel_pos = pos[tgt] - pos[src]
        # That vector in Spherical Harmonics
        rel_pos_sh = o3.spherical_harmonics(self.irreps_edge, rel_pos, normalize=True)
        # The norm of that vector
        dist = torch.linalg.vector_norm(rel_pos, dim=-1, keepdims=True)

        # Embed atom one-hot
        x = self.embedder(z)

        # Convolve nodes
        for layer in self.layers:
            x = layer(edge_index, x, rel_pos_sh, dist)

        # 1-dim output, squeeze it out
        x = x.squeeze(-1)

        # TODO: add dense layers
        
        # Global pooling of node features
        x = tg.nn.global_mean_pool(x, batch)
        return x

In [4]:
class Atom3D(lp.LightningModule):
    def __init__(
        self,
        model:nn.Module,
        metrics:dict[str,torchmetrics.Metric],
        lr:float=1e-4,
        *args,
        **kwargs
    ):
        
        super().__init__(*args, **kwargs)
        self.model = model
        self.lr = lr

        self.metrics = metrics
        self.loss_fn = nn.MSELoss()

    def forward(self, batch:tg_batch):
        return self.model(batch)

    def training_step(self, batch:tg_batch, batch_idx:int):
        out = self(batch)
        loss = self.loss_fn(out, batch.label)
                
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

    def validation_step(self, batch:tg_batch, batch_idx:int):
        out = self(batch)
        loss = self.loss_fn(out, batch.label)
        
        self.log("val/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

    def test_step(self, batch:tg_batch, batch_idx:int):
        out = self(batch)

        out = out.detach().cpu()
        label = batch.label.detach().cpu()

        results = dict()
        for key, func in self.metrics.items():
            results[f'test/{key}'] = func(out, label)
        self.log_dict(results, on_epoch=True, logger=True)
        
        return self.loss_fn(out, label)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        
        return {
            'optimizer': optimizer
        }

In [5]:
# MODEL
num_features = 16      # <- can be anything, we must test the impact!
l_max = 1              # <- has to be >0, increases model complexity REALLY fast
depth = 3              # <- nr of layers, must be >1

# TRAINING
epochs = 1
lr = 1e-3              # <- makes use of Adam, so doesnt really matter
batch_size = 2         # <- Choose biggest that doesnt crash
num_workers = 4        # <- just so we dont get a warning

# DATA
datadir='/media/jip/T7/DL02/data/' # <- TODO: change to whatever works for you
task='LBA'             # <-'PPI','RSR','PSR','MSP','LEP','LBA','SMP'
smp_idx=3              # <- range 0-19 (incl), only matters if task=='SMP'
lba_split=60           # <- 30 or 60, only matters if task=='LBA'

# LOGGING
logdir='./runs/'       # <- tensorboard --logdir=...
modeldir='./models/'   # <- saves top-2 and last models here

# TESTING 
# if set -> dont train only test, 
# otherwise -> train first, then test best model
test=None
# test='/home/jip/Desktop/DL02/repo/models/LBA-lba_split60-epochepoch28-metricval_metric0.00e00.ckpt'

In [6]:
irreps_in = (Irreps("1x0e")*num_features).simplify()
irreps_hidden = (Irreps.spherical_harmonics(l_max)*num_features).sort()[0].simplify()
irreps_edge = Irreps("1x1o")
irreps_out = Irreps("1x0e")

print("Input irreps", irreps_in)
print("Hidden irreps", irreps_hidden)
print("Edge irreps", irreps_edge)
print("Output irreps", irreps_out)
print("Dim hidden irreps:", irreps_hidden.dim)

model = ConvModel(irreps_in, irreps_hidden, irreps_edge, irreps_out, depth)
print()
print(model.embedder)
print(model.layers[0].conv)
print(model.layers[-1].conv)

Input irreps 16x0e
Hidden irreps 16x0e+16x1o+16x2e
Edge irreps 1x1o
Output irreps 1x0e
Dim hidden irreps: 144

Embedding(9, 16)
Convolution(
  (tp): FullyConnectedTensorProduct(16x0e x 1x1o -> 48x0e+16x1o+16x2e | 256 paths | 256 weights)
  (radial_net): RadialNet(
    (net): Sequential(
      (0): BesselBasisLayer(
        (envelope): Envelope()
      )
      (1): Linear(in_features=10, out_features=16, bias=True)
      (2): SiLU()
      (3): Linear(in_features=16, out_features=256, bias=True)
    )
  )
)
Convolution(
  (tp): FullyConnectedTensorProduct(16x0e+16x1o+16x2e x 1x1o -> 1x0e | 16 paths | 16 weights)
  (radial_net): RadialNet(
    (net): Sequential(
      (0): BesselBasisLayer(
        (envelope): Envelope()
      )
      (1): Linear(in_features=10, out_features=16, bias=True)
      (2): SiLU()
      (3): Linear(in_features=16, out_features=16, bias=True)
    )
  )
)


In [7]:
metrics:dict[str,callable]=get_metrics(task)
print("Test metrics:", list(metrics.keys()))

datasets:dict[str,any] = get_datasets(
    task=task, 
    smp_idx=smp_idx,
    lba_split=lba_split,
    data_dir=datadir)

dataloaders:dict[str,tg.loader.DataLoader] = {
    "train": tg.loader.DataLoader(datasets['train'], batch_size=batch_size, num_workers=num_workers, shuffle=True),
    "valid": tg.loader.DataLoader(datasets['valid'], batch_size=batch_size, num_workers=num_workers),
    "test":  tg.loader.DataLoader(datasets['test'],  batch_size=batch_size, num_workers=num_workers),
}

Test metrics: ['pearson', 'kendall', 'spearman', 'rmse']


In [8]:
_name = str(task)
if task == 'SMP':
    _name+=f'-smp_idx={smp_idx}'
elif task == 'LBA':
    _name+=f'-lba_split={lba_split}'

_version = f"{time.strftime('%Y%b%d-%T')}"
    
logger = TensorBoardLogger(
    save_dir=logdir, 
    name=_name,
    version=_version, 
)

checkpoint_callback = ModelCheckpoint(
    dirpath=modeldir, 
    save_top_k=2, 
    monitor="val_loss",
    mode='min',
    save_on_train_epoch_end=True,
    filename=_name+"-{epoch:02d}-{val_loss:.2e}",
    save_last=True,
)

In [9]:
plmodule = Atom3D(
    model=model, 
    metrics=metrics,
    lr=lr,
)

In [10]:
trainer = lp.Trainer(
    max_epochs=epochs,
    logger=logger,
    default_root_dir=modeldir,
    callbacks=[checkpoint_callback,],
)

if not test:
    trainer.fit(plmodule, dataloaders["train"], dataloaders["valid"])
    print("Best model:", checkpoint_callback.best_model_path)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params
--------------------------------------
0 | model   | ConvModel | 31.4 K
1 | loss_fn | MSELoss   | 0     
--------------------------------------
31.4 K    Trainable params
0         Non-trainable params
31.4 K    Total params
0.126     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 1782/1782 [03:01<00:00,  9.82it/s, v_num=2:10, train/loss_step=10.10, val/loss_step=16.80, val/loss_epoch=45.10, train/loss_epoch=45.90]

TypeError: 'method' object is not subscriptable

In [None]:
if checkpoint_callback.best_model_path:
    results = trainer.test(plmodule, dataloaders['test'], ckpt_path=checkpoint_callback.best_model_path)
elif test:
    results = trainer.test(plmodule, dataloaders['test'], ckpt_path=test)
else:
    print("Could not find a model to test")

Restoring states from the checkpoint path at /home/jip/Desktop/DL02/repo/models/LBA-lba_split60-epochepoch28-metricval_metric0.00e00.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/jip/Desktop/DL02/repo/models/LBA-lba_split60-epochepoch28-metricval_metric0.00e00.ckpt


Testing DataLoader 0:  97%|█████████▋| 220/226 [00:12<00:00, 17.08it/s]



Testing DataLoader 0: 100%|██████████| 226/226 [00:13<00:00, 17.06it/s]
