In [1]:
 import pandas as pd
 import numpy as np

In [9]:
from google.colab import drive
drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


Загрузим данные

In [10]:
df = pd.read_csv('gdrive/MyDrive/train.csv')

Выделим классы

In [11]:
for num, name_class in enumerate(df.category.unique()):
  df.loc[df.category==name_class, 'class'] = num

In [None]:
df

Unnamed: 0,text,category,class
0,I am still waiting on my card?,card_arrival,0.0
1,What can I do if my card still hasn't arrived ...,card_arrival,0.0
2,I have been waiting over a week. Is the card s...,card_arrival,0.0
3,Can I track my card while it is in the process...,card_arrival,0.0
4,"How do I know if I will get my card, or if it ...",card_arrival,0.0
...,...,...,...
9998,You provide support in what countries?,country_support,76.0
9999,What countries are you supporting?,country_support,76.0
10000,What countries are getting support?,country_support,76.0
10001,Are cards available in the EU?,country_support,76.0


In [16]:
import nlpaug.augmenter.word as naw
import nltk
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


True

Разделим на обучающую и тестовую выборку

In [17]:
from sklearn.model_selection import train_test_split
train, test = train_test_split(df, test_size=0.25, random_state=241)

Проведем Аугментацию данных путем замены слов на синонимы с помощью библиотеки - nlpaug

In [18]:
aug = naw.SynonymAug(aug_src='wordnet', model_path=None, name='Synonym_Aug', aug_min=1, aug_max=10, aug_p=0.2, lang='eng', 
                     stopwords=None, tokenizer=None, reverse_tokenizer=None, stopwords_regex=None, force_reload=False, 
                     verbose=0)
 
test_sentence = df.text[0]
result = pd.DataFrame()
for text, label in zip(train.text, train['class']):
  for _ in range(10):
    result = result.append(pd.DataFrame({'text': aug.augment(text), 'class':label}, index=range(0, 1)))

Мержим все что есть и типом train и test, чтобы в тест не попали данные с аугментации

In [19]:
train = train.append(result)
train['type'] = 'train'
test['type'] = 'test'
df = train.append(test)
df

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until


Unnamed: 0,text,category,class,type
9266,Tell me how to renew my new card?,activate_my_card,71.0,train
4236,i can not see my top up,topping_up_by_card,33.0,train
8715,MY transfer has still not appeared.,balance_not_updated_after_bank_transfer,67.0,train
7425,How can I configure what currency my salary is...,receiving_money,56.0,train
4080,do i need to wait for my card before i get pin,get_physical_card,31.0,train
...,...,...,...,...
1670,Help me unblock my account. I entered the PIN...,pin_blocked,13.0,test
782,My cash withdrawal isnt showing on my account,pending_cash_withdrawal,5.0,test
6107,My transfer's stuck on pending.,pending_transfer,46.0,test
5929,I tried to pay with my card and it didn't work,declined_card_payment,45.0,test


Следующий этап состоит в том, чтобы собрать мешок слов из символов

In [20]:
BOS, EOS = ' ', '\n' 

lines = df.text.apply(lambda line: BOS + line.replace(EOS, ' ') + EOS) \
            .tolist()
lines = [line.lower() for line in lines]

In [21]:
for i in range(len(lines)):
  lines[i] = lines[i].replace('\xa0', '')

In [22]:
lines[:2]

[' tell me how to renew my new card?\n', ' i can not see my top up\n']

In [23]:
from nltk.tokenize import WordPunctTokenizer
tokenizer = WordPunctTokenizer()
tokens = [' '.join(tokenizer.tokenize(line.lower())) for line in lines]
tokens = [char for line in tokens for char in line] 
tokens.append(BOS)
tokens.append(EOS)
tokens = list(set(tokens))
tokens = sorted(tokens)
n_tokens = len(tokens)

In [24]:
print('shape of tokens = %s'%len(tokens))

shape of tokens = 58


In [25]:
#Зададим каждому символу, свой порядковый номер
token_to_id = {char: t for t, char in enumerate(tokens)}

Функция для перевода текстовых строчек в мешок слов

In [None]:
def to_matrix(lines, max_len=None, pad=token_to_id[EOS], dtype=np.int64):
    max_len = max_len or max(map(len, lines))
    lines_ix = np.full([len(lines), max_len], pad, dtype=dtype)
    for i in range(len(lines)):
        line_ix = list(map(token_to_id.get, lines[i][:max_len]))
        lines_ix[i, :len(line_ix)] = line_ix
    return lines_ix

In [None]:
to_matrix(lines[:2])

array([[ 1, 36,  1, 46, 32, 32,  1, 28,  1, 31, 48, 43, 39, 36, 30, 28,
        47, 32,  1, 43, 28, 52, 40, 32, 41, 47, 12,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 1, 50, 35, 52,  1, 31, 42,  1, 36,  1, 35, 28, 49, 32,  1, 28,
         1, 30, 35, 28, 45, 34, 32,  1, 33, 42, 45,  1, 28, 41,  1, 28,
        47, 40,  1, 50, 36, 47, 35, 31, 45, 28, 50, 28, 39, 27,  1, 36,
         1, 47, 35, 42, 48, 34, 35, 47,  1, 47, 35, 32, 46, 32,  1, 50,
        32, 45, 32,  1, 33, 45, 32, 32, 27,  0]])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
n_class = len(df['class'].unique())

