In [1]:
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from sklearn.model_selection import train_test_split
import textwrap

from torch.utils.data import Dataset,DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

from transformers import(
    AdamW,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

In [2]:
from datasets import Dataset

In [3]:
def get_device_and_set_seed(seed):
    """ Set all seeds to make results reproducible """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:1" if use_cuda else "cpu")
    return device
    
SEED = 123
device = get_device_and_set_seed(SEED)

In [4]:
device

device(type='cuda', index=1)

In [5]:
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base")
model.to(device)

T5ForConditionalGeneration(
  (shared): Embedding(36096, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(36096, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [6]:
next(model.parameters()).is_cuda

True

In [7]:
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")  

In [8]:
tokenizer.max_model_input_sizes

{'t5-small': 512, 't5-base': 512, 't5-large': 512, 't5-3b': 512, 't5-11b': 512}

In [9]:
labels = tokenizer(
        't√¥i th√≠ch b·∫°n', max_length=256, truncation=True, padding=True
    )

In [10]:
labels

{'input_ids': [671, 1470, 1113, 1], 'attention_mask': [1, 1, 1, 1]}

## Prepare Data

In [11]:
train_path = '../processed_data/train_df.csv'

In [12]:
test_path = '../processed_data/test_df.csv'

In [13]:
test_df = pd.read_csv(test_path)
test_df.head()

Unnamed: 0,example_id,statement,law_id,article_id
0,qr6S2jA9GG,"N·∫øu kh√¥ng ph·∫°m t·ªôi qu·∫£ tang, m·ªôt ng∆∞·ªùi s·∫Ω kh√¥n...",Hi·∫øn ph√°p 2013,20
1,zFWTqve74q,"Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c ...","Lu·∫≠t Ph√≤ng, ch·ªëng b·∫°o l·ª±c gia ƒë√¨nh 2022",7
2,bv98C2F6I9,Ho·∫°t ƒë·ªông thanh tra ph·∫£i ƒë∆∞·ª£c th·ª±c hi·ªán theo k...,Lu·∫≠t Thanh tra 2022,46
3,Kzmq9qdScq,Lo·∫°i h·ª£p ƒë·ªìng n√†o sau ƒë√¢y kh√¥ng ƒë∆∞·ª£c BLDS 2015...,B·ªô Lu·∫≠t D√¢n s·ª± 2015,402
4,8x8IdMONw9,B·∫£n √°n ph√∫c th·∫©m ƒë∆∞·ª£c tuy√™n v√†o ng√†y 20/01/202...,Lu·∫≠t T·ªë t·ª•ng h√†nh ch√≠nh 2015,242


In [14]:
train_df = pd.read_csv(train_path)
train_df.head()

Unnamed: 0,example_id,label,statement,law_id,article_id,segment
0,q9zjh7Uw7Q,No,Ng∆∞·ªùi xem d∆∞·ªõi 16 tu·ªïi ƒë∆∞·ª£c xem phim c√≥ n·ªôi du...,Lu·∫≠t ƒêi·ªán ·∫£nh 2022,32,Ng∆∞·ªùi xem d∆∞·ªõi 16 tu·ªïi ƒë∆∞·ª£c xem phim c√≥ n·ªôi_du...
1,ckQFn8y202,No,"Trong v√≤ng 03 ng√†y l√†m vi·ªác, k·ªÉ t·ª´ ng√†y ng∆∞·ªùi ...","Lu·∫≠t Ph√≤ng, ch·ªëng ma t√∫y 2021",30,"Trong v√≤ng 03 ng√†y l√†m_vi·ªác , k·ªÉ t·ª´ ng√†y ng∆∞·ªùi..."
2,3ROu621ZEO,Yes,Vi√™n ch·ª©c c√≥ 02 nƒÉm li√™n ti·∫øp b·ªã ph√¢n lo·∫°i ƒë√°n...,Lu·∫≠t Vi√™n ch·ª©c 2010,29,Vi√™n_ch·ª©c c√≥ 02 nƒÉm li√™n_ti·∫øp b·ªã ph√¢n_lo·∫°i ƒë√°n...
3,VT1QuVmhCc,Yes,C√°c bi·ªán ph√°p cai nghi·ªán ma t√∫y l√† nh·ªØng bi·ªán ...,"Lu·∫≠t Ph√≤ng, ch·ªëng ma t√∫y 2021",28,C√°c bi·ªán_ph√°p cai_nghi·ªán ma_tu√Ω l√† nh·ªØng bi·ªán_...
4,0MwITLtbmg,No,Vi√™n ch·ª©c thu·ªôc m·ªôt ƒë∆°n v·ªã s·ª± nghi·ªáp c√¥ng l·∫≠p ...,Lu·∫≠t Vi√™n ch·ª©c 2010,14,Vi√™n_ch·ª©c thu·ªôc m·ªôt ƒë∆°n_v·ªã s·ª±_nghi·ªáp c√¥ng_l·∫≠p ...


In [15]:
legal_df = pd.read_csv('../processed_data/legal_df.csv')

In [16]:
merge_df = pd.merge(train_df, legal_df, on=['law_id', 'article_id'], suffixes=('_df1', '_df2'), how='outer')
merge_df

Unnamed: 0,example_id,label,statement,law_id,article_id,segment,text
0,q9zjh7Uw7Q,No,Ng∆∞·ªùi xem d∆∞·ªõi 16 tu·ªïi ƒë∆∞·ª£c xem phim c√≥ n·ªôi du...,Lu·∫≠t ƒêi·ªán ·∫£nh 2022,32,Ng∆∞·ªùi xem d∆∞·ªõi 16 tu·ªïi ƒë∆∞·ª£c xem phim c√≥ n·ªôi_du...,Ph√¢n lo·∫°i phim\n\n1. Phim ƒë∆∞·ª£c ph√¢n lo·∫°i theo ...
1,bu3nyKAUNT,No,Ng∆∞·ªùi xem t·ª´ 18 tu·ªïi tr·ªü l√™n ƒë∆∞·ª£c ph√°p xem phi...,Lu·∫≠t ƒêi·ªán ·∫£nh 2022,32,Ng∆∞·ªùi xem t·ª´ 18 tu·ªïi tr·ªü l√™n ƒë∆∞·ª£c ph√°p xem phi...,Ph√¢n lo·∫°i phim\n\n1. Phim ƒë∆∞·ª£c ph√¢n lo·∫°i theo ...
2,aS4Oxqxklj,Yes,Ng∆∞·ªùi t·ª´ 16 tu·ªïi tr·ªü l√™n ƒë∆∞·ª£c ph√©p xem phim c√≥...,Lu·∫≠t ƒêi·ªán ·∫£nh 2022,32,Ng∆∞·ªùi t·ª´ 16 tu·ªïi tr·ªü l√™n ƒë∆∞·ª£c ph√©p xem phim c√≥...,Ph√¢n lo·∫°i phim\n\n1. Phim ƒë∆∞·ª£c ph√¢n lo·∫°i theo ...
3,E3nezW9AK0,No,Phim lo·∫°i C l√† phim ƒë∆∞·ª£c ph√©p ph·ªï bi·∫øn ƒë·∫øn ng∆∞...,Lu·∫≠t ƒêi·ªán ·∫£nh 2022,32,Phim lo·∫°i C l√† phim ƒë∆∞·ª£c ph√©p ph·ªï_bi·∫øn ƒë·∫øn ng∆∞...,Ph√¢n lo·∫°i phim\n\n1. Phim ƒë∆∞·ª£c ph√¢n lo·∫°i theo ...
4,ckQFn8y202,No,"Trong v√≤ng 03 ng√†y l√†m vi·ªác, k·ªÉ t·ª´ ng√†y ng∆∞·ªùi ...","Lu·∫≠t Ph√≤ng, ch·ªëng ma t√∫y 2021",30,"Trong v√≤ng 03 ng√†y l√†m_vi·ªác , k·ªÉ t·ª´ ng√†y ng∆∞·ªùi...","Cai nghi·ªán ma t√∫y t·ª± nguy·ªán t·∫°i gia ƒë√¨nh, c·ªông..."
...,...,...,...,...,...,...,...
2265,,,,Lu·∫≠t Thanh ni√™n 2020,36,,N·ªôi dung qu·∫£n l√Ω nh√† n∆∞·ªõc v·ªÅ thanh ni√™n\n\n1. ...
2266,,,,Lu·∫≠t Thanh ni√™n 2020,38,,Tr√°ch nhi·ªám c·ªßa B·ªô N·ªôi v·ª•\n\nB·ªô N·ªôi v·ª• ch·ªãu tr...
2267,,,,Lu·∫≠t Thanh ni√™n 2020,39,,"Tr√°ch nhi·ªám c·ªßa c√°c B·ªô, c∆° quan ngang B·ªô\n\nC√°..."
2268,,,,Lu·∫≠t Thanh ni√™n 2020,40,,"Tr√°ch nhi·ªám c·ªßa H·ªôi ƒë·ªìng nh√¢n d√¢n, ·ª¶y ban nh√¢n..."


In [17]:
merge_df_test = pd.merge(test_df, legal_df, on=['law_id', 'article_id'], suffixes=('_df1', '_df2'), how='outer')
merge_df_test
merge_df_test = merge_df_test[~merge_df_test['statement'].isnull()]
merge_df_test

Unnamed: 0,example_id,statement,law_id,article_id,text
0,qr6S2jA9GG,"N·∫øu kh√¥ng ph·∫°m t·ªôi qu·∫£ tang, m·ªôt ng∆∞·ªùi s·∫Ω kh√¥n...",Hi·∫øn ph√°p 2013,20,1. M·ªçi ng∆∞·ªùi c√≥ quy·ªÅn b·∫•t kh·∫£ x√¢m ph·∫°m v·ªÅ th√¢n...
1,zFWTqve74q,"Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c ...","Lu·∫≠t Ph√≤ng, ch·ªëng b·∫°o l·ª±c gia ƒë√¨nh 2022",7,"Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c ..."
2,RwHiqOG8vb,"Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c ...","Lu·∫≠t Ph√≤ng, ch·ªëng b·∫°o l·ª±c gia ƒë√¨nh 2022",7,"Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c ..."
3,wv4eBTDVKS,"Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c ...","Lu·∫≠t Ph√≤ng, ch·ªëng b·∫°o l·ª±c gia ƒë√¨nh 2022",7,"Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c ..."
4,DyWakZCVVK,"Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c ...","Lu·∫≠t Ph√≤ng, ch·ªëng b·∫°o l·ª±c gia ƒë√¨nh 2022",7,"Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c ..."
...,...,...,...,...,...
135,IdIDVZpLW4,Kho·∫£ng th·ªùi gian x·∫£y ra tr·ªü ng·∫°i kh√°ch quan l√†...,B·ªô Lu·∫≠t D√¢n s·ª± 2015,156,Th·ªùi gian kh√¥ng t√≠nh v√†o th·ªùi hi·ªáu kh·ªüi ki·ªán v...
136,TiFpkKZtFM,Ph√°n quy·∫øt tr·ªçng t√†i kh√¥ng b·ªã xem x√©t l·∫°i theo...,Lu·∫≠t Tr·ªçng t√†i th∆∞∆°ng m·∫°i 2010,4,Nguy√™n t·∫Øc gi·∫£i quy·∫øt tranh ch·∫•p b·∫±ng Tr·ªçng t√†...
137,0XKeP3929n,Kinh ph√≠ ho·∫°t ƒë·ªông c·ªßa c∆° quan thanh tra do ng...,Lu·∫≠t Thanh tra 2022,112,Kinh ph√≠ ho·∫°t ƒë·ªông c·ªßa c∆° quan thanh tra; ch·∫ø ...
138,zw1vs0Odbf,"Theo Lu·∫≠t Gi√°o d·ª•c nƒÉm 2019, m·ªôt trong nh·ªØng y...",Lu·∫≠t Gi√°o d·ª•c 2019,7,"Y√™u c·∫ßu v·ªÅ n·ªôi dung, ph∆∞∆°ng ph√°p gi√°o d·ª•c\n\n1..."


In [18]:
merge_df = merge_df[~merge_df['statement'].isnull()]
merge_df

Unnamed: 0,example_id,label,statement,law_id,article_id,segment,text
0,q9zjh7Uw7Q,No,Ng∆∞·ªùi xem d∆∞·ªõi 16 tu·ªïi ƒë∆∞·ª£c xem phim c√≥ n·ªôi du...,Lu·∫≠t ƒêi·ªán ·∫£nh 2022,32,Ng∆∞·ªùi xem d∆∞·ªõi 16 tu·ªïi ƒë∆∞·ª£c xem phim c√≥ n·ªôi_du...,Ph√¢n lo·∫°i phim\n\n1. Phim ƒë∆∞·ª£c ph√¢n lo·∫°i theo ...
1,bu3nyKAUNT,No,Ng∆∞·ªùi xem t·ª´ 18 tu·ªïi tr·ªü l√™n ƒë∆∞·ª£c ph√°p xem phi...,Lu·∫≠t ƒêi·ªán ·∫£nh 2022,32,Ng∆∞·ªùi xem t·ª´ 18 tu·ªïi tr·ªü l√™n ƒë∆∞·ª£c ph√°p xem phi...,Ph√¢n lo·∫°i phim\n\n1. Phim ƒë∆∞·ª£c ph√¢n lo·∫°i theo ...
2,aS4Oxqxklj,Yes,Ng∆∞·ªùi t·ª´ 16 tu·ªïi tr·ªü l√™n ƒë∆∞·ª£c ph√©p xem phim c√≥...,Lu·∫≠t ƒêi·ªán ·∫£nh 2022,32,Ng∆∞·ªùi t·ª´ 16 tu·ªïi tr·ªü l√™n ƒë∆∞·ª£c ph√©p xem phim c√≥...,Ph√¢n lo·∫°i phim\n\n1. Phim ƒë∆∞·ª£c ph√¢n lo·∫°i theo ...
3,E3nezW9AK0,No,Phim lo·∫°i C l√† phim ƒë∆∞·ª£c ph√©p ph·ªï bi·∫øn ƒë·∫øn ng∆∞...,Lu·∫≠t ƒêi·ªán ·∫£nh 2022,32,Phim lo·∫°i C l√† phim ƒë∆∞·ª£c ph√©p ph·ªï_bi·∫øn ƒë·∫øn ng∆∞...,Ph√¢n lo·∫°i phim\n\n1. Phim ƒë∆∞·ª£c ph√¢n lo·∫°i theo ...
4,ckQFn8y202,No,"Trong v√≤ng 03 ng√†y l√†m vi·ªác, k·ªÉ t·ª´ ng√†y ng∆∞·ªùi ...","Lu·∫≠t Ph√≤ng, ch·ªëng ma t√∫y 2021",30,"Trong v√≤ng 03 ng√†y l√†m_vi·ªác , k·ªÉ t·ª´ ng√†y ng∆∞·ªùi...","Cai nghi·ªán ma t√∫y t·ª± nguy·ªán t·∫°i gia ƒë√¨nh, c·ªông..."
...,...,...,...,...,...,...,...
71,InehZtAZn9,Yes,Vi√™n ch·ª©c b·ªã k·ª∑ lu·∫≠t t·ª´ khi·ªÉn tr√°ch ƒë·∫øn c√°ch c...,Lu·∫≠t Vi√™n ch·ª©c 2010,56,Vi√™n_ch·ª©c b·ªã k·ª∑_lu·∫≠t t·ª´ khi·ªÉn_tr√°ch ƒë·∫øn c√°ch_c...,C√°c quy ƒë·ªãnh kh√°c li√™n quan ƒë·∫øn vi·ªác k·ª∑ lu·∫≠t v...
72,3lEnngVd8Z,Yes,Vi√™n ch·ª©c b·ªã khi·ªÉn tr√°ch th√¨ th·ªùi h·∫°n n√¢ng l∆∞∆°...,Lu·∫≠t Vi√™n ch·ª©c 2010,9,Vi√™n_ch·ª©c b·ªã khi·ªÉn_tr√°ch th√¨ th·ªùi_h·∫°n n√¢ng l∆∞∆°...,ƒê∆°n v·ªã s·ª± nghi·ªáp c√¥ng l·∫≠p v√† c∆° c·∫•u t·ªï ch·ª©c qu...
73,6nfjyV5thx,No,"C∆° quan nh√† n∆∞·ªõc kh√¥ng c√≥ th·∫©m quy·ªÅn theo d√µi,...","Lu·∫≠t Ph√≤ng, ch·ªëng ma t√∫y 2021",14,C∆°_quan nh√†_n∆∞·ªõc kh√¥ng c√≥ th·∫©m_quy·ªÅn theo_d√µi ...,"Ki·ªÉm so√°t ho·∫°t ƒë·ªông v·∫≠n chuy·ªÉn ch·∫•t ma t√∫y, ti..."
74,TisBomZhjP,Yes,C√°ch ch·ª©c l√† m·ªôt trong c√°c h√¨nh th·ª©c x·ª≠ l√Ω k·ª∑ ...,Lu·∫≠t Vi√™n ch·ª©c 2010,52,C√°ch_ch·ª©c l√† m·ªôt trong c√°c h√¨nh_th·ª©c x·ª≠_l√Ω k·ª∑_...,C√°c h√¨nh th·ª©c k·ª∑ lu·∫≠t ƒë·ªëi v·ªõi vi√™n ch·ª©c \n\n1....


In [19]:
merge_df['inputs'] = 'hypothesis: ' + merge_df['statement'] + ' premise: '+ merge_df['text']

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  merge_df['inputs'] = 'hypothesis: ' + merge_df['statement'] + ' premise: '+ merge_df['text']


In [20]:
merge_df_test['inputs'] = 'hypothesis: ' + merge_df_test['statement'] + 'premise: '+ merge_df_test['text']

In [21]:
def preprocess_function(examples):
    model_inputs = tokenizer(
        examples["inputs"], max_length=1024, truncation=True, padding=True
    )
    
    
    labels = tokenizer(
        examples["labels"], max_length=256, truncation=True, padding=True
    )
    model_inputs['labels'] = labels['input_ids']
    model_inputs['input_ids'] = model_inputs['input_ids']
    return model_inputs

In [22]:
dict_obj = {'inputs': merge_df['inputs'], 'labels': merge_df['label']}
dataset = Dataset.from_dict(dict_obj)
train_data = dataset.map(preprocess_function, batched=True, remove_columns=['inputs'], num_proc=8)

                

#1:   0%|          | 0/1 [00:00<?, ?ba/s]

#4:   0%|          | 0/1 [00:00<?, ?ba/s]

#3:   0%|          | 0/1 [00:00<?, ?ba/s]

#7:   0%|          | 0/1 [00:00<?, ?ba/s]

#0:   0%|          | 0/1 [00:00<?, ?ba/s]

#2:   0%|          | 0/1 [00:00<?, ?ba/s]

#5:   0%|          | 0/1 [00:00<?, ?ba/s]

#6:   0%|          | 0/1 [00:00<?, ?ba/s]

In [24]:
len(train_data.__getitem__(1)['input_ids'])

893

In [25]:
len(train_data)

76

In [26]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt")


training_args = Seq2SeqTrainingArguments("tmp/",
                                      do_train=True,
                                      do_eval=False,
                                      num_train_epochs=15,
                                      learning_rate=1e-5,
                                      warmup_ratio=0.05,
                                      weight_decay=0.01,
                                      per_device_train_batch_size=2,
                                      per_device_eval_batch_size=4,
                                      logging_dir='./log',
                                      group_by_length=True,
                                      save_strategy="epoch",
                                      save_total_limit=3,
                                      #eval_steps=1,
                                      #evaluation_strategy="steps",
                                      # evaluation_strategy="no",
                                      fp16=True,
                                      )

## Training

In [30]:
# trainer = Seq2SeqTrainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_data,
#     data_collator=data_collator,
# )

# trainer.train()

## Inference

In [20]:
from datasets import load_metric
metric = load_metric("rouge")

  metric = load_metric("rouge")


In [31]:
model = AutoModelForSeq2SeqLM.from_pretrained("./tmp/checkpoint-285")

loading configuration file ./tmp/checkpoint-285/config.json
Model config T5Config {
  "_name_or_path": "./tmp/checkpoint-285",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "relu",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": false,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "torch_dtype": "float32",
  "transformers_version": "4.26.0",
  "use_cache": true,
  "vocab_size": 36096
}

loading weights file ./tmp/checkpoint-285/pytorch_model.bin
Generate config GenerationConfig {
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "t

In [32]:
model.to(device)

T5ForConditionalGeneration(
  (shared): Embedding(36096, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(36096, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [24]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt")

In [25]:
import tqdm
import torch 
import numpy as np
metrics = load_metric('rouge')

max_target_length = 256
dataloader = torch.utils.data.DataLoader(test_data, collate_fn=data_collator, batch_size=32)

predictions = []
references = []
for i, batch in enumerate(dataloader):
  outputs = model.generate(
      input_ids=batch['input_ids'].to('cuda'),
      max_length=max_target_length,
      attention_mask=batch['attention_mask'].to('cuda'),
  )
  outputs = [tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in outputs]

  labels = np.where(batch['labels'] != -100,  batch['labels'], tokenizer.pad_token_id)
  actuals = [tokenizer.decode(out, clean_up_tokenization_spaces=False, skip_special_tokens=True) for out in labels]
  predictions.extend(outputs)
  references.extend(actuals)
  metrics.add_batch(predictions=outputs, references=actuals)


metrics.compute()

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'rouge1': AggregateScore(low=Score(precision=0.9988018200762616, recall=0.998607608406491, fmeasure=0.9986347304154567), mid=Score(precision=0.9993805976766871, recall=0.9992200939966306, fmeasure=0.9992122872988795), high=Score(precision=0.999738926576217, recall=0.999624013478762, fmeasure=0.9995694423257271)),
 'rouge2': AggregateScore(low=Score(precision=0.9981314400993173, recall=0.9978736809435135, fmeasure=0.9978791822185674), mid=Score(precision=0.9989075108628183, recall=0.9987212911235258, fmeasure=0.9986936873389385), high=Score(precision=0.9995065952824334, recall=0.9993792675356922, fmeasure=0.9992694402121775)),
 'rougeL': AggregateScore(low=Score(precision=0.9988835018178597, recall=0.9986260973663208, fmeasure=0.9987033673961049), mid=Score(precision=0.9993805976766871, recall=0.9992294049835948, fmeasure=0.9992270898974809), high=Score(precision=0.9997396027312231, recall=0.9996311519021017, fmeasure=0.9995801538881147)),
 'rougeLsum': AggregateScore(low=Score(precisi

In [26]:
correct = 0
correct += sum(o==a for o, a in zip(predictions, references))
correct

5341

In [27]:
correct/ len(predictions)

0.9945996275605214

In [30]:
a= next(iter(dataloader))

In [31]:
tokenizer.decode(a['input_ids'][0], skip_special_tokens=True)

'nguy·ªÖn vƒÉn ti·∫øn th√¨ d·∫° b√™n kh√¥ng cho'

In [52]:
idx = 3
merge_df_test.iloc[idx]['inputs']

'hypothesis: Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c gia ƒë√¨nh ƒë∆∞·ª£c di·ªÖn ra v√†o m√πa h√® trong nƒÉmpremise: Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c gia ƒë√¨nh\n\n1. Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c gia ƒë√¨nh ƒë∆∞·ª£c t·ªï ch·ª©c v√†o th√°ng 6 h·∫±ng nƒÉm ƒë·ªÉ th√∫c ƒë·∫©y ho·∫°t ƒë·ªông ph√≤ng, ch·ªëng b·∫°o l·ª±c gia ƒë√¨nh v√† t√¥n vinh gi√° tr·ªã gia ƒë√¨nh. \n\n2. B·ªô VƒÉn h√≥a, Th·ªÉ thao v√† Du l·ªãch ch·ªß tr√¨, ph·ªëi h·ª£p v·ªõi c∆° quan, t·ªï ch·ª©c c√≥ li√™n quan ƒë·ªÉ ch·ªâ ƒë·∫°o, h∆∞·ªõng d·∫´n v√† t·ªï ch·ª©c th·ª±c hi·ªán Th√°ng h√†nh ƒë·ªông qu·ªëc gia ph√≤ng, ch·ªëng b·∫°o l·ª±c gia ƒë√¨nh.'

In [53]:
t = merge_df_test.iloc[idx]['inputs']
b = tokenizer(t, return_tensors='pt')
b

{'input_ids': tensor([[12957,  2938, 12508, 35862,  1867,   235,   256,   405,   254,   464,
         35790,   871,  1855,   374,   254,   721,    74,   769,   170,   162,
          1630,  2369,    80,   128,  3738,   428,  4918, 35862,  1867,   235,
           256,   405,   254,   464, 35790,   871,  1855,   374,   254,   721,
            40, 35792,  1867,   235,   256,   405,   254,   464, 35790,   871,
          1855,   374,   254,   721,    74,   655,   311,   162,   287,   163,
          4307,   128,   180,  1451,  1646,   758,   256,   464, 35790,   871,
          1855,   374,   254,   721,    39,  1973,  3937,   774,   747,   254,
           721, 35792,    60, 35792,   656,   652,   837, 35790,  3811,  2622,
            39,  2497,  1114,   436,  1760, 35790,  1271,   373,   123,   316,
           206, 35790,   655,   311,    71,   483,   206,   180,   337,   679,
         35790,  1184,   861,    39,   655,   311,   305,   219,  1867,   235,
           256,   405,   254,   464, 3

In [54]:
b['input_ids'].shape

torch.Size([1, 132])

In [55]:
outputs = model.generate(
      input_ids=b['input_ids'].to('cuda:1'),
      max_length=20,
      attention_mask=b['attention_mask'].to('cuda:1'),
  )
outputs

Generate config GenerationConfig {
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "transformers_version": "4.26.0"
}



tensor([[    0, 19074,     1]], device='cuda:1')

In [56]:
tokenizer.decode(outputs[0], skip_special_tokens=True)

'Yes'