In [1]:
cd ..

/home/is/akiyoshi-n/my-project


In [2]:
import os
# 使用するGPUを指定. この環境変数の場所は，pytorchをimportする前に入れる
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from pathlib import Path
from datetime import datetime
from src.my_project.dataset import load_dataset_4class_Multi_classification, split_multilabel_data, load_text_dataset
from src.my_project.train_v2 import MultiClassClassifier
from sklearn.model_selection import train_test_split
from src.my_project.dataset import load_multiclass_dataset
import wandb
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

In [3]:
DATASET_PATH = Path('/home/is/akiyoshi-n/my-project/data')
# 本日の日付
timestamp = datetime.now().strftime("%Y-%m-%d")
# 出力先ディレクトリ
output_dir = Path('/home/is/akiyoshi-n/my-project/outputs/{}'.format(timestamp))
# モデル出力先ディレクトリ
output_model_dir = Path('/home/is/akiyoshi-n/my-project/outputs_model')

### パラメータの設定

In [4]:
# 最大トークン数
MAX_LEN = 128
# バッチサイズ
BATCH_SIZE = 16
# エポック数
NUM_EPOCHS = 100
# 学習率
LEARNING_RATE = 2e-5
# Cross Validation時のFold数
NUM_FOLDS = 3
# 早期停止のための忍耐値
PATIENCE = 5
# 乱数シード
SEED = 2023
# クラス数
NUM_LABELS = 4
# 閾値
THRESH = 0.5

In [12]:
# データの読み込み
data, class_name = load_dataset_4class_Multi_classification(f"{DATASET_PATH}/act_classification_final_ChatGPT4.xlsx")

In [13]:
data_labels_np = np.array(data['labels'])
data_labels_np.sum(axis=0)

array([122, 145, 358, 611])

### モデル精度評価

In [14]:
# 東北大BERT-v3
MODEL_NAME = 'cl-tohoku/bert-base-japanese-v3'
Classifier_model = MultiClassClassifier(model_name = MODEL_NAME, num_labels=NUM_LABELS, seed=SEED, thresh=THRESH)

In [15]:
# testデータと訓練に使用するデータに分割
dataset, eval_data = split_multilabel_data(data=data, test_size=0.2, SEED=SEED)

In [16]:
# dataset['labels']とtest_data['labels']の各列の合計値を出力
dataset_labels_np = np.array(dataset['labels'])
eval_data_labels_np = np.array(eval_data['labels'])
print(dataset_labels_np.sum(axis=0))
print(eval_data_labels_np.sum(axis=0))

[ 98 116 286 489]
[ 24  29  72 122]


In [17]:
trainer = Classifier_model.train_model(dataset, eval_data, MAX_LEN, NUM_EPOCHS, LEARNING_RATE, BATCH_SIZE, PATIENCE, output_dir, project_name='use_training', run_name='test')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/960 [00:00<?, ? examples/s]

Map:   0%|          | 0/240 [00:00<?, ? examples/s]



VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112536614139875, max=1.0…

Epoch,Training Loss,Validation Loss,Accuracy,Macro F1,Class F1,Class Recall,Class Precision
1,0.587,0.531818,0.016667,0.015385,"[0.0, 0.0, 0.0, 0.062]","[0.0, 0.0, 0.0, 0.033]","[0.0, 0.0, 0.0, 0.5]"
2,0.4816,0.500947,0.108333,0.114895,"[0.154, 0.0, 0.0, 0.306]","[0.083, 0.0, 0.0, 0.197]","[1.0, 0.0, 0.0, 0.686]"
3,0.4205,0.46299,0.15,0.296574,"[0.368, 0.375, 0.159, 0.284]","[0.292, 0.31, 0.097, 0.172]","[0.5, 0.474, 0.438, 0.808]"
4,0.3371,0.423856,0.4,0.459446,"[0.417, 0.423, 0.4, 0.598]","[0.417, 0.379, 0.278, 0.5]","[0.417, 0.478, 0.714, 0.744]"
5,0.2419,0.432253,0.504167,0.531505,"[0.486, 0.408, 0.597, 0.634]","[0.375, 0.345, 0.597, 0.533]","[0.692, 0.5, 0.597, 0.783]"
6,0.1693,0.491664,0.583333,0.527081,"[0.424, 0.356, 0.606, 0.723]","[0.292, 0.276, 0.597, 0.738]","[0.778, 0.5, 0.614, 0.709]"
7,0.0894,0.526732,0.604167,0.572546,"[0.486, 0.426, 0.649, 0.729]","[0.375, 0.345, 0.694, 0.705]","[0.692, 0.556, 0.61, 0.754]"
8,0.0622,0.564048,0.554167,0.56809,"[0.465, 0.526, 0.637, 0.644]","[0.417, 0.517, 0.806, 0.533]","[0.526, 0.536, 0.527, 0.812]"
9,0.0379,0.53074,0.6125,0.622557,"[0.565, 0.545, 0.653, 0.726]","[0.542, 0.517, 0.667, 0.697]","[0.591, 0.577, 0.64, 0.759]"


### unursoデータに適応

In [18]:
import pickle
with open(f'{DATASET_PATH}/unurso_Dataset_type.pkl', 'rb') as f:
    test_data = pickle.load(f)

In [23]:
predictions = trainer.predict(test_data)

In [31]:
import torch
# sigmoid関数を適応できるようにTensorに変換
logits = torch.from_numpy(predictions.predictions)
# シグモイド関数を適用し，確率に変換
predictions_proba = torch.sigmoid(logits)
# 閾値を設定し予測ラベルに変換
predictions_label = (predictions_proba>0.5).float()

In [32]:
predictions_label = predictions_label.numpy()
predictions_label

array([[0., 0., 1., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       ...,
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 0., 0.]], dtype=float32)

In [33]:
import re
import pandas as pd

# 正規表現パターンと対応する置換表現を辞書で定義
replacements = {
    r"@(\S*)": "[USR]", # ユーザ名を[USR]に置換
    r"http(\S*)": "[URL]", # URLを[URL]に置換
}

# 文中のマッチした箇所を対応する置換表現に置換する関数
def apply_regex(text):
    for pattern, replacement in replacements.items():
        text = re.sub(pattern, replacement, text)
    return text

