In [1]:
from transformers import AutoTokenizer, AutoModel
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt
import plotly.express as px
import torch
from torch import nn 
from torch.optim import AdamW
from sklearn.metrics import accuracy_score
import sys 
import os
sys.path.append('../')

from src.bayes.bayes_models import *
from src.trainer import CustomTrainer
from custom_datasets.bert_bayes import *
from src.utils.data_utils import load_exp_model

In [None]:
model = BayesModule("bert-base-cased", 10, AdamW, 3e-4)
tokenizer = model.tokenizer
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
metric = accuracy_score
batch_size = 8
n_epochs = 5
datas = [
    load_bigbenchhard(tokenizer, batch_size), 
    load_multiscam(tokenizer, batch_size), 
    load_llmops(tokenizer, batch_size)
]
res_logs = {
    "BaseBert_loss": [],
    "BaseBert_metric": [],
}
for train_loader, test_loader in datas:
    mx = 0
    for _, q in train_loader:
        mx = max(mx, max(q))
    model = BayesModule("bert-base-cased", mx+1, AdamW, 3e-4)
    trainer = CustomTrainer(model, optimizer, criterion, metric, 'cuda')
    logs = trainer.run_train(train_loader, test_loader)
    res_logs['BaseBert_loss'].append(logs['result']['loss'])
    res_logs['BaseBert_metric'].append(logs['result']['metric'])

for lg in res_logs:
    res_logs[lg] = np.mean(res_logs[lg])

fig = px.bar(x=list(res_logs.keys()), y=list(res_logs.values()), 
             title="Гистограмма значений BaseBert", 
             labels={'x': 'Метрика', 'y': 'Значение'}, 
             color=list(res_logs.values()))

fig.update_layout(showlegend=False)  
fig.show()

TypeError: BayesModule.__init__() missing 2 required positional arguments: 'optimizer' and 'param'

In [None]:
model = BayesModule("bert-base-cased", 10, AdamW, 3e-4)
tokenizer = model.tokenizer
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
metric = accuracy_score
batch_size = 8
n_epochs = 5
datas = [
    load_bigbenchhard(tokenizer, batch_size), 
    load_multiscam(tokenizer, batch_size), 
    load_llmops(tokenizer, batch_size)
]
res_logs = {
    "BayesBert_loss": [],
    "BayesBert_metric": [],
}
for train_loader, test_loader in datas:
    mx = 0
    for _, q in train_loader:
        mx = max(mx, max(q))
    lrs = [3e-4]
    model = BayesModel("bert-base-cased", mx+1, AdamW, 3e-4)
    cur_logs = {
        "BayesBert_loss": [],
        "BayesBert_metric": [],
    }
    for iter in range(3):
        trainer = CustomTrainer(model, optimizer, criterion, metric, 'cuda')
        logs = trainer.run_train(train_loader, test_loader)
        cur_logs['BayesBert_loss'].append(logs['result']['loss'])
        cur_logs['BayesBert_metric'].append(logs['result']['metric'])
        model.step(cur_logs['BayesBert_loss'][-1])
    res_logs['BayesBert_loss'].append(min(cur_logs['BayesBert_loss']))
    res_logs['BayesBert_metric'].append(max(cur_logs['BayesBert_metric']))

for lg in res_logs:
    res_logs[lg] = np.mean(res_logs[lg])

fig = px.bar(x=list(res_logs.keys()), y=list(res_logs.values()), 
             title="Гистограмма значений BaseBert", 
             labels={'x': 'Метрика', 'y': 'Значение'}, 
             color=list(res_logs.values()))

fig.update_layout(showlegend=False)  
fig.show()

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.8374104690044484]]
[[0.0003]]
[[0.4187052345022242]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.8374104690044484], [0.8592380729127438]]
[[0.0003], [0.00030888976824755185]]
[[0.4296190364563719]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.8374104690044484], [0.8592380729127438], [0.8300971350771316]]
[[0.0003], [0.00030888976824755185], [0.0006510418572541991]]
[[0.4150485675385658]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.8374104690044484], [0.8592380729127438], [0.8300971350771316], [0.8371449899166188]]
[[0.0003], [0.00030888976824755185], [0.0006510418572541991], [0.20070026251968887]]
[[0.4185724949583094]]


  samples = np.abs(np.random.multivariate_normal(mu.ravel(), cov, 500))


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.8374104690044484], [0.8592380729127438], [0.8300971350771316], [0.8371449899166188], [0.7939894871508821]]
[[0.0003], [0.00030888976824755185], [0.0006510418572541991], [0.20070026251968887], [431837.8980124456]]
[[0.39699474357544107]]


  samples = np.abs(np.random.multivariate_normal(mu.ravel(), cov, 500))


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

