In [1]:
from dataset_functions import from_path_to_dataloader
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
#plot   
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#create the dataloaders
path_train = './chaoyang-data/train'
path_test = './chaoyang-data/test'
batch_size = 16

train_dataloader = from_path_to_dataloader(path_train, batch_size, True, True)
test_dataloader = from_path_to_dataloader(path_test, batch_size, False, False)

#split the train dataset into train and validation
train_size = int(0.8 * len(train_dataloader.dataset))
val_size = len(train_dataloader.dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataloader.dataset, [train_size, val_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

#use the GPU if available

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#use huggingface's pretrained model for image classification vit-base-patch16-224
# model = transformers.ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=1000)
# use huggingface's pretrained model for image classification swin-base-patch4-window7-224
#model = transformers.Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
# change the number of output classes to 4
#model.classifier = nn.Linear(1024, 4)
#model.classifier = nn.Linear(768, 4)
from sm_vit import ViTSM
model = ViTSM(
    mem_blocks=64,
    image_size = 224,
    patch_size = 16,
    num_classes = 4,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)
model.to(device)
#use the AdamW optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)
#use the cross entropy loss
criterion = nn.CrossEntropyLoss()


In [3]:
from tqdm import tqdm
#import classification_report from sklearn
from sklearn.metrics import classification_report
def evaluate(model, val_dataloader, criterion):
    #initialize the loss and the number of correct predictions
    val_loss = 0
    correct = 0
    predictions=[]
    labels=[]
    #for each batch
    with torch.no_grad():
        for data, target in tqdm(val_dataloader):
            #send the data and the target to the GPU
            data, target = data.to(device), target.to(device)
            #forward pass
            output = model(data)
            #compute the loss
            val_loss += criterion(output, target).item()
            #compute the number of correct predictions
            pred = output.argmax(dim=1, keepdim=True)
            #append the predictions
            predictions.extend(pred.cpu().numpy().tolist())
            #append the labels
            labels.extend(target.cpu().numpy().tolist())

    #compute the average loss
    val_loss /= len(val_dataloader.dataset)
    #return the average loss and the number of correct predictions
    #print the classification report
    print(classification_report(labels, predictions))
    return val_loss, correct

In [4]:

#train the model
output_path = './models/'
def train(model, train_dataloader, val_dataloader, optimizer, criterion, epochs):
    #set the model in training mode
    model.train()
    train_losses = []
    val_losses = []
    #for each epoch
    for epoch in range(epochs):
        train_losses_inner = []
        loss=0
        #for each batch
        for (data, target) in tqdm(train_dataloader):
            #send the data and the target to the GPU
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            train_losses_inner.append(loss.cpu().detach().numpy())
            loss += loss.item()
            loss.backward()
            optimizer.step()
        train_losses.append(np.mean(train_losses_inner))
        #evaluate the model on the validation set
        print('Epoch: ', epoch)
        print('Validation set:')
        val_loss, val_accuracy = evaluate(model, val_dataloader, criterion)
        val_losses.append(val_loss)
        print('Test set:')
        test_loss, test_accuracy = evaluate(model, test_dataloader, criterion)
        #if the path doesn't exist, create it
        if not os.path.exists(output_path):
            os.makedirs(output_path)
        #save the model
        torch.save(model.state_dict(), output_path + 'model_' + str(epoch) + '.pth')
    #plot the train and validation losses after transferring to cpu
    

    print(train_losses)
    plt.plot(train_losses, label='train loss')
    plt.plot(val_losses, label='validation loss')
    plt.legend()
    plt.show()
    

        

In [5]:

#train the model
#train(model, val_dataloader, val_dataloader, optimizer, criterion, 20)

train(model, train_dataloader, val_dataloader, optimizer, criterion,20)

100%|██████████| 281/281 [05:37<00:00,  1.20s/it]


Epoch:  0
Validation set:


100%|██████████| 71/71 [00:51<00:00,  1.39it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:38<00:00,  1.36it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:35<00:00,  1.19s/it]


Epoch:  1
Validation set:


100%|██████████| 71/71 [00:51<00:00,  1.39it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:37<00:00,  1.37it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:35<00:00,  1.19s/it]


Epoch:  2
Validation set:


100%|██████████| 71/71 [00:51<00:00,  1.38it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:38<00:00,  1.36it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:35<00:00,  1.19s/it]


Epoch:  3
Validation set:


100%|██████████| 71/71 [00:51<00:00,  1.38it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:42<00:00,  1.31it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:35<00:00,  1.19s/it]


Epoch:  4
Validation set:


100%|██████████| 71/71 [00:51<00:00,  1.39it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:38<00:00,  1.36it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:34<00:00,  1.19s/it]


Epoch:  5
Validation set:


100%|██████████| 71/71 [00:51<00:00,  1.39it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:38<00:00,  1.36it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:39<00:00,  1.21s/it]


Epoch:  6
Validation set:


100%|██████████| 71/71 [00:52<00:00,  1.35it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:40<00:00,  1.33it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:37<00:00,  1.20s/it]


Epoch:  7
Validation set:


100%|██████████| 71/71 [00:49<00:00,  1.43it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:34<00:00,  1.42it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:25<00:00,  1.16s/it]


Epoch:  8
Validation set:


100%|██████████| 71/71 [00:48<00:00,  1.46it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:33<00:00,  1.44it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:25<00:00,  1.16s/it]


Epoch:  9
Validation set:


100%|██████████| 71/71 [00:48<00:00,  1.47it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:33<00:00,  1.43it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:25<00:00,  1.16s/it]


Epoch:  10
Validation set:


100%|██████████| 71/71 [00:48<00:00,  1.47it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:34<00:00,  1.41it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:26<00:00,  1.16s/it]


Epoch:  11
Validation set:


100%|██████████| 71/71 [00:48<00:00,  1.46it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:33<00:00,  1.43it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:32<00:00,  1.18s/it]


Epoch:  12
Validation set:


100%|██████████| 71/71 [00:51<00:00,  1.37it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:40<00:00,  1.34it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:37<00:00,  1.20s/it]


Epoch:  13
Validation set:


100%|██████████| 71/71 [00:50<00:00,  1.41it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:37<00:00,  1.38it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:35<00:00,  1.19s/it]


Epoch:  14
Validation set:


100%|██████████| 71/71 [00:50<00:00,  1.41it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:37<00:00,  1.38it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:43<00:00,  1.22s/it]


Epoch:  15
Validation set:


100%|██████████| 71/71 [00:54<00:00,  1.30it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:43<00:00,  1.29it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:36<00:00,  1.20s/it]


Epoch:  16
Validation set:


100%|██████████| 71/71 [00:50<00:00,  1.42it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.25      1.00      0.40       281
           1       0.00      0.00      0.00       293
           2       0.00      0.00      0.00       252
           3       0.00      0.00      0.00       298

    accuracy                           0.25      1124
   macro avg       0.06      0.25      0.10      1124
weighted avg       0.06      0.25      0.10      1124

Test set:


100%|██████████| 134/134 [01:36<00:00,  1.39it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.33      1.00      0.50       705
           1       0.00      0.00      0.00       321
           2       0.00      0.00      0.00       840
           3       0.00      0.00      0.00       273

    accuracy                           0.33      2139
   macro avg       0.08      0.25      0.12      2139
weighted avg       0.11      0.33      0.16      2139



100%|██████████| 281/281 [05:34<00:00,  1.19s/it]


Epoch:  17
Validation set:


 73%|███████▎  | 52/71 [00:38<00:13,  1.37it/s]


KeyboardInterrupt: 