In [1]:
import sys
sys.path.append('..')
import transformer_modules
import transformers
import pytorch_lightning as pl
from datasets import load_from_disk
from util import load_config, get_class_object
from numpy.random import default_rng
import pandas as pd
rng = default_rng()

In [14]:
model_name = 't5-small-SGD'
exp_name = 'schema-guided'
ver_name = 'train-batch256-adafactor'
config_file = os.path.join('../models', model_name, 'logs', exp_name, ver_name, 'config.yaml')
# config = load_config(config_file)
test_name = 'test-batch256-adafactor+beam+max_length/annotated_cases.csv'
# config

In [30]:
def make_base_dir(model_name, exp_name, ver_name):
    return f"../models/{model_name}/logs/{exp_name}/{ver_name}"

def get_dataset_path(tokenizer_name, linearizer_name, split='test'):
    return os.path.join('../data', f'GEMSGD_{tokenizer_name}{linearizer_name}_{split}')

def sample_from_dataset(dataset_path, chosen_idx=[], seed=20, size=10):
    dataset = load_from_disk(dataset_path)
    dataset.set_format('torch', columns=['input_ids', 'labels'])
    dataset = dataset[chosen_idx] if chosen_idx else dataset[rng.choice(seed, size=size, replace=False)]
    return dataset

def get_checkpoint_path(model_name, exp_name, ver_name, epoch_num=7, step_num=5139, checkpoint_name=None):
    base_dir = make_base_dir(model_name, exp_name, ver_name)
    checkpoint_path = os.path.join(base_dir, 'checkpoints', checkpoint_name or f'epoch={epoch_num}-step={step_num}.ckpt')
    return checkpoint_path

def load_model(model_class_obj, checkpoint_path):
    model = model_class_obj.load_from_checkpoint(checkpoint_path=checkpoint_path)
    return model

# def load_tokenizer(tokenizer_class):
#     if tokenizer

def forward_model(model, dataset):
    outputs = model.forward(dataset['input_ids'])
    return outputs

def batch_decode(texts, tokenizer):
    decoded_text = tokenizer.batch_decode(texts, skip_special_tokens=True)
    return decoded_text

def construct_df(model, tokenizer, dataset):
    input_text = tokenizer.batch_decode(dataset['input_ids'], skip_special_tokens=True)
    target_text = tokenizer.batch_decode(dataset['labels'], skip_special_tokens=True)
    output = model.forward(dataset['input_ids'])
    pred_text = tokenizer.batch_decode(output, skip_special_tokens=True)
    df = pd.DataFrame()
    df['input'] = input_text
    df['target'] = target_text
    df['pred'] = pred_text
    return df

def predict_with_model(model_name, exp_name, ver_name, epoch_num=7, step_num=5139, checkpoint_name=None, chosen_idx=[], seed=20, size=10):
    config_file = os.path.join('../models', model_name, 'logs', exp_name, ver_name, 'config.yaml')
    config = load_config(config_file)
    model_class_object = get_class_object(transformer_modules, config['LightningModuleName'])
    checkpoint_path = get_checkpoint_path(model_name, exp_name, ver_name, epoch_num=epoch_num, step_num=step_num, checkpoint_name=checkpoint_name)
    tokenizer_name = config['LightningDataModuleParas']['tokenizer_class']
    tokenizer_class = get_class_object(transformers, tokenizer_name)
    linearizer_class = config['LightningDataModuleParas']['linearizer_class']
    dataset_path = get_dataset_path(tokenizer_name, linearizer_class)

    model = load_model(model_class_object, checkpoint_path)
    tokenizer = tokenizer_class.from_pretrained(config['LightningDataModuleParas']['tokenizer_path'])
    test_dataset = sample_from_dataset(dataset_path, chosen_idx=chosen_idx, seed=seed, size=size)
    res_df = construct_df(model, tokenizer, test_dataset)
    return res_df

def get_bad_case(annotated_file, by='has_slot_error', limit=20):
    anno_df = pd.read_csv(annotated_file)
    if by == 'has_slot_error':
        return anno_df[anno_df[by] == True][:limit]
    elif by in {'PARENT-recall', 'PARENT-precision'}:
        return anno_df.sort_values(by)[:limit]
    else:
        raise ValueError("Invalid bad case clue!")


In [31]:
base_dir = make_base_dir(model_name, exp_name, ver_name)
get_bad_case('../models/t5-small-SGD/logs/schema-guided/test-batch256-adafactor+beam+max_length/annotated_cases.csv', by="PARENT-precision")

