# Download Package

In [1]:
!pip install simplet5

Collecting simplet5
  Downloading simplet5-0.1.4.tar.gz (7.3 kB)
  Preparing metadata (setup.py) ... [?25l- done
Collecting transformers==4.16.2 (from simplet5)
  Downloading transformers-4.16.2-py3-none-any.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.8/61.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pytorch-lightning==1.5.10 (from simplet5)
  Downloading pytorch_lightning-1.5.10-py3-none-any.whl.metadata (31 kB)
Collecting pyDeprecate==0.3.1 (from pytorch-lightning==1.5.10->simplet5)
  Downloading pyDeprecate-0.3.1-py3-none-any.whl.metadata (10 kB)
Collecting setuptools==59.5.0 (from pytorch-lightning==1.5.10->simplet5)
  Downloading setuptools-59.5.0-py3-none-any.whl.metadata (5.0 kB)
Collecting sacremoses (from transformers==4.16.2->simplet5)
  Downloading sacremoses-0.1.1-py3-none-any.whl.metadata (8.3 kB)
Downloading pytorch_lightning-1.5.10-py3-none-any.whl (527 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━

# Import Library

In [2]:
# For data
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

# For model, tokenizer
from simplet5 import SimpleT5
import transformers
from transformers import AutoTokenizer, AutoConfig, T5ForConditionalGeneration

# For neccessary
import os
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings("ignore")

# Load dataset

In [3]:
train_df = pd.read_csv("/kaggle/input/t5-skinchat-data-tokenizer-preparation/data_train.csv")
test_df = pd.read_csv("/kaggle/input/t5-skinchat-data-tokenizer-preparation/data_test.csv")
train_df.shape, test_df.shape

((66083, 2), (7346, 2))

In [4]:
train_df['source_text'] = train_df['source_text'].astype(str)
train_df['target_text'] = train_df['target_text'].astype(str)
test_df['source_text'] = test_df['source_text'].astype(str)
test_df['target_text'] = test_df['target_text'].astype(str)

# Load Tokenizer

In [5]:
MODEL_NAME = "t5-base"
MAX_VOCAB = 20000

In [6]:
new_tokenizer = AutoTokenizer.from_pretrained("/kaggle/input/t5-skinchat-training-1/outputs/simplet5-epoch-9-train-loss-1.3308-val-loss-1.4264/")

# Load Model

In [7]:
model = T5ForConditionalGeneration.from_pretrained("/kaggle/input/t5-skinchat-training-1/outputs/simplet5-epoch-9-train-loss-1.3308-val-loss-1.4264/")

In [8]:
# instantiate
simplet5 = SimpleT5()

simplet5.model = model
simplet5.tokenizer = new_tokenizer

In [9]:
total_params = sum(p.numel() for p in simplet5.model.parameters())
print(f"Tổng số tham số của mô hình: {total_params}")
simplet5.model

Tổng số tham số của mô hình: 209531136


T5ForConditionalGeneration(
  (shared): Embedding(14716, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(14716, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dr

# Training model

In [10]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [11]:
simplet5.train(train_df=train_df,
               eval_df=test_df,
               source_max_token_len=100,
               target_max_token_len=250,
               batch_size=16, 
               max_epochs=10, 
               use_gpu=True,
               early_stopping_patience_epochs = 3,
               precision=16
              )

  from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [12]:
! (cd outputs; ls)

simplet5-epoch-0-train-loss-1.254-val-loss-1.4018
simplet5-epoch-1-train-loss-1.188-val-loss-1.3865
simplet5-epoch-2-train-loss-1.1289-val-loss-1.3752
simplet5-epoch-3-train-loss-1.0741-val-loss-1.3597
simplet5-epoch-4-train-loss-1.0214-val-loss-1.3624
simplet5-epoch-5-train-loss-0.9718-val-loss-1.3626
simplet5-epoch-6-train-loss-0.9248-val-loss-1.3602
