In [1]:
import sys
import numpy as np
import random as rn
import pandas as pd
import torch
from pytorch_pretrained_bert import BertModel
from torch import nn

from pytorch_pretrained_bert import BertTokenizer
from keras_preprocessing.sequence import pad_sequences
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
import torch.nn.functional as F
import torch.optim as optim

from IPython.display import clear_output
from transformers import AutoTokenizer, AutoModelForMaskedLM

import sqlite3
import jieba
import jieba.posseg as pseg

from sklearn.metrics import classification_report

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
rn.seed(321)
np.random.seed(321)
torch.manual_seed(321)
torch.cuda.manual_seed(321)

In [3]:
import pathlib
pathlib.Path().resolve() # CHECK ROUTE

WindowsPath('C:/Users/HackerByeBye/Documents/Therapy-Chatbot-Deploying-NLP/Training')

In [4]:
train_data = pd.read_csv('../likecount_train.csv')
test_data = pd.read_csv('../likecount_test.csv')

In [5]:
train_data = train_data.to_dict(orient='records')
test_data = test_data.to_dict(orient='records')
len(train_data), len(test_data)

(30761, 10252)

In [6]:
train_texts, train_labels = list(zip(*map(lambda d: (d['title'], d['label']), train_data)))
test_texts, test_labels = list(zip(*map(lambda d: (d['title'], d['label']), test_data)))

four_train_labels, four_test_labels = [], []

for x in train_labels:
    if x==1:
        four_train_labels.append(0)
    elif x==2:
        four_train_labels.append(0)
    elif x==3:
        four_train_labels.append(1)
    elif x==4:
        four_train_labels.append(1)

        
for x in test_labels:
    if x==1:
        four_test_labels.append(0)
    elif x==2:
        four_test_labels.append(0)
    elif x==3:
        four_test_labels.append(1)
    elif x==4:
        four_test_labels.append(1)


In [11]:
def jiebaSlice(content,mode):
    stopword_set = []
    content = str(content)
#     with open('../Analyzing/stopword.txt','r', encoding='utf-8') as stopwords:
#         for stopword in stopwords:
#             stopword_set.append(stopword.strip('\n'))
    
    content = content.strip('\n')
    if mode == "POSSEG":
        words = pseg.cut(content,use_paddle=True)
        slicedWords = []
        for word, flag in words:
#             if word not in stopword_set:
            slicedWords.append(word)
        return slicedWords
    elif mode == "CUT_HMM":
        seg_list = jieba.cut(content,HMM=True,cut_all=True)
        slicedWords = list(seg_list)
        return slicedWords
    elif mode == "CUT_FOR_SEARCH":
        seg_list = jieba.cut_for_search(content,HMM=True)
        slicedWords = list(seg_list)
        return slicedWords
    elif mode == "NORMAL":
        seg_list = jieba.cut_for_search(content)
        slicedWords = list(seg_list)
        return slicedWords  

In [12]:
train_tokens = list(map(lambda t: ['[CLS]'] + jiebaSlice(t,'CUT_HMM') + ['[SEP]'], train_texts))
test_tokens = list(map(lambda t: ['[CLS]'] + jiebaSlice(t,'CUT_HMM') + ['[SEP]'], test_texts))

tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese', do_lower_case=True)

train_tokens_ids = pad_sequences(list(map(tokenizer.convert_tokens_to_ids, train_tokens)), maxlen=512, truncating="post", padding="post", dtype="int")
test_tokens_ids = pad_sequences(list(map(tokenizer.convert_tokens_to_ids, test_tokens)), maxlen=512, truncating="post", padding="post", dtype="int")

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\HACKER~1\AppData\Local\Temp\jieba.cache
Loading model cost 0.530 seconds.
Prefix dict has been built successfully.


In [13]:
train_y = np.array(four_train_labels)
test_y = np.array(four_test_labels) 
train_y.shape, test_y.shape

((30761,), (10252,))

In [14]:
train_masks = [[float(i > 0) for i in ii] for ii in train_tokens_ids]
test_masks = [[float(i > 0) for i in ii] for ii in test_tokens_ids]

