In [60]:
import sys
sys.path.append('..')
from datasets import load_dataset
from train import smiles2graph

In [67]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from torch_geometric.data import Data, Batch
from model.mmcl_attr import MultiModalCLAttr
from transformers import AutoModel, AutoTokenizer
from model.airl3 import AIRL

device = 'cuda' if torch.cuda.is_available() else 'cpu'
cache_dir = '/home/ali.lawati/mol-incontext/data/pretrained_SciBERT'
text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=cache_dir)
text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=cache_dir).to(device)

def padarray(A, size, value=0):
    t = size - len(A)
    return np.pad(A, pad_width=(0, t), mode='constant', constant_values = value)

def preprocess_each_sentence(sentence, tokenizer, max_seq_len):
    text_input = tokenizer(
        sentence, truncation=True, max_length=max_seq_len,
        padding='max_length', return_tensors='np')
    
    input_ids = text_input['input_ids'].squeeze()
    attention_mask = text_input['attention_mask'].squeeze()

    sentence_tokens_ids = padarray(input_ids, max_seq_len)
    sentence_masks = padarray(attention_mask, max_seq_len)
    return [sentence_tokens_ids, sentence_masks]

def embed_text(text2latent, text_model, text_tokenizer, text_arr):
    description_tokens_ids, description_masks = prepare_text_tokens(device, text_arr, text_tokenizer, 500) 
    description_output = text_model(input_ids=description_tokens_ids, attention_mask=description_masks)
    description_repr = description_output["pooler_output"]
    description_repr = text2latent(description_repr)
    return description_repr


# This is for BERT
def prepare_text_tokens(device, description, tokenizer, max_seq_len):
    B = len(description)
    tokens_outputs = [preprocess_each_sentence(description[idx], tokenizer, max_seq_len) for idx in range(B)]
    tokens_ids = [o[0] for o in tokens_outputs]
    masks = [o[1] for o in tokens_outputs]
    tokens_ids = torch.Tensor(tokens_ids).long().to(device)
    masks = torch.Tensor(masks).bool().to(device)
    return tokens_ids, masks

def load_local_dataset(dataset_name = 'liupf/ChEBI-20-MM'):
    dataset = load_dataset(dataset_name)
    df_train = dataset['train'].to_pandas()
    df_valid = dataset['validation'].to_pandas()
    df_test = dataset['test'].to_pandas()
    return df_train, df_valid, df_test

def load_model(model_checkpoint = '/home/ali.lawati/mol-incontext/checkpoints/mmcl-300.pt'):
    model = MultiModalCLAttr(9, 32, 64, 9)  # Replace with your model class 
    model.load_state_dict(torch.load('/home/ali.lawati/mol-incontext/checkpoints/mmcl-300.pt', map_location=torch.device('cpu')))
    return model, model.text2latent

state_dim = 64
action_dim = 64
airl = AIRL(state_dim, action_dim)
model, text2latent = load_model()
df_train, df_valid, df_test = load_local_dataset()

val_graphs = [smiles2graph(smiles) for smiles in df_valid['SMILES']]
train_graphs = [smiles2graph(smiles) for smiles in df_train['SMILES']]
train_batch = Batch.from_data_list(train_graphs).to(device)
valid_batch  = Batch.from_data_list(val_graphs).to(device)
train_pool = model(train_batch.x, train_batch.edge_index, train_batch.batch, train_batch.edge_attr)
valid_pool = model(valid_batch.x, valid_batch.edge_index, valid_batch.batch, valid_batch.edge_attr)

  model.load_state_dict(torch.load('/home/ali.lawati/mol-incontext/checkpoints/mmcl-300.pt', map_location=torch.device('cpu')))


In [68]:
def create_expert_trajs(text_model, text_tokenizer, text2latent, states_all, actions_all, init_states_all, samples, demos, B=32, ML=2):
    
    actions_B = actions_all.unsqueeze(0).expand(B, -1, 64)
    traj_actions = actions_B[torch.arange(B).unsqueeze(1), demos]

    init_states = init_states_all[samples].T.reshape(B,-1)

    action_states = states_all[demos]

    action_states = np.concatenate((init_states, action_states), axis=1)
    action_states = np.char.add.accumulate(np.core.defchararray.add(action_states, ' # '), axis=1)
    action_states[:,0] = init_states[:, 0]

    embed_states =  torch.tensor(embed_text(text2latent, text_model, text_tokenizer, np.reshape(action_states, -1))).reshape(B,-1, 64)

    # Create triplets using tensor indexing
    triplet_states = embed_states[:,:-1,:]  # Select all states except the last
    triplet_actions = traj_actions       # Actions are the same
    triplet_next_states = embed_states[:,1:,:]  # Select all states except the first

    # Stack the triplets into a single tensor (optional)
    triplets = torch.stack((triplet_states, triplet_actions, triplet_next_states), dim=2)
    triplets = triplets.reshape(-1,3,64)
    triplets = triplets[torch.multinomial(torch.ones(triplets.shape[0]), B, replacement=False)]
    return torch.split(triplets,1, dim=1)

