In [1]:
from transformers import AutoTokenizer, AutoModel
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt
import plotly.express as px
import torch
from torch import nn 
from torch.optim import AdamW
from sklearn.metrics import accuracy_score
import sys 
import os
sys.path.append('../')

from src.bayes.bayes_models import *
from src.trainer import CustomTrainer
from custom_datasets.bert_bayes import *
from src.utils.data_utils import load_exp_model

In [2]:
model = BayesModule("bert-base-cased", 10)
tokenizer = model.tokenizer
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
metric = accuracy_score
batch_size = 8
n_epochs = 5
datas = [
    load_bigbenchhard(tokenizer, batch_size), 
    load_multiscam(tokenizer, batch_size), 
    load_llmops(tokenizer, batch_size)
]
res_logs = {
    "BaseBert_loss": [],
    "BaseBert_metric": [],
}
for train_loader, test_loader in datas:
    mx = 0
    for _, q in train_loader:
        mx = max(mx, max(q))
    model = BayesModule("bert-base-cased", mx+1)
    trainer = CustomTrainer(model, optimizer, criterion, metric, 'cuda')
    logs = trainer.run_train(train_loader, test_loader)
    res_logs['BaseBert_loss'].append(logs['result']['loss'])
    res_logs['BaseBert_metric'].append(logs['result']['metric'])

for lg in res_logs:
    res_logs[lg] = np.mean(res_logs[lg])

fig = px.bar(x=list(res_logs.keys()), y=list(res_logs.values()), 
             title="Гистограмма значений BaseBert", 
             labels={'x': 'Метрика', 'y': 'Значение'}, 
             color=list(res_logs.values()))

fig.update_layout(showlegend=False)  
fig.show()

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]