In [15]:
class BertBinaryClassifier(nn.Module):
    def __init__(self, dropout=0.1):
        super(BertBinaryClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-chinese')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 4)
#         self.sigmoid = nn.Sigmoid()
    
    def forward(self, tokens, masks=None):
        _, pooled_output = self.bert(tokens, attention_mask=masks, output_all_encoded_layers=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
#         proba = self.sigmoid(linear_output)
        return F.log_softmax(linear_output)

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [17]:
BATCH_SIZE = 1
EPOCHS = 30

In [18]:
train_tokens_tensor = torch.tensor(train_tokens_ids)
train_y_tensor = torch.tensor(train_y.reshape(-1, 1)).float()

test_tokens_tensor = torch.tensor(test_tokens_ids)
test_y_tensor = torch.tensor(test_y.reshape(-1, 1)).float()

train_masks_tensor = torch.tensor(train_masks)
test_masks_tensor = torch.tensor(test_masks)

str(torch.cuda.memory_allocated(device)/1000000 ) + 'M'

'0.0M'

In [19]:
train_dataset = TensorDataset(train_tokens_tensor, train_masks_tensor, train_y_tensor)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=BATCH_SIZE)

test_dataset = TensorDataset(test_tokens_tensor, test_masks_tensor, test_y_tensor)
test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=BATCH_SIZE)

In [20]:
bert_clf = BertBinaryClassifier()
bert_clf = bert_clf.cuda()
optimizer = Adam(bert_clf.parameters(), lr=3e-6)

In [21]:
import os 
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.cuda.empty_cache()

In [22]:
epoch_loss = []
for epoch_num in range(EPOCHS):
    bert_clf.train()
    train_loss = 0
    for step_num, batch_data in enumerate(train_dataloader):
        token_ids, masks, labels = tuple(t.to(device) for t in batch_data)
        labels = torch.squeeze(labels,1)
        labels = labels.to(torch.long)
        print(str(torch.cuda.memory_allocated(device)/1000000 ) + 'M')
        logits = bert_clf(token_ids, masks)

        loss_func = nn.CrossEntropyLoss()

        batch_loss = loss_func(logits, labels) #計算loss
        train_loss += batch_loss.item()
        
        bert_clf.zero_grad() #清空前一次的gradient
        batch_loss.backward() #根據loss進行back propagation，計算gradient
        
        clip_grad_norm_(parameters=bert_clf.parameters(), max_norm=1.0)
        optimizer.step() #做gradient descent
        
        clear_output(wait=True)
        print('Epoch: ', epoch_num + 1)
        print("\r" + "{0}/{1} loss: {2} ".format(step_num, len(train_data) / BATCH_SIZE, train_loss / (step_num + 1)))
    epoch_loss.append(train_loss)

Epoch:  30
30760/30761.0 loss: 0.3680111478890567 


In [23]:
bert_clf.eval()
bert_predicted = []
all_logits = []
with torch.no_grad():
    for step_num, batch_data in enumerate(test_dataloader):
        token_ids, masks, labels = tuple(t.to(device) for t in batch_data)
        labels = torch.squeeze(labels,1)
        labels = labels.to(torch.long)
        logits = bert_clf(token_ids, masks)
#         loss_func = nn.BCELoss()
        loss_func = nn.CrossEntropyLoss()
        loss = loss_func(logits, labels)
        numpy_logits = logits.cpu().detach().numpy()
        print("\r" + "{0}/{1}".format(step_num, len(test_data) / BATCH_SIZE))
        bert_predicted.append(list(numpy_logits[0, :]).index(max(list(numpy_logits[0, :]))))
        all_logits += list(numpy_logits[:, 0])

  


0/10252.0
1/10252.0
2/10252.0
3/10252.0
4/10252.0
5/10252.0
6/10252.0
7/10252.0
8/10252.0
9/10252.0
10/10252.0
11/10252.0
12/10252.0
13/10252.0
14/10252.0
15/10252.0
16/10252.0
17/10252.0
18/10252.0
19/10252.0
20/10252.0
21/10252.0
22/10252.0
23/10252.0
24/10252.0
25/10252.0
26/10252.0
27/10252.0
28/10252.0
29/10252.0
30/10252.0
31/10252.0
32/10252.0
33/10252.0
34/10252.0
35/10252.0
36/10252.0
37/10252.0
38/10252.0
39/10252.0
40/10252.0
41/10252.0
42/10252.0
43/10252.0
44/10252.0
45/10252.0
46/10252.0
47/10252.0
48/10252.0
49/10252.0
50/10252.0
51/10252.0
52/10252.0
53/10252.0
54/10252.0
55/10252.0
56/10252.0
57/10252.0
58/10252.0
59/10252.0
60/10252.0
61/10252.0
62/10252.0
63/10252.0
64/10252.0
65/10252.0
66/10252.0
67/10252.0
68/10252.0
69/10252.0
70/10252.0
71/10252.0
72/10252.0
73/10252.0
74/10252.0
75/10252.0
76/10252.0
77/10252.0
78/10252.0
79/10252.0
80/10252.0
81/10252.0
82/10252.0
83/10252.0
84/10252.0
85/10252.0
86/10252.0
87/10252.0
88/10252.0
89/10252.0
90/10252.0
91/10252.