In [None]:
class ClassifierRNN(nn.Module):
    def __init__(self, n_tokens=n_tokens, emb_size=32, hid_size=256, n_class=n_class):
       
        super().__init__() 

        self.emb = nn.Embedding(num_embeddings=n_tokens, embedding_dim=emb_size)
        self.LSTM = nn.LSTM(input_size=emb_size, hidden_size=hid_size, batch_first=True, num_layers=2)
        self.relu = nn.ReLU()
        self.linear = nn.Linear(hid_size, hid_size)
        self.linear2 = nn.Linear(hid_size, hid_size)
        self.linear3 = nn.Linear(hid_size, n_class)
        self.softmax = nn.Softmax()
        
    
    def __call__(self, input_ix):
        x = self.emb(input_ix)
        x = self.LSTM(x)[0][:,-1]
        x = self.relu(self.linear(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

Попытка сделать RNN сеть не увенчалась успехом, плохой результат получился

In [None]:
class ClassifierCNN(nn.Module):
    def __init__(self, n_tokens=len(tokens), emb_size=64, hid_size=128, n_class=n_class):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings=n_tokens, embedding_dim=emb_size)
        self.cnn1 = nn.Sequential(
                            nn.Conv1d(emb_size, hid_size, kernel_size=3),
                            nn.Dropout(p=0.25),
                            nn.ReLU()
                            )
        self.cnn2 = nn.Sequential(
                            nn.Conv1d(hid_size, hid_size, kernel_size=3),
                            nn.Dropout(p=0.25),
                            nn.ReLU()
                            )
        self.cnn3 = nn.Sequential(
                            nn.Conv1d(hid_size, hid_size, kernel_size=3),
                            nn.Dropout(p=0.25),
                            nn.ReLU(),
                            nn.AdaptiveMaxPool1d(output_size=1)
                            )
        self.linear = nn.Sequential(
                            nn.Linear(hid_size, 4*hid_size),
                            nn.ReLU(),
                            nn.Linear(hid_size*4, hid_size*4),
                            nn.ReLU(),
                            nn.Linear(4*hid_size, n_class)
                            )
        
    def __call__(self, input_ix):
        x = self.emb(input_ix).transpose(1, 2)
        x = self.cnn1(x)
        x = self.cnn2(x)
        x = self.cnn3(x).transpose(1, 2)
        return self.linear(x).squeeze()
        

Простая сеть с 3-мя сверточными слоями

In [None]:
class ClassifierCNN(nn.Module):
    def __init__(self, n_tokens=len(tokens), emb_size=64, hid_size=128, n_class=n_class):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings=n_tokens, embedding_dim=emb_size)
        self.cnn1 = nn.Sequential(
                            nn.Conv1d(emb_size, hid_size, kernel_size=3),
                            nn.Dropout(p=0.25),
                            nn.ReLU(),
                            nn.BatchNorm1d(hid_size)
                            )
        self.cnn2 = nn.Sequential(
                            nn.Conv1d(hid_size, hid_size, kernel_size=3),
                            nn.Dropout(p=0.25),
                            nn.ReLU(),
                            nn.BatchNorm1d(hid_size)
                            )
        self.cnn3 = nn.Sequential(
                            nn.Conv1d(hid_size, hid_size, kernel_size=3),
                            nn.Dropout(p=0.25),
                            nn.ReLU(),
                            nn.BatchNorm1d(hid_size)
                            )
        self.cnn4 = nn.Sequential(
                            nn.Conv1d(hid_size, hid_size, kernel_size=3),
                            nn.Dropout(p=0.25),
                            nn.ReLU(),
                            nn.AdaptiveMaxPool1d(output_size=1),
                            nn.BatchNorm1d(hid_size)
                            )
        self.linear = nn.Sequential(
                            nn.Linear(hid_size, 4*hid_size),
                            nn.ReLU(),
                            nn.Linear(hid_size*4, hid_size*4),
                            nn.ReLU(),
                            nn.Dropout(p=0.25),
                            nn.Linear(hid_size*4, hid_size*4),
                            nn.ReLU(),
                            nn.Linear(4*hid_size, n_class)
                            )
        
    def __call__(self, input_ix):
        x = self.emb(input_ix).transpose(1, 2)
        x = self.cnn1(x)
        x = self.cnn2(x)
        x = self.cnn3(x)
        x = self.cnn4(x).transpose(1, 2)
        return self.linear(x).squeeze()
        

Немного улучшаем

In [None]:
df['lines'] = lines
df.head(2)

Unnamed: 0,text,category,class,type,lines
9266,Tell me how to renew my new card?,activate_my_card,71.0,train,tell me how to renew my new card?\n
4236,i can not see my top up,topping_up_by_card,33.0,train,i can not see my top up\n


In [None]:
train_text = df.loc[df.type=='train', 'lines'].to_list()
test_text = df.loc[df.type=='test', 'lines'].to_list()
train_labels = df.loc[df.type=='train', 'class'].to_list()
test_labels = df.loc[df.type=='test', 'class'].to_list()

In [None]:
#переводим в тензора
train_matrix = torch.tensor(to_matrix(train_text))

test_matrix = torch.tensor(to_matrix(test_text))

In [None]:
train_labels = torch.tensor(train_labels)

In [None]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

In [None]:
batch_size = 256
train_data = TensorDataset(train_matrix, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

Т.к метки несбалансированные, подсчитаем веса для каждой метки и передадим их, как аргумет в функцию потерь - кросс_энтропию

In [None]:
from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)


In [None]:
device = torch.device('cuda')
weights= torch.tensor(class_weights,dtype=torch.float)

weights = weights.to(device)


Обучение

In [None]:
from tqdm.notebook import tqdm
from torch.optim import lr_scheduler
model = ClassifierCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.1)
EPOCHS = 10
for epoch in range(EPOCHS):
    print(f"epoch: {epoch}")
    model.train()
    for  i, batch in enumerate(tqdm(train_dataloader)):
        batch = [r.to(device) for r in batch]
        sent_id,  labels = batch
        optimizer.zero_grad()
        pred = model(sent_id)
        labels = labels.long()
        #target = batch['class'].long()
        loss = criterion(pred, labels)
        loss.backward()
        optimizer.step()
        #scheduler.step()
        if i % 20 == 0:
          print('loss = %f'%loss.item())

