# Training of the Hierarchical model for miniwob++

In [1]:
!pip install sentencepiece

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
! pip install datasets transformers rouge-score nltk evaluate accelerate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
from huggingface_hub import notebook_login

notebook_login()

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [4]:
from datasets import load_dataset

# Load dataset regarding action-step sequences
raw_datasets_actions = load_dataset("LucasThil/miniwob_plusplus_hierarchical_training_actions")

# Load dataset regarding task of hierarchical planning
raw_datasets_planning = load_dataset("LucasThil/miniwob_plusplus_hierarchical_planning")



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
raw_datasets_actions, raw_datasets_planning

(DatasetDict({
     train: Dataset({
         features: ['history_episodes', 'instruction', 'actions', 'refs', 'keydown_text', 'subtask_completion'],
         num_rows: 42097
     })
 }),
 DatasetDict({
     train: Dataset({
         features: ['hierarchical_plans', 'standolone_instruction'],
         num_rows: 10960
     })
 }))

In [6]:
# Add extra column regarding condensed target_labels
# We append together the (action, ref, text)
# Where action is either of type 'click' or 'type'
# ref is the targeted element number to perform the action on
# and type is the text content in case the action is 'type'
def condensed_targets(dataset, label):
  condensed_targets = []
  for row in dataset[label]:
    target_text = row['keydown_text']
    if target_text is None:
      target_text = ""
    target = '{' + row['actions'] + ', ' + str(row['refs']) + ', ' + target_text + ', ' + row['subtask_completion'] + '}'
    condensed_targets.append(target)
  return condensed_targets

raw_datasets_actions = raw_datasets_actions.map(lambda example: {k: v if v == v else "" for k, v in example.items()})
raw_datasets_actions['train'] = raw_datasets_actions['train'].add_column("targets", condensed_targets(raw_datasets_actions, 'train'))
raw_datasets_actions['train'] = raw_datasets_actions['train'].remove_columns(['actions', 'refs', 'keydown_text', 'instruction', 'subtask_completion'])
raw_datasets_actions



DatasetDict({
    train: Dataset({
        features: ['history_episodes', 'targets'],
        num_rows: 42097
    })
})

In [7]:
# Format the hierchical planning dataset, later we'll marge it in the action dataset
# First add the Command specifying that this row is about asking the Model to Devise a plan
raw_datasets_planning = raw_datasets_planning.map(lambda row: {'standolone_instruction': 'Devise a plan for the following instruction: ' + row['standolone_instruction']})  
raw_datasets_planning = raw_datasets_planning.map(lambda example: {k: v if v == v else "" for k, v in example.items()})

# Rename the columns
raw_datasets_planning = raw_datasets_planning.rename_column('standolone_instruction', 'history_episodes')
raw_datasets_planning = raw_datasets_planning.rename_column('hierarchical_plans', 'targets')
raw_datasets_planning



DatasetDict({
    train: Dataset({
        features: ['targets', 'history_episodes'],
        num_rows: 10960
    })
})

In [8]:
from datasets import Dataset
from datasets import concatenate_datasets, DatasetDict

# select how much of the datset we want to use
# shuffle the elements, and split for a test set
def reduce_split_dataset(raw_datasets, coef):
  raw_datasets['train'] = raw_datasets['train'].shuffle()
  raw_datasets['train'] = raw_datasets["train"].select(range(int(len(raw_datasets["train"])*coef)))
  
  raw_datasets = raw_datasets["train"].train_test_split(train_size=0.95, seed=20)
  val_dataset = raw_datasets.pop("test")
  raw_datasets = raw_datasets["train"].train_test_split(train_size=0.95, seed=20)
  raw_datasets.update({'validation': val_dataset})

  return raw_datasets
  

In [9]:
raw_datasets_actions = reduce_split_dataset(raw_datasets_actions, 0.05)
raw_datasets_planning = reduce_split_dataset(raw_datasets_planning, 0.05)
raw_datasets_actions, raw_datasets_planning