698/10252.0
699/10252.0
700/10252.0
701/10252.0
702/10252.0
703/10252.0
704/10252.0
705/10252.0
706/10252.0
707/10252.0
708/10252.0
709/10252.0
710/10252.0
711/10252.0
712/10252.0
713/10252.0
714/10252.0
715/10252.0
716/10252.0
717/10252.0
718/10252.0
719/10252.0
720/10252.0
721/10252.0
722/10252.0
723/10252.0
724/10252.0
725/10252.0
726/10252.0
727/10252.0
728/10252.0
729/10252.0
730/10252.0
731/10252.0
732/10252.0
733/10252.0
734/10252.0
735/10252.0
736/10252.0
737/10252.0
738/10252.0
739/10252.0
740/10252.0
741/10252.0
742/10252.0
743/10252.0
744/10252.0
745/10252.0
746/10252.0
747/10252.0
748/10252.0
749/10252.0
750/10252.0
751/10252.0
752/10252.0
753/10252.0
754/10252.0
755/10252.0
756/10252.0
757/10252.0
758/10252.0
759/10252.0
760/10252.0
761/10252.0
762/10252.0
763/10252.0
764/10252.0
765/10252.0
766/10252.0
767/10252.0
768/10252.0
769/10252.0
770/10252.0
771/10252.0
772/10252.0
773/10252.0
774/10252.0
775/10252.0
776/10252.0
777/10252.0
778/10252.0
779/10252.0
780/10252.0
781/

1353/10252.0
1354/10252.0
1355/10252.0
1356/10252.0
1357/10252.0
1358/10252.0
1359/10252.0
1360/10252.0
1361/10252.0
1362/10252.0
1363/10252.0
1364/10252.0
1365/10252.0
1366/10252.0
1367/10252.0
1368/10252.0
1369/10252.0
1370/10252.0
1371/10252.0
1372/10252.0
1373/10252.0
1374/10252.0
1375/10252.0
1376/10252.0
1377/10252.0
1378/10252.0
1379/10252.0
1380/10252.0
1381/10252.0
1382/10252.0
1383/10252.0
1384/10252.0
1385/10252.0
1386/10252.0
1387/10252.0
1388/10252.0
1389/10252.0
1390/10252.0
1391/10252.0
1392/10252.0
1393/10252.0
1394/10252.0
1395/10252.0
1396/10252.0
1397/10252.0
1398/10252.0
1399/10252.0
1400/10252.0
1401/10252.0
1402/10252.0
1403/10252.0
1404/10252.0
1405/10252.0
1406/10252.0
1407/10252.0
1408/10252.0
1409/10252.0
1410/10252.0
1411/10252.0
1412/10252.0
1413/10252.0
1414/10252.0
1415/10252.0
1416/10252.0
1417/10252.0
1418/10252.0
1419/10252.0
1420/10252.0
1421/10252.0
1422/10252.0
1423/10252.0
1424/10252.0
1425/10252.0
1426/10252.0
1427/10252.0
1428/10252.0
1429/10252.0