epoch: 0


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 4.341788
loss = 4.334428
loss = 4.308545
loss = 4.292790
loss = 4.274245
loss = 4.310888
loss = 4.235419
loss = 4.136908
loss = 4.085417
loss = 3.910193
loss = 3.814699
loss = 3.816343
loss = 3.638532
loss = 3.636296
loss = 3.551883
loss = 3.350270
loss = 3.360361
epoch: 1


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.182056
loss = 3.090386
loss = 3.249537
loss = 3.284030
loss = 3.126810
loss = 3.083251
loss = 3.095666
loss = 2.865837
loss = 2.832069
loss = 2.892827
loss = 2.824306
loss = 2.712342
loss = 2.615083
loss = 2.790845
loss = 2.536636
loss = 2.570494
loss = 2.648888
epoch: 2


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 2.551025
loss = 2.513076
loss = 2.597655
loss = 2.465110
loss = 2.440705
loss = 2.449447
loss = 2.272830
loss = 2.364008
loss = 2.336810
loss = 2.216341
loss = 2.294706
loss = 2.289957
loss = 2.203597
loss = 1.900822
loss = 2.452011
loss = 2.108922
loss = 2.219896
epoch: 3


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 2.211082
loss = 2.083792
loss = 2.065789
loss = 2.150692
loss = 1.871441
loss = 1.978140
loss = 1.970117
loss = 1.806178
loss = 1.764886
loss = 1.961727
loss = 1.682358
loss = 1.809179
loss = 1.761168
loss = 1.655488
loss = 1.988308
loss = 1.825673
loss = 1.532641
epoch: 4


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 1.537594
loss = 1.763197
loss = 1.580875
loss = 1.814865
loss = 1.562471
loss = 1.637715
loss = 1.591374
loss = 1.684320
loss = 1.590917
loss = 1.619012
loss = 1.603112
loss = 1.575046
loss = 1.499650
loss = 1.636661
loss = 1.672900
loss = 1.486451
loss = 1.592847
epoch: 5


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 1.395523
loss = 1.389082
loss = 1.341252
loss = 1.565843
loss = 1.482944
loss = 1.624638
loss = 1.664808
loss = 1.382774
loss = 1.377626
loss = 1.376261
loss = 1.422557
loss = 1.475945
loss = 1.471000
loss = 1.320946
loss = 1.395864
loss = 1.288824
loss = 1.397500
epoch: 6


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 1.321754
loss = 1.089076
loss = 1.286401
loss = 1.378406
loss = 1.292358
loss = 1.141913
loss = 1.297963
loss = 1.166074
loss = 1.194276
loss = 1.330861
loss = 1.250726
loss = 1.421867
loss = 1.139443
loss = 1.227530
loss = 1.058413
loss = 1.171807
loss = 1.321053
epoch: 7


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 1.295890
loss = 1.176460
loss = 1.180734
loss = 1.138168
loss = 1.314819
loss = 1.132066
loss = 1.274629
loss = 1.127044
loss = 1.087569
loss = 1.118252
loss = 1.233875
loss = 1.095969
loss = 1.017406
loss = 0.939166
loss = 1.158645
loss = 1.155480
loss = 1.058894
epoch: 8


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 1.146747
loss = 1.123288
loss = 1.210436
loss = 1.185087
loss = 1.008755
loss = 1.122216
loss = 1.024320
loss = 0.959082
loss = 1.010312
loss = 1.034966
loss = 0.980612
loss = 1.111919
loss = 1.067711
loss = 0.884091
loss = 0.962864
loss = 1.131087
loss = 1.078777
epoch: 9


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 1.037803
loss = 1.159962
loss = 1.119490
loss = 1.012737
loss = 0.927251
loss = 1.233433
loss = 1.032953
loss = 0.996858
loss = 0.899910
loss = 0.976690
loss = 0.950469
loss = 0.921913
loss = 1.089092
loss = 0.953282
loss = 1.039637
loss = 0.839102
loss = 1.033888


In [None]:
from sklearn.metrics import f1_score
def test_score(model, test_matrix, device):
  f1 = 0

  model.eval()
  with torch.no_grad():
    pred = model(test_matrix.to(device))
    pred = torch.argmax(pred, dim=1).cpu().detach().numpy()
    f1 = f1_score(pred, test_labels, pred, average='macro')
  return f1

In [None]:
from tqdm.notebook import tqdm
from torch.optim import lr_scheduler
model = ClassifierCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=400, gamma=0.1)
EPOCHS = 15
f1 = 0
params_model = 0
for epoch in range(EPOCHS):
    print(f"epoch: {epoch}")
    model.train()
    for  i, batch in enumerate(tqdm(train_dataloader)):
        batch = [r.to(device) for r in batch]
        sent_id,  labels = batch
        optimizer.zero_grad()
        pred = model(sent_id)
        labels = labels.long()
        #target = batch['class'].long()
        loss = criterion(pred, labels)
        loss.backward()
        optimizer.step()
        #scheduler.step()
        if test_score(model, test_matrix, device)>f1:
          f1 = test_score(model, test_matrix, device)
          params_model = model.state_dict()
        if i % 20 == 0:
          print('loss = %f'%loss.item(), 'f1_score test = ',f1 )