(DatasetDict({
     train: Dataset({
         features: ['history_episodes', 'targets'],
         num_rows: 1898
     })
     test: Dataset({
         features: ['history_episodes', 'targets'],
         num_rows: 100
     })
     validation: Dataset({
         features: ['history_episodes', 'targets'],
         num_rows: 106
     })
 }),
 DatasetDict({
     train: Dataset({
         features: ['targets', 'history_episodes'],
         num_rows: 494
     })
     test: Dataset({
         features: ['targets', 'history_episodes'],
         num_rows: 26
     })
     validation: Dataset({
         features: ['targets', 'history_episodes'],
         num_rows: 28
     })
 }))

In [10]:

# Now combine both datasets in one
train_dataset = concatenate_datasets([raw_datasets_actions['train'], raw_datasets_planning['train']])
test_dataset = concatenate_datasets([raw_datasets_actions['test'], raw_datasets_planning['test']])
validate_dataset = concatenate_datasets([raw_datasets_actions['validation'], raw_datasets_planning['validation']])
dataset = DatasetDict({
    'train': train_dataset,
    'test': test_dataset,
    'validation': validate_dataset
})
dataset

DatasetDict({
    train: Dataset({
        features: ['history_episodes', 'targets'],
        num_rows: 2392
    })
    test: Dataset({
        features: ['history_episodes', 'targets'],
        num_rows: 126
    })
    validation: Dataset({
        features: ['history_episodes', 'targets'],
        num_rows: 134
    })
})

In [11]:
# Explore the dataset to ensure there are no NaNs, or lengths too big

# Show NaNs
nan_examples = dataset['train'].filter(lambda example: any(val != val for val in example.values()))

# Print the NaN examples
for example in nan_examples:
    print(f"Example ID: {example['id']}")
    for key, value in example.items():
        if value != value:  # Check for NaN values
            print(f"{key}: {value}")
    print('\n')

short_examples = dataset['train'].filter(lambda example: any(len(str(val)) < 5 for val in example.values()))

# Print the short examples
for example in short_examples:
    print(f"Example ID: {example}")
    for key, value in example.items():
        if len(str(value)) < 5:  # Check for short string values
            print(f"{key}: {value}")
    print('\n')
    break

Filter:   0%|          | 0/2392 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2392 [00:00<?, ? examples/s]

In [12]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [13]:
show_random_elements(dataset["train"])