1986/10252.0
1987/10252.0
1988/10252.0
1989/10252.0
1990/10252.0
1991/10252.0
1992/10252.0
1993/10252.0
1994/10252.0
1995/10252.0
1996/10252.0
1997/10252.0
1998/10252.0
1999/10252.0
2000/10252.0
2001/10252.0
2002/10252.0
2003/10252.0
2004/10252.0
2005/10252.0
2006/10252.0
2007/10252.0
2008/10252.0
2009/10252.0
2010/10252.0
2011/10252.0
2012/10252.0
2013/10252.0
2014/10252.0
2015/10252.0
2016/10252.0
2017/10252.0
2018/10252.0
2019/10252.0
2020/10252.0
2021/10252.0
2022/10252.0
2023/10252.0
2024/10252.0
2025/10252.0
2026/10252.0
2027/10252.0
2028/10252.0
2029/10252.0
2030/10252.0
2031/10252.0
2032/10252.0
2033/10252.0
2034/10252.0
2035/10252.0
2036/10252.0
2037/10252.0
2038/10252.0
2039/10252.0
2040/10252.0
2041/10252.0
2042/10252.0
2043/10252.0
2044/10252.0
2045/10252.0
2046/10252.0
2047/10252.0
2048/10252.0
2049/10252.0
2050/10252.0
2051/10252.0
2052/10252.0
2053/10252.0
2054/10252.0
2055/10252.0
2056/10252.0
2057/10252.0
2058/10252.0
2059/10252.0
2060/10252.0
2061/10252.0
2062/10252.0

2619/10252.0
2620/10252.0
2621/10252.0
2622/10252.0
2623/10252.0
2624/10252.0
2625/10252.0
2626/10252.0
2627/10252.0
2628/10252.0
2629/10252.0
2630/10252.0
2631/10252.0
2632/10252.0
2633/10252.0
2634/10252.0
2635/10252.0
2636/10252.0
2637/10252.0
2638/10252.0
2639/10252.0
2640/10252.0
2641/10252.0
2642/10252.0
2643/10252.0
2644/10252.0
2645/10252.0
2646/10252.0
2647/10252.0
2648/10252.0
2649/10252.0
2650/10252.0
2651/10252.0
2652/10252.0
2653/10252.0
2654/10252.0
2655/10252.0
2656/10252.0
2657/10252.0
2658/10252.0
2659/10252.0
2660/10252.0
2661/10252.0
2662/10252.0
2663/10252.0
2664/10252.0
2665/10252.0
2666/10252.0
2667/10252.0
2668/10252.0
2669/10252.0
2670/10252.0
2671/10252.0
2672/10252.0
2673/10252.0
2674/10252.0
2675/10252.0
2676/10252.0
2677/10252.0
2678/10252.0
2679/10252.0
2680/10252.0
2681/10252.0
2682/10252.0
2683/10252.0
2684/10252.0
2685/10252.0
2686/10252.0
2687/10252.0
2688/10252.0
2689/10252.0
2690/10252.0
2691/10252.0
2692/10252.0
2693/10252.0
2694/10252.0
2695/10252.0

3254/10252.0
3255/10252.0
3256/10252.0
3257/10252.0
3258/10252.0
3259/10252.0
3260/10252.0
3261/10252.0
3262/10252.0
3263/10252.0
3264/10252.0
3265/10252.0
3266/10252.0
3267/10252.0
3268/10252.0
3269/10252.0
3270/10252.0
3271/10252.0
3272/10252.0
3273/10252.0
3274/10252.0
3275/10252.0
3276/10252.0
3277/10252.0
3278/10252.0
3279/10252.0
3280/10252.0
3281/10252.0
3282/10252.0
3283/10252.0
3284/10252.0
3285/10252.0
3286/10252.0
3287/10252.0
3288/10252.0
3289/10252.0
3290/10252.0
3291/10252.0
3292/10252.0
3293/10252.0
3294/10252.0
3295/10252.0
3296/10252.0
3297/10252.0
3298/10252.0
3299/10252.0
3300/10252.0
3301/10252.0
3302/10252.0
3303/10252.0
3304/10252.0
3305/10252.0
3306/10252.0
3307/10252.0
3308/10252.0
3309/10252.0
3310/10252.0
3311/10252.0
3312/10252.0
3313/10252.0
3314/10252.0
3315/10252.0
3316/10252.0
3317/10252.0
3318/10252.0
3319/10252.0
3320/10252.0
3321/10252.0
3322/10252.0
3323/10252.0
3324/10252.0
3325/10252.0
3326/10252.0
3327/10252.0
3328/10252.0
3329/10252.0
3330/10252.0