epoch: 0


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 4.333051 f1_score test =  0.022924901185770754
loss = 4.340577 f1_score test =  0.028379976350019712
loss = 4.330057 f1_score test =  0.028379976350019712
loss = 4.315484 f1_score test =  0.03148366784730422
loss = 4.333817 f1_score test =  0.03148366784730422
loss = 4.306513 f1_score test =  0.03148366784730422
loss = 4.303718 f1_score test =  0.03148366784730422
loss = 4.255809 f1_score test =  0.03714548493888704
loss = 4.241342 f1_score test =  0.06785494711392036
loss = 4.056478 f1_score test =  0.07029318281462509
loss = 3.909437 f1_score test =  0.07029318281462509
loss = 3.931465 f1_score test =  0.07612423107954289
loss = 3.803548 f1_score test =  0.08484851004657169
loss = 3.796975 f1_score test =  0.10953963712647456
loss = 3.563515 f1_score test =  0.129126056432543
loss = 3.506716 f1_score test =  0.14294784479666917
loss = 3.429619 f1_score test =  0.17693112712247328
epoch: 1


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.807784 f1_score test =  0.17693112712247328
loss = 3.627548 f1_score test =  0.17693112712247328
loss = 3.561411 f1_score test =  0.17693112712247328
loss = 3.450225 f1_score test =  0.17693112712247328
loss = 3.399137 f1_score test =  0.18096971935267955
loss = 3.343088 f1_score test =  0.18469601181704737
loss = 3.228672 f1_score test =  0.18521022415658053
loss = 3.095377 f1_score test =  0.18876770576263258
loss = 3.170963 f1_score test =  0.21161039869554685
loss = 3.036689 f1_score test =  0.21773232702057818
loss = 3.154443 f1_score test =  0.22206235067561275
loss = 3.081253 f1_score test =  0.2395267948425143
loss = 2.938586 f1_score test =  0.24056825046033223
loss = 2.805140 f1_score test =  0.2623518054681758
loss = 2.905184 f1_score test =  0.2683226039825396
loss = 2.706948 f1_score test =  0.283183532515984
loss = 2.718599 f1_score test =  0.30522775760979265
epoch: 2


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 4.125410 f1_score test =  0.30522775760979265
loss = 2.713878 f1_score test =  0.30522775760979265
loss = 2.569620 f1_score test =  0.3059008961899888
loss = 2.517426 f1_score test =  0.33040076244276856
loss = 2.648285 f1_score test =  0.3335433072408317
loss = 2.450604 f1_score test =  0.3498404231038464
loss = 2.474101 f1_score test =  0.37451485142565105
loss = 2.256472 f1_score test =  0.38190560712565186
loss = 2.389289 f1_score test =  0.38190560712565186
loss = 2.319353 f1_score test =  0.4036390055771493
loss = 2.187809 f1_score test =  0.419533510301105
loss = 2.191029 f1_score test =  0.4361768785029343
loss = 2.338873 f1_score test =  0.4483031966311235
loss = 2.176456 f1_score test =  0.4637030069466462
loss = 1.992315 f1_score test =  0.4665615662757876
loss = 2.022861 f1_score test =  0.4708047970769973
loss = 1.893453 f1_score test =  0.5003617723865703
epoch: 3


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.375527 f1_score test =  0.5003617723865703
loss = 1.990347 f1_score test =  0.5003617723865703
loss = 1.899928 f1_score test =  0.5135871628566451
loss = 1.921692 f1_score test =  0.5135871628566451
loss = 1.758223 f1_score test =  0.5227213561789341
loss = 1.701598 f1_score test =  0.5284558947347965
loss = 1.619028 f1_score test =  0.5459613789526689
loss = 1.810137 f1_score test =  0.5460100099297716
loss = 1.571254 f1_score test =  0.5624743906873475
loss = 1.618328 f1_score test =  0.5714454442904449
loss = 1.547958 f1_score test =  0.5726810140469023
loss = 1.505723 f1_score test =  0.5786128161337085
loss = 1.572198 f1_score test =  0.5870290430692523
loss = 1.512695 f1_score test =  0.5937329972098956
loss = 1.614436 f1_score test =  0.5987299548601119
loss = 1.516585 f1_score test =  0.6052293476751373
loss = 1.361745 f1_score test =  0.6117388446069615
epoch: 4


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.490172 f1_score test =  0.6154852300515288
loss = 1.387516 f1_score test =  0.6154852300515288
loss = 1.311036 f1_score test =  0.6204046583860413
loss = 1.321962 f1_score test =  0.6255298028275972
loss = 1.261350 f1_score test =  0.6325629893097076
loss = 1.457396 f1_score test =  0.6411754974104583
loss = 1.349374 f1_score test =  0.6411754974104583
loss = 1.336953 f1_score test =  0.6425137027082166
loss = 1.371755 f1_score test =  0.6480448898737124
loss = 1.224570 f1_score test =  0.6533546768998527
loss = 1.141081 f1_score test =  0.6549283480866483
loss = 1.140454 f1_score test =  0.6636154586462134
loss = 1.131668 f1_score test =  0.6657568966978775
loss = 1.188670 f1_score test =  0.6703795287684835
loss = 1.083627 f1_score test =  0.6777034733323295
loss = 1.042001 f1_score test =  0.6777034733323295
loss = 1.131606 f1_score test =  0.6826194083663292
epoch: 5


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.092832 f1_score test =  0.6826194083663292
loss = 1.144620 f1_score test =  0.6826194083663292
loss = 1.138501 f1_score test =  0.6826194083663292
loss = 1.108262 f1_score test =  0.6860797881477567
loss = 0.964052 f1_score test =  0.6886232098110638
loss = 1.135424 f1_score test =  0.6994585180681973
loss = 1.107702 f1_score test =  0.6994585180681973
loss = 0.835872 f1_score test =  0.6994585180681973
loss = 1.012085 f1_score test =  0.7050032817817498
loss = 0.983231 f1_score test =  0.7050032817817498
loss = 0.995173 f1_score test =  0.7070139879695629
loss = 1.107189 f1_score test =  0.7106363655450273
loss = 1.040139 f1_score test =  0.711021603411607
loss = 0.961374 f1_score test =  0.7184371738058515
loss = 0.873121 f1_score test =  0.7184371738058515
loss = 0.859716 f1_score test =  0.7197439424160568
loss = 0.885727 f1_score test =  0.7197439424160568
epoch: 6


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.021949 f1_score test =  0.7197439424160568
loss = 0.889665 f1_score test =  0.7197439424160568
loss = 0.935612 f1_score test =  0.7197439424160568
loss = 0.841377 f1_score test =  0.7268730787007175
loss = 1.116075 f1_score test =  0.7268730787007175
loss = 0.975735 f1_score test =  0.7394990269129431
loss = 0.938074 f1_score test =  0.7394990269129431
loss = 0.833858 f1_score test =  0.7394990269129431
loss = 0.877671 f1_score test =  0.7394990269129431
loss = 0.914679 f1_score test =  0.7418543021625289
loss = 0.734507 f1_score test =  0.7475596443273606
loss = 0.702900 f1_score test =  0.7475596443273606
loss = 0.824121 f1_score test =  0.7475596443273606
loss = 0.927404 f1_score test =  0.7520342499227104
loss = 0.740040 f1_score test =  0.7520342499227104
loss = 0.831211 f1_score test =  0.7520342499227104
loss = 0.745060 f1_score test =  0.7520342499227104
epoch: 7


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.250595 f1_score test =  0.7520342499227104
loss = 0.777071 f1_score test =  0.7520342499227104
loss = 0.686313 f1_score test =  0.7520342499227104
loss = 0.757890 f1_score test =  0.75917159016205
loss = 0.747861 f1_score test =  0.75917159016205
loss = 0.744653 f1_score test =  0.75917159016205
loss = 0.700282 f1_score test =  0.7612585289959763
loss = 0.770998 f1_score test =  0.7612585289959763
loss = 0.724212 f1_score test =  0.7612585289959763
loss = 0.749513 f1_score test =  0.7626632537543743
loss = 0.605444 f1_score test =  0.76841777671888
loss = 0.835398 f1_score test =  0.76841777671888
loss = 0.796150 f1_score test =  0.7694633837597781
loss = 0.697391 f1_score test =  0.7694633837597781
loss = 0.739443 f1_score test =  0.7738202211905767
loss = 0.804535 f1_score test =  0.7738202211905767
loss = 0.727260 f1_score test =  0.7738202211905767
epoch: 8


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.270973 f1_score test =  0.7738202211905767
loss = 0.869378 f1_score test =  0.7738202211905767
loss = 0.634208 f1_score test =  0.7738202211905767
loss = 0.618409 f1_score test =  0.7772440177152751
loss = 0.534739 f1_score test =  0.780906473878836
loss = 0.704580 f1_score test =  0.780906473878836
loss = 0.683045 f1_score test =  0.780906473878836
loss = 0.569660 f1_score test =  0.7830493976692323
loss = 0.529541 f1_score test =  0.7830493976692323
loss = 0.756713 f1_score test =  0.7830493976692323
loss = 0.504991 f1_score test =  0.7833644623574343
loss = 0.551087 f1_score test =  0.7833644623574343
loss = 0.608194 f1_score test =  0.7860740207316241
loss = 0.602959 f1_score test =  0.7860740207316241
loss = 0.519684 f1_score test =  0.7878070641733494
loss = 0.695902 f1_score test =  0.7878070641733494
loss = 0.658101 f1_score test =  0.7878070641733494
epoch: 9


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.181536 f1_score test =  0.7878070641733494
loss = 0.602166 f1_score test =  0.7878070641733494
loss = 0.516754 f1_score test =  0.7878070641733494
loss = 0.478679 f1_score test =  0.7893974128327439
loss = 0.524916 f1_score test =  0.7917551132030506
loss = 0.503570 f1_score test =  0.7917551132030506
loss = 0.514861 f1_score test =  0.7949162596968731
loss = 0.573149 f1_score test =  0.7949162596968731
loss = 0.557895 f1_score test =  0.7949162596968731
loss = 0.564213 f1_score test =  0.7949162596968731
loss = 0.468306 f1_score test =  0.7949162596968731
loss = 0.578894 f1_score test =  0.7949162596968731
loss = 0.546632 f1_score test =  0.7949162596968731
loss = 0.575810 f1_score test =  0.7949162596968731
loss = 0.472259 f1_score test =  0.8009016042081177
loss = 0.496793 f1_score test =  0.8009016042081177
loss = 0.542721 f1_score test =  0.8012775807148784
epoch: 10


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.325005 f1_score test =  0.8012775807148784
loss = 0.466984 f1_score test =  0.8012775807148784
loss = 0.496997 f1_score test =  0.8012775807148784
loss = 0.461167 f1_score test =  0.8012775807148784
loss = 0.460474 f1_score test =  0.8057386748212881
loss = 0.533477 f1_score test =  0.8057386748212881
loss = 0.448382 f1_score test =  0.8057386748212881
loss = 0.457149 f1_score test =  0.8057386748212881
loss = 0.539009 f1_score test =  0.8057386748212881
loss = 0.651297 f1_score test =  0.8057386748212881
loss = 0.556331 f1_score test =  0.8057386748212881
loss = 0.381759 f1_score test =  0.8057386748212881
loss = 0.585529 f1_score test =  0.8058636784044034
loss = 0.514419 f1_score test =  0.8058636784044034
loss = 0.439694 f1_score test =  0.8058636784044034
loss = 0.421242 f1_score test =  0.8058636784044034
loss = 0.400606 f1_score test =  0.8070245727139858
epoch: 11


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.447986 f1_score test =  0.8070245727139858
loss = 0.380190 f1_score test =  0.8070245727139858
loss = 0.490674 f1_score test =  0.8070245727139858
loss = 0.378580 f1_score test =  0.8070245727139858
loss = 0.393265 f1_score test =  0.8070245727139858
loss = 0.382801 f1_score test =  0.8121927777380351
loss = 0.451050 f1_score test =  0.8121927777380351
loss = 0.455696 f1_score test =  0.8121927777380351
loss = 0.425944 f1_score test =  0.8121927777380351
loss = 0.400314 f1_score test =  0.8121927777380351
loss = 0.317511 f1_score test =  0.8121927777380351
loss = 0.430117 f1_score test =  0.8121927777380351
loss = 0.448108 f1_score test =  0.8121927777380351
loss = 0.394805 f1_score test =  0.8121927777380351
loss = 0.386156 f1_score test =  0.8134403949044786
loss = 0.397250 f1_score test =  0.8134403949044786
loss = 0.456386 f1_score test =  0.8134403949044786
epoch: 12


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.575835 f1_score test =  0.8136282538840737
loss = 0.318610 f1_score test =  0.8136282538840737
loss = 0.443527 f1_score test =  0.8136282538840737
loss = 0.357025 f1_score test =  0.8136282538840737
loss = 0.357926 f1_score test =  0.8136282538840737
loss = 0.318624 f1_score test =  0.8157704299070311
loss = 0.415966 f1_score test =  0.8157704299070311
loss = 0.364197 f1_score test =  0.817218599579626
loss = 0.394739 f1_score test =  0.817218599579626
loss = 0.233619 f1_score test =  0.817218599579626
loss = 0.463955 f1_score test =  0.817218599579626
loss = 0.320580 f1_score test =  0.817218599579626
loss = 0.383343 f1_score test =  0.817218599579626
loss = 0.421067 f1_score test =  0.8198972238740403
loss = 0.327417 f1_score test =  0.8198972238740403
loss = 0.418147 f1_score test =  0.8198972238740403
loss = 0.428344 f1_score test =  0.8198972238740403
epoch: 13


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.569323 f1_score test =  0.8198972238740403
loss = 0.343627 f1_score test =  0.8198972238740403
loss = 0.299583 f1_score test =  0.8198972238740403
loss = 0.360819 f1_score test =  0.8202880157539338
loss = 0.311093 f1_score test =  0.8202880157539338
loss = 0.326485 f1_score test =  0.8202880157539338
loss = 0.355652 f1_score test =  0.8202880157539338
loss = 0.356214 f1_score test =  0.8208707249628232
loss = 0.279828 f1_score test =  0.8208707249628232
loss = 0.315582 f1_score test =  0.8226709402364537
loss = 0.325685 f1_score test =  0.8226709402364537
loss = 0.441789 f1_score test =  0.8226709402364537
loss = 0.324359 f1_score test =  0.8227125275923517
loss = 0.395139 f1_score test =  0.8227125275923517
loss = 0.374317 f1_score test =  0.8227125275923517
loss = 0.275440 f1_score test =  0.8227125275923517
loss = 0.302723 f1_score test =  0.8227125275923517
epoch: 14


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.583187 f1_score test =  0.8227125275923517
loss = 0.222897 f1_score test =  0.8227125275923517
loss = 0.276702 f1_score test =  0.8227125275923517
loss = 0.338383 f1_score test =  0.8227125275923517
loss = 0.286983 f1_score test =  0.8236797605785647
loss = 0.352642 f1_score test =  0.8236797605785647
loss = 0.320782 f1_score test =  0.8253582445297397
loss = 0.280404 f1_score test =  0.8253582445297397
loss = 0.365577 f1_score test =  0.8253582445297397
loss = 0.319370 f1_score test =  0.8253582445297397
loss = 0.285768 f1_score test =  0.8253582445297397
loss = 0.245819 f1_score test =  0.8253582445297397
loss = 0.347355 f1_score test =  0.8253582445297397
loss = 0.330593 f1_score test =  0.8253582445297397
loss = 0.318129 f1_score test =  0.8253582445297397
loss = 0.253024 f1_score test =  0.8253582445297397
loss = 0.257543 f1_score test =  0.8272277243625022


