In [1]:
# !pip install pandas
import pandas as pd
import gc
import json
import math
import torch
from transformers import DebertaTokenizer, DebertaForSequenceClassification
from transformers import T5Tokenizer, T5ForSequenceClassification
from tqdm.auto import tqdm
import os

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32  # Uses ~5GB VRAM

# tokenizer = DebertaTokenizer.from_pretrained('microsoft/deberta-base')
# ckpt_dir = 'deberta_results/deberta_ckpt/checkpoint-12000'
# model = DebertaForSequenceClassification.from_pretrained(ckpt_dir).to(device)
ckpt_dir = 'deberta_results/flan_ckpt/checkpoint-4000'
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForSequenceClassification.from_pretrained('t5-base').to(device)

def tokenize_function(text):
    return tokenizer(text, padding='max_length', return_tensors='pt', truncation=True)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of T

In [4]:
def filter_page(html_page: str, objective: str, top_k=(1, 5, 10, 50)):
    with torch.no_grad():
        elements = [el.strip() for el in html_page.split('\n')]
        elements = [el for el in elements if el]
        prompts = [f'Objective: {objective}.\nElement: {element}' for element in elements]
        positive_logits = None
    
        for j in range(0, len(prompts), batch_size):
            ex = tokenize_function(prompts[j:j+batch_size]).to(device)
            out = model(**ex)
            cur = out.logits[:, 1]
            positive_logits = torch.cat((positive_logits, cur)) if positive_logits is not None else cur
            del cur, out, ex
            gc.collect() 
            torch.cuda.empty_cache()
    
        results = []
        for k in top_k:
            top_k_indices = sorted(range(len(positive_logits)), key=lambda i: positive_logits[i], reverse=True)[:k]
            top_k_labels = [elements[i] for i in top_k_indices]
            results.append((k, top_k_labels))
    
        return results

In [None]:
output_file = 'mind2web.txt'
file_path = 'data/mind2web.csv'
df = pd.read_csv(file_path)
val_start_idx =  math.floor(len(df) * 0.8)

with open(output_file, 'w') as wf:
    for i in tqdm(range(val_start_idx, len(df))):
        ex = df.iloc[len(df)-i-1]
        action_string = ex['ACTION']
        objective = ex['OBJECTIVE']
        wf.write(f'{len(df)-i-1}, Task: {objective}; Action: {action_string}\n')
        results = filter_page(ex['OBSERVATION'], objective)
        for k, top_k_labels in results:
            wf.write(f'k={k}: {top_k_labels}\n')
        wf.write('=' * 50 + '\n')

## Run on WebArena

In [3]:
with open('data/webarena_test.json', 'r') as f:
    webarena_data = json.load(f)
    id2objective = {}
    for d in webarena_data:
        id2objective[d['task_id']] = d['intent']

In [5]:
import os
import glob

base_dir = 'data/webarena_acc_tree'
pattern = os.path.join(base_dir, 'render_*_tree_0.txt')
output_file = os.path.join(ckpt_dir, 'webarena_results.txt')

with open(output_file, 'w') as output:
    for file_path in tqdm(glob.glob(pattern)):
        with open(file_path, 'r') as f:
            html_content = '\n'.join([s for s in f.readlines()])
            # Extracting the task_id from the file name
            task_id = int(file_path.split('/')[-1].split('_')[1])
            objective = id2objective[task_id]

        output.write(f"{task_id} Task: {objective}\n")
        results = filter_page(html_content, objective)
        for k, top_k_labels in results:
            output.write(f"k={k}: {top_k_labels}\n")
        output.write('=' * 50 + '\n')

  0%|          | 0/96 [00:00<?, ?it/s]