3885/10252.0
3886/10252.0
3887/10252.0
3888/10252.0
3889/10252.0
3890/10252.0
3891/10252.0
3892/10252.0
3893/10252.0
3894/10252.0
3895/10252.0
3896/10252.0
3897/10252.0
3898/10252.0
3899/10252.0
3900/10252.0
3901/10252.0
3902/10252.0
3903/10252.0
3904/10252.0
3905/10252.0
3906/10252.0
3907/10252.0
3908/10252.0
3909/10252.0
3910/10252.0
3911/10252.0
3912/10252.0
3913/10252.0
3914/10252.0
3915/10252.0
3916/10252.0
3917/10252.0
3918/10252.0
3919/10252.0
3920/10252.0
3921/10252.0
3922/10252.0
3923/10252.0
3924/10252.0
3925/10252.0
3926/10252.0
3927/10252.0
3928/10252.0
3929/10252.0
3930/10252.0
3931/10252.0
3932/10252.0
3933/10252.0
3934/10252.0
3935/10252.0
3936/10252.0
3937/10252.0
3938/10252.0
3939/10252.0
3940/10252.0
3941/10252.0
3942/10252.0
3943/10252.0
3944/10252.0
3945/10252.0
3946/10252.0
3947/10252.0
3948/10252.0
3949/10252.0
3950/10252.0
3951/10252.0
3952/10252.0
3953/10252.0
3954/10252.0
3955/10252.0
3956/10252.0
3957/10252.0
3958/10252.0
3959/10252.0
3960/10252.0
3961/10252.0

4520/10252.0
4521/10252.0
4522/10252.0
4523/10252.0
4524/10252.0
4525/10252.0
4526/10252.0
4527/10252.0
4528/10252.0
4529/10252.0
4530/10252.0
4531/10252.0
4532/10252.0
4533/10252.0
4534/10252.0
4535/10252.0
4536/10252.0
4537/10252.0
4538/10252.0
4539/10252.0
4540/10252.0
4541/10252.0
4542/10252.0
4543/10252.0
4544/10252.0
4545/10252.0
4546/10252.0
4547/10252.0
4548/10252.0
4549/10252.0
4550/10252.0
4551/10252.0
4552/10252.0
4553/10252.0
4554/10252.0
4555/10252.0
4556/10252.0
4557/10252.0
4558/10252.0
4559/10252.0
4560/10252.0
4561/10252.0
4562/10252.0
4563/10252.0
4564/10252.0
4565/10252.0
4566/10252.0
4567/10252.0
4568/10252.0
4569/10252.0
4570/10252.0
4571/10252.0
4572/10252.0
4573/10252.0
4574/10252.0
4575/10252.0
4576/10252.0
4577/10252.0
4578/10252.0
4579/10252.0
4580/10252.0
4581/10252.0
4582/10252.0
4583/10252.0
4584/10252.0
4585/10252.0
4586/10252.0
4587/10252.0
4588/10252.0
4589/10252.0
4590/10252.0
4591/10252.0
4592/10252.0
4593/10252.0
4594/10252.0
4595/10252.0
4596/10252.0

5154/10252.0
5155/10252.0
5156/10252.0
5157/10252.0
5158/10252.0
5159/10252.0
5160/10252.0
5161/10252.0
5162/10252.0
5163/10252.0
5164/10252.0
5165/10252.0
5166/10252.0
5167/10252.0
5168/10252.0
5169/10252.0
5170/10252.0
5171/10252.0
5172/10252.0
5173/10252.0
5174/10252.0
5175/10252.0
5176/10252.0
5177/10252.0
5178/10252.0
5179/10252.0
5180/10252.0
5181/10252.0
5182/10252.0
5183/10252.0
5184/10252.0
5185/10252.0
5186/10252.0
5187/10252.0
5188/10252.0
5189/10252.0
5190/10252.0
5191/10252.0
5192/10252.0
5193/10252.0
5194/10252.0
5195/10252.0
5196/10252.0
5197/10252.0
5198/10252.0
5199/10252.0
5200/10252.0
5201/10252.0
5202/10252.0
5203/10252.0
5204/10252.0
5205/10252.0
5206/10252.0
5207/10252.0
5208/10252.0
5209/10252.0
5210/10252.0
5211/10252.0
5212/10252.0
5213/10252.0
5214/10252.0
5215/10252.0
5216/10252.0
5217/10252.0
5218/10252.0
5219/10252.0
5220/10252.0
5221/10252.0
5222/10252.0
5223/10252.0
5224/10252.0
5225/10252.0
5226/10252.0
5227/10252.0
5228/10252.0
5229/10252.0
5230/10252.0