In [None]:
import os
torch.save(params_model, os.path.join('gdrive/MyDrive/', 'best_model'))

Проверка модели с SelfAdjDiceLoss 

In [None]:
class SelfAdjDiceLoss(torch.nn.Module):

    def __init__(self, alpha: float = 1.0, gamma: float = 1.0, reduction: str = "mean") -> None:
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        probs = torch.softmax(logits, dim=1)
        probs = torch.gather(probs, dim=1, index=targets.unsqueeze(1))

        probs_with_factor = ((1 - probs) ** self.alpha) * probs
        loss = 1 - (2 * probs_with_factor + self.gamma) / (probs_with_factor + 1 + self.gamma)

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        elif self.reduction == "none" or self.reduction is None:
            return loss
        else:
            raise NotImplementedError(f"Reduction `{self.reduction}` is not supported.")

In [None]:
from tqdm.notebook import tqdm
from torch.optim import lr_scheduler
model = ClassifierCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=400, gamma=0.1)
EPOCHS = 15
f1 = 0
params_model = 0
for epoch in range(EPOCHS):
    print(f"epoch: {epoch}")
    model.train()
    for  i, batch in enumerate(tqdm(train_dataloader)):
        batch = [r.to(device) for r in batch]
        sent_id,  labels = batch
        optimizer.zero_grad()
        pred = model(sent_id)
        labels = labels.long()
        #target = batch['class'].long()
        loss = criterion(pred, labels)
        loss.backward()
        optimizer.step()
        #scheduler.step()
        f1_tmp = test_score(model, test_matrix, device)
        if f1_tmp>f1:
          f1 = f1_tmp
          params_model = model.state_dict()
        if i % 20 == 0:
          print('loss = %f'%loss.item(), 'f1_score test = ', f1_tmp)