# テキストファイルを読み込み、1つ目と3つ目のフィールドに前処理を適用し、それらを出力する関数
def preprocess_file(input_file):
    with open(input_file, 'r', encoding='utf-8') as file:
        processed_data = []
        for line in file:
            fields = line.strip().split('\t')
            # 2つ目のフィールド（前処理なし）と3つ目のフィールド（前処理あり）4つ目のフィールド（前処理なし）を取得
            first_field = fields[1] # ユーザ名
            third_field = apply_regex(fields[2]) # ツイート本文
            forth_field = fields[3] # ツイート日時
            processed_data.append([first_field, third_field, forth_field])
    return processed_data

In [34]:
# 関数を使用してファイルを前処理し、結果をDataFrameに格納
input_file = f'{DATASET_PATH}/unurso_users85.txt' # 入力ファイル名
processed_data = preprocess_file(input_file)
column_names = ['Name','text','time']
# DataFrameを作成
df = pd.DataFrame(processed_data, columns=column_names)

In [38]:
# 列の合計が0の個数
print(predictions_label.sum(axis=1).tolist().count(0))
print(len(predictions_label))

42048
150584


In [39]:
new_df = pd.DataFrame(predictions_label)
new_df

Unnamed: 0,0,1,2,3
0,0.0,0.0,1.0,0.0
1,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,1.0
...,...,...,...,...
150579,0.0,0.0,1.0,0.0
150580,0.0,0.0,0.0,0.0
150581,0.0,0.0,0.0,1.0
150582,0.0,0.0,1.0,0.0


In [40]:
action_df = pd.concat([df, new_df], axis=1)
action_df = action_df[['Name',0, 1, 2, 3]]

In [51]:
# 0~20の列の合計が0の行を削除
action_df = action_df.drop(action_df.index[predictions_label.sum(axis=1)==0])

In [52]:
action_df

Unnamed: 0,Name,0,1,2,3
0,3zhen,0.0,0.0,1.0,0.0
4,3zhen,0.0,0.0,0.0,1.0
6,3zhen,0.0,0.0,0.0,1.0
13,3zhen,0.0,0.0,0.0,1.0
18,3zhen,0.0,0.0,1.0,0.0
...,...,...,...,...,...
150576,zubora_sweet,0.0,0.0,1.0,0.0
150577,zubora_sweet,0.0,0.0,0.0,1.0
150579,zubora_sweet,0.0,0.0,1.0,0.0
150581,zubora_sweet,0.0,0.0,0.0,1.0


In [53]:
action_df.groupby('Name').mean().mean(axis=0)

0    0.147462
1    0.119273
2    0.175406
3    0.581685
dtype: float32

In [54]:
action_df.groupby('Name').mean().var(axis=0)

0    0.015059
1    0.012558
2    0.044599
3    0.049790
dtype: float32

### ursoデータ

In [55]:
import pickle
with open(f'{DATASET_PATH}/urso_Dataset_type.pkl', 'rb') as f:
    test_data = pickle.load(f)

In [57]:
predictions = trainer.predict(test_data)

In [58]:
import torch
# sigmoid関数を適応できるようにTensorに変換
logits = torch.from_numpy(predictions.predictions)
# シグモイド関数を適用し，確率に変換
predictions_proba = torch.sigmoid(logits)
# 閾値を設定し予測ラベルに変換
predictions_label = (predictions_proba>0.5).float()

In [64]:
predictions_label = predictions_label.numpy()
predictions_label

array([[0., 0., 0., 1.],
       [1., 0., 0., 0.],
       [1., 1., 0., 0.],
       ...,
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.]], dtype=float32)

In [65]:
# 関数を使用してファイルを前処理し、結果をDataFrameに格納
input_file = f'{DATASET_PATH}/urso_users.txt' # 入力ファイル名
processed_data = preprocess_file(input_file)
column_names = ['Name','text','time']
# DataFrameを作成
df = pd.DataFrame(processed_data, columns=column_names)

In [66]:
# 列の合計が0の個数
print(predictions_label.sum(axis=1).tolist().count(0))
print(len(predictions_label))

189310
510809


In [67]:
new_df = pd.DataFrame(predictions_label)
new_df

Unnamed: 0,0,1,2,3
0,0.0,0.0,0.0,1.0
1,1.0,0.0,0.0,0.0
2,1.0,1.0,0.0,0.0
3,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,1.0
...,...,...,...,...
510804,0.0,0.0,1.0,0.0
510805,0.0,0.0,0.0,1.0
510806,0.0,0.0,0.0,1.0
510807,0.0,0.0,0.0,1.0


In [68]:
action_df = pd.concat([df, new_df], axis=1)
action_df = action_df[['Name',0, 1, 2, 3]]

In [69]:
# 0~20の列の合計が0の行を削除
action_df = action_df.drop(action_df.index[predictions_label.sum(axis=1)==0])

In [70]:
action_df

Unnamed: 0,Name,0,1,2,3
0,468251793,0.0,0.0,0.0,1.0
1,468251793,1.0,0.0,0.0,0.0
2,468251793,1.0,1.0,0.0,0.0
4,468251793,0.0,0.0,0.0,1.0
5,468251793,0.0,0.0,0.0,1.0
...,...,...,...,...,...
510804,you1,0.0,0.0,1.0,0.0
510805,you1,0.0,0.0,0.0,1.0
510806,you1,0.0,0.0,0.0,1.0
510807,you1,0.0,0.0,0.0,1.0


In [71]:
action_df.groupby('Name').mean().mean(axis=0)

0    0.161198
1    0.106414
2    0.198668
3    0.556087
dtype: float32

In [72]:
action_df.groupby('Name').mean().var(axis=0)

0    0.008814
1    0.008805
2    0.035082
3    0.035768
dtype: float32

In [12]:
predictions = Classifier_model.predict(trainer, eval_dataset, MAX_LEN)

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [13]:
predictions

