In [1]:
import torch
from transformers import GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from transformers import GPT2Model, GPT2Config
from faiss_attention import * 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class GPT2FaissBlock(GPT2Block):
    def __init__(self, config):
        super().__init__(config)
        self.attn = FAISSAttention(config)

class GPT2FaissModel(GPT2Model):
    def __init__(self, config):
        super().__init__(config)
        for i in range(len(self.h)):
            self.h[i] = GPT2FaissBlock(config)


In [3]:
model_name = "gpt2"
model = GPT2Model.from_pretrained(model_name)

In [4]:
config = GPT2Config.from_pretrained("gpt2")

In [6]:
config._attn_implementation

'sdpa'

In [5]:
custom_model = GPT2FaissModel(config)

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_text = "The quick brown fox jumps over the lazy dog."
inputs = tokenizer(input_text, return_tensors="pt")

In [8]:
model

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [7]:
custom_model

GPT2FaissModel(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2FaissBlock(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): FAISSAttention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [None]:
outputs = model(**inputs)
print(outputs)

In [7]:
# Forward pass through the custom model
outputs = custom_model(**inputs)
print(outputs)

: 

: 

In [9]:
import faiss