In [1]:
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

sns.set(palette='summer')

In [2]:
!pip install accelerate -U

Collecting accelerate
  Downloading accelerate-1.1.1-py3-none-any.whl.metadata (19 kB)
Downloading accelerate-1.1.1-py3-none-any.whl (333 kB)
   ---------------------------------------- 0.0/333.2 kB ? eta -:--:--
   - -------------------------------------- 10.2/333.2 kB ? eta -:--:--
   - -------------------------------------- 10.2/333.2 kB ? eta -:--:--
   - -------------------------------------- 10.2/333.2 kB ? eta -:--:--
   - -------------------------------------- 10.2/333.2 kB ? eta -:--:--
   - -------------------------------------- 10.2/333.2 kB ? eta -:--:--
   - -------------------------------------- 10.2/333.2 kB ? eta -:--:--
   - -------------------------------------- 10.2/333.2 kB ? eta -:--:--
   - -------------------------------------- 10.2/333.2 kB ? eta -:--:--
   - -------------------------------------- 10.2/333.2 kB ? eta -:--:--
   - -------------------------------------- 10.2/333.2 kB ? eta -:--:--
   - -------------------------------------- 10.2/333.2 kB ? eta -:-

In [3]:
!pip install datasets



In [4]:
!pip install transformers



In [5]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
   ---------------------------------------- 0.0/84.0 kB ? eta -:--:--
   ---------------------------------------- 0.0/84.0 kB ? eta -:--:--
   ---------------------------------------- 0.0/84.0 kB ? eta -:--:--
   ---------------------------------------- 0.0/84.0 kB ? eta -:--:--
   ---- ----------------------------------- 10.2/84.0 kB ? eta -:--:--
   ---- ----------------------------------- 10.2/84.0 kB ? eta -:--:--
   ---- ----------------------------------- 10.2/84.0 kB ? eta -:--:--
   ---- ----------------------------------- 10.2/84.0 kB ? eta -:--:--
   -------------- ------------------------- 30.7/84.0 kB 109.5 kB/s eta 0:00:01
   -------------- ------------------------- 30.7/84.0 kB 109.5 kB/s eta 0:00:01
   ------------------- -------------------- 41.0/84.0 kB 93.7 kB/s eta 0:00:01
   ------------------- -------------------- 41.0/84.0 kB 93.7

In [6]:
import transformers
from datasets import load_dataset
import evaluate

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

## Подготовка данных

In [8]:
billsum = load_dataset("billsum", split="ca_test")

Downloading readme:   0%|          | 0.00/7.27k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/91.8M [00:00<?, ?B/s]

Error while downloading from https://huggingface.co/datasets/billsum/resolve/3d8510441c06a3d9dfb32eb0d7f80151730bcc4f/data/train-00000-of-00001.parquet: HTTPSConnectionPool(host='cdn-lfs.hf.co', port=443): Read timed out.
Trying to resume download...


ChunkedEncodingError: ('Connection broken: IncompleteRead(23756451 bytes read, 36597381 more expected)', IncompleteRead(23756451 bytes read, 36597381 more expected))

In [None]:
billsum

Dataset({
    features: ['text', 'summary', 'title'],
    num_rows: 1237
})

In [None]:
billsum = billsum.train_test_split(test_size=0.1)

In [None]:
billsum

DatasetDict({
    train: Dataset({
        features: ['text', 'summary', 'title'],
        num_rows: 1113
    })
    test: Dataset({
        features: ['text', 'summary', 'title'],
        num_rows: 124
    })
})

In [None]:
billsum['train']['text']

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("ainize/bart-base-cnn")

Downloading (…)okenizer_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

In [None]:
def preprocess_function(examples):
    model_inputs = tokenizer(examples["text"], max_length=1024, truncation=True)

    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_billsum = billsum.map(preprocess_function, batched=True)

Map:   0%|          | 0/1113 [00:00<?, ? examples/s]

Map:   0%|          | 0/124 [00:00<?, ? examples/s]

In [None]:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained("ainize/bart-base-cnn")

Downloading pytorch_model.bin:   0%|          | 0.00/558M [00:00<?, ?B/s]

In [None]:
data_collator = transformers.DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [None]:
training_args = transformers.Seq2SeqTrainingArguments(
        output_dir="./results",
        evaluation_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        weight_decay=0.01,
        save_total_limit=3,
        num_train_epochs=2,
    )

