In [None]:
!pip install transformers==4.3.2
import torch
import io
import torch.nn.functional as F
import random
import numpy as np
import time
import math
import datetime
import torch.nn as nn
from transformers import *
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

Collecting transformers==4.3.2
  Downloading transformers-4.3.2-py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 4.0 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 42.8 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 42.8 MB/s 
Installing collected packages: tokenizers, sacremoses, transformers
Successfully installed sacremoses-0.0.45 tokenizers-0.10.3 transformers-4.3.2


In [None]:
if torch.cuda.is_available():    
    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

No GPU available, using the CPU instead.


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')


Mounted at /content/gdrive


In [None]:
%cd "/content/gdrive/My Drive/"

/content/gdrive/My Drive


In [None]:
#--------------------------------
#  Transformer parameters
#--------------------------------
max_seq_length = 64
batch_size = 64

#--------------------------------
#  GAN-BERT specific parameters
#--------------------------------
# number of hidden layers in the generator, 
# each of the size of the output space
num_hidden_layers_g = 1; 
# number of hidden layers in the discriminator, 
# each of the size of the input space
num_hidden_layers_d = 1; 
# size of the generator's input noisy vectors
noise_size = 100
# dropout to be applied to discriminator's input vectors
out_dropout_rate = 0.2

# Replicate labeled data to balance poorly represented datasets, 
# e.g., less than 1% of labeled material
apply_balance = True

#--------------------------------
#  Optimization parameters
#--------------------------------
learning_rate_discriminator = 5e-5
learning_rate_generator = 5e-5
epsilon = 1e-8
num_train_epochs = 10
multi_gpu = True
# Scheduler
apply_scheduler = False
warmup_proportion = 0.1
# Print
print_each_n_step = 10

In [None]:
#------------------------------
#   The Generator as in 
#   https://www.aclweb.org/anthology/2020.acl-main.191/
#   https://github.com/crux82/ganbert
#------------------------------
class Generator(nn.Module):
    def __init__(self, noise_size=100, output_size=512, hidden_sizes=[512], dropout_rate=0.1):
        super(Generator, self).__init__()
        layers = []
        hidden_sizes = [noise_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])

        layers.append(nn.Linear(hidden_sizes[-1],output_size))
        self.layers = nn.Sequential(*layers)

    def forward(self, noise):
        output_rep = self.layers(noise)
        return output_rep

#------------------------------
#   The Discriminator
#   https://www.aclweb.org/anthology/2020.acl-main.191/
#   https://github.com/crux82/ganbert
#------------------------------
class Discriminator(nn.Module):
    def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1):
        super(Discriminator, self).__init__()
        self.input_dropout = nn.Dropout(p=dropout_rate)
        layers = []
        hidden_sizes = [input_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])

        self.layers = nn.Sequential(*layers) #per il flatten
        self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real.
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input_rep):
        input_rep = self.input_dropout(input_rep)
        last_rep = self.layers(input_rep)
        logits = self.logit(last_rep)
        probs = self.softmax(logits)
        return last_rep, logits, probs

In [None]:
def generate_data_loader(input_examples, label_masks, label_map, do_shuffle = False, balance_label_examples = False):
  '''
  Generate a Dataloader given the input examples, eventually masked if they are 
  to be considered NOT labeled.
  '''
  examples = []

  # Count the percentage of labeled examples  
  num_labeled_examples = 0
  for label_mask in label_masks:
    if label_mask: 
      num_labeled_examples += 1
  label_mask_rate = num_labeled_examples/len(input_examples)

  # if required it applies the balance
  for index, ex in enumerate(input_examples): 
    if label_mask_rate == 1 or not balance_label_examples:
      examples.append((ex, label_masks[index]))
    else:
      # IT SIMULATE A LABELED EXAMPLE
      if label_masks[index]:
        balance = int(1/label_mask_rate)
        balance = int(math.log(balance,2))
        if balance < 1:
          balance = 1
        for b in range(0, int(balance)):
          examples.append((ex, label_masks[index]))
      else:
        examples.append((ex, label_masks[index]))
  
  #-----------------------------------------------
  # Generate input examples to the Transformer
  #-----------------------------------------------
  input_ids = []
  input_mask_array = []
  label_mask_array = []
  label_id_array = []

  # Tokenization 
  for (text, label_mask) in examples:
    encoded_sent = tokenizer.encode(text[0], add_special_tokens=True, max_length=max_seq_length, padding="max_length", truncation=True)
    input_ids.append(encoded_sent)
    label_id_array.append(label_map[text[1]])
    label_mask_array.append(label_mask)
  
  # Attention to token (to ignore padded input wordpieces)
  for sent in input_ids:
    att_mask = [int(token_id > 0) for token_id in sent]                          
    input_mask_array.append(att_mask)
  # Convertion to Tensor
  input_ids = torch.tensor(input_ids) 
  input_mask_array = torch.tensor(input_mask_array)
  label_id_array = torch.tensor(label_id_array, dtype=torch.long)
  label_mask_array = torch.tensor(label_mask_array)

  # Building the TensorDataset
  dataset = TensorDataset(input_ids, input_mask_array, label_id_array, label_mask_array)

  if do_shuffle:
    sampler = RandomSampler
  else:
    sampler = SequentialSampler

  # Building the DataLoader
  return DataLoader(
              dataset,  # The training samples.
              sampler = sampler(dataset), 
              batch_size = batch_size) # Trains with this batch size.


