# LoRA

The goal of this practical is to adapt the code of [minGPT](https://github.com/karpathy/minGPT/) form [Karpathy](https://karpathy.ai/) in order to incorporate Low Rank Adaptation (LoRA) for fine-tuning.

![](https://miro.medium.com/v2/resize:fit:720/format:webp/1*D_i25E9dTd_5HMa45zITSg.png)

This [blog](https://r4j4n.github.io/blogs/posts/lora/) by [Rajan Ghimire](https://r4j4n.github.io/blogs/about/) is a nice introduction to LoRA.

In [5]:
import math
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

## Building a custom Linear module

methods
- [`forward`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.forward)
- [`train`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.train)
- [`eval`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval)
- [`reset_parameters`](https://github.com/pytorch/pytorch/blob/v2.6.0/torch/nn/modules/linear.py#L50)

In [222]:
class LoRALinear(nn.Linear):

    def __init__(self,
                 # nn.Linear parameters
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 device=None,
                 dtype=None,
                 # LoRA parameters
                 lora_rank: int = 0,
                 lora_alpha: float = 0.0,
                ) -> None:
        nn.Linear.__init__(
            self,
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            device=device,
            dtype=dtype
        )

        # LoRA stuff
        self.has_weights_merged = False
        if lora_rank > 0:
            self.lora_scaling = lora_alpha / lora_rank
            self.lora_A = nn.Parameter(torch.empty((lora_rank, self.in_features), device=device, dtype=dtype))
            self.lora_B = nn.Parameter(torch.empty((self.out_features, lora_rank), device=device, dtype=dtype))

            self.lora_A.requires_grad = False
            self.lora_B.requires_grad = False

            self.reset_parameters_lora()


    def reset_parameters_lora(self) -> None:
        ###
        # your code here
        
        # using a guassian distribution to initialize the weights
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) # Same as nn.Linear
        # using zero for B
        nn.init.zeros_(self.lora_B)
        ###

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        x = nn.Linear.forward(self, input)
        ###
        # your code here
        
        # pass through A
        dx = torch.matmul(input, self.lora_A.T)
        # pass through B
        dx = torch.matmul(dx, self.lora_B.T)
        # scaling
        dx = dx * self.lora_scaling
        # add to x
        x = x + dx    
        ###
        return x

    def train(self, mode: bool = True) -> "LoRALinear":
        nn.Linear.train(self, mode)
        ###
        # your code here

        # set the requires_grad to True
        self.lora_A.requires_grad = True
        self.lora_B.requires_grad = True
        ###
        return self

    def eval(self) -> "LoRALinear":
        nn.Linear.eval(self)
        ###
        # your code here

        # set the requires_grad to False
        self.lora_A.requires_grad = False
        self.lora_B.requires_grad = False
        ###
        return self

In [223]:
ln = LoRALinear(in_features=3,out_features=4, lora_rank = 8, lora_alpha = 32)

In [224]:
ln.weight

Parameter containing:
tensor([[ 0.1673,  0.1379,  0.3174],
        [-0.3197, -0.3685,  0.0887],
        [-0.2866, -0.4369, -0.0785],
        [-0.3525,  0.1554,  0.3265]], requires_grad=True)

In [225]:
ln.bias

Parameter containing:
tensor([ 0.5289, -0.4530, -0.5749,  0.1643], requires_grad=True)

In [226]:
for p in ln.parameters():
    print(p)

Parameter containing:
tensor([[ 0.1673,  0.1379,  0.3174],
        [-0.3197, -0.3685,  0.0887],
        [-0.2866, -0.4369, -0.0785],
        [-0.3525,  0.1554,  0.3265]], requires_grad=True)
Parameter containing:
tensor([ 0.5289, -0.4530, -0.5749,  0.1643], requires_grad=True)
Parameter containing:
tensor([[-0.4028, -0.5322,  0.1801],
        [-0.2230,  0.3775, -0.2045],
        [-0.4025,  0.2486,  0.1134],
        [ 0.3117,  0.1060, -0.2012],
        [ 0.4751, -0.4945, -0.5722],
        [ 0.4838, -0.1583, -0.2806],
        [ 0.5514, -0.3610, -0.1109],
        [-0.5553, -0.1461,  0.0150]])
Parameter containing:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])


In [227]:
bs = 5
x = torch.randn((bs, 3))
y = ln(x)

In [228]:
y2 = x@ln.weight.T + ln.bias

In [229]:
torch.isclose(y,y2)

tensor([[True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True]])

In [230]:
ln.train()

LoRALinear(in_features=3, out_features=4, bias=True)

In [231]:
y3 = ln(x)
torch.isclose(y3,y2)

tensor([[True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True]])

In [232]:
ln.eval()
y3 = ln(x)
torch.isclose(y3,y2)

tensor([[True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True]])

In [233]:
def get_lora_model(model: nn.Module) -> nn.Module:
    for name, param in model.named_parameters():
        if "lora" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False
    return model

In [234]:
ln_lora = get_lora_model(ln)

In [235]:
for p in ln_lora.parameters():
    print(p)

Parameter containing:
tensor([[ 0.1673,  0.1379,  0.3174],
        [-0.3197, -0.3685,  0.0887],
        [-0.2866, -0.4369, -0.0785],
        [-0.3525,  0.1554,  0.3265]])
Parameter containing:
tensor([ 0.5289, -0.4530, -0.5749,  0.1643])
Parameter containing:
tensor([[-0.4028, -0.5322,  0.1801],
        [-0.2230,  0.3775, -0.2045],
        [-0.4025,  0.2486,  0.1134],
        [ 0.3117,  0.1060, -0.2012],
        [ 0.4751, -0.4945, -0.5722],
        [ 0.4838, -0.1583, -0.2806],
        [ 0.5514, -0.3610, -0.1109],
        [-0.5553, -0.1461,  0.0150]], requires_grad=True)
Parameter containing:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]], requires_grad=True)


