참고 url : https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb

In [None]:
!pip install transformers
!pip install sentencepiece==0.1.91

Collecting transformers
  Downloading transformers-4.12.5-py3-none-any.whl (3.1 MB)
[K     |████████████████████████████████| 3.1 MB 15.0 MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 76.5 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 89.0 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 91.5 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.2.1-py3-none-any.whl (61 kB)
[K     |████████████████████████████████| 61 kB 574 kB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers


transformers version = '4.12.5'

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# model.generate(pieces)
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader

In [None]:
model_folder = '/content/drive/MyDrive/3차 프로젝트/eT5_epoch8/pretrained_270000'

model = T5ForConditionalGeneration.from_pretrained(model_folder)
tokenizer = T5Tokenizer.from_pretrained(model_folder)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class CustomDataset:

    def __init__(self, dataframe, tokenizer, source_len, summ_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.source_len = source_len
        self.summ_len = summ_len
        self.text = self.data.text
        self.ctext = self.data.ctext

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        ctext = str(self.ctext[index])
        ctext = ' '.join(ctext.split())

        text = str(self.text[index])
        text = ' '.join(text.split())

        source = self.tokenizer.batch_encode_plus([ctext], max_length= self.source_len, pad_to_max_length=True,return_tensors='pt')
        target = self.tokenizer.batch_encode_plus([text], max_length= self.summ_len, pad_to_max_length=True,return_tensors='pt')

        source_ids = source['input_ids'].squeeze()
        source_mask = source['attention_mask'].squeeze()
        target_ids = target['input_ids'].squeeze()
        target_mask = target['attention_mask'].squeeze()

        return {
            'source_ids': source_ids.to(dtype=torch.long), 
            'source_mask': source_mask.to(dtype=torch.long), 
            'target_ids': target_ids.to(dtype=torch.long),
            'target_ids_y': target_ids.to(dtype=torch.long)
        }

In [None]:
def train(epoch, tokenizer, model, device, loader, optimizer):
    model.train()
    for _,data in tqdm(enumerate(loader, 0)):
        y = data['target_ids'].to(device, dtype = torch.long)
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone().detach()
        lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
        ids = data['source_ids'].to(device, dtype = torch.long)
        mask = data['source_mask'].to(device, dtype = torch.long)

        outputs = model(input_ids = ids, attention_mask = mask, decoder_input_ids=y_ids, labels=lm_labels)
        loss = outputs[0]
        
        if _%10 == 0:
            pass
            
        if _%500==0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # xm.optimizer_step(optimizer)
        # xm.mark_step()

In [None]:
def validate(epoch, tokenizer, model, device, loader):
    model.eval()
    predictions = []
    actuals = []
    with torch.no_grad():
        for _, data in enumerate(loader, 0):
            y = data['target_ids'].to(device, dtype = torch.long)
            ids = data['source_ids'].to(device, dtype = torch.long)
            mask = data['source_mask'].to(device, dtype = torch.long)

            generated_ids = model.generate(
                input_ids = ids,
                attention_mask = mask, 
                max_length=150, 
                num_beams=2,
                repetition_penalty=2.5, 
                length_penalty=1.0, 
                early_stopping=True
                )
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
            target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y]
            if _%100==0:
                print(f'Completed {_}')

            predictions.extend(preds)
            actuals.extend(target)
    return predictions, actuals

In [None]:
model.to(device)

T5ForConditionalGeneration(
  (shared): Embedding(45100, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(45100, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedGeluDense(
              (wi_0): Linear(in_features=768, out_features=3072, bias=False)
              (wi_1): Linear(in_features=768, out_features=3072, bias=False)
              (wo)

hyper-parameters

In [None]:
config = T5Config()
config.MAX_LEN = 1024
config.SUMMARY_LEN = 150 
config.TRAIN_BATCH_SIZE = 2       # input batch size for training (default: 64)
config.VALID_BATCH_SIZE = 2    # input batch size for testing (default: 1000)
config.TRAIN_EPOCHS = 8      # number of epochs to train (default: 10)
config.VAL_EPOCHS = 1
config.LEARNING_RATE = 1e-4    # learning rate (default: 0.01)
config.SEED = 42               # random seed (default: 42)

In [None]:
train_params = {
        'batch_size': config.TRAIN_BATCH_SIZE,
        'shuffle': True,
        'num_workers': 0
        }

val_params = {
        'batch_size': config.VALID_BATCH_SIZE,
        'shuffle': False,
        'num_workers': 0
        }

optimizer = torch.optim.Adam(params =  model.parameters(), lr=config.LEARNING_RATE)

dataset

In [None]:
import pandas as pd
train_dataset = pd.read_csv('/content/drive/MyDrive/3차 프로젝트/dataset/valid.csv')[['document','label']].iloc[:20000]
validation_dataset = pd.read_csv('/content/drive/MyDrive/3차 프로젝트/dataset/valid.csv')[['document','label']].iloc[20000:]

In [None]:
import numpy as np
train_dataset.set_index(np.arange(20000),inplace=True)

train

In [None]:
train_dataset.columns = ['ctext','text']
train_dataset.ctext = 'summarize: ' + train_dataset.ctext

training_set = CustomDataset(train_dataset, tokenizer, config.MAX_LEN, config.SUMMARY_LEN)
training_loader = DataLoader(training_set, **train_params)


for epoch in range(config.TRAIN_EPOCHS):
    print (epoch)
    train(epoch, tokenizer, model, device, training_loader, optimizer)

0


0it [00:00, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Epoch: 0, Loss:  1.0179436206817627


501it [02:41,  3.12it/s]

Epoch: 0, Loss:  2.4281959533691406


1001it [05:21,  3.13it/s]

Epoch: 0, Loss:  1.7643043994903564


1501it [08:01,  3.13it/s]

Epoch: 0, Loss:  1.28212308883667


2001it [10:41,  3.08it/s]

Epoch: 0, Loss:  1.8295304775238037


2501it [13:22,  3.12it/s]

Epoch: 0, Loss:  1.7735939025878906


3001it [16:02,  3.07it/s]

Epoch: 0, Loss:  2.3337581157684326


3501it [18:42,  3.14it/s]

Epoch: 0, Loss:  2.0601871013641357


4001it [21:22,  3.09it/s]

Epoch: 0, Loss:  2.016383647918701


4501it [24:04,  3.12it/s]

Epoch: 0, Loss:  2.4833502769470215


5001it [26:44,  3.12it/s]

Epoch: 0, Loss:  1.5827765464782715


5501it [29:24,  3.12it/s]

Epoch: 0, Loss:  1.836774230003357


6001it [32:04,  3.13it/s]

Epoch: 0, Loss:  1.1149574518203735


6501it [34:44,  3.13it/s]

Epoch: 0, Loss:  2.3753550052642822


7001it [37:25,  3.12it/s]

Epoch: 0, Loss:  2.0718300342559814


7501it [40:05,  3.12it/s]

Epoch: 0, Loss:  1.9655580520629883


8001it [42:45,  3.11it/s]

Epoch: 0, Loss:  1.5127049684524536


8501it [45:25,  3.11it/s]

Epoch: 0, Loss:  2.350362539291382


9001it [48:05,  3.11it/s]

Epoch: 0, Loss:  1.0204596519470215


9501it [50:45,  3.11it/s]

Epoch: 0, Loss:  2.4175682067871094


10000it [53:26,  3.12it/s]


1


0it [00:00, ?it/s]

Epoch: 1, Loss:  1.5985511541366577


501it [02:40,  3.10it/s]

Epoch: 1, Loss:  1.469878911972046


1001it [05:21,  3.13it/s]

Epoch: 1, Loss:  1.443392038345337


1501it [08:01,  3.09it/s]

Epoch: 1, Loss:  1.625374436378479


2001it [10:41,  3.13it/s]

Epoch: 1, Loss:  1.3250776529312134


2501it [13:22,  3.10it/s]

Epoch: 1, Loss:  1.1945077180862427


3001it [16:02,  3.09it/s]

Epoch: 1, Loss:  0.8842711448669434


3501it [18:42,  3.12it/s]

Epoch: 1, Loss:  0.9388953447341919


4001it [21:23,  3.12it/s]

Epoch: 1, Loss:  1.4701943397521973


4501it [24:03,  3.09it/s]

Epoch: 1, Loss:  0.9114287495613098


5001it [26:44,  3.10it/s]

Epoch: 1, Loss:  1.9835315942764282


5501it [29:25,  3.10it/s]

Epoch: 1, Loss:  0.5731541514396667


6001it [32:07,  3.10it/s]

Epoch: 1, Loss:  1.4761725664138794


6501it [34:48,  3.10it/s]

Epoch: 1, Loss:  1.2935709953308105


7001it [37:29,  3.11it/s]

Epoch: 1, Loss:  1.3554624319076538


7501it [40:11,  3.09it/s]

Epoch: 1, Loss:  2.2100460529327393


8001it [42:52,  3.12it/s]

Epoch: 1, Loss:  1.7183125019073486


8501it [45:33,  3.07it/s]

Epoch: 1, Loss:  1.7860759496688843


9001it [48:15,  3.09it/s]

Epoch: 1, Loss:  1.3479446172714233


9501it [50:56,  3.11it/s]

Epoch: 1, Loss:  1.2716609239578247


10000it [53:37,  3.11it/s]


2


0it [00:00, ?it/s]

Epoch: 2, Loss:  0.7035086154937744


501it [02:41,  3.09it/s]

Epoch: 2, Loss:  0.6442661285400391


1001it [05:22,  3.10it/s]

Epoch: 2, Loss:  0.6438519954681396


1501it [08:03,  3.10it/s]

Epoch: 2, Loss:  0.5682475566864014


2001it [10:45,  3.11it/s]

Epoch: 2, Loss:  0.4925718307495117


2501it [13:26,  3.10it/s]

Epoch: 2, Loss:  1.1779292821884155


3001it [16:07,  3.11it/s]

Epoch: 2, Loss:  1.3146321773529053


3501it [18:49,  3.11it/s]

Epoch: 2, Loss:  0.5677058696746826


4001it [21:30,  3.07it/s]

Epoch: 2, Loss:  0.45388028025627136


4501it [24:11,  3.12it/s]

Epoch: 2, Loss:  0.5948448181152344


5001it [26:52,  3.12it/s]

Epoch: 2, Loss:  1.0735149383544922


5501it [29:33,  3.09it/s]

Epoch: 2, Loss:  1.0695632696151733


6001it [32:14,  3.08it/s]

Epoch: 2, Loss:  0.9111840724945068


6501it [34:56,  3.11it/s]

Epoch: 2, Loss:  0.6029452681541443


7001it [37:37,  3.09it/s]

Epoch: 2, Loss:  0.6462777256965637


7501it [40:19,  3.06it/s]

Epoch: 2, Loss:  0.9772745370864868


8001it [43:00,  3.06it/s]

Epoch: 2, Loss:  1.473323106765747


8501it [45:42,  3.11it/s]

Epoch: 2, Loss:  0.5455979108810425


9001it [48:23,  3.09it/s]

Epoch: 2, Loss:  0.3375103175640106


9501it [51:05,  3.11it/s]

Epoch: 2, Loss:  1.0440514087677002


10000it [53:47,  3.10it/s]


3


0it [00:00, ?it/s]

Epoch: 3, Loss:  0.3979508578777313


501it [02:42,  3.08it/s]

Epoch: 3, Loss:  0.2917020618915558


1001it [05:24,  3.07it/s]

Epoch: 3, Loss:  0.5654667019844055


1500it [08:05,  3.05it/s]

Epoch: 3, Loss:  0.501425564289093


2001it [10:47,  3.10it/s]

Epoch: 3, Loss:  0.8856915235519409


2501it [13:29,  3.11it/s]

Epoch: 3, Loss:  0.642139732837677


3001it [16:10,  3.11it/s]

Epoch: 3, Loss:  0.35548996925354004


3501it [18:51,  3.12it/s]

Epoch: 3, Loss:  0.42682909965515137


4001it [21:32,  3.09it/s]

Epoch: 3, Loss:  0.5767117738723755


4501it [24:14,  3.10it/s]

Epoch: 3, Loss:  0.48366546630859375


5001it [26:55,  3.12it/s]

Epoch: 3, Loss:  0.7322551608085632


5501it [29:36,  3.08it/s]

Epoch: 3, Loss:  0.4278084933757782


6001it [32:17,  3.11it/s]

Epoch: 3, Loss:  0.42172539234161377


6501it [34:59,  3.09it/s]

Epoch: 3, Loss:  0.4063580334186554


7001it [37:40,  3.09it/s]

Epoch: 3, Loss:  0.31396016478538513


7501it [40:21,  3.10it/s]

Epoch: 3, Loss:  0.43488720059394836


8001it [43:02,  3.09it/s]

Epoch: 3, Loss:  0.7637120485305786


8501it [45:44,  3.10it/s]

Epoch: 3, Loss:  0.8417660593986511


9001it [48:25,  3.12it/s]

Epoch: 3, Loss:  0.461762934923172


9501it [51:06,  3.11it/s]

Epoch: 3, Loss:  0.23026441037654877


10000it [53:47,  3.10it/s]


4


0it [00:00, ?it/s]

Epoch: 4, Loss:  0.23612849414348602


501it [02:42,  3.10it/s]

Epoch: 4, Loss:  0.3897140622138977


1001it [05:23,  3.08it/s]

Epoch: 4, Loss:  0.5160859823226929


1501it [08:05,  3.10it/s]

Epoch: 4, Loss:  0.26389798521995544


2001it [10:46,  3.09it/s]

Epoch: 4, Loss:  0.30514833331108093


2501it [13:29,  3.08it/s]

Epoch: 4, Loss:  0.3611358404159546


3001it [16:10,  3.10it/s]

Epoch: 4, Loss:  0.49235251545906067


3501it [18:52,  3.07it/s]

Epoch: 4, Loss:  0.28892776370048523


4001it [21:34,  3.10it/s]

Epoch: 4, Loss:  0.1615063101053238


4501it [24:16,  3.10it/s]

Epoch: 4, Loss:  0.4057304859161377


5001it [26:57,  3.08it/s]

Epoch: 4, Loss:  0.10943176597356796


5501it [29:39,  3.09it/s]

Epoch: 4, Loss:  0.39039963483810425


6001it [32:21,  3.10it/s]

Epoch: 4, Loss:  0.4213894009590149


6501it [35:03,  3.07it/s]

Epoch: 4, Loss:  0.36032405495643616


7001it [37:44,  3.09it/s]

Epoch: 4, Loss:  0.38144415616989136


7501it [40:26,  3.08it/s]

Epoch: 4, Loss:  0.3581073582172394


8001it [43:08,  3.08it/s]

Epoch: 4, Loss:  0.3829452097415924


8501it [45:50,  3.11it/s]

Epoch: 4, Loss:  0.3490884006023407


9001it [48:32,  3.10it/s]

Epoch: 4, Loss:  0.6395193934440613


9501it [51:13,  3.07it/s]

Epoch: 4, Loss:  0.2794356346130371


10000it [53:55,  3.09it/s]


5


0it [00:00, ?it/s]

Epoch: 5, Loss:  0.15445606410503387


501it [02:42,  3.10it/s]

Epoch: 5, Loss:  0.2655979096889496


1001it [05:23,  3.08it/s]

Epoch: 5, Loss:  0.12379536777734756


1501it [08:04,  3.11it/s]

Epoch: 5, Loss:  0.23637570440769196


2001it [10:45,  3.12it/s]

Epoch: 5, Loss:  0.22151370346546173


2501it [13:27,  3.10it/s]

Epoch: 5, Loss:  0.1730654239654541


3001it [16:08,  3.10it/s]

Epoch: 5, Loss:  0.21844661235809326


3501it [18:50,  3.10it/s]

Epoch: 5, Loss:  0.1971435695886612


4001it [21:31,  3.07it/s]

Epoch: 5, Loss:  0.27657318115234375


4501it [24:13,  3.09it/s]

Epoch: 5, Loss:  0.13694696128368378


5001it [26:55,  3.06it/s]

Epoch: 5, Loss:  0.21067358553409576


5501it [29:36,  3.10it/s]

Epoch: 5, Loss:  0.22294440865516663


6001it [32:18,  3.11it/s]

Epoch: 5, Loss:  0.2679937481880188


6501it [34:59,  3.10it/s]

Epoch: 5, Loss:  0.24879774451255798


7001it [37:41,  3.10it/s]

Epoch: 5, Loss:  0.3455445468425751


7501it [40:22,  3.08it/s]

Epoch: 5, Loss:  0.20682772994041443


8001it [43:04,  3.10it/s]

Epoch: 5, Loss:  0.43496009707450867


8501it [45:46,  3.10it/s]

Epoch: 5, Loss:  0.2186824232339859


9001it [48:27,  3.10it/s]

Epoch: 5, Loss:  0.15243880450725555


9500it [51:09,  3.10it/s]

Epoch: 5, Loss:  0.18610069155693054


10000it [53:50,  3.10it/s]


6


0it [00:00, ?it/s]

Epoch: 6, Loss:  0.055953580886125565


501it [02:42,  3.07it/s]

Epoch: 6, Loss:  0.2741834223270416


1001it [05:23,  3.11it/s]

Epoch: 6, Loss:  0.09903175383806229


1501it [08:05,  3.09it/s]

Epoch: 6, Loss:  0.15076255798339844


2001it [10:47,  3.10it/s]

Epoch: 6, Loss:  0.19665281474590302


2501it [13:29,  3.11it/s]

Epoch: 6, Loss:  0.1375505030155182


3001it [16:10,  3.10it/s]

Epoch: 6, Loss:  0.10771556943655014


3501it [18:52,  3.11it/s]

Epoch: 6, Loss:  0.21325621008872986


4001it [21:33,  3.08it/s]

Epoch: 6, Loss:  0.2798163890838623


4501it [24:15,  3.11it/s]

Epoch: 6, Loss:  0.05893850326538086


5001it [26:56,  3.10it/s]

Epoch: 6, Loss:  0.1931663155555725


5501it [29:37,  3.10it/s]

Epoch: 6, Loss:  0.1717403084039688


6001it [32:19,  3.07it/s]

Epoch: 6, Loss:  0.09010182321071625


6501it [35:00,  3.09it/s]

Epoch: 6, Loss:  0.129190593957901


7001it [37:41,  3.08it/s]

Epoch: 6, Loss:  0.20594145357608795


7501it [40:23,  3.08it/s]

Epoch: 6, Loss:  0.17842760682106018


8001it [43:05,  3.05it/s]

Epoch: 6, Loss:  0.20705877244472504


8501it [45:46,  3.08it/s]

Epoch: 6, Loss:  0.1342467963695526


9001it [48:28,  3.11it/s]

Epoch: 6, Loss:  0.2095559686422348


9501it [51:09,  3.11it/s]

Epoch: 6, Loss:  0.20448696613311768


10000it [53:50,  3.10it/s]


7


0it [00:00, ?it/s]

Epoch: 7, Loss:  0.15694980323314667


501it [02:41,  3.10it/s]

Epoch: 7, Loss:  0.17117370665073395


1001it [05:23,  3.11it/s]

Epoch: 7, Loss:  0.13855169713497162


1501it [08:04,  3.09it/s]

Epoch: 7, Loss:  0.12910106778144836


2001it [10:46,  3.09it/s]

Epoch: 7, Loss:  0.09587527811527252


2501it [13:27,  3.08it/s]

Epoch: 7, Loss:  0.1817949414253235


3001it [16:09,  3.08it/s]

Epoch: 7, Loss:  0.15388034284114838


3501it [18:50,  3.10it/s]

Epoch: 7, Loss:  0.18405736982822418


4001it [21:32,  3.09it/s]

Epoch: 7, Loss:  0.32748502492904663


4501it [24:14,  3.10it/s]

Epoch: 7, Loss:  0.36244404315948486


5001it [26:55,  3.09it/s]

Epoch: 7, Loss:  0.1167621910572052


5501it [29:37,  3.10it/s]

Epoch: 7, Loss:  0.11731458455324173


6001it [32:19,  3.09it/s]

Epoch: 7, Loss:  0.13131320476531982


6501it [35:01,  3.10it/s]

Epoch: 7, Loss:  0.275181382894516


7001it [37:43,  3.08it/s]

Epoch: 7, Loss:  0.22769537568092346


7501it [40:24,  3.09it/s]

Epoch: 7, Loss:  0.19854170083999634


8001it [43:06,  3.06it/s]

Epoch: 7, Loss:  0.12499503046274185


8501it [45:47,  3.10it/s]

Epoch: 7, Loss:  0.07190177589654922


9001it [48:29,  3.06it/s]

Epoch: 7, Loss:  0.1369173526763916


9501it [51:11,  3.09it/s]

Epoch: 7, Loss:  0.47515732049942017


10000it [53:53,  3.09it/s]


In [None]:
tokenizer.save_pretrained('/content/drive/MyDrive/3차 프로젝트/eT5_epoch8/pretrained_290000(8)/')
model.save_pretrained('/content/drive/MyDrive/3차 프로젝트/eT5_epoch8/pretrained_290000(8)/')