In [1]:
%pip install -q petals datasets wandb scikit-learn

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.5/88.5 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m38.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m215.3/215.3 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.2/9.2 MB[0m [31m99.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m115.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import os

import torch
import transformers
import wandb
from datasets import load_dataset
from tqdm import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import BloomTokenizerFast, get_scheduler

from petals import DistributedBloomForCausalLM

In [3]:
# Choose a model you'd like to prompt-tune. We recommend starting with
# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.
# Once your code is ready, you can switch to full-scale
# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).
MODEL_NAME = "bigscience/bloom-7b1-petals"

# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').
# The latter fine-tunes separate prefixes for each transformer block,
# so prompt-tuning will take more time but yield better results.
# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf
TUNING_MODE = 'ptune'

NUM_PREFIX_TOKENS = 16
DEVICE = 'cuda'
BATCH_SIZE = 8
LR = 1e-2
WEIGHT_DECAY = 0.0
NUM_SAMPLES = 1000
SEED = 42
MODEL_MAX_LENGTH = 256

Prepare tokenizer and distributed model, connect it to servers.

In [5]:
tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)
tokenizer.padding_side = 'right'
tokenizer.model_max_length = MODEL_MAX_LENGTH
model = DistributedBloomForCausalLM.from_pretrained(
    MODEL_NAME,
    pre_seq_len=NUM_PREFIX_TOKENS, 
    tuning_mode=TUNING_MODE
).to(DEVICE)

Downloading (…)okenizer_config.json:   0%|          | 0.00/322 [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/96.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/786 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/2.06G [00:00<?, ?B/s]

May 01 13:12:13.940 [[1m[34mINFO[0m] Prompt embeddings and their optimizer statistics will be kept in float32 to increase ptune quality


Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization.

In [6]:
dataset = load_dataset("bavard/personachat_truecased")


def chunking(examples):
    inputs = [
        "\n-----\n".join(history) + "\n-----\n" + candidate
        for history, candidates in zip(examples["history"], examples["candidates"])
        for candidate in candidates
    ]
    return {"chunks": inputs}


def tokenize(examples):
    outputs = {
        "input_ids": tokenizer(examples["chunks"], padding='max_length', truncation=True)["input_ids"]
    }
    outputs["labels"] = outputs["input_ids"]
    return outputs


tokenized_datasets = (
    dataset
        .map(chunking, batched=True, remove_columns=dataset["train"].column_names)
        .map(tokenize, batched=True, remove_columns=["chunks"])
)


tokenized_datasets.set_format("torch")
train_dataset = tokenized_datasets["train"].shuffle(seed=SEED)
train_dataloader = DataLoader(
    train_dataset.select(list(range(NUM_SAMPLES))),
    shuffle=True,
    batch_size=BATCH_SIZE,
    drop_last=True,
)

Downloading builder script:   0%|          | 0.00/3.73k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/5.47k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/4.00k [00:00<?, ?B/s]

May 01 13:12:52.349 [[1m[38;5;208mWARN[0m] [[1mdatasets.builder._create_builder_config:465[0m] No config specified, defaulting to: personachat_truecased/full


Downloading and preparing dataset personachat_truecased/full (download: 195.70 MiB, generated: 210.99 MiB, post-processed: Unknown size, total: 406.69 MiB) to /root/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/193M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/131438 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7801 [00:00<?, ? examples/s]

Dataset personachat_truecased downloaded and prepared to /root/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/131438 [00:00<?, ? examples/s]

Map:   0%|          | 0/7801 [00:00<?, ? examples/s]

Map:   0%|          | 0/2628760 [00:00<?, ? examples/s]

Map:   0%|          | 0/156020 [00:00<?, ? examples/s]

Before setting up optimizers, check the model parameters that will be trained.

In [7]:
for n, p in model.named_parameters():
    if p.requires_grad:
        print(n, p.requires_grad, p.device)

transformer.prompt_embeddings.weight True cuda:0


The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler.

In [8]:
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)
)

Let's initialize wandb for logging and start the training loop!

In [9]:
wandb.init(
    project="bloom-personachat",
    config={
        "num_samples": NUM_SAMPLES,
        "batch_size": BATCH_SIZE,
        "learning_rate": LR,
        "weight_decay": WEIGHT_DECAY,
        "num_prefix_tokens": NUM_PREFIX_TOKENS,
        "model_name": MODEL_NAME,
        "seed": SEED,
    }
)

for batch in tqdm(train_dataloader):
    batch = {k: v.to(DEVICE) for k, v in batch.items()}

    model.train()
    outputs = model(**batch)
    loss = outputs.loss6+
    loss.backward()

    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()

    wandb.log({"Train Loss": loss})

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


100%|██████████| 125/125 [2:19:25<00:00, 66.93s/it]


Try to talk with the trained model! Submit an empty input to stop the execution.


__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster "interactive" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica.

In [1]:
TOP_K = 100
TEMPERATURE = 0.6

with model.inference_session(max_length=512) as sess:
    while True:
        user_phrase = input()
        if len(user_phrase) == 0:
            break
        inputs = tokenizer([f"{user_phrase}\n-----\n"], return_tensors='pt')['input_ids'].to(DEVICE)
        while True:
            outputs = model.generate(
                inputs,
                temperature=TEMPERATURE,
                do_sample=True,
                top_k=TOP_K,
                max_new_tokens=1,
                session=sess,
            )
            bloom_answer_token = tokenizer.decode(outputs[0, -1:])
            print(bloom_answer_token, end="", flush=True)
            if bloom_answer_token == "\n":
                break
            inputs = None

NameError: ignored