In [None]:
#Predict
label_list = ['user_department',
 'user_country',
 'payday',
 'rewards_balance',
 'todo_list_update',
 'rollover_401k',
 'w2',
 'spending_history',
 'timer',
 'find_phone',
 'whisper_mode',
 'todo_list',
 'measurement_conversion',
 'flip_coin',
 'interest_rate',
 'translate',
 'freeze_account',
 'international_visa',
 'definition',
 'maybe',
 'carry_on',
 'pin_change',
 'pto_request_status',
 'car_rental',
 'insurance',
 'jump_start',
 'transfer',
 'share_location',
 'ingredients_list',
 'min_payment',
 'card_declined',
 'pto_balance',
 'change_volume',
 'uber',
 'directions',
 'account_blocked',
 'meal_suggestion',
 'update_playlist',
 'tell_joke',
 'last_maintenance',
 'restaurant_suggestion',
 'goodbye',
 'thank_you',
 'insurance_change',
 'plug_type',
 'pto_used',
 'pay_bill',
 'improve_credit_score',
 'bill_balance',
 'food_last',
 'cancel_reservation',
 'income',
 'exchange_rate',
 'repeat',
 'change_speed',
 'weather',
 'calendar',
 'do_you_have_pets',
 'order_checks',
 'calculator',
 'traffic',
 'shopping_list',
 'next_song',
 'roll_dice',
 'change_ai_name',
 'ingredient_substitution',
 'who_do_you_work_for',
 'how_old_are_you',
 'accept_reservations',
 'travel_alert',
 'sync_device',
 'damaged_card',
 'replacement_card_duration',
 'where_are_you_from',
 'credit_limit_change',
 'what_can_i_ask_you',
 'next_holiday',
 'distance',
 'report_lost_card',
 'fun_fact',
 'what_are_your_hobbies',
 'report_fraud',
 'who_made_you',
 'international_fees',
 'redeem_rewards',
 'order_status',
 'shopping_list_update',
 'credit_score',
 'yes',
 'calories',
 'taxes',
 'how_busy',
 'reminder_update',
 'balance',
 'flight_status',
 'pto_request',
 'alarm',
 'book_flight',
 'confirm_reservation',
 'change_user_name',
 'no',
 'schedule_meeting',
 'book_hotel',
 'transactions',
 'play_music',
 'cook_time',
 'cancel',
 'gas',
 'meaning_of_life',
 'meeting_schedule',
 'restaurant_reviews',
 'tire_change',
 'oil_change_when',
 'gas_type',
 'change_accent',
 'are_you_a_bot',
 'apr',
 'tire_pressure',
 'time',
 'reset_settings',
 'mpg',
 'oil_change_how',
 'current_location',
 'greeting',
 'nutrition_info',
 'application_status',
 'new_card',
 'what_is_your_name',
 'change_language',
 'order',
 'schedule_maintenance',
 'travel_notification',
 'text',
 'travel_suggestion',
 'make_call',
 'smart_home',
 'recipe',
 'restaurant_reservation',
 'user_name',
 'bill_due',
 'what_song',
 'lost_luggage',
 'spelling',
 'routing',
 'calendar_update',
 'direct_deposit',
 'reminder',
 'credit_limit',
 'vaccines',
 'timezone',
 'expiration_date',
 'date',
 'UNK_UNK']
label_map = {}
label_reverse_map ={}
for (i, label) in enumerate(label_list):
  label_map[label] = i
  label_reverse_map[i]=label


In [None]:
transformer = AutoModel.from_pretrained("./transf2")