5788/10252.0
5789/10252.0
5790/10252.0
5791/10252.0
5792/10252.0
5793/10252.0
5794/10252.0
5795/10252.0
5796/10252.0
5797/10252.0
5798/10252.0
5799/10252.0
5800/10252.0
5801/10252.0
5802/10252.0
5803/10252.0
5804/10252.0
5805/10252.0
5806/10252.0
5807/10252.0
5808/10252.0
5809/10252.0
5810/10252.0
5811/10252.0
5812/10252.0
5813/10252.0
5814/10252.0
5815/10252.0
5816/10252.0
5817/10252.0
5818/10252.0
5819/10252.0
5820/10252.0
5821/10252.0
5822/10252.0
5823/10252.0
5824/10252.0
5825/10252.0
5826/10252.0
5827/10252.0
5828/10252.0
5829/10252.0
5830/10252.0
5831/10252.0
5832/10252.0
5833/10252.0
5834/10252.0
5835/10252.0
5836/10252.0
5837/10252.0
5838/10252.0
5839/10252.0
5840/10252.0
5841/10252.0
5842/10252.0
5843/10252.0
5844/10252.0
5845/10252.0
5846/10252.0
5847/10252.0
5848/10252.0
5849/10252.0
5850/10252.0
5851/10252.0
5852/10252.0
5853/10252.0
5854/10252.0
5855/10252.0
5856/10252.0
5857/10252.0
5858/10252.0
5859/10252.0
5860/10252.0
5861/10252.0
5862/10252.0
5863/10252.0
5864/10252.0

6422/10252.0
6423/10252.0
6424/10252.0
6425/10252.0
6426/10252.0
6427/10252.0
6428/10252.0
6429/10252.0
6430/10252.0
6431/10252.0
6432/10252.0
6433/10252.0
6434/10252.0
6435/10252.0
6436/10252.0
6437/10252.0
6438/10252.0
6439/10252.0
6440/10252.0
6441/10252.0
6442/10252.0
6443/10252.0
6444/10252.0
6445/10252.0
6446/10252.0
6447/10252.0
6448/10252.0
6449/10252.0
6450/10252.0
6451/10252.0
6452/10252.0
6453/10252.0
6454/10252.0
6455/10252.0
6456/10252.0
6457/10252.0
6458/10252.0
6459/10252.0
6460/10252.0
6461/10252.0
6462/10252.0
6463/10252.0
6464/10252.0
6465/10252.0
6466/10252.0
6467/10252.0
6468/10252.0
6469/10252.0
6470/10252.0
6471/10252.0
6472/10252.0
6473/10252.0
6474/10252.0
6475/10252.0
6476/10252.0
6477/10252.0
6478/10252.0
6479/10252.0
6480/10252.0
6481/10252.0
6482/10252.0
6483/10252.0
6484/10252.0
6485/10252.0
6486/10252.0
6487/10252.0
6488/10252.0
6489/10252.0
6490/10252.0
6491/10252.0
6492/10252.0
6493/10252.0
6494/10252.0
6495/10252.0
6496/10252.0
6497/10252.0
6498/10252.0

7055/10252.0
7056/10252.0
7057/10252.0
7058/10252.0
7059/10252.0
7060/10252.0
7061/10252.0
7062/10252.0
7063/10252.0
7064/10252.0
7065/10252.0
7066/10252.0
7067/10252.0
7068/10252.0
7069/10252.0
7070/10252.0
7071/10252.0
7072/10252.0
7073/10252.0
7074/10252.0
7075/10252.0
7076/10252.0
7077/10252.0
7078/10252.0
7079/10252.0
7080/10252.0
7081/10252.0
7082/10252.0
7083/10252.0
7084/10252.0
7085/10252.0
7086/10252.0
7087/10252.0
7088/10252.0
7089/10252.0
7090/10252.0
7091/10252.0
7092/10252.0
7093/10252.0
7094/10252.0
7095/10252.0
7096/10252.0
7097/10252.0
7098/10252.0
7099/10252.0
7100/10252.0
7101/10252.0
7102/10252.0
7103/10252.0
7104/10252.0
7105/10252.0
7106/10252.0
7107/10252.0
7108/10252.0
7109/10252.0
7110/10252.0
7111/10252.0
7112/10252.0
7113/10252.0
7114/10252.0
7115/10252.0
7116/10252.0
7117/10252.0
7118/10252.0
7119/10252.0
7120/10252.0
7121/10252.0
7122/10252.0
7123/10252.0
7124/10252.0
7125/10252.0
7126/10252.0
7127/10252.0
7128/10252.0
7129/10252.0
7130/10252.0
7131/10252.0

