In [1]:
import os
import pickle
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import cuda
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


### Fine-Tuning

The cached activations can be loaded from disk to faciliate the fine-tuning of a classification model on the sentiment analysis task.

In [2]:
activations_path = "./resources/"

Let's define an Activation Dataset which will load our activations from disk.

In [3]:
class ActivationDataset(Dataset):
    def __init__(self, activations_path: str) -> None:
        self._load_activations(activations_path)

    def _load_activations(self, path: str) -> None:
        with open(path, "rb") as handle:
            cached_activations = pickle.load(handle)
        self.activations = cached_activations["activations"]
        self.labels = cached_activations["labels"]

    def __len__(self) -> int:
        return len(self.activations)

    def __getitem__(self, idx: int) -> Tuple[List[float], int]:
        return self.activations[idx], self.labels[idx]

We will be performing classification on the last token of the sequence, common practice for autoregressive models (e.g. GPT-3). The following batch_last_token collate function will be passed into the dataloader to stack the last token activation from each sequence.

In [4]:
def batch_last_token(batch: List[Tuple[torch.Tensor, int]]) -> Tuple[torch.Tensor, int]:
    last_token_activations: List[torch.Tensor] = []
    labels: List[int] = []
    for activations, label in batch:
        last_token_activations.append(activations)
        labels.append(label)

    activation_batch = torch.stack(last_token_activations)

    return activation_batch, labels  # type: ignore

And an MLP to perform the classification.

In [5]:
class MLP(nn.Module):
    def __init__(self, cfg: Dict[str, int]) -> None:
        super().__init__()
        self.linear = nn.Linear(cfg["embedding_dim"], cfg["hidden_dim"], bias=False)
        self.out = nn.Linear(cfg["hidden_dim"], cfg["label_dim"])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.linear(x))
        x = self.out(x)
        return x

#### Train and Test Model for Activations without Prompts

In [6]:
train_dataset = ActivationDataset(os.path.join(activations_path, "train_activations_demo.pkl"))
test_dataset = ActivationDataset(os.path.join(activations_path, "test_activations_demo.pkl"))

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=batch_last_token)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True, collate_fn=batch_last_token)

We can now write a relatively simple script to train and evaluate our model.

In [8]:
model = MLP({"embedding_dim": 12288, "hidden_dim": 128, "label_dim": 2})
device = "cuda" if cuda.is_available() else "cpu"
model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.001)

NUM_EPOCHS = 25
pbar = tqdm(range(NUM_EPOCHS))
for epoch_idx in pbar:
    pbar.set_description("Epoch: %s" % epoch_idx)
    training_params = {"Train-Loss": 0.0, "Test-Accuracy": 0.0}
    pbar.set_postfix(training_params)

    model.train()
    for batch in train_dataloader:
        activations, labels = batch
        activations = activations.to(device)
        labels = torch.tensor(labels).to(device)

        optimizer.zero_grad()

        logits = model(activations)
        loss = loss_fn(logits, labels)

        loss.backward()
        optimizer.step()

        training_params["Train-Loss"] = loss.detach().item()
        pbar.set_postfix(training_params)

    model.eval()
    with torch.no_grad():
        predictions = []
        for batch in test_dataloader:
            activations, labels = batch
            activations = activations.float().to(device)
            labels = torch.tensor(labels).to(device)

            logits = model(activations)
            predictions.extend((logits.argmax(dim=1) == labels))

        accuracy = torch.stack(predictions).sum() / len(predictions)

        training_params["Test-Accuracy"] = accuracy.detach().item()
        pbar.set_postfix(training_params)

Epoch: 24: 100%|██████████| 25/25 [00:00<00:00, 31.38it/s, Train-Loss=0.0312, Test-Accuracy=0.763]


#### Train and Test Model for Activations with Prompts

In [9]:
train_dataset = ActivationDataset(os.path.join(activations_path, "train_activations_with_prompts_demo.pkl"))
test_dataset = ActivationDataset(os.path.join(activations_path, "test_activations_with_prompts_demo.pkl"))

In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=batch_last_token)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True, collate_fn=batch_last_token)

We can now write a relatively simple script to train and evaluate our model.

In [11]:
model = MLP({"embedding_dim": 12288, "hidden_dim": 128, "label_dim": 2})
device = "cuda" if cuda.is_available() else "cpu"
model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.001)

NUM_EPOCHS = 25
pbar = tqdm(range(NUM_EPOCHS))
for epoch_idx in pbar:
    pbar.set_description("Epoch: %s" % epoch_idx)
    training_params = {"Train-Loss": 0.0, "Test-Accuracy": 0.0}
    pbar.set_postfix(training_params)

    model.train()
    for batch in train_dataloader:
        activations, labels = batch
        activations = activations.to(device)
        labels = torch.tensor(labels).to(device)

        optimizer.zero_grad()

        logits = model(activations)
        loss = loss_fn(logits, labels)

        loss.backward()
        optimizer.step()

        training_params["Train-Loss"] = loss.detach().item()
        pbar.set_postfix(training_params)

    model.eval()
    with torch.no_grad():
        predictions = []
        for batch in test_dataloader:
            activations, labels = batch
            activations = activations.float().to(device)
            labels = torch.tensor(labels).to(device)

            logits = model(activations)
            predictions.extend((logits.argmax(dim=1) == labels))

        accuracy = torch.stack(predictions).sum() / len(predictions)

        training_params["Test-Accuracy"] = accuracy.detach().item()
        pbar.set_postfix(training_params)

Epoch: 24: 100%|██████████| 25/25 [00:00<00:00, 33.77it/s, Train-Loss=0.0813, Test-Accuracy=0.96] 
