In [1]:
#!nvidia-smi | grep 300W

In [2]:
import os
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [3]:
import torch
import pandas as pd
from tqdm import tqdm
import numpy as np
import wandb
import time
import copy

In [4]:
wandb.require("service")

In [5]:
!python3 -m wandb login eb7b1964fb84cd81de96b2a273ecf2bb6254aeac

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/nick1899/.netrc


In [6]:
tqdm.pandas()

### Upload config

In [7]:
import yaml

config = yaml.load(open("config-graphormer.yaml", "r"), Loader=yaml.FullLoader)
print(config)

{'batch_size': 32, 'warm_up': 2, 'epochs': 40, 'load_model': 'None', 'save_every_n_epochs': 1, 'fp16_precision': False, 'init_lr': 0.0001, 'weight_decay': '1e-5', 'gpu': 'cuda:0', 'graph_model_type': 'gcn', 'model': {'num_layer': 5, 'emb_dim': 300, 'feat_dim': 768, 'drop_ratio': 0, 'pool': 'mean'}, 'aug': 'node', 'dataset': {'num_workers': 12, 'valid_size': 0.1, 'data_path': 'data/pubchem-10m-clean.txt'}, 'loss': {'temperature': 0.1, 'use_cosine_similarity': True}, 'loss_params': {'alpha': 1, 'beta': 1.5, 'gamma': 3}}


In [8]:
#config['batch_size'] = 16
config['num_workers'] = 1
print('batch_size =', config['batch_size'])
#config['gpu'] = 0
batch_size = config['batch_size']

batch_size = 32


In [9]:
print('running on device:', config['gpu'])
device = torch.device(config['gpu']) if torch.cuda.is_available() else torch.device('cpu')
device

running on device: cuda:0


device(type='cuda', index=0)

In [10]:
def _save_config_file(config, log_dir):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    with open(os.path.join(log_dir, 'config.yml'), 'w') as outfile:
        yaml.dump(config, outfile, default_flow_style=False, sort_keys=False)

### Upload and Split Dataset

In [11]:
# dataframe = pd.read_csv("cleared_pubchem10m-ecfp1.csv", usecols = ['smiles', 'ecfp1'])
dataframe = pd.read_csv("cleared_pubchem10m-ecfp1.csv")

In [12]:
dataframe = dataframe.sample(1000)

In [13]:
# dataframe = dataframe.sample(3200*4)
dataframe = dataframe.reset_index(drop=True)

In [14]:
dataframe['ecfp1'].iloc[0]

"[['2246728737', '3217380708', '3218693969', '3218693969', '3217380708', '2245384272', '2245273601', '847961216', '2246699815', '864942730', '2245900962', '2245900962', '3388977530', '2246728737', '3217380708', '3218693969', '3218693969', '3218693969', '3218693969', '3218693969', '3217380708', '3218693969', '3218693969', '3218693969', '3218693969', '3218693969', '3217380708', '3218693969', '3218693969', '3218693969', '3218693969', '3218693969', '3218693969', '3218693969']]"

In [15]:
# this because pandas thinks columns with arrays are strings
def preprocess_data_dataset(df, column):
    for row in tqdm(range(len(df))):
        str_ints = eval(df.iloc[row][column])[0] # change for ecfp2 and so on
        str_fingerprint = ' '.join(str_ints)
        df.at[row, column] = str_fingerprint

In [16]:
preprocess_data_dataset(dataframe, 'ecfp1')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 3185.21it/s]


In [17]:
#dataframe = dataframe.rename(columns={'smiles': 'Smiles'})

