In [1]:
from transformers import pipeline
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset_builder
from datasets import load_dataset
import numpy as np
import evaluate
import torch

In [None]:
!pip install transformers
!pip install evaluate
!pip install datasets

In [2]:
tokenizer = AutoTokenizer.from_pretrained("6mtx9/train_iwslt2017",model_max_length=128)
model = AutoModelForSeq2SeqLM.from_pretrained("6mtx9/train_iwslt2017")

# Getting dataset

In [3]:
def prepare_dataset(data):
    source_language = [value['ko'] for key, value in data.items()]
    target_language = [value['en'] for key, value in data.items()]
    return source_language, target_language

In [4]:
train  = load_dataset("msarmi9/korean-english-multitarget-ted-talks-task", split="train")
#test = load_dataset("Moo/korean-parallel-corpora", split="test")
validation = load_dataset("msarmi9/korean-english-multitarget-ted-talks-task", split="validation")

In [5]:
print(train)

Dataset({
    features: ['korean', 'english'],
    num_rows: 166215
})


In [6]:
print(validation)

Dataset({
    features: ['korean', 'english'],
    num_rows: 1958
})


# Tokenizer

In [15]:
inputs_train = tokenizer(train['korean'],return_tensors="pt", max_length=128, truncation=True,padding=True)
outputs_train = tokenizer(train['english'],return_tensors="pt",max_length=128, truncation=True,padding=True)

In [16]:
inputs_validation = tokenizer(validation['korean'],return_tensors="pt", max_length=128, truncation=True,padding=True)
outputs_validation = tokenizer(validation['english'],return_tensors="pt",max_length=128, truncation=True,padding=True)

In [17]:
train_dataset = torch.utils.data.TensorDataset(inputs_train.input_ids, inputs_train.attention_mask, outputs_train.input_ids, outputs_train.attention_mask)

In [18]:
validation_dataset = torch.utils.data.TensorDataset(inputs_validation.input_ids, inputs_validation.attention_mask, outputs_validation.input_ids, outputs_validation.attention_mask)

In [19]:
train_dataset[0]

(tensor([20004, 20015, 16765, 20018, 17935,    12, 22996, 59665, 12175, 20006,
           363,     4, 21819, 62048,  3862,   513,     7,     1,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [20]:
validation_dataset[0]

(tensor([20004, 35163, 24938,    11,   729,    10,     4,    24,    59,    58,
          1198,     4,  3784,     6,   248,    13, 20016,    90,    24,     9,
            13, 20006,    48,  3784,     6, 10620, 20015,     2, 20018,  1527,
           399,    83,  2155,     4,    24,  3862,   513,     7, 20005,     1,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [None]:
pip install accelerate -U

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime !')

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [22]:
from transformers import TrainingArguments, Trainer

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    fp16=True,
    remove_unused_columns=False,
    logging_dir="./logs",
)

def data_collator(batch):
        return {
            "input_ids": torch.stack([item[0] for item in batch]),
            "attention_mask": torch.stack([item[1] for item in batch]),
            "labels": torch.stack([item[2] for item in batch]),
        }

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)


import transformers
transformers.logging.set_verbosity_info()

trainer.train()

# Save the trained model
output_dir = "./train_translatorKO_EN"
trainer.save_model(output_dir)

***** Running training *****
  Num examples = 166,215
  Num Epochs = 1
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 10,389
  Number of trainable parameters = 296,696,448


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 