Unnamed: 0,history_episodes,targets
0,"Select IIozst <body ref=""22""><div id=""wrap"" ref=""3""><div id=""area"" ref=""20""><div id=""boxes-left"" ref=""21""><label ref=""25""><input type=""checkbox"" id=""ch0"" ref=""4"" value=""true""></input><t ref=""-351"" text=""vfxrn""></t></label><label ref=""16""><input type=""checkbox"" id=""ch1"" ref=""7"" value=""false""></input><t ref=""-352"" text=""mzu""></t></label><label ref=""19""><input type=""checkbox"" id=""ch2"" ref=""11"" value=""true""></input><t ref=""-353"" text=""iiavcs""></t></label><label ref=""14""><input type=""checkbox"" id=""ch3"" ref=""9"" value=""false""></input><t ref=""-354"" text=""ydc""></t></label><label ref=""23""><input type=""checkbox"" id=""ch4"" ref=""10"" value=""true""></input><t ref=""-355"" text=""bupff""></t></label></div><div id=""boxes-right"" ref=""26""><label ref=""1""><input type=""checkbox"" id=""ch5"" ref=""24"" value=""false""></input><t ref=""-356"" text=""bqt04""></t></label><label ref=""13""><input type=""checkbox"" id=""ch6"" ref=""18"" value=""true""></input><t ref=""-357"" text=""iiozst""></t></label><label ref=""17""><input type=""checkbox"" id=""ch7"" ref=""5"" value=""true""></input><t ref=""-358"" text=""yh""></t></label><label ref=""6""><input type=""checkbox"" id=""ch8"" ref=""15"" value=""false""></input><t ref=""-359"" text=""qepnort""></t></label><label ref=""12""><input type=""checkbox"" id=""ch9"" ref=""2"" value=""true""></input><t ref=""-360"" text=""gi3kd""></t></label></div><button id=""subbtn"" classes=""secondary-action"" ref=""8"" text=""submit""></button></div></div></body>,","{click, 18, , stop}"
1,"Select Destination City Las Vegas <body ref=""32""><a classes=""banner"" ref=""233""><img classes=""alaska_logo"" ref=""210""></img></a><div id=""main"" ref=""44""><div id=""starth1"" classes=""header"" ref=""249""><h1 classes=""h1"" ref=""57"" text=""book a flight""></h1></div><form id=""searchform"" classes=""miniwob-main-form"" ref=""182""><div classes=""form-row group"" ref=""228""><div classes=""left"" ref=""60""><label classes=""css-lbl"" ref=""162"" text=""one-way""></label></div><div id=""use-miles-div"" classes=""right"" ref=""169""><label classes=""css-lbl"" ref=""121"" text=""use miles""></label></div></div><div id=""geo-from-wrap"" classes=""form-row geo-wrap"" ref=""111""><label classes=""text-label"" ref=""34""><t ref=""-79"" text=""from""></t><input_text id=""geo-from"" classes=""text-input-pad "" ref=""145"" value=""austin, tx (aus-austin/bergstrom intl.)""></input_text></label><div classes=""cleartxt"" ref=""245""></div><div id=""geo-from-button"" classes=""geo-button"" ref=""270""><span classes=""geo-img"" ref=""276"" text=""geolocation""></span></div></div><div id=""geo-to-wrap"" classes=""form-row geo-wrap"" ref=""243""><label classes=""text-label"" ref=""227""><t ref=""-80"" text=""to""></t><input_text id=""geo-to"" classes=""text-input-pad"" ref=""76""></input_text></label><div classes=""cleartxt"" ref=""139""></div><div id=""geo-to-button"" classes=""geo-button"" ref=""236""><span classes=""geo-img"" ref=""187"" text=""geolocation""></span></div></div><div classes=""form-row group"" ref=""125""><div classes=""datecnt left"" ref=""267""><label id=""lbldep-date"" classes=""text-label"" ref=""84"" text=""depart""></label><input_text id=""departure-date"" classes=""text-input calbg"" ref=""242""></input_text></div><div id=""rt-container"" classes=""datecnt right"" ref=""48""><label id=""lblret-date"" classes=""text-label"" ref=""239"" text=""return""></label><input_text id=""return-date"" classes=""text-input calbg"" ref=""176""></input_text></div></div><div classes=""form-row"" ref=""39""><label classes=""text-label"" ref=""66"" text=""number of passengers""></label><div id=""num-travelers-cnt"" classes=""left group"" ref=""263""><div classes=""tnum-button"" ref=""18"" text=""-""></div><div id=""tnum-display"" ref=""188"" text=""1""></div><div classes=""tnum-button"" ref=""225"" text=""+""></div></div><a id=""umnrlink"" classes=""right"" ref=""195"" text=""child traveling alone?""></a></div><div id=""moreoptionsdiv"" classes=""drop"" ref=""98""><h2 classes=""closed drop-head"" ref=""230"" text=""more search options""></h2></div><div id=""is-cal-div"" ref=""240""><label classes=""css-lbl"" ref=""218"" text=""view results on low-fare calendar""></label></div><div classes=""form-row last"" ref=""209""><input_submit classes=""button"" ref=""43"" value=""find flights""></input_submit></div></form></div><div id=""footer-wrapper"" ref=""207""><ul id=""footer-nav"" classes=""group"" ref=""50""><li id=""help-link"" ref=""189""><a ref=""27"" text=""faq""></a></li><li id=""full-link"" ref=""130""><a ref=""206"" text=""full site""></a></li><li id=""legal-link"" ref=""261""><a ref=""135"" text=""legal""></a></li><li id=""privacy-link"" ref=""68""><a ref=""112"" text=""privacy""></a></li><li id=""contact-link"" ref=""200""><a ref=""171"" text=""contact us""></a></li></ul><div id=""more-links"" ref=""83""></div><div id=""copyright"" ref=""265"" text=""© 2017 alaska airlines, inc.""></div></div></body>,","{click, 76, , continue}"
2,"Devise a plan for the following instruction: Switch between the tabs to find and click on the link ""maecenas."".","Switch between the tabs to find the link ""maecenas."".; Click on the link ""maecenas."".;"
3,"Select 01/12/2016 as the date <body ref=""6""><div id=""wrap"" ref=""135""><div id=""area"" ref=""12""><p ref=""183""><t ref=""-6"" text=""date:""></t><input_text id=""datepicker"" classes=""hasdatepicker"" ref=""178""></input_text></p><button id=""subbtn"" classes=""secondary-action"" ref=""153"" text=""submit""></button></div></div><div id=""ui-datepicker-div"" classes=""ui-datepicker ui-widget ui-widget-content ui-helper-clearfix ui-corner-all"" ref=""26""><div classes=""ui-datepicker-header ui-widget-header ui-helper-clearfix ui-corner-all"" ref=""17""><a classes=""ui-datepicker-prev ui-corner-all"" ref=""66""><span classes=""ui-icon ui-icon-circle-triangle-w"" ref=""116"" text=""prev""></span></a><a classes=""ui-datepicker-next ui-corner-all ui-state-disabled"" ref=""91""><span classes=""ui-icon ui-icon-circle-triangle-e"" ref=""129"" text=""next""></span></a><div classes=""ui-datepicker-title"" ref=""200""><span classes=""ui-datepicker-month"" ref=""80"" text=""december""></span><span classes=""ui-datepicker-year"" ref=""100"" text=""2016""></span></div></div><table classes=""ui-datepicker-calendar"" ref=""42""><thead ref=""67""><tr ref=""31""><th classes=""ui-datepicker-week-end"" ref=""205""><span ref=""207"" text=""su""></span></th><th ref=""133""><span ref=""117"" text=""mo""></span></th><th ref=""181""><span ref=""13"" text=""tu""></span></th><th ref=""27""><span ref=""57"" text=""we""></span></th><th ref=""161""><span ref=""103"" text=""th""></span></th><th ref=""60""><span ref=""174"" text=""fr""></span></th><th classes=""ui-datepicker-week-end"" ref=""88""><span ref=""87"" text=""sa""></span></th></tr></thead><tbody ref=""108""><tr ref=""70""><td classes="" ui-datepicker-week-end ui-datepicker-other-month ui-datepicker-unselectable ui-state-disabled"" ref=""204""></td><td classes="" ui-datepicker-other-month ui-datepicker-unselectable ui-state-disabled"" ref=""71""></td><td classes="" ui-datepicker-other-month ui-datepicker-unselectable ui-state-disabled"" ref=""98""></td><td classes="" ui-datepicker-other-month ui-datepicker-unselectable ui-state-disabled"" ref=""196""></td><td classes="" "" ref=""86""><a classes=""ui-state-default"" ref=""134"" text=""1""></a></td><td classes="" "" ref=""159""><a classes=""ui-state-default"" ref=""171"" text=""2""></a></td><td classes="" ui-datepicker-week-end "" ref=""19""><a classes=""ui-state-default"" ref=""187"" text=""3""></a></td></tr><tr ref=""140""><td classes="" ui-datepicker-week-end "" ref=""172""><a classes=""ui-state-default"" ref=""2"" text=""4""></a></td><td classes="" "" ref=""28""><a classes=""ui-state-default"" ref=""62"" text=""5""></a></td><td classes="" "" ref=""53""><a classes=""ui-state-default"" ref=""189"" text=""6""></a></td><td classes="" "" ref=""44""><a classes=""ui-state-default"" ref=""118"" text=""7""></a></td><td classes="" "" ref=""15""><a classes=""ui-state-default"" ref=""201"" text=""8""></a></td><td classes="" "" ref=""85""><a classes=""ui-state-default"" ref=""16"" text=""9""></a></td><td classes="" ui-datepicker-week-end "" ref=""34""><a classes=""ui-state-default"" ref=""123"" text=""10""></a></td></tr><tr ref=""127""><td classes="" ui-datepicker-week-end "" ref=""136""><a classes=""ui-state-default"" ref=""72"" text=""11""></a></td><td classes="" "" ref=""18""><a classes=""ui-state-default"" ref=""97"" text=""12""></a></td><td classes="" "" ref=""63""><a classes=""ui-state-default"" ref=""177"" text=""13""></a></td><td classes="" "" ref=""199""><a classes=""ui-state-default"" ref=""148"" text=""14""></a></td><td classes="" "" ref=""23""><a classes=""ui-state-default"" ref=""35"" text=""15""></a></td><td classes="" "" ref=""101""><a classes=""ui-state-default"" ref=""164"" text=""16""></a></td><td classes="" ui-datepicker-week-end "" ref=""113""><a classes=""ui-state-default"" ref=""105"" text=""17""></a></td></tr><tr ref=""139""><td classes="" ui-datepicker-week-end "" ref=""202""><a classes=""ui-state-default"" ref=""46"" text=""18""></a></td><td classes="" "" ref=""198""><a classes=""ui-state-default"" ref=""54"" text=""19""></a></td><td classes="" "" ref=""182""><a classes=""ui-state-default"" ref=""124"" text=""20""></a></td><td classes="" "" ref=""76""><a classes=""ui-state-default"" ref=""120"" text=""21""></a></td><td classes="" "" ref=""102""><a classes=""ui-state-default"" ref=""190"" text=""22""></a></td><td classes="" "" ref=""77""><a classes=""ui-state-default"" ref=""5"" text=""23""></a></td><td classes="" ui-datepicker-week-end "" ref=""59""><a classes=""ui-state-default"" ref=""206"" text=""24""></a></td></tr><tr ref=""142""><td classes="" ui-datepicker-week-end "" ref=""29""><a classes=""ui-state-default"" ref=""56"" text=""25""></a></td><td classes="" "" ref=""50""><a classes=""ui-state-default"" ref=""126"" text=""26""></a></td><td classes="" "" ref=""107""><a classes=""ui-state-default"" ref=""10"" text=""27""></a></td><td classes="" "" ref=""20""><a classes=""ui-state-default"" ref=""188"" text=""28""></a></td><td classes="" "" ref=""143""><a classes=""ui-state-default"" ref=""14"" text=""29""></a></td><td classes="" "" ref=""33""><a classes=""ui-state-default"" ref=""163"" text=""30""></a></td><td classes="" ui-datepicker-week-end ui-datepicker-days-cell-over "" ref=""83""><a classes=""ui-state-default ui-state-hover"" ref=""8"" text=""31""></a></td></tr></tbody></table></div></body>,","{click, 178, , continue}"
4,"Select words similar to genuine <body ref=""3""><div id=""wrap"" ref=""9""><div id=""area"" ref=""5""><div id=""boxes"" ref=""6""><label ref=""12""><input type=""checkbox"" id=""ch0"" ref=""7"" value=""true""></input><t ref=""-26"" text=""actual""></t></label><label ref=""11""><input type=""checkbox"" id=""ch1"" ref=""2"" value=""false""></input><t ref=""-27"" text=""immoral""></t></label><label ref=""15""><input type=""checkbox"" id=""ch2"" ref=""14"" value=""false""></input><t ref=""-28"" text=""rabbits""></t></label><label ref=""1""><input type=""checkbox"" id=""ch3"" ref=""10"" value=""false""></input><t ref=""-29"" text=""dumb""></t></label><label ref=""8""><input type=""checkbox"" id=""ch4"" ref=""4"" value=""false""></input><t ref=""-30"" text=""carve""></t></label></div><button id=""subbtn"" classes=""secondary-action"" ref=""13"" text=""submit""></button></div></div></body>,","{click, 7, , stop}"
5,"Select quiet <body ref=""9""><div id=""wrap"" ref=""14""><div id=""area"" ref=""7""><div id=""boxes"" ref=""15""><label ref=""11""><input type=""checkbox"" id=""ch0"" ref=""4"" value=""false""></input><t ref=""-139"" text=""conceal""></t></label><label ref=""8""><input type=""checkbox"" id=""ch1"" ref=""10"" value=""false""></input><t ref=""-140"" text=""serene""></t></label><label ref=""16""><input type=""checkbox"" id=""ch2"" ref=""3"" value=""true""></input><t ref=""-141"" text=""carve""></t></label><label ref=""13""><input type=""checkbox"" id=""ch3"" ref=""5"" value=""true""></input><t ref=""-142"" text=""end""></t></label><label ref=""6""><input type=""checkbox"" id=""ch4"" ref=""2"" value=""false""></input><t ref=""-143"" text=""funny""></t></label><label ref=""12""><input type=""checkbox"" id=""ch5"" ref=""9"" value=""false""></input><t ref=""-144"" text=""gigantic""></t></label></div><button id=""subbtn"" classes=""secondary-action"" ref=""1"" text=""submit""></button></div></div></body>,","{click, 5, , stop}"
6,Devise a plan for the following instruction: Select words similar to weird and click Submit.,Select words similar to weird; click submit;
7,Devise a plan for the following instruction: Enter 4:20 AM as the time and press submit.,Enter 4:20 AM as the time; click submit;
8,Devise a plan for the following instruction: Enter 6:43 PM as the time and press submit.,Enter 6:43 PM as the time; click submit;
9,"Select JU32ZOg <body ref=""11""><div id=""wrap"" ref=""2""><div id=""area"" ref=""22""><div id=""boxes-left"" ref=""12""><label ref=""13""><input type=""checkbox"" id=""ch0"" ref=""1"" value=""false""></input><t ref=""-172"" text=""uxgadjs""></t></label><label ref=""3""><input type=""checkbox"" id=""ch1"" ref=""24"" value=""false""></input><t ref=""-173"" text=""baadl9""></t></label><label ref=""10""><input type=""checkbox"" id=""ch2"" ref=""7"" value=""true""></input><t ref=""-174"" text=""ju32zog""></t></label><label ref=""8""><input type=""checkbox"" id=""ch3"" ref=""16"" value=""false""></input><t ref=""-175"" text=""ywxs""></t></label><label ref=""9""><input type=""checkbox"" id=""ch4"" ref=""18"" value=""false""></input><t ref=""-176"" text=""ytdb""></t></label></div><div id=""boxes-right"" ref=""20""><label ref=""17""><input type=""checkbox"" id=""ch5"" ref=""6"" value=""true""></input><t ref=""-177"" text=""varwjfp""></t></label><label ref=""14""><input type=""checkbox"" id=""ch6"" ref=""19"" value=""false""></input><t ref=""-178"" text=""prj""></t></label><label ref=""5""><input type=""checkbox"" id=""ch7"" ref=""21"" value=""true""></input><t ref=""-179"" text=""zmoo2bx""></t></label><label ref=""23""><input type=""checkbox"" id=""ch8"" ref=""4"" value=""false""></input><t ref=""-180"" text=""bxgrot""></t></label></div><button id=""subbtn"" classes=""secondary-action"" ref=""15"" text=""submit""></button></div></div></body>,","{click, 7, , stop}"