[[2.311596834287047]]
[[0.0003]]
[[1.1557984171435236]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

[[2.311596834287047], [2.314451437443495]]
[[0.0003], [0.0016216727722672972]]
[[1.1572257187217474]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

[[2.311596834287047], [2.314451437443495], [2.3138049717992546]]
[[0.0003], [0.0016216727722672972], [0.0026688401749311863]]
[[1.1569024858996273]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

[[2.311596834287047], [2.314451437443495], [2.3138049717992546], [2.316240653768182]]
[[0.0003], [0.0016216727722672972], [0.0026688401749311863], [645.5954893143509]]
[[1.158120326884091]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/160 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

[[2.311596834287047], [2.314451437443495], [2.3138049717992546], [2.316240653768182], [2.319893315434456]]
[[0.0003], [0.0016216727722672972], [0.0026688401749311863], [645.5954893143509], [86341631.2743172]]
[[1.159946657717228]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

[[2.9674726288805724]]
[[0.0003]]
[[1.4837363144402862]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

[[2.9674726288805724], [2.944992742230815]]
[[0.0003], [0.0036016079270571794]]
[[1.4724963711154075]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

[[2.9674726288805724], [2.944992742230815], [2.9232788355119768]]
[[0.0003], [0.0036016079270571794], [0.003342611079280397]]
[[1.4616394177559884]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

[[2.9674726288805724], [2.944992742230815], [2.9232788355119768], [2.9390434424082437]]
[[0.0003], [0.0036016079270571794], [0.003342611079280397], [0.2604151364299443]]
[[1.4695217212041218]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

[[2.9674726288805724], [2.944992742230815], [2.9232788355119768], [2.9390434424082437], [2.9763093084417362]]
[[0.0003], [0.0036016079270571794], [0.003342611079280397], [0.2604151364299443], [104443.3329378651]]
[[1.4881546542208681]]


In [None]:
from tqdm.auto import tqdm 

model = BayesModule("bert-base-cased", 10, AdamW, 3e-4)
tokenizer = model.tokenizer
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
metric = accuracy_score
batch_size = 8
n_epochs = 5
datas = [
    load_bigbenchhard(tokenizer, batch_size), 
    load_multiscam(tokenizer, batch_size), 
    load_llmops(tokenizer, batch_size)
]
res_logs = {
    "BayesEpochBert_loss": [],
    "BayesEpochBert_metric": [],
}
for train_loader, test_loader in datas:
    mx = 0
    for _, q in train_loader:
        mx = max(mx, max(q))
    lrs = [3e-4]
    model = BayesModel("bert-base-cased", mx+1, AdamW, 3e-4)
    cur_logs = {
        "BayesEpochBert_loss": [],
        "BayesEpochBert_metric": [],
    }
    for iter in range(3):
        trainer = CustomTrainer(model, optimizer, criterion, metric, 'cuda')

        pbar = tqdm(list(range(n_epochs)))
        last_log = None
        for epoch in pbar:
            train_logs = trainer.train_epoch(train_loader)
            last_log=train_logs
            if test_loader is not None:
                val_logs = trainer.val_epoch(test_loader)
                last_log=val_logs
                pbar.set_description(f"Train Loss: {train_logs['result']['loss']} Train Metric: {train_logs['result']['metric']}\nVal Loss: {val_logs['result']['loss']} Val Metric: {val_logs['result']['metric']}")
            else:
                pbar.set_description(f"Train Loss: {train_logs['result']['loss']} Train Metric: {train_logs['result']['metric']}")
            model.step(val_logs['result']['loss'])
        logs = last_log

        logs = trainer.run_train(train_loader, test_loader)
        cur_logs['BayesEpochBert_loss'].append(logs['result']['loss'])
        cur_logs['BayesEpochBert_metric'].append(logs['result']['metric'])
        model.step(cur_logs['BayesEpochBert_loss'][-1])
    res_logs['BayesEpochBert_loss'].append(min(cur_logs['BayesEpochBert_loss']))
    res_logs['BayesEpochBert_metric'].append(max(cur_logs['BayesEpochBert_metric']))

for lg in res_logs:
    res_logs[lg] = np.mean(res_logs[lg])

fig = px.bar(x=list(res_logs.keys()), y=list(res_logs.values()), 
             title="Гистограмма значений BaseBert", 
             labels={'x': 'Метрика', 'y': 'Значение'}, 
             color=list(res_logs.values()))

fig.update_layout(showlegend=False)  
fig.show()

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.7595576339579643]]
[[0.0003]]
[[0.37977881697898214]]


  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.7595576339579643], [0.8035174849185538]]
[[0.0003], [0.00016845099385258189]]
[[0.4017587424592769]]


  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.7595576339579643], [0.8035174849185538], [0.8283430528133473]]
[[0.0003], [0.00016845099385258189], [4.5158002113292834e-05]]
[[0.41417152640667365]]


  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.7595576339579643], [0.8035174849185538], [0.8283430528133473], [0.8700362418560271]]
[[0.0003], [0.00016845099385258189], [4.5158002113292834e-05], [4.171955075092312e-05]]
[[0.43501812092801356]]


  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.7595576339579643], [0.8035174849185538], [0.8283430528133473], [0.8700362418560271], [0.8800795471414606]]
[[0.0003], [0.00016845099385258189], [4.5158002113292834e-05], [4.171955075092312e-05], [0.02626273253069258]]
[[0.4400397735707303]]



covariance is not symmetric positive-semidefinite.



  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.7595576339579643], [0.8035174849185538], [0.8283430528133473], [0.8700362418560271], [0.8800795471414606], [0.9091100515203273]]
[[0.0003], [0.00016845099385258189], [4.5158002113292834e-05], [4.171955075092312e-05], [0.02626273253069258], [98.1905944607836]]
[[0.45455502576016366]]


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.7595576339579643], [0.8035174849185538], [0.8283430528133473], [0.8700362418560271], [0.8800795471414606], [0.9091100515203273], [0.9256889693280483]]
[[0.0003], [0.00016845099385258189], [4.5158002113292834e-05], [4.171955075092312e-05], [0.02626273253069258], [98.1905944607836], [989230.5582597057]]
[[0.46284448466402417]]



covariance is not symmetric positive-semidefinite.



  0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[[0.7595576339579643], [0.8035174849185538], [0.8283430528133473], [0.8700362418560271], [0.8800795471414606], [0.9091100515203273], [0.9256889693280483], [0.8739601868264218]]
[[0.0003], [0.00016845099385258189], [4.5158002113292834e-05], [4.171955075092312e-05], [0.02626273253069258], [98.1905944607836], [989230.5582597057], [3716038183.772287]]
[[0.4369800934132109]]


LinAlgError: Singular matrix