In [18]:
dataframe

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,smiles,ecfp1
0,236909,236909,Cc1ccc(CC(NC(=O)C#C[Si](C)(c2ccccc2)c2ccccc2)c...,2246728737 3217380708 3218693969 3218693969 32...
1,6991312,6991312,COc1ccc(CNc2ccc(C(=O)N3CCCCCC3)cn2)cc1,2246728737 864674487 3217380708 3218693969 321...
2,4650989,4650989,C=CCNC(=O)CC(C(=O)[O-])c1ccncc1,2246997334 2246703798 2245384272 847961216 224...
3,1607453,1607453,COC(=O)c1ccc(NC2(C(F)(F)F)CC2)nn1,2246728737 864674487 2246699815 864942730 3217...
4,6424995,6424995,CC(C)n1ccc(NC(=O)NCc2ccn(C3CCCC3)n2)n1,2246728737 2245273601 2246728737 2092489639 32...
...,...,...,...,...
995,1604060,1604060,COc1cccc(NC(=O)CSc2nnc(NN=Cc3cccc(Cl)c3)n2N)c1,2246728737 864674487 3217380708 3218693969 321...
996,9605734,9605734,CCCN(Cc1ccc(N)cc1)c1nc(C)cc(C)n1,2246728737 2245384272 2245384272 848128881 224...
997,7607186,7607186,COc1ccc(c2ncncc2CC(=O)[O-])cc1C,2246728737 864674487 3217380708 3218693969 321...
998,3774350,3774350,CC1C[NH+](CC(O)COCc2ccccc2Cl)CC(C)O1,2246728737 2976033787 2968968094 2143075994 22...


In [19]:
#dataframe = dataframe.drop(columns=['ecfp2', 'ecfp3', 'Molecular Weight', 'Bioactivities', 'AlogP', 'Polar Surface Area', 'CX Acidic pKa', 'CX Basic pKa'])

In [20]:
# dataframe.head()

### Create Molecule Dataset
##### It will generate torch_geometric.data.Data objects for both bert and GIN/GCN models.

In [21]:
from rdkit import Chem

ATOM_LIST = list(range(1,119))
CHIRALITY_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [
    Chem.rdchem.BondType.SINGLE, 
    Chem.rdchem.BondType.DOUBLE, 
    Chem.rdchem.BondType.TRIPLE, 
    Chem.rdchem.BondType.AROMATIC
]
BONDDIR_LIST = [
    Chem.rdchem.BondDir.NONE,
    Chem.rdchem.BondDir.ENDUPRIGHT,
    Chem.rdchem.BondDir.ENDDOWNRIGHT
]

In [22]:
from transformers import AutoTokenizer

model_name_bert = 'molberto_ecfp0_2M'


tokenizer = AutoTokenizer.from_pretrained(model_name_bert)
        # Creating a new column by applying the function
#         self.dataset['graph'] = self.dataset['Smiles'].apply(self.get_graph_from_smiles)


  from .autonotebook import tqdm as notebook_tqdm


In [23]:
import random
import math
from copy import deepcopy
from torch_geometric.data import Data, Dataset


In [24]:
import transformers
print(transformers.__file__)

/home/nick1899/anaconda3/envs/mol/lib/python3.9/site-packages/transformers/__init__.py


In [25]:
#from transformers.models.graphormer.collating_graphormer import GraphormerDataCollator
import transformers
#from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator
from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader
from transformers import GraphormerForGraphClassification
import torch
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
from transformers import AdamW, get_scheduler
#import transformers.models.graphormer.collating_graphormer

# Graphormer collator

In [27]:
from transformers.utils import is_cython_available
if is_cython_available():

    import pyximport

    pyximport.install(setup_args={"include_dirs": np.get_include()})
    
    from transformers.models.graphormer import algos_graphormer as algos_graphormer

In [28]:
from typing import Any, Dict, List, Mapping

import numpy as np
import torch

from transformers.utils import is_cython_available, requires_backends


if is_cython_available():
    import pyximport

    pyximport.install(setup_args={"include_dirs": np.get_include()})
    import sys
    sys.path.append('algos_graphormer.so')
    import algos_graphormer


def convert_to_single_emb(x, offset: int = 512):
    feature_num = x.shape[1] if len(x.shape) > 1 else 1
    feature_offset = 1 + np.arange(0, feature_num * offset, offset, dtype=np.int64)
    x = x + feature_offset
    return x


def preprocess_item(item, keep_features=True):
    requires_backends(preprocess_item, ["cython"])

    if keep_features and "edge_attr" in item.keys():  # edge_attr
        edge_attr = np.asarray(item["edge_attr"], dtype=np.int64)
    else:
        edge_attr = np.ones((len(item["edge_index"][0]), 1), dtype=np.int64)  # same embedding for all

    if keep_features and "node_feat" in item.keys():  # input_nodes
        node_feature = np.asarray(item["node_feat"], dtype=np.int64)
    else:
        node_feature = np.ones((item["num_nodes"], 1), dtype=np.int64)  # same embedding for all

    edge_index = np.asarray(item["edge_index"], dtype=np.int64)

    input_nodes = convert_to_single_emb(node_feature) + 1
    num_nodes = item["num_nodes"]

    if len(edge_attr.shape) == 1:
        edge_attr = edge_attr[:, None]
    attn_edge_type = np.zeros([num_nodes, num_nodes, edge_attr.shape[-1]], dtype=np.int64)
    attn_edge_type[edge_index[0], edge_index[1]] = convert_to_single_emb(edge_attr) + 1

    # node adj matrix [num_nodes, num_nodes] bool
    adj = np.zeros([num_nodes, num_nodes], dtype=bool)
    adj[edge_index[0], edge_index[1]] = True

    shortest_path_result, path = algos_graphormer.floyd_warshall(adj)
    max_dist = np.amax(shortest_path_result)

    input_edges = algos_graphormer.gen_edge_input(max_dist, path, attn_edge_type)
    attn_bias = np.zeros([num_nodes + 1, num_nodes + 1], dtype=np.single)  # with graph token

    # combine
    item["input_nodes"] = input_nodes + 1  # we shift all indices by one for padding
    item["attn_bias"] = attn_bias
    item["attn_edge_type"] = attn_edge_type
    item["spatial_pos"] = shortest_path_result.astype(np.int64) + 1  # we shift all indices by one for padding
    item["in_degree"] = np.sum(adj, axis=1).reshape(-1) + 1  # we shift all indices by one for padding
    item["out_degree"] = item["in_degree"]  # for undirected graph
    item["input_edges"] = input_edges + 1  # we shift all indices by one for padding
    if "labels" not in item:
        item["labels"] = item["y"]

    return item


class GraphormerDataCollator:
    def __init__(self, spatial_pos_max=20, on_the_fly_processing=False):
        if not is_cython_available():
            raise ImportError("Graphormer preprocessing needs Cython (pyximport)")

        self.spatial_pos_max = spatial_pos_max
        self.on_the_fly_processing = on_the_fly_processing

    def __call__(self, features: List[dict]) -> Dict[str, Any]:
        if self.on_the_fly_processing:
            features = [preprocess_item(i) for i in features]

        if not isinstance(features[0], Mapping):
            features = [vars(f) for f in features]
        batch = {}

        max_node_num = max(len(i["input_nodes"]) for i in features)
        node_feat_size = len(features[0]["input_nodes"][0])
        edge_feat_size = len(features[0]["attn_edge_type"][0][0])
        max_dist = max(len(i["input_edges"][0][0]) for i in features)
        edge_input_size = len(features[0]["input_edges"][0][0][0])
        batch_size = len(features)

        batch["attn_bias"] = torch.zeros(batch_size, max_node_num + 1, max_node_num + 1, dtype=torch.float)
        batch["attn_edge_type"] = torch.zeros(batch_size, max_node_num, max_node_num, edge_feat_size, dtype=torch.long)
        batch["spatial_pos"] = torch.zeros(batch_size, max_node_num, max_node_num, dtype=torch.long)
        batch["in_degree"] = torch.zeros(batch_size, max_node_num, dtype=torch.long)
        batch["input_nodes"] = torch.zeros(batch_size, max_node_num, node_feat_size, dtype=torch.long)
        batch["input_edges"] = torch.zeros(
            batch_size, max_node_num, max_node_num, max_dist, edge_input_size, dtype=torch.long
        )

        for ix, f in enumerate(features):
            for k in ["attn_bias", "attn_edge_type", "spatial_pos", "in_degree", "input_nodes", "input_edges"]:
                f[k] = torch.tensor(f[k])

            if len(f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max]) > 0:
                f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max] = float("-inf")

            batch["attn_bias"][ix, : f["attn_bias"].shape[0], : f["attn_bias"].shape[1]] = f["attn_bias"]
            batch["attn_edge_type"][ix, : f["attn_edge_type"].shape[0], : f["attn_edge_type"].shape[1], :] = f[
                "attn_edge_type"
            ]
            batch["spatial_pos"][ix, : f["spatial_pos"].shape[0], : f["spatial_pos"].shape[1]] = f["spatial_pos"]
            batch["in_degree"][ix, : f["in_degree"].shape[0]] = f["in_degree"]
            batch["input_nodes"][ix, : f["input_nodes"].shape[0], :] = f["input_nodes"]
            batch["input_edges"][
                ix, : f["input_edges"].shape[0], : f["input_edges"].shape[1], : f["input_edges"].shape[2], :
            ] = f["input_edges"]

        batch["out_degree"] = batch["in_degree"]

        sample = features[0]["labels"]
        if len(sample) == 1:  # one task
            if isinstance(sample[0], float):  # regression
                batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
            else:  # binary classification
                batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
        else:  # multi task classification, left to float to keep the NaNs
            batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], axis=0))

        return batch


