<a href="https://colab.research.google.com/github/TimStep/QTransformer/blob/main/QGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q pennylane transformer_lens

In [2]:
import math
import pennylane as qml
import torch
#from transformers import GPTNeoForCausalLM, GPT2Tokenizer
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import sklearn.metrics as metrics
from transformer_lens import HookedTransformer

In [3]:
RANDOM_SEED = 5678
torch.manual_seed(seed=RANDOM_SEED)
torch.cuda.manual_seed(seed=RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

# Tokenizer

In [None]:
#not working lately, waiting until https://huggingface.co/openai-community/gpt2 comes back to life
reference_gpt2 = HookedTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
reference_text = ["Today we are going to implement a Transformer from scratch!", "Today we are going to implement a Transformer from scratch!"]
text_tokens = reference_gpt2.to_tokens(reference_text).to(device)

# Config

In [7]:
@dataclass
class Config:
    #classical params
    d_model: int = 768 #embedding size
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024 #context length
    n_heads: int = 12 #number of attention heads
    n_layers: int = 12 #number of transformer blocks
    dropout: float = 0.1
    tying = False
    #quantum params
    query_depth: int = 1
    key_depth: int = 2
    value_depth: int = 3
    q_device: str = "lightning.qubit"

cfg = Config()

# Embedding, MLP and other stuff

In [8]:
class Embed(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg
        self.wte = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        self.wpe = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))

        nn.init.normal_(self.wte, std=self.cfg.init_range)
        nn.init.normal_(self.wpe, std=self.cfg.init_range)

    def forward(self, tokens):

        tok_emb = self.wte[tokens]
        pos_emb = self.wpe[torch.arange(tokens.shape[1])]
        embeddings = tok_emb + pos_emb
        return embeddings

