<a href="https://colab.research.google.com/github/TimStep/QTransformer/blob/main/QGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [60]:
!pip install -q pennylane transformer_lens

In [61]:
import math
import pennylane as qml
import torch
#from transformers import GPTNeoForCausalLM, GPT2Tokenizer
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import sklearn.metrics as metrics
from transformer_lens import HookedTransformer
from datasets import load_dataset

In [62]:
RANDOM_SEED = 5678
torch.manual_seed(seed=RANDOM_SEED)
torch.cuda.manual_seed(seed=RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [63]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Tokenizer

In [64]:
reference_gpt2 = HookedTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
gpt2_tokenizer = reference_gpt2.tokenizer

Loaded pretrained model gpt2-small into HookedTransformer


In [65]:
reference_text = ["Jingle bells, jingle bells, jingle all the way", "Today I was walking home, when suddenly"]
text_tokens = reference_gpt2.to_tokens(reference_text)
text_tokens

tensor([[50256,    41, 17697, 30987,    11,   474, 17697, 30987,    11,   474,
         17697,   477,   262,   835],
        [50256,  8888,   314,   373,  6155,  1363,    11,   618,  6451, 50256,
         50256, 50256, 50256, 50256]], device='cuda:0')

In [66]:
gpt2_tokenizer.decode(text_tokens[1])

'<|endoftext|>Today I was walking home, when suddenly<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>'

In [67]:
gpt2_tokenizer(" ")

{'input_ids': [220], 'attention_mask': [1]}

# Config

In [68]:
@dataclass
class Config:
    #classical params
    d_model: int = 768 #embedding size
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024 #context length
    n_heads: int = 12 #number of attention heads
    n_layers: int = 12 #number of transformer blocks
    dropout: float = 0.1
    tying = False
    #quantum params
    query_depth: int = 1
    key_depth: int = 2
    value_depth: int = 3
    q_device: str = "lightning.qubit"

qgpt_cfg = Config()

# Embedding, MLP, Unembedding

In [69]:
class Embed(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg
        self.wte = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        self.wpe = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))

        nn.init.normal_(self.wte, std=self.cfg.init_range)
        nn.init.normal_(self.wpe, std=self.cfg.init_range)

    def forward(self, tokens):

        tok_emb = self.wte[tokens]
        pos_emb = self.wpe[torch.arange(tokens.shape[1])]
        embeddings = tok_emb + pos_emb
        return embeddings