epoch: 0


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 4.344694 f1_score test =  0.02448657187993681
loss = 4.340944 f1_score test =  0.02448657187993681
loss = 4.335081 f1_score test =  0.030442506380471534
loss = 4.317577 f1_score test =  0.022924901185770754
loss = 4.312341 f1_score test =  0.022924901185770754
loss = 4.283731 f1_score test =  0.022924901185770754
loss = 4.286858 f1_score test =  0.022924901185770754
loss = 4.270424 f1_score test =  0.04503894816576089
loss = 4.201200 f1_score test =  0.054843555818678294
loss = 4.053119 f1_score test =  0.06728562332163933
loss = 3.999486 f1_score test =  0.04358608192657985
loss = 3.964964 f1_score test =  0.06489640787508695
loss = 3.902173 f1_score test =  0.07323194949495875
loss = 3.798434 f1_score test =  0.08455297987978956
loss = 3.821606 f1_score test =  0.09547246815090615
loss = 3.687636 f1_score test =  0.1070552172275461
loss = 3.674948 f1_score test =  0.12693617361499132
epoch: 1


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 4.089890 f1_score test =  0.03976312406008121
loss = 3.780951 f1_score test =  0.10108022997953953
loss = 3.513255 f1_score test =  0.15518148476325036
loss = 3.382892 f1_score test =  0.17075319535983954
loss = 3.347241 f1_score test =  0.21583700155755906
loss = 3.328513 f1_score test =  0.1800862803971159
loss = 3.130100 f1_score test =  0.22333680845299905
loss = 3.017849 f1_score test =  0.2503979184954317
loss = 3.075322 f1_score test =  0.25768610051047597
loss = 2.947673 f1_score test =  0.26772933586186604
loss = 2.776262 f1_score test =  0.29160209459519637
loss = 2.713614 f1_score test =  0.2777351926617068
loss = 2.605471 f1_score test =  0.2965042309691886
loss = 2.636892 f1_score test =  0.31201821755432707
loss = 2.539486 f1_score test =  0.32957006748672146
loss = 2.615698 f1_score test =  0.3370173272372037
loss = 2.230172 f1_score test =  0.34757784581857865
epoch: 2


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.921682 f1_score test =  0.1235369106481857
loss = 2.429591 f1_score test =  0.3238333751062314
loss = 2.517101 f1_score test =  0.3676881762559751
loss = 2.327060 f1_score test =  0.3818045315414476
loss = 2.394382 f1_score test =  0.40456317312769285
loss = 2.217238 f1_score test =  0.426156339445794
loss = 2.167068 f1_score test =  0.43908487026472665
loss = 2.163267 f1_score test =  0.4365675585801656
loss = 2.108879 f1_score test =  0.45237682786082556
loss = 2.000051 f1_score test =  0.4617527622218911
loss = 2.004923 f1_score test =  0.4812489037684134
loss = 1.770822 f1_score test =  0.500890027254186
loss = 1.901386 f1_score test =  0.4860917941080601
loss = 1.816019 f1_score test =  0.49765525750459993
loss = 1.847995 f1_score test =  0.5132799650008699
loss = 1.735033 f1_score test =  0.5279750393521258
loss = 1.898532 f1_score test =  0.5386012128390758
epoch: 3


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.485085 f1_score test =  0.41573362674954945
loss = 1.616508 f1_score test =  0.5190179252006598
loss = 1.822491 f1_score test =  0.5387820602856368
loss = 1.675851 f1_score test =  0.5527155405281997
loss = 1.530312 f1_score test =  0.5625435330249012
loss = 1.556380 f1_score test =  0.5789091599231558
loss = 1.493488 f1_score test =  0.5898418782343458
loss = 1.575471 f1_score test =  0.5821454701081366
loss = 1.618262 f1_score test =  0.6054742749261963
loss = 1.450235 f1_score test =  0.5900114208758659
loss = 1.602924 f1_score test =  0.6207184747090072
loss = 1.325365 f1_score test =  0.5993478187026641
loss = 1.321588 f1_score test =  0.6065120600527778
loss = 1.467043 f1_score test =  0.6218645330638704
loss = 1.429854 f1_score test =  0.6388057932140009
loss = 1.502958 f1_score test =  0.6333357240943567
loss = 1.286069 f1_score test =  0.6333135804355545
epoch: 4


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.372087 f1_score test =  0.557482902735366
loss = 1.255290 f1_score test =  0.6470595652995129
loss = 1.328600 f1_score test =  0.6332507949713387
loss = 1.070583 f1_score test =  0.6348088738714291
loss = 1.176201 f1_score test =  0.6388143789064076
loss = 1.045590 f1_score test =  0.66526388753064
loss = 1.058048 f1_score test =  0.6509930471652436
loss = 1.165683 f1_score test =  0.6676442109325001
loss = 1.319538 f1_score test =  0.6755998715817824
loss = 1.069525 f1_score test =  0.6784862648593798
loss = 1.173520 f1_score test =  0.6788220360626205
loss = 1.219678 f1_score test =  0.6807172345805964
loss = 1.181854 f1_score test =  0.6914815606712995
loss = 0.980418 f1_score test =  0.6946814299334924
loss = 1.038193 f1_score test =  0.6901074388860203
loss = 1.083134 f1_score test =  0.7013157294700447
loss = 1.083646 f1_score test =  0.689738536869538
epoch: 5


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.308193 f1_score test =  0.5567041160571975
loss = 1.058555 f1_score test =  0.7093964394963195
loss = 1.175316 f1_score test =  0.7032061175769744
loss = 1.163705 f1_score test =  0.7172094203538585
loss = 0.979850 f1_score test =  0.7216310426314114
loss = 1.005127 f1_score test =  0.7279283174059712
loss = 0.922866 f1_score test =  0.7226886713522563
loss = 0.853454 f1_score test =  0.7199844769248449
loss = 1.148207 f1_score test =  0.7353762781228745
loss = 0.949875 f1_score test =  0.7185069956510921
loss = 0.908638 f1_score test =  0.7303913781810011
loss = 0.778424 f1_score test =  0.7394707506149537
loss = 0.842806 f1_score test =  0.7334282735859823
loss = 0.917480 f1_score test =  0.7445536545731295
loss = 0.883093 f1_score test =  0.7393322496120128
loss = 0.800551 f1_score test =  0.7458102009684354
loss = 0.966514 f1_score test =  0.7450988363006593
epoch: 6


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.208591 f1_score test =  0.6834975591761808
loss = 0.912339 f1_score test =  0.7454617609669713
loss = 0.728503 f1_score test =  0.750221919069894
loss = 0.908293 f1_score test =  0.7391122938933667
loss = 0.751924 f1_score test =  0.7465213977187592
loss = 0.670084 f1_score test =  0.7450940470588042
loss = 0.724849 f1_score test =  0.7553852981160661
loss = 0.774512 f1_score test =  0.7534462932596064
loss = 0.713051 f1_score test =  0.7549415049485725
loss = 0.775060 f1_score test =  0.7494610387344585
loss = 0.892745 f1_score test =  0.7653691824721214
loss = 0.678937 f1_score test =  0.7687238992252985
loss = 0.801581 f1_score test =  0.7580420410568304
loss = 0.740800 f1_score test =  0.7636208838931307
loss = 0.629077 f1_score test =  0.7715995822571248
loss = 0.795037 f1_score test =  0.7697799362274975
loss = 0.781844 f1_score test =  0.7710036863390384
epoch: 7


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.293194 f1_score test =  0.6739520852820214
loss = 0.707934 f1_score test =  0.7683411584735056
loss = 0.707708 f1_score test =  0.7781472796891098
loss = 0.670446 f1_score test =  0.7678549447166463
loss = 0.766737 f1_score test =  0.774269149421142
loss = 0.648472 f1_score test =  0.7791197917544609
loss = 0.776533 f1_score test =  0.7759437764826542
loss = 0.764055 f1_score test =  0.7752859970452779
loss = 0.626658 f1_score test =  0.7851626450286218
loss = 0.683078 f1_score test =  0.7819008429081454
loss = 0.704870 f1_score test =  0.7853637919722358
loss = 0.606638 f1_score test =  0.7748442390846992
loss = 0.685560 f1_score test =  0.7816031030176991
loss = 0.806843 f1_score test =  0.7725725067916377
loss = 0.704107 f1_score test =  0.788421054559213
loss = 0.706031 f1_score test =  0.781884840516761
loss = 0.640915 f1_score test =  0.7973875117075198
epoch: 8


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.272333 f1_score test =  0.7264375650913274
loss = 0.722148 f1_score test =  0.7729717069693947
loss = 0.641918 f1_score test =  0.7842572238868344
loss = 0.600847 f1_score test =  0.7856342945779402
loss = 0.594931 f1_score test =  0.7825276612212261
loss = 0.584876 f1_score test =  0.7863389188053918
loss = 0.674710 f1_score test =  0.7766613826535846
loss = 0.596354 f1_score test =  0.7909537607938925
loss = 0.521934 f1_score test =  0.7948564265330748
loss = 0.639961 f1_score test =  0.794052055882271
loss = 0.529237 f1_score test =  0.7824180990937906
loss = 0.554386 f1_score test =  0.7863406243907132
loss = 0.659194 f1_score test =  0.7893637256229665
loss = 0.533844 f1_score test =  0.7849892610991962
loss = 0.541920 f1_score test =  0.7918493807681666
loss = 0.682156 f1_score test =  0.7845811187716621
loss = 0.602169 f1_score test =  0.8040392185733548
epoch: 9


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.339424 f1_score test =  0.7559865270567687
loss = 0.525857 f1_score test =  0.7885534704064563
loss = 0.502508 f1_score test =  0.7933551001551431
loss = 0.641863 f1_score test =  0.7985056689304169
loss = 0.518088 f1_score test =  0.7951882432829916
loss = 0.568474 f1_score test =  0.7873105388367763
loss = 0.537083 f1_score test =  0.8072023794440758
loss = 0.447326 f1_score test =  0.8034979529878782
loss = 0.564425 f1_score test =  0.8023850409645948
loss = 0.600060 f1_score test =  0.7954636920521545
loss = 0.484310 f1_score test =  0.805698622106311
loss = 0.433493 f1_score test =  0.8018080925103379
loss = 0.763904 f1_score test =  0.7985255860595082
loss = 0.477908 f1_score test =  0.7929225759583054
loss = 0.493146 f1_score test =  0.7992329744122776
loss = 0.455510 f1_score test =  0.8000151160582595
loss = 0.490468 f1_score test =  0.803145349892342
epoch: 10


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.667579 f1_score test =  0.7604850746976706
loss = 0.533356 f1_score test =  0.7894073804562372
loss = 0.396478 f1_score test =  0.796617894958363
loss = 0.447351 f1_score test =  0.8038409391350303
loss = 0.366609 f1_score test =  0.8004244429962898
loss = 0.392708 f1_score test =  0.802196826036353
loss = 0.566435 f1_score test =  0.8002970873225396
loss = 0.437431 f1_score test =  0.8078717148064359
loss = 0.495223 f1_score test =  0.8094698973942244
loss = 0.523638 f1_score test =  0.8042036042816303
loss = 0.388580 f1_score test =  0.8058707306185258
loss = 0.422175 f1_score test =  0.8041077429816387
loss = 0.498512 f1_score test =  0.8069500665395257
loss = 0.532696 f1_score test =  0.8033493794227666
loss = 0.428323 f1_score test =  0.807495406960704
loss = 0.480160 f1_score test =  0.8121297062383352
loss = 0.496415 f1_score test =  0.8009588495779969
epoch: 11


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.590042 f1_score test =  0.7691186294476285
loss = 0.400644 f1_score test =  0.7957862587698915
loss = 0.465263 f1_score test =  0.809360645660407
loss = 0.387943 f1_score test =  0.8051769801751687
loss = 0.503329 f1_score test =  0.8051001215875008
loss = 0.510958 f1_score test =  0.8126148333529974
loss = 0.524208 f1_score test =  0.8113561746309389
loss = 0.498960 f1_score test =  0.8116528040742895
loss = 0.271811 f1_score test =  0.8147338448131098
loss = 0.366621 f1_score test =  0.8075191207252865
loss = 0.484091 f1_score test =  0.8033836665101098
loss = 0.348371 f1_score test =  0.8195547915220879
loss = 0.396028 f1_score test =  0.8052552590711656
loss = 0.323102 f1_score test =  0.8105873029712359
loss = 0.371016 f1_score test =  0.8063503437417835
loss = 0.364810 f1_score test =  0.8085136685568559
loss = 0.510053 f1_score test =  0.8127013639777794
epoch: 12


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.464351 f1_score test =  0.7934575328627635
loss = 0.339922 f1_score test =  0.798452528187645
loss = 0.443840 f1_score test =  0.8110090755202326
loss = 0.318464 f1_score test =  0.8088856554066459
loss = 0.399199 f1_score test =  0.8089658897591595
loss = 0.377918 f1_score test =  0.814912852437234
loss = 0.375582 f1_score test =  0.8202225015661401
loss = 0.375551 f1_score test =  0.8153170672684724
loss = 0.428398 f1_score test =  0.8082719279069469
loss = 0.311597 f1_score test =  0.8140191027077762
loss = 0.380004 f1_score test =  0.8104699427590427
loss = 0.392910 f1_score test =  0.8142878399760126
loss = 0.370969 f1_score test =  0.8149019923467593
loss = 0.348088 f1_score test =  0.8072396189804305
loss = 0.329974 f1_score test =  0.8180904863551768
loss = 0.321437 f1_score test =  0.8202542971220085
loss = 0.365932 f1_score test =  0.8179408890826835
epoch: 13


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.388175 f1_score test =  0.8029036315598538
loss = 0.319861 f1_score test =  0.8105850684646768
loss = 0.261368 f1_score test =  0.8163123311253251
loss = 0.287383 f1_score test =  0.8093850198360231
loss = 0.254900 f1_score test =  0.8167804892355119
loss = 0.269838 f1_score test =  0.8087530556620461
loss = 0.348378 f1_score test =  0.8137593273452564
loss = 0.363147 f1_score test =  0.8119435255247883
loss = 0.418221 f1_score test =  0.8255463486368007
loss = 0.401102 f1_score test =  0.8172350448812238
loss = 0.281728 f1_score test =  0.8152209542896368
loss = 0.315673 f1_score test =  0.8153781382116447
loss = 0.319890 f1_score test =  0.8217079549370292
loss = 0.274401 f1_score test =  0.8163870636270549
loss = 0.277367 f1_score test =  0.8152532934301442
loss = 0.337910 f1_score test =  0.8147864407940613
loss = 0.342172 f1_score test =  0.8183610724050515
epoch: 14


  0%|          | 0/323 [00:00<?, ?it/s]

loss = 3.710405 f1_score test =  0.7935447762105456
loss = 0.268102 f1_score test =  0.8223030928029146
loss = 0.295395 f1_score test =  0.8153810315807836
loss = 0.251980 f1_score test =  0.8161008140370621
loss = 0.348536 f1_score test =  0.8236929480298286
loss = 0.296359 f1_score test =  0.8176507664379907
loss = 0.379505 f1_score test =  0.8130335458565513
loss = 0.262314 f1_score test =  0.8245846633361984
loss = 0.241769 f1_score test =  0.8126664150231078
loss = 0.238844 f1_score test =  0.8168662320316952
loss = 0.349524 f1_score test =  0.8218681545817035
loss = 0.279137 f1_score test =  0.8203770992217918
loss = 0.339232 f1_score test =  0.8139177341690035
loss = 0.249672 f1_score test =  0.8159514381652749
loss = 0.266252 f1_score test =  0.8136234917864512
loss = 0.318513 f1_score test =  0.8138429369072003
loss = 0.230122 f1_score test =  0.812831771660833


In [None]:
f1

0.8281057443793844