In [9]:
class MLP(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.w_in = nn.Linear(cfg.d_model, 4 * cfg.d_model)
        self.gelu = nn.GELU()
        self.w_out = nn.Linear(4 * cfg.d_model, cfg.d_model)
        self.dropout = nn.Dropout(cfg.dropout)

    def forward(self, x):
        x = self.w_in(x)
        x = self.gelu(x)
        x = self.w_out(x)
        x = self.dropout(x)
        return x

In [10]:
class Unembed(nn.Module):

    def __init__(self, cfg, tying=None): #tying should be the W_E matrix
        super().__init__()

        self.unembed = nn.Linear(cfg.d_model, cfg.d_vocab)
        if tying: self.unembed.weight = tying

    def forward(self, x):

        return self.unembed(x)


# Attention and transformer block

In [23]:
class QAttention(torch.nn.Module):

    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg
        self.qkv_depth = (cfg.query_depth, cfg.key_depth, cfg.value_depth)
        self.n_qubits = int(math.ceil(math.log2(cfg.d_model)))
        self.device = cfg.q_device

        # init device
        self.dev = qml.device(self.device, shots=None, wires=self.n_qubits)

        # init weights
        #context size equals to the number of parallel quantum circuits we need
        #self.embedding_circuit = torch.nn.Parameter(torch.empty(self.qkv_depth[0], self.n_qubits))
        self.query_weights = nn.Parameter(torch.empty(self.qkv_depth[0], self.n_qubits))
        self.key_weights = nn.Parameter(torch.empty(self.qkv_depth[1], self.n_qubits))
        self.value_weights = nn.Parameter(torch.empty(self.qkv_depth[2], self.n_qubits))
        #output projection is kept classical
        self.W_O = nn.Linear(self.n_qubits, cfg.d_model)

        self.reset_weights()

        # init QNode
        self.query_node = qml.QNode(self.queryCircuit, self.dev, interface="torch", diff_method="best")
        self.key_node = qml.QNode(self.keyCircuit, self.dev, interface="torch", diff_method="best")
        self.value_node = qml.QNode(self.valueCircuit, self.dev, interface="torch", diff_method="best")

    def queryCircuit(self, inputs, weights, depth):

        #quantum embedding
        inputs = inputs.detach() #amplitude embediding does not support differentiable tensors
        qml.AmplitudeEmbedding(inputs, range(self.n_qubits), normalize=True, pad_with=0)


        #VQC
        for j in range(depth):
            for i in range(self.n_qubits):
                qml.RY(weights[j, i], wires=[i])

            for i in range(self.n_qubits):
                qml.CNOT(wires=[i % self.n_qubits, (i + 1) % self.n_qubits])

        return [qml.expval(qml.PauliZ(wires=[i])) for i in range(self.n_qubits)]

    def keyCircuit(self, inputs, weights, depth):
        return self.queryCircuit(inputs, weights, depth)

    def valueCircuit(self, inputs, weights, depth):
        return self.queryCircuit(inputs, weights, depth)

    def forward(self, x):

        #x = (B, seq_len, emb_len)
        #flatten all batches into one sequence x = (B*seq_len, emb_len) for q_node
        B, T = x.shape[:2] #save the batch size and sequence length for unflattening

        x = torch.flatten(x, start_dim=0, end_dim=1)
        q = self.query_node(x, self.query_weights, self.qkv_depth[0])
        q = torch.stack(q, dim=-1) # q = (B*seq_len, n_qubit)
        q = torch.unflatten(q, 0, (B, T)) #q = (B, seq_len, n_qubit)
        q = torch.unsqueeze(q, 1) #torch attention expects size[1] to be number of heads, in our case it is always one

        #same with key and value
        k = self.key_node(x, self.key_weights, self.qkv_depth[1])
        k = torch.stack(k, dim=-1)
        k = torch.unflatten(k, 0, (B, T))
        k = torch.unsqueeze(k, 1)

        v = self.value_node(x, self.value_weights, self.qkv_depth[2])
        v = torch.stack(v, dim=-1)
        v = torch.unflatten(v, 0, (B, T))
        v = torch.unsqueeze(v, 1)
        #print("Query, Key, Value shapes:")
        #print(q.shape, k.shape, v.shape)

        #attention
        att = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
        att = torch.squeeze(att) #remove the useless nhead dimension
        x = torch.unflatten(x, 0, (B, T))
        #print("Attention shape:")
        #print(att.shape)

        x = self.W_O(att)
        #print("Residual shape:")
        #print(x.shape)

        return x

    def extra_repr(self):
        pass

    def reset_weights(self):

        nn.init.uniform_(self.query_weights, a=0, b=2 * torch.pi)
        nn.init.uniform_(self.key_weights, a=0, b=2 * torch.pi)
        nn.init.uniform_(self.value_weights, a=0, b=2 * torch.pi)

    def draw_circuit(self):
        sample_input = torch.randn((self.cfg.d_model,))

        query_drawer = qml.draw(self.query_node)
        query_diagram = query_drawer(sample_input, self.query_weights, self.qkv_depth[0])

        key_drawer = qml.draw(self.key_node)
        key_diagram = key_drawer(sample_input, self.key_weights, self.qkv_depth[1])

        value_drawer = qml.draw(self.value_node)
        value_diagram = value_drawer(sample_input, self.value_weights, self.qkv_depth[2])

        '''
        print("Query circuit:")
        print(query_diagram)
        print("Key circuit:")
        print(key_diagram)
        print("Value circuit:")
        print(value_diagram)
        '''

    def draw_circuit_mpl(self):
        # Generate a sample input and weights for visualization
        sample_input = torch.randn((self.n_qubits,))

        # Use qml.draw_mpl to plot the circuit
        qml.draw_mpl(self.query_node)(sample_input, self.weights[0])
        plt.title("Quantum Circuit")
        plt.show()

In [24]:
class TransformerBlock(torch.nn.Module):

    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg
        self.ln1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
        self.ln2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
        self.attn = QAttention(cfg)
        self.mlp = MLP(cfg)

    def forward(self, x):

        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

# Full Transformer

In [25]:
class QGPT(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
        self.unembed = Unembed(cfg)

    def forward(self, tokens):

        residual = self.embed(tokens)
        for block in self.blocks:
            residual = block(residual)
        logits = self.unembed(self.ln_final(residual))
        return logits

In [26]:
def freeze_and_load(model, reference_model): #load mlp and ln weights and freeze them

    with torch.no_grad():

      #wte and wpe matrices
      model.embed.wte.copy_(reference_model.state_dict()['embed.W_E'])
      model.embed.wpe.copy_(reference_model.state_dict()['pos_embed.W_pos'])
      model.embed.wte.requires_grad_(False)
      model.embed.wpe.requires_grad_(False)

      #unembedding matrix
      model.unembed.unembed.weight.copy_(reference_model.state_dict()['unembed.W_U'].T) #for some strange reason pretrained W_U is transposed
      #print(model.unembed.unembed.bias.shape, reference_model.state_dict()['unembed.b_U'].shape)
      model.unembed.unembed.bias.copy_(reference_model.state_dict()['unembed.b_U'])
      model.unembed.unembed.requires_grad_(False)

      #final LayerNorm
      model.ln_final.weight.copy_(reference_model.state_dict()['ln_final.w'])
      model.ln_final.bias.copy_(reference_model.state_dict()['ln_final.b'])
      model.ln_final.requires_grad_(False)

      for (n, block) in enumerate(model.blocks):

        #LayerNorms
        block.ln1.weight.copy_(reference_model.state_dict()['blocks.'+str(n)+'.ln1.w'])
        block.ln1.bias.copy_(reference_model.state_dict()['blocks.'+str(n)+'.ln1.b'])
        block.ln1.requires_grad_(False)
        block.ln2.weight.copy_(reference_model.state_dict()['blocks.'+str(n)+'.ln2.w'])
        block.ln2.bias.copy_(reference_model.state_dict()['blocks.'+str(n)+'.ln2.b'])
        block.ln2.requires_grad_(False)
        #MLP (the pretrained weights are transposed)
        block.mlp.w_in.weight.copy_(reference_model.state_dict()['blocks.'+str(n)+'.mlp.W_in'].T)
        block.mlp.w_in.bias.copy_(reference_model.state_dict()['blocks.'+str(n)+'.mlp.b_in'])
        block.mlp.w_in.requires_grad_(False)
        block.mlp.w_out.weight.copy_(reference_model.state_dict()['blocks.'+str(n)+'.mlp.W_out'].T)
        block.mlp.w_out.bias.copy_(reference_model.state_dict()['blocks.'+str(n)+'.mlp.b_out'])
        block.mlp.w_out.requires_grad_(False)


In [27]:
qgpt = QGPT(cfg)

In [28]:
params = sum(p.numel() for p in qgpt.parameters() if p.requires_grad)
print(f'Number of trainable parameters: {params}')

Number of trainable parameters: 134841121


In [29]:
freeze_and_load(qgpt, reference_gpt2)

In [30]:
params = sum(p.numel() for p in qgpt.parameters() if p.requires_grad)
print(f'Number of trainable parameters: {params}')

Number of trainable parameters: 102096


In [31]:
qgpt(text_tokens).shape

torch.Size([2, 13, 50257])