In [None]:
trainer = transformers.Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_billsum["train"],
    eval_dataset=tokenized_billsum["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

In [None]:
trainer.train()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss
1,No log,2.054266
2,2.174500,2.005796


TrainOutput(global_step=558, training_loss=2.1474227221635935, metrics={'train_runtime': 403.3054, 'train_samples_per_second': 5.519, 'train_steps_per_second': 1.384, 'total_flos': 1357273356042240.0, 'train_loss': 2.1474227221635935, 'epoch': 2.0})

###  Предсказания на тестовом множестве

In [None]:
text_example = billsum["test"]['text'][0]
print(text_example)

The people of the State of California do enact as follows:


SECTION 1.
Section 236.14 is added to the Penal Code, to read:
236.14.
(a) If a person was arrested for or convicted of any nonviolent offense committed while he or she was a victim of human trafficking, including, but not limited to, prostitution as described in subdivision (b) of Section 647, the person may petition the court for vacatur relief of his or her convictions and arrests under th">(e) The court may, with the agreement of the petitioner and all of the involved state or local prosecutorial agencies, consolidate into one hearing a petition with multiple convictions from different jurisdictions.
(f) If the petition is opposed or if the court otherwise deems it necessary, the court shall schedule a hearing on the petition. The hearing may consist of the following:
(1) Testimony by the petitioner, which may be required in support of the petition.
(2) Evidence and supporting documentation in support of the petition.
(3)

In [None]:
input_ids = tokenizer.encode(
    text_example,
    return_tensors="pt",
    max_length=1024,
    truncation=True,
    ).to(device)

In [None]:
input_ids.shape

torch.Size([1, 1024])

In [None]:
summary_text_ids = model.generate(
    input_ids=input_ids,
    bos_token_id=model.config.bos_token_id,
    eos_token_id=model.config.eos_token_id,
    max_length=142,
    min_length=56,
    num_beams=4,
)

In [None]:
summary_text_ids

tensor([[    2,     0,  9089, 15528,   488,  2730,  7396,    10,   621,    54,
            21,  1128,    13,    50,  3828,     9,   143, 36887,  2970,  2021,
           150,    37,    50,    79,    21,    10,  1802,     9,  1050,  7492,
             6,   217,     6,    53,    45,  1804,     7,     6, 23105,     6,
             7,  5265,     5,   461,    13, 18721,  8367,  3500,     9,    39,
            50,    69, 12618,     8,  7102,     4,  3015, 15528,   488,  3441,
             5,   461,     7, 19752,    88,    65,  1576,    10,  5265,    19,
          1533, 12618,    31,   430, 17607,     4, 50118,   713,  1087,    74,
          2703,     5,   461,     6,    19,     5,  1288,     9,     5, 31390,
             8,    70,     9,     5,   963,   194,    50,   400, 42308, 17707,
          2244,     6,     7,     2]], device='cuda:0')

In [None]:
decoded_text = tokenizer.decode(summary_text_ids[0], skip_special_tokens=True)

In [None]:
len(decoded_text), len(text_example)

(547, 8857)

In [None]:
decoded_text

'Existing law authorizes a person who was arrested for or convicted of any nonviolent offense committed while he or she was a victim of human trafficking, including, but not limited to, prostitution, to petition the court for vacatur relief of his or her convictions and arrests. Existing law requires the court to consolidate into one hearing a petition with multiple convictions from different jurisdictions.\nThis bill would require the court, with the agreement of the petitioner and all of the involved state or local prosecutorial agencies, to'

In [None]:
summaries = []

for text in tqdm(billsum["test"]['text']):
    input_ids = tokenizer.encode(
        text,
        return_tensors="pt",
        max_length=1024,
        truncation=True,
        ).to(device)

    summary_text_ids = model.generate(
        input_ids=input_ids,
        bos_token_id=model.config.bos_token_id,
        eos_token_id=model.config.eos_token_id,
        max_length=142,
        min_length=56,
        num_beams=4,
    )

    decoded_text = tokenizer.decode(summary_text_ids[0], skip_special_tokens=True)
    summaries.append(decoded_text)

  0%|          | 0/124 [00:00<?, ?it/s]

## Считаем качество

### ROUGE

In [None]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=71e95dac241dd35a66fc8195d2717a975f9a1112f7431ca946d0133aafada3cd
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [None]:
rouge = evaluate.load('rouge')

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

In [None]:
%%time
results = rouge.compute(
        predictions=summaries,
        references=billsum["test"]['summary']
    )

CPU times: user 5.03 s, sys: 8.51 ms, total: 5.03 s
Wall time: 5.12 s


In [None]:
results

{'rouge1': 0.3295094036218852,
 'rouge2': 0.1785330127243029,
 'rougeL': 0.22409067060892754,
 'rougeLsum': 0.28382496411321023}