In [1]:
import pandas as pd
import numpy as np
import torch

import time

import sys

import re

from tqdm import tqdm

import faiss

from transformers import BertTokenizerFast, BertModel, T5TokenizerFast, T5Model
from datasets import load_dataset

from pprint import pprint
import io

import logging
logging.basicConfig(level=logging.INFO)

import matplotlib.pyplot as plt

from helper import stream

import psycopg2

len_dataset = 2326839

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda', index=0)

In [2]:
tokenizer = T5TokenizerFast.from_pretrained('t5-small')
model = T5Model.from_pretrained('t5-small').to(device)

In [3]:
dataset = load_dataset('json', data_files='dataset/arxiv_data.json', split='train', streaming=True)

In [4]:
dataset

<datasets.iterable_dataset.IterableDataset at 0x26591560b90>

In [5]:
next(iter(dataset))

{'id': '0704.0001',
 'submitter': 'Pavel Nadolsky',
 'authors': "C. Bal\\'azs, E. L. Berger, P. M. Nadolsky, C.-P. Yuan",
 'title': 'Calculation of prompt diphoton production cross sections at Tevatron and\n  LHC energies',
 'comments': '37 pages, 15 figures; published version',
 'journal-ref': 'Phys.Rev.D76:013009,2007',
 'doi': '10.1103/PhysRevD.76.013009',
 'report-no': 'ANL-HEP-PR-07-12',
 'categories': 'hep-ph',
 'license': None,
 'abstract': '  A fully differential calculation in perturbative quantum chromodynamics is\npresented for the production of massive photon pairs at hadron colliders. All\nnext-to-leading order perturbative contributions from quark-antiquark,\ngluon-(anti)quark, and gluon-gluon subprocesses are included, as well as\nall-orders resummation of initial-state gluon radiation valid at\nnext-to-next-to-leading logarithmic accuracy. The region of phase space is\nspecified in which the calculation is most reliable. Good agreement is\ndemonstrated with data from th

In [6]:
def tokenize_dataset(data):
    return tokenizer(data['abstract'], padding=True, truncation=True, max_length=256, return_tensors='pt')

tokenized_dataset = dataset.map(tokenize_dataset, batched=True, batch_size=512, remove_columns=['id', 'submitter', 'authors', 'title', 'comments', 'journal-ref', 'doi', 'report-no', 'categories', 'license', 'versions', 'update_date', 'authors_parsed'])

In [7]:
dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=128)

In [8]:
print(len(next(iter(dataloader))['abstract']))

128


In [9]:
def save_to_disk(data, filename):
    np.savez(filename, data)

In [10]:
model.eval()

