# Demo: Part 2

This notebook is part 2 of a 3-part demonstration of the model, tested on a Windows 11 machine.

### Instructions:

1. **Setup**: Follow the instructions in [README.md](../README.md) (also given in [demo_part1.ipynb](../demo_part1.ipynb)) to:
   - Set up the required [Conda environments](../README.md#setup).
   - Download all necessary [data and models](../README.md#download-models-and-data).
   - Prepare codebase of the [submodule](../README.md#prepare-codebase)  

2. **Preparation**: Complete the steps in [demo_part1.ipynb](../demo_part1.ipynb) before running this notebook.  

3. **Environment**: Activate the `retrieval` Conda virtual environment to run this notebook. Refer to [README.md](../README.md) for detailed setup guidance.  

4. **Next Step**: Once you finish this notebook, proceed to [../demo_part3.ipynb](../demo_part3.ipynb).

In [2]:
#base import
import os
from pathlib import Path
import sys
import logging
import numpy as np
import torch
from argparse import Namespace

In [3]:
# imports from RetrievalModel
from mips import MIPS, augment_query, l2_to_ip
from retriever import ProjEncoder, DataLoader 
from utils import move_to_device, asynchronous_load
from data import Vocab, BOS, EOS

In [4]:
# Initialize logger
logger = logging.getLogger(__name__)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)

In [5]:
# Step 1: Define directories
# Current directory
current_dir = os.getcwd()
current_dir = os.path.dirname(current_dir) # root folder

print(f"Parent Directory: {current_dir}")

Parent Directory: c:\Users\fahim\Documents\startup\Chemistry_RetroSub\RetroSub_Research


In [6]:
# Step 2: Define constants
top = 20
retrieval_model_dir = Path("../ckpts/uspto_full/dual_encoder/epoch116_batch349999_acc0.79").resolve()
data_dir = Path("../data/uspto_full").resolve()
print(data_dir)

C:\Users\fahim\Documents\startup\Chemistry_RetroSub\RetroSub_Research\data\uspto_full


In [7]:
# Step 3: Define arguments as a namespace
args = Namespace(
    input_file=f"{current_dir}\\demo_data\\test_input_dual_encoder.txt",
    output_file=f"{current_dir}\\demo_data\\test_input_dual_encoder.top{top}.txt",
    ckpt_path=retrieval_model_dir / "query_encoder",
    args_path=retrieval_model_dir / "args",
    vocab_path=data_dir / "retrieval/src.vocab",
    index_file=data_dir / "candidates.txt",
    index_path=str(retrieval_model_dir / "mips_index"),
    topk=top,
    allow_hit=True,
    batch_size=1024,
    nprobe=64
)

print(args.input_file)
print(args.index_path)

c:\Users\fahim\Documents\startup\Chemistry_RetroSub\RetroSub_Research\demo_data\test_input_dual_encoder.txt
C:\Users\fahim\Documents\startup\Chemistry_RetroSub\RetroSub_Research\ckpts\uspto_full\dual_encoder\epoch116_batch349999_acc0.79\mips_index


Main function from `search_index.py` :

In [8]:
logger.info('Loading model...')
vocab = Vocab(args.vocab_path, 0, [BOS, EOS])
model_args = torch.load(args.args_path)
model = ProjEncoder.from_pretrained(vocab, model_args, args.ckpt_path)
# model.to(device)
# model.cuda()
# model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
model.eval()

logger.info('Model loaded.')

12/10/2024 19:19:59 - INFO - __main__ - Loading model...
12/10/2024 19:19:59 - INFO - __main__ - Model loaded.


In [9]:

logger.info('Collecting data...')

data_r = []
with open(args.index_file) as f:
    for line in f.readlines():
        r = line.strip()
        data_r.append(r)

data_q = []
data_qr = []
with open(args.input_file, 'r') as f:
    for line in f.readlines():
        q, r = line.strip().split('\t')
        data_q.append(q)
        data_qr.append(r)

logger.info('Collected %d instances', len(data_q))

12/10/2024 19:20:02 - INFO - __main__ - Collecting data...
12/10/2024 19:20:03 - INFO - __main__ - Collected 1 instances


In [10]:
textq, textqr, textr = data_q, data_qr, data_r
data_loader = DataLoader(data_q, vocab, args.batch_size)

mips = MIPS.from_built(args.index_path, nprobe=args.nprobe)
max_norm = torch.load(os.path.dirname(args.index_path)+'/max_norm.pt')
# mips.to_gpu() 
logger.info('Data loader and MIPS done')

12/10/2024 19:20:06 - INFO - __main__ - Data loader and MIPS done


In [None]:
logger.info('Start search')
cur, tot = 0, len(data_q)
with open(args.output_file, 'w') as fo:
    for batch in asynchronous_load(data_loader):
        with torch.no_grad():
            q = torch.from_numpy(batch).contiguous().t() 
            bsz = q.size(0)
            vecsq = model(q, batch_first=True).detach().cpu().numpy()
        vecsq = augment_query(vecsq)
        D, I = mips.search(vecsq, args.topk+1)
        D = l2_to_ip(D, vecsq, max_norm) / (max_norm * max_norm)
        for i, (Ii, Di) in enumerate(zip(I, D)):
            item = [textq[cur+i], textqr[cur+i]]
            for pred, s in zip(Ii, Di):
                if args.allow_hit or textr[pred] != textqr[cur+i]:
                    item.append(textr[pred])
                    item.append(str(float(s)))
            item = item[:2+2*args.topk]
            assert len(item) == 2+2*args.topk
            #print(item)
            fo.write('\t'.join(item)+'\n')
        cur += bsz
        logger.info('finished %d / %d', cur, tot)

12/10/2024 19:20:10 - INFO - __main__ - Start search
12/10/2024 19:20:11 - INFO - __main__ - finished 1 / 1


['C O C ( = O ) C c 1 c c 2 c c c ( F ) c c 2 c ( - c 2 c c c ( S ( C ) ( = O ) = O ) c n 2 ) c 1 C', 'TGT_PLACEHOLDER', 'C O C ( = O ) C c 1 c c 2 c c c ( F ) c c 2 c ( O S ( = O ) ( = O ) C ( F ) ( F ) F ) c 1 C . C S ( = O ) ( = O ) c 1 c c c ( B ( O ) O ) c c 1', '0.9396535754203796', 'C O C ( = O ) C c 1 c c 2 c c c ( F ) c c 2 c ( O S ( = O ) ( = O ) C ( F ) ( F ) F ) c 1 C . C c 1 c c c c c 1 S ( = O ) ( = O ) c 1 c c c ( B 2 O C ( C ) ( C ) C ( C ) ( C ) O 2 ) c c 1', '0.9394717216491699', 'C C S ( = O ) ( = O ) c 1 c c c ( Br ) n c 1 . C O C ( = O ) C c 1 c c ( O ) c 2 c c ( F ) c c c 2 c 1 C', '0.9388871192932129', 'C O C ( = O ) C c 1 c c 2 c c c ( F ) c c 2 c ( - c 2 c c c ( S ( C ) ( = O ) = O ) c n 2 ) c 1 C', '0.9384120106697083', 'C C S ( = O ) ( = O ) c 1 c c c ( Br ) n c 1 . C O C ( = O ) C c 1 c c 2 c c c ( F ) c c 2 c ( O ) c 1 F', '0.9382520914077759', 'C O C ( = O ) C c 1 c c 2 c c c ( F ) c c 2 c ( O S ( = O ) ( = O ) C ( F ) ( F ) F ) c 1 C . C c 1 c c c ( C B 2