In [1]:
!pip install transformers datasets

Collecting datasets
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.1-py3-none-any.whl (471 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[

In [4]:
from datasets import load_dataset
from transformers import T5TokenizerFast, TFAutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, create_optimizer
import matplotlib.pyplot as plt

In [2]:
BATCH_SIZE=64
MAX_LENGTH=128

In [None]:
dataset_id="liweili/c4_200m"
dataset = load_dataset(dataset_id)

c4_200m.py:   0%|          | 0.00/2.79k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/937 [00:00<?, ?B/s]

The repository for liweili/c4_200m contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/liweili/c4_200m.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


data.zip:   0%|          | 0.00/14.9G [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
dataset

In [None]:
dataset['train'][0]

In [None]:
model_id="t5-small"
tokenizer=T5TokenizerFast.from_pretrained(model_id)

In [None]:
def preprocess_function(examples):

  inputs = [example for example in examples['input']]
  targets = [example for example in examples['output']]

  model_inputs = tokenizer(inputs, text_target=targets,max_length=MAX_LENGTH, truncation=True)
  return model_inputs

In [None]:
tokenized_dataset=dataset.map(preprocess_function,batched=True,remove_columns=dataset["train"].column_names)

In [None]:
tokenized_dataset

In [None]:
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_id)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer,model=model, return_tensors="tf")

In [None]:
train_dataset=tokenized_dataset["train"].to_tf_dataset(
    shuffle=True,
    batch_size=BATCH_SIZE,
    collate_fn=data_collator,
)

In [None]:
val_dataset=tokenized_dataset["test"].to_tf_dataset(
    shuffle=False,
    batch_size=BATCH_SIZE,
    collate_fn=data_collator,
)

In [None]:
model.summary()

In [None]:
num_epochs = 5
num_train_steps=len(train_dataset)*num_epochs

optimizer, schedule = create_optimizer(
    init_lr=2e-5,
    num_warmup_steps=0,
    num_train_steps=num_train_steps,
)
model.compile(optimizer=optimizer)

In [None]:
history=model.fit(
  train_dataset,
  validation_data=val_dataset,
  epochs=num_epochs
)

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model_loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

In [None]:
wrong_english=[
    "Dady hav'e eateing her foot",
    "DJ Sorryyouwastedyourmoneytobehere",
    "i used to like to swimming",
    "maybe we should organized a meetin with the people from unesco",
    "when are we goinge to start play football",
    "many a time rain fall in my city",
   ]
tokenized=tokenizer(
  wrong_english,
  padding="longest",
  max_length=MAX_LENGTH,
  truncation=True,
  return_tensors='pt'
)
out=model.generate(**tokenized, max_length=128)
print(out)

In [None]:
for i in range(len(wrong_english)):
  print(wrong_english[i]+"------------>"+tokenizer.decode(out[i], skip_special_tokens=True))