In [1]:
import warnings
warnings.simplefilter('ignore')
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.metrics import f1_score, classification_report
import matplotlib.pyplot as plt
import logging
logging.basicConfig(level=logging.ERROR)
from ast import literal_eval
from torch import cuda
import json
device = 'cuda:0' if cuda.is_available() else 'cpu'

user_intents = ['initial_query', 'greeting', 'add_filter', 'remove_filter', 'continue', 'accept_response', 'reject_response', 'others']
system_intents = ['feedback_request', 'detail_attribute_request', 'passive_recommend', 'active_recommend', 'parroting_response', 'sympathetic_response', 'others']
music_attributes = ['track', 'artist', 'year', 'popularity', 'culture', 'similar_track', 'similar_artist', 'user', 'theme', 'mood', 'genre', 'instrument', 'vocal', 'tempo', 'none']
intents_dict = {'user': user_intents, 'system': system_intents, 'music': music_attributes}

df = pd.read_csv('./most_recent.csv', encoding='unicode_escape')
df['intent'] = df['intent'].apply(literal_eval)
df['music_attribute'] = df['music_attribute'].apply(literal_eval)

# df = df[df['dialog_id'].apply(lambda x: x not in error_dialog_id)]

# 20개 이하인 intent는 others로 변경 (question 13개, answer 7개)
df["intent"] = df["intent"].apply(lambda x: ["others" if item in ["item_attribute_answer", "item_attribute_question"] else item for item in x])

# others 외의 intent가 함께 있으면 others 제거
def remove_others_if_not_alone(intents):
	if 'others' in intents and len(intents) > 1:
		intents.remove('others')
	return intents
df['intent'] = df['intent'].apply(remove_others_if_not_alone)

# initial query와 함께 [remove_filter, continue, accept_response, reject_response, others]가 있으면 제거
def preprocess_initial(row):
	if 'initial_query' in row['intent']:
		for intent_to_remove in ['remove_filter', 'continue', 'accept_response', 'reject_response', 'others']:
			if intent_to_remove in row['intent']:
				row['intent'].remove(intent_to_remove)
	return row
df = df.apply(preprocess_initial, axis=1)

#######################
def concat_previous_1_rows(group):
	if len(group) < 1:
		return pd.DataFrame()
	group = group.copy()
	group['content'] = group['content'].shift(1).fillna('') + '. ' + group['content']
	group['content'].iloc[0] = group['content'].iloc[0].lstrip('. ')
	return group

def concat_previous_2_rows(group):
	if len(group) < 2:
		return pd.DataFrame()
	group = group.copy()
	group['content'] = group['content'].shift(2).fillna('') + '. ' + group['content'].shift(1).fillna('') + '. ' + group['content']
	group['content'].iloc[0] = group['content'].iloc[0].lstrip('. ')
	group['content'].iloc[1] = group['content'].iloc[1].lstrip('. ')
	return group

def concat_previous_4_rows(group):
	if len(group) < 4:
		return pd.DataFrame()
	group = group.copy()
	group['content'] = group['content'].shift(4).fillna('') + '. ' + group['content'].shift(3).fillna('') + '. ' + group['content'].shift(2).fillna('') + '. ' + group['content'].shift(1).fillna('') + '. ' + group['content']
	for i in range(4):
		group['content'].iloc[i] = group['content'].iloc[i].lstrip('. ')
	return group

def concat_previous_6_rows(group):
	if len(group) < 6:
		return pd.DataFrame()
	group = group.copy()
	group['content'] = group['content'].shift(6).fillna('') + '. ' + group['content'].shift(5).fillna('') + '. ' + group['content'].shift(4).fillna('') + '. ' + group['content'].shift(3).fillna('') + '. ' + group['content'].shift(2).fillna('') + '. ' + group['content'].shift(1).fillna('') + '. ' + group['content']
	for i in range(6):
		group['content'].iloc[i] = group['content'].iloc[i].lstrip('. ')
	return group

def concat_previous_8_rows(group):
	if len(group) < 8:
		return pd.DataFrame()
	group = group.copy()
	group['content'] = (
		group['content'].shift(8).fillna('') + '. ' +
		group['content'].shift(7).fillna('') + '. ' +
		group['content'].shift(6).fillna('') + '. ' + 
		group['content'].shift(5).fillna('') + '. ' + 
		group['content'].shift(4).fillna('') + '. ' + 
		group['content'].shift(3).fillna('') + '. ' + 
		group['content'].shift(2).fillna('') + '. ' + 
		group['content'].shift(1).fillna('') + '. ' + 
		group['content']
	)
	for i in range(8):
		group['content'].iloc[i] = group['content'].iloc[i].lstrip('. ')
	return group

