In [1]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm

torch.manual_seed(10)

<torch._C.Generator at 0x2287fbf4590>

In [2]:
# Test Model
class TestModel(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.embedding = torch.nn.Embedding(10, hidden_size)
        self.linear = torch.nn.Linear(hidden_size, hidden_size)
        self.lm_head = torch.nn.Linear(hidden_size, 10)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        x = self.linear(x)
        x = self.lm_head(x)
        return x

In [3]:
hidden_size = 1024
model = TestModel(hidden_size)
model

TestModel(
  (embedding): Embedding(10, 1024)
  (linear): Linear(in_features=1024, out_features=1024, bias=True)
  (lm_head): Linear(in_features=1024, out_features=10, bias=True)
)

In [4]:
# dummy inputs
input_ids = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

# toy example of a detokenizer. 
# The vocabulary only consists of 10 words (different colors)
detokenizer = [
    "red",
    "orange",
    "yellow",
    "green",
    "blue",
    "indigo",
    "violet",
    "magenta",
    "marigold",
    "chartreuse",
]

In [5]:
# Token Generation function
def generate_token(model, **kwargs):
    with torch.no_grad():
        logits = model(**kwargs)
    last_logits = logits[:, -1, :]
    next_token_ids = last_logits.argmax(dim=1)

    return [detokenizer[token_id] for token_id in next_token_ids]

In [6]:
# generate one token
next_token = generate_token(model, input_ids=input_ids) # passing dummy inputs
next_token[0]

'red'

In [7]:
# dummy input tensor
X = torch.randn(1, 8, 1024) # (batch_size, sequence_length, hidden_size)

In [8]:
# LORA 
# A (hidden_size, rank)
# B (rank, hidden_size)
lora_a = torch.randn(1024, 2)
lora_b = torch.randn(2, 1024) 

In [9]:
W = model.linear.weight
W.shape

torch.Size([1024, 1024])

In [10]:
W2 = lora_a @ lora_b
print(W2.shape)


# Compare number of elements of A and B with number of elements of W
# W here has shape (hidden_size, hidden_size)
lora_numel = lora_a.numel() + lora_b.numel()
base_numel = W.numel()
print("|A+B| / |W|:", lora_numel / base_numel)

# output of X @ W
base_output = model.linear(X)  # X: dummy input tensor

# output of X @ (A @ B)
lora_output = X @ lora_a @ lora_b

# sum of base_output and lora_output
total_output = base_output + lora_output

# check if the outputs are of same shape
total_output.shape

torch.Size([1024, 1024])
|A+B| / |W|: 0.00390625


torch.Size([1, 8, 1024])

In [11]:
# Lora Model
class LoraLayer(torch.nn.Module):
    def __init__(self, base_layer, r):
        super().__init__()
        self.base_layer = base_layer

        d_in, d_out = self.base_layer.weight.shape
        self.lora_a = torch.randn(d_in, r)
        self.lora_b = torch.randn(r, d_out)
    
    def forward(self, x):
        y1 = self.base_layer(x)
        y2 = x @ self.lora_a @ self.lora_b
        return y1 + y2
    
# wrap the linear layer of our test model, use rank 2
lora_layer = LoraLayer(model.linear, 2)
print(lora_layer(X).shape)

model.linear = lora_layer
model

torch.Size([1, 8, 1024])


TestModel(
  (embedding): Embedding(10, 1024)
  (linear): LoraLayer(
    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (lm_head): Linear(in_features=1024, out_features=10, bias=True)
)

In [12]:
# Token Generation with LoRA layer
next_token = generate_token(model, input_ids=input_ids)
next_token[0]

'green'