7687/10252.0
7688/10252.0
7689/10252.0
7690/10252.0
7691/10252.0
7692/10252.0
7693/10252.0
7694/10252.0
7695/10252.0
7696/10252.0
7697/10252.0
7698/10252.0
7699/10252.0
7700/10252.0
7701/10252.0
7702/10252.0
7703/10252.0
7704/10252.0
7705/10252.0
7706/10252.0
7707/10252.0
7708/10252.0
7709/10252.0
7710/10252.0
7711/10252.0
7712/10252.0
7713/10252.0
7714/10252.0
7715/10252.0
7716/10252.0
7717/10252.0
7718/10252.0
7719/10252.0
7720/10252.0
7721/10252.0
7722/10252.0
7723/10252.0
7724/10252.0
7725/10252.0
7726/10252.0
7727/10252.0
7728/10252.0
7729/10252.0
7730/10252.0
7731/10252.0
7732/10252.0
7733/10252.0
7734/10252.0
7735/10252.0
7736/10252.0
7737/10252.0
7738/10252.0
7739/10252.0
7740/10252.0
7741/10252.0
7742/10252.0
7743/10252.0
7744/10252.0
7745/10252.0
7746/10252.0
7747/10252.0
7748/10252.0
7749/10252.0
7750/10252.0
7751/10252.0
7752/10252.0
7753/10252.0
7754/10252.0
7755/10252.0
7756/10252.0
7757/10252.0
7758/10252.0
7759/10252.0
7760/10252.0
7761/10252.0
7762/10252.0
7763/10252.0

8321/10252.0
8322/10252.0
8323/10252.0
8324/10252.0
8325/10252.0
8326/10252.0
8327/10252.0
8328/10252.0
8329/10252.0
8330/10252.0
8331/10252.0
8332/10252.0
8333/10252.0
8334/10252.0
8335/10252.0
8336/10252.0
8337/10252.0
8338/10252.0
8339/10252.0
8340/10252.0
8341/10252.0
8342/10252.0
8343/10252.0
8344/10252.0
8345/10252.0
8346/10252.0
8347/10252.0
8348/10252.0
8349/10252.0
8350/10252.0
8351/10252.0
8352/10252.0
8353/10252.0
8354/10252.0
8355/10252.0
8356/10252.0
8357/10252.0
8358/10252.0
8359/10252.0
8360/10252.0
8361/10252.0
8362/10252.0
8363/10252.0
8364/10252.0
8365/10252.0
8366/10252.0
8367/10252.0
8368/10252.0
8369/10252.0
8370/10252.0
8371/10252.0
8372/10252.0
8373/10252.0
8374/10252.0
8375/10252.0
8376/10252.0
8377/10252.0
8378/10252.0
8379/10252.0
8380/10252.0
8381/10252.0
8382/10252.0
8383/10252.0
8384/10252.0
8385/10252.0
8386/10252.0
8387/10252.0
8388/10252.0
8389/10252.0
8390/10252.0
8391/10252.0
8392/10252.0
8393/10252.0
8394/10252.0
8395/10252.0
8396/10252.0
8397/10252.0