In [69]:
actions_all = torch.cat((train_pool, valid_pool)).detach()
init_states_all = np.concatenate((df_valid['SMILES'].values, df_train['SMILES'].values)) # I'm doing this opposite since the prompt files use valid dataset, and so i can index it directly
states_all = np.concatenate((df_train['description'].values, df_valid['description'].values))

In [15]:
import numpy as np
scores = np.round(np.load(f"/home/ali.lawati/mol-incontext/input/embed/mmcl_attr-chebi-2-epochs300-new-test.mistral-7B.scores.npy"), 1)
rewards = np.round(np.load("/home/ali.lawati/mol-incontext/input/embed/mmcl_attr-chebi-2-epochs300-new-test.mistral-7B.rewards.npy"), 1)

# Finding indices of low scores < 0.3
low_score_indices = np.where(scores < 0.3)[0]

# Finding corresponding rewards for low scores
low_score_rewards = rewards[low_score_indices]

# Finding indices of high scores > 0.7
high_score_indices = np.where(scores > 0.7)[0]

# Finding corresponding rewards for high scores
high_score_rewards = rewards[high_score_indices]

print("Low score indices:", scores[low_score_indices])
print("Corresponding rewards for low scores:", low_score_rewards)

print("High score indices:", scores[high_score_indices])
print("Corresponding rewards for high scores:", np.round(high_score_rewards,0))

np.mean(low_score_rewards)

Low score indices: [0.1 0.2 0.  0.2 0.2 0.1 0.1 0.2 0.2 0.  0.2 0.1 0.2 0.2 0.2 0.2 0.2 0.2
 0.2 0.1 0.1 0.1 0.1 0.1 0.2 0.2 0.2 0.2 0.2 0.  0.2 0.1 0.1 0.2 0.1 0.
 0.2 0.  0.2 0.2 0.2 0.  0.2 0.2 0.2 0.2 0.  0.2 0.  0.1 0.2 0.1 0.2 0.2
 0.2 0.2 0.2 0.2 0.1 0.1 0.1 0.  0.2 0.1 0.1 0.2 0.2 0.2 0.2 0.1 0.2 0.2
 0.2 0.  0.2 0.1 0.2 0.  0.1 0.  0.1 0.1 0.2 0.1 0.  0.2 0.2 0.2 0.2 0.1
 0.  0.1 0.1 0.1 0.  0.2 0.2 0.1 0.1 0.2 0.1 0.2 0.1 0.2 0.2 0.2 0.  0.2
 0.2 0.2 0.2 0.1 0.  0.2 0.1 0.2 0.2 0.1 0.1 0.2 0.1 0.2 0.  0.2 0.1 0.1
 0.2 0.2 0.1 0.  0.2 0.2 0.1 0.1 0.2 0.2 0.1 0.2 0.2 0.2 0.2 0.1 0.2 0.2
 0.2 0.1 0.1 0.2 0.  0.2 0.  0.1 0.2 0.2 0.  0.  0.1 0.2 0.  0.1 0.2 0.1
 0.  0.1 0.1 0.  0.2 0.2 0.1 0.2 0.2 0.2 0.2 0.2 0.  0.  0.2 0.  0.  0.1
 0.2 0.2 0.  0.2 0.1 0.2 0.2 0.2 0.2 0.2 0.1 0.2 0.2 0.  0.1 0.  0.1 0.2
 0.2 0.  0.  0.2 0.1 0.  0.2 0.1 0.1 0.1 0.1 0.  0.  0.1 0.1 0.2 0.2 0.2
 0.2 0.2 0.1 0.2 0.1 0.1 0.1 0.2 0.  0.1 0.2 0.2 0.2 0.1 0.1 0.2 0.2 0.2
 0.1 0.  0.2 0.2 0.2 0.2 0.1 0.2 

np.float64(3.9076923076923076)