## Use the LoRA layer in the building blocks of minGPT

In [236]:
from mingpt.model import CausalSelfAttention

class CausalSelfAttention_LoRA(CausalSelfAttention):
    def __init__(self, config):
        super().__init__(config)
        # minor modifications
        self.c_attn = LoRALinear(
            in_features=config.n_embd,
            out_features=3 * config.n_embd,
            lora_rank=config.lora_rank,
            lora_alpha=config.lora_alpha,
        )
        # output projection
        self.c_proj = LoRALinear(
            in_features=config.n_embd,
            out_features=config.n_embd,
            lora_rank=config.lora_rank,
            lora_alpha=config.lora_alpha,
        )

In [237]:
from mingpt.model import Block, NewGELU

class Block_LoRA(Block):
    """ an unassuming Transformer block """

    def __init__(self, config):
        super().__init__(config)
        # minor modification
        self.attn = CausalSelfAttention_LoRA(config)

Same thing for the GPT module and you can simplify the configuration of the optimizer for the LoRA module

In [238]:
from mingpt.model import GPT

class GPT_LoRA(GPT):
    def __init__(self, config):
        super().__init__(config)
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.embd_pdrop),
            h = nn.ModuleList([Block_LoRA(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.config = config
        # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

    def configure_optimizers(self, train_config):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        ###
        # your code here
        
        
        ###
        
        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
                # random note: because named_modules and named_parameters are recursive
                # we will see the same tensors p many many times. but doing it this way
                # allows us to know which parent module any tensor p belongs to...
                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

## Learning to sort

We use the [demo](https://github.com/karpathy/minGPT/blob/master/demo.ipynb) to check that our code is running fine!

In [239]:
@dataclass
class Config:
    n_head = 3
    n_embd = 15
    block_size = 11
    # dropout hyperparameters
    embd_pdrop = 0.1
    resid_pdrop = 0.1
    attn_pdrop = 0.1
    # LoRA
    lora_rank = 8
    lora_alpha = 32

# create a GPT instance
model_config = GPT.get_default_config()
model_config.model_type = 'gpt-nano'
model_config.vocab_size = 3
model_config.block_size = 100
model_config.lora_rank = 8
model_config.lora_alpha = 32

model = GPT_LoRA(model_config)

number of parameters: 0.09M


In [240]:
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from mingpt.utils import set_seed
set_seed(3407)
import pickle

class SortDataset(Dataset):
    """ 
    Dataset for the Sort problem. E.g. for problem length 6:
    Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2
    Which will feed into the transformer concatenated as:
    input:  0 0 2 1 0 1 0 0 0 1 1
    output: I I I I I 0 0 0 1 1 2
    where I is "ignore", as the transformer is reading the input sequence
    """

    def __init__(self, split, length=6, num_digits=3):
        assert split in {'train', 'test'}
        self.split = split
        self.length = length
        self.num_digits = num_digits
    
    def __len__(self):
        return 10000 # ...
    
    def get_vocab_size(self):
        return self.num_digits
    
    def get_block_size(self):
        # the length of the sequence that will feed into transformer, 
        # containing concatenated input and the output, but -1 because
        # the transformer starts making predictions at the last input element
        return self.length * 2 - 1

    def __getitem__(self, idx):
        
        # use rejection sampling to generate an input example from the desired split
        while True:
            # generate some random integers
            inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)
            # half of the time let's try to boost the number of examples that 
            # have a large number of repeats, as this is what the model seems to struggle
            # with later in training, and they are kind of rate
            if torch.rand(1).item() < 0.5:
                if inp.unique().nelement() > self.length // 2:
                    # too many unqiue digits, re-sample
                    continue
            # figure out if this generated example is train or test based on its hash
            h = hash(pickle.dumps(inp.tolist()))
            inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test
            if inp_split == self.split:
                break # ok
        
        # solve the task: i.e. sort
        sol = torch.sort(inp)[0]

        # concatenate the problem specification and the solution
        cat = torch.cat((inp, sol), dim=0)

        # the inputs to the transformer will be the offset sequence
        x = cat[:-1].clone()
        y = cat[1:].clone()
        # we only want to predict at output locations, mask out the loss at the input locations
        y[:self.length-1] = -1
        return x, y

In [241]:
# print an example instance of the dataset
train_dataset = SortDataset('train')
test_dataset = SortDataset('test')

In [242]:
model_config = GPT.get_default_config()
model_config.model_type = 'gpt-nano'
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = 24 #train_dataset.get_block_size()
model_config.lora_rank = 8
model_config.lora_alpha = 32
model_config.lora_dropout = 0
model = GPT_LoRA(model_config)

number of parameters: 0.09M


In [243]:
# create a Trainer object
from mingpt.trainer import Trainer

train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
train_config.max_iters = 1000#2000
train_config.num_workers = 0
trainer = Trainer(train_config, model, train_dataset)

running on device cuda


In [244]:
def batch_end_callback(trainer):
    if trainer.iter_num % 100 == 0:
        print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
trainer.set_callback('on_batch_end', batch_end_callback)

trainer.run()

AssertionError: parameters {'transformer.h.0.attn.c_attn.lora_B', 'transformer.h.2.attn.c_proj.lora_B', 'transformer.h.1.attn.c_attn.lora_B', 'transformer.h.2.attn.c_proj.lora_A', 'transformer.h.0.attn.c_attn.lora_A', 'transformer.h.2.attn.c_attn.lora_B', 'transformer.h.1.attn.c_proj.lora_A', 'transformer.h.1.attn.c_proj.lora_B', 'transformer.h.2.attn.c_attn.lora_A', 'transformer.h.0.attn.c_proj.lora_B', 'transformer.h.1.attn.c_attn.lora_A', 'transformer.h.0.attn.c_proj.lora_A'} were not separated into either decay/no_decay set!

In [None]:
# now let's perform some evaluation
model.eval();
dataset = {'train':train_dataset, 'test':test_dataset}
def eval_split(trainer, split, max_batches, dataset=dataset):
    dataset = dataset[split]
    n = dataset.length # naugy direct access shrug
    results = []
    mistakes_printed_already = 0
    loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)
    for b, (x, y) in enumerate(loader):
        x = x.to(trainer.device)
        y = y.to(trainer.device)
        # isolate the input pattern alone
        inp = x[:, :n]
        sol = y[:, -n:]
        # let the model sample the rest of the sequence
        cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling
        sol_candidate = cat[:, -n:] # isolate the filled in sequence
        # compare the predicted sequence to the true sequence
        correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha
        for i in range(x.size(0)):
            results.append(int(correct[i]))
            if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense
                mistakes_printed_already += 1
                print("GPT claims that %s sorted is %s but gt is %s" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))
        if max_batches is not None and b+1 >= max_batches:
            break
    rt = torch.tensor(results, dtype=torch.float)
    print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean()))
    return rt.sum()

# run a lot of examples from both train and test through the model and verify the output correctness
with torch.no_grad():
    train_score = eval_split(trainer, 'train', max_batches=50)
    test_score  = eval_split(trainer, 'test',  max_batches=50)

Now we modifiy the distribution of the dataset a little bit and use LoRA to fine-tune.

In [None]:
train_dataset2 = SortDataset('train',length=10)
test_dataset2 = SortDataset('test',length=10)

In [None]:
dataset2 = {'train':train_dataset2, 'test':test_dataset2}
with torch.no_grad():
    train_score = eval_split(trainer, 'train', max_batches=50, dataset=dataset2)
    test_score  = eval_split(trainer, 'test',  max_batches=50, dataset=dataset2)

In [None]:
# your code here for training with LoRA

In [None]:
model.eval();
with torch.no_grad():
    train_score = eval_split(trainer, 'train', max_batches=50, dataset=dataset2)
    test_score  = eval_split(trainer, 'test',  max_batches=50, dataset=dataset2)
