ライブラリインポート

In [1]:
!pip install transformers
!pip install wandb 



In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

In [3]:
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

In [4]:
from transformers import T5TokenizerFast, T5ForConditionalGeneration

In [5]:
import wandb

In [6]:
!nvidia-smi

Fri Sep  3 18:21:49 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.63.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   72C    P0    87W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [7]:
from torch import cuda
device = "cuda" if cuda.is_available() else "cpu"

In [8]:
device

'cuda'

In [9]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mnaoya[0m (use `wandb login --relogin` to force relogin)


In [10]:
!pip install sentencepiece



データ準備

In [11]:
df = pd.read_csv("./data/news_summary.csv", encoding = "latin-1")

In [12]:
df.head()

Unnamed: 0,author,date,headlines,read_more,text,ctext
0,Chhavi Tyagi,"03 Aug 2017,Thursday",Daman & Diu revokes mandatory Rakshabandhan in...,http://www.hindustantimes.com/india-news/raksh...,The Administration of Union Territory Daman an...,The Daman and Diu administration on Wednesday ...
1,Daisy Mowke,"03 Aug 2017,Thursday",Malaika slams user who trolled her for 'divorc...,http://www.hindustantimes.com/bollywood/malaik...,Malaika Arora slammed an Instagram user who tr...,"From her special numbers to TV?appearances, Bo..."
2,Arshiya Chopra,"03 Aug 2017,Thursday",'Virgin' now corrected to 'Unmarried' in IGIMS...,http://www.hindustantimes.com/patna/bihar-igim...,The Indira Gandhi Institute of Medical Science...,The Indira Gandhi Institute of Medical Science...
3,Sumedha Sehra,"03 Aug 2017,Thursday",Aaj aapne pakad liya: LeT man Dujana before be...,http://indiatoday.intoday.in/story/abu-dujana-...,Lashkar-e-Taiba's Kashmir commander Abu Dujana...,Lashkar-e-Taiba's Kashmir commander Abu Dujana...
4,Aarushi Maheshwari,"03 Aug 2017,Thursday",Hotel staff to get training to spot signs of s...,http://indiatoday.intoday.in/story/sex-traffic...,Hotels in Maharashtra will train their staff t...,Hotels in Mumbai and other Indian cities are t...


In [13]:
df.ctext.iat[0]

'The Daman and Diu administration on Wednesday withdrew a circular that asked women staff to tie rakhis on male colleagues after the order triggered a backlash from employees and was ripped apart on social media.The union territory?s administration was forced to retreat within 24 hours of issuing the circular that made it compulsory for its staff to celebrate Rakshabandhan at workplace.?It has been decided to celebrate the festival of Rakshabandhan on August 7. In this connection, all offices/ departments shall remain open and celebrate the festival collectively at a suitable time wherein all the lady staff shall tie rakhis to their colleagues,? the order, issued on August 1 by Gurpreet Singh, deputy secretary (personnel), had said.To ensure that no one skipped office, an attendance report was to be sent to the government the next evening.The two notifications ? one mandating the celebration of Rakshabandhan (left) and the other withdrawing the mandate (right) ? were issued by the Dama

In [14]:
df.text.iat[0]

'The Administration of Union Territory Daman and Diu has revoked its order that made it compulsory for women to tie rakhis to their male colleagues on the occasion of Rakshabandhan on August 7. The administration was forced to withdraw the decision within 24 hours of issuing the circular after it received flak from employees and was slammed on social media.'

In [15]:
df = df[["text", "ctext"]]

In [16]:
len(df.index)

4514

In [17]:
df = df.sample(100)

In [18]:
df.head()

Unnamed: 0,text,ctext
1583,"Viacom18, the producers of the John Abraham an...",A dispute has broken out between the productio...
2658,"Amid communal violence in West Bengal, Haryana...",Even as West Bengal continues to simmer in the...
3970,"Uday Kotak, Kotak Mahindra Bank's Managing Dir...","Mumbai, Apr 2 (PTI) Banker Uday Kotak feels th..."
1831,A Kotak Mahindra Bank branch manager arrested ...,The manager of the Kotak Mahindra Bank branch ...
4385,Union Minister Uma Bharti has asked Congress V...,"New Delhi, Mar 3 (PTI) Taking on Rahul Gandhi ..."


In [19]:
train_size = 0.8

train_dataset = df.sample(frac = train_size, random_state = 0)
val_dataset = df.drop(train_dataset.index).reset_index(drop = True)
train_dataset = train_dataset.reset_index(drop = True)

In [20]:
print(len(train_dataset.index))
print(len(val_dataset.index))

80
20


In [21]:
class CustomDataset(Dataset):
  def __init__(self, dataframe, tokenizer, source_len, summ_len):
    self.tokenizer = tokenizer
    self.data = dataframe
    self.source_len = source_len
    self.summ_len = summ_len
    self.text = self.data.text
    self.ctext = self.data.ctext

  def __len__(self):
    return len(self.text)
  
  def __getitem__(self, index):
    ctext = str(self.ctext[index])
    ctext = " ".join(ctext.split())

    text = str(self.text[index])
    text = " ".join(text.split())

    source = self.tokenizer.batch_encode_plus([ctext],
                                              max_length = self.source_len,
                                              pad_to_max_length = True,
                                              return_tensors = "pt")
    target = self.tokenizer.batch_encode_plus([text],
                                              max_length = self.summ_len,
                                              pad_to_max_length = True,
                                              return_tensors = "pt")
    
    source_ids = source["input_ids"].squeeze()
    source_mask = source["attention_mask"].squeeze()
    target_ids = target["input_ids"].squeeze()
    target_mask = target["attention_mask"].squeeze()

    return {
        "source_ids": source_ids,
        "source_mask": source_mask,
        "target_ids": target_ids,
        "target_mask": target_mask
    }

In [22]:
tokenizer = T5TokenizerFast.from_pretrained("t5-small")

In [23]:
training_set = CustomDataset(train_dataset, tokenizer, 512, 150) 
val_set = CustomDataset(val_dataset, tokenizer, 512, 150)
"""
512 : max_sourcelen
150 : max_summlen
"""

'\n512 : max_sourcelen\n150 : max_summlen\n'

In [24]:
train_params = {
    "batch_size": 2,
    "shuffle": True,
    "num_workers": 0
}
val_params = {
    "batch_size": 2,
    "shuffle": False,
    "num_workers": 0
}

In [25]:
training_loader = DataLoader(training_set, **train_params)
val_loader = DataLoader(val_set, **val_params)

In [26]:
for i in training_loader:
  print(i)
  break

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


{'source_ids': tensor([[ 5264,   116,   736,  ...,     0,     0,     0],
        [15810,     6, 10299,  ...,    19,  1631,     1]]), 'source_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]]), 'target_ids': tensor([[   86,     3,     9,  3046,  3634,    28,  7556,   736,    23, 10309,
           107, 16453,  8694,     6, 15039,   162,    49, 16738,  2832,     6,
            96,   196,    31,    51, 11320,    24,    25,   237,   214,    27,
          3223,  4720, 15039,   162,    49,  2832,    48,   227,   736,    23,
         10309,   107, 25932,    81,   376,    59,  8776,    53,    12,   112,
         12780,  3591, 18660,     7,     5, 15039,   162,    49, 27975,   856,
             6,    96,  5420,  6873,    24,    27,  3001, 16359,  5121, 18606,
            55,    27,   131,  2269,  7122,    34,   396,    55,   148,   130,
            16,   685,    80,    13,     8,   166,    12,   237,  1663,  4720,
             1,     0,     0,     0,     0,     0,     0,



モデル構築

In [27]:
model = T5ForConditionalGeneration.from_pretrained("t5-small")
model = model.to(device)

モデル訓練

In [28]:
def train(epoch, tokenizer, model, device, loader, optimizer):
  model.train()
  for i, data in enumerate(loader, 0):
    y = data["target_ids"].to(device, dtype = torch.long)
    y_ids = y[:, :-1].contiguous()
    lm_labels = y[:, 1:].clone().detach()
    lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100 #ちょっとしたテクニック
    ids = data["source_ids"].to(device, dtype = torch.long)
    mask = data["source_mask"].to(device, dtype = torch.long)

    outputs = model(input_ids = ids, 
                    attention_mask = mask,
                    decoder_input_ids = y_ids,
                    labels = lm_labels)
    loss = outputs[0]

    if i % 10 == 0:
      wandb.log({"Training Loss": loss.item()})
    
    if i % 500 == 0:
      print(f"Epoch:{epoch + 1}, Loss:{loss.item()}")
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [29]:
def validate(epoch, tokenizer, model, device, loader):
  model.eval()
  predictions = []
  actuals = []

  with torch.no_grad():
    for i, data in enumerate(loader, 0):
      y = data["target_ids"].to(device, dtype = torch.long)
      ids = data["source_ids"].to(device, dtype = torch.long)
      mask = data["source_mask"].to(device, dtype = torch.long)

      generated_ids = model.generate(
          input_ids = ids,
          attention_mask = mask,
          max_length = 150, 
          num_beams = 2,
          repetition_penalty = 2.5,
          length_penalty = 1.0,
          early_stopping = True
      )
      preds = [tokenizer.decode(g, 
                                skip_special_tokens = True,
                                clean_up_tokenization_spaces = True)
      for g in generated_ids]
      target = [tokenizer.decode(t,
                                 skip_special_tokens = True,
                                 clean_up_tokenization_spaces = True)
      for t in y]

      predictions.extend(preds)
      actuals.extend(target)

      if i % 100 == 0:
        print(f"completed: {i}")
    return predictions, actuals

In [30]:
def main():
  wandb.init(project = "transformers_tutorials_summarization")

  config = wandb.config
  config.TRAIN_BATCH_SIZE = 2
  config.VALID_BATCH_SIZE = 2
  config.TRAIN_EPOCHS = 2
  config.VAL_EPOCHS = 1
  config.LEARNIG_RATE = 1e-4
  config.SEED = 42 # ハッシュの固定
  config.MAX_LEN = 512
  config.SUMMARY_LEN = 150

  # 再現性を担保
  torch.manual_seed(config.SEED)
  np.random.seed(config.SEED)
  torch.backends.cudnn.deterministic = True

  optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNIG_RATE)
  wandb.watch(model, log = "all")
  for epoch in range(config.TRAIN_EPOCHS):
    train(epoch, tokenizer, model, device, training_loader, optimizer)

  for epoch in range(config.VAL_EPOCHS):
    predictions, actuals = validate(epoch, tokenizer, model, device, val_loader)
    final_df = pd.DataFrame({"Generated Text":predictions, "Actual Text":actuals})
    final_df.to_csv("models/predictions.csv")
    print("Output files generated for review")

In [31]:
main()

[34m[1mwandb[0m: Currently logged in as: [33mnaoya[0m (use `wandb login --relogin` to force relogin)




Epoch:1, Loss:3.1496410369873047
Epoch:2, Loss:1.4310293197631836
completed: 0
Output files generated for review


推論(要約)

In [32]:
final_df = pd.read_csv("./models/predictions.csv")

In [33]:
final_df.head()

Unnamed: 0.1,Unnamed: 0,Generated Text,Actual Text
0,0,administration and the police on Friday lodged...,Farmers on Friday blocked the passage of Haldi...
1,1,": The Beginning hit the screens in 2015, Baahu...",The trailer of SS Rajamouli directorial 'Baahu...
2,2,Minister Ashok Gajapathi Raju wrote to the Civ...,A day after Air India lifted the ban on Shiv S...
3,3,python was bigger than Matthew Bager. It took ...,An Australian man saved the life of a 2.5-metr...
4,4,Rahul Gandhi met the protesting Tamil Nadu far...,Congress Vice-President Rahul Gandhi on Friday...


In [34]:
final_df.at[0, "Generated Text"]

'administration and the police on Friday lodged an FIR against 20 farmers for blocking the passage of Haldia Express train on the Delhi-Howrah line.Farmers have been protesting near Bil Akbarpur village for the last four months demanding increased compensation for their agricultural land earmarked for 135km Eastern Peripheral Expressway. They decided to block railway line after the administration forcibly removed them from the EPE site. Farmers are angry with the administration for filing an FIR and forcibly'

In [35]:
val_dataset.ctext.iat[0]

' The Gautam Budh Nagar district administration and the police on Friday lodged an FIR against 20 farmers for blocking the passage of Haldia Express train on the Delhi-Howrah railway line that passes through Greater Noida.Farmers sat on a  protest at Bodaki village along Delhi-Howrah line, thereby blocking trains.Farmers have been protesting near Bil Akbarpur village for the last four months demanding increased compensation for their agricultural land earmarked for 135km Eastern Peripheral Expressway. They decided to the block railway line after the administration forcibly removed them from the Eastern Peripheral Expressway site. The expressway is meant to decongest Delhi with an aim to reduce air pollution in Delhi-NCR. The Supreme Court, in 2006, had  directed the Centre to build the expressway as soon as possible. The National Highways Authority of India said 17km of the expressway falls in 39 villages of Greater Noida. Farmers are demanding that instead of the Rs 3,640 per square m

1件要約
- Starbucksの説明文(Wikipedia)の記事を入力に、要約を1件単位で行うソースを書いて要約を行う。

In [36]:
english_text = """
Starbucks Corporation is an American multinational chain of coffeehouses and roastery reserves headquartered in Seattle, Washington. As the world's largest coffeehouse chain, Starbucks is seen to be the major representation of the United States' second wave of coffee culture.[3][4] As of September 2020, the company had 32,660 stores in 83 countries, including 16,637 company operated stores and 16,023 licensed stores.[2] Of these 32,660 stores, 18,354 were in the United States, Canada, and Latin America.[2] Starbucks locations serve hot and cold drinks, whole-bean coffee, micro-ground instant coffee, espresso, caffe latte, full and loose-leaf teas, juices, Frappuccino beverages, pastries, and snacks. Some offerings are seasonal or specific to the locality of the store. Depending on the country, most locations offer free Wi-Fi.

Headquartered in the Starbucks Center, the company was founded in 1971 by Jerry Baldwin, Zev Siegl, and Gordon Bowker at Seattle's Pike Place Market. During the early 1980s, they sold the company to Howard Schultz who – after a business trip to Milan, Italy – decided to make the coffee bean store a coffeeshop serving espresso-based drinks. Schultz's first tenure as chief executive, from 1986 to 2000, led to an aggressive expansion of the franchise, first in Seattle, then across the West Coast of the United States. Despite an initial economic downturn with its expansion into the Midwestern United States and British Columbia, the company experienced revitalized prosperity with its entry into California in the early 1990s through a series of highly publicized coffee wars. Schultz was succeeded by Orin Smith who ran the company for five years, positioning Starbucks as a large player in fair trade coffee and increasing sales to $5 billion. Jim Donald served as chief executive from 2005 to 2008, orchestrating a large-scale earnings expansion. Schultz returned as CEO during the financial crisis of 2007–08 and spent the succeeding decade growing its market share, expanding its offerings, and reorienting itself around corporate social responsibility. Kevin Johnson succeeded Schultz in 2017, and continues to serve as the firm's chief executive.

Many stores sell pre-packaged food items, pastries, hot and cold sandwiches, and drinkware including mugs and tumblers. There are also several select "Starbucks Evenings" locations which offer beer, wine, and appetizers. Starbucks-brand coffee, ice cream, and bottled cold coffee drinks are sold at grocery stores in the United States and other countries. In 2010, the company began its Starbucks Reserve program for single-origin coffees and high-end coffee shops. It planned to open 1,000 Reserve coffee shops by the end of 2017.[5] Starbucks operates six roasteries with tasting rooms and 43 coffee bars as part of the program. The latest roastery location opened on Chicago's Magnificent Mile in November 2019, and is the world's largest Starbucks location. The company has received significant criticism about its business practices, corporate affairs, and role in society. Conversely, its franchise has commanded substantial brand loyalty, market share, and company value.
"""

In [38]:
encoded_text_mask = tokenizer.encode_plus(english_text,
                                          max_length = 256,
                                          pad_to_max_length = True,
                                          return_tensors = "pt")



In [39]:
encoded_text_mask["input_ids"]

tensor([[24474,  6708,    19,    46,   797, 23089,  3741,    13,  1975,  1840,
             7,    11, 14238,  4203, 11222,     3, 27630,    16,  8854,     6,
          2386,     5,   282,     8,   296,    31,     7,  2015,  1975,  1840,
          3741,     6, 24474,    19,   894,    12,    36,     8,   779,  6497,
            13,     8,   907,  1323,    31,   511,  6772,    13,  1975,  1543,
             5,  6306,   519,   908,  6306,   591,   908,   282,    13,  1600,
          6503,     6,     8,   349,   141,  3538,     6, 27720,  3253,    16,
             3,  4591,  1440,     6,   379, 11940,  3891,   940,   349,  7747,
          3253,    11, 11940,   632,  2773,  6681,  3253,     5,  6306,   357,
           908,  1129,   175,  3538,     6, 27720,  3253,     6, 14985,  2469,
           591,   130,    16,     8,   907,  1323,     6,  1894,     6,    11,
          6271,  1371,     5,  6306,   357,   908, 24474,  3248,  1716,  1312,
            11,  2107,  6750,     6,   829,    18,  

In [40]:
encoded_text_mask["attention_mask"]

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

In [41]:
generated_ids = model.generate(
    input_ids = encoded_text_mask["input_ids"].to(device),
    attention_mask = encoded_text_mask["attention_mask"].to(device),
    max_length = 30,
    num_beams = 2,
    repetition_penalty = 2.5,
    length_penalty = 1.0,
    early_stopping = True
)

In [43]:
sum_sturbucks = [tokenizer.decode(g, 
                                  skip_special_tokens = True,
                                  clean_up_tokenization_spaces = True)
for g in generated_ids]
" ".join(sum_sturbucks)

"Corporation is an American multinational chain of coffeehouses and roastery reserves headquartered in Seattle, Washington. Starbucks Corporation is the world's"