In [1]:
import os
import sys

FS_MOL_CHECKOUT_PATH = os.path.abspath('../')

os.chdir(FS_MOL_CHECKOUT_PATH)
sys.path.insert(0, FS_MOL_CHECKOUT_PATH)

In [2]:
from torch_geometric.data import InMemoryDataset
from torch.nn import functional as F
from torch import nn
import torch
from torch_geometric.nn.conv import PNAConv, RGCNConv
from torch_geometric.nn.models import GAT, PNA
from torch_geometric.nn.aggr import SumAggregation
from fs_mol.modules.graph_readout import CombinedGraphReadout
from fs_mol.clip_like import FingerprintEncoder
from torch_geometric.utils import to_undirected
from fs_mol.modules.pyg_gnn import PyG_GraphFeatureExtractor
from fs_mol.modules.graph_feature_extractor import GraphFeatureExtractorConfig
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
from dataclasses import dataclass

In [3]:
class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, raw_file_path, dest_file_name, transform=None, pre_transform=None, pre_filter=None):
        self.dest = dest_file_name
        self.raw_file_path = raw_file_path
        super().__init__(root, transform, pre_transform, pre_filter)
        
        self.data, self.slices = torch.load(self.processed_paths[0], map_location=torch.device('cuda'))
        
    @property
    def raw_file_names(self):
        return [self.raw_file_path]
    
    @property
    def processed_file_names(self):
        return [self.dest]
    
    def process(self):
        data_list = torch.load(self.raw_file_names[0])
        
        print(data_list)
        
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]
        
        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]
        
        data, slices = self.collate(data_list)
        
        torch.save((data, slices), self.processed_paths[0])

train_dataset = MyOwnDataset(root='/', raw_file_path='/FS-MOL/train_raw_mols.pt', dest_file_name='/FS-MOL/train_preprocessed_mols.pt')
valid_dataset = MyOwnDataset(root='/', raw_file_path='/FS-MOL/valid_raw_mols.pt', dest_file_name='/FS-MOL/valid_preprocessed_mols.pt')     

In [4]:
print(f'Train Dataset Size: {len(train_dataset)}')
print(f'Valid Dataset Size: {len(valid_dataset)}')

Train Dataset Size: 216827
Valid Dataset Size: 14735


In [5]:
@dataclass
class TrainConfig:
    epochs: int = 10
    batch_size: int = 32
    graph_encoder_num_layers: int = 5
    graph_encoder_hidden_dim: int = 80
    graph_encoder_out_dim: int = 256
    graph_encoder_heads: int = 4
    graph_encoder_edge_dim: int = 1
    graph_encoder_dropout: float = 0.1
    graph_encoder_mlp_hidden_dim: int = 512
    
    fingerprint_encoder_hidden_dim: int = 1024
    fingerprint_encoder_output_dim: int = 512
    fingerprint_encoder_dropout: int = 0.1
    
config = TrainConfig()

config.graph_encoder_num_layers

5

In [6]:
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, drop_last=False)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, drop_last=False)

batch = next(iter(train_loader))

In [7]:
# deg_hist = PNAConv.get_degree_histogram(loader=loader)


In [8]:
# class PNA_GraphEncoder(nn.Module):
#     def __init__(self, deg_hist) -> None:
#         super().__init__()
        
#         self.batch_size = 32
        
#         self.gnn = PNA(
#             in_channels=32,
#             hidden_channels=128,
#             num_layers=10,
#             out_channels=128,
#             edge_dim=1,
#             aggregators=['sum', 'mean', 'max', 'std'],
#             scalers=['amplification', 'attenuation'],
#             deg=deg_hist
#         )
        
#         self.readout = CombinedGraphReadout(
#             node_dim=128,
#             out_dim=512,
#             num_heads=12,
#             head_dim=64,
#         )
        
#     def forward(self, batch):
#         edge_index, edge_attr = to_undirected(batch.edge_index, batch.edge_attr, 32)
#         node_features = self.gnn(batch.x, edge_index, edge_attr=edge_attr)
#         return self.readout(node_features, batch.batch, self.batch_size)
    
# model = PNA_GraphEncoder(deg_hist=deg_hist).cuda()

# model(batch).shape

In [9]:
from torch_geometric.nn.norm import LayerNorm

