In [1]:
%pip install -q tensor_parallel py7zr transformers datasets


### Train flan-t5-xl on text summarization

This notebook will teach you to fine-tune flan-t5-xl model to summarize texts on [samsum](https://huggingface.co/datasets/samsum) dataset.

- [__flan-t5-xl__](https://huggingface.co/google/flan-t5-xl/tree/main) is a large pre-trained transformer with 3 billion parameters, about 10 times larger than BERT-large
- [__tensor_parallel__](https://github.com/BlackSamorez/tensor_parallel) is a library that splits your model between GPUs in 2 lines of code

You can run this notebook on your own hardware or using __[kaggle's free cloud instances with dual T4](https://www.kaggle.com/product-feedback/361104)__ (requires phone verification).

The code is based on [this tutorial](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization), but supports way larger models. It was originally run on **four rusty 1080Ti**, just to prove that it's possible. If you're running on something more serious, you can probably tune batch size.

In [2]:
import torch
import tensor_parallel as tp
import transformers

tokenizer = transformers.AutoTokenizer.from_pretrained("google/flan-t5-xl")
model = transformers.T5ForConditionalGeneration.from_pretrained(
    "google/flan-t5-xl", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, offload_state_dict=True)

  from .autonotebook import tqdm as notebook_tqdm


### Apply tensor_parallel

You wrap the model - and it becomes parallel across GPUs.

In [3]:
model = tp.tensor_parallel(model)

Using automatic config: no tensor parallel config provided and no predefined configs can be used
Using ZeRO-3 sharding for 249856 non tensor-parallel parameters


In [4]:
input_ids = tokenizer("A cat sat on a mat", return_tensors="pt").input_ids.to("cuda")
output_ids = tokenizer("A cat sat did not sit on a mat", return_tensors="pt").input_ids.to("cuda")

# forward and backward works as usual
loss = model(input_ids=input_ids, labels=output_ids).loss
loss.backward()  # check nvidia-smi for gpu memory usage :)

### Fine-tuning

We are reusing basic code from [the official tutorial](https://github.com/huggingface/transformers/blob/main/examples/pytorch/summarization/run_summarization.py), except that our model is wrapped with tensor_parallel.

In [5]:
import datasets
data = datasets.load_dataset("samsum")
print("Example:", data['train'][25])

Found cached dataset samsum (/home/jheuristic/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 807.58it/s]

Example: {'id': '13810064', 'dialogue': 'Julius: dude, your assessment of manutd\r\nLawrence: i have nothing to say, im so offended and hopeless of them this season\r\nJulius: me too\r\nLawrence: i dont even know whats wrong with the team\r\nJulius: the quality is there but nothing is happening\r\nLawrence: the players look tired of something\r\nJulius:  with mourinhos conservative football!!\r\nLawrence: its so boring\r\nJulius: so lifeless\r\nLawrence: man!!\r\nJulius: it needs to change, hope the board sees it\r\nLawrence: sooner than later\r\nJulius: yeah\r\nLawrence: yeah', 'summary': "Lawrence doesn't like the play of Manchester United. He and Julius complain about the team and Mourinho's style."}





In [6]:
def preprocess_function(examples, prefix="summarize:"):
    inputs, targets = examples['dialogue'], examples['summary']
    inputs = [prefix + inp for inp in inputs]
    model_inputs = tokenizer(inputs, max_length=256, truncation=True)
    labels = tokenizer(text_target=targets, max_length=256, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


train_data = data['train'].filter(lambda row: row['dialogue'] and row['summary']).map(
    preprocess_function, batched=True, remove_columns=['id', 'dialogue', 'summary'])

Loading cached processed dataset at /home/jheuristic/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-b7c74e4eb9dea8c6.arrow
Loading cached processed dataset at /home/jheuristic/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-e2ba6db2279730ad.arrow


In [7]:
trainer = transformers.Seq2SeqTrainer(
    model=model, train_dataset=train_data,
    args=transformers.Seq2SeqTrainingArguments(
        do_train=True, remove_unused_columns=False,
        per_device_train_batch_size=4, gradient_accumulation_steps=2,
        optim='adafactor', warmup_steps=250, max_steps=1000, learning_rate=1e-5,
        logging_steps=1, output_dir='outputs'),
    tokenizer=tokenizer,
    data_collator=transformers.DataCollatorForSeq2Seq(
        tokenizer, padding=True, max_length=512, pad_to_multiple_of=8)
)

max_steps is given, it will override any value given in num_train_epochs


In [None]:
trainer.train()

***** Running training *****
  Num examples = 14731
  Num Epochs = 1
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 2
  Total optimization steps = 1000
  Number of trainable parameters = 2849757184
You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
1,1.0391
2,1.1094
3,1.1953
4,1.0781
5,1.3242
6,1.1836
7,1.248
8,1.0215
9,1.2266
10,1.2812


#### And that's it!

This tutorial keeps things simple to focus on tensor parallelism. If you want to train a more advanced summarization model, open the [`transformers/examples/pytorch/summarization`](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization) and wrap the model with __`tp.tensor_parallel(model)`__.