# Tutorial 4: Geometric Pretraining

## Step 1. Load Packages and Set Random Seeds and Device

In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader

from Geom3D.models import SchNet, GeoSSL_DDM
from Geom3D.datasets import Molecule3DDataset
from Geom3D.dataloaders import AtomTupleExtractor, DataLoaderAtomTuple

import sys
sys.path.insert(0, "../examples_3D")

from tqdm import tqdm

seed = 42
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)

device = "cuda" if torch.cuda.is_available() else torch.device("cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader

from Geom3D.models import SchNet, GeoSSL_DDM
from Geom3D.datasets import Molecule3DDataset
from Geom3D.dataloaders import AtomTupleExtractor, DataLoaderAtomTuple

import sys
sys.path.insert(0, "../examples_3D")

from tqdm import tqdm

seed = 42
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)

device = "cuda" if torch.cuda.is_available() else torch.device("cpu")

## Step 2. Set Task, Dataset and Dataloader

In [3]:
dataset = "QM9"

data_root = "../data/{}".format(dataset)
transform = AtomTupleExtractor(ratio=0.1, option="combination")

dataset = Molecule3DDataset(data_root, dataset=dataset, transform=transform)

batch_size = 128
num_workers = 0
loader = DataLoaderAtomTuple(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

Dataset: QM9
Data: Data(x=[2359210], edge_index=[2, 4883516], edge_attr=[4883516, 3], positions=[2359210, 3], id=[130831], y=[1700803])


## Step 3. Set Model

In [4]:
node_class, edge_class = 119, 5
num_tasks = 1

emb_dim = 128
SchNet_num_filters = 128
SchNet_num_interactions = 6
SchNet_num_gaussians = 51
SchNet_cutoff = 10
SchNet_readout = "mean"

model = SchNet(
    hidden_channels=emb_dim,
    num_filters=SchNet_num_filters,
    num_interactions=SchNet_num_interactions,
    num_gaussians=SchNet_num_gaussians,
    cutoff=SchNet_cutoff,
    readout=SchNet_readout,
    node_class=node_class,
).to(device)
graph_pred_linear = torch.nn.Linear(emb_dim, num_tasks).to(device)

## Step 4. Set Pretraining Model

In [5]:
emb_dim = 128
GeoSSL_sigma_begin, GeoSSL_sigma_end, GeoSSL_num_noise_level = 10, 0.01, 10
GeoSSL_noise_type = "symmetry"
GeoSSL_anneal_power = 2

GeoSSL_DDM_01 = GeoSSL_DDM(
    emb_dim,
    sigma_begin=GeoSSL_sigma_begin, sigma_end=GeoSSL_sigma_end, num_noise_level=GeoSSL_num_noise_level,
    noise_type=GeoSSL_noise_type, anneal_power=GeoSSL_anneal_power).to(device)
GeoSSL_DDM_02 = GeoSSL_DDM(
    emb_dim,
    sigma_begin=GeoSSL_sigma_begin, sigma_end=GeoSSL_sigma_end, num_noise_level=GeoSSL_num_noise_level,
    noise_type=GeoSSL_noise_type, anneal_power=GeoSSL_anneal_power).to(device)

## Step 5. Set Optimizer

In [6]:
lr = 5e-4
decay = 0
criterion = nn.BCEWithLogitsLoss()

model_param_group = [
    {"params": model.parameters(), "lr": lr},
    {"params": graph_pred_linear.parameters(), "lr": lr},
    {"params": GeoSSL_DDM_01.parameters(), "lr": lr},
    {"params": GeoSSL_DDM_02.parameters(), "lr": lr},
]
optimizer = optim.Adam(model_param_group, lr=lr, weight_decay=decay)

## Step 6. Start Pretraining

In [7]:
def perturb(x, positions, mu, sigma):
    x_perturb = x

    device = positions.device
    positions_perturb = positions + torch.normal(mu, sigma, size=positions.size()).to(device)

    return x_perturb, positions_perturb

mu, sigma = 0, 0.3

epochs = 1
for batch in tqdm(loader):
    batch = batch.to(device)

    positions = batch.positions

    x_01 = batch.x
    positions_01 = positions
    x_02, positions_02 = perturb(x_01, positions, mu, sigma)

    _, molecule_3D_repr_01 = model(x_01, positions_01, batch.batch, return_latent=True)
    _, molecule_3D_repr_02 = model(x_02, positions_02, batch.batch, return_latent=True)

    super_edge_index = batch.super_edge_index

    u_pos_01 = torch.index_select(positions_01, dim=0, index=super_edge_index[0])
    v_pos_01 = torch.index_select(positions_01, dim=0, index=super_edge_index[1])
    distance_01 = torch.sqrt(torch.sum((u_pos_01-v_pos_01)**2, dim=1)).unsqueeze(1) # (num_edge, 1)
    
    u_pos_02 = torch.index_select(positions_02, dim=0, index=super_edge_index[0])
    v_pos_02 = torch.index_select(positions_02, dim=0, index=super_edge_index[1])
    distance_02 = torch.sqrt(torch.sum((u_pos_02-v_pos_02)**2, dim=1)).unsqueeze(1) # (num_edge, 1)

    loss_01 = GeoSSL_DDM_01(batch, molecule_3D_repr_01, distance_02)
    loss_02 = GeoSSL_DDM_02(batch, molecule_3D_repr_02, distance_01)
    
    loss = (loss_01 + loss_02) / 2

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1023/1023 [00:54<00:00, 18.62it/s]
