Code implemented from: https://github.com/lucidrains/graph-transformer-pytorch
- pip install einops
- pip install rotary-embedding-torch

In [1]:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--device', type=str, default='cuda:2')
# check gpu status with: gpustat -cuFi 1

parser.add_argument('--hdim', type=int, default=64) #hidden dimension
parser.add_argument('--edge_dim', type=int, default=None)
parser.add_argument('--n_layers', type=int, default=3)

parser.add_argument('--lr', type=float, default=1e-3)

args = parser.parse_args([])

In [2]:
model_name = f'GraphTransformer_h{args.hdim}e{args.edge_dim}l{args.n_layers}_lr{args.lr}_test'

In [4]:
import pandas as pd
import numpy as np
import sys

import torch
import torch.nn as nn
from graph_transformer_pytorch import GraphTransformer
import torch_geometric

ModuleNotFoundError: No module named 'einops'

In [None]:
sys.path.append('../')
from utils_dm import EarlyStopper

### dataset

In [52]:
#from torch_geometric.utils.smiles import from_smiles
from geometric_utils import from_smiles

In [9]:
#whole_df = pd.read_csv('../../../2023-2/ETC/molecule_stability.csv')

In [53]:
#whole_data = [from_smiles(smi) for smi in whole_df.SMILES]
#len(whole_data)

3498

In [13]:
#torch.save(whole_data, '../../../2023-2/processed_data/graph_transformer/graph_transformer_whole_data.pt')

In [14]:
whole_data = torch.load('../../../2023-2/processed_data/graph_transformer/graph_transformer_whole_data.pt')

In [15]:
# build dataset
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader

class MyDataset(Dataset):
    def __init__(self, dataset, labels):
        self.dataset = dataset
        self.labels = torch.tensor(labels, dtype=torch.float32)
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx], self.labels[idx]

In [16]:
labels = whole_df[['MLM','HLM']].values
dataset = MyDataset(whole_data, labels)

In [24]:
# split dataset
test_ratio = 0.1
valid_ratio = 0.1

test_len = int(len(dataset)*test_ratio)
valid_len = int(len(dataset)*valid_ratio)
train_len = len(dataset) - valid_len - test_len
print(train_len, valid_len, test_len)

trainset,validset,testset = torch.utils.data.random_split(dataset, [train_len,valid_len,test_len],
                                      torch.Generator().manual_seed(42))
print(len(trainset), len(validset), len(testset))

2800 349 349
2800 349 349


In [25]:
# build dataloader
trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=False,
                        generator=torch.Generator().manual_seed(42))
validloader = DataLoader(validset, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)
testloader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)

# Build model

In [26]:
from torch_scatter import scatter_mean
from torch_geometric.utils import to_dense_batch, to_dense_adj

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

class GraphTransformer_MTL(nn.Module):
    def __init__(self, dim, depth, edge_dim = None, tasks=2, with_feedforwards = True, gated_residual = True, rel_pos_emb = False):
        super().__init__()
        edge_dim = default(edge_dim, dim)
        self.encoder = GraphTransformer(dim = dim, depth = depth, edge_dim=edge_dim, 
                                        with_feedforwards=with_feedforwards, gated_residual=gated_residual, 
                                        rel_pos_emb=rel_pos_emb)
        
        self.node_feature_encoder = nn.Linear(9, dim)
        self.edge_feature_encoder = nn.Linear(3, edge_dim)
        self.pred_heads = nn.ModuleList()
        for _ in range(tasks):  # tasks = 2 
            self.pred_heads.append(
                nn.Sequential(
                    nn.Linear(dim, dim),
                    nn.ReLU(),
                    nn.Linear(dim, dim),
                    nn.ReLU(),
                    nn.Linear(dim, 1)
                )
            )
    
    def encode_features(self, batch):
        z_x = self.node_feature_encoder(batch.x.float())
        z_e = self.edge_feature_encoder(batch.edge_attr.float())
        nodes, mask = to_dense_batch(z_x, batch.batch)
        edges = to_dense_adj(batch.edge_index, batch.batch, edge_attr=z_e)
        return nodes, edges, mask

    def forward(self, batch):
        nodes, edges, mask = self.encode_features(batch)
        nodes, edges = self.encoder(nodes, edges, mask = mask)

        res = scatter_mean(nodes[mask], batch.batch, dim=0)
        
        preds = []
        for head in self.pred_heads:
            out = head(res)
            preds.append(out)
        return torch.cat(preds, dim = -1)