# Prepare Model

In [14]:
from transformers import AutoTokenizer

checkpoint = "google/t5-v1_1-base"
#checkpoint = 'LucasThil/T5_base_hierarchy1'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [15]:
tokenizer

T5TokenizerFast(name_or_path='google/t5-v1_1-base', vocab_size=32100, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>', '<extra_id_43>'

In [16]:
max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [doc for doc in examples["history_episodes"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding=True)

    # Setup the tokenizer for targets
    #with tokenizer.as_target_tokenizer():
    labels = tokenizer(examples["targets"], max_length=max_target_length, truncation=True, padding=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [17]:
preprocess_function(dataset['train'][:2])

{'input_ids': [[13218, 344, 8, 3808, 7, 12, 253, 8, 1309, 96, 553, 155, 9, 15, 1280, 3, 2, 6965, 6273, 17592, 2266, 121, 3155, 2, 8481, 3, 23, 26, 17592, 210, 5846, 121, 6273, 17592, 2368, 121, 3155, 2, 8481, 3, 23, 26, 17592, 498, 121, 2287, 17592, 76, 23, 18, 10309, 7, 3, 76, 23, 18, 13165, 49, 18, 1748, 3, 76, 23, 18, 12018, 2782, 3, 76, 23, 18, 12018, 2782, 18, 14819, 121, 6273, 17592, 2596, 121, 3155, 2, 83, 2287, 17592, 76, 23, 18, 10309, 7, 18, 14128, 3, 76, 23, 18, 13165, 49, 18, 1748, 3, 76, 23, 18, 15061, 49, 18, 60, 2244, 3, 76, 23, 18, 15061, 49, 18, 2482, 291, 12304, 3, 76, 23, 18, 12018, 2782, 18, 3313, 49, 121, 6273, 17592, 17395, 3155, 2, 40, 23, 2287, 17592, 76, 23, 18, 10309, 7, 18, 10309, 3, 76, 23, 18, 13165, 49, 18, 2916, 3, 76, 23, 18, 5540, 18, 31026, 3, 76, 23, 18, 10309, 3, 76, 23, 18, 10309, 7, 18, 6645, 3, 76, 23, 18, 5540, 18, 6645, 121, 6273, 17592, 2555, 121, 3155, 2, 9, 3, 23, 26, 17592, 76, 23, 18, 23, 26, 9169, 121, 2287, 17592, 76, 23, 18, 10309, 7, 18

In [18]:
tokenized_datasets = dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/2392 [00:00<?, ? examples/s]

Map:   0%|          | 0/126 [00:00<?, ? examples/s]

Map:   0%|          | 0/134 [00:00<?, ? examples/s]

In [19]:
print(len(tokenized_datasets['train'][0]['input_ids']), len(tokenized_datasets['train'][1]['input_ids']))


1024 1024


In [20]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

#model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

# new:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [21]:
import evaluate
# Use ROUGE as a metric
rouge = evaluate.load("rouge")

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

In [22]:
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [23]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
model.config.ctc_zero_infinity = True

Downloading pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [24]:
trainer = ""
import gc
import torch
torch.cuda.empty_cache()
gc.collect()

69

In [26]:
training_args = Seq2SeqTrainingArguments(
    output_dir="t5-v1_1-base_hierarchy1",
    evaluation_strategy="epoch",
    learning_rate=2e-6,
    #per_device_train_batch_size=16,
    #per_device_eval_batch_size=16,
    auto_find_batch_size=True,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
    #fp16_full_eval=True,
    push_to_hub=False,
    #logging_nan_inf_filter=True,
    dataloader_drop_last=True,
    #optim='adamw_torch'
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,No log,,0.0405,0.0173,0.036,0.0363,19.0


In [None]:
trainer.push_to_hub()

In [None]:
trainer.save_model("LucasThil/T5_base_hierarchy3")
trainer.save_config("LucasThil/T5_base_hierarchy3")

In [None]:
!transformers-cli upload LucasThil/T5_base_hierarchy3

In [None]:
model = ""
trainer = ""

In [None]:
tokenized_datasets['train'][0]['targets']

'{click, 14, , stop}'

In [None]:
# Test compute metrics
pred = tokenized_datasets['train'][0]['labels']
reference = tokenized_datasets['train'][0]['targets']
pred_labels = [pred, pred]
print(compute_metrics(pred_labels))
decoded_pred = tokenizer.batch_decode(pred, skip_special_tokens=True)
print(decoded_pred)

result = rouge.compute(predictions=decoded_pred, references=decoded_pred, use_stemmer=True)
print(result)
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for p in pred]
print(prediction_lens)
result["gen_len"] = np.mean(prediction_lens)
print(result)

{'rouge1': 0.0714, 'rouge2': 0.0, 'rougeL': 0.0714, 'rougeLsum': 0.0714, 'gen_len': 0.2381}
['', '', 'click', ',', '14,', '', ',', 'stop', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '']
{'rouge1': 0.07142857142857142, 'rouge2': 0.0, 'rougeL': 0.07142857142857142, 'rougeLsum': 0.07142857142857142}
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
{'rouge1': 0.07142857142857142, 'rouge2': 0.0, 'rougeL': 0.07142857142857142, 'rougeLsum': 0.07142857142857142, 'gen_len': 1.0}


# Test Tensorflow

In [None]:
from transformers import create_optimizer, AdamWeightDecay

optimizer = AdamWeightDecay(learning_rate=2e-5, weight_decay_rate=0.01)

In [None]:
from transformers import TFAutoModelForSeq2SeqLM

model = TFAutoModelForSeq2SeqLM.from_pretrained(checkpoint)

Downloading tf_model.h5:   0%|          | 0.00/892M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-base.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint, return_tensors="tf")

In [None]:
tf_train_set = model.prepare_tf_dataset(
    tokenized_datasets["train"],
    shuffle=True,
    batch_size=8,
    collate_fn=data_collator,
)

tf_test_set = model.prepare_tf_dataset(
    tokenized_datasets["test"],
    shuffle=False,
    batch_size=8,
    collate_fn=data_collator,
)

tf_validation_set = model.prepare_tf_dataset(
    tokenized_datasets["validation"],
    shuffle=False,
    batch_size=8,
    collate_fn=data_collator,
)

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [None]:
import tensorflow as tf

model.compile(optimizer=optimizer)

No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.


In [None]:
from transformers.keras_callbacks import KerasMetricCallback

metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_validation_set)

In [None]:
from transformers.keras_callbacks import PushToHubCallback

push_to_hub_callback = PushToHubCallback(
    output_dir="tf_T5_hierarchical_1",
    tokenizer=tokenizer,
)

/content/tf_T5_hierarchical_1 is already a clone of https://huggingface.co/LucasThil/tf_T5_hierarchical_1. Make sure you pull the latest changes with `repo.git_pull()`.


In [None]:
callbacks = [metric_callback, push_to_hub_callback]

In [None]:
model.fit(x=tf_train_set, validation_data=tf_test_set, epochs=3, callbacks=callbacks)

Epoch 1/3

ValueError: ignored

In [None]:
model = ""

In [None]:
optimizer = ""

# Inference

In [None]:
text = "Devise a plan for the following instruction: Select 12rh2X4, Hjfr, iMDDWUP, Jc0Wr4A, Rg and click Submit.	"

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("/content/T5_base_hierarchy4/checkpoint-10500")
inputs = tokenizer(text, return_tensors="pt").input_ids

In [None]:
from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("/content/T5_base_hierarchy4/checkpoint-10500")
outputs = model.generate(inputs, max_new_tokens=100, do_sample=False)

In [None]:
tokenizer.decode(outputs[0], skip_special_tokens=True)

''