In [1]:
import nltk
import fasttext
import pandas as pd
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score
import os
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.autograd import Variable
from tqdm import tqdm

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
data = pd.read_csv('../corpora/annotated.csv')
data['text'] = data['text'].astype(str)

In [4]:
data['tokens'] = data['text'].apply(nltk.word_tokenize)

In [5]:
# def equalize_lengths(vectors, dst_len=100):
#     if len(vectors) < dst_len:
#         return vectors + [[0] * 100] * (dst_len - len(vectors))
#     elif len(vectors) > dst_len:
#         return vectors[:dst_len]
#     else:
#         return vectors

def get_vectorizer(model, equalized=False, dst_len=100):
    def get_text_repr(tokens):
        return torch.tensor(np.stack([model.get_word_vector(word) for word in tokens]))
    def get_text_repr_eq(tokens):
        vectors = torch.zeros((dst_len, 100))
        vectors[:min(len(tokens), dst_len)] = torch.tensor(np.stack([model.get_word_vector(word) for word in tokens[:dst_len]]))
        return vectors
    if equalized:
        return get_text_repr_eq
    else:
        return get_text_repr

def prepare_data(data, model, equalized=False):
    vectorizer = get_vectorizer(model, equalized)
    X = data['text'].apply(vectorizer)
    y = data['target']
    mask = np.random.rand(len(X)) < 0.8
    X_train = list(X[mask])
    X_test = list(X[~mask])
    y_train = list(y[mask])
    y_test = list(y[~mask])
    return X_train, y_train, X_test, y_test

In [10]:
class MyDataset(Dataset):

    def __init__(self, X, y) -> None:
        super().__init__()
        self.data = list(zip(X, y))

    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

class MyNet(nn.Module):

    def __init__(self, rlayer: nn.Module) -> None:
        super().__init__()
        self.rnn = rlayer(100, 100, 3, batch_first=True)
        self.reg = nn.Linear(100, 2)
    
    def forward(self, x, x_lengths):
        packed = nn.utils.rnn.pack_padded_sequence(x, x_lengths, batch_first=True, enforce_sorted=False)
        output, _ = self.rnn(packed)
        unpacked, unpacked_len = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
        indices = Variable(torch.LongTensor(np.array(unpacked_len) - 1).view(-1, 1)
                                                                       .expand(unpacked.size(0), unpacked.size(2))
                                                                       .unsqueeze(1)
                                                                       .to(device))
        last_encoded_states = unpacked.gather(dim=1, index=indices).squeeze(dim=1)
        return self.reg(last_encoded_states)

class MyConvNet(nn.Module):

    def __init__(self, seq_len=100) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(100, 100, 5),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(100, 100, 3),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(100, 100, 2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(1100, 2),
            nn.Softmax()
        )
    
    def forward(self, x, lengths=None):
        return self.net(x)

def my_collate(batch):
    X_batch, y_batch = zip(*batch)
    lengths = [sample.shape[0] for sample in X_batch]
    X_batch = nn.utils.rnn.pad_sequence(X_batch, batch_first=True)
    return X_batch, torch.tensor(y_batch), lengths

def to_dl(X, y, equalized=False):
    if equalized:
        ds = TensorDataset(torch.stack(X), torch.tensor(y))
        dl = DataLoader(ds, 128, shuffle=True, pin_memory=True)
    else:
        ds = MyDataset(X, y)
        dl = DataLoader(ds, 128, shuffle=True, collate_fn=my_collate, pin_memory=True)
    return dl

In [7]:
def validate(model, dl, loss_fn, equalized=False):
    y_pred = []
    y_true = []
    for batch in dl:
        if equalized:
            (X_batch, y_batch) = batch
            lengths = None
        else:
            (X_batch, y_batch, lengths) = batch
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        y_true.append(y_batch)
        y_pred.append(model(X_batch, lengths))
    y_pred = torch.concat(y_pred)
    y_true = torch.concat(y_true)
    loss = loss_fn(y_pred, y_true)
    y_pred = torch.argmax(y_pred, axis=1).cpu()
    y_true = y_true.cpu()
    return precision_score(y_true, y_pred), recall_score(y_true, y_pred), f1_score(y_true, y_pred), loss

