In [20]:
import pickle 
import yaml
import pandas as pd
from PrepareData import prepare_data


import torch
from torch import nn, optim, Tensor
from torch.nn import functional as F
import pickle 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
import seaborn as sns
from architecture import CLIP
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32


In [21]:
from train_utils import load_model

In [22]:
def make_deterministic(random_seed = 0):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    
make_deterministic(0)

In [23]:
config = yaml.safe_load(open('/fs-computility/ai4chem/luxinyu.p/Spectra2Structure/checkpoints/qm9s_raman/config.yaml', 'r'))
logs = pickle.load(open('/fs-computility/ai4chem/luxinyu.p/Spectra2Structure/checkpoints/qm9s_raman/logs.pickle', 'rb'))
for key in logs:
    if "best" in key:
        print(key, logs[key])

best_epoch 484
best_clip_epoch 478
best_recon_epoch 484
best_total_loss 3.4158136367797853
best_clip_loss 3.365497016906738
best_recon_loss 0.04973456114530563


In [24]:
# config['data']['batch_size'] = 128

In [25]:
model = load_model(config['train']['checkpoint_dir'], type="best_latest")
model.eval()
dataloaders, max_charge, num_species, scaler = prepare_data(config)

129816it [00:06, 19906.25it/s]
129816it [00:06, 19930.91it/s]