# 'dialog_id'별로 그룹화하여 이전 n개 row를 concat
df_1 = df.groupby('dialog_id').apply(concat_previous_1_rows).reset_index(drop=True)
df_2 = df.groupby('dialog_id').apply(concat_previous_2_rows).reset_index(drop=True)
df_4 = df.groupby('dialog_id').apply(concat_previous_4_rows).reset_index(drop=True)
df_6 = df.groupby('dialog_id').apply(concat_previous_6_rows).reset_index(drop=True)
df_8 = df.groupby('dialog_id').apply(concat_previous_8_rows).reset_index(drop=True)

df = df
################################

#model = DistilBERTClass_noFinetune().to(device)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# def text_to_768(text):
#     inputs = tokenizer.encode_plus(
#         text,
#         None,
#         add_special_tokens=True,
#         max_length=128,
#         padding='max_length',
#         return_token_type_ids=False,
#         truncation=True,
#         return_tensors='pt'
#     )
#     ids = inputs['input_ids'].to(device)
#     mask = inputs['attention_mask'].to(device)
    
#     with torch.no_grad():
#         output = model(ids, mask)
    
#     return output.cpu().numpy().flatten()

# df['vector'] = df['content'].apply(text_to_768)

##########################
user_df = df[df['role']=='user']
system_df = df[df['role']=='system']

del user_df['role']
del user_df['music_attribute']
del system_df['role']
del system_df['music_attribute']

def encode_intents(intent_list, intents):
	return [1 if intent in intent_list else 0 for intent in intents]

user_df.loc[:, 'intent'] = user_df['intent'].apply(lambda x: encode_intents(x, user_intents))
user_df = user_df.reset_index(drop=True)

system_df.loc[:, 'intent'] = system_df['intent'].apply(lambda x: encode_intents(x, system_intents))
system_df = system_df.reset_index(drop=True)

music_df = df[['index','dialog_id', 'role', 'content', 'music_attribute']]
music_df.loc[:, 'music_attribute'] = music_df['music_attribute'].apply(lambda x: encode_intents(x, music_attributes))
music_df.rename(columns={'music_attribute': 'intent'}, inplace=True)
music_df = music_df.reset_index(drop=True)

user_y = torch.stack([torch.tensor(item) for item in user_df['intent']])
system_y = torch.stack([torch.tensor(item) for item in system_df['intent']])
music_y = torch.stack([torch.tensor(item) for item in music_df['intent']])

# Train, Valid Split

from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit

msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

for train_index, test_index in msss.split(user_df['content'].values, user_y):
	user_train_df, user_val_df = user_df.iloc[train_index], user_df.iloc[test_index]
	user_train_y, user_val_y = user_y[train_index], user_y[test_index]

msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)

for train_index, test_index in msss.split(user_val_df['content'].values, user_val_y):
	user_val_df, user_test_df = user_val_df.iloc[train_index], user_val_df.iloc[test_index]
	user_val_y, user_test_y = user_val_y[train_index], user_val_y[test_index]

msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

for train_index, test_index in msss.split(system_df['content'].values, system_y):
	system_train_df, system_val_df = system_df.iloc[train_index], system_df.iloc[test_index]
	system_train_y, system_val_y = system_y[train_index], system_y[test_index]

msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)

for train_index, test_index in msss.split(system_val_df['content'].values, system_val_y):
	system_val_df, system_test_df = system_val_df.iloc[train_index], system_val_df.iloc[test_index]
	system_val_y, system_test_y = system_val_y[train_index], system_val_y[test_index]
 
msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

for train_index, test_index in msss.split(music_df['content'].values, music_y):
	music_train_df, music_val_df = music_df.iloc[train_index], music_df.iloc[test_index]
	music_train_y, music_val_y = music_y[train_index], music_y[test_index]

msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)

for train_index, test_index in msss.split(music_val_df['content'].values, music_val_y):
	music_val_df, music_test_df = music_val_df.iloc[train_index], music_val_df.iloc[test_index]
	music_val_y, music_test_y = music_val_y[train_index], music_val_y[test_index]

