In [10]:
%load_ext autoreload
%autoreload 2

import torch
import os

from mingpt.utils import set_seed
from mingpt.trainer import Trainer, PrefixTrainer, LoRATrainer
from mingpt.model import GPT
from mingpt.data_tools import CustomDataset, eval, batch_end_callback, attention_visualization, label_batch

import seaborn as sns
import matplotlib.pyplot as plt
from typing import Optional, List

from minlora import get_lora_state_dict

set_seed(1234)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
prefix_size = 12
train_dataset_random = CustomDataset('train', mode="random", prefix_padding=prefix_size)
train_dataset_ascending = CustomDataset('train', mode="ascending", prefix_padding=prefix_size)
train_dataset_descending = CustomDataset('train', mode="descending", prefix_padding=prefix_size)
train_dataset_add1 = CustomDataset('train', mode="add1", prefix_padding=prefix_size)
train_dataset_add2 = CustomDataset('train', mode="add2", prefix_padding=prefix_size)
train_dataset_ascending_add1 = CustomDataset('train', mode="ascending_add1", prefix_padding=prefix_size)
train_dataset_double_hist = CustomDataset('train', mode="double_hist", prefix_padding=prefix_size)

test_dataset_random = CustomDataset('test', mode="random", prefix_padding=prefix_size)
test_dataset_ascending = CustomDataset('test', mode="ascending", prefix_padding=prefix_size)
test_dataset_descending = CustomDataset('test', mode="descending", prefix_padding=prefix_size)
test_dataset_add1 = CustomDataset('test', mode="add1", prefix_padding=prefix_size)
test_dataset_add2 = CustomDataset('test', mode="add2", prefix_padding=prefix_size)
test_dataset_ascending_add1 = CustomDataset('test', mode="ascending_add1", prefix_padding=prefix_size)
test_dataset_double_hist = CustomDataset('test', mode="double_hist", prefix_padding=prefix_size)

Let's load the pretrained model (if you don't have it, run notebook 03 first).

In [5]:
# create a GPT instance
model_config = GPT.get_default_config()
model_config.model_type = None
model_config.vocab_size = train_dataset_random.get_vocab_size()
model_config.block_size = train_dataset_random.get_block_size()
model_config.n_layer = 4
model_config.n_head = 4
model_config.n_embd = 256
model_config.batch_size = 512
model = GPT(model_config)

fname = '04_pretrained.pth'
if os.path.exists(fname):
    print("Loading weights from cache, won't train from scratch.")
    model.load_state_dict(torch.load(fname))
    model.config = model_config
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
else:
    # create a Trainer object
    train_config = Trainer.get_default_config()
    train_config.learning_rate = 5e-4
    train_config.max_iters = 40000
    train_config.num_workers = 0
    trainer = Trainer(train_config, model, train_dataset_random)
    trainer.set_callback('on_batch_end', batch_end_callback)
    trainer.run()
    device = trainer.device

    # save the model weights:
    torch.save(model.state_dict(), fname)

number of parameters: 3.17M
Loading weights from cache, won't train from scratch.


Check that the pretrained model has zero accuracy at double histogram.

In [6]:
print("Pretrained performance on the double_hist dataset:")
_ = eval(model, dataset=test_dataset_double_hist, device=device, max_batches=32)

Pretrained performance on the double_hist dataset:
Final score: 0/3200 = 0.00% correct


A prefix for this task still has close to 0% accuracy.

In [8]:
fname = f'04_prefix_double_hist.pth'
if os.path.exists(fname):
    prefix = torch.load(fname)
    print(f"Prefix double_histloaded from cache.")
else:
    prefix  = torch.randn((model.config.n_layer,prefix_size, model.config.n_embd), requires_grad=True, device=device)
    train_config = Trainer.get_default_config()
    train_config.num_workers = 0
    train_config.max_iters = 100_000
    train_config.learning_rate = 5e-5
    trainer = PrefixTrainer(train_config, model, train_dataset_double_hist, prefix)
    trainer.set_callback('on_batch_end', batch_end_callback)
    trainer.run()
    torch.save(prefix, fname)
print("Performance on the double_hist dataset with prefix:")
_ = eval(model, dataset=test_dataset_double_hist, device=device, max_batches=32, prefixes=prefix)

Prefix double_histloaded from cache.
Performance on the double_hist dataset with prefix:
Final score: 10/3200 = 0.31% correct


However, rank 1 LoRA update of the MLP weights for just a tenth of the training iterations results in high accuracy:

In [11]:
train_config = Trainer.get_default_config()
train_config.num_workers = 0
train_config.max_iters = 10_000
train_config.learning_rate = 5e-3
trainer = LoRATrainer(
    train_config, 
    model, 
    train_dataset_double_hist, 
    rank=1,
    device=device,
)
trainer.set_callback('on_batch_end', batch_end_callback)
trainer.run()
_ = eval(model, test_dataset_double_hist, device=device, max_batches=32)


running on device cuda
iter_dt  14.36ms; iter   1000: train loss 1.07757
iter_dt  14.93ms; iter   2000: train loss 1.02382
iter_dt  14.35ms; iter   3000: train loss 0.95937
iter_dt  16.79ms; iter   4000: train loss 0.84579
iter_dt  14.31ms; iter   5000: train loss 0.55840
iter_dt  16.43ms; iter   6000: train loss 0.18979
iter_dt  14.44ms; iter   7000: train loss 0.08735
iter_dt  14.87ms; iter   8000: train loss 0.07735
iter_dt  14.52ms; iter   9000: train loss 0.06246
iter_dt  16.38ms; iter  10000: train loss 0.08361
Final score: 2956/3200 = 92.38% correct


And that is despite the two fine-tuning approaches having the same number of learnable parameters. The limitated performance of prefix-tuning is not simply because of it using few parameters as LoRA with the same number of parameters solves the task. Therefore, prefix-tuning (and prompting) suffer unique structural limitations.

In [12]:
print(f"Number of LoRAparameters: {sum(p.numel() for p in get_lora_state_dict(model).values())}")
print(f"Number of prefix parameters: {torch.numel(prefix)} ")

Number of LoRAparameters: 12288
Number of prefix parameters: 12288 
