In [1]:
from convlab.util import load_dataset
from pprint import pprint
import queue
from copy import deepcopy
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_statistics(dataset_name):
    data = load_dataset(dataset_name)
    table = []
    domain_cnt = {}
    for data_split in data:
        for dial in data[data_split]:
            if 'police' in dial['domains'] or 'hospital' in dial['domains']:
                continue
            domains = sorted(set(dial['domains']) - set(['general']))
            domain_cnt.setdefault(tuple(domains), {'train': 0, 'validation': 0, 'test': 0})
            domain_cnt[tuple(domains)][data_split] += 1
    for domains, stat in sorted(domain_cnt.items(), key=lambda x:len(x[0])*10000+sum(x[1].values())):
        s = sum(stat.values())
        if s < 10:
            continue
        res = {'domains':domains, 'all': s}
        for data_split in data:
            res[data_split] = stat[data_split]
        table.append(res)
    return table

In [3]:
table = get_statistics('sgd')

In [4]:
domains =  {}
multi_domain = {}
for res in table:
    if len(res['domains']) == 1:
        service = res['domains'][0]
        domain = service.split('_')[0]
        domains.setdefault(domain, {})
        domains[domain][service] = res['all']
    else:
        multi_domain[tuple(res['domains'])] = res['all']
pprint(domains)

{'Alarm': {'Alarm_1': 84},
 'Banks': {'Banks_1': 207, 'Banks_2': 42},
 'Buses': {'Buses_1': 195, 'Buses_2': 159, 'Buses_3': 88},
 'Calendar': {'Calendar_1': 169},
 'Events': {'Events_1': 289, 'Events_2': 572, 'Events_3': 76},
 'Flights': {'Flights_1': 800,
             'Flights_2': 185,
             'Flights_3': 94,
             'Flights_4': 87},
 'Homes': {'Homes_1': 349, 'Homes_2': 89},
 'Hotels': {'Hotels_1': 142, 'Hotels_2': 304, 'Hotels_3': 129, 'Hotels_4': 115},
 'Media': {'Media_1': 281, 'Media_2': 46, 'Media_3': 80},
 'Movies': {'Movies_1': 376, 'Movies_2': 47, 'Movies_3': 48},
 'Music': {'Music_1': 98, 'Music_2': 331, 'Music_3': 25},
 'Payment': {'Payment_1': 36},
 'RentalCars': {'RentalCars_1': 143, 'RentalCars_2': 111, 'RentalCars_3': 64},
 'Restaurants': {'Restaurants_1': 367, 'Restaurants_2': 146},
 'RideSharing': {'RideSharing_1': 106, 'RideSharing_2': 92},
 'Services': {'Services_1': 265,
              'Services_2': 185,
              'Services_3': 188,
              'Se

In [5]:
def get_multi_domain(combination):
    multi = 0
    cnt = 0
    for comb in multi_domain:
        if all([s in combination for s in comb]):
            multi += multi_domain[comb]
            cnt += 1
    return cnt, multi

def get_single_domain(combination):
    single = 0
    for service in combination:
        domain = service.split('_')[0]
        single += domains[domain][service]
    return single
    
def remove_combination(all_services, combination):
    new_all_services = deepcopy(all_services)
    for services in new_all_services:
        if len(services) > 1:
            for service in services:
                if service in combination:
                    services.remove(service)
                    break
    return new_all_services

In [6]:
all_combinations = []
def dfs(i, services, combination):
    if i == len(services):
        num_comb, num_multi = get_multi_domain(combination)
        if num_comb >= 42:
            num_single = get_single_domain(combination)
            all_combinations.append((deepcopy(combination),num_comb,num_multi,num_single))
        return
    for service in services[i]:
        combination.add(service)
        dfs(i+1, services, combination)
        combination.remove(service)
    return

In [7]:
all_services = [list(domains[domain].keys()) for domain in domains]
dfs(0,all_services,set())
len(all_combinations)

7577

In [8]:
q = queue.PriorityQueue()
for i in tqdm(range(len(all_combinations))):
    comb1,num_comb1,num_multi1,num_single1 = all_combinations[i]
    for j in range(i+1,len(all_combinations)):
        comb2,num_comb2,num_multi2,num_single2 = all_combinations[j]
        if len(comb1&comb2) == 6:
            prior = abs(num_multi1-num_multi2)+abs(num_single1-num_single2)
            q.put((prior, deepcopy(comb1), deepcopy(comb2), num_comb1,num_multi1,num_single1, num_comb2,num_multi2,num_single2))
q.qsize()

100%|██████████| 7577/7577 [00:16<00:00, 452.17it/s] 


942

In [9]:
res = q.get()
res

(376,
 {'Alarm_1',
  'Banks_1',
  'Buses_3',
  'Calendar_1',
  'Events_3',
  'Flights_4',
  'Homes_2',
  'Hotels_4',
  'Media_3',
  'Movies_1',
  'Music_3',
  'Payment_1',
  'RentalCars_3',
  'Restaurants_1',
  'RideSharing_2',
  'Services_1',
  'Trains_1',
  'Travel_1',
  'Weather_1'},
 {'Alarm_1',
  'Banks_2',
  'Buses_1',
  'Calendar_1',
  'Events_2',
  'Flights_3',
  'Homes_1',
  'Hotels_1',
  'Media_2',
  'Movies_3',
  'Music_1',
  'Payment_1',
  'RentalCars_1',
  'Restaurants_2',
  'RideSharing_1',
  'Services_4',
  'Trains_1',
  'Travel_1',
  'Weather_1'},
 42,
 3387,
 2456,
 42,
 3185,
 2630)