In [1]:
import json
from copy import deepcopy
import random
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer, util
from convlab.util import load_ontology

In [31]:
def find_coreference(state):
    value2slot = {}
    for domain in state:
        for slot in state[domain]:
            value = state[domain][slot]
            value2slot.setdefault(value, [])
            value2slot[value].append(f'{domain}-{slot}')
    return {value: value2slot[value] for value in value2slot if len(set([d_s.split('-')[0] for d_s in value2slot[value]])) > 1}

In [75]:
multi_domain_data = json.load(open('data/sgd/group0/multi_domain.json'))

In [76]:
slot_coref = {}
for dial in multi_domain_data:
    slot_coref_dial = set()
    for turn in dial['turns']:
        if 'state' in turn:
            value2slot = find_coreference(turn['state'])
            for value, ds_list in value2slot.items():
                ds_list = tuple(sorted(ds_list))
                t = (value, ds_list)
                if t not in slot_coref_dial:
                    slot_coref_dial.add(t)
    for value, ds_list in slot_coref_dial:
        slot_coref.setdefault(ds_list, [])
        slot_coref[ds_list].append(value)

In [68]:
len(slot_coref)

93

In [69]:
{slot: len(value) for slot, value in slot_coref.items() if len(value)>1}

{('attraction-area', 'hotel-area'): 227,
 ('hotel-book day', 'train-day'): 371,
 ('hotel-book people', 'train-book people'): 324,
 ('hotel-book people', 'hotel-book stay', 'train-book people'): 43,
 ('hotel-area', 'restaurant-area'): 349,
 ('restaurant-book day', 'train-day'): 486,
 ('hotel-stars', 'train-book people'): 22,
 ('hotel-book people', 'hotel-stars', 'train-book people'): 11,
 ('restaurant-book people', 'train-book people'): 390,
 ('attraction-area', 'restaurant-area'): 457,
 ('hotel-stars', 'restaurant-book people'): 22,
 ('hotel-price range', 'restaurant-price range'): 288,
 ('hotel-book people', 'restaurant-book people'): 356,
 ('hotel-book day', 'restaurant-book day'): 416,
 ('restaurant-book time', 'taxi-arrive by'): 533,
 ('restaurant-name', 'taxi-destination'): 519,
 ('attraction-name', 'taxi-departure'): 343,
 ('hotel-name', 'taxi-departure'): 454,
 ('attraction-name', 'taxi-destination'): 224,
 ('hotel-name', 'taxi-destination'): 213,
 ('restaurant-name', 'taxi-depa

In [4]:
def get_state_update(prev_state, cur_state):
    state = deepcopy(cur_state)
    for domain in prev_state:
        state.setdefault(domain, {})
        for slot in prev_state[domain]:
            if slot not in state[domain]:
                state[domain][slot] = ''
            elif prev_state[domain][slot] == state[domain][slot]:
                state[domain].pop(slot)
        if len(state[domain]) == 0:
            state.pop(domain)
    return state

In [2]:
single_domain_data = json.load(open('data/multiwoz21/single_domain.json'))

In [5]:
random.seed(42)
domain2slot2value = {}
for dial in single_domain_data:
    prev_state = {}
    for turn in dial['turns']:
        if 'state' in turn:
            state_update = get_state_update(prev_state, turn['state'])
            for domain in state_update:
                domain2slot2value.setdefault(domain, {})
                for slot, value in state_update[domain].items():
                    domain2slot2value[domain].setdefault(slot, [])
                    domain2slot2value[domain][slot].append(value)
            prev_state = turn['state']
num_sample_value = 10
for domain in domain2slot2value:
    for slot, value_set in domain2slot2value[domain].items():
        domain2slot2value[domain][slot] = random.sample(value_set, min(num_sample_value, len(value_set)))

In [9]:
domains = sorted(list(domain2slot2value.keys()))
df = pd.DataFrame([],index=domains,columns=domains)
df

Unnamed: 0,Alarm_1,Banks_1,Buses_3,Calendar_1,Events_3,Flights_4,Homes_2,Hotels_4,Media_3,Movies_1,Music_3,Payment_1,RentalCars_3,Restaurants_1,RideSharing_2,Services_1,Trains_1,Travel_1,Weather_1
Alarm_1,,,,,,,,,,,,,,,,,,,
Banks_1,,,,,,,,,,,,,,,,,,,
Buses_3,,,,,,,,,,,,,,,,,,,
Calendar_1,,,,,,,,,,,,,,,,,,,
Events_3,,,,,,,,,,,,,,,,,,,
Flights_4,,,,,,,,,,,,,,,,,,,
Homes_2,,,,,,,,,,,,,,,,,,,
Hotels_4,,,,,,,,,,,,,,,,,,,
Media_3,,,,,,,,,,,,,,,,,,,
Movies_1,,,,,,,,,,,,,,,,,,,


In [11]:
ontology = load_ontology('sgd')

In [14]:
model = SentenceTransformer('/zhangpai23/zhuqi/pre-trained-models/all-mpnet-base-v2')

In [63]:
domain2slot2embed = {}
for domain in domains:
    domain2slot2embed[domain] = {}
    for slot in ontology['state'][domain]:
        desc = ontology['domains'][domain]['slots'][slot]['description']
        embed = model.encode(desc)
        domain2slot2embed[domain][slot] = embed

In [64]:
slot_sims = []
for i in range(len(domains)):
    embed_i = domain2slot2embed[domains[i]]
    for j in range(i+1, len(domains)):
        embed_j = domain2slot2embed[domains[j]]
        sim_mat = {}
        for slot_i in embed_i:
            for slot_j in embed_j:
                slot_pair = f'{domains[i]}-{slot_i}@{domains[j]}-{slot_j}'
                sim_score = util.cos_sim(embed_i[slot_i], embed_j[slot_j]).item()
                sim_mat[f'{slot_i}-{slot_j}'] = sim_score
                slot_sims.append((slot_pair, sim_score))
        df.iloc[i,j] = sorted(sim_mat.items(),key=lambda x: x[1],reverse=True)


In [37]:
df.iloc[3,3] = pd.NA
df

Unnamed: 0,Alarm_1,Banks_1,Buses_3,Calendar_1,Events_3,Flights_4,Homes_2,Hotels_4,Media_3,Movies_1,Music_3,Payment_1,RentalCars_3,Restaurants_1,RideSharing_2,Services_1,Trains_1,Travel_1,Weather_1
Alarm_1,,"[(alarm_name-recipient_account_name, [tensor([...","[(alarm_time-departure_time, [tensor([0.5261])...","[(alarm_time-event_time, [tensor([0.5312])]), ...","[(alarm_time-time, [tensor([0.5312])]), (alarm...","[(alarm_time-outbound_departure_time, [tensor(...","[(alarm_name-property_name, [tensor([0.3040])]...","[(alarm_name-place_name, [tensor([0.3084])]), ...","[(alarm_name-title, [tensor([0.3001])]), (new_...","[(alarm_time-show_time, [tensor([0.5145])]), (...","[(alarm_name-track, [tensor([0.3566])]), (alar...","[(new_alarm_name-receiver, [tensor([0.2461])])...","[(new_alarm_time-pickup_time, [tensor([0.4137]...","[(alarm_name-restaurant_name, [tensor([0.3168]...","[(alarm_time-wait_time, [tensor([0.3057])]), (...","[(alarm_time-appointment_time, [tensor([0.4614...","[(alarm_time-journey_start_time, [tensor([0.48...","[(new_alarm_name-attraction_name, [tensor([0.3...","[(alarm_name-city, [tensor([0.2700])]), (alarm..."
Banks_1,,,"[(amount-num_passengers, [tensor([0.3327])]), ...","[(recipient_account_name-event_name, [tensor([...","[(recipient_account_name-venue, [tensor([0.242...","[(amount-price, [tensor([0.3296])]), (recipien...","[(recipient_account_name-address, [tensor([0.2...","[(recipient_account_name-street_address, [tens...","[(recipient_account_type-genre, [tensor([0.258...","[(amount-number_of_tickets, [tensor([0.2898])]...","[(recipient_account_name-artist, [tensor([0.29...","[(recipient_account_name-receiver, [tensor([0....","[(recipient_account_type-car_type, [tensor([0....","[(recipient_account_name-street_address, [tens...","[(recipient_account_name-destination, [tensor(...","[(recipient_account_name-street_address, [tens...","[(amount-total, [tensor([0.2734])]), (recipien...","[(recipient_account_type-category, [tensor([0....","[(recipient_account_name-city, [tensor([0.2633..."
Buses_3,,,,"[(departure_time-event_time, [tensor([0.5919])...","[(price-price_per_ticket, [tensor([0.7611])]),...","[(num_passengers-number_of_tickets, [tensor([0...","[(departure_date-visit_date, [tensor([0.5183])...","[(to_city-location, [tensor([0.5785])]), (depa...","[(to_city-title, [tensor([0.1951])]), (departu...","[(price-price, [tensor([0.8033])]), (num_passe...","[(to_station-device, [tensor([0.2964])]), (fro...","[(num_passengers-amount, [tensor([0.3776])]), ...","[(to_city-city, [tensor([0.4570])]), (departur...","[(departure_date-date, [tensor([0.4911])]), (n...","[(to_city-destination, [tensor([0.5701])]), (p...","[(departure_date-appointment_date, [tensor([0....","[(to_station-from_station, [tensor([0.8183])])...","[(to_city-location, [tensor([0.5863])]), (from...","[(to_city-city, [tensor([0.5500])]), (from_cit..."
Calendar_1,,,,,"[(event_time-time, [tensor([1.])]), (event_loc...","[(event_time-departure_date, [tensor([0.6173])...","[(event_date-visit_date, [tensor([0.5202])]), ...","[(event_date-check_in_date, [tensor([0.6681])]...","[(event_name-title, [tensor([0.4413])]), (even...","[(event_time-show_time, [tensor([0.6069])]), (...","[(event_name-artist, [tensor([0.4081])]), (eve...","[(event_name-receiver, [tensor([0.2097])]), (e...","[(event_date-end_date, [tensor([0.3158])]), (e...","[(event_date-date, [tensor([0.7632])]), (event...","[(event_location-destination, [tensor([0.3419]...","[(event_date-appointment_date, [tensor([0.5968...","[(event_time-journey_start_time, [tensor([0.61...","[(event_location-location, [tensor([0.5171])])...","[(event_date-date, [tensor([0.4321])]), (event..."
Events_3,,,,,,"[(price_per_ticket-price, [tensor([0.7486])]),...","[(venue_address-address, [tensor([0.6441])]), ...","[(number_of_tickets-number_of_rooms, [tensor([...","[(event_name-title, [tensor([0.4225])]), (even...","[(price_per_ticket-price, [tensor([0.9238])]),...","[(event_name-artist, [tensor([0.5871])]), (eve...","[(price_per_ticket-amount, [tensor([0.2941])])...","[(price_per_ticket-price_per_day, [tensor([0.3...","[(number_of_tickets-party_size, [tensor([0.685...","[(number_of_tickets-number_of_seats, [tensor([...","[(date-appointment_date, [tensor([0.5976])]), ...","[(number_of_tickets-number_of_adults, [tensor(...","[(city-location, [tensor([0.6118])]), (venue_a...","[(city-city, [tensor([0.5279])]), (date-date, ..."
Flights_4,,,,,,,"[(departure_date-visit_date, [tensor([0.6228])...","[(departure_date-check_in_date, [tensor([0.564...","[(seating_class-genre, [tensor([0.2168])]), (a...","[(price-price, [tensor([0.6755])]), (number_of...","[(departure_date-year, [tensor([0.2793])]), (d...","[(price-amount, [tensor([0.3588])]), (number_o...","[(return_date-end_date, [tensor([0.4634])]), (...","[(departure_date-date, [tensor([0.5904])]), (r...","[(price-ride_fare, [tensor([0.5492])]), (desti...","[(departure_date-appointment_date, [tensor([0....","[(destination_airport-from_station, [tensor([0...","[(destination_airport-location, [tensor([0.412...","[(destination_airport-city, [tensor([0.4895])]..."
Homes_2,,,,,,,,"[(address-street_address, [tensor([0.6800])]),...","[(property_name-title, [tensor([0.2757])]), (p...","[(area-location, [tensor([0.4850])]), (address...","[(property_name-artist, [tensor([0.2492])]), (...","[(address-receiver, [tensor([0.3230])]), (prop...","[(area-city, [tensor([0.4658])]), (visit_date-...","[(phone_number-phone_number, [tensor([0.5414])...","[(address-destination, [tensor([0.3640])]), (a...","[(visit_date-appointment_date, [tensor([0.5784...","[(visit_date-date_of_journey, [tensor([0.4656]...","[(area-location, [tensor([0.5410])]), (phone_n...","[(area-city, [tensor([0.5535])]), (property_na..."
Hotels_4,,,,,,,,,"[(place_name-title, [tensor([0.2237])]), (smok...","[(street_address-street_address, [tensor([0.54...","[(place_name-artist, [tensor([0.2453])]), (che...","[(street_address-receiver, [tensor([0.3123])])...","[(location-city, [tensor([0.5493])]), (price_p...","[(check_in_date-date, [tensor([0.8381])]), (ch...","[(number_of_rooms-number_of_seats, [tensor([0....","[(check_in_date-appointment_date, [tensor([0.6...","[(number_of_rooms-number_of_adults, [tensor([0...","[(location-location, [tensor([0.6436])]), (pho...","[(location-city, [tensor([0.5304])]), (street_..."
Media_3,,,,,,,,,,"[(title-movie_name, [tensor([0.8592])]), (titl...","[(title-track, [tensor([0.4742])]), (title-art...","[(genre-payment_method, [tensor([0.2725])]), (...","[(title-car_name, [tensor([0.2346])]), (genre-...","[(title-restaurant_name, [tensor([0.3741])]), ...","[(genre-ride_type, [tensor([0.2577])]), (genre...","[(title-stylist_name, [tensor([0.2406])]), (ti...","[(genre-class, [tensor([0.1590])]), (title-to_...","[(genre-category, [tensor([0.5094])]), (genre-...","[(title-city, [tensor([0.2877])]), (subtitle_l..."
Movies_1,,,,,,,,,,,"[(movie_name-track, [tensor([0.5609])]), (genr...","[(number_of_tickets-amount, [tensor([0.3567])]...","[(price-price_per_day, [tensor([0.4164])]), (l...","[(location-city, [tensor([0.5669])]), (street_...","[(number_of_tickets-number_of_seats, [tensor([...","[(location-city, [tensor([0.5239])]), (show_ti...","[(price-total, [tensor([0.6083])]), (number_of...","[(location-location, [tensor([0.6035])]), (str...","[(location-city, [tensor([0.5398])]), (theater..."


In [38]:
df.to_csv('slot_desc_sim.csv')

In [57]:
domain2slot2embed = {}
for domain in domains:
    domain2slot2embed[domain] = {}
    for slot in domain2slot2value[domain]:
        values = domain2slot2value[domain][slot]
        embeds = model.encode(values)
        embed = np.mean(embeds,axis=0)
        domain2slot2embed[domain][slot] = embed

slot_sims = []
for i in range(len(domains)):
    embed_i = domain2slot2embed[domains[i]]
    for j in range(i+1, len(domains)):
        embed_j = domain2slot2embed[domains[j]]
        sim_mat = {}
        for slot_i in embed_i:
            for slot_j in embed_j:
                slot_pair = f'{domains[i]}-{slot_i}@{domains[j]}-{slot_j}'
                sim_score = util.cos_sim(embed_i[slot_i], embed_j[slot_j]).item()
                sim_mat[f'{slot_i}-{slot_j}'] = sim_score
                slot_sims.append((slot_pair, sim_score))
        df.iloc[i,j] = sorted(sim_mat.items(),key=lambda x: x[1],reverse=True)


In [48]:
df.to_csv('value_set_sim.csv')

In [77]:
for slot, value in slot_coref.items():
    for s in slot:
        if 'True' in value:
        # if 'Buses_3' in s:
            print(slot, len(value))

('Hotels_4-smoking_allowed', 'RentalCars_3-add_insurance') 3
('Hotels_4-smoking_allowed', 'RentalCars_3-add_insurance') 3
('Payment_1-private_visibility', 'Trains_1-trip_protection') 7
('Payment_1-private_visibility', 'Trains_1-trip_protection') 7
('Buses_3-additional_luggage', 'RentalCars_3-add_insurance') 38
('Buses_3-additional_luggage', 'RentalCars_3-add_insurance') 38
('Buses_3-additional_luggage', 'Travel_1-free_entry') 3
('Buses_3-additional_luggage', 'Travel_1-free_entry') 3
('Buses_3-additional_luggage', 'Travel_1-free_entry', 'Travel_1-good_for_kids') 3
('Buses_3-additional_luggage', 'Travel_1-free_entry', 'Travel_1-good_for_kids') 3
('Buses_3-additional_luggage', 'Travel_1-free_entry', 'Travel_1-good_for_kids') 3
('Hotels_4-smoking_allowed', 'Travel_1-good_for_kids') 6
('Hotels_4-smoking_allowed', 'Travel_1-good_for_kids') 6
('Buses_3-additional_luggage', 'Travel_1-good_for_kids') 2
('Buses_3-additional_luggage', 'Travel_1-good_for_kids') 2
('Hotels_4-smoking_allowed', 'Trav

In [60]:
slot_sims = sorted(slot_sims, key=lambda x:x[1],reverse=True)
pd.DataFrame(slot_sims).to_csv('value_set_sim_all.csv')

In [65]:
slot_sims = sorted(slot_sims, key=lambda x:x[1],reverse=True)
pd.DataFrame(slot_sims).to_csv('slot_desc_sim_all.csv')

In [71]:
list(df.columns)

['Alarm_1',
 'Banks_1',
 'Buses_3',
 'Calendar_1',
 'Events_3',
 'Flights_4',
 'Homes_2',
 'Hotels_4',
 'Media_3',
 'Movies_1',
 'Music_3',
 'Payment_1',
 'RentalCars_3',
 'Restaurants_1',
 'RideSharing_2',
 'Services_1',
 'Trains_1',
 'Travel_1',
 'Weather_1']