T5Model(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=

In [11]:
%%time
model.eval()
embeddings = []
i = 0
with torch.no_grad():
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        start = time.time()
        
        outputs = model.encoder(input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        embeddings.append(hidden_states.mean(dim=1).cpu().numpy()) # average the 256 vectors
        
        end = time.time()
        i+= 1
        
        if i%100 == 0:
            print(f'Previous batch took {end - start:.2f} seconds\tBatch: {i}/{int(np.ceil(len_dataset/128))}\tEmbedding Shape: {embeddings[-1].shape}')
            
        
        if i % 1000 == 0:
            embeddings = np.array(embeddings)
            save_to_disk(embeddings, f'T5_embeddings/embeddings_{i}.npz')
            embeddings = embeddings.tolist()
            embeddings = []
            
if len(embeddings) > 0:
    embeddings = np.array(embeddings)
    save_to_disk(embeddings, f'T5_embeddings/embeddings_{i}.npz')

Previous batch took 0.12 seconds	Batch: 100/18179	Embedding Shape: (128, 512)
Previous batch took 0.12 seconds	Batch: 200/18179	Embedding Shape: (128, 512)
Previous batch took 0.12 seconds	Batch: 300/18179	Embedding Shape: (128, 512)
Previous batch took 0.12 seconds	Batch: 400/18179	Embedding Shape: (128, 512)
Previous batch took 0.13 seconds	Batch: 500/18179	Embedding Shape: (128, 512)
Previous batch took 0.12 seconds	Batch: 600/18179	Embedding Shape: (128, 512)
Previous batch took 0.12 seconds	Batch: 700/18179	Embedding Shape: (128, 512)
Previous batch took 0.12 seconds	Batch: 800/18179	Embedding Shape: (128, 512)
Previous batch took 0.12 seconds	Batch: 900/18179	Embedding Shape: (128, 512)
Previous batch took 0.12 seconds	Batch: 1000/18179	Embedding Shape: (128, 512)
Previous batch took 0.13 seconds	Batch: 1100/18179	Embedding Shape: (128, 512)
Previous batch took 0.12 seconds	Batch: 1200/18179	Embedding Shape: (128, 512)
Previous batch took 0.12 seconds	Batch: 1300/18179	Embedding 

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (179,) + inhomogeneous part.

In [13]:
embeddings = np.load('T5_embeddings/embeddings_1000.npz')['arr_0']
embeddings.shape

(1000, 128, 512)

In [None]:
def vec_to_sql_string(vector):
    return str(vector.tolist())

def from_sql_to_list(string):
    sql_string = re.findall(r'\[.*?\]', string)
    lst = [eval(i) for i in sql_string]
    return torch.tensor(np.array(lst), device=device, dtype=torch.float64)

def connect_to_db():
    conn = psycopg2.connect(
        host="localhost",
        port=5432,
        database="vector_database",
        user="postgres",
        password="admin"
    )
    conn.autocommit = True

    cur = conn.cursor()
    cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
    return cur, conn

cur, conn = connect_to_db()
def save_batch(batch, start): 
    i = start
    for embedding in tqdm(batch):   
        sql_string = vec_to_sql_string(batch[i])
        cur.execute("INSERT INTO article_embeddings VALUES (%s, %s)", (i, sql_string))
        i += 1

In [None]:
cur.execute("drop table article_embeddings;")
cur.execute("create table article_embeddings(article_ID int primary key, embedding vector(768));")
conn.commit()

In [14]:
embeddings = embeddings.reshape(-1, 512)
embeddings.shape

(128000, 512)

In [15]:
print(embeddings[0])
embeddings[0].shape

[ 3.28093544e-02  6.75026774e-02 -2.95573473e-02 -1.31904539e-02
 -8.76235217e-02  8.76239315e-02 -1.80753209e-02 -1.09374686e-03
 -9.44832116e-02  4.24410217e-02  9.34034772e-03 -8.47868621e-02
 -3.98606472e-02  1.32916262e-03 -2.34508477e-02 -1.07465126e-03
  9.97326337e-03  3.27166542e-02  2.37016007e-02 -5.53505942e-02
 -2.12609209e-02 -4.41676192e-03 -6.82403147e-02 -1.53395534e-01
  1.80014856e-02  1.19663961e-02 -2.06725523e-02 -8.29009414e-02
  1.00044971e-02 -2.83104852e-02  2.85374969e-02  2.63073388e-02
  1.40376966e-02  2.98258886e-02 -7.50582516e-02 -2.36231964e-02
 -1.50535414e-02  1.74066275e-02  4.39158753e-02  5.58604859e-02
 -9.32354759e-03 -2.78842282e-02  2.80379742e-01 -5.09586558e-02
  1.45210132e-01  7.65275955e-02  7.03355074e-02 -3.42450887e-02
  7.22508878e-03 -2.05006786e-02  2.51101758e-02 -4.99866083e-02
  9.30405874e-03  1.26727685e-01 -8.21842253e-03 -7.64768058e-03
 -1.80531219e-02 -1.04181282e-03  2.18757689e-02 -5.31628309e-03
 -1.49713242e-02  1.09140

(512,)

In [17]:
index = faiss.IndexFlatL2(512)
index.add(embeddings)
print(index.ntotal)

128000


In [18]:
index

<faiss.swigfaiss.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x00000267CA986AF0> >

In [19]:
D, I = index.search(embeddings[:1], 4)

In [20]:
D.shape

(1, 4)

In [21]:
I.shape

(1, 4)

In [22]:
# explain what D and I are
print(D)
print(I)

# get the vectors for the first 4 nearest neighbors
embeddings[I[0]]

[[0.         0.19465011 0.19810846 0.20966665]]
[[     0  25322  65998 112867]]


array([[ 0.03280935,  0.06750268, -0.02955735, ..., -0.03240348,
        -0.01664723,  0.00773852],
       [ 0.01696098,  0.10538723, -0.0293349 , ..., -0.03645744,
        -0.02133687,  0.0337696 ],
       [ 0.00078729,  0.06435277, -0.02840437, ..., -0.05428283,
        -0.02219047,  0.02471812],
       [-0.00676448,  0.08924416, -0.01654493, ..., -0.03034954,
        -0.01020308,  0.0657523 ]], dtype=float32)

In [24]:
def load_all_embeddings():
    e_list = []
    for i in range(1, 19):
        embeddings = np.load(f'T5_embeddings/embeddings_{i}000.npz')['arr_0'].reshape(-1, 512)
        e_list.append(embeddings)
    
    e = np.concatenate(e_list, axis=0)
    return e

embeddings_array = load_all_embeddings()

In [25]:
embeddings_array.shape

(2304000, 512)

In [27]:
%%time
index = faiss.IndexFlatL2(512)
index.add(embeddings_array)
print(index.ntotal)

2304000
CPU times: total: 219 ms
Wall time: 1.47 s


In [29]:
faiss.write_index(index, 'Indexes/T5_embeddings.index')

In [30]:
loaded_index = faiss.read_index('Indexes/T5_embeddings.index')

In [31]:
assert loaded_index.ntotal == 2304000

In [32]:
D, I = loaded_index.search(embeddings_array[:1], 4)

In [33]:
pprint(D)
pprint(I)

embeddings_array[I[0]]

array([[0.        , 0.18301383, 0.1832258 , 0.18500452]], dtype=float32)
array([[      0,  585399, 1387869,  336090]], dtype=int64)


array([[ 0.03280935,  0.06750268, -0.02955735, ..., -0.03240348,
        -0.01664723,  0.00773852],
       [ 0.00789752,  0.09153082, -0.04614971, ..., -0.03393538,
        -0.03743657,  0.01426536],
       [-0.00270015,  0.08274905, -0.00822058, ..., -0.01871685,
        -0.03101622, -0.00574923],
       [ 0.02184535,  0.09049194, -0.03042706, ..., -0.02963874,
        -0.02268996,  0.05477721]], dtype=float32)