array([[0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [1., 0., 0., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],


In [24]:
from transformers import AutoTokenizer
from src.my_project.dataset import preprocess_for_Trainer
import numpy as np
import torch
# tokenizerの定義
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
eval_dataset_use = preprocess_for_Trainer(eval_dataset, tokenizer, max_len=MAX_LEN)
predictions = trainer.predict(eval_dataset_use)

# predictions.predictionsにsigmoid関数を適用し，確率に変換
predictions = torch.sigmoid(torch.from_numpy(predictions.predictions))
thresh = 0.5
# predictions = torch.where(predictions > THRESH, 1, 0)
predictions_label = (predictions>thresh).float()

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [25]:
type(predictions)

torch.Tensor

In [26]:
predictions[0]

tensor([0.0440, 0.0457, 0.3518, 0.3285])

In [23]:
if torch.sum(predictions[0]) == 0:
    print(True)

True


In [31]:
predictions_label[0]

tensor([0., 0., 0., 0.])

In [32]:
a[1]

array([0., 0., 0., 1.])

In [33]:
max_index = torch.argmax(predictions[0])
max_index
a[0][max_index] = 1
a[1][:-1] = predictions_label[0][:-1]

wandb: Network error (ReadTimeout), entering retry loop.


In [17]:
a = np.zeros((5,4))
for i in range(5):
    a[i] = predictions[i]

In [19]:
type(a)

numpy.ndarray

In [1]:
predictions

NameError: name 'predictions' is not defined

In [26]:
eval_dataset_use

Dataset({
    features: ['texts', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 200
})

In [21]:
from transformers import AutoTokenizer
from src.my_project.dataset import preprocess_for_Trainer
import numpy as np
import torch
# tokenizerの定義
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# データセットの前処理
eval_dataset_use = preprocess_for_Trainer(eval_dataset, tokenizer, max_len=MAX_LEN)
predictions = trainer.predict(eval_dataset_use)
# predictions.predictionsにsigmoid関数を適用し，確率に変換
predictions = torch.sigmoid(torch.from_numpy(predictions.predictions))
# 0.5以上の確率を1，それ以外を0に変換
thresh = 0.5
# predictions = torch.where(predictions > THRESH, 1, 0)
predictions = (predictions>thresh).float()

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [43]:
# 保存したTrainerを読み込む
model = AutoModelForSequenceClassification.from_pretrained('/home/is/akiyoshi-n/my-project/outputs/2024-02-04/cl-tohoku/bert-base-japanese-v32024-02-04T16-22-24/checkpoint-285')
trainer_v2 = Trainer(model=model)

In [55]:
# trainerの予測値とeval_dataset['labels']のAccuracyとF1を出す
from sklearn.metrics import accuracy_score, f1_score
accuracy = accuracy_score(eval_dataset['labels'], predictions)
f1 = f1_score(eval_dataset['labels'], predictions, average='macro')
print(f'Accuracy: {accuracy:.4f}')
print(f'F1: {f1:.4f}')

Accuracy: 0.6850
F1: 0.6169


In [10]:
prediction = Classifier_model.predict(trainer, eval_dataset, MAX_LEN)

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [13]:
prediction.label_ids

200

In [10]:
# 評価データでの評価
Classifier_model.evaluation(trainer, eval_dataset, MAX_LEN)

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

{'eval_loss': 0.33989983797073364,
 'eval_accuracy': 0.545,
 'eval_f1': 0.5774035592587015,
 'eval_runtime': 0.8056,
 'eval_samples_per_second': 248.263,
 'eval_steps_per_second': 6.207,
 'epoch': 10.0}

In [None]:
wandb.finish()

### Cross Validation

In [12]:
# データの読み込み
data, class_name = load_dataset_4class_Multi_classification(f"{DATASET_PATH}/act_classification_final_ChatGPT4.xlsx")

In [13]:
data_labels_np = np.array(data['labels'])
data_labels_np.sum(axis=0)

array([122, 145, 358, 611])

In [14]:
# testデータと訓練に使用するデータに分割
dataset, test_data = split_multilabel_data(data=data, test_size=0.1, SEED=SEED)

In [15]:
print(np.array(dataset['labels']).sum(axis=0))
print(np.array(test_data['labels']).sum(axis=0))

[110 130 322 550]
[12 15 36 61]


In [18]:
# 東北大BERT-v3
MODEL_NAME = 'cl-tohoku/bert-base-japanese-v3'
NUM_LABELS=4
Classifier_model = MultiClassClassifier(model_name=MODEL_NAME, num_labels=NUM_LABELS, seed=SEED, thresh=THRESH)

In [19]:
result = Classifier_model.cross_validation(dataset, test_data, MAX_LEN, NUM_EPOCHS, LEARNING_RATE, BATCH_SIZE, PATIENCE, NUM_FOLDS, output_dir, project_name='ChatGPT_data_4class_weight')

-----------------Fold: 1-----------------


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/719 [00:00<?, ? examples/s]

Map:   0%|          | 0/360 [00:00<?, ? examples/s]



VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112121534016398, max=1.0…

Epoch,Training Loss,Validation Loss,Accuracy,Macro F1,Class F1,Class Recall,Class Precision
1,0.5148,0.473223,0.041667,0.078205,"[0.2, 0.0, 0.0, 0.113]","[0.111, 0.0, 0.0, 0.06]","[1.0, 0.0, 0.0, 0.917]"
2,0.4598,0.466306,0.244444,0.241125,"[0.408, 0.0, 0.0, 0.556]","[0.278, 0.0, 0.0, 0.432]","[0.769, 0.0, 0.0, 0.782]"


Map:   0%|          | 0/121 [00:00<?, ? examples/s]

{'accuracy': 0.7083333333333334, 'macro_f1': 0.3355263157894737, 'class_f1': [0.5, 0.0, 0.0, 0.842], 'class_recall': [0.333, 0.0, 0.0, 1.0], 'class_precision': [1.0, 0.0, 0.0, 0.727]}
-----------------Fold: 2-----------------


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/719 [00:00<?, ? examples/s]

Map:   0%|          | 0/360 [00:00<?, ? examples/s]



VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112426449027326, max=1.0…

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [12]:
# 重みなしの場合の結果（cv=5）
average_accuracy = round(sum(d['eval_accuracy'] for d in result)/len(result), 3)
average_macro_f1 = round(sum(d['eval_macro_f1'] for d in result)/len(result), 3)
# クラスごとの平均値を計算
average_class_f1 = [round(sum(d['eval_class_f1'][i] for d in result) / len(result), 3) for i in range(len(result[0]['eval_class_f1']))]
average_class_recall = [round(sum(d['eval_class_recall'][i] for d in result) / len(result), 3) for i in range(len(result[0]['eval_class_recall']))]
average_class_precision = [round(sum(d['eval_class_precision'][i] for d in result) / len(result), 3) for i in range(len(result[0]['eval_class_precision']))]
print("Average accuracy:", average_accuracy)
print("Average Macro f1:", average_macro_f1)
print("Average Class f1:", average_class_f1)
print("Average Class recall:", average_class_recall)
print("Average Class precision:", average_class_precision)

Average accuracy: 0.634
Average Macro f1: 0.53
Average Class f1: [0.431, 0.355, 0.583, 0.753]
Average Class recall: [0.314, 0.294, 0.593, 0.812]
Average Class precision: [0.833, 0.501, 0.573, 0.703]


In [12]:
# # 重みありの場合の結果（cv=5）
# average_accuracy = round(sum(d['eval_accuracy'] for d in result)/len(result), 3)
# average_macro_f1 = round(sum(d['eval_macro_f1'] for d in result)/len(result), 3)
# # クラスごとの平均値を計算
# average_class_f1 = [round(sum(d['eval_class_f1'][i] for d in result) / len(result), 3) for i in range(len(result[0]['eval_class_f1']))]
# average_class_recall = [round(sum(d['eval_class_recall'][i] for d in result) / len(result), 3) for i in range(len(result[0]['eval_class_recall']))]
# average_class_precision = [round(sum(d['eval_class_precision'][i] for d in result) / len(result), 3) for i in range(len(result[0]['eval_class_precision']))]
# print("Average accuracy:", average_accuracy)
# print("Average Macro f1:", average_macro_f1)
# print("Average Class f1:", average_class_f1)
# print("Average Class recall:", average_class_recall)
# print("Average Class precision:", average_class_precision)

Average accuracy: 0.545
Average Macro f1: 0.531
Average Class f1: [0.562, 0.43, 0.461, 0.671]
Average Class recall: [0.432, 0.585, 0.455, 0.661]
Average Class precision: [0.806, 0.341, 0.472, 0.685]


In [13]:
print(np.array(dataset['labels']).sum(axis=0))

[110 130 322 550]


In [17]:
# テストサイズ×マルチラベル数の二次元リスト
majority_pred = [[0 for i in range(4)] for j in range(len(test_data['labels']))]
for i in range(len(test_data['labels'])):
    majority_pred[i][-1] = 1

In [18]:
# accuracyを計算
accuracy = round(accuracy_score(y_true=test_data['labels'], y_pred=majority_pred), 3)
# macro f1を計算
macro_f1 = round(f1_score(y_true=test_data['labels'], y_pred=majority_pred, average='macro', zero_division=0), 3)
# クラス毎のF1値を計算
class_f1 = [round(score, 3) for score in f1_score(y_true=test_data['labels'], y_pred=majority_pred, average=None, zero_division=0)]
# クラス毎のrecallを計算
class_recall = [round(score, 3) for score in recall_score(y_true=test_data['labels'], y_pred=majority_pred, average=None, zero_division=0)]
# クラス毎のprecisionを計算
class_precision = [round(score, 3) for score in precision_score(y_true=test_data['labels'], y_pred=majority_pred, average=None, zero_division=0)]
print("Average accuracy:", accuracy)
print("Average Macro f1:", macro_f1)
print("Average Class f1:", class_f1)
print("Average Class recall:", class_recall)
print("Average Class precision:", class_precision)

Average accuracy: 0.504
Average Macro f1: 0.168
Average Class f1: [0.0, 0.0, 0.0, 0.67]
Average Class recall: [0.0, 0.0, 0.0, 1.0]
Average Class precision: [0.0, 0.0, 0.0, 0.504]


### 21クラスマルチクラス分類

In [5]:
# データの読み込み
data, class_name = load_multiclass_dataset(f"{DATASET_PATH}/act_classification_final_ChatGPT4.xlsx")

In [6]:
len(class_name)

21

In [7]:
data_labels_np = np.array(data['labels'])
data_labels_np.sum(axis=0)

array([ 29,  21,  28,  47,   8,  64,  28,  11,  19,   8,  11, 138,  28,
        20,  83,  13,  54,  13,  11,   6, 611])

In [8]:
# testデータと訓練に使用するデータに分割
dataset, test_data = split_multilabel_data(data=data, test_size=0.1, SEED=SEED)

In [9]:
print(np.array(dataset['labels']).sum(axis=0))
print(np.array(test_data['labels']).sum(axis=0))

[ 26  19  25  42   7  58  25  10  17   7  10 124  25  18  75  12  49  12
  10   5 550]
[ 3  2  3  5  1  6  3  1  2  1  1 14  3  2  8  1  5  1  1  1 61]


In [10]:
# 東北大BERT-v3
MODEL_NAME = 'cl-tohoku/bert-base-japanese-v3'
# クラス数
NUM_LABELS = 21
Classifier_model_21 = MultiClassClassifier(model_name=MODEL_NAME, num_labels=21, seed=SEED, thresh=THRESH)

In [11]:
PATIENCE=10
result = Classifier_model_21.cross_validation(dataset, test_data, MAX_LEN, NUM_EPOCHS, LEARNING_RATE, BATCH_SIZE, PATIENCE, NUM_FOLDS, output_dir, project_name='ChatGPT_data_21class')

-----------------Fold: 1-----------------


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Parameter 'fn_kwargs'={'tokenizer': BertJapaneseTokenizer(name_or_path='cl-tohoku/bert-base-japanese-v3', vocab_size=32768, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False,

Map:   0%|          | 0/716 [00:00<?, ? examples/s]

Map:   0%|          | 0/359 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss,Accuracy,Macro F1,Class F1,Class Recall,Class Precision
1,0.6227,0.462342,0.401114,0.049918,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.265, 0.0, 0.108, 0.0, 0.0, 0.0, 0.675]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.36, 0.0, 0.118, 0.0, 0.0, 0.0, 0.978]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.209, 0.0, 0.1, 0.0, 0.0, 0.0, 0.516]"
2,0.3316,0.209219,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
3,0.1698,0.120525,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
4,0.1167,0.092991,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
5,0.0991,0.082648,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
6,0.0918,0.077472,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
7,0.0875,0.074597,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
8,0.0831,0.072734,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
9,0.0824,0.071648,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
10,0.0798,0.070647,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"


Map:   0%|          | 0/125 [00:00<?, ? examples/s]

{'accuracy': 0.75, 'macro_f1': 0.18483245149911817, 'class_f1': [0.0, 0.0, 0.0, 1.0, 0.0, 0.667, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.815], 'class_recall': [0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.667, 0.0, 0.0, 0.0, 0.0, 0.0, 0.917], 'class_precision': [0.0, 0.0, 0.0, 1.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.75, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.733]}
-----------------Fold: 2-----------------


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/717 [00:00<?, ? examples/s]

Map:   0%|          | 0/358 [00:00<?, ? examples/s]



VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112350184056494, max=1.0…

Epoch,Training Loss,Validation Loss,Accuracy,Macro F1,Class F1,Class Recall,Class Precision
1,0.6597,0.505458,0.002793,0.023366,"[0.01, 0.085, 0.0, 0.029, 0.049, 0.061, 0.017, 0.0, 0.0, 0.0, 0.0, 0.138, 0.0, 0.0, 0.0, 0.0, 0.07, 0.0, 0.0, 0.0, 0.032]","[0.111, 0.333, 0.0, 0.071, 0.5, 0.053, 0.111, 0.0, 0.0, 0.0, 0.0, 0.146, 0.0, 0.0, 0.0, 0.0, 0.125, 0.0, 0.0, 0.0, 0.016]","[0.005, 0.049, 0.0, 0.019, 0.026, 0.071, 0.009, 0.0, 0.0, 0.0, 0.0, 0.13, 0.0, 0.0, 0.0, 0.0, 0.049, 0.0, 0.0, 0.0, 0.75]"
2,0.3323,0.190807,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
3,0.1608,0.118237,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
4,0.1158,0.093527,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
5,0.0989,0.083187,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
6,0.0916,0.077879,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
7,0.0859,0.074821,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
8,0.0856,0.072791,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
9,0.0826,0.071454,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
10,0.0823,0.070913,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"


Map:   0%|          | 0/125 [00:00<?, ? examples/s]

{'accuracy': 0.75, 'macro_f1': 0.19177489177489176, 'class_f1': [0.0, 0.0, 0.0, 1.0, 0.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.182, 0.0, 0.0, 0.5, 0.0, 0.667, 0.0, 0.0, 0.0, 0.879], 'class_recall': [0.0, 0.0, 0.0, 1.0, 0.0, 0.667, 0.0, 0.0, 0.0, 0.0, 0.0, 0.111, 0.0, 0.0, 0.667, 0.0, 1.0, 0.0, 0.0, 0.0, 0.967], 'class_precision': [0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.4, 0.0, 0.5, 0.0, 0.0, 0.0, 0.806]}
-----------------Fold: 3-----------------


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/717 [00:00<?, ? examples/s]

Map:   0%|          | 0/358 [00:00<?, ? examples/s]



VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111231493867106, max=1.0)…

Epoch,Training Loss,Validation Loss,Accuracy,Macro F1,Class F1,Class Recall,Class Precision
1,0.5928,0.471905,0.094972,0.035027,"[0.0, 0.0, 0.0, 0.126, 0.03, 0.0, 0.0, 0.0, 0.105, 0.0, 0.0, 0.218, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.256]","[0.0, 0.0, 0.0, 0.429, 0.333, 0.0, 0.0, 0.0, 0.167, 0.0, 0.0, 0.381, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.175]","[0.0, 0.0, 0.0, 0.074, 0.016, 0.0, 0.0, 0.0, 0.077, 0.0, 0.0, 0.152, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.478]"
2,0.3306,0.205577,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
3,0.1689,0.125926,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
4,0.1178,0.09802,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
5,0.0972,0.087,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
6,0.091,0.081738,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
7,0.0861,0.078864,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
8,0.0832,0.076986,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
9,0.0822,0.075807,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
10,0.0804,0.074639,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"


Map:   0%|          | 0/125 [00:00<?, ? examples/s]

{'accuracy': 0.8103448275862069, 'macro_f1': 0.30558040567825306, 'class_f1': [0.667, 0.0, 0.0, 1.0, 0.0, 0.75, 0.0, 0.0, 0.667, 0.0, 0.0, 0.6, 0.0, 0.0, 0.857, 0.0, 1.0, 0.0, 0.0, 0.0, 0.877], 'class_recall': [0.5, 0.0, 0.0, 1.0, 0.0, 0.75, 0.0, 0.0, 1.0, 0.0, 0.0, 0.5, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.941], 'class_precision': [1.0, 0.0, 0.0, 1.0, 0.0, 0.75, 0.0, 0.0, 0.5, 0.0, 0.0, 0.75, 0.0, 0.0, 0.75, 0.0, 1.0, 0.0, 0.0, 0.0, 0.821]}


In [14]:
# 重みありの場合の結果（cv=5）解答
average_accuracy = round(sum(d['accuracy'] for d in result)/len(result), 3)
average_macro_f1 = round(sum(d['macro_f1'] for d in result)/len(result), 3)
# クラスごとの平均値を計算
average_class_f1 = [round(sum(d['class_f1'][i] for d in result) / len(result), 3) for i in range(len(result[0]['class_f1']))]
average_class_recall = [round(sum(d['class_recall'][i] for d in result) / len(result), 3) for i in range(len(result[0]['class_recall']))]
average_class_precision = [round(sum(d['class_precision'][i] for d in result) / len(result), 3) for i in range(len(result[0]['class_precision']))]
print("Average accuracy:", average_accuracy)
print("Average Macro f1:", average_macro_f1)
print("Average Class f1:", average_class_f1)
print("Average Class recall:", average_class_recall)
print("Average Class precision:", average_class_precision)

Average accuracy: 0.77
Average Macro f1: 0.227
Average Class f1: [0.222, 0.0, 0.0, 1.0, 0.0, 0.739, 0.0, 0.0, 0.222, 0.0, 0.0, 0.461, 0.0, 0.0, 0.719, 0.0, 0.556, 0.0, 0.0, 0.0, 0.857]
Average Class recall: [0.167, 0.0, 0.0, 1.0, 0.0, 0.806, 0.0, 0.0, 0.333, 0.0, 0.0, 0.37, 0.0, 0.0, 0.778, 0.0, 0.667, 0.0, 0.0, 0.0, 0.942]
Average Class precision: [0.333, 0.0, 0.0, 1.0, 0.0, 0.75, 0.0, 0.0, 0.167, 0.0, 0.0, 0.667, 0.0, 0.0, 0.717, 0.0, 0.5, 0.0, 0.0, 0.0, 0.787]


In [14]:
# 重みありの場合の結果（cv=5）
average_accuracy = round(sum(d['accuracy'] for d in result)/len(result), 3)
average_macro_f1 = round(sum(d['macro_f1'] for d in result)/len(result), 3)
# クラスごとの平均値を計算
average_class_f1 = [round(sum(d['class_f1'][i] for d in result) / len(result), 3) for i in range(len(result[0]['class_f1']))]
average_class_recall = [round(sum(d['class_recall'][i] for d in result) / len(result), 3) for i in range(len(result[0]['class_recall']))]
average_class_precision = [round(sum(d['class_precision'][i] for d in result) / len(result), 3) for i in range(len(result[0]['class_precision']))]
print("Average accuracy:", average_accuracy)
print("Average Macro f1:", average_macro_f1)
print("Average Class f1:", average_class_f1)
print("Average Class recall:", average_class_recall)
print("Average Class precision:", average_class_precision)

Average accuracy: 0.697
Average Macro f1: 0.205
Average Class f1: [0.0, 0.0, 0.0, 1.0, 0.0, 0.611, 0.0, 0.0, 0.0, 0.0, 0.0, 0.342, 0.0, 0.0, 0.764, 0.0, 0.483, 0.0, 0.333, 0.0, 0.775]
Average Class recall: [0.0, 0.0, 0.0, 1.0, 0.0, 0.556, 0.0, 0.0, 0.0, 0.0, 0.0, 0.375, 0.0, 0.0, 0.889, 0.0, 0.611, 0.0, 0.333, 0.0, 0.767]
Average Class precision: [0.0, 0.0, 0.0, 1.0, 0.0, 0.722, 0.0, 0.0, 0.0, 0.0, 0.0, 0.341, 0.0, 0.0, 0.681, 0.0, 0.4, 0.0, 0.333, 0.0, 0.791]


In [28]:
# # 重みなしの場合の結果（cv=5）
# average_accuracy = round(sum(d['eval_accuracy'] for d in result)/len(result), 3)
# average_macro_f1 = round(sum(d['eval_macro_f1'] for d in result)/len(result), 3)
# # クラスごとの平均値を計算
# average_class_f1 = [round(sum(d['eval_class_f1'][i] for d in result) / len(result), 3) for i in range(len(result[0]['eval_class_f1']))]
# average_class_recall = [round(sum(d['eval_class_recall'][i] for d in result) / len(result), 3) for i in range(len(result[0]['eval_class_recall']))]
# average_class_precision = [round(sum(d['eval_class_precision'][i] for d in result) / len(result), 3) for i in range(len(result[0]['eval_class_precision']))]
# print("Average accuracy:", average_accuracy)
# print("Average Macro f1:", average_macro_f1)
# print("Average Class f1:", average_class_f1)
# print("Average Class recall:", average_class_recall)
# print("Average Class precision:", average_class_precision)

Average accuracy: 0.711
Average Macro f1: 0.184
Average Class f1: [0.0, 0.0, 0.0, 0.952, 0.0, 0.622, 0.0, 0.0, 0.0, 0.0, 0.0, 0.443, 0.0, 0.0, 0.547, 0.0, 0.478, 0.0, 0.0, 0.0, 0.826]
Average Class recall: [0.0, 0.0, 0.0, 1.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.396, 0.0, 0.0, 0.556, 0.0, 0.667, 0.0, 0.0, 0.0, 0.887]
Average Class precision: [0.0, 0.0, 0.0, 0.917, 0.0, 0.833, 0.0, 0.0, 0.0, 0.0, 0.0, 0.522, 0.0, 0.0, 0.556, 0.0, 0.406, 0.0, 0.0, 0.0, 0.775]


### major class

In [10]:
print(np.array(dataset['labels']).sum(axis=0))

[ 26  19  25  42   7  58  25  10  17   7  10 124  25  18  75  12  49  12
  10   5 550]


In [11]:
# テストサイズ×マルチラベル数の二次元リスト
majority_pred = [[0 for i in range(21)] for j in range(len(test_data['labels']))]
for i in range(len(test_data['labels'])):
    majority_pred[i][-1] = 1

In [12]:
# accuracyを計算
accuracy = round(accuracy_score(y_true=test_data['labels'], y_pred=majority_pred), 3)
# macro f1を計算
macro_f1 = round(f1_score(y_true=test_data['labels'], y_pred=majority_pred, average='macro', zero_division=0), 3)
# クラス毎のF1値を計算
class_f1 = [round(score, 3) for score in f1_score(y_true=test_data['labels'], y_pred=majority_pred, average=None, zero_division=0)]
# クラス毎のrecallを計算
class_recall = [round(score, 3) for score in recall_score(y_true=test_data['labels'], y_pred=majority_pred, average=None, zero_division=0)]
# クラス毎のprecisionを計算
class_precision = [round(score, 3) for score in precision_score(y_true=test_data['labels'], y_pred=majority_pred, average=None, zero_division=0)]
print("Average accuracy:", accuracy)
print("Average Macro f1:", macro_f1)
print("Average Class f1:", class_f1)
print("Average Class recall:", class_recall)
print("Average Class precision:", class_precision)

Average accuracy: 0.488
Average Macro f1: 0.031
Average Class f1: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.656]
Average Class recall: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
Average Class precision: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.488]


