In [2]:
import pandas as pd
import gc
import math
import torch
from transformers import DebertaTokenizer, DebertaForSequenceClassification

In [4]:
!tar -xvf  'checkpoint-4000.tar'

deberta_results/checkpoint-4000/
deberta_results/checkpoint-4000/model.safetensors
deberta_results/checkpoint-4000/config.json
deberta_results/checkpoint-4000/trainer_state.json


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = DebertaTokenizer.from_pretrained('microsoft/deberta-base')
model = DebertaForSequenceClassification.from_pretrained('deberta_results/checkpoint-4000').to(device)

def tokenize_function(text):
    return tokenizer(text, padding='max_length', return_tensors='pt', truncation=True)

In [None]:
file_path = 'mind2web.csv'
df = pd.read_csv(file_path)

with torch.no_grad():
    for i in range(0, math.floor(len(df)*0.2)):
        ex = df.iloc[len(df)-i-1]
        elements = [el.strip() for el in ex['OBSERVATION'].split('\n')]
        action_string = ex['ACTION']
        objective = ex['OBJECTIVE']
        prompts = [f'Objective: {objective}.\nElement: {element}' for element in elements]
        positive_logits = None
        print(len(df)-i-1, objective, "; Action: ", action_string)
        for j in range(len(prompts)):
            ex = tokenize_function(prompts[j]).to(device)
            out = model(**ex)
            cur = out.logits[:, 1]
            positive_logits = torch.cat((positive_logits, cur)) if positive_logits is not None else cur
            del cur, out, ex
            gc.collect() 
            torch.cuda.empty_cache()
        for k in [1, 5, 10, 50]:
            top_k_indices = sorted(range(len(positive_logits)), key=lambda i: positive_logits[i], reverse=True)[:k]
            top_k_labels = [elements[i] for i in top_k_indices]
            print("k=", k, ": ", top_k_labels)

6914 Show me the most expensive things to do anywhere in December in a map ; Action:  click [346]
k= 1 :  ["[339] StaticText 'Anywhere'"]
k= 5 :  ["[339] StaticText 'Anywhere'", "[297] combobox '' hasPopup: menu expanded: False", "[296] combobox '' hasPopup: menu expanded: False", "[346] button 'View ListMap'", "[343] StaticText 'All Dates'"]
k= 10 :  ["[339] StaticText 'Anywhere'", "[297] combobox '' hasPopup: menu expanded: False", "[296] combobox '' hasPopup: menu expanded: False", "[346] button 'View ListMap'", "[343] StaticText 'All Dates'", "[444] StaticText 'All Dates'", "[423] button 'Sort by Price: High to Low'", "[263] button 'Filter: ONFilter'", "[293] textbox 'Anytime...' required: False", "[298] button 'Search Deals'"]
k= 50 :  ["[339] StaticText 'Anywhere'", "[297] combobox '' hasPopup: menu expanded: False", "[296] combobox '' hasPopup: menu expanded: False", "[346] button 'View ListMap'", "[343] StaticText 'All Dates'", "[444] StaticText 'All Dates'", "[423] button 'Sor

In [5]:
objective = 'Show most expensive cruise deals in Europe and Mediterranean.'
# TODO: Replace with WebArena elements?
elements = [
    '[180] Search bar',
    '[183] Login button',
    '[190] Image',
]
prompts = [f'Objective: {objective}.\nElement: {element}' for element in elements]

In [9]:
# TODO: Batch data to prevent OOM
k = 2
with torch.no_grad():
    ex = tokenize_function(prompts).to(device)
    print(ex)
    out = model(**ex)
    positive_logits = out.logits[:, 1]
    print(positive_logits)
    top_k_indices = sorted(range(len(positive_logits)), key=lambda i: positive_logits[i], reverse=True)[:k]
    print(top_k_indices)
    top_k_labels = [prompts[i] for i in top_k_indices]
    print(top_k_labels)

{'input_ids': tensor([[    1, 46674,  2088,  ...,     0,     0,     0],
        [    1, 46674,  2088,  ...,     0,     0,     0],
        [    1, 46674,  2088,  ...,     0,     0,     0]], device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')}
tensor([-0.4908, -2.5371, -2.7766], device='cuda:0')
[0, 1]
['Objective: Show most expensive cruise deals in Europe and Mediterranean..\nElement: [180] Search bar', 'Objective: Show most expensive cruise deals in Europe and Mediterranean..\nElement: [183] Login button']


In [32]:
# TODO: Take the top-k elements and keep them
positive_logits

tensor([-0.4908, -2.5371, -2.7766], device='cuda:0')