# import tool

In [1]:
import numpy as np
import pandas as pd
import json 
import pickle

from tqdm import tqdm
from IPython.display import clear_output
# torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import TensorDataset, random_split

# bert
from transformers import BertTokenizer, BertModel, BertForSequenceClassification
clear_output()

# import self-define tool

In [2]:
from module.class_label_preprocessing import label_preprocess
with open('./module/label_encoding.pkl' , 'rb') as input:
    label_preprocessing = pickle.load(input)

# Parameter

In [3]:
model_name = "bert_small_2_model_5030"
MAX_LENGTH = 100
DEVICE = "cuda: 0"

In [4]:
PRETRAINED_MODEL_NAME = "bert-base-cased"

# read data

In [5]:
data = [json.loads(line) for line in open('./data/test_unlabeled.json', 'r')]    

In [6]:
data

[{'idx': 36000,
  'text': '@Youngdeji_ I think if uzi and carti dropping Monday you gotta drop lil woah also',
  'reply': ''},
 {'idx': 36001,
  'text': 'For the third year in a row we’re discussing trading picks for a safety.',
  'reply': ''},
 {'idx': 36002,
  'text': 'dababy album sounds like it was made for niggas who work in the kitchen at Denny’s',
  'reply': "That's why you bought it."},
 {'idx': 36003,
  'text': 'Majority of Indians don’t watch any sport other than cricket.\nMajority of Indians would’ve not known about you if the movie never came out.\nZaira Wasim did so much hardwork to make you look good on screen. \nUsing her name to justify your mentality is sick.',
  'reply': '@ZairaWasimmm got a great story because of these sis..    #ISupportBabitaPhogat'},
 {'idx': 36004,
  'text': 'everybody is just now listening to @madisonbeer after selfish came out when I’ve been listening since “as she please” wtf',
  'reply': ''},
 {'idx': 36005,
  'text': 'Might just say fuck it a

In [7]:
X = [(i['text'], i['reply']) for i in data]

In [8]:
def create_data_loader(X, batch_size_):
    X_text = [i[0] for i in X]
    X_reply = [i[1] for i in X]
    tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)

    buf = [tokenizer.encode_plus(i[0], i[1], do_lower_case = False, add_special_tokens = True, max_length = MAX_LENGTH, pad_to_max_length = True) for i in tqdm(X)]   
    input_ids = torch.LongTensor( [i['input_ids'] for i in buf] )
    token_type_ids = torch.LongTensor( [i['token_type_ids'] for i in buf])
    attention_mask = torch.LongTensor( [i['attention_mask'] for i in buf])

    dataset = TensorDataset(input_ids, token_type_ids, attention_mask)
    loader = torch.utils.data.DataLoader(dataset = dataset, batch_size = batch_size_, shuffle = True)

    return(loader)

In [9]:
data_loader = create_data_loader(X, batch_size_ = 1)

100%|██████████| 4000/4000 [00:01<00:00, 2315.57it/s]


# model

In [10]:
model = BertForSequenceClassification.from_pretrained("./model_save/%s" % model_name, num_labels = 43)

In [11]:
model.to(DEVICE)
clear_output()

In [12]:
def output_label(data):
    input_ids, token_type_ids, attention_mask = [t.to(DEVICE) for t in data]
    outputs = model(input_ids = input_ids, 
                        token_type_ids = token_type_ids, 
                        attention_mask = attention_mask) 

    predict_prop = list(outputs[0].cpu().detach().numpy()[0])
    X = list(np.arange(len(predict_prop)))
    X.sort(key=dict(zip(list(X), list(predict_prop))).get, reverse=True)

    predict_label = X[:6]

    buf = np.array([ 1 if i in predict_label else 0 for i in range(43)]).reshape(1, 43)

    return( label_preprocessing.mlb.inverse_transform(buf)[0] )

In [13]:
answer = [output_label(data) for data in tqdm(data_loader)]

100%|██████████| 4000/4000 [01:09<00:00, 57.72it/s]


# save answer

In [14]:
data = [json.loads(line) for line in open('./data/test_unlabeled.json', 'r')]    

In [15]:
for i in range(len(data)):
    data[i]['categories'] = list(answer[i])  

# save answer

In [16]:
save_name = "./output_file/%s_eval_result.json" % model_name

In [17]:
df = pd.DataFrame(data).reindex(columns=["idx", 'categories', 'reply', 'text'])

In [18]:
df

Unnamed: 0,idx,categories,reply,text
0,36000,"[applause, hug, slow_clap, smh, yes, you_got_t...",,@Youngdeji_ I think if uzi and carti dropping ...
1,36001,"[applause, eye_roll, no, popcorn, slow_clap, yes]",,For the third year in a row we’re discussing t...
2,36002,"[awww, deal_with_it, hug, popcorn, sorry, you_...",That's why you bought it.,dababy album sounds like it was made for nigga...
3,36003,"[agree, applause, omg, shocked, slow_clap, yes]",@ZairaWasimmm got a great story because of the...,Majority of Indians don’t watch any sport othe...
4,36004,"[agree, applause, facepalm, hug, seriously, smh]",,everybody is just now listening to @madisonbee...
...,...,...,...,...
3995,39995,"[facepalm, no, omg, seriously, sigh, smh]",,Every Republican that's calling for the countr...
3996,39996,"[applause, no, omg, oops, sorry, yes]",,The point of “15 and stupid” when someone says...
3997,39997,"[agree, applause, slow_clap, win, yes, you_got...",OMG 😮 people do that to men as well,I'm not responding to any one who uses the wor...
3998,39998,"[agree, no, omg, oops, shocked, yes]",,Just made a (major?) breakthrough in my analys...


In [19]:
df.to_json(save_name, lines = True, orient = "records")

In [20]:
# with open(save_name, 'w') as outfile:
#     json.dump(data, outfile)