In [35]:
model = GraphTransformer_MTL(dim=args.hdim, depth=args.n_layers, edge_dim=args.edge_dim).to(args.device)

### Train

In [42]:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.MSELoss()    # regression 
early_stopper = EarlyStopper(patience=20,printfunc=print,verbose=True,path=f'ckpts/{model_name}.pt')

In [43]:
def train(model, trainloader, args, optimizer=optimizer, criterion=criterion):
    model.train()
    train_loss = 0
    for batch, label in trainloader:
        batch = batch.to(args.device)
        label = label.to(args.device)

        optimizer.zero_grad()
        pred = model(batch)
        
        loss1 = criterion(pred[:,0].squeeze(), label[:,0].squeeze())
        loss2 = criterion(pred[:,1].squeeze(), label[:,1].squeeze())
        loss = (loss1 + loss2)/2

        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    return train_loss/len(trainloader)

In [44]:
def eval(model, loader, args, return_output=False, criterion=criterion):
    model.eval()
    preds = []
    ys = []
    with torch.no_grad():
        for batch, label in loader:
            batch = batch.to(args.device)
            label = label.to(args.device)
            pred = model(batch)
            preds.append(pred)
            ys.append(label)
    preds = torch.cat(preds, dim=0)
    ys = torch.cat(ys, dim=0)
    
    loss1 = criterion(preds[:,0].squeeze(), ys[:,0].squeeze())
    loss2 = criterion(preds[:,1].squeeze(), ys[:,1].squeeze())
    loss = (loss1 + loss2)/2

    if return_output:
        return loss.item(), preds, ys
    else:
        return loss.item()

In [45]:
epoch = 0
while True:
    epoch+=1
    train_loss = train(model,trainloader,args)**0.5 # RMSE: root MSE
    valid_loss = eval(model,validloader,args)**0.5 # RMSE: root MSE
    print(f'[Epoch{epoch}] train_loss: {train_loss:.4f}, valid_loss: {valid_loss:.4f}')
    early_stopper(valid_loss,model)
    if early_stopper.early_stop:
        print('early stopping')
        break

[Epoch1] train_loss: 35.9834, valid_loss: 36.2212
[Epoch2] train_loss: 35.8342, valid_loss: 36.5837
EarlyStopping counter: 1/20
[Epoch3] train_loss: 36.4064, valid_loss: 36.0178
[Epoch4] train_loss: 36.0012, valid_loss: 36.7574
EarlyStopping counter: 1/20
[Epoch5] train_loss: 35.6902, valid_loss: 35.6540
[Epoch6] train_loss: 35.4527, valid_loss: 35.8211
EarlyStopping counter: 1/20
[Epoch7] train_loss: 35.1027, valid_loss: 35.1435
[Epoch8] train_loss: 35.0703, valid_loss: 36.3097
EarlyStopping counter: 1/20
[Epoch9] train_loss: 34.9416, valid_loss: 34.9781
[Epoch10] train_loss: 34.6726, valid_loss: 36.1143
EarlyStopping counter: 1/20
[Epoch11] train_loss: 34.4224, valid_loss: 35.6266
EarlyStopping counter: 2/20
[Epoch12] train_loss: 34.4940, valid_loss: 35.0928
EarlyStopping counter: 3/20
[Epoch13] train_loss: 34.5630, valid_loss: 34.8735
[Epoch14] train_loss: 34.1196, valid_loss: 34.3184
[Epoch15] train_loss: 33.9026, valid_loss: 35.4201
EarlyStopping counter: 1/20
[Epoch16] train_loss

### Validate

In [55]:
model.load_state_dict(torch.load(early_stopper.path,map_location=args.device))
model.eval()
print(f'loaded best model "{early_stopper.path}", valid loss: {early_stopper.val_loss_min:.4f}')

loaded best model "ckpts/GraphTransformer_h64eNonel3_lr0.001_test.pt", valid loss: 33.2596


In [59]:
test_loss=eval(model,testloader,args)**0.5
print(f'Final test loss: {test_loss:.4f}')

Final test loss: 32.6390
