# LLM Activation Fine-Tuning Demo

In [None]:
# Notebook Deps
from IPython.lib.pretty import pretty
from pprint import pprint


## Overview

In this demo, we will explore an **experimental** workflow for fine-tuning models on top of activations retreived from foundational models hosted on the Vector cluster. We will briefly demonstrate the a few fundamental concepts in the following sections:

* Text generation
* Model querying and activation generation
* Fine-tuning

We will be interfacing with a deployment of the Open Pre-trained Transformers (OPT). This demonstration will utilize the small OPT-125M parameter model for simplicity, however, the workflow remains the same when querying larger models.

## Text Generation


The Vector OPT Client class will be our primary tool for querying the OPT deployment. The OPT deployment exposes a RESTful API which is conveniently wrapped by the Client. 

### Client Initialization

In [None]:
from opt_client import Client

In [None]:
OPT_HOST = "172.17.8.104"
OPT_PORT = "6969"

client = Client(host=OPT_HOST, port=OPT_PORT)

The client provides a set of functions for interacting with the remote model. For example, we can use the ``generate`` function to perform text generation.

In [None]:
prompt = "Hello World"
response = client.generate(prompt)

print("Prompt: ", prompt)
print("Generation: ", response['choices'][0]['text'])

We can also pass in an array of prompts and adjust hyperparamters.

In [None]:
prompts = [
    "Hello World",
    "Fizz Buzz"
]

response = client.generate(prompts, temperature=0.8)


for prompt, generation in zip(prompts, response['choices']):
    print("Prompt: ", prompt)
    print("Generation: ", generation['text'], "\n")

## Activation Generation

Activation generation is also quite easy. We can use the client to query the remote model and explore the various modules. 

In [None]:
client.module_names

We can select the module names of interest and pass them into a ``get_activations`` function alongside our set of prompts.

In [None]:
module_names = ['decoder.layers.11.fc2']

response = client.get_activations(prompts, module_names)

pprint(response)
print("Tensor Shape:", response[0]['decoder.layers.11.fc2'].shape)

### IMDB Sentiment Classification Dataset

For a set of prompts, we can use the ``get_activations`` function to generate a set of activations that can be cached for offline use. Let's now take a look at how to do this using the [IMDB Sentiment Analysis Dataset](https://huggingface.co/datasets/imdb).

The dataset consists of 25,000 highly polar movie reviews for training, and another 25,000 for testing. The task is binary sentiment classification, where labels indicate either a positive or negative review.

Sample Review:
> "This movie sucked. It really was a waste of my life. The acting was atrocious, the plot completely implausible. Long, long story short, these people get "terrorized" by this pathetic "crazed killer", but completely fail to fight back in any manner. And this is after they take a raft on a camping trip, with no gear, and show up at a campsite that is already assembled and completely stocked with food and clothes and the daughters headphones. Additionally, after their boat goes missing, they panic that they're stuck in the woods, but then the daughters boyfriend just shows up and they apparently never consider that they could just hike out of the woods like he did to get to them. Like I said, this movie sucks. A complete joke. Don't let your girlfriend talk you into watching it."

Label: ``Negative``

The dataset is provided by the HuggingFace ``datasets`` package. 

In [None]:
import pickle
from tqdm import tqdm
from datasets import load_dataset

In [None]:
imdb = load_dataset("imdb")

For demonstration purposes, let's filter this dataset down to 100 training samples and 100 test samples. The ``generate_dataset_activations`` function will then generate and cache a pickled set of activations.

In [None]:
small_train_dataset = imdb["train"].shuffle(seed=42).select([i for i in list(range(100))])
small_test_dataset = imdb["test"].shuffle(seed=42).select([i for i in list(range(100))])

In [None]:
def batcher(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))

