In [1]:
from cltrier_prosem import Pipeline

In [2]:
pipeline = Pipeline({
    'encoder': {
        'model': 'deepset/gbert-base',
        # 'model': './data/results/querdenker_tu/model.bin'
    },
    'dataset': {
        'path': './data/preparation_output',
        'text_column': 'text',
        'label_column': 'label',
        'label_classes': ['corona', 'web2']
        # 'label_classes': ['corona', 'web1', 'web2'],
    },
    'classifier': {
        'hid_size': 512,
        'dropout': 0.2,
    },
    'pooler': {
        'form': 'cls',
        'span_column': 'span'
    },
    'trainer': {
        'num_epochs': 5,
        'batch_size': 32,
        'learning_rate': 1e-3,
        'export_path': './data/results',
    }
})

[--- SETUP ---]
> Computation Device: cuda
[--- LOAD ENCODER ---]
> Encoder Name: deepset/gbert-base
  Memory Usage: 419.3486 MB
> f(__init__) took: 1.6905 sec
[--- LOAD TRAINER ---]
> Dataset: train
  Samples: 7982
> Dataset: test
  Samples: 1995
> Model: Multilayer perceptron
  Memory Usage: 1.5059 MB 
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): Dropout(p=0.2, inplace=False)
  (2): LeakyReLU(negative_slope=0.01)
  (3): Linear(in_features=512, out_features=2, bias=True)



In [3]:
pipeline()

[--- RUN TRAINER ---]
[@001]: 	loss_train=0.1861 	loss_test=0.1314 	f1_train=0.9218 	f1_test=0.9471 	duration=0:00:00
[@002]: 	loss_train=0.1242 	loss_test=0.1516 	f1_train=0.9533 	f1_test=0.9354 	duration=0:00:00
[@003]: 	loss_train=0.1096 	loss_test=0.1255 	f1_train=0.9577 	f1_test=0.9491 	duration=0:00:00
[@004]: 	loss_train=0.0924 	loss_test=0.1121 	f1_train=0.9646 	f1_test=0.9627 	duration=0:00:00
[@005]: 	loss_train=0.0814 	loss_test=0.1189 	f1_train=0.9706 	f1_test=0.9562 	duration=0:00:00
[--- EVALUATION on max(f1_test) ---]
              precision    recall  f1-score  support
corona         0.969792  0.953893  0.961777    976.0
web2           0.956055  0.971230  0.963583   1008.0
macro avg      0.962923  0.962562  0.962680   1984.0
weighted avg   0.962812  0.962702  0.962694   1984.0