class GAT_GraphEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.gnn = GAT(
            32,
            hidden_channels=config.graph_encoder_hidden_dim,
            num_layers=config.graph_encoder_num_layers,
            out_channels=config.graph_encoder_out_dim,
            heads=config.graph_encoder_heads,
            v2=True,
            edge_dim=config.graph_encoder_edge_dim,
            dropout=config.graph_encoder_dropout,
            add_self_loops=True,
        )
        
        self.aggr = SumAggregation()
        
        mlp_hidden_dim = config.fingerprint_encoder_hidden_dim
        mlp_output_dim = config.fingerprint_encoder_output_dim
        
        self.mlp = nn.Sequential(
            nn.Linear(config.graph_encoder_out_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(mlp_hidden_dim),
            nn.Linear(mlp_hidden_dim, mlp_output_dim)
        )
        
    def forward(self, batch):
        node_feats = self.gnn(batch.x, batch.edge_index, edge_attr=batch.edge_attr.to(torch.float32))
        graph_feats = self.aggr(node_feats, batch.batch)
        
        return self.mlp(graph_feats)
        
model = GAT_GraphEncoder().to('cuda')

model(batch)

tensor([[-0.6943,  0.4627,  0.3599,  ...,  1.2139,  0.2754, -1.0201],
        [-0.5942,  0.3890,  0.3243,  ...,  1.2098,  0.3400, -0.9228],
        [-0.6145,  0.3334,  0.1615,  ...,  1.1944,  0.3325, -0.8099],
        ...,
        [-0.5830,  0.3726,  0.3106,  ...,  1.2075,  0.2881, -1.0368],
        [-0.6587,  0.4870,  0.2075,  ...,  1.1493,  0.1682, -0.9762],
        [-0.5897,  0.3647,  0.3321,  ...,  1.2491,  0.3384, -0.9824]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

In [10]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from wandb import wandb

valid_step = 0

class ClipLike(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.temp = 0.1
        # self.graph_encoder = PyG_GraphFeatureExtractor(GraphFeatureExtractorConfig())
        self.graph_encoder = GAT_GraphEncoder()
        self.fingerprint_encoder = FingerprintEncoder(2048, config.fingerprint_encoder_hidden_dim, config.fingerprint_encoder_output_dim, config.fingerprint_encoder_dropout)
    
    def training_step(self, batch, batch_idx):
        encoded_graphs = self.graph_encoder(batch)
        encoded_fingerprints = self.fingerprint_encoder(batch.fingerprint.reshape(-1, 2048).to(torch.float32))
        # TODO: Check the OpenAI's Codebase on CLIP and make sure this is right.
        logits = encoded_graphs @ encoded_fingerprints.T / self.temp
        
        current_batch_size = encoded_graphs.shape[0]
        
        targets = torch.eye(current_batch_size).cuda() / self.temp
        
        loss = F.cross_entropy(logits, targets, reduction="none")
        
        loss = loss.mean()
        self.log('train_loss', loss, batch_size=current_batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        encoded_graphs = self.graph_encoder(batch)
        encoded_fingerprints = self.fingerprint_encoder(batch.fingerprint.reshape(-1, 2048).to(torch.float32))
        # TODO: Check the OpenAI's Codebase on CLIP and make sure this is right.
        logits = encoded_graphs @ encoded_fingerprints.T / self.temp
        
        current_batch_size = encoded_graphs.shape[0]
        
        targets = torch.eye(current_batch_size).cuda() / self.temp
        
        loss = F.cross_entropy(logits, targets, reduction="none")
        
        loss = loss.mean()
        self.log('valid_loss', loss, batch_size=current_batch_size)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    
wandb.init(config=config)
model = ClipLike()
wandb.watch(model, log='all')
    
trainer = pl.Trainer(logger=WandbLogger(), accelerator='gpu', devices=1, max_epochs=100)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=valid_loader)



Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mathecoder[0m ([33mdest[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668033449999106, max=1.0…

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type               | Params
-----------------------------------------------------------
0 | graph_encoder       | GAT_GraphEncoder   | 1.0 M 
1 | fingerprint_encoder | FingerprintEncoder | 2.6 M 
-----------------------------------------------------------
3.6 M     Trainable params
0         Non-trainable params
3.6 M     Total params
14.513    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]