### Cross validation

In [6]:
# データの読み込み
data, class_name = load_multiclass_dataset(f"{DATASET_PATH}/act_classification_final_ChatGPT4.xlsx")

In [7]:
# testデータと訓練に使用するデータに分割
dataset, test_data = split_multilabel_data(data=data, test_size=0.1, SEED=SEED)

In [8]:
# dataset['labels']とtest_data['labels']の各列の合計値を出力
dataset_labels_np = np.array(dataset['labels'])
test_data_labels_np = np.array(test_data['labels'])
print(dataset_labels_np.sum(axis=0))
print(test_data_labels_np.sum(axis=0))

[ 26  19  25  42   7  58  25  10  17   7  10 124  25  18  75  12  49  12
  10   5 550]
[ 3  2  3  5  1  6  3  1  2  1  1 14  3  2  8  1  5  1  1  1 61]


In [9]:
# 東北大BERT-v3
MODEL_NAME = 'cl-tohoku/bert-base-japanese-v3'
# クラス数
NUM_LABELS = 21
Classifier_model = MultiClassClassifier(model_name=MODEL_NAME, num_labels=NUM_LABELS, seed=SEED, thresh=THRESH)

In [10]:
result = Classifier_model.cross_validation(dataset, test_data, MAX_LEN, NUM_EPOCHS, LEARNING_RATE, BATCH_SIZE, PATIENCE, NUM_FOLDS, output_dir, project_name='4Classification_cross_validation_class_weights_v2')