Normalizing each spectrum individually



  all_species = torch.cat([torch.tensor(dataset['charges']).unique()


SMILES WILL BE RANDOMIZED
SMILES WILL BE RANDOMIZED
SMILES WILL BE RANDOMIZED


In [26]:
val_ids = pickle.load(open('/fs-computility/ai4chem/luxinyu.p/Spectra2Structure/checkpoints/qm9s_raman/val_ids.pickle', 'rb'))

In [27]:
all_ids = []
with torch.no_grad():
    for i, data in tqdm(enumerate(dataloaders['val'])):    
        data = {k: v.to(device) for k, v in data.items()}
        all_ids.append(data['index'].detach().cpu())
all_ids = torch.cat(all_ids, 0)

15it [00:00, 39.09it/s]


In [28]:
assert((all_ids.sort()[0] == val_ids.sort()[0]).sum() == all_ids.shape[0])

In [29]:
model = load_model(config['train']['checkpoint_dir'], type="best_total")
model = model.eval()
model.to(device)


  model.load_state_dict(torch.load(model_path))


DataParallel(
  (module): CLIP(
    (Molecule_Encoder): EGNN(
      (embedding): Linear(in_features=15, out_features=256, bias=True)
      (embedding_out): Linear(in_features=256, out_features=15, bias=True)
      (gcl_0): E_GCL_mask(
        (edge_mlp): Sequential(
          (0): Linear(in_features=513, out_features=256, bias=True)
          (1): SiLU()
          (2): Linear(in_features=256, out_features=256, bias=True)
          (3): SiLU()
        )
        (node_mlp): Sequential(
          (0): Linear(in_features=527, out_features=256, bias=True)
          (1): SiLU()
          (2): Linear(in_features=256, out_features=256, bias=True)
        )
        (att_mlp): Sequential(
          (0): Linear(in_features=256, out_features=1, bias=True)
          (1): Sigmoid()
        )
        (act_fn): SiLU()
      )
      (gcl_1): E_GCL_mask(
        (edge_mlp): Sequential(
          (0): Linear(in_features=513, out_features=256, bias=True)
          (1): SiLU()
          (2): Linear(in_feat

In [30]:
# data = next(iter(dataloaders['val']))
# data = {k: v.to(device) for k, v in data.items()}
# mol_latents, spec_latents, smile_preds, logit_scale, ids = model(data)

In [31]:
# from train_utils import calculate_decoder_accuracy
# acc = calculate_decoder_accuracy(model, dataloaders, k=1)

# Faster Sampling algorithm for greedy and random

In [32]:
PAD = 0
UNK = 1
EOS = 2
SOS = 3
MASK = 4

SMI_MAX_SIZE = 71

class Sampler():
    def __init__(self, model, vocab):
        self.model = model
        self.vocab = vocab
        self.max_len = 70
    def sample(self, embed, greedy_decode=False):
        samples = []
        # print(embed.shape)
        with torch.no_grad():
            embed = embed.to(device)
            smiles_seq = torch.full((embed.shape[0], 1), SOS).long().to(device)
            
            for i in range(SMI_MAX_SIZE):
                logits = self.model.forward(embed, smiles_seq)
                # print("logits", logits.shape)
                probs = F.softmax(logits[:,-1], dim= -1)
                # print(probs.shape)
                # break
                if greedy_decode:
                    pred_id = torch.argmax(probs, dim= -1)
                    pred_id = pred_id.unsqueeze(1)
                else:
                    pred_id = torch.multinomial(probs, num_samples=1)
                # print(pred_id.shape)
                # break
                smiles_seq = torch.cat([smiles_seq, pred_id], dim=1)
                
            for i in range(len(smiles_seq)):
                smile = self.vocab.from_seq(smiles_seq[i].cpu().numpy())
                final_smile = ""
                for char in smile[1:]: # first is start token
                    if char == "<eos>" :
                        break
                    final_smile += char
                try:
                    final_smile = Chem.CanonSmiles(final_smile)
                except:
                    pass
                samples.append(final_smile)
                
        return samples
  
    

In [33]:
def calculate_decoder_accuracy(model, dataloaders, k=1, greedy_decode=True):
    all_samples = []
    og_samples = []
    sampler = Sampler(model.module.smiles_decoder, model.module.vocab)
    
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloaders['val'])):
            data = {k: v.to(device) for k, v in data.items()}
            spec_latents = model.module.forward_spec(data)
            generated_samples = sampler.sample(spec_latents, greedy_decode=greedy_decode)
            all_samples += generated_samples
            
            for og in data['smiles']:
                og_smile = ""
                chars = model.module.vocab.from_seq(og)
                for char in chars:
                    if char != "<pad>" and char != "<eos>" and char != "<sos>" and char != "<unk>":
                        og_smile += char
                try:
                    og_smile = Chem.CanonSmiles(og_smile)
                except:
                    og_smile = None
                og_samples.append(og_smile)
          
    # Calculate accuracy after processing all batches
    total = len(og_samples)
    if total == 0:
        return 0.0  # avoid division by zero
    
    hits = sum(1 for og, gen in zip(og_samples, all_samples) if og == gen)
    accuracy = hits / total
    
    print(f"Final Accuracy: {accuracy:.4f}")
    return accuracy

In [34]:
acc = calculate_decoder_accuracy(model, dataloaders, k=1, greedy_decode=True)

15it [01:03,  4.22s/it]

Final Accuracy: 0.6578





# Beam Search Sampler

In [35]:
import torch
PAD = 0
UNK = 1
EOS = 2
SOS = 3
MASK = 4

class BeamSampler():
    def __init__(self, model, vocab, beam_size=5):
        self.model = model
        self.vocab = vocab
        self.max_len = 70
        self.beam_size = beam_size

    def sample(self, embed, greedy_decode=False):
        embed = embed.unsqueeze(0).to(device) # shape (1, 512) 
        self.model.eval()
        sample_tensor = torch.zeros((1,self.max_len), dtype=torch.int64).to(device)
        sample_tensor[0,0] = SOS
        beams = [(sample_tensor, 0.0)]

        for i in range(self.max_len -1):
            
            new_beams = []
            for beam in beams:
                
                vec, score = beam[0], beam[1]
                tensor = vec[:,:i+1]
                logits = self.model.forward(embed, tensor)[:,-1,:]
                probabilities = F.log_softmax(logits, dim=1)
                topk_values, topk_indices = torch.topk(probabilities, self.beam_size, dim=-1)
                for value, ind in zip(topk_values[0], topk_indices[0]):
                    new_vec = vec.clone()
                    new_vec[0,i+1] = ind
                    new_score = score + value
                    new_beams.append((new_vec, new_score))     
     
            beams = sorted(new_beams, key=lambda x:x[1].item(), reverse=True)[:self.beam_size]

        sampled_smiles = [] 
        for beam in beams:
            smiles = ""
            chars = self.vocab.from_seq(beam[0][0])
            for char in chars:
                if char != "<pad>" and char != "<eos>" and char != "<sos>" and char != "<unk>":
                    smiles += char
            sampled_smiles.append(smiles)
                
        return sampled_smiles
        

In [36]:
# sampler = BeamSampler(model.module.smiles_decoder, model.module.vocab)
# sample_smiles = sampler.sample(spec_latents[0], greedy_decode=False)
# print("sample_smile", sample_smile)

In [37]:
def calculate_decoder_accuracy( model, dataloaders, beam_size=5):
    with torch.no_grad():
        pred_smiles_list = []
        og_smiles_list = []
        count = 0
        sampler = BeamSampler(model.module.smiles_decoder, model.module.vocab, beam_size=beam_size)
        
        for i, data in tqdm(enumerate(dataloaders['val'])):
            data = {k: v.to(device) for k, v in data.items()}
            spec_latents = model.module.forward_spec(data)
            for spec, og in zip(spec_latents, data['smiles'] ):
                ls = [sampler.sample(spec)[0]] # checking only the top beam
                # ls = sampler.sample(spec) # checking all beams
                generated_smiles = []
                for smi in ls:
                    try:
                        generated_smiles.append(Chem.CanonSmiles(smi))
                    except:
                        pass
                og_smile = ""
                chars = model.module.vocab.from_seq(og)
                for char in chars:
                    if char != "<pad>" and char != "<eos>" and char != "<sos>" and char != "<unk>":
                        og_smile += char
                try:
                    og_smile = Chem.CanonSmiles(og_smile)
                except:
                    og_smile=None
                    
                if og_smile is not None and og_smile in generated_smiles:
                    count += 1
                
                og_smiles_list.append(og_smile)
                pred_smiles_list.append(generated_smiles)
            print("No of Hits : ",count / len(og_smiles_list))
            # if sampling takes too long, we can stop after sampling around 5000 molecules    
            if i == 5:
                break
        
        return count / len(og_smiles_list)

In [38]:
acc = calculate_decoder_accuracy(model, dataloaders, beam_size=1)

1it [08:17, 497.09s/it]

No of Hits :  0.65


2it [12:27, 351.75s/it]

No of Hits :  0.66875


3it [16:05, 290.76s/it]

No of Hits :  0.6491666666666667


4it [19:48, 263.98s/it]

No of Hits :  0.655


5it [23:23, 246.56s/it]

No of Hits :  0.6535


5it [27:17, 327.48s/it]

No of Hits :  0.6533333333333333





In [23]:
acc = calculate_decoder_accuracy(model, dataloaders, beam_size=3)



No of Hits :  0.495




No of Hits :  0.455




No of Hits :  0.465




No of Hits :  0.445




No of Hits :  0.452


5it [33:20, 400.13s/it]

No of Hits :  0.44666666666666666





In [24]:
acc = calculate_decoder_accuracy(model, dataloaders, beam_size=5)





No of Hits :  0.465




No of Hits :  0.455




No of Hits :  0.4716666666666667




No of Hits :  0.46125




No of Hits :  0.47


5it [58:25, 701.04s/it]

No of Hits :  0.46





In [25]:
acc = calculate_decoder_accuracy(model, dataloaders, beam_size=10)



No of Hits :  0.5




No of Hits :  0.505




No of Hits :  0.52




No of Hits :  0.48875




No of Hits :  0.484


5it [2:09:25, 1553.16s/it]

No of Hits :  0.4841666666666667