def generate_dataset_activations(split, dataset, client):

    print("Generating Activations: " + split)

    module_names = [
        'decoder.layers.11.fc2'
    ]

    activations = []
    BATCH_SIZE = 16
    for batch in tqdm(batcher(dataset, BATCH_SIZE), total=int(len(dataset)/BATCH_SIZE)):
        prompts = batch['text']
        activations.append(client.get_activations(prompts, module_names))

    parsed_activations = []
    for batch in activations:
        for prompt_activation in batch:
            parsed_activations.append(prompt_activation['decoder.layers.11.fc2'])

    cached_activations = {
        'activations': parsed_activations,
        'labels': dataset['label']
    }

    with open(split + '_activations_demo.pkl', 'wb') as handle:
        pickle.dump(cached_activations, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
generate_dataset_activations('train', small_train_dataset, client)
generate_dataset_activations('test', small_test_dataset, client)

## Fine-Tuning


The cached activations can be loaded from disk to faciliate the fine-tuning of a classification model on the sentiment analysis task.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils.rnn as rnn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

Let's define an Activation Dataset which will load our activations from disk.

In [None]:
class ActivationDataset(Dataset):

    def __init__(self, activations_path):
        self._load_activations(activations_path)
        
    def _load_activations(self, path):
        with open(path, 'rb') as handle:
            cached_activations = pickle.load(handle)
        self.activations = cached_activations['activations']
        self.labels = cached_activations['labels']
    
    def __len__(self):
        return len(self.activations)

    def __getitem__(self, idx):
        return self.activations[idx], self.labels[idx]

We will be performing classification on the last token of the sequence, common practive for autoregressive models (e.g. GPT-3). The following ``batch_last_token`` collate function will be passed into the dataloader to extract the last token activation from each sequence. 

In [None]:
def batch_last_token(batch):
    (x, y) = zip(*batch)
    
    x = torch.stack([seq[-1] for seq in x])

    return x, y

And a MLP to perform the classification. 

In [None]:
class MLP(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.linear = nn.Linear(cfg['embedding_dim'], cfg['hidden_dim'], bias=False)
        self.out = nn.Linear(cfg['hidden_dim'], cfg['label_dim'])

    def forward(self, x):
        x = F.relu(self.linear(x))
        x = self.out(x)  
        return x

In [None]:
train_dataset = ActivationDataset('./train_activations.pkl')
test_dataset = ActivationDataset('./test_activations.pkl')

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=batch_last_token)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True, collate_fn=batch_last_token)

We can now write a relatively simple script to train and evaluate our model.

In [None]:
model = MLP({
        "embedding_dim": 768,
        "hidden_dim": 128,
        "label_dim": 2
    })
model.cuda()

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

NUM_EPOCHS = 10
pbar = tqdm(range(NUM_EPOCHS))
for epoch_idx in pbar:

    pbar.set_description("Epoch: %s" % epoch_idx)
    training_params = {
        "Train-Loss": 0.0,
        "Test-Accuracy": 0.0
    }
    pbar.set_postfix(training_params)

    model.train()
    for batch in train_dataloader:

        activations, labels = batch
        activations = activations.float().cuda()
        labels = torch.tensor(labels).cuda()

        optimizer.zero_grad()

        logits = model(activations)
        loss = loss_fn(logits, labels)

        loss.backward()
        optimizer.step()

        training_params["Train-Loss"] = loss.detach().item()
        pbar.set_postfix(training_params)

    model.eval()
    with torch.no_grad():
        predictions = []
        for batch in test_dataloader:
            activations, labels = batch
            activations = activations.float().cuda()
            labels = torch.tensor(labels).cuda()

            logits = model(activations)
            predictions.extend((logits.argmax(dim=1) == labels)) 


        accuracy = torch.stack(predictions).sum() / len(predictions)

        training_params["Test-Accuracy"] = accuracy.detach().item()
        pbar.set_postfix(training_params)


## Conclusion

We have demonstrated a simple workflow for interfacing with remote models on the Vector cluster. Additionally, we have show how you can generate and cache activations for fine-tuning models offline.