-----------------Fold: 1-----------------


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Parameter 'fn_kwargs'={'tokenizer': BertJapaneseTokenizer(name_or_path='cl-tohoku/bert-base-japanese-v3', vocab_size=32768, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False,

Map:   0%|          | 0/716 [00:00<?, ? examples/s]

Map:   0%|          | 0/359 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.6197,0.466773,0.0,0.007


ValueError: Found array with 0 sample(s) (shape=(0, 21)) while a minimum of 1 is required.

In [31]:
result

[{'eval_loss': 0.06489665806293488,
  'eval_accuracy': 0.516260162601626,
  'eval_f1': 0.19292948191055528,
  'eval_runtime': 0.8932,
  'eval_samples_per_second': 275.412,
  'eval_steps_per_second': 17.913,
  'epoch': 25.0},
 {'eval_loss': 0.06510759890079498,
  'eval_accuracy': 0.5528455284552846,
  'eval_f1': 0.3209843153703635,
  'eval_runtime': 0.7542,
  'eval_samples_per_second': 326.194,
  'eval_steps_per_second': 21.216,
  'epoch': 26.0},
 {'eval_loss': 0.06674114614725113,
  'eval_accuracy': 0.532520325203252,
  'eval_f1': 0.19560974371084755,
  'eval_runtime': 0.7671,
  'eval_samples_per_second': 320.689,
  'eval_steps_per_second': 20.858,
  'epoch': 23.0}]

In [32]:
average_accuracy = sum(d['eval_accuracy'] for d in result)/len(result)
average_f1 = sum(d['eval_f1'] for d in result)/len(result)
print("Average accuracy:", average_accuracy)
print("Average f1:", average_f1)

Average accuracy: 0.5338753387533876
Average f1: 0.23650784699725547


### 全て０の場合を除かなかった場合の予測

In [10]:
# 東北大BERT-v3
MODEL_NAME = 'cl-tohoku/bert-base-japanese-v3'
# クラス数
NUM_LABELS = 21
Classifier_model_21 = MultiClassClassifier(model_name=MODEL_NAME, num_labels=21, seed=SEED, thresh=THRESH)

In [11]:
PATIENCE=5
result = Classifier_model_21.cross_validation(dataset, test_data, MAX_LEN, NUM_EPOCHS, LEARNING_RATE, BATCH_SIZE, PATIENCE, NUM_FOLDS, output_dir, project_name='ChatGPT_data_21class')

-----------------Fold: 1-----------------


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Parameter 'fn_kwargs'={'tokenizer': BertJapaneseTokenizer(name_or_path='cl-tohoku/bert-base-japanese-v3', vocab_size=32768, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False,

Map:   0%|          | 0/716 [00:00<?, ? examples/s]

Map:   0%|          | 0/359 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss,Accuracy,Macro F1,Class F1,Class Recall,Class Precision
1,0.5861,0.463071,0.016713,0.017715,"[0.0, 0.0, 0.0, 0.0, 0.042, 0.053, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.267, 0.0, 0.0, 0.011]","[0.0, 0.0, 0.0, 0.0, 0.5, 0.316, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.005]","[0.0, 0.0, 0.0, 0.0, 0.022, 0.029, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.182, 0.0, 0.0, 0.5]"
2,0.3241,0.196512,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
3,0.1625,0.116295,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
4,0.1147,0.092641,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
5,0.0988,0.082812,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
6,0.0917,0.077796,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
7,0.088,0.074941,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
8,0.0833,0.072949,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
9,0.0832,0.071685,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
10,0.0804,0.070463,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"


Map:   0%|          | 0/125 [00:00<?, ? examples/s]

{'eval_loss': 0.0631701648235321, 'eval_accuracy': 0.08, 'eval_macro_f1': 0.07984396555825127, 'eval_class_f1': [0.0, 0.0, 0.0, 0.571, 0.0, 0.286, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.364, 0.0, 0.364, 0.0, 0.0, 0.0, 0.092], 'eval_class_recall': [0.0, 0.0, 0.0, 0.4, 0.0, 0.167, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.4, 0.0, 0.0, 0.0, 0.049], 'eval_class_precision': [0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.667, 0.0, 0.333, 0.0, 0.0, 0.0, 0.75], 'eval_runtime': 0.4732, 'eval_samples_per_second': 264.155, 'eval_steps_per_second': 16.906, 'epoch': 24.0}
-----------------Fold: 2-----------------


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/717 [00:00<?, ? examples/s]

Map:   0%|          | 0/358 [00:00<?, ? examples/s]



VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112307881315548, max=1.0…

Epoch,Training Loss,Validation Loss,Accuracy,Macro F1,Class F1,Class Recall,Class Precision
1,0.597,0.461796,0.00838,0.013709,"[0.06, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.077, 0.07, 0.0, 0.0, 0.0, 0.0, 0.081, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.444, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 1.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.032, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.041, 0.036, 0.0, 0.0, 0.0, 0.0, 0.051, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
2,0.3138,0.186706,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
3,0.1582,0.117727,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
4,0.1148,0.092949,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
5,0.0986,0.082832,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
6,0.0912,0.077592,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
7,0.0856,0.074551,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
8,0.0852,0.072591,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
9,0.0818,0.071396,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
10,0.0803,0.07031,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"


Map:   0%|          | 0/125 [00:00<?, ? examples/s]

{'eval_loss': 0.06025061011314392, 'eval_accuracy': 0.336, 'eval_macro_f1': 0.13770356627499483, 'eval_class_f1': [0.0, 0.0, 0.0, 0.333, 0.0, 0.667, 0.0, 0.0, 0.0, 0.0, 0.0, 0.286, 0.0, 0.0, 0.333, 0.0, 0.667, 0.0, 0.0, 0.0, 0.606], 'eval_class_recall': [0.0, 0.0, 0.0, 0.2, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.214, 0.0, 0.0, 0.25, 0.0, 0.6, 0.0, 0.0, 0.0, 0.492], 'eval_class_precision': [0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.429, 0.0, 0.0, 0.5, 0.0, 0.75, 0.0, 0.0, 0.0, 0.789], 'eval_runtime': 0.4953, 'eval_samples_per_second': 252.374, 'eval_steps_per_second': 16.152, 'epoch': 29.0}
-----------------Fold: 3-----------------


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/717 [00:00<?, ? examples/s]

Map:   0%|          | 0/358 [00:00<?, ? examples/s]



VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112335117326842, max=1.0…

Epoch,Training Loss,Validation Loss,Accuracy,Macro F1,Class F1,Class Recall,Class Precision
1,0.6472,0.494113,0.103352,0.040485,"[0.0, 0.051, 0.0, 0.0, 0.051, 0.0, 0.0, 0.0, 0.033, 0.0, 0.026, 0.0, 0.0, 0.065, 0.0, 0.026, 0.0, 0.0, 0.0, 0.0, 0.599]","[0.0, 0.143, 0.0, 0.0, 0.667, 0.0, 0.0, 0.0, 0.333, 0.0, 0.25, 0.0, 0.0, 0.167, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.656]","[0.0, 0.031, 0.0, 0.0, 0.026, 0.0, 0.0, 0.0, 0.017, 0.0, 0.014, 0.0, 0.0, 0.04, 0.0, 0.014, 0.0, 0.0, 0.0, 0.0, 0.55]"
2,0.3332,0.187065,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
3,0.1572,0.119123,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
4,0.1139,0.096045,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
5,0.096,0.086605,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
6,0.0909,0.081802,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
7,0.0864,0.078951,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
8,0.0835,0.077149,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
9,0.0825,0.076071,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"
10,0.081,0.075075,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]"


Map:   0%|          | 0/125 [00:00<?, ? examples/s]

{'eval_loss': 0.058311909437179565, 'eval_accuracy': 0.272, 'eval_macro_f1': 0.1902841007402109, 'eval_class_f1': [0.0, 0.0, 0.0, 0.571, 0.0, 0.444, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4, 0.0, 0.0, 0.625, 0.0, 0.545, 0.0, 1.0, 0.0, 0.41], 'eval_class_recall': [0.0, 0.0, 0.0, 0.4, 0.0, 0.333, 0.0, 0.0, 0.0, 0.0, 0.0, 0.286, 0.0, 0.0, 0.625, 0.0, 0.6, 0.0, 1.0, 0.0, 0.279], 'eval_class_precision': [0.0, 0.0, 0.0, 1.0, 0.0, 0.667, 0.0, 0.0, 0.0, 0.0, 0.0, 0.667, 0.0, 0.0, 0.625, 0.0, 0.5, 0.0, 1.0, 0.0, 0.773], 'eval_runtime': 0.4707, 'eval_samples_per_second': 265.588, 'eval_steps_per_second': 16.998, 'epoch': 32.0}


In [13]:
# 重みなしの場合の結果（cv=5）
average_accuracy = round(sum(d['eval_accuracy'] for d in result)/len(result), 3)
average_macro_f1 = round(sum(d['eval_macro_f1'] for d in result)/len(result), 3)
# クラスごとの平均値を計算
average_class_f1 = [round(sum(d['eval_class_f1'][i] for d in result) / len(result), 3) for i in range(len(result[0]['eval_class_f1']))]
average_class_recall = [round(sum(d['eval_class_recall'][i] for d in result) / len(result), 3) for i in range(len(result[0]['eval_class_recall']))]
average_class_precision = [round(sum(d['eval_class_precision'][i] for d in result) / len(result), 3) for i in range(len(result[0]['eval_class_precision']))]
print("Average accuracy:", average_accuracy)
print("Average Macro f1:", average_macro_f1)
print("Average Class f1:", average_class_f1)
print("Average Class recall:", average_class_recall)
print("Average Class precision:", average_class_precision)

Average accuracy: 0.229
Average Macro f1: 0.136
Average Class f1: [0.0, 0.0, 0.0, 0.492, 0.0, 0.466, 0.0, 0.0, 0.0, 0.0, 0.0, 0.229, 0.0, 0.0, 0.441, 0.0, 0.525, 0.0, 0.333, 0.0, 0.369]
Average Class recall: [0.0, 0.0, 0.0, 0.333, 0.0, 0.333, 0.0, 0.0, 0.0, 0.0, 0.0, 0.167, 0.0, 0.0, 0.375, 0.0, 0.533, 0.0, 0.333, 0.0, 0.273]
Average Class precision: [0.0, 0.0, 0.0, 1.0, 0.0, 0.889, 0.0, 0.0, 0.0, 0.0, 0.0, 0.365, 0.0, 0.0, 0.597, 0.0, 0.528, 0.0, 0.333, 0.0, 0.771]