user_train_df = user_train_df.reset_index(drop=True)
user_val_df = user_val_df.reset_index(drop=True)
user_test_df = user_test_df.reset_index(drop=True)

system_train_df = system_train_df.reset_index(drop=True)
system_val_df = system_val_df.reset_index(drop=True)
system_test_df = system_test_df.reset_index(drop=True)

music_train_df = music_train_df.reset_index(drop=True)
music_val_df = music_val_df.reset_index(drop=True)
music_test_df = music_test_df.reset_index(drop=True)

# Generate Data Dictionary

data_dict = {
	'user': {
		'train': {
			'dataframe': user_train_df,
			'label': user_train_y
		},
		'val': {
			'dataframe': user_val_df,
			'label': user_val_y
		},
  		'test': {
			'dataframe': user_test_df,
			'label': user_test_y
		}
	},
	'system': {
		'train': {
			'dataframe': system_train_df,
			'label': system_train_y
		},
		'val': {
			'dataframe': system_val_df,
			'label': system_val_y
		},
  		'test': {
			'dataframe': system_test_df,
			'label': system_test_y
		}
	},
	'music': {
		'train': {
			'dataframe': music_train_df,
			'label': music_train_y
		},
		'val': {
			'dataframe': music_val_df,
			'label': music_val_y
		},
  		'test': {
			'dataframe': music_test_df,
			'label': music_test_y
		}
	}
}

# Define Dataset Class

class MultiLabelDataset(Dataset):
	def __init__(self, dataframe, tokenizer, max_len):
		self.tokenizer = tokenizer
		self.data = dataframe
		self.text = dataframe.content
		self.targets = self.data.intent
		self.max_len = max_len

	def __len__(self):
		return len(self.text)

	def __getitem__(self, index):
		text = str(self.text[index])
		text = " ".join(text.split())

		tokens = self.tokenizer.tokenize(text)
		if len(tokens) > self.max_len:
			tokens = tokens[-self.max_len:]
		truncated_text = self.tokenizer.convert_tokens_to_string(tokens)

		inputs = self.tokenizer.encode_plus(
			truncated_text,
			None,
			add_special_tokens=True,
			max_length=self.max_len,
			pad_to_max_length=True,
			return_token_type_ids=False,
   			truncation=True
		)
		ids = inputs['input_ids']
		mask = inputs['attention_mask']

		return {
			'ids': torch.tensor(ids, dtype=torch.int),
			'mask': torch.tensor(mask, dtype=torch.int),
			'targets': torch.tensor(self.targets[index], dtype=torch.int)
		}
  
class Dataset_768(Dataset):
	def __init__(self, dataframe):
		self.tokenizer = tokenizer
		self.data = dataframe
		self.vector = dataframe.vector
		self.targets = self.data.intent

	def __getitem__(self, index):
		return {
			'vector': torch.tensor(self.vector, dtype=torch.float),
			'targets': torch.tensor(self.targets[index], dtype=torch.int)
		}

# Define Model

class DistilBERTClass(nn.Module):
	def __init__(self, num_intents):
		super(DistilBERTClass, self).__init__()
		self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
		self.fc1 = nn.Sequential(
			nn.Linear(768, 64),
			nn.BatchNorm1d(64),
			nn.ReLU(),
		)
		self.fc2 = nn.Sequential(
			nn.Linear(64, num_intents)
		)

	def forward(self, input_ids, attention_mask):
		output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
		hidden_state = output_1[0]
		pooler = hidden_state[:, 0]
		pooler = self.fc1(pooler)
		output = self.fc2(pooler)
		return output