# Data prerpoc

In [29]:
def getshortest_path(datapoint):
    num_nodes=len(datapoint['node_feat'])
    edge_index = datapoint['edge_index']
    adj = np.zeros([num_nodes, num_nodes], dtype=bool)
    adj[edge_index[0], edge_index[1]] = True
    shortest_path_result, path = algos_graphormer.floyd_warshall(adj)
    max_dist = np.amax(shortest_path_result)
    return {"max_dist":max_dist, "path": path}


In [30]:
 
def processItemForGraphormer(graph, yi):
    processed = preprocess_item(
                 {"node_feat":graph.x.tolist(),
                 "edge_index":graph.edge_index.tolist(),
                 "edge_attr":graph.edge_attr.tolist(),
                 "num_nodes":len(graph.x),
                 'y': yi
                })
    processed['attn_edge_type_ORIG'] = np.array(processed['attn_edge_type'])+0
    processed['input_nodes_ORIG'] = np.array(processed['input_nodes'])+0
    return processed
    


In [31]:
from torch_geometric.data import Data, Dataset
class MoleculeDataset(Dataset):
    def __init__(self, dataset: pd.DataFrame, tokenizer, node_mask_percent=0.15, edge_mask_percent=0.2):
        super(Dataset, self).__init__()
        self.dataset = dataset
        self.node_mask_percent = node_mask_percent
        self.edge_mask_percent = edge_mask_percent

        self.tokenizer = tokenizer
        self.tokenizer.model_max_len = 512
        
        self.yi = torch.tensor(np.full(( 768), 1).tolist()) 
        
        self.dataset['tokens'] = self.dataset['ecfp1'].progress_apply(self.tokenize)
        self.dataset['graph'] = self.dataset['smiles'].progress_apply(self.get_graph_from_smiles)
        
    
        self.dataset['graphormerdata'] = self.dataset['graph'].progress_apply(
              lambda graph:   processItemForGraphormer(graph, self.yi)                    
        )                                                     
                                                                     
        self.dataset['graphormerdataRAW'] = self.dataset['graph'].progress_apply(
              lambda graph:
#       preprocess_item(
                 {"node_feat":graph.x.tolist(),
                 "edge_index":graph.edge_index.tolist(),
                 "edge_attr":graph.edge_attr.tolist(),
                 "num_nodes":len(graph.x),
                 'y': self.yi
                }
#         )                      
        ) 
        
        self.dataset['shortest_path'] = self.dataset['graphormerdataRAW'].progress_apply(
                   lambda datapoint:
                        getshortest_path(datapoint)
                    )
        
        self.maskedGraphAtom = torch.tensor([[len(ATOM_LIST),0]],dtype=torch.long)
        self.edgeGraphMask = torch.tensor([len(BOND_LIST) + 1, len(BONDDIR_LIST)], dtype=torch.long)
        
 
    def get_graph_from_smiles(self, smiles):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return torch.tensor([[], []], dtype=torch.long), \
                    torch.tensor(np.array([]), dtype=torch.long), \
                    torch.tensor(np.array([]), dtype=torch.long), \
                    0
    
        N = mol.GetNumAtoms()
        M = mol.GetNumBonds()
    
        type_idx = []
        chirality_idx = []
        atomic_number = []
        
        for atom in mol.GetAtoms():
            type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
            chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
            atomic_number.append(atom.GetAtomicNum())
        
        x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
        x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
        node_feat = torch.cat([x1, x2], dim=-1)
    
        row, col, edge_feat = [], [], []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
    
        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.long)
        num_nodes = N
        num_edges = M
        return Data(x=node_feat, edge_index=edge_index, edge_attr=edge_attr)
         

    def get_augmented_graph_copy(self, node_feat, edge_index, edge_attr, N, M):
        num_mask_nodes = max([1, math.floor(self.node_mask_percent * N)])
        #num_mask_nodes = 1
        num_mask_edges = max([0, math.floor(self.edge_mask_percent * M)])

        
        mask_nodes = random.sample(list(range(N)), num_mask_nodes)
        mask_edges_single = random.sample(list(range(M)), num_mask_edges)
        
        
        mask_edges = [2*i for i in mask_edges_single] + [2*i+1 for i in mask_edges_single]

        
        node_feat_new = deepcopy(node_feat)
        
        node_feat_new[mask_nodes] = (node_feat_new[mask_nodes][:,:]*0 + self.maskedGraphAtom)
            
        edge_attr_new = edge_attr
        edge_attr_new[mask_edges] =  self.edgeGraphMask

        return Data(x=node_feat_new, edge_index=edge_index, edge_attr=edge_attr_new)

    def tokenize(self, item):
        sample = self.tokenizer(item, truncation=True, max_length=512, padding='max_length')
        return (torch.tensor(sample.input_ids), 
                torch.tensor(sample.attention_mask), 
                torch.tensor(sample.input_ids)
               )
        return Data(input_ids=sample.input_ids, attention_mask=sample.attention_mask, labels=sample.input_ids)

    def mlm(self, tensor):
        rand = torch.rand(tensor.shape)
        # mask random 15% where token is not 0 <s>, 1 <pad>, or 2 <s/>
        mask_arr = (rand < .15) * (tensor != 0) * (tensor != 1) * (tensor != 2)
        selection = torch.flatten(mask_arr.nonzero()).tolist()
        # mask tensor, token == 4 is our mask token
        tensor[selection] = 4
        return tensor

    def apply_mlm(self, sample):
        labels = torch.tensor(sample.input_ids)
        attention_mask = torch.tensor(sample.attention_mask)
        input_ids = self.mlm(labels.detach().clone())
        return Data(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    def __getitem__(self, index):
        bert = self.dataset['tokens'][index]
        graph = self.dataset['graph'][index]
        
        return graph, bert

    def __len__(self):
        return len(self.dataset)

    def get(self):
        pass
    def len(self):
        pass

In [32]:
st = time.time()
dataset = MoleculeDataset(dataframe, tokenizer)
time.time()-st

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2325.19it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1611.44it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:09<00:00, 100.22it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 40838.36it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 295.30it/s]


