In [1]:
import torch
from transformers import GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from transformers import GPT2Model, GPT2Config
from faiss_attention import * 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class GPT2FaissBlock(GPT2Block):
    def __init__(self, config):
        super().__init__(config)
        self.attn = FAISSAttention(config)

class GPT2FaissModel(GPT2Model):
    def __init__(self, config):
        super().__init__(config)
        for i in range(len(self.h)):
            self.h[i] = GPT2FaissBlock(config)


In [3]:
model_name = "gpt2"
model = GPT2Model.from_pretrained(model_name)

In [4]:
config = GPT2Config.from_pretrained("gpt2")

In [5]:
custom_model = GPT2FaissModel(config)

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_text = "My name is Anthony Hsu and I am from Tucson Arizona. I like Asian food the most."
inputs = tokenizer(input_text, return_tensors="pt")

In [6]:
model

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [7]:
custom_model

GPT2FaissModel(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2FaissBlock(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): FAISSAttention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [8]:
outputs = model(**inputs)
print(outputs)

BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[-0.0334, -0.0433, -0.2827,  ..., -0.1524,  0.0162, -0.1180],
         [-0.5585,  0.0883, -0.7683,  ...,  0.6417, -0.1568,  0.2390],
         [ 0.0399, -0.0287,  0.1253,  ...,  0.2588, -0.1078,  0.6273],
         ...,
         [-0.4435, -0.2309, -0.8796,  ...,  0.1454, -0.0781, -0.1260],
         [ 0.5223,  0.0613, -1.5569,  ...,  0.3271, -0.1860,  0.3285],
         [ 0.1979,  0.1766, -0.3105,  ...,  0.2982, -0.3473,  0.1629]]],
       grad_fn=<ViewBackward0>), past_key_values=((tensor([[[[-1.0961,  1.8475,  0.8989,  ..., -1.3003, -0.7141,  1.1528],
          [-2.8246,  2.6349,  1.2125,  ..., -0.9360, -0.7353,  1.5514],
          [-1.9516,  2.4045,  1.9535,  ..., -1.6259, -2.7266,  2.5379],
          ...,
          [-2.6458,  2.1116,  2.2836,  ..., -0.7238, -1.7342,  2.0893],
          [-2.7420,  1.1224,  2.1848,  ..., -2.1276, -2.1284,  0.7462],
          [-2.6352,  2.3528,  2.2304,  ..., -0.2091, -2.0992,  1.6247]],

In [9]:
# Forward pass through the custom model
outputs = custom_model(**inputs)
print(outputs)

BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[ 0.0991, -0.1415,  0.0910,  ..., -0.0793, -1.3604, -0.3952],
         [ 1.5787, -0.1135,  0.5435,  ..., -1.3379, -2.0873,  0.3842],
         [ 1.0581, -0.7977,  0.0196,  ..., -0.9184, -2.3383,  0.5916],
         ...,
         [ 0.4253, -1.2391, -0.2684,  ...,  0.1148, -1.2241,  0.2189],
         [-0.3329, -0.8296, -0.6341,  ..., -0.8617, -1.1323, -0.4224],
         [ 1.0290, -1.3304, -0.9650,  ..., -0.7179, -0.6320, -0.8087]]],
       grad_fn=<ViewBackward0>), past_key_values=((tensor([[[[ 0.3994, -0.2384, -0.8961,  ...,  0.8724,  0.1836, -0.5150],
          [-0.9705,  0.9019, -0.4713,  ..., -0.1061, -0.7327,  0.4053],
          [ 0.1455,  0.5162, -0.2417,  ..., -0.2850, -0.0565, -0.2497],
          ...,
          [-1.0522, -0.2864,  0.6264,  ...,  0.1014, -0.0280,  0.5325],
          [ 0.3871, -0.1346,  0.2734,  ...,  0.2250,  0.8547, -0.2662],
          [ 0.0497,  0.5319, -0.1264,  ..., -0.1402, -0.4538, -0.2900]],

In [10]:
from torch.nn.functional import cosine_similarity

# Get outputs
regular_outputs = model(**inputs).last_hidden_state
custom_outputs = custom_model(**inputs).last_hidden_state

# Compute similarity
cos_sim = cosine_similarity(regular_outputs.flatten(), custom_outputs.flatten(), dim=0)
print(f"Cosine Similarity: {cos_sim.item()}")

# Compute Mean Squared Error
mse = torch.mean((regular_outputs - custom_outputs) ** 2).item()
print(f"Mean Squared Error: {mse}")


Cosine Similarity: 0.0075410520657896996
Mean Squared Error: 56.18490982055664


In [11]:
import time

# Measure regular model time
start_time = time.time()
for _ in range(10):
    outputs = model(**inputs)
regular_model_time = (time.time() - start_time) / 10
print(f"Regular Model Inference Time: {regular_model_time:.4f} seconds")

# Measure custom FAISS model time
start_time = time.time()
for _ in range(10):
    outputs = custom_model(**inputs)
custom_model_time = (time.time() - start_time) / 10
print(f"Custom Model Inference Time: {custom_model_time:.4f} seconds")


Regular Model Inference Time: 0.0431 seconds
Custom Model Inference Time: 0.0549 seconds