In [None]:
model_name = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435797.0, style=ProgressStyle(descripti…




In [None]:
config = AutoConfig.from_pretrained(model_name)
hidden_size = int(config.hidden_size)
# Define the number and width of hidden layers
hidden_levels_g = [hidden_size for i in range(0, num_hidden_layers_g)]
hidden_levels_d = [hidden_size for i in range(0, num_hidden_layers_d)]

In [None]:
generator2 = Generator(noise_size=noise_size, output_size=hidden_size, hidden_sizes=hidden_levels_g, dropout_rate=out_dropout_rate)
generator2.load_state_dict(torch.load("./generator2.pt",map_location=torch.device('cpu')))
generator2.eval()

Generator(
  (layers): Sequential(
    (0): Linear(in_features=100, out_features=768, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=768, out_features=768, bias=True)
  )
)

In [None]:
discriminator2 = Discriminator(input_size=hidden_size, hidden_sizes=hidden_levels_d, num_labels=len(label_list), dropout_rate=out_dropout_rate)
discriminator2.load_state_dict(torch.load("./discriminator2.pt",map_location=torch.device('cpu')))
discriminator2.eval()

Discriminator(
  (input_dropout): Dropout(p=0.2, inplace=False)
  (layers): Sequential(
    (0): Linear(in_features=768, out_features=768, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout(p=0.2, inplace=False)
  )
  (logit): Linear(in_features=768, out_features=154, bias=True)
  (softmax): Softmax(dim=-1)
)

In [None]:
test_examples = [('Which department has responded the most?', 'user_department'), ('Students from which department think pressure is the reason for cheating?', 'user_department'), ('Department with international students', 'user_department'), ('Which departments have students who think pressure causes cheating?', 'user_department'), ('Department with local students', 'user_department'), ('Department with students of age less than 30', 'user_department'), ('Department related to health', 'user_department'), ('Which department has local students?', 'user_department'), ('Which department have students of age less than 20?p', 'user_department'), ('Students from which country has responded the most?', 'user_country'), ('Students from which country think pressure is the reason for cheating?', 'user_country'), ('Which country are most international students from', 'user_country'), ('Which country has students who think pressure causes cheating?', 'user_country'), ('country with students of age less than 30', 'user_country'), ('country which is related to health', 'user_country'), ('Which country are most students from?', 'user_country'), ('Which countries have students of age less than 20?p', 'user_country'), ('i want to know when i was last paid', 'payday'), ('do you know how many rewards points are currently on my discover card', 'rewards_balance'), ('on my to do list, add exercising', 'todo_list_update'), ('help me roll over my 401k', 'rollover_401k'), ('can i get a w2 form online', 'w2'), ('how much have i spent on clothes recently', 'spending_history'), ('you need to set the timer for me', 'timer'), ('how can i find my phone', 'find_phone'), ('use whisper voice', 'whisper_mode'), ('check my domestic to do list for carpet cleaning', 'todo_list'), ('i wanna know how many teaspoons are in a tablespoon', 'measurement_conversion'), ('go ahead and flip a coin, i am calling tails', 'flip_coin'), ('what is the interest rate at chase', 'interest_rate'), ('what do you call a subway if you were english', 'translate'), ('put a freeze on my bank account', 'freeze_account'), ('does mexico require an international visa', 'international_visa'), ('what is let it be mean', 'definition'), ("i don't really know", 'maybe'), ('what is the date of my last paycheck', 'payday'), ("what's restricted in my carry-on with united", 'carry_on'), ('is there a way to get my pin number', 'pin_change'), ('has my vacation request been approved', 'pto_request_status'), ('how can i rent a car in boston', 'car_rental'), ('tell me what my health plan is called', 'insurance'), ('find repair shop that will diagnosis dead car battery', 'jump_start'), ('send 2000 dollars between chase and rabobank accounts', 'transfer'), ('can you let ben and jerry know my current location', 'share_location'), ('what ingredients do i need for philly cheesesteaks', 'ingredients_list'), ("what's the minimum required payment on my visa card", 'min_payment'), ('either or should work', 'maybe'), ('at target trying to buy a mug and my card was declined', 'card_declined'), ('how many pto days do i have for the year', 'pto_balance'), ('lower the volume', 'change_volume'), ('uber, i have 3 people who are going to union station', 'uber'), ('please tell me how much i have spent recently', 'spending_history'), ('how do i get to the closest starbucks', 'directions'), ('how do i convert four inches into centimeters', 'measurement_conversion'), ('i think my account is blocked but i do not know the reason', 'account_blocked'), ("i'd like you to give me an idea for a meal from iceland", 'meal_suggestion'), ('i would like to add you say by lauren daigle to my playlist', 'update_playlist'), ('tell me a joke about animals', 'tell_joke'), ('do you know how long its been since i had my oil changed and a tune up', 'last_maintenance'), ('which restaurants in reno are good for sushi', 'restaurant_suggestion'), ('buhbye', 'goodbye'), ('i appreciate the assistance', 'thank_you'), ('are there some good thai restaurants in san antonio', 'restaurant_suggestion'), ('my account is blocked, help me figure out why', 'account_blocked'), ('help me find new insurance providers', 'insurance_change'), ('how many kilos are in 150 pounds', 'measurement_conversion'), ('do i need a socket converter in england', 'plug_type'), ("what is the total number of vacation days i've used", 'pto_used'), ('i want to know if i shouldget a tourist visa for italy', 'international_visa'), ('schedule an uber to the bean', 'uber'), ('i need help paying my hoa bill', 'pay_bill'), ('what do i need to cook chicken soup', 'ingredients_list'), ('how do i build my credit score', 'improve_credit_score'), ('how much is my electric bill for this month', 'bill_balance'), ('is taking out the trash on my to do list', 'todo_list'), ('my chinese food has sat in the fridge since friday but can i still eat it', 'food_last'), ('please tell wanda where i am now', 'share_location'), ("because of circumstances i don't need my reservation anymore", 'cancel_reservation'), ('i would like to know what my salary is at this time', 'income'), ('what is the total amount that i owe to sony and verizon', 'bill_balance'), ('what is the quantity of dollars i receive for trading 6 yens', 'exchange_rate'), ("what are delta's carry-on restrictions", 'carry_on'), ('would you repeat that', 'repeat'), ('speak faster', 'change_speed'), ("what's the climate like in chicago", 'weather'), ('i want to pay my bill, please', 'pay_bill'), ('what items do i have on my calendar for easter', 'calendar'), ('what animals do you have', 'do_you_have_pets'), ('please submit an order for new checks on my pnc money market account', 'order_checks'), ('is there a carry on weight limit', 'carry_on'), ('pull up my calendar and tell me what i have scheduled for march 30', 'calendar'), ('what is 5 x 4', 'calculator'), ('let me know the number of vacation days i have', 'pto_balance'), ('please share the definition of episcopalian', 'definition'), ('give me an idea of traffic on the way to the doctors office at 6', 'traffic'), ('what did i put on my shopping list so far', 'shopping_list'), ('what is the next song to be played', 'next_song'), ("what's on my agenda for tomorrow", 'calendar'), ('please roll a 6 sided dice', 'roll_dice'), ('can we change your name to lisa', 'change_ai_name'), ('i didnt hear what you just said, can you say it again', 'repeat'), ('i have no more checks, how can i order more', 'order_checks'), ('locate my phone', 'find_phone'), ('can you sub half and half for heavy whipping cream', 'ingredient_substitution'), ('who do you work for', 'who_do_you_work_for'), ('how long have you been alive for', 'how_old_are_you'), ('do you know if outback allows reservations', 'accept_reservations'), ('what alerts are listed for traveling to paris', 'travel_alert'), ('does moscow require travel visas to visit', 'international_visa'), ('i want that volume bumped up to 4', 'change_volume'), ('please link up to my cell', 'sync_device'), ('could you direct me to any good places in kentucky that serve sushi', 'restaurant_suggestion'), ('my card is not working and i need to let them know', 'damaged_card'), ('can you get me an uber to the science museum', 'uber'), ('how soon will you mail me a new card', 'replacement_card_duration'), ('where were you manufactured', 'where_are_you_from'), ('can my limit be increased for my visa', 'credit_limit_change'), ('how do you order new checks', 'order_checks'), ('what can i ask you to accomplish', 'what_can_i_ask_you'), ('could you skip to the next song on this playlist', 'next_song'), ('tell me when i will next have the day off', 'next_holiday'), ("how long is the drive from sfo to fisherman's wharf", 'distance'), ('how old are you', 'how_old_are_you'), ('where to report lost discover credit card', 'report_lost_card'), ('help, i cannot find my phone', 'find_phone'), ('how much time to reach doctors office', 'distance'), ('what is the next holiday', 'next_holiday'), ('increase volume to 4', 'change_volume'), ('what should i do with an unusable and damaged card', 'damaged_card'), ('can you tell me a fun fact about elephants', 'fun_fact'), ('do you have any hobbies', 'what_are_your_hobbies'), ("i'd like to rent an automobile in pittsburgh from this tuesday until next thursday can i do that", 'car_rental'), ("i can't locate my mastercard and i want to report it as lost, please", 'report_lost_card'), ('i have a fraudulent transaction from red robin on my account', 'report_fraud'), ('which company is responsible for your design', 'who_made_you'), ('what date did i receive my last paycheck', 'payday'), ('does home have a starbucks nearby', 'directions'), ('will i get charged extra if i use my card while in italy', 'international_fees'), ('what do i do to redeem credit card points', 'redeem_rewards'), ('can i still make changes to my order, or has it already been shipped', 'order_status'), ('say it one more time', 'repeat'), ('good answer, thanks for providing it', 'thank_you'), ('how many holiday days do i have left to use', 'pto_balance'), ('i would like you to add milk if it is not on my shopping list already', 'shopping_list_update'), ('what things do you enjoy doing', 'what_are_your_hobbies'), ('really dont know', 'maybe'), ('how many vacation days are available to me', 'pto_balance'), ('are you paid by somebody', 'who_do_you_work_for'), ('how does my credit score look', 'credit_score'), ("i think that's true", 'yes'), ('please give me the minimum payment total for my sprint phone bill', 'min_payment'), ('get me an uber to chilis', 'uber'), ('please transfer $x from checking to saving', 'transfer'), ('how many calories in meatloaf', 'calories'), ('am i your manager', 'who_do_you_work_for'), ('tell me if there is an event called annual physical in my calendar', 'calendar'), ('i need to know how much i should pay in tax', 'taxes'), ('is there an extra fee for using my card in costa rica', 'international_fees'), ('how long will it take to be seated at the flying w', 'how_busy'), ('how many kilos are in 20 pounds', 'measurement_conversion'), ('i need to take out the trash please remind me', 'reminder_update'), ('what is my money market account balance', 'balance'), ('how do you say hello in french', 'translate'), ('tell me when my flight is landing', 'flight_status'), ('can you let me know if it will be warm this week', 'weather'), ('please let me know how much my gas bill is', 'bill_balance'), ('track my package now please', 'order_status'), ('get the details on the pto i have used', 'pto_used'), ('where do i make a vacation request', 'pto_request'), ('describe the meaning of "dog', 'definition'), ('i just got a new job and need help transferring my 401k', 'rollover_401k'), ('please cancel the table for two at burger king', 'cancel_reservation'), ('set my alarm for 5pm', 'alarm'), ('do you have another job', 'who_do_you_work_for'), ('remind me tommorow at 3pm i have a doctors appointment', 'reminder_update'), ('for the dates april 1st to the 7th, find me round trip air tickets from la to sfo', 'book_flight'), ('my card split in half, how do i report it', 'damaged_card'), ('are car rentals available out of new york from march 2 to march 3', 'car_rental'), ('whats the most recent status for my virgin air flight', 'flight_status'), ('how long will it take to walk to the safeway in the strip mall', 'distance'), ('is my reservation confirmed for lu shi at 7 pm', 'confirm_reservation'), ('i left my phone somewhere', 'find_phone'), ('will i be charged if i use the card in australia', 'international_fees'), ('i demagnetized my card and want to report it', 'damaged_card'), ('may you repeat what you said again', 'repeat'), ('yes that is correct', 'yes'), ('i need for you to connect to my phone', 'sync_device'), ('im sorry can you repeat yourself', 'repeat'), ('how do i keep good credit', 'improve_credit_score'), ("i'm called tim", 'change_user_name'), ('what is my total annual income', 'income'), ('negative', 'no'), ('can you schedule a meeting with damon for 1', 'schedule_meeting'), ('help me find a good reviews hotel in la', 'book_hotel'), ("if i'm using citibank, what is the interest rate", 'interest_rate'), ('i need to know the calorie content in a piece of pepperoni pizza', 'calories'), ('i want to see the list of transaction on my bank of hawaii', 'transactions'), ('how much of my pto have i depleted', 'pto_used'), ('play me the playlist entitled jazz', 'play_music'), ('get me roundtrip flight information for flights from dallas to houston for june 8 and june 17', 'book_flight'), ('what is my income from work', 'income'), ('change my reservation with carl at umami to canceled', 'cancel_reservation'), ('what state were you born in', 'where_are_you_from'), ('book me an uber to olive garden', 'uber'), ('what is the wait time at this restaurant', 'how_busy'), ('how long would it take to get to times square by bus', 'distance'), ('say a fun fact about mt everest', 'fun_fact'), ('how long should i fry an egg for', 'cook_time'), ('good seeing you', 'goodbye'), ("i can't find my phone do you know where it is", 'find_phone'), ("actually forget that i don't need it", 'cancel'), ('what amount of gas is in my car', 'gas'), ('what is my current gas level', 'gas'), ('are you now working for me', 'who_do_you_work_for'), ('tell me some amazing trivia about radios', 'fun_fact'), ('can you put on music by beach house', 'play_music'), ('define the meaning of life', 'meaning_of_life'), ('tell me the lowest amount i can pay for my cable bill', 'min_payment'), ('explain to me how to rollover my 401k', 'rollover_401k'), ('what is 20 + 20', 'calculator'), ('please show me what transactions i made on the first of this month', 'transactions'), ('what is the weather report for seattle', 'weather'), ('make a payment on the electric bill', 'pay_bill'), ('when is my meeting with dan scheduled for', 'meeting_schedule'), ('how many meetings do i have today between noon and one', 'meeting_schedule'), ('how many points have been earned on my amex', 'rewards_balance'), ("what's traffic usually like around 5pm going to the bank", 'traffic'), ('is it too much trouble to put a stop on my bank account', 'freeze_account'), ("what do people think about the sandwiches at wendy's", 'restaurant_reviews'), ('get an uber to the airport', 'uber'), ('can you share the meaning of life', 'meaning_of_life'), ('how much longer until i have to change my tires', 'tire_change'), ("i have how many tier credits on my caesar's card", 'rewards_balance'), ('i appreciate your help!', 'thank_you'), ("how long does it take to get to spago's in los angeles", 'distance'), ('can my credit limit be increased to one thousand dollars', 'credit_limit_change'), ('i appreciate that', 'thank_you'), ('will i pay extra if i use my card in juarez', 'international_fees'), ('when would you recommend i get my oil changed', 'oil_change_when'), ('slow your speech down', 'change_speed'), ("i'd like to me told the meaning of life", 'meaning_of_life'), ('i need to know what kind of gas to put in this car', 'gas_type'), ('are you from the uk', 'where_are_you_from'), ('switch over from female to the male voice', 'change_accent'), ('for anew card to be mailed to me, how long does it take', 'replacement_card_duration'), ('due to fraudulent activity on my card i need to make a report', 'report_fraud'), ('is this a computer right now or a human being', 'are_you_a_bot'), ('what is the apr to my credit card', 'apr'), ('play that song called colony of birchman', 'play_music'), ('decrease volume', 'change_volume'), ('can you locate my phone for me', 'find_phone'), ('please measure my tire pressure', 'tire_pressure'), ('what are you familiar with', 'what_can_i_ask_you'), ('research the meaning of life', 'meaning_of_life'), ('what percentage of my income is taken out for taxes', 'taxes'), ('can you help me book a car rental in paris', 'car_rental'), ('will i receive a fee if i use my card in ireland', 'international_fees'), ('websites that share credit ratings', 'credit_score'), ('complete a transaction from savings to checking of $20000', 'transfer'), ('let me know when should i get my tires changed next', 'tire_change'), ('are you able to track a phone', 'find_phone'), ('what does it say on the clock', 'time'), ("can you tell me what's the car's tire pressure", 'tire_pressure'), ("what's the forecast like for pittsburgh", 'weather'), ('give my my current points on my marriot rewards card', 'rewards_balance'), ('what kind of gas does my car use', 'gas_type'), ('discontinue please', 'cancel'), ('what must i do to report my card lost', 'report_lost_card'), ('roll a die with 6 sides', 'roll_dice'), ('change back to your factory settins', 'reset_settings'), ('do you know the fuel economy of this car', 'mpg'), ('how do i build credit', 'improve_credit_score'), ('forget this song and go to the next', 'next_song'), ('i intend to learn how to change oil in my car', 'oil_change_how'), ("please tell me if my reservation is scheduled for sally's at 3 pm", 'confirm_reservation'), ('i think i made a fraudulent transaction', 'report_fraud'), ('may i know my insurance benefits', 'insurance'), ('are there any meeting rooms free from 10:00 am until 10:30 am', 'schedule_meeting'), ('how many carry ons can i take on a flight with american airlines to seattle', 'carry_on'), ("tell me the subjects you're aware of", 'what_can_i_ask_you'), ('please give me my gps coordinates', 'current_location'), ('tell me the meeting schdule please', 'meeting_schedule'), ('i wanna know the gas i need to fill this car up with', 'gas_type'), ('might be true, might be false', 'maybe'), ('i went to target to buy a mug but my card did not work', 'card_declined'), ('yes, that is accurate', 'yes'), ('what are my gps coordinates at this location', 'current_location'), ('i want you to speak to me faster', 'change_speed'), ("hello how's it going", 'greeting'), ('help me change my insurance plan', 'insurance_change'), ('connect with my phone please', 'sync_device'), ('name some awesome things about dogs', 'fun_fact'), ('your answer was enjoyable', 'thank_you'), ('timer for 5 minutes', 'timer'), ("what's the facts about nutrients in rice milk", 'nutrition_info'), ('i have suspicious charges on my discovery card', 'report_fraud'), ('put dr feelgood by motley crue on my playlist', 'update_playlist'), ("call hr so i can figure out how many days off i've taken", 'pto_used'), ("does chili's take reservations", 'accept_reservations'), ('approximately how long does it take to get to the resort in miami', 'distance'), ('you have to connect to my phone', 'sync_device'), ('do you have any suggestions as to what i should cook for dinner', 'meal_suggestion'), ('let me know how much gas i have', 'gas'), ('do they take reservations at bar tartine', 'accept_reservations'), ('update me on my vacation request', 'pto_request_status'), ('how much longer until my next day off', 'next_holiday'), ('tell me straight if it has been processed or not', 'application_status'), ('please increase the speed of your talking', 'change_speed'), ("i'd like to get on a new united healthcare plan", 'insurance_change'), ('is it possible to rent a car from new york on 3/2/19 to 3/5/19', 'car_rental'), ('credit limit change', 'credit_limit_change'), ('are there any meetings for today on my calendar', 'meeting_schedule'), ('please apply cc visa card for me', 'new_card'), ("i'll talk to you later ai", 'goodbye'), ('how can i get my year end summary for taxes', 'w2'), ('can you locate my w-2', 'w2'), ('would you let me know what flights are available out of dallas to el paso on march 20 for under $400', 'book_flight'), ('what kind of gas should i put in this car', 'gas_type'), ('provide the name i should use to refer to you', 'what_is_your_name'), ('what is 20% of a thousand', 'calculator'), ('what would be the conversion between tablespoons and teaspoons', 'measurement_conversion'), ('my card is damaged and no longer function', 'damaged_card'), ('can you show me my shopping list', 'shopping_list'), ('set your language setting to english', 'change_language'), ('has may supervisor approved the vacation days i asked for', 'pto_request_status'), ('is there really an answer to the meaning of life', 'meaning_of_life'), ('take watering the plants off of my to do list', 'todo_list_update'), ('i need a recent transaction looked into', 'transactions'), ('show me the minimum payment for my boat bill', 'min_payment'), ('i need to know my salary', 'income'), ('i need you to buy a mouse for me', 'order'), ('what sorts of subjects are you well versed in', 'what_can_i_ask_you'), ('when is the next scheduled holiday, please', 'next_holiday'), ('please start counting down from 10 minutes', 'timer'), ('get me an uber for 4 heading to miam', 'uber'), ('about how many meetings am i attending between 12 and 3 today', 'meeting_schedule'), ('i need to schedule some car maintenance', 'schedule_maintenance'), ('let my bank know i will be on travel to fes', 'travel_notification'), ('could you speak a little faster, please', 'change_speed'), ("send mom a text i'll talk to you later", 'text'), ("can i make a reservation at chevy's", 'accept_reservations'), ('what are some interesting things to do in dc', 'travel_suggestion'), ('what is our purpose here on earth', 'meaning_of_life'), ('tell me what the weather is like', 'weather'), ('how do i use reward for my first hawaiian bank', 'redeem_rewards'), ('can you help me with anything i need', 'what_can_i_ask_you'), ('how much is 10kg in pounds and ounces', 'measurement_conversion'), ('i need to know how many vacation days i have', 'pto_balance'), ('do i need to get gas soon', 'gas'), ('what meetings are on my calendar', 'meeting_schedule'), ('can i get more money on my discover card', 'credit_limit_change'), ('play a song for me', 'play_music'), ('how long do i have before frozen chicken will go bad', 'food_last'), ('can you please locate my phone', 'find_phone'), ('my card declined yesterday and i want to know why', 'card_declined'), ("how do you say it's snowing in german", 'translate'), ("i don't understand why i have been barred from accessing my own account", 'account_blocked'), ('do i have anything to do march 2nd', 'calendar'), ('can you help me solve a math problem', 'calculator'), ('i must call dr smith', 'make_call'), ('does longhorn steakhouse have good reviews', 'restaurant_reviews'), ('i wanna know the point of life', 'meaning_of_life'), ('what is the process for requesting a vacation', 'pto_request'), ('set reminder to feed cat tonight at 6pm', 'reminder_update'), ('when do i have meetings today', 'meeting_schedule'), ("i'd like a block on my charles schwab account immediately", 'freeze_account'), ('you must skip this song and play the next one', 'next_song'), ('are you familiar with any types of subjects', 'what_can_i_ask_you'), ('my gps coordiantes will be shared with aunt sunny', 'share_location'), ('how much is my income', 'income'), ('what kind of fuel should i use to fill the car', 'gas_type'), ("my new playlist well be having god's plan adding to it", 'update_playlist'), ('read my calendar events', 'calendar'), ('check to see who is at the doof', 'smart_home'), ('has my credit card application been approved', 'application_status'), ('i am interested in applying for a visa card', 'new_card'), ('how long to cook a frozen pizza', 'cook_time'), ('is spaghetti healthy', 'nutrition_info'), ('can you give me a recipe for german chocolate cake', 'recipe'), ('what are the reviews for mountain mikes', 'restaurant_reviews'), ('purchase a flight from boise to sacramento on sunday and returning on wednesday', 'book_flight'), ("what's the nutritional info for spaghetti", 'nutrition_info'), ('i wanna know your name', 'what_is_your_name'), ('please whisper', 'whisper_mode'), ('tell me what time my flight ought to be landing', 'flight_status'), ('is my application processed for credit card', 'application_status'), ('start a timer for two minutes', 'timer'), ('what can i use you to help me with', 'what_can_i_ask_you'), ('find an uber xl to take me to the kroger near me', 'uber'), ('it was great to see you again, see ya later!', 'goodbye'), ('what date should i get my tires changed', 'tire_change'), ('yesterday what did i spend on lunch', 'spending_history'), ("i need you to reserve at table for a party of four at devon's for 6:00 pm", 'restaurant_reservation'), ("what's your designation", 'what_is_your_name'), ('who do you function for', 'who_do_you_work_for'), ('what are my insurance rewards', 'insurance'), ('do you my name', 'user_name'), ('i want to report a broken card', 'damaged_card'), ('revert to factory settings please', 'reset_settings'), ('i need you to schedule a meeting with bob brown at noon the day after tomorrow', 'schedule_meeting'), ("what's the due date for my american express payment", 'bill_due'), ('how do they say hello in germany', 'translate'), ('how are things with you', 'greeting'), ('find a chow mein recipe for me, please', 'recipe'), ('please revert all settings to factory default', 'reset_settings'), ('do you know of any good restaurants', 'restaurant_suggestion'), ("what's the miles per gallon on this car", 'mpg'), ('do you know you are not human', 'are_you_a_bot'), ('what if you want to obtain a new credit card', 'new_card'), ('what do you think is the meaning of like', 'meaning_of_life'), ('where can i find a place where i can schedule to check my tires out', 'schedule_maintenance'), ('what are the fees to use my card in nigeria', 'international_fees'), ("what's the name of this tune", 'what_song'), ('your boss is who', 'who_do_you_work_for'), ('what meetings are scheduled for today', 'meeting_schedule'), ('no, that information is wrong', 'no'), ('please check the pressure in my tires', 'tire_pressure'), ('what name do you know me by', 'user_name'), ('will qdoba take reservations', 'accept_reservations'), ('tell me the exchange rate between rubles and dollars', 'exchange_rate'), ('please let me know if giving the dog a bath is on my list of tasks to complete', 'todo_list'), ('can you tell me the mane of the song playing', 'what_song'), ('what do you want me to refer to you as', 'what_is_your_name'), ("how long's it been since my car was at autozone", 'last_maintenance'), ("do you know if anyone's even looked at the application i sent in for the new visa card", 'application_status'), ('do you know any fun facts about shampoo', 'fun_fact'), ('i want to hear my to do list please', 'todo_list'), ('flip a coin for me', 'flip_coin'), ('send a text message to chris and ask what he wants to eat for dinner', 'text'), ('when was the last time i got my oil changed', 'last_maintenance'), ('can you tell bob to get his dog via text', 'text'), ('book an uber to school', 'uber'), ('how do you make pot roast', 'recipe'), ('ai, do you like your name', 'what_is_your_name'), ('bring up my most recent purchases', 'transactions'), ('do i need a visa for germany', 'international_visa'), ('find out for me if my amex card application was received', 'application_status'), ('on may 12 to may 16 can i get a car from aiken', 'car_rental'), ('when am i next having a meeting in december', 'calendar'), ('set a reminder for the movie', 'reminder_update'), ('can you please help me find my lost luggage', 'lost_luggage'), ('tell me how healthy mac and cheese is', 'nutrition_info'), ("what's my checking balance", 'balance'), ('are my tires good on air', 'tire_pressure'), ("let's switch to whisper voice", 'whisper_mode'), ('does my car have enough gas to get to chicago', 'gas'), ("favorite what's currently playing on my playlist", 'update_playlist'), ('how long should i wait before i can bake bread with homemade dough', 'cook_time'), ('can you tell me how to spell manipulation', 'spelling'), ('i must apply for a new credit card', 'new_card'), ('assist me by setting my alarm for 9:00 and another for 11:00', 'alarm'), ('check the interest rate on my savings account', 'interest_rate'), ('i would like to change my pin number for my chase account', 'pin_change'), ('what amount of time will pass in order for a person at our location to take bus to detroit', 'distance'), ('add this track to my rock playlist', 'update_playlist'), ('please remind me to take out the trash', 'reminder_update'), ('please change pin to 1234 on my bank account trailing in 3829', 'pin_change'), ('set default language to english', 'change_language'), ('i would like to know how to rollover my 401k', 'rollover_401k'), ('is there a wait time to get into tgifridays', 'how_busy'), ('i lost my phone', 'find_phone'), ("what's bank of america's routing number", 'routing'), ('i want to check on my vacation request', 'pto_request_status'), ('i need to figure out what to do about lost luggage', 'lost_luggage'), ('play beatles', 'play_music'), ('i need to let my bank know i am traveling to chicago', 'travel_notification'), ('i need an uber to times square asap', 'uber'), ("i'm looking for a good suggestion for norse cuisine", 'meal_suggestion'), ('tell me a joke about lawyers', 'tell_joke'), ('is an international visa needed to travel to z', 'international_visa'), ('can you ping priest chris with a text and send the following confession', 'text'), ('i need to switch back to my factory settings', 'reset_settings'), ('i want to reserve an uber to go to the airport', 'uber'), ('i need to know my income', 'income'), ('go ahead and change your accent to the male british one', 'change_accent'), ('what is the remainder of my starbucks rewards balance', 'rewards_balance'), ('give me the nutrition facts for chicken breast', 'nutrition_info'), ('can i hear some music by cloud control', 'play_music'), ('how long can i go before i need to change my tires', 'tire_change'), ('is a bar close to my church', 'directions'), ('i need to change your voice settings', 'change_speed'), ('tell me how long i ought to spend preparing fajitas', 'cook_time'), ('not quite sure how to respond', 'maybe'), ('i want you to call me pam', 'change_user_name')]
test_label_masks = np.ones(len(test_examples), dtype=bool)

test_dataloader = generate_data_loader(test_examples, test_label_masks, label_map, do_shuffle = False, balance_label_examples = False)

In [None]:
total_test_loss = 0
nb_test_steps = 0

all_preds = []
all_labels_ids = []

#loss
nll_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)

# Evaluate data for one epoch
for batch in test_dataloader:
    
    # Unpack this training batch from our dataloader. 
    b_input_ids = batch[0].to(device)
    b_input_mask = batch[1].to(device)
    b_labels = batch[2].to(device)
    
    # Tell pytorch not to bother with constructing the compute graph during
    # the forward pass, since this is only needed for backprop (training).
    with torch.no_grad():        
        model_outputs = transformer(b_input_ids, attention_mask=b_input_mask)
        hidden_states = model_outputs[-1]
        _, logits, probs = discriminator2(hidden_states)
        ###log_probs = F.log_softmax(probs[:,1:], dim=-1)
        filtered_logits = logits[:,0:-1]
        # Accumulate the test loss.
        total_test_loss += nll_loss(filtered_logits, b_labels)
        
    # Accumulate the predictions and the input labels
    _, preds = torch.max(filtered_logits, 1)
    all_preds += preds.detach().cpu()
    all_labels_ids += b_labels.detach().cpu()

# Report the final accuracy for this validation run.
all_preds = torch.stack(all_preds).numpy()
all_labels_ids = torch.stack(all_labels_ids).numpy()
test_accuracy = np.sum(all_preds == all_labels_ids) / len(all_preds)
print("  Accuracy: {0:.3f}".format(test_accuracy))

  Accuracy: 1.000


In [None]:
all_preds

array([31, 31,  4, 31, 44, 31, 31,  4, 36, 31, 11, 31, 36, 36, 45, 30, 36,
       36, 31, 45, 10, 31,  3, 36, 31, 31, 36, 36, 11,  6, 39, 49, 39, 31,
       31, 31, 23, 30, 30, 31,  6, 33, 41, 39, 36, 31, 31, 33, 36, 10, 39,
       31, 31, 31, 49, 31, 36,  6,  6, 45, 36, 31, 36, 31, 31,  3, 45, 31,
       31, 31, 11,  3,  6, 30, 36,  4, 31, 14, 31, 45, 36, 45, 31, 45, 31,
       36, 36, 39, 31, 31, 33, 36, 39, 31, 31,  6, 36, 31, 10, 31, 39, 41,
        6, 39, 31, 31, 39, 45, 31, 45, 39, 31, 36, 39])

In [None]:
all_labels_ids

array([39, 18, 14,  0,  0, 39, 18, 14,  6, 31, 45, 31,  6, 36,  3, 29, 36,
        6, 31,  3, 33, 31, 40,  5, 31, 18,  6, 36, 25, 34, 10, 23, 10, 31,
       18, 39, 12,  4,  4, 39, 37,  2, 11,  7, 36, 31, 31,  2, 36, 37, 26,
       39, 31, 31, 32, 31, 36, 34,  5,  3, 36, 18, 36, 39, 39, 40,  3, 31,
       18, 39, 45, 40, 34, 29, 36, 14, 31, 41, 31,  3,  5, 30, 39,  3, 31,
        6,  6, 10, 31, 31,  2, 36, 44, 31, 31,  5, 36, 31, 33, 31, 26, 49,
        5,  1, 31, 31,  7,  3, 18,  3, 44, 31, 36,  7])

In [None]:
label_reverse_map[31]


'HUM_ind'

In [None]:
def get_prediction(input_list):
  test_examples = []
  for item in input_list:
    test_examples.append((item,'UNK_UNK'))
  test_label_masks = np.ones(len(test_examples), dtype=bool)
  test_dataloader = generate_data_loader(test_examples, test_label_masks, label_map, do_shuffle = False, balance_label_examples = False)
  

  total_test_loss = 0
  nb_test_steps = 0

  all_preds = []
  all_labels_ids = []

  #loss
  nll_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)

  # Evaluate data for one epoch
  for batch in test_dataloader:
      
      # Unpack this training batch from our dataloader. 
      b_input_ids = batch[0].to(device)
      b_input_mask = batch[1].to(device)
      b_labels = batch[2].to(device)
      
      # Tell pytorch not to bother with constructing the compute graph during
      # the forward pass, since this is only needed for backprop (training).
      with torch.no_grad():        
          model_outputs = transformer(b_input_ids, attention_mask=b_input_mask)
          hidden_states = model_outputs[-1]
          _, logits, probs = discriminator2(hidden_states)
          ###log_probs = F.log_softmax(probs[:,1:], dim=-1)
          filtered_logits = logits[:,0:-1]
          # Accumulate the test loss.
          total_test_loss += nll_loss(filtered_logits, b_labels)
          
      # Accumulate the predictions and the input labels
      _, preds = torch.max(filtered_logits, 1)
      all_preds += preds.detach().cpu()
      all_labels_ids += b_labels.detach().cpu()
  predictions = []
  all_preds = torch.stack(all_preds).numpy()
  for item in all_preds:
    print(item)
    predictions.append(label_reverse_map[item])
  return predictions

In [None]:
get_prediction(["which department is good","students from which country are obedient"])

0
1


['user_department', 'user_country']