In [1]:
import sys
sys.path.append('..')

In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

from datasets import load_dataset
from util.model import smiles2graph

from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, Batch
from model.mmcl import MultiModalCLGAE, train as train_mmcl
from util.dataset import GraphTextDataset
from util.scibert import get_batched_text_outputs, get_tokenizer

from itertools import chain
from util.prompt import create_cot_prompt, create_incontext_prompt2
from util.measure import measure
import torch_geometric.utils.smiles as smiles

device = 'cuda' if torch.cuda.is_available() else 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset_name = 'liupf/ChEBI-20-MM'

dataset = load_dataset(dataset_name)
df_train = dataset['train'].to_pandas()
df_valid = dataset['validation'].to_pandas()
df_test = dataset['test'].to_pandas()

def smiles2graph(smiles_str):
    data = smiles.from_smiles(smiles_str)
    data.edge_attr = data.edge_attr.float()
    data.x = data.x.float()
    return Data(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr)

In [4]:
df_train.iloc[0]

CID                                                    129626631
SMILES         CCCCC[C@@H]1O[C@@H]1/C=C/C(O)C/C=C\C/C=C\CCCC(...
description    The molecule is an epoxy(hydroxy)icosatrienoat...
polararea                                                   72.9
xlogp                                                        4.6
inchi          InChI=1S/C20H32O4/c1-2-3-9-13-18-19(24-18)16-1...
iupacname      (5Z,8Z,12E)-11-hydroxy-13-[(2R,3S)-3-pentyloxi...
SELFIES        [C][C][C][C][C][C@@H1][O][C@@H1][Ring1][Ring1]...
Name: 0, dtype: object

In [5]:
smiles2graph(df_train.iloc[0]['SMILES'])

Data(x=[24, 9], edge_index=[2, 48], edge_attr=[48, 3])

In [7]:
max_seq_len = 512
batch_size = 256
text_tokenizer, text_model = get_tokenizer()

df_train = df_train[0:1000]
df_valid = df_valid[0:500]
df_test = df_test[0:500]

train_graphs = [smiles2graph(smiles) for smiles in df_train['SMILES']]
val_graphs = [smiles2graph(smiles) for smiles in df_valid['SMILES']]
test_graphs = [smiles2graph(smiles) for smiles in df_test['SMILES']]


In [8]:
from model.mmcl_attr import MultiModalCLAttr

model = MultiModalCLAttr(9, 32, 64, 9)  # Replace with your model class
model.load_state_dict(torch.load('/home/ali.lawati/mol-incontext/checkpoints/mmcl-300.pt', map_location=torch.device('cpu')))

  model.load_state_dict(torch.load('/home/ali.lawati/mol-incontext/checkpoints/mmcl-300.pt', map_location=torch.device('cpu')))


<All keys matched successfully>

In [9]:
with torch.no_grad():
    train_batch = Batch.from_data_list(train_graphs)#.to(device)
    test_batch  = Batch.from_data_list(test_graphs)#.to(device)
    train_pool = model(train_batch.x, train_batch.edge_index, train_batch.batch, train_batch.edge_attr)
    test_pool = model(test_batch.x, test_batch.edge_index, test_batch.batch, test_batch.edge_attr)

In [20]:
train_pool = model(train_batch.x, train_batch.edge_index, train_batch.batch)

In [22]:
train_pool.shape


torch.Size([1000, 64])

In [None]:

model.text2latent

In [11]:
import argparse
import os
import torch
import numpy as np
 

parser = argparse.ArgumentParser()

parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=int, default=0)

parser.add_argument("--dataspace_path", type=str, default="./data")
parser.add_argument("--SSL_emb_dim", type=int, default=256)
parser.add_argument("--max_seq_len", type=int, default=512)

args = parser.parse_args("")
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import AutoModel, AutoTokenizer
pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT')
text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder)
text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device)

In [12]:
# This is for BERT
def padarray(A, size, value=0):
    t = size - len(A)
    return np.pad(A, pad_width=(0, t), mode='constant', constant_values = value)

def preprocess_each_sentence(sentence, tokenizer, max_seq_len):
    text_input = tokenizer(
        sentence, truncation=True, max_length=max_seq_len,
        padding='max_length', return_tensors='np')
    
    input_ids = text_input['input_ids'].squeeze()
    attention_mask = text_input['attention_mask'].squeeze()

    sentence_tokens_ids = padarray(input_ids, max_seq_len)
    sentence_masks = padarray(attention_mask, max_seq_len)
    return [sentence_tokens_ids, sentence_masks]


# This is for BERT
def prepare_text_tokens(device, description, tokenizer, max_seq_len):
    B = len(description)
    tokens_outputs = [preprocess_each_sentence(description[idx], tokenizer, max_seq_len) for idx in range(B)]
    tokens_ids = [o[0] for o in tokens_outputs]
    masks = [o[1] for o in tokens_outputs]
    tokens_ids = torch.Tensor(tokens_ids).long().to(device)
    masks = torch.Tensor(masks).bool().to(device)
    return tokens_ids, masks

In [14]:
text_dim = 768
text2latent = torch.nn.Linear(text_dim, args.SSL_emb_dim).to(device)

description_tokens_ids, description_masks = prepare_text_tokens(device, ['The molecule is a branched amino tetrasaccharide consisting of N-acetyl-beta-D-glucosamine having two alpha-L-fucosyl residues at the 3- and 6-positions as well as an N-acetyl-beta-D-glucosaminyl residue at the 4-position. It has a role as a carbohydrate allergen. It is a glucosamine oligosaccharide and an amino tetrasaccharide. It derives from an alpha-L-Fucp-(1->3)-[alpha-L-Fucp-(1->6)]-beta-D-GlcpNAc'],text_tokenizer, 500) 

description_output = text_model(input_ids=description_tokens_ids, attention_mask=description_masks)
description_repr = description_output["pooler_output"]
description_repr = text2latent("The molecule is an 11-oxo steroid that is corticosterone in which the hydroxy substituent at the 11beta position has been oxidised to give the corresponding ketone. It has a role as a human metabolite and a mouse metabolite. It is a 21-hydroxy steroid, a 3-oxo-Delta(4) steroid, a 20-oxo steroid, an 11-oxo steroid, a corticosteroid and a primary alpha-hydroxy ketone. It derives from a corticosterone.")

TypeError: linear(): argument 'input' (position 1) must be Tensor, not str