In [3]:
%load_ext autoreload
%autoreload 2

import torch
import os

from mingpt.utils import set_seed
from mingpt.trainer import Trainer, PrefixTrainer
from mingpt.model import GPT
from mingpt.data_tools import CustomDataset, eval, batch_end_callback, attention_visualization, label_batch
import mingpt.data_tools as tasks

import seaborn as sns
import matplotlib.pyplot as plt
from typing import Optional, List

set_seed(1234)

We saw that prefix-tuning cannot learn a new task.
However, we hypothesize that if the model has been pre-trained on various tasks, prefix-tuning can elucidate one of them.
In this notebook we demonstrate that this is the case.
We will pre-train a model on four different tasks: sort in ascending and descending order, or add one or two to each element of the list.
The model will recieve no indication of which task it needs to solve so it will learn to put approximately equal probability on all 4 completions, resulting in about 25% accuracyt.
However, by learning prefixes, we can constrain the output distribution to only one of the tasks, thus demonstrating that while prefix-tuning cannot learn a completely new task, it can specialize a model to one of the tasks it has already seen.
We will also show that the exact same method fails to learn a different new task (double histogram) which requires a novel attention pattern.

First let's prepare our datasets:

In [4]:
prefix_size = 0

print("> Sample from the ascending dataset:")
x, y = CustomDataset('train', tasks=[tasks.SortDescending()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the descending dataset:")
x, y = CustomDataset('train', tasks=[tasks.SortAscending()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the InverseBinary dataset:")
x, y = CustomDataset('train', tasks=[tasks.InverseBinary()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the Add1 dataset:")
x, y = CustomDataset('train', tasks=[tasks.Add1()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the Add2 dataset:")
x, y = CustomDataset('train', tasks=[tasks.Add1()+tasks.Add1()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the DoubleHistogram dataset:")
x, y = CustomDataset('train', tasks=[tasks.DoubleHistogram()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the Modulo dataset:")
x, y = CustomDataset('train', tasks=[tasks.Modulo()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the LessThan dataset:")
x, y = CustomDataset('train', tasks=[tasks.LessThan()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the MoreThanEqual dataset:")
x, y = CustomDataset('train', tasks=[tasks.MoreThanEqual()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the Divisible dataset:")
x, y = CustomDataset('train', tasks=[tasks.Divisible()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the NotDivisible dataset:")
x, y = CustomDataset('train', tasks=[tasks.NotDivisible()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the FilterAtLeastNTimes dataset:")
x, y = CustomDataset('train', tasks=[tasks.FilterAtLeastNTimes()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the ascending + add dataset:")
x, y = CustomDataset('train', tasks=[tasks.SortAscending()+tasks.Add1()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the add + less than dataset:")
x, y = CustomDataset('train', tasks=[tasks.Add1()+tasks.LessThan()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the less than + add dataset:")
x, y = CustomDataset('train', tasks=[tasks.LessThan()+tasks.Add1()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the less than + ascending dataset:")
x, y = CustomDataset('train', tasks=[tasks.LessThan()+tasks.SortAscending()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

print("> Sample from the ascending + divisible dataset:")
x, y = CustomDataset('train', tasks=[tasks.LessThan()+tasks.SortAscending()], prefix_padding=prefix_size, num_digits=10)[0]
print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))


print("> Samples from the random training dataset:")
training_tasks = [
    tasks.SortAscending(),
    tasks.SortDescending(),
    tasks.Add1(),
    tasks.Modulo(),
    tasks.LessThan(),
    tasks.Divisible(),
    tasks.InverseBinary(),
]

testing_tasks = [
    tasks.MoreThanEqual(),
    tasks.NotDivisible(),
    tasks.DoubleHistogram(),
    tasks.FilterAtLeastNTimes(),
    tasks.SortAscending()+tasks.Add1(),
    tasks.Add1() + tasks.LessThan(),
    tasks.LessThan() + tasks.Add1(),
    tasks.LessThan() + tasks.SortAscending(),
    tasks.Divisible() + tasks.Add1(),
]

train_dataset_random = CustomDataset('train', prefix_padding=prefix_size, num_digits=10, tasks=training_tasks)
for i in range(10):
    x, y = train_dataset_random[i]
    print("".join(map(lambda x: f"{x:>4}", x.tolist()[:10])) + " ->" + "".join(map(lambda x: f"{x:>4}", y.tolist()[9:])))

> Sample from the ascending dataset:
   6   2   7   6   7   5   3   6   6  10 ->  10   7   7   6   6   6   6   5   3   2
> Sample from the descending dataset:
   1   5  10   3   2  10   3   3  10   5 ->   1   2   3   3   3   5   5  10  10  10
> Sample from the InverseBinary dataset:
   1   0   0   0   1   1   0   0   0   0 ->   0   1   1   1   0   0   1   1   1   1
> Sample from the Add1 dataset:
   2   8   8   2   3   4   3   2   3   8 ->   3   9   9   3   4   5   4   3   4   9
> Sample from the Add2 dataset:
   8   7   2   9   2   9   2   7  10   4 ->  10   9   4  11   4  11   4   9  12   6
> Sample from the DoubleHistogram dataset:
   5  10   6   3   8   7   7   8   2   8 ->   1   1   1   1   3   2   2   3   1   3
> Sample from the Modulo dataset:
   5   7   8   8   3   9   6   8   3   3 ->   0   2   3   3   3   4   1   3   3   3
> Sample from the LessThan dataset:
   5   4   7   3   4   2   2   1   3   6 ->   0   4   0   3   4   2   2   1   3   0
> Sample from the MoreThanEqual dat

In [5]:
prefix_size = 10
train_dataset_random = CustomDataset('train', prefix_padding=prefix_size, num_digits=10, tasks=training_tasks)
test_dataset_random = CustomDataset('test', prefix_padding=prefix_size, num_digits=10, tasks=training_tasks)

train_dataset_SortAscending = CustomDataset('train', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.SortAscending()])
train_dataset_SortDescending = CustomDataset('train', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.SortDescending()])
train_dataset_Add1 = CustomDataset('train', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.Add1()])
train_dataset_Modulo = CustomDataset('train', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.Modulo()])
train_dataset_LessThan = CustomDataset('train', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.LessThan()])
train_dataset_Divisible = CustomDataset('train', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.Divisible()])
train_dataset_InverseBinary = CustomDataset('train', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.InverseBinary()])

test_dataset_SortAscending = CustomDataset('test', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.SortAscending()])
test_dataset_SortDescending = CustomDataset('test', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.SortDescending()])
test_dataset_Add1 = CustomDataset('test', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.Add1()])
test_dataset_Modulo = CustomDataset('test', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.Modulo()])
test_dataset_LessThan = CustomDataset('test', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.LessThan()])
test_dataset_Divisible = CustomDataset('test', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.Divisible()])
test_dataset_InverseBinary = CustomDataset('test', prefix_padding=prefix_size, num_digits=10, tasks=[tasks.InverseBinary()])

Let's pre-train the model on `train_dataset_random` and check its accuracy:

In [6]:
# create a GPT instance
model_config = GPT.get_default_config()
model_config.model_type = None
model_config.vocab_size = train_dataset_random.get_vocab_size()
model_config.block_size = train_dataset_random.get_block_size()
model_config.n_layer = 12
model_config.n_head = 8
model_config.n_embd = 512
model_config.batch_size = 512
model = GPT(model_config)

fname = '06_pretrained.pth'
if os.path.exists(fname):
    print("Loading weights from cache, won't train from scratch.")
    model.load_state_dict(torch.load(fname))
    model.config = model_config
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
else:
    # create a Trainer object
    train_config = Trainer.get_default_config()
    train_config.learning_rate = 5e-4
    train_config.max_iters = 100000
    train_config.num_workers = 0
    trainer = Trainer(train_config, model, train_dataset_random)
    trainer.set_callback('on_batch_end', batch_end_callback)
    trainer.run()
    device = trainer.device

    # save the model weights:
    torch.save(model.state_dict(), fname)

number of parameters: 37.85M
Loading weights from cache, won't train from scratch.


In [7]:
print("Performance on a randomly labeled dataset:")
_ = eval(model, dataset=test_dataset_random, device=device, max_batches=32)
print("Performance on the ascending dataset:")
_ = eval(model, dataset=test_dataset_SortAscending, device=device, max_batches=32)
print("Performance on the descending dataset:")
_ = eval(model, dataset=test_dataset_SortDescending, device=device, max_batches=32)
print("Performance on the add1 dataset:")
_ = eval(model, dataset=test_dataset_Add1, device=device, max_batches=32)
print("Performance on the Divisible dataset:")
_ = eval(model, dataset=test_dataset_Divisible, device=device, max_batches=32)
print("Performance on the LessThan dataset:")
_ = eval(model, dataset=test_dataset_LessThan, device=device, max_batches=32)
print("Performance on the Modulo dataset:")
_ = eval(model, dataset=test_dataset_Modulo, device=device, max_batches=32)
print("Performance on the InverseBinary dataset:")
_ = eval(model, dataset=test_dataset_InverseBinary, device=device, max_batches=32)

Performance on a randomly labeled dataset:
Final score: 1202/3200 = 37.56% correct
Performance on the ascending dataset:
Final score: 342/3200 = 10.69% correct
Performance on the descending dataset:
Final score: 666/3200 = 20.81% correct
Performance on the add1 dataset:
Final score: 799/3200 = 24.97% correct
Performance on the Divisible dataset:
Final score: 817/3200 = 25.53% correct
Performance on the LessThan dataset:
Final score: 147/3200 = 4.59% correct
Performance on the Modulo dataset:
Final score: 2360/3200 = 73.75% correct
Performance on the InverseBinary dataset:
Final score: 3200/3200 = 100.00% correct


As there is randomness involved, the model might not have 25% accuracy for all tasks, but it still does have about 25% accuracy overall.

Let's train a separate prefix for each task now:

In [9]:
prefixes = dict()

tasks = training_tasks+testing_tasks

for task, iterations, lr in zip(
    tasks,
    [20_000]*len(tasks),
    [5e-5]*len(tasks)
):
    task_name = str(task)
    fname = f'06_prefix_{task_name}.pth'
    train_dataset = CustomDataset('train', prefix_padding=prefix_size, num_digits=10, tasks=[task])
    test_dataset = CustomDataset('test', prefix_padding=prefix_size, num_digits=10, tasks=[task])
    if os.path.exists(fname):
        prefixes[task_name] = torch.load(fname)
        print(f"Prefix {task} loaded from cache.")
    else:
        print(f"TRAINING A PREFIX FOR THE {task_name.upper()} TASK:")
        prefixes[task_name]  = torch.randn((model.config.n_layer,prefix_size, model.config.n_embd), requires_grad=True, device=device)
        train_config = Trainer.get_default_config()
        train_config.num_workers = 0
        train_config.max_iters = iterations
        train_config.learning_rate = lr
        trainer = PrefixTrainer(train_config, model, train_dataset, prefixes[task_name])
        trainer.set_callback('on_batch_end', batch_end_callback)
        trainer.run()
        torch.save(prefixes[task_name], fname)
    _ = eval(model, dataset=test_dataset, device=device, max_batches=32, prefixes=prefixes[task_name])
    print()


Prefix SortAscending loaded from cache.
Final score: 3198/3200 = 99.94% correct

Prefix SortDescending loaded from cache.
Final score: 3196/3200 = 99.88% correct

Prefix Add1 loaded from cache.
Final score: 3200/3200 = 100.00% correct

Prefix Modulo loaded from cache.
Final score: 3200/3200 = 100.00% correct

Prefix LessThan loaded from cache.
Final score: 3200/3200 = 100.00% correct

Prefix Divisible loaded from cache.
Final score: 3200/3200 = 100.00% correct

Prefix InverseBinary loaded from cache.
Final score: 3200/3200 = 100.00% correct

TRAINING A PREFIX FOR THE MORETHANEQUAL TASK:


KeyError: 'train_dataset_MoreThanEqual'

With the task-specific prefixes, the accuracy gets close to 100% for all four tasks, as before.
Furthermore, we obtained high accuracy for the task of ascending and incrementing by one even though the pretrained model had never seen it before. 
This was likely successful because of this new task being a composition of two pretraining tasks.
In contrast, double histogram, another new task but one that is not a composition of pretraining tasks, cannot be solved with prefix tuning.
This further illustrates that prefix-tuning is unlikely to be able to learn a completely new task but is able to elicit a pretraining task or to learn a new task that can be solved with skills learned during pre-training.