<a href="https://colab.research.google.com/github/0x71d3/hf-mt5-ja/blob/main/train_mt5_ja.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

tutorial

https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb

blog

https://huggingface.co/blog/how-to-generate

mt5

https://huggingface.co/transformers/v4.0.0/model_doc/mt5.html#mt5forconditionalgeneration

https://huggingface.co/transformers/v4.0.0/model_doc/t5.html#t5tokenizer

install

In [None]:
!pip install transformers==4.0.0

In [None]:
!pip install sentencepiece

import

In [None]:
import csv

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

from transformers import MT5ForConditionalGeneration, T5Tokenizer

GPU

In [None]:
!nvidia-smi

In [None]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, path, tokenizer, max_len):
        sources = []
        targets = []
        with open(path, newline='') as f:
            reader = csv.reader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
            for row in reader:
                sources.append(row[0])
                targets.append(row[1])
                
        self.batch = tokenizer.prepare_seq2seq_batch(
            src_texts=sources,
            tgt_texts=targets,
            max_length=max_len,
            return_tensors='pt'
        )

    def __len__(self):
        return self.batch['input_ids'].size(0)

    def __getitem__(self, index):
        input_ids = self.batch['input_ids'][index]
        attention_mask = self.batch['attention_mask'][index]
        labels = self.batch['labels'][index]

        return {
            'input_ids': input_ids.to(dtype=torch.long), 
            'attention_mask': attention_mask.to(dtype=torch.long), 
            'labels': labels.to(dtype=torch.long),
        }

train

In [None]:
def train(epoch, tokenizer, model, device, loader, optimizer):
    model.train()

    for i, batch in enumerate(loader):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device, dtype=torch.long)
        attention_mask = batch['attention_mask'].to(device, dtype=torch.long)
        labels = batch['labels'].to(device, dtype=torch.long)

        # outputs = model(**batch)
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs.loss

        if i % 50 == 0:
            print(f'Epoch: {epoch}, Loss: {loss.item()}')
        
        loss.backward()
        optimizer.step()

validation

In [None]:
def validate(epoch, tokenizer, model, device, loader):
    model.eval()

    predictions = []
    actuals = []
    with torch.no_grad():
        for i, batch in enumerate(loader):
            input_ids = batch['input_ids'].to(device, dtype=torch.long)
            attention_mask = batch['attention_mask'].to(device, dtype=torch.long)
            labels = batch['labels'].to(device, dtype=torch.long)

            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask, 
                max_length=50, 
                num_beams=5, 
                no_repeat_ngram_size=2, 
                early_stopping=True
            )
            preds = [
                tokenizer.decode(generated_id, skip_special_tokens=True)
                for generated_id in generated_ids
            ]
            target = [
                tokenizer.decode(label, skip_special_tokens=True)
                for label in labels
            ]
            
            if i % 10 == 0:
                print(f'Completed {i}')

            predictions.extend(preds)
            actuals.extend(target)
    return predictions, actuals

main

In [None]:
def main():
    TRAIN_BATCH_SIZE = 2
    VALID_BATCH_SIZE = 4
    TRAIN_EPOCHS = 2
    # VAL_EPOCHS = 1
    LEARNING_RATE = 1e-4
    SEED = 42
    MAX_LEN = 32

    torch.manual_seed(SEED)  # pytorch random seed
    np.random.seed(SEED)  # numpy random seed
    torch.backends.cudnn.deterministic = True

    tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")

    train_set = CustomDataset(
        'drive/My Drive/stc-jpn/train.tsv',
        tokenizer,
        MAX_LEN
    )
    val_set = CustomDataset(
        'drive/My Drive/stc-jpn/val.tsv',
        tokenizer,
        MAX_LEN
    )

    train_loader = DataLoader(
        train_set,
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=True,
        num_workers=0
    )
    val_loader = DataLoader(
        val_set,
        batch_size=VALID_BATCH_SIZE,
        shuffle=False,
        num_workers=0
    )

    model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
    model = model.to(device)

    optimizer = torch.optim.Adam(
        params=model.parameters(),
        lr=LEARNING_RATE
    )

    for epoch in range(TRAIN_EPOCHS):
        train(epoch, tokenizer, model, device, train_loader, optimizer)

    predictions, actuals = validate(epoch, tokenizer, model, device, val_loader)
    with open('predictions.tsv', 'w') as f:
        for prediction, actual in zip(predictions, actuals):
            f.write(prediction + '\t' + actual + '\n')

In [None]:
if __name__ == '__main__':
    main()

copy

In [None]:
!cp predictions.tsv drive/My\ Drive/stc-jpn/