def fit(model, loss_fn, optimizer, train_dl, val_dl, epochs=50, equalized=False, show_metrics=True):
    metrics = {
        'train': {
            'precision': [],
            'recall': [],
            'f1': [],
            'loss': []
        },
        'val': {
            'precision': [],
            'recall': [],
            'f1': [],
            'loss': []
        }
    }
    for epoch in range(epochs):
        for batch in tqdm(train_dl):
            if equalized:
                (X_batch, y_batch) = batch
                lengths = None
            else:
                (X_batch, y_batch, lengths) = batch
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)
            y_pred = model(X_batch, lengths)
            loss = loss_fn(y_pred, y_batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        with torch.no_grad():
            train_prec, train_recall, train_f1, train_loss = validate(model, train_dl, loss_fn, equalized)
            val_prec, val_recall, val_f1, val_loss = validate(model, val_dl, loss_fn, equalized)
            metrics['train']['precision'].append(train_prec)
            metrics['train']['recall'].append(train_recall)
            metrics['train']['f1'].append(train_f1)
            metrics['train']['loss'].append(train_loss.item())
            metrics['val']['precision'].append(val_prec)
            metrics['val']['recall'].append(val_recall)
            metrics['val']['f1'].append(val_f1)
            metrics['val']['loss'].append(val_loss.item())
            if show_metrics:
                print(f'Epoch: {epoch}\ttrain: f1 = {train_f1} loss = {train_loss}\tval: f1 = {val_f1} loss = {val_loss}')
    return metrics


In [8]:
data['target'].value_counts()

0    20959
1    13328
Name: target, dtype: int64

In [11]:
metrics = {}
classifiers = [
    'cnn',
    # 'rnn',
    # 'lstm'
]
for model_path in os.listdir('../dist_models'):
    model_path = '../dist_models/' + model_path
    metrics[model_path] = {}
    for classifier in classifiers:
        equalize = (classifier == 'cnn')
        vectorizer = fasttext.FastText.load_model(model_path)
        X_train, y_train, X_test, y_test = prepare_data(data, vectorizer, equalize)
        train_dl = to_dl(X_train, y_train, equalize)
        test_dl = to_dl(X_test, y_test, equalize)
        if classifier == 'cnn':
            model = MyConvNet().to(device)
        elif classifier == 'rnn':
            model = MyNet(nn.RNN).to(device)
        elif classifier == 'lstm':
            model = MyNet(nn.LSTM).to(device)
        loss_fn = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        print('\n', model_path, classifier,'\n')
        metrics[model_path][classifier] = fit(model, loss_fn, optimizer, train_dl, test_dl, epochs=50, equalized=equalize)

  0%|          | 0/215 [00:00<?, ?it/s]


 ../dist_models/full_m1.bin cnn 



  input = module(input)
100%|██████████| 215/215 [00:01<00:00, 176.42it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 245.10it/s]

Epoch: 0	train: f1 = 0.21150011962676446 loss = 0.6305745244026184	val: f1 = 0.15226469643430776 loss = 0.6323307752609253


100%|██████████| 215/215 [00:00<00:00, 252.64it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  input = module(input)
  _warn_prf(average, modifier, msg_start, len(result))
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 245.09it/s]

Epoch: 1	train: f1 = 0.0 loss = 0.6400322914123535	val: f1 = 0.0 loss = 0.6443073749542236


100%|██████████| 215/215 [00:00<00:00, 250.00it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 242.72it/s]

Epoch: 2	train: f1 = 0.2945664739884393 loss = 0.6066299080848694	val: f1 = 0.17278617710583152 loss = 0.6387745141983032


100%|██████████| 215/215 [00:00<00:00, 251.46it/s]
  input = module(input)
  input = module(input)
 23%|██▎       | 50/215 [00:00<00:00, 245.10it/s]

Epoch: 3	train: f1 = 0.3186419560816072 loss = 0.5855923891067505	val: f1 = 0.15457713018688626 loss = 0.6464352607727051


100%|██████████| 215/215 [00:00<00:00, 252.05it/s]
  input = module(input)
  input = module(input)
 23%|██▎       | 50/215 [00:00<00:00, 246.91it/s]

Epoch: 4	train: f1 = 0.34013188161324953 loss = 0.5797343850135803	val: f1 = 0.2016250376166115 loss = 0.6468380689620972


100%|██████████| 215/215 [00:00<00:00, 255.34it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 250.00it/s]

Epoch: 5	train: f1 = 0.3454782541995858 loss = 0.5759356617927551	val: f1 = 0.2052663076002394 loss = 0.6477394700050354


100%|██████████| 215/215 [00:00<00:00, 255.04it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 252.43it/s]

Epoch: 6	train: f1 = 0.3466420493884145 loss = 0.5754237174987793	val: f1 = 0.20048455481526345 loss = 0.6475369930267334


100%|██████████| 215/215 [00:00<00:00, 257.79it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 254.90it/s]

Epoch: 7	train: f1 = 0.3478995468857999 loss = 0.5763253569602966	val: f1 = 0.20578586340590518 loss = 0.6516066789627075


100%|██████████| 215/215 [00:00<00:00, 257.48it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 247.52it/s]

Epoch: 8	train: f1 = 0.34048818287485466 loss = 0.5751975178718567	val: f1 = 0.1500638569604087 loss = 0.6501412987709045


100%|██████████| 215/215 [00:00<00:00, 255.34it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 247.52it/s]

Epoch: 9	train: f1 = 0.3346557759626605 loss = 0.5767582058906555	val: f1 = 0.13190383365821962 loss = 0.6532734036445618


100%|██████████| 215/215 [00:00<00:00, 251.17it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 245.10it/s]

Epoch: 10	train: f1 = 0.34783273450042046 loss = 0.5755226016044617	val: f1 = 0.22287735849056606 loss = 0.6490431427955627


100%|██████████| 215/215 [00:00<00:00, 253.84it/s]
  input = module(input)
  input = module(input)
 24%|██▍       | 52/215 [00:00<00:00, 254.90it/s]

Epoch: 11	train: f1 = 0.3453193283753737 loss = 0.5757966041564941	val: f1 = 0.20066285025610123 loss = 0.6506432890892029


100%|██████████| 215/215 [00:00<00:00, 256.87it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 247.52it/s]

Epoch: 12	train: f1 = 0.3397743082392951 loss = 0.5761271119117737	val: f1 = 0.17037269025994364 loss = 0.6497892737388611


100%|██████████| 215/215 [00:00<00:00, 253.84it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 252.43it/s]

Epoch: 13	train: f1 = 0.34524714828897335 loss = 0.5786301493644714	val: f1 = 0.2173913043478261 loss = 0.6533048748970032


100%|██████████| 215/215 [00:00<00:00, 254.74it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 250.00it/s]

Epoch: 14	train: f1 = 0.3410984264116014 loss = 0.5762283205986023	val: f1 = 0.177449168207024 loss = 0.6524011492729187


100%|██████████| 215/215 [00:00<00:00, 255.34it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 250.00it/s]

Epoch: 15	train: f1 = 0.3410458121240167 loss = 0.576094388961792	val: f1 = 0.1796850879901204 loss = 0.6510370969772339


100%|██████████| 215/215 [00:00<00:00, 254.44it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 247.53it/s]

Epoch: 16	train: f1 = 0.34661965353365015 loss = 0.5753017067909241	val: f1 = 0.20209895052473764 loss = 0.6520302891731262


100%|██████████| 215/215 [00:00<00:00, 255.04it/s]
  input = module(input)
  input = module(input)
 23%|██▎       | 50/215 [00:00<00:00, 250.00it/s]

Epoch: 17	train: f1 = 0.3451761102603369 loss = 0.5764454007148743	val: f1 = 0.20549581839904416 loss = 0.6521902680397034


100%|██████████| 215/215 [00:00<00:00, 255.34it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 252.43it/s]

Epoch: 18	train: f1 = 0.33607320099255583 loss = 0.5766331553459167	val: f1 = 0.15178571428571427 loss = 0.6517460346221924


100%|██████████| 215/215 [00:00<00:00, 255.04it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 250.05it/s]

Epoch: 19	train: f1 = 0.33888845825904973 loss = 0.5753331780433655	val: f1 = 0.14892937040588047 loss = 0.6528979539871216


100%|██████████| 215/215 [00:00<00:00, 255.66it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 252.43it/s]

Epoch: 20	train: f1 = 0.3441339066339067 loss = 0.5757885575294495	val: f1 = 0.19382192610539067 loss = 0.6518751382827759


100%|██████████| 215/215 [00:00<00:00, 257.18it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 247.52it/s]

Epoch: 21	train: f1 = 0.34003401886500695 loss = 0.5755324959754944	val: f1 = 0.16813048933500627 loss = 0.6515231728553772


100%|██████████| 215/215 [00:00<00:00, 250.58it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 247.53it/s]

Epoch: 22	train: f1 = 0.315640258308395 loss = 0.5810014009475708	val: f1 = 0.10395707578806172 loss = 0.6541670560836792


100%|██████████| 215/215 [00:00<00:00, 243.21it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 250.00it/s]

Epoch: 23	train: f1 = 0.3451714351461226 loss = 0.5757362246513367	val: f1 = 0.20096560048280024 loss = 0.6511601209640503


100%|██████████| 215/215 [00:00<00:00, 254.14it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 250.00it/s]

Epoch: 24	train: f1 = 0.3450774805334978 loss = 0.5741269588470459	val: f1 = 0.17694204685573364 loss = 0.6534164547920227


100%|██████████| 215/215 [00:00<00:00, 251.46it/s]
  input = module(input)
  input = module(input)
 24%|██▍       | 52/215 [00:00<00:00, 252.43it/s]

Epoch: 25	train: f1 = 0.33944173818912854 loss = 0.5759729146957397	val: f1 = 0.1581920903954802 loss = 0.6548151969909668


100%|██████████| 215/215 [00:00<00:00, 248.84it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 252.43it/s]

Epoch: 26	train: f1 = 0.33909069785454266 loss = 0.5755571722984314	val: f1 = 0.15757957768673178 loss = 0.6547338962554932


100%|██████████| 215/215 [00:00<00:00, 253.84it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 254.90it/s]

Epoch: 27	train: f1 = 0.34283066944637486 loss = 0.5746512413024902	val: f1 = 0.17199627444892893 loss = 0.6538594961166382


100%|██████████| 215/215 [00:00<00:00, 255.95it/s]
  input = module(input)
  input = module(input)
 11%|█         | 24/215 [00:00<00:00, 239.99it/s]

Epoch: 28	train: f1 = 0.3322429906542056 loss = 0.5770396590232849	val: f1 = 0.1272845953002611 loss = 0.6550006866455078


100%|██████████| 215/215 [00:00<00:00, 247.41it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 250.00it/s]

Epoch: 29	train: f1 = 0.33485989288209267 loss = 0.5770063400268555	val: f1 = 0.15180265654648958 loss = 0.6564581990242004


100%|██████████| 215/215 [00:00<00:00, 250.29it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 252.42it/s]

Epoch: 30	train: f1 = 0.3363516755669066 loss = 0.577172577381134	val: f1 = 0.15845959595959597 loss = 0.6544317007064819


100%|██████████| 215/215 [00:00<00:00, 253.54it/s]
  input = module(input)
  input = module(input)
 11%|█         | 24/215 [00:00<00:00, 235.29it/s]

Epoch: 31	train: f1 = 0.3418316259724255 loss = 0.5758830308914185	val: f1 = 0.1741555624418965 loss = 0.6535577774047852


100%|██████████| 215/215 [00:00<00:00, 247.41it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 242.72it/s]

Epoch: 32	train: f1 = 0.3435788192331102 loss = 0.5788419246673584	val: f1 = 0.2172106824925816 loss = 0.6501418352127075


100%|██████████| 215/215 [00:00<00:00, 255.65it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 250.00it/s]

Epoch: 33	train: f1 = 0.342038363762422 loss = 0.5757935047149658	val: f1 = 0.18041871921182265 loss = 0.6536883115768433


100%|██████████| 215/215 [00:00<00:00, 252.35it/s]
  input = module(input)
  input = module(input)
 24%|██▎       | 51/215 [00:00<00:00, 246.55it/s]

Epoch: 34	train: f1 = 0.34101951106655354 loss = 0.5765252113342285	val: f1 = 0.18198585920688598 loss = 0.653664231300354


100%|██████████| 215/215 [00:00<00:00, 255.04it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 250.00it/s]

Epoch: 35	train: f1 = 0.34089524969549334 loss = 0.5804638266563416	val: f1 = 0.2045656685443226 loss = 0.6562837958335876


100%|██████████| 215/215 [00:00<00:00, 249.13it/s]
  input = module(input)
  input = module(input)
 22%|██▏       | 48/215 [00:00<00:00, 236.22it/s]

Epoch: 36	train: f1 = 0.3365650969529086 loss = 0.5787374973297119	val: f1 = 0.18476598348118692 loss = 0.6537705063819885


100%|██████████| 215/215 [00:00<00:00, 235.49it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 247.52it/s]

Epoch: 37	train: f1 = 0.27916331456154464 loss = 0.5909081697463989	val: f1 = 0.09284988139613691 loss = 0.6554970145225525


100%|██████████| 215/215 [00:00<00:00, 246.84it/s]
  input = module(input)
  input = module(input)
 10%|▉         | 21/215 [00:00<00:00, 201.92it/s]

Epoch: 38	train: f1 = 0.3353950827276944 loss = 0.5783088207244873	val: f1 = 0.17342482844666252 loss = 0.6519681215286255


100%|██████████| 215/215 [00:00<00:00, 243.21it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 242.72it/s]

Epoch: 39	train: f1 = 0.3422402572544216 loss = 0.5776113867759705	val: f1 = 0.20489844683393071 loss = 0.6534018516540527


100%|██████████| 215/215 [00:00<00:00, 252.94it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 247.52it/s]

Epoch: 40	train: f1 = 0.329466718567536 loss = 0.5782440304756165	val: f1 = 0.14212218649517686 loss = 0.6544206142425537


100%|██████████| 215/215 [00:00<00:00, 250.58it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 247.52it/s]

Epoch: 41	train: f1 = 0.34010777521170127 loss = 0.5770073533058167	val: f1 = 0.1730232558139535 loss = 0.6545342803001404


100%|██████████| 215/215 [00:00<00:00, 255.04it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 248.30it/s]

Epoch: 42	train: f1 = 0.3266037882921506 loss = 0.5793122053146362	val: f1 = 0.14249037227214378 loss = 0.6549794673919678


100%|██████████| 215/215 [00:00<00:00, 247.79it/s]
  input = module(input)
  input = module(input)
 11%|█         | 24/215 [00:00<00:00, 237.62it/s]

Epoch: 43	train: f1 = 0.3389176252605574 loss = 0.5769763588905334	val: f1 = 0.16883923029174427 loss = 0.6564470529556274


100%|██████████| 215/215 [00:00<00:00, 253.24it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 247.52it/s]

Epoch: 44	train: f1 = 0.3401309353600722 loss = 0.5843250751495361	val: f1 = 0.23258491652274038 loss = 0.6547430753707886


100%|██████████| 215/215 [00:00<00:00, 250.29it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/215 [00:00<00:00, 245.10it/s]

Epoch: 45	train: f1 = 0.3198684828558009 loss = 0.5810782313346863	val: f1 = 0.13733075435203096 loss = 0.6559439897537231


100%|██████████| 215/215 [00:00<00:00, 257.79it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 252.43it/s]

Epoch: 46	train: f1 = 0.33713581008170185 loss = 0.5780174136161804	val: f1 = 0.17868145409735056 loss = 0.6543813347816467


100%|██████████| 215/215 [00:00<00:00, 253.54it/s]
  input = module(input)
  input = module(input)
 23%|██▎       | 50/215 [00:00<00:00, 240.19it/s]

Epoch: 47	train: f1 = 0.3336424045742544 loss = 0.5788038969039917	val: f1 = 0.17202970297029704 loss = 0.6557523012161255


100%|██████████| 215/215 [00:00<00:00, 257.18it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/215 [00:00<00:00, 258.80it/s]

Epoch: 48	train: f1 = 0.34288326729655594 loss = 0.5789474248886108	val: f1 = 0.20708122582564714 loss = 0.6538868546485901


100%|██████████| 215/215 [00:00<00:00, 258.58it/s]
  input = module(input)


Epoch: 49	train: f1 = 0.3406862745098039 loss = 0.578438401222229	val: f1 = 0.18738629472407523 loss = 0.6564867496490479


  input = module(input)
 12%|█▏        | 26/216 [00:00<00:00, 257.43it/s]


 ../dist_models/full_m2.bin cnn 



100%|██████████| 216/216 [00:00<00:00, 259.62it/s]
  input = module(input)
  input = module(input)
 11%|█         | 24/216 [00:00<00:00, 237.63it/s]

Epoch: 0	train: f1 = 0.29228287173971496 loss = 0.6327171921730042	val: f1 = 0.3047120418848167 loss = 0.6368463039398193


100%|██████████| 216/216 [00:00<00:00, 257.76it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 247.52it/s]

Epoch: 1	train: f1 = 0.29613333333333336 loss = 0.6314786076545715	val: f1 = 0.3086900129701686 loss = 0.6365037560462952


100%|██████████| 216/216 [00:00<00:00, 254.72it/s]
  input = module(input)
  input = module(input)
 24%|██▎       | 51/216 [00:00<00:00, 252.18it/s]

Epoch: 2	train: f1 = 0.22287717062290904 loss = 0.6251424551010132	val: f1 = 0.15947874650946323 loss = 0.6378347277641296


100%|██████████| 216/216 [00:00<00:00, 256.53it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 247.53it/s]

Epoch: 3	train: f1 = 0.20496277915632752 loss = 0.6144179105758667	val: f1 = 0.11657440573103224 loss = 0.6437643766403198


100%|██████████| 216/216 [00:00<00:00, 254.72it/s]
  input = module(input)
  input = module(input)
 11%|█         | 24/216 [00:00<00:00, 239.99it/s]

Epoch: 4	train: f1 = 0.28621070920865577 loss = 0.5970348715782166	val: f1 = 0.16037441497659904 loss = 0.6501312851905823


100%|██████████| 216/216 [00:00<00:00, 247.14it/s]
  input = module(input)
  input = module(input)
 11%|█         | 24/216 [00:00<00:00, 237.62it/s]

Epoch: 5	train: f1 = 0.31906614785992216 loss = 0.5861878991127014	val: f1 = 0.1670746634026928 loss = 0.6519937515258789


100%|██████████| 216/216 [00:00<00:00, 231.26it/s]
  input = module(input)
  input = module(input)
 11%|█         | 24/216 [00:00<00:00, 237.62it/s]

Epoch: 6	train: f1 = 0.33705443067210716 loss = 0.5804919600486755	val: f1 = 0.21370023419203746 loss = 0.651189923286438


100%|██████████| 216/216 [00:00<00:00, 234.02it/s]
  input = module(input)
  input = module(input)
 10%|█         | 22/216 [00:00<00:00, 215.69it/s]

Epoch: 7	train: f1 = 0.34025025258413 loss = 0.5754513740539551	val: f1 = 0.17580398162327718 loss = 0.6531650424003601


100%|██████████| 216/216 [00:00<00:00, 216.65it/s]
  input = module(input)
  input = module(input)
 11%|█         | 24/216 [00:00<00:00, 240.00it/s]

Epoch: 8	train: f1 = 0.34178593652769135 loss = 0.5745996832847595	val: f1 = 0.16424822476072864 loss = 0.6559463739395142


100%|██████████| 216/216 [00:00<00:00, 249.14it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/216 [00:00<00:00, 250.00it/s]

Epoch: 9	train: f1 = 0.3380941231561695 loss = 0.5748610496520996	val: f1 = 0.14985773000316155 loss = 0.6534700989723206


100%|██████████| 216/216 [00:00<00:00, 250.14it/s]
  input = module(input)
  input = module(input)
 24%|██▎       | 51/216 [00:00<00:00, 250.36it/s]

Epoch: 10	train: f1 = 0.3452131376659679 loss = 0.5732201337814331	val: f1 = 0.19632265717674968 loss = 0.6554413437843323


100%|██████████| 216/216 [00:00<00:00, 250.58it/s]
  input = module(input)
  input = module(input)
 24%|██▎       | 51/216 [00:00<00:00, 247.25it/s]

Epoch: 11	train: f1 = 0.34546442151004886 loss = 0.5737773776054382	val: f1 = 0.19785458879618595 loss = 0.6562924385070801


100%|██████████| 216/216 [00:00<00:00, 250.87it/s]
  input = module(input)
  input = module(input)
 24%|██▍       | 52/216 [00:00<00:00, 253.16it/s]

Epoch: 12	train: f1 = 0.34362187669219457 loss = 0.5749818086624146	val: f1 = 0.20520402128917797 loss = 0.6535796523094177


100%|██████████| 216/216 [00:00<00:00, 253.52it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 245.10it/s]

Epoch: 13	train: f1 = 0.344238091537804 loss = 0.5733038783073425	val: f1 = 0.18131868131868134 loss = 0.6533448100090027


100%|██████████| 216/216 [00:00<00:00, 252.93it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/216 [00:00<00:00, 252.43it/s]

Epoch: 14	train: f1 = 0.34459197142635295 loss = 0.5734177827835083	val: f1 = 0.18109996961409905 loss = 0.6537577509880066


100%|██████████| 216/216 [00:00<00:00, 246.01it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/216 [00:00<00:00, 252.42it/s]

Epoch: 15	train: f1 = 0.3433543074652373 loss = 0.5737264156341553	val: f1 = 0.18181818181818182 loss = 0.6555233001708984


100%|██████████| 216/216 [00:00<00:00, 245.73it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 240.38it/s]

Epoch: 16	train: f1 = 0.3433056181521486 loss = 0.5742785334587097	val: f1 = 0.17620067298868156 loss = 0.6558271646499634


100%|██████████| 216/216 [00:00<00:00, 248.56it/s]
  input = module(input)
  input = module(input)
 22%|██▏       | 48/216 [00:00<00:00, 238.01it/s]

Epoch: 17	train: f1 = 0.34292827135522014 loss = 0.5737132430076599	val: f1 = 0.1720958819913952 loss = 0.6562511324882507


100%|██████████| 216/216 [00:00<00:00, 244.07it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/216 [00:00<00:00, 250.01it/s]

Epoch: 18	train: f1 = 0.345805252149663 loss = 0.5734778046607971	val: f1 = 0.20673360897814527 loss = 0.6536928415298462


100%|██████████| 216/216 [00:00<00:00, 252.63it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 240.38it/s]

Epoch: 19	train: f1 = 0.3467204217708172 loss = 0.572746217250824	val: f1 = 0.20444444444444443 loss = 0.6532027125358582


100%|██████████| 216/216 [00:00<00:00, 244.07it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 240.39it/s]

Epoch: 20	train: f1 = 0.34339505212385246 loss = 0.5731948614120483	val: f1 = 0.16076970825574177 loss = 0.6558522582054138


100%|██████████| 216/216 [00:00<00:00, 252.93it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 247.52it/s]

Epoch: 21	train: f1 = 0.33129411764705885 loss = 0.5761941075325012	val: f1 = 0.13556055252168328 loss = 0.6556144952774048


100%|██████████| 216/216 [00:00<00:00, 249.71it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 247.52it/s]

Epoch: 22	train: f1 = 0.3445773832497105 loss = 0.574933648109436	val: f1 = 0.21148213239601643 loss = 0.654420018196106


100%|██████████| 216/216 [00:00<00:00, 251.16it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/216 [00:00<00:00, 252.42it/s]

Epoch: 23	train: f1 = 0.34713820381572824 loss = 0.5726929306983948	val: f1 = 0.19856887298747763 loss = 0.6536232233047485


100%|██████████| 216/216 [00:00<00:00, 249.71it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/216 [00:00<00:00, 252.43it/s]

Epoch: 24	train: f1 = 0.3469704013637068 loss = 0.5726555585861206	val: f1 = 0.20385185185185187 loss = 0.6550868153572083


100%|██████████| 216/216 [00:00<00:00, 252.63it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/216 [00:00<00:00, 252.43it/s]

Epoch: 25	train: f1 = 0.3346638491038585 loss = 0.575270414352417	val: f1 = 0.13952 loss = 0.6559618711471558


100%|██████████| 216/216 [00:00<00:00, 256.23it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 250.00it/s]

Epoch: 26	train: f1 = 0.34510595358224017 loss = 0.5729709267616272	val: f1 = 0.19472738166566803 loss = 0.6543720364570618


100%|██████████| 216/216 [00:00<00:00, 258.37it/s]
  input = module(input)
  input = module(input)
 24%|██▎       | 51/216 [00:00<00:00, 251.69it/s]

Epoch: 27	train: f1 = 0.34692134482490883 loss = 0.5722079277038574	val: f1 = 0.18766917293233082 loss = 0.6567940711975098


100%|██████████| 216/216 [00:00<00:00, 246.86it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 240.38it/s]

Epoch: 28	train: f1 = 0.33799515587155243 loss = 0.5742582678794861	val: f1 = 0.14744200826183668 loss = 0.6541734933853149


100%|██████████| 216/216 [00:00<00:00, 252.04it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/216 [00:00<00:00, 252.43it/s]

Epoch: 29	train: f1 = 0.3462312335551772 loss = 0.5735010504722595	val: f1 = 0.2129165437923916 loss = 0.6523505449295044


100%|██████████| 216/216 [00:00<00:00, 249.14it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 247.52it/s]

Epoch: 30	train: f1 = 0.33760349945321044 loss = 0.5743852257728577	val: f1 = 0.16249999999999998 loss = 0.6536580920219421


100%|██████████| 216/216 [00:00<00:00, 251.16it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 249.99it/s]

Epoch: 31	train: f1 = 0.34341235911387485 loss = 0.5733432769775391	val: f1 = 0.1857359635811836 loss = 0.654394268989563


100%|██████████| 216/216 [00:00<00:00, 251.75it/s]
  input = module(input)
  input = module(input)
 11%|█         | 23/216 [00:00<00:00, 227.72it/s]

Epoch: 32	train: f1 = 0.3428838368109623 loss = 0.5732712745666504	val: f1 = 0.17176397899289467 loss = 0.6538372039794922


100%|██████████| 216/216 [00:00<00:00, 247.42it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 242.72it/s]

Epoch: 33	train: f1 = 0.3406344999610258 loss = 0.5737276673316956	val: f1 = 0.17202970297029704 loss = 0.6532871127128601


100%|██████████| 216/216 [00:00<00:00, 250.00it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 240.39it/s]

Epoch: 34	train: f1 = 0.3422301766399502 loss = 0.5733816623687744	val: f1 = 0.17830882352941177 loss = 0.6538968682289124


100%|██████████| 216/216 [00:00<00:00, 258.37it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/216 [00:00<00:00, 247.62it/s]

Epoch: 35	train: f1 = 0.34537841191067 loss = 0.5732279419898987	val: f1 = 0.19630071599045346 loss = 0.6551087498664856


100%|██████████| 216/216 [00:00<00:00, 244.34it/s]
  input = module(input)
  input = module(input)
 11%|█         | 24/216 [00:00<00:00, 233.01it/s]

Epoch: 36	train: f1 = 0.3462848297213622 loss = 0.5732268691062927	val: f1 = 0.20444444444444443 loss = 0.6540798544883728


100%|██████████| 216/216 [00:00<00:00, 252.04it/s]
  input = module(input)
  input = module(input)
 11%|█         | 24/216 [00:00<00:00, 235.30it/s]

Epoch: 37	train: f1 = 0.3450419645632577 loss = 0.5727773308753967	val: f1 = 0.18891916439600362 loss = 0.6535951495170593


100%|██████████| 216/216 [00:00<00:00, 252.34it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 247.53it/s]

Epoch: 38	train: f1 = 0.34615981380915434 loss = 0.5726956129074097	val: f1 = 0.20342992312241276 loss = 0.6559641361236572


100%|██████████| 216/216 [00:00<00:00, 256.84it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 242.72it/s]

Epoch: 39	train: f1 = 0.3460433330744738 loss = 0.572431743144989	val: f1 = 0.20403321470937127 loss = 0.654970645904541


100%|██████████| 216/216 [00:00<00:00, 243.24it/s]
  input = module(input)
  input = module(input)
 23%|██▎       | 50/216 [00:00<00:00, 247.52it/s]

Epoch: 40	train: f1 = 0.3443451224251846 loss = 0.5729185938835144	val: f1 = 0.1932193219321932 loss = 0.6550403833389282


100%|██████████| 216/216 [00:00<00:00, 254.12it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 242.72it/s]

Epoch: 41	train: f1 = 0.3464188927603562 loss = 0.5731143355369568	val: f1 = 0.19798458802608182 loss = 0.6576799750328064


100%|██████████| 216/216 [00:00<00:00, 255.92it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 245.09it/s]

Epoch: 42	train: f1 = 0.3426791277258567 loss = 0.5729762315750122	val: f1 = 0.17516902274124158 loss = 0.6548171043395996


100%|██████████| 216/216 [00:00<00:00, 251.16it/s]
  input = module(input)
  input = module(input)
 11%|█         | 23/216 [00:00<00:00, 227.72it/s]

Epoch: 43	train: f1 = 0.3310582882246803 loss = 0.5760524868965149	val: f1 = 0.1362905818064931 loss = 0.6547420620918274


100%|██████████| 216/216 [00:00<00:00, 248.28it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 250.00it/s]

Epoch: 44	train: f1 = 0.3437013996889581 loss = 0.5730798840522766	val: f1 = 0.18077039733090688 loss = 0.6564350724220276


100%|██████████| 216/216 [00:00<00:00, 253.22it/s]
  input = module(input)
  input = module(input)
 11%|█         | 23/216 [00:00<00:00, 221.16it/s]

Epoch: 45	train: f1 = 0.33999063816508035 loss = 0.5736397504806519	val: f1 = 0.15519399249061325 loss = 0.6562162041664124


100%|██████████| 216/216 [00:00<00:00, 252.63it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 26/216 [00:00<00:00, 252.43it/s]

Epoch: 46	train: f1 = 0.3444726127371898 loss = 0.5762539505958557	val: f1 = 0.2392725206024439 loss = 0.6531237363815308


100%|██████████| 216/216 [00:00<00:00, 254.72it/s]
  input = module(input)
  input = module(input)
 11%|█         | 24/216 [00:00<00:00, 237.63it/s]

Epoch: 47	train: f1 = 0.3414330218068536 loss = 0.5736791491508484	val: f1 = 0.17447595561035759 loss = 0.6539187431335449


100%|██████████| 216/216 [00:00<00:00, 256.23it/s]
  input = module(input)
  input = module(input)
 12%|█▏        | 25/216 [00:00<00:00, 245.10it/s]

Epoch: 48	train: f1 = 0.34569052615217216 loss = 0.5723620653152466	val: f1 = 0.17830882352941177 loss = 0.6544350981712341


100%|██████████| 216/216 [00:00<00:00, 255.32it/s]
  input = module(input)


Epoch: 49	train: f1 = 0.34686631676443347 loss = 0.5737553238868713	val: f1 = 0.22794330344229097 loss = 0.6524492502212524


  input = module(input)
  8%|▊         | 17/214 [00:00<00:02, 97.70it/s] 


 ../dist_models/sample_m1.bin cnn 






KeyboardInterrupt: 

In [None]:
# import json

# with open('metrics_cnn.json', 'w') as f:
#     json.dump(metrics, f)