class DistilBERTClass_noFinetune(torch.nn.Module):
    def __init__(self):
        super(DistilBERTClass_noFinetune, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")

    def forward(self, input_ids, attention_mask):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1.last_hidden_state # output_1[0]이랑 같음 (아마도)
        pooler = hidden_state[:, 0]
        return pooler

class MLP_768(nn.Module):
	def __init__(self, num_intents):
		super(MLP_768, self).__init__()
		self.fc1 = nn.Sequential(
			nn.Linear(768, 64),
			nn.BatchNorm1d(64),
			nn.ReLU(),
		)
		self.fc2 = nn.Linear(64, num_intents)
		self.sigmoid = nn.Sigmoid()

	def forward(self, x):
		x = self.fc1(x)
		x = self.fc2(x)
		x = self.sigmoid(x)
		return x

# Define functions

def decode_intents(encoded_list, data_type):
	return [intent for intent, flag in zip(intents_dict[data_type], encoded_list) if flag == 1]

In [32]:
class ExampleDataset(Dataset):
	def __init__(self, dataframe, tokenizer, max_len):
		self.tokenizer = tokenizer
		self.data = dataframe
		self.text = dataframe.content
		self.max_len = max_len

	def __len__(self):
		return len(self.text)

	def __getitem__(self, index):
		text = str(self.text[index])
		text = " ".join(text.split())

		tokens = self.tokenizer.tokenize(text)
		if len(tokens) > self.max_len:
			tokens = tokens[-self.max_len:]
		truncated_text = self.tokenizer.convert_tokens_to_string(tokens)

		inputs = self.tokenizer.encode_plus(
			truncated_text,
			None,
			add_special_tokens=True,
			max_length=self.max_len,
			pad_to_max_length=True,
			return_token_type_ids=False,
   			truncation=True
		)
		ids = inputs['input_ids']
		mask = inputs['attention_mask']

		return {
			'ids': torch.tensor(ids, dtype=torch.int),
			'mask': torch.tensor(mask, dtype=torch.int)
		}
  
MAX_LEN = 128

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', do_lower_case=True)
params = {'batch_size': 1, 'shuffle': False, 'num_workers': 4}
num_intents_dict = {'user': 8, 'system': 7, 'music': 15}

user_best_thresholds = [0.28, 0.24, 0.58, 0.42, 0.32, 0.29, 0.69]
system_best_thresholds = [0.49, 0.21, 0.5, 0.23, 0.37, 0.18]
music_best_thresholds = [0.52, 0.69, 0.32, 0.14, 0.13, 0.14, 0.15, 0.46, 0.65, 0.1, 0.52, 0.2, 0.14, 0.23]

thresholds_dict = {'user': user_best_thresholds, 'system': system_best_thresholds, 'music': music_best_thresholds}

def text_to_intent(data_type, text):
	df = pd.DataFrame({'content': [text]})
	
	example_set = ExampleDataset(df, tokenizer, MAX_LEN)

	example_loader = DataLoader(example_set, **params)

	num_intents = num_intents_dict[data_type]

	model = DistilBERTClass(num_intents)
	if data_type=='music':
		model.load_state_dict(torch.load('./models/' + data_type + '_model.pth'))
	else:
		model.load_state_dict(torch.load('./models/' + data_type + '_model_concatone.pth'))
	model.to(device)
	
	# Test with best thresholds
	model.eval()
	probability_outputs=[]
	with torch.no_grad():
		for data in example_loader:
			ids = data['ids'].to(device, dtype = torch.int)
			mask = data['mask'].to(device, dtype = torch.int)
			outputs = model(ids, mask)
			probability_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())

	probability_outputs = np.array(probability_outputs)[0][:-1]
	binary_outputs = (probability_outputs >= thresholds_dict[data_type]).astype(int)
	if np.sum(binary_outputs)==0:
		binary_outputs[-1] = True

	return decode_intents(binary_outputs, data_type)

text_dict = [
	{'data_type': 'system', 'text': ""},
	{'data_type': 'user', 'text': "i dont like the vocals in 'the blood', its a bit dark. ill continue to rate these, can you please add some deadmaus, especially 'i remember' with kaskade?"},
]

concated_text = ""
for i in range(len(text_dict)):
	if i==0:
		concated_text = text_dict[i]['text']
	else:
		concated_text = text_dict[i-1]['text'] + '. ' + text_dict[i]['text']
  
	intents = text_to_intent(text_dict[i]['data_type'], concated_text)
	attributes = text_to_intent('music', text_dict[i]['text'])
  
	print(f"Text: {text_dict[i]['text']}")
	print(f"Intents: {intents}")
	print(f"Music Attributes: {attributes}")
	print()

Text: 
Intents: ['sympathetic_response']
Music Attributes: ['tempo']

Text: i dont like the vocals in 'the blood', its a bit dark. ill continue to rate these, can you please add some deadmaus, especially 'i remember' with kaskade?
Intents: ['add_filter', 'remove_filter']
Music Attributes: ['track', 'artist', 'mood', 'vocal']