8956/10252.0
8957/10252.0
8958/10252.0
8959/10252.0
8960/10252.0
8961/10252.0
8962/10252.0
8963/10252.0
8964/10252.0
8965/10252.0
8966/10252.0
8967/10252.0
8968/10252.0
8969/10252.0
8970/10252.0
8971/10252.0
8972/10252.0
8973/10252.0
8974/10252.0
8975/10252.0
8976/10252.0
8977/10252.0
8978/10252.0
8979/10252.0
8980/10252.0
8981/10252.0
8982/10252.0
8983/10252.0
8984/10252.0
8985/10252.0
8986/10252.0
8987/10252.0
8988/10252.0
8989/10252.0
8990/10252.0
8991/10252.0
8992/10252.0
8993/10252.0
8994/10252.0
8995/10252.0
8996/10252.0
8997/10252.0
8998/10252.0
8999/10252.0
9000/10252.0
9001/10252.0
9002/10252.0
9003/10252.0
9004/10252.0
9005/10252.0
9006/10252.0
9007/10252.0
9008/10252.0
9009/10252.0
9010/10252.0
9011/10252.0
9012/10252.0
9013/10252.0
9014/10252.0
9015/10252.0
9016/10252.0
9017/10252.0
9018/10252.0
9019/10252.0
9020/10252.0
9021/10252.0
9022/10252.0
9023/10252.0
9024/10252.0
9025/10252.0
9026/10252.0
9027/10252.0
9028/10252.0
9029/10252.0
9030/10252.0
9031/10252.0
9032/10252.0

9588/10252.0
9589/10252.0
9590/10252.0
9591/10252.0
9592/10252.0
9593/10252.0
9594/10252.0
9595/10252.0
9596/10252.0
9597/10252.0
9598/10252.0
9599/10252.0
9600/10252.0
9601/10252.0
9602/10252.0
9603/10252.0
9604/10252.0
9605/10252.0
9606/10252.0
9607/10252.0
9608/10252.0
9609/10252.0
9610/10252.0
9611/10252.0
9612/10252.0
9613/10252.0
9614/10252.0
9615/10252.0
9616/10252.0
9617/10252.0
9618/10252.0
9619/10252.0
9620/10252.0
9621/10252.0
9622/10252.0
9623/10252.0
9624/10252.0
9625/10252.0
9626/10252.0
9627/10252.0
9628/10252.0
9629/10252.0
9630/10252.0
9631/10252.0
9632/10252.0
9633/10252.0
9634/10252.0
9635/10252.0
9636/10252.0
9637/10252.0
9638/10252.0
9639/10252.0
9640/10252.0
9641/10252.0
9642/10252.0
9643/10252.0
9644/10252.0
9645/10252.0
9646/10252.0
9647/10252.0
9648/10252.0
9649/10252.0
9650/10252.0
9651/10252.0
9652/10252.0
9653/10252.0
9654/10252.0
9655/10252.0
9656/10252.0
9657/10252.0
9658/10252.0
9659/10252.0
9660/10252.0
9661/10252.0
9662/10252.0
9663/10252.0
9664/10252.0

10213/10252.0
10214/10252.0
10215/10252.0
10216/10252.0
10217/10252.0
10218/10252.0
10219/10252.0
10220/10252.0
10221/10252.0
10222/10252.0
10223/10252.0
10224/10252.0
10225/10252.0
10226/10252.0
10227/10252.0
10228/10252.0
10229/10252.0
10230/10252.0
10231/10252.0
10232/10252.0
10233/10252.0
10234/10252.0
10235/10252.0
10236/10252.0
10237/10252.0
10238/10252.0
10239/10252.0
10240/10252.0
10241/10252.0
10242/10252.0
10243/10252.0
10244/10252.0
10245/10252.0
10246/10252.0
10247/10252.0
10248/10252.0
10249/10252.0
10250/10252.0
10251/10252.0


In [24]:
print(classification_report(test_y, bert_predicted))

              precision    recall  f1-score   support

           0       0.63      0.73      0.67      6037
           1       0.00      0.00      0.00         8
           2       0.38      0.15      0.22        39
           3       0.48      0.37      0.42      4168

    accuracy                           0.58     10252
   macro avg       0.37      0.31      0.33     10252
weighted avg       0.57      0.58      0.57     10252



In [25]:
print(epoch_loss)

[48161.826351630734, 45839.93867501954, 41989.24006806989, 37090.992195491795, 32753.273734059185, 29288.746379895223, 26071.458497231943, 23714.98107702726, 21888.41091365992, 20454.021743885263, 19172.97132580796, 18105.878856272637, 17250.542053121622, 16391.622835510138, 15824.127572997426, 15372.637246203249, 14842.386928096948, 14644.625992708778, 14355.812987334624, 13775.606519482018, 13301.202704979045, 12877.856288157487, 12306.736937414338, 12178.804884861906, 12282.2350730996, 11742.634107740558, 11202.829733081533, 10920.099223856261, 11155.251404406901, 11320.390920215274]