14.459450960159302

In [33]:
#st = time.time()
#for j in range(len(dataset)):
    #dd = dataset[j]
#time.time()-st, len(dataset)

In [34]:
#from transformers.models.graphormer.collating_graphormer import GraphormerDataCollator
import transformers
#from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator




In [35]:
from torch_geometric.data import Batch

In [36]:

from torch_geometric.loader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_train = len(dataset)
indices = list(range(num_train))
np.random.shuffle(indices)

split = int(np.floor(config['dataset']['valid_size'] * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# config['dataset']['num_workers'] = 1
# train_dataloader = DataLoader(
#     dataset, batch_size=config['batch_size'], sampler=train_sampler,
#     num_workers=config['dataset']['num_workers'], drop_last=True
# )

# epoch_counter = 0
# train_tqdm = tqdm(train_dataloader, unit="batch")
# train_tqdm.set_description(f'Epoch {epoch_counter}')

In [37]:
import torch.nn.functional as F
def pad_to_shape(tensor, target_shape):
    current_shape = tensor.shape
    num_dims = len(current_shape)
    
    if num_dims != len(target_shape):
        raise ValueError(f"Tensor has {num_dims} dimensions but target shape has {len(target_shape)} dimensions.")
    
    # Calculate padding needed for each dimension
    padding = []
    for i in range(num_dims - 1, -1, -1):  # Iterate from the last dimension backwards
        if target_shape[i] < current_shape[i]:
            raise ValueError(f"Target shape at dimension {i} is smaller than the tensor shape.")
        padding.append(0)  # No padding on the left
        padding.append(target_shape[i] - current_shape[i])  # Right side padding
        

    # Apply padding
    padded_tensor = F.pad(tensor, padding)
    return padded_tensor

In [38]:
'''import pyximport

pyximport.install(setup_args={"include_dirs": np.get_include()})

from transformers.models.graphormer.collating_graphormer import algos_graphormer

'''

'import pyximport\n\npyximport.install(setup_args={"include_dirs": np.get_include()})\n\nfrom transformers.models.graphormer.collating_graphormer import algos_graphormer\n\n'

In [39]:
dataset.node_mask_percent

0.15

In [85]:
import torch
from torch.utils.data import IterableDataset, DataLoader
import copy

            


class CustomBatchDataset(IterableDataset):
    
    def setTrain(self,train = True):
        if train:
            self.sample_id = self.train_idx
            self.train = True
        else:
            self.sample_id = self.valid_idx
            self.train = False
    
    def __init__(self, dataframe,dataset, train_idx,valid_idx, batch_size):
        
        self.y = torch.tensor(np.full((config['batch_size'] * 2, 768), 1).tolist())
        self.yi = torch.tensor(np.full(( 768), 1).tolist())
        self.dataset=dataset
        self.dataframe = dataframe
        self.sample_id = train_idx
        self.valid_idx = valid_idx
        self.train_idx = train_idx
        self.batch_size = batch_size
        self.train = True
        
        input_ids = [e[0] for e in dataframe['tokens']]
        self.input_ids = torch.stack(input_ids)

        attention_mask = [e[1] for e in dataframe['tokens']]
        self.attention_mask = torch.stack(attention_mask)

        labels = [e[2] for e in dataframe['tokens']]
        self.labels = torch.stack(labels)

        self.graphs = [e for e in dataframe['graph']]
        
    def __iter__(self):
        """
        Custom iterator that yields batches of data and labels.
        """
        # Get the total number of samples
        total_samples = (len(self.sample_id)//self.batch_size)*self.batch_size
        if self.train:
            np.random.shuffle(self.sample_id)
        # Yield minibatches
        for i in range(0, total_samples, self.batch_size):
             
            S = self.sample_id[ i : i + self.batch_size]
            
            inp_Idx =  self.input_ids[S]
            rand = torch.rand(inp_Idx.shape)
            mask_arr = (rand < .15) * (inp_Idx != 0) * (inp_Idx != 1) * (inp_Idx != 2)
            inp_Idx[mask_arr] = 4
            atte = self.attention_mask[S]
            labe = self.labels[S]
            
            SS = S+S
            graphdataProcessed =  [ self.dataframe.graphormerdata[i] for i in SS]

            shortestPath =  [ self.dataframe.shortest_path[i] for i in SS]
            selId = [ 
                    random.sample(
                     list(range(len(g["node_feat"]))),  
                                      max([1, math.floor(self.dataset.node_mask_percent 
                                                         * len(g["node_feat"]))])
                             )
                                            for g  in  graphdataProcessed ]
             
            
            for g, selidi in zip(graphdataProcessed,selId):
                g["input_nodes"]= g["input_nodes_ORIG"] +0
                for s in selidi:
                    g["input_nodes"][s,0]=self.dataset.maskedGraphAtom[0][0]
            
#             eselId = [ 
#                      random.sample(
#                      list(range(len(g['edge_index'][0])//2)),  
#                       max([0, math.floor(self.dataset.edge_mask_percent * (len(g['edge_index'][0])//2))])
#              )
#                             for g   in   graphdataProcessed ] 
            
            eselId = [ 
                 random.sample(
                                 list(range(len(g['edge_index'][0])//2)),  
                                  max([0, math.floor(self.dataset.edge_mask_percent * (len(g['edge_index'][0])//2))])
                         )
                                        for g   in   graphdataProcessed ]
            for g, eid  in zip(graphdataProcessed, eselId):
                g['attn_edge_type'] = np.array(g['attn_edge_type_ORIG'])+0
                for e in eid:
                    fn = g['edge_index'][0][2*e]
                    tn = g['edge_index'][1][2*e]
                    g['attn_edge_type'][fn,tn,:] = self.dataset.edgeGraphMask
                    fn = g['edge_index'][0][2*e+1]
                    tn = g['edge_index'][1][2*e+1]
                    g['attn_edge_type'][fn,tn,:] = self.dataset.edgeGraphMask
            
            
            for g, sp  in zip(graphdataProcessed, shortestPath):
                edge_attr = g['attn_edge_type']
                
                input_edges = algos_graphormer.gen_edge_input(sp['max_dist'], sp['path'], edge_attr)

                g['input_edges'] = input_edges+1
            
            yield inp_Idx, atte, labe, graphdataProcessed#, g1, g2 

    def __len__(self):
        if self.train:
            return len(self.train_idx)//self.batch_size
        else:
            return len(self.valid_idx)//self.batch_size


custom_batch_dataset = CustomBatchDataset(dataframe, dataset, train_idx,valid_idx, batch_size)

#LST=[]
#j=0
#st = time.time()
#for batch_idx, (inp_Idx, atte, labe, graphdataProcessed   ) in tqdm(enumerate(custom_batch_dataset)):
    #j+=1
#     print(graphdataProcessed[0]['input_edges'].shape)
    #if j > 1:
        #break
      
   
#(time.time()-st) ,0.7877202033996582, 0.512915849685669


### Create Transformer Model

In [65]:
from transformers import GraphormerForGraphClassification
from torch.utils.data import DataLoader
from transformers.models.graphormer.collating_graphormer import GraphormerDataCollator
import transformers
from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator
from datasets import Dataset, DatasetDict

model_name_base = 'graphormer-base-pcqm4mv1'
model_name = 'clefourrier/graphormer-base-pcqm4mv1'

In [66]:
class NTXentLoss(torch.nn.Module):

    def __init__(self, device, batch_size, temperature, use_cosine_similarity):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        representations = torch.cat([zjs, zis], dim=0)

        similarity_matrix = self.similarity_function(representations, representations)

        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)

        logits = torch.cat((positives, negatives), dim=1)
        logits = logits.abs() + 0.0001
        logits = torch.log(logits)
        logits /= self.temperature
        
        labels = torch.zeros(2 * self.batch_size).to(self.device).long()
        loss = self.criterion(logits, labels)

        return loss / (2 * self.batch_size)

In [67]:
from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader
from transformers import GraphormerForGraphClassification
import torch
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
from transformers import AdamW, get_scheduler


class GraphormerDataCollator_():
    def __init__(self):
        self.data_collator = GraphormerDataCollator()

    def __call__(self, features):
        for mol in features:
            if mol['num_nodes'] == 1:
                features.remove(mol)
        return self.data_collator(features)

In [68]:
wandb.init(
    project="efcp_transformer",
    name="RobertaForMaskedLM + Graphormer-speed-up-1m",
    config=config
)

0,1
bert_loss/train,███▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▁▁▁
bimodal_loss/train,▇▅▅▇▄▅▅▆▄▅█▃▇▆▅▅▅▅▃▅▄▄▂▇▃▄▃▃▃▅▅▃▄▄▃▅▁▄
graph_loss/train,████▇▆▆▆▅▅▄▅▄▃▃▃▃▂▂▂▃▂▃▂▂▂▂▁▁▁▁▂▁▂▁▁▁▁
loss/train,█▆▆█▆▅▆▇▅▅█▄▇▆▅▅▅▅▃▅▄▄▂▆▃▄▃▃▃▅▅▃▄▃▃▄▁▃

0,1
bert_loss/train,5.06958
bimodal_loss/train,26.57698
graph_loss/train,0.18155
loss/train,85.07286


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [69]:
from transformers import RobertaForMaskedLM
from transformers import RobertaConfig
from torch import nn

class MolecularBertGraph(torch.nn.Module):
    def __init__(self):
        super(MolecularBertGraph, self).__init__()
        self.batch_size = config['batch_size']

        roberta_config = RobertaConfig(
            vocab_size=30_522,
            max_position_embeddings=514,
            hidden_size=768,
            num_attention_heads=12,
            num_hidden_layers=6,
            type_vocab_size=1
        )
        self.bert = RobertaForMaskedLM(roberta_config)
        
        self.data_collator = GraphormerDataCollator_()
        
        self.graph_model = GraphormerForGraphClassification.from_pretrained(
            model_name, 
            num_classes=1,
            ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
            ).to(device)     # GraphModel(**config['model'])
        self.graph_model.classifier = nn.Identity()
#         print(self.graph_model)
        # self.graph_model = self._load_pre_trained_weights(self.graph_model)

        self.out_graph_linear = torch.nn.Linear(768 * 2, 768, bias=True)

        self.out_graph_projection1 = torch.nn.Linear(768, 768, bias=True)

        self.bn1_graph = nn.BatchNorm1d(768)

        self.out_graph_projection2 = torch.nn.Linear(768, 768, bias=True)

        self.bn2_graph = nn.BatchNorm1d(768)

        self.out_bert_projection1 = torch.nn.Linear(768, 768, bias=True)

        self.bn1_bert = nn.BatchNorm1d(768)

        self.out_bert_projection2 = torch.nn.Linear(768, 768, bias=True)
        
        self.bn2_bert = nn.BatchNorm1d(768)

        # contrastive loss for MolCLR
        self.nt_xent_criterion = NTXentLoss(device, self.batch_size, **config['loss'])
        # cosine distance as loss between models
        self.cosine_sim = torch.nn.CosineSimilarity(dim=-1)

    def forward(self, inp_Idx, atte, labe, graphdataProcessed):
#         inp_Idx, atte, labe, g1, g2


#         ST =time.time()
        bert_output = self.bert(input_ids=inp_Idx, 
                                 attention_mask=atte,
                                 labels=labe, output_hidden_states=True)
#         print("BERT", time.time()-ST);ST=time.time()
        bert_loss = bert_output.loss
        bert_emb = bert_output.hidden_states[0][:, 0, :] # take emb for CLS token

#         print("BERT-OUT", time.time()-ST);ST=time.time()
        graph_loss, hidden_states_1, hidden_states_2 = self.graph_step(graphdataProcessed)
#         print("GRAPH", time.time()-ST);ST=time.time()
    
        graph_emb = self.out_graph_linear(torch.cat((hidden_states_1, hidden_states_2), dim=-1)).mean(axis=0)
        graph_emb_projected1 = self.out_graph_projection1(graph_emb)
        graph_emb_projected_bn1 = self.bn1_graph(graph_emb_projected1)
        graph_emb_projected2 = self.out_graph_projection2(torch.nn.functional.relu(graph_emb_projected_bn1))
        graph_emb_projected_bn2 = self.bn2_graph(graph_emb_projected2)
        #bert projections:
        bert_emb_projected1 = self.out_bert_projection1(bert_emb)
        bert_emb_projected_bn1 = self.bn1_bert(bert_emb_projected1)
        bert_emb_projected2 = self.out_bert_projection2(torch.nn.functional.relu(bert_emb_projected_bn1))
        bert_emb_projected_bn2 = self.bn2_bert(bert_emb_projected2)
#         print("embed proj",time.time()-ST);ST = time.time()
        bimodal_loss = self.nt_xent_criterion(bert_emb_projected_bn2, graph_emb_projected_bn2)
#         print("bimodal_loss",time.time()-ST);
#         print("PROJECTIONS AND LOSS", time.time()-ST);ST=time.time()
        return bert_loss, graph_loss, bimodal_loss, graph_emb_projected_bn2, bert_emb_projected_bn2

    def graph_step(self, graphdataProcessed):
         
         
#         ST=time.time() 
        batch = {}
        for k in ['attn_bias', 'attn_edge_type', 'spatial_pos', 'in_degree', 'input_nodes', 'input_edges', 'out_degree', 'labels']:
            shp = np.max([  np.array(e[k]).shape for e in graphdataProcessed], 0)
            batch[k] = torch.stack([pad_to_shape(torch.tensor(e[k]), shp) for e in graphdataProcessed])
        
        
#         print("GRAPH -data", time.time()-ST);ST=time.time() 
        input_batch = { k: v.to(device) for k, v in batch.items() }
#         print("GRAPH -data moving", time.time()-ST);ST=time.time()
        outputs = self.graph_model(**input_batch)
        # get the representations and the projections
#         print("GRAPHHORMER", time.time()-ST);ST=time.time()
        zis = outputs.logits[:config['batch_size']]
        zjs = outputs.logits[config['batch_size']:]

        ris = outputs.hidden_states[0][:, 0:config['batch_size'], :].to(device)
        rjs = outputs.hidden_states[0][:, config['batch_size']:config['batch_size']*2, :].to(device)
        
        zis = torch.nn.functional.normalize(zis, dim=1)
        zjs = torch.nn.functional.normalize(zjs, dim=1)
    
         
        loss = self.nt_xent_criterion(zis, zjs)
#         print("GRAPH-loss", time.time()-ST);ST=time.time()
        return loss, ris, rjs
        
#for batch_idx, (inp_Idx, atte, labe, graphdataProcessed  ) in enumerate(custom_batch_dataset):
    #break

model = MolecularBertGraph().to(device);
#out = model(inp_Idx.to(device), atte.to(device), labe.to(device), graphdataProcessed)


#st = time.time()

#out = model(inp_Idx.to(device), atte.to(device), labe.to(device), graphdataProcessed)

#print((time.time()-st)*8000/60/60)


Some weights of the model checkpoint at clefourrier/graphormer-base-pcqm4mv1 were not used when initializing GraphormerForGraphClassification: ['classifier.classifier.weight']
- This IS expected if you are initializing GraphormerForGraphClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GraphormerForGraphClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [46]:

# st = time.time()
# for batch_idx, (inp_Idx, atte, labe, graphdataProcessed ) in tqdm(enumerate(custom_batch_dataset)):
#     optimizer.zero_grad()
    
#     bert_loss, graph_loss, bimodal_loss, emb1, emb2 = model(inp_Idx.to(device), atte.to(device), labe.to(device), graphdataProcessed)

#     loss =   bert_loss +   graph_loss +   bimodal_loss
#     loss.backward()
#     optimizer.step()

# (time.time()-st) 

In [47]:
#sum([w.numel() for w in model.parameters()])/(1024**3)*4

### Define utils

In [70]:
num_epoch = config['epochs']

optimizer = torch.optim.Adam(
    model.parameters(), config['init_lr'], 
    weight_decay=eval(config['weight_decay'])
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config['epochs']-config['warm_up'], 
    eta_min=0, last_epoch=-1
)
num_epoch

40

### Training (with validation)

In [71]:
alpha = config['loss_params']['alpha']
beta = config['loss_params']['beta']
gamma = config['loss_params']['gamma']
alpha, beta, gamma

(1, 1.5, 3)

In [72]:
epoch_counter = 0
import matplotlib.pyplot as plt

In [86]:
len(custom_batch_dataset)

29

In [87]:
def train_loop():
    custom_batch_dataset.setTrain(True)
    train_tqdm = tqdm(custom_batch_dataset, unit="batch")
    train_tqdm.set_description(f'Epoch {epoch_counter}')
    bert_loss_sum, graph_model_loss_sum, bimodal_loss_sum, loss_sum = 0, 0, 0, 0
    model.train()
         
    for inp_Idx, atte, labe, graphdataProcessed in train_tqdm:
     
        if True:
            optimizer.zero_grad()
    
    
            bert_loss, graph_loss, bimodal_loss, emb1, emb2 = model(inp_Idx.to(device), atte.to(device), labe.to(device), graphdataProcessed)
    
            loss = alpha * bert_loss + beta * graph_loss + gamma * bimodal_loss
            loss.backward()
    
 
            bert_loss_sum += bert_loss.item()
            graph_model_loss_sum += graph_loss.item()
            bimodal_loss_sum += bimodal_loss.item()
            loss_sum += loss.item()
    
            #wandb.log({“bert_loss/train”:bert_loss, “graph_loss/train”: graph_loss, "bimodal_loss/train": bimodal_loss, "loss/train": loss})
            wandb.log({"bert_loss/train": bert_loss })
            wandb.log({"graph_loss/train": graph_loss})
            wandb.log({"bimodal_loss/train": bimodal_loss})
            wandb.log({"loss/train": loss})
    
    
            optimizer.step()
#             train_tqdm.set_postfix(loss=loss.item(), bert_loss=bert_loss.item(), graph_loss=graph_loss.item(), bimodal_loss=bimodal_loss.item())
             
#         except:
#             continue
          
    return bert_loss_sum / len(custom_batch_dataset), graph_model_loss_sum / len(custom_batch_dataset), bimodal_loss_sum / len(custom_batch_dataset), loss_sum / len(custom_batch_dataset)


In [88]:
def eval_loop():
#     eval_tqdm = tqdm(eval_dataloader, unit="batch")
#     eval_tqdm.set_description(f'Epoch {epoch_counter}')
    bert_loss_sum, graph_model_loss_sum, bimodal_loss_sum, loss_sum = 0, 0, 0, 0
    
    model.eval()
    custom_batch_dataset.setTrain(False)
     
         
    for batch_idx, (inp_Idx, atte, labe, graphdataProcessed  ) in enumerate(custom_batch_dataset):
        try:
            with torch.no_grad():
                bert_loss, graph_loss, bimodal_loss, emb1, emb2 = model(inp_Idx.to(device), atte.to(device), labe.to(device), graphdataProcessed)

    
            loss = alpha * bert_loss + beta * graph_loss + gamma * bimodal_loss
    
            bert_loss_sum += bert_loss.item()
            graph_model_loss_sum += graph_loss.item()
            bimodal_loss_sum += bimodal_loss.item()
            loss_sum += loss.item()
    
#             eval_tqdm.set_postfix(loss=loss.item(), bert_loss=bert_loss.item(), graph_loss=graph_loss.item(), bimodal_loss=bimodal_loss.item())
        
        except:
            continue
    return bert_loss_sum / len(custom_batch_dataset), graph_model_loss_sum / len(custom_batch_dataset), bimodal_loss_sum / len(custom_batch_dataset), loss_sum / len(custom_batch_dataset)

## Main loop

In [89]:
from datetime import datetime

model_checkpoints_folder = os.path.join('ckpts')
dir_name = datetime.now().strftime('%b%d_%H-%M-%S_graphormer_edge_masking')
log_dir = os.path.join(model_checkpoints_folder, dir_name)
_save_config_file(config, log_dir)

In [90]:
n_iter = 0
valid_n_iter = 0
best_valid_loss = np.inf

for epoch_counter in range(num_epoch):
    bert_loss, graph_loss, bimodal_loss, loss = train_loop()

    wandb.log({"bert_loss/train": bert_loss}, step=epoch_counter)
    wandb.log({"graph_loss/train": graph_loss}, step=epoch_counter)
    wandb.log({"bimodal_loss/train": bimodal_loss}, step=epoch_counter)
    wandb.log({"loss/train": loss}, step=epoch_counter)

    bert_loss, graph_loss, bimodal_loss, loss = eval_loop()

    wandb.log({"bert_loss/eval": bert_loss}, step=epoch_counter)
    wandb.log({"graph_loss/eval": graph_loss}, step=epoch_counter)
    wandb.log({"bimodal_loss/eval": bimodal_loss}, step=epoch_counter)
    wandb.log({"loss/eval": loss}, step=epoch_counter)
    
    if loss < best_valid_loss:
        best_valid_loss = loss
        torch.save(model.state_dict(), os.path.join(log_dir, 'model.pth'))
    
    if (epoch_counter + 1) % config['save_every_n_epochs'] == 0:
        torch.save(model.state_dict(), os.path.join(log_dir, 'model_{}.pth'.format(str(epoch_counter))))

    # warmup for the first few epochs
    if epoch_counter >= config['warm_up']:
        #wandb.log({"cosine_lr_decay": scheduler.get_last_lr()[0]}, step=epoch_counter)
        scheduler.step()

  batch[k] = torch.stack([pad_to_shape(torch.tensor(e[k]), shp) for e in graphdataProcessed])
Epoch 0:  97%|██████████████████████████████████████████████████████████████████████████████████████████████▌   | 28/29 [02:28<00:05,  5.32s/batch]
Epoch 1:  97%|██████████████████████████████████████████████████████████████████████████████████████████████▌   | 28/29 [02:15<00:04,  4.85s/batch]
Epoch 2:   7%|██████▊                                                                                            | 2/29 [00:14<03:09,  7.03s/batch]


KeyboardInterrupt: 

In [None]:
wandb.finish()

In [None]:
custom_batch_dataset.setTrain(True)