In [70]:
class MLP(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.w_in = nn.Linear(cfg.d_model, 4 * cfg.d_model)
        self.gelu = nn.GELU()
        self.w_out = nn.Linear(4 * cfg.d_model, cfg.d_model)
        self.dropout = nn.Dropout(cfg.dropout)

    def forward(self, x):
        x = self.w_in(x)
        x = self.gelu(x)
        x = self.w_out(x)
        x = self.dropout(x)
        return x

In [71]:
class Unembed(nn.Module):

    def __init__(self, cfg, tying=None): #tying should be the W_E matrix
        super().__init__()

        self.unembed = nn.Linear(cfg.d_model, cfg.d_vocab)
        if tying: self.unembed.weight = tying

    def forward(self, x):

        return self.unembed(x)


# Attention and transformer block

In [72]:
class QAttention(torch.nn.Module):

    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg
        self.qkv_depth = (cfg.query_depth, cfg.key_depth, cfg.value_depth)
        self.n_qubits = int(math.ceil(math.log2(cfg.d_model)))
        self.device = cfg.q_device

        # init device
        self.dev = qml.device(self.device, shots=None, wires=self.n_qubits)

        # init weights
        #context size equals to the number of parallel quantum circuits we need
        #self.embedding_circuit = torch.nn.Parameter(torch.empty(self.qkv_depth[0], self.n_qubits))
        self.query_weights = nn.Parameter(torch.empty(self.qkv_depth[0], self.n_qubits))
        self.key_weights = nn.Parameter(torch.empty(self.qkv_depth[1], self.n_qubits))
        self.value_weights = nn.Parameter(torch.empty(self.qkv_depth[2], self.n_qubits))
        #output projection is kept classical
        self.W_O = nn.Linear(self.n_qubits, cfg.d_model)

        self.reset_weights()

        # init QNode
        self.query_node = qml.QNode(self.queryCircuit, self.dev, interface="torch", diff_method="best")
        self.key_node = qml.QNode(self.keyCircuit, self.dev, interface="torch", diff_method="best")
        self.value_node = qml.QNode(self.valueCircuit, self.dev, interface="torch", diff_method="best")

    def queryCircuit(self, inputs, weights, depth):

        #quantum embedding
        inputs = inputs.detach() #amplitude embediding does not support differentiable tensors
        qml.AmplitudeEmbedding(inputs, range(self.n_qubits), normalize=True, pad_with=0)


        #VQC
        for j in range(depth):
            for i in range(self.n_qubits):
                qml.RY(weights[j, i], wires=[i])

            for i in range(self.n_qubits):
                qml.CNOT(wires=[i % self.n_qubits, (i + 1) % self.n_qubits])

        return [qml.expval(qml.PauliZ(wires=[i])) for i in range(self.n_qubits)]

    def keyCircuit(self, inputs, weights, depth):
        return self.queryCircuit(inputs, weights, depth)

    def valueCircuit(self, inputs, weights, depth):
        return self.queryCircuit(inputs, weights, depth)

    def forward(self, x):

        #x = (B, seq_len, emb_len)
        #flatten all batches into one sequence x = (B*seq_len, emb_len) for q_node
        B, T = x.shape[:2] #save the batch size and sequence length for unflattening

        x = torch.flatten(x, start_dim=0, end_dim=1)
        q = self.query_node(x, self.query_weights, self.qkv_depth[0])
        q = torch.stack(q, dim=-1) # q = (B*seq_len, n_qubit)
        q = torch.unflatten(q, 0, (B, T)) #q = (B, seq_len, n_qubit)
        q = torch.unsqueeze(q, 1) #torch attention expects size[1] to be number of heads, in our case it is always one

        #same with key and value
        k = self.key_node(x, self.key_weights, self.qkv_depth[1])
        k = torch.stack(k, dim=-1)
        k = torch.unflatten(k, 0, (B, T))
        k = torch.unsqueeze(k, 1)

        v = self.value_node(x, self.value_weights, self.qkv_depth[2])
        v = torch.stack(v, dim=-1)
        v = torch.unflatten(v, 0, (B, T))
        v = torch.unsqueeze(v, 1)
        #print("Query, Key, Value shapes:")
        #print(q.shape, k.shape, v.shape)

        #attention
        att = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
        att = torch.squeeze(att) #remove the useless nhead dimension
        x = torch.unflatten(x, 0, (B, T))
        #print("Attention shape:")
        #print(att.shape)

        x = self.W_O(att)
        #print("Residual shape:")
        #print(x.shape)

        return x

    def extra_repr(self):
        pass

    def reset_weights(self):

        nn.init.uniform_(self.query_weights, a=0, b=2 * torch.pi)
        nn.init.uniform_(self.key_weights, a=0, b=2 * torch.pi)
        nn.init.uniform_(self.value_weights, a=0, b=2 * torch.pi)

    def draw_circuit(self):
        sample_input = torch.randn((self.cfg.d_model,))

        query_drawer = qml.draw(self.query_node)
        query_diagram = query_drawer(sample_input, self.query_weights, self.qkv_depth[0])

        key_drawer = qml.draw(self.key_node)
        key_diagram = key_drawer(sample_input, self.key_weights, self.qkv_depth[1])

        value_drawer = qml.draw(self.value_node)
        value_diagram = value_drawer(sample_input, self.value_weights, self.qkv_depth[2])


        print("Query circuit:")
        print(query_diagram)
        print("Key circuit:")
        print(key_diagram)
        print("Value circuit:")
        print(value_diagram)


    def draw_circuit_mpl(self):
        # Generate a sample input and weights for visualization
        sample_input = torch.randn((self.n_qubits,))

        # Use qml.draw_mpl to plot the circuit
        qml.draw_mpl(self.query_node)(sample_input, self.weights[0])
        plt.title("Quantum Circuit")
        plt.show()

In [73]:
class TransformerBlock(torch.nn.Module):

    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg
        self.ln1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
        self.ln2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
        self.attn = QAttention(cfg)
        self.mlp = MLP(cfg)

    def forward(self, x):

        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

# Utils

In [74]:
def freeze_and_load(model, reference_model, n_unfreeze=0): #load mlp and ln weights and freeze them except last n blocks

    with torch.no_grad():

      #wte and wpe matrices
      model.embed.wte.copy_(reference_model.state_dict()['embed.W_E'])
      model.embed.wpe.copy_(reference_model.state_dict()['pos_embed.W_pos'])
      model.embed.wte.requires_grad_(False)
      model.embed.wpe.requires_grad_(False)

      #unembedding matrix
      model.unembed.unembed.weight.copy_(reference_model.state_dict()['unembed.W_U'].T) #for some strange reason pretrained W_U is transposed
      #print(model.unembed.unembed.bias.shape, reference_model.state_dict()['unembed.b_U'].shape)
      model.unembed.unembed.bias.copy_(reference_model.state_dict()['unembed.b_U'])
      model.unembed.unembed.requires_grad_(False)

      #final LayerNorm
      model.ln_final.weight.copy_(reference_model.state_dict()['ln_final.w'])
      model.ln_final.bias.copy_(reference_model.state_dict()['ln_final.b'])
      model.ln_final.requires_grad_(False)

      for (n, block) in enumerate(model.blocks):

        #LayerNorms
        block.ln1.weight.copy_(reference_model.state_dict()['blocks.'+str(n)+'.ln1.w'])
        block.ln1.bias.copy_(reference_model.state_dict()['blocks.'+str(n)+'.ln1.b'])
        block.ln2.weight.copy_(reference_model.state_dict()['blocks.'+str(n)+'.ln2.w'])
        block.ln2.bias.copy_(reference_model.state_dict()['blocks.'+str(n)+'.ln2.b'])
        #MLP (the pretrained weights are transposed)
        block.mlp.w_in.weight.copy_(reference_model.state_dict()['blocks.'+str(n)+'.mlp.W_in'].T)
        block.mlp.w_in.bias.copy_(reference_model.state_dict()['blocks.'+str(n)+'.mlp.b_in'])
        block.mlp.w_out.weight.copy_(reference_model.state_dict()['blocks.'+str(n)+'.mlp.W_out'].T)
        block.mlp.w_out.bias.copy_(reference_model.state_dict()['blocks.'+str(n)+'.mlp.b_out'])
        #freeze weights
        if n in range(model.cfg.n_layers-n_unfreeze):
          #print("weights in block "+str(n)+" frozen")
          block.ln1.requires_grad_(False)
          block.ln2.requires_grad_(False)
          block.mlp.w_in.requires_grad_(False)
          block.mlp.w_out.requires_grad_(False)


# Full Transformer

In [75]:
class QGPT(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
        self.unembed = Unembed(cfg)

    def forward(self, tokens):

        residual = self.embed(tokens)
        for block in self.blocks:
            residual = block(residual)
        logits = self.unembed(self.ln_final(residual))
        return logits

In [76]:
qgpt = QGPT(qgpt_cfg)

In [77]:
params = sum(p.numel() for p in qgpt.parameters() if p.requires_grad)
print(f'Number of trainable parameters: {params}')

Number of trainable parameters: 134841121


In [78]:
freeze_and_load(qgpt, reference_gpt2)

In [79]:
params = sum(p.numel() for p in qgpt.parameters() if p.requires_grad)
print(f'Number of trainable parameters: {params}')

Number of trainable parameters: 102096


In [80]:
qgpt.to(device)

QGPT(
  (embed): Embed()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): QAttention(
        (W_O): Linear(in_features=10, out_features=768, bias=True)
      )
      (mlp): MLP(
        (w_in): Linear(in_features=768, out_features=3072, bias=True)
        (gelu): GELU(approximate='none')
        (w_out): Linear(in_features=3072, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (unembed): Unembed(
    (unembed): Linear(in_features=768, out_features=50257, bias=True)
  )
)

In [81]:
qgpt(text_tokens).shape

torch.Size([2, 14, 50257])

# Data

In [82]:
dataset = load_dataset("NeelNanda/pile-10k", split="train").remove_columns("meta")

README.md:   0%|          | 0.00/373 [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/921 [00:00<?, ?B/s]

(…)-00000-of-00001-4746b8785c874cc7.parquet:   0%|          | 0.00/33.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [83]:
torch.stack(torch.tensor(gpt2_tokenizer(dataset[0]['text'])['input_ids']).split(1024)[:-1]).shape

Token indices sequence length is longer than the specified maximum sequence length for this model (3180 > 1024). Running this sequence through the model will result in indexing errors


torch.Size([3, 1024])

In [84]:
class TextDataset(Dataset):
    def __init__(self, text_string, tokenizer, n_ctx):

      tokens = tokenizer(text_string)['input_ids']
      self.tokens = torch.tensor(tokens)
      self.X = torch.stack(torch.split(self.tokens, n_ctx)[:-1])
      self.y = torch.stack(torch.split(self.tokens.roll(-1), n_ctx)[:-1])

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, index):
        return self.X[index], self.y[index]


In [87]:
trainDS = TextDataset(dataset[0]['text'], gpt2_tokenizer, qgpt_cfg.n_ctx)
trainDL = DataLoader(trainDS, shuffle=True, pin_memory=True, batch_size = 3)

In [88]:
X, y = trainDS[0]
X, y

(tensor([1026,  318, 1760,  ..., 1088,  838, 4201]),
 tensor([ 318, 1760,   11,  ...,  838, 4201,  329]))

In [89]:
for b in trainDL:
  print(b)

[tensor([[1180, 2628,   11,  ...,  284,  564,  250],
        [1026,  318, 1760,  ..., 1088,  838, 4201],
        [ 329,  257, 7480,  ..., 6626,  656, 1811]]), tensor([[2628,   11,  422,  ...,  564,  250,  447],
        [ 318, 1760,   11,  ...,  838, 4201,  329],
        [ 257, 7480,  284,  ...,  656, 1811, 1180]])]


# Train

In [90]:
@dataclass
class TrainingArgs:
    epochs: int = 10
    max_steps_per_epoch: int = 100
    lr: int = 0.01
    weight_decay: int = 0.01
    betas: tuple = (0.9, 0.99)

qgpt_training_args = TrainingArgs()

In [93]:
class TransformerTrainer:
    def __init__(self, args, model, train_loader):
        super().__init__()
        self.model = model
        self.args = args
        self.optimizer = optim.AdamW(self.model.parameters(), lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
        self.step = 0
        self.train_loader = train_loader


    def training_step(self, batch):

        #forward pass
        tokens, targets = batch
        tokens = tokens.to(device)
        targets = targets.to(device)
        logits = self.model(tokens)

        #backward pass
        self.optimizer.zero_grad()
        loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        loss.backward()
        self.optimizer.step()
        self.step += 1

        return loss


    def validation_step(self, batch):

        tokens = batch[0].to(device)
        logits = self.model(tokens)[:, :-1]
        predicted_tokens = logits.argmax(dim=-1)
        correct_predictions = (predicted_tokens == tokens[:, 1:]).flatten()
        return correct_predictions


    def train(self):

        accuracy = np.nan

        progress_bar = tqdm(total = self.args.max_steps_per_epoch * self.args.epochs)

        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader):
                loss = self.training_step(batch)
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.2f}")
                if i >= self.args.max_steps_per_epoch:
                    break

            correct_predictions = torch.concat([self.validation_step(batch) for batch in self.test_loader()])
            accuracy = correct_predictions.float().mean().item()


In [None]:
qgpt_trainer = TransformerTrainer(qgpt_training_args, qgpt, trainDL)
qgpt_trainer.train()

  0%|          | 0/1000 [00:00<?, ?it/s]

# Inference

In [None]:
class Sampler:

    def __init__(self, cfg, model, tokenizer):

      self.cfg = cfg
      self.model = model
      self.tokenizer = tokenizer

    @torch.inference_mode()
    def generate(self, tokens, n_new_tokens, temperature=1.0, top_k=None):

        for _ in range(n_new_tokens):
            #crop sequence at context size if required
            cropped_tokens = tokens if tokens.size(1) <= self.cfg.n_ctx else tokens[:, -self.cfg.n_ctx:]
            #forward the model
            logits = self.model(cropped_tokens)
            #get the last logit and scale it by desired temperature
            logits = logits[:, -1, :] / temperature
            #apply top k
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            #softmax
            probs = nn.functional.softmax(logits, dim=-1)
            #sample nect tokens
            next_token = torch.multinomial(probs, num_samples=1)
            #add the sample token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)

        return tokens

    def sample(self, tokens, n_new_tokens, temperature=1, top_k=None):

        sentences = []
        token_sequences = self.generate(tokens, n_new_tokens, temperature, top_k)
        for sequence in token_sequences:
            sentence = self.tokenizer.decode(sequence)
            sentences.append(sentence)

        return sentences

In [None]:
qgpt_sampler = Sampler(qgpt_cfg, qgpt, gpt2_tokenizer)

In [None]:
qgpt_sampler.sample(text_tokens, 10, temperature=0.5)

In [None]:
gpt2_sampler = Sampler(qgpt_cfg, reference_gpt2, gpt2_tokenizer)

In [None]:
qgpt_sampler.sample(text_tokens, 10, temperature=0.5)