Unnamed: 0.1,Unnamed: 0,ref,pred,dialog_acts,domain,has_slot_error,PARENT-precision,PARENT-recall,PARENT-fscore
682,682,You can contact them at 702-678-5780.,Their phone number is 702-678-5780.,"[{'act': 4, 'slot': 'phone_number', 'values': ...",Travel,False,0.0,0.003,0.0
1242,1242,"You can call the property on 510-943-8264. No,...",It does not have a garage. The phone number i...,"[{'act': 4, 'slot': 'has_garage', 'values': ['...",Homes,False,0.0,0.003,0.0
5301,5301,It was 2018.,It came out in 2018.,"[{'act': 4, 'slot': 'year', 'values': ['2018']}]",Music,False,0.0,0.003,0.0
5795,5795,You can reach them by phone at 925-945-1221.,Their number is 925-945-1221.,"[{'act': 4, 'slot': 'phone_number', 'values': ...",Services,False,0.0,0.003,0.0
215,215,The ticket price is $45.,It costs $45 per ticket.,"[{'act': 4, 'slot': 'price_per_ticket', 'value...",Events,False,0.0,0.0,0.0
9260,9260,You can call them at 213-626-1901.,Their phone number is 213-626-1901.,"[{'act': 4, 'slot': 'phone_number', 'values': ...",Travel,False,0.0,0.003,0.0
9249,9249,Are there any particular categories of events ...,Are you interested in Music or Theater?,"[{'act': 13, 'slot': 'event_type', 'values': [...",Events,False,0.0,0.0,0.0
7503,7503,Tickets are $35.,The ticket costs $35.,"[{'act': 4, 'slot': 'price_per_ticket', 'value...",Events,False,0.0,0.003,0.0
3513,3513,The cost per day is $32.00.,It costs $32.00 per day.,"[{'act': 4, 'slot': 'price_per_day', 'values':...",RentalCars,False,0.0,0.0,0.0
3038,3038,"You can contact them at 707-575-5123, and your...",I've confirmed your appointment. The phone nu...,"[{'act': 4, 'slot': 'phone_number', 'values': ...",Services,False,0.0,0.003,0.0


In [11]:
res_df = predict_with_model(model_name, exp_name, ver_name, chosen_idx=[1,2])

In [12]:
res_df

Unnamed: 0,input,target,pred
0,CONFIRM ( Place to pick up the car = LGB Airpo...,You are picking up a hatchback from LGB Airpor...,Please confirm: You want me to reserve a Hatch...
1,OFFER ( Address of the house = 100 Capitol mal...,"There are 10 houses available, of which there ...",There are 10 houses available. There is a nice...


In [28]:
model_class_object = get_class_object(transformer_modules, config['LightningModuleName'])
checkpoint_path = get_checkpoint_path(model_name, exp_name, ver_name)
model = load_model(model_class_object, checkpoint_path)

In [45]:
tokenizer_name = config['LightningDataModuleParas']['tokenizer_class']
tokenizer_class = get_class_object(transformers, tokenizer_name)
tokenizer = tokenizer_class.from_pretrained(config['LightningDataModuleParas']['tokenizer_path'])


In [22]:
linearizer_class = config['LightningDataModuleParas']['linearizer_class']


In [33]:
dataset_path = get_dataset_pat h(tokenizer_class, linearizer_class)
test_dataset = sample_from_dataset(dataset_path)

In [35]:
outputs = forward_model(model, test_dataset)

In [50]:
res_df = construct_df(model, tokenizer, test_dataset)

In [51]:
res_df

Unnamed: 0,input,target,pred
0,REQ_MORE,What else can I do?,Is there anything else I can help you with?
1,INFORM ( The cost for renting the car per day ...,"Your rental has been booked, and you will pay ...",Your car has been booked. The cost of renting ...
2,INFORM ( Price per night of the house = $522 )...,Your reservation has been made. The total is $...,Your reservation is complete. The total price ...
3,REQUEST ( Type of cab ride = Luxury Regular ),Sure thing! Would you like a luxury ride or a ...,"Luxury, Regular, or something else?"
4,REQUEST ( Start date of the trip = March 1st ),Are you departing on March 1st?,Will you be leaving on March 1st?
5,OFFER ( The company that provides air transpor...,"Ok, there is a Southwest Airlines connecting f...",Southwest Airlines has a flight that leaves at...
6,OFFER ( Temperature in Fahrenheit = 83 ) OFFER...,There average temperature is 83 degrees Fahren...,The average temperature for the day should be ...
7,INFORM ( Contact number of the therapist = 510...,"Yes, the contact number they have listed is 51...",Their number is 510-797-3941.
8,OFFER ( Name of artist or play = Advanced Acti...,The Advanced Acting Scene Study at the TGW Act...,There is Advanced Acting Scene Study at TGW Ac...
9,GOODBYE,Thank you very much.,Have a great day.
