# Notebook pro trénink s destilací nad datasetem CIFAR10
V tomto notebooku je trénován MobileNetV2 nad datasetem CIFAR10, jako učitelsý model je využíván finetunued ViT nad stejným datasetem. 

MobileNetV2 je používán s náhodnou inicializací, tréninkem pouze klasifikační hlavy inicializovaného (předtrénovaného nad ImageNetem) MobileNetuV2 a trénink celého modelu, taktéž inicializovaného. Tyto tři úlohy jsou trénovány bězným způsobem a také s pomocí destilace výše zmíněného modelu.  

Při destilaci je využíváno předpočítaných logitů ze sešitu precompute_logits.

In [1]:
%pip install transformers[torch] huggingface_hub datasets evaluate torchvision optuna

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## Import knihoven a definice metod

In [2]:
from transformers import Trainer, TrainingArguments, MobileNetV2Config, MobileNetV2ForImageClassification, EarlyStoppingCallback
from torchvision import transforms
from torch.utils.data import Dataset
import torch.nn.functional as F
from PIL import Image
import torch.nn as nn
import numpy as np
import evaluate
import random
import pickle
import optuna
import torch
import os 

Resetování náhodného seedu pro replikovatelnost výsledků.
Zřejmě je možné části odebrat.

TODO: Odebrat zbytečná nastavení.

In [3]:
def reset_seed(seed=42):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

Nový wrapper, který pracuje přímo se soubory staženého a upraveného datasetu CIFAR10.
Využití načtení pomocí metody jako dříve není možné kvůli jiné checksum. 

Zároveň se již dotahují logity přímo z datasetu.

In [4]:
class CustomCIFAR10(Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None):
        self.root = root
        self.train = train
        self.transform = transform
        self.target_transform = target_transform

        self.data = []
        self.targets = []
        self.logits = []
        
        if self.train:
             for i in range(1, 6):
                 data_file = os.path.join(self.root, 'cifar-10-batches-py', f'data_batch_{i}')
                 with open(data_file, 'rb') as fo:
                     dict = pickle.load(fo, encoding='bytes')
                     self.data.append(dict[b'data'])
                     self.targets.extend(dict[b'labels'])
                     self.logits.extend(dict[b'logits'])  
        else:
            data_file = os.path.join(self.root, 'cifar-10-batches-py', 'test_batch')
            with open(data_file, 'rb') as fo:
                dict = pickle.load(fo, encoding='bytes')
                self.data.append(dict[b'data'])
                self.targets.extend(dict[b'labels'])
                self.logits.extend(dict[b'logits'])  

        self.data = np.concatenate(self.data, axis=0)
        self.targets = np.array(self.targets)
        self.logits = np.array(self.logits)


    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index].reshape(3, 32, 32).transpose(1, 2, 0)
        label = self.targets[index]
        logit = self.logits[index]
        
        image = Image.fromarray(image.astype('uint8'), 'RGB')
        logit = torch.tensor(logit, dtype=torch.float)
        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            target = self.target_transform(target)
            
        return {
            'pixel_values': image,
            'labels': label,
            'logits': logit
        }


Definice accuracy metriky pro trénování modelu.

In [5]:
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    
    accuracy = accuracy_metric.compute(predictions=np.argmax(predictions, axis=1), references=labels)
    precision = precision_metric.compute(predictions=np.argmax(predictions, axis=1), references=labels, average='macro')
    recall = recall_metric.compute(predictions=np.argmax(predictions, axis=1), references=labels, average='macro')
    f1 = f1_metric.compute(predictions=np.argmax(predictions, axis=1), references=labels, average='macro')

    return {
        "accuracy": accuracy["accuracy"],
        "precision": precision["precision"],
        "recall": recall["recall"],
        "f1": f1["f1"]
    }

Trénovací argumenty pro trainer. 

In [6]:
class Custom_training_args(TrainingArguments):
    def __init__(self, lambda_param, temperature, *args, **kwargs):
        super().__init__(*args, **kwargs)    
        self.lambda_param = lambda_param
        self.temperature = temperature

In [7]:
def get_training_args(output_dir:str, logging_dir:str, remove_unused_columns:bool):
    return (
        Custom_training_args(
        output_dir=output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=5e-5, #Defaultní hodnota 
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        num_train_epochs=20,
        weight_decay=0.01,
        seed = 42,  #Defaultní hodnota 
        metric_for_best_model="f1",
        load_best_model_at_end=True,
        fp16=False, 
        logging_dir=logging_dir,
        remove_unused_columns=remove_unused_columns,
        lambda_param = 0.5, 
        temperature = 5
    ))

Náhodně inicializovaný MobileNetV2.

In [8]:
def get_random_init_mobilenet():
    reset_seed(42)
    student_config = MobileNetV2Config()
    student_config.num_labels = 10
    return MobileNetV2ForImageClassification(student_config)

In [9]:
reset_seed(42)

In [10]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


Provedení transformací nad datasetem.

In [11]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

test = CustomCIFAR10(root='../data/10-logits', train=False, transform=transform)
train = CustomCIFAR10(root='../data/10-logits', train=True, transform=transform)

### Standardní trénink náhodně inicializovaného modelu. 

In [12]:
training_args = get_training_args("./results/cifar10-random", './logs/cifar10-random', True)

In [13]:
def hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 5e-4, log=True),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64]),
        "weight_decay": trial.suggest_float("weight_decay", 0, 1e-2, step=1e-3)
    }

In [14]:
pruner = optuna.pruners.HyperbandPruner(min_resource=5, reduction_factor=4)
sampler = optuna.samplers.TPESampler(seed=42, multivariate=True)



In [15]:
trainer = Trainer(
    args=training_args,
    train_dataset=train,
    eval_dataset=test,
    compute_metrics=compute_metrics,
    model_init=get_random_init_mobilenet,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)

In [16]:
best_trial = trainer.hyperparameter_search(
    direction="maximize",
    backend="optuna",
    hp_space=hp_space,
    compute_objective=lambda metrics: metrics["eval_f1"],
    pruner=pruner,
    sampler=sampler,
    n_trials=80
)

[I 2025-01-01 23:59:01,142] A new study created in memory with name: no-name-7c8c698d-acf7-4823-8e1e-c939c40ae55e


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.8701,1.732741,0.3678,0.372398,0.3678,0.344483
2,1.6867,1.530936,0.4364,0.462839,0.4364,0.436615
3,1.5182,1.405519,0.4907,0.496632,0.4907,0.48304
4,1.3986,1.321123,0.5247,0.526839,0.5247,0.514257
5,1.309,1.367839,0.5183,0.582725,0.5183,0.525557
6,1.2555,1.170599,0.5875,0.578888,0.5875,0.576408
7,1.1877,1.262417,0.5646,0.591695,0.5646,0.553731
8,1.1264,1.838385,0.4837,0.560011,0.4837,0.464588
9,1.0557,1.09533,0.6149,0.646942,0.6149,0.607511
10,1.0028,0.970768,0.6575,0.669641,0.6575,0.659654


[I 2025-01-02 00:47:58,443] Trial 0 finished with value: 0.6166393892685925 and parameters: {'learning_rate': 1.0253509690168497e-05, 'per_device_train_batch_size': 16, 'weight_decay': 0.001}. Best is trial 0 with value: 0.6166393892685925.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.3683,2.179152,0.1999,0.191525,0.1999,0.160997
2,2.1017,1.920544,0.2721,0.279815,0.2721,0.243458
3,1.9801,1.825716,0.304,0.306483,0.304,0.292905
4,1.9156,1.773887,0.3318,0.329674,0.3318,0.317815
5,1.8717,1.73213,0.345,0.362109,0.345,0.336211
6,1.8224,1.681978,0.3668,0.372624,0.3668,0.357031
7,1.7831,1.72638,0.3518,0.357473,0.3518,0.341093
8,1.7404,2.282531,0.2768,0.327437,0.2768,0.247815
9,1.7207,1.607467,0.4023,0.4183,0.4023,0.383987
10,1.6963,1.560524,0.4174,0.425994,0.4174,0.413825


[I 2025-01-02 01:49:21,289] Trial 1 finished with value: 0.44523251373587086 and parameters: {'learning_rate': 2.6364803038431666e-06, 'per_device_train_batch_size': 32, 'weight_decay': 0.007}. Best is trial 0 with value: 0.6166393892685925.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.3953,2.261401,0.1661,0.170477,0.1661,0.117911


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
[I 2025-01-02 01:52:38,380] Trial 2 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.3339,2.119989,0.2224,0.217168,0.2224,0.188756
2,2.0354,1.868605,0.2929,0.296512,0.2929,0.2768
3,1.9303,1.785219,0.323,0.329949,0.323,0.309395
4,1.8687,1.71348,0.3528,0.358296,0.3528,0.341369
5,1.8172,1.713547,0.3606,0.383036,0.3606,0.35216
6,1.7661,1.621587,0.3868,0.390573,0.3868,0.374108
7,1.7231,1.703216,0.378,0.392654,0.378,0.367826
8,1.6809,2.144433,0.3023,0.350604,0.3023,0.270876
9,1.6529,1.565638,0.4213,0.442707,0.4213,0.401052
10,1.629,1.499206,0.4443,0.452705,0.4443,0.441679


[I 2025-01-02 02:32:31,646] Trial 3 finished with value: 0.4338594876433339 and parameters: {'learning_rate': 3.1261029103110603e-06, 'per_device_train_batch_size': 32, 'weight_decay': 0.003}. Best is trial 0 with value: 0.6166393892685925.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.0084,1.504491,0.4468,0.468078,0.4468,0.437246
2,1.3802,1.259138,0.5523,0.584425,0.5523,0.555834
3,1.1892,1.074039,0.6204,0.634809,0.6204,0.617809
4,0.9473,1.004617,0.6637,0.671333,0.6637,0.656792
5,0.8336,1.035498,0.6475,0.707577,0.6475,0.658204
6,0.6949,0.976839,0.6889,0.707067,0.6889,0.678144
7,0.6378,1.067553,0.6739,0.705155,0.6739,0.665933
8,0.5076,1.38573,0.6307,0.676386,0.6307,0.619246
9,0.4312,0.978333,0.7078,0.747607,0.7078,0.704496
10,0.3603,0.855144,0.7358,0.752542,0.7358,0.738398


[I 2025-01-02 03:12:18,564] Trial 4 finished with value: 0.7051612455436723 and parameters: {'learning_rate': 4.480975918214949e-05, 'per_device_train_batch_size': 64, 'weight_decay': 0.005}. Best is trial 4 with value: 0.7051612455436723.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7674,1.187583,0.5744,0.587323,0.5744,0.566282
2,0.9553,0.897048,0.6902,0.724477,0.6902,0.693114
3,0.7559,0.683947,0.762,0.774543,0.762,0.761134
4,0.5551,0.645303,0.7868,0.787565,0.7868,0.782396
5,0.4628,0.728655,0.7662,0.803764,0.7662,0.773938
6,0.3383,0.675208,0.7933,0.815355,0.7933,0.787776
7,0.291,0.853146,0.7714,0.796303,0.7714,0.76721
8,0.1865,1.035169,0.7502,0.770742,0.7502,0.73652
9,0.1421,0.839619,0.7909,0.822381,0.7909,0.78979
10,0.1001,0.679991,0.8195,0.825433,0.8195,0.820145


[I 2025-01-02 03:56:17,622] Trial 5 finished with value: 0.7865497806697535 and parameters: {'learning_rate': 0.00013157287601765647, 'per_device_train_batch_size': 64, 'weight_decay': 0.0}. Best is trial 5 with value: 0.7865497806697535.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.007,1.587795,0.4223,0.451711,0.4223,0.399987


[I 2025-01-02 03:59:13,300] Trial 6 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7358,1.168352,0.5851,0.617776,0.5851,0.57725
2,0.9346,0.935845,0.6781,0.730667,0.6781,0.681333


[I 2025-01-02 04:05:20,539] Trial 7 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.451,2.24939,0.1723,0.137835,0.1723,0.121726


[I 2025-01-02 04:08:23,907] Trial 8 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9334,1.562255,0.4449,0.456515,0.4449,0.429148
2,1.2697,1.142043,0.5936,0.625074,0.5936,0.596526
3,1.0745,0.925571,0.6777,0.677464,0.6777,0.671984
4,0.815,0.88632,0.71,0.712231,0.71,0.70326
5,0.7021,0.873201,0.7006,0.748855,0.7006,0.708819
6,0.539,0.80343,0.7369,0.747884,0.7369,0.72968
7,0.4771,1.062685,0.6948,0.734492,0.6948,0.686168
8,0.3423,1.435754,0.6449,0.689787,0.6449,0.621531
9,0.2705,0.893126,0.743,0.772763,0.743,0.739122
10,0.2064,0.838018,0.7643,0.787154,0.7643,0.768363


[I 2025-01-02 04:48:09,440] Trial 9 finished with value: 0.7571503808239928 and parameters: {'learning_rate': 6.139426050898147e-05, 'per_device_train_batch_size': 64, 'weight_decay': 0.002}. Best is trial 5 with value: 0.7865497806697535.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.4137,1.24854,0.547,0.568961,0.547,0.535213


[I 2025-01-02 04:51:26,472] Trial 10 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.971,1.530671,0.4494,0.471789,0.4494,0.428965


[I 2025-01-02 04:54:30,037] Trial 11 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1915,1.070803,0.6164,0.626575,0.6164,0.608866
2,0.863,0.776705,0.7272,0.743918,0.7272,0.728822
3,0.6737,0.687962,0.7652,0.775329,0.7652,0.762128
4,0.5427,0.619863,0.7949,0.802723,0.7949,0.791474
5,0.4656,0.750789,0.7639,0.812754,0.7639,0.771767
6,0.3804,0.618239,0.8114,0.825079,0.8114,0.809961
7,0.3034,0.670298,0.81,0.822979,0.81,0.808273
8,0.2388,1.072393,0.7456,0.76945,0.7456,0.730947
9,0.18,0.740629,0.8134,0.830662,0.8134,0.813216
10,0.1419,0.655372,0.8353,0.843014,0.8353,0.836418


[I 2025-01-02 05:40:43,970] Trial 12 finished with value: 0.8184546045653107 and parameters: {'learning_rate': 0.000135907004719098, 'per_device_train_batch_size': 32, 'weight_decay': 0.002}. Best is trial 12 with value: 0.8184546045653107.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2868,1.157032,0.5831,0.591586,0.5831,0.570268


[I 2025-01-02 05:43:49,606] Trial 13 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1577,1.048958,0.6313,0.635528,0.6313,0.621107


[I 2025-01-02 05:46:53,991] Trial 14 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6075,1.109848,0.6132,0.619436,0.6132,0.601243


[I 2025-01-02 05:49:57,496] Trial 15 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6605,1.546782,0.4341,0.45656,0.4341,0.412766


[I 2025-01-02 05:53:02,528] Trial 16 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.4472,2.242061,0.1807,0.143483,0.1807,0.128761


[I 2025-01-02 05:56:04,832] Trial 17 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.5362,1.409183,0.4828,0.50059,0.4828,0.466408


[I 2025-01-02 05:59:09,114] Trial 18 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.213,1.05221,0.6191,0.629933,0.6191,0.608866


[I 2025-01-02 06:02:25,142] Trial 19 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2066,1.063382,0.6198,0.637609,0.6198,0.614094


[I 2025-01-02 06:05:42,234] Trial 20 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9424,1.514701,0.4541,0.475718,0.4541,0.433605


[I 2025-01-02 06:08:46,622] Trial 21 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.4616,1.388304,0.4994,0.517712,0.4994,0.476002


[I 2025-01-02 06:11:50,027] Trial 22 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.8338,1.752831,0.3644,0.37482,0.3644,0.338871


[I 2025-01-02 06:15:06,442] Trial 23 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.5203,1.530983,0.4515,0.497126,0.4515,0.431207


[I 2025-01-02 06:18:23,354] Trial 24 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.3239,1.236328,0.5624,0.58238,0.5624,0.544056


[I 2025-01-02 06:21:27,399] Trial 25 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.2969,1.852785,0.3093,0.315926,0.3093,0.284555


[I 2025-01-02 06:24:29,835] Trial 26 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.5708,1.500128,0.4636,0.486704,0.4636,0.444188


[I 2025-01-02 06:27:33,292] Trial 27 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6331,1.057979,0.6245,0.642431,0.6245,0.616596


[I 2025-01-02 06:30:35,482] Trial 28 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.8519,1.302542,0.5225,0.5273,0.5225,0.506887


[I 2025-01-02 06:33:37,522] Trial 29 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9193,1.799984,0.3457,0.345477,0.3457,0.322155


[I 2025-01-02 06:36:54,066] Trial 30 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.8654,1.397234,0.4892,0.517728,0.4892,0.476481


[I 2025-01-02 06:39:56,441] Trial 31 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.3164,1.858459,0.3083,0.312438,0.3083,0.286179


[I 2025-01-02 06:42:58,831] Trial 32 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.3965,1.291039,0.5311,0.553568,0.5311,0.516085


[I 2025-01-02 06:46:16,712] Trial 33 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.0571,1.595298,0.4102,0.435216,0.4102,0.395932


[I 2025-01-02 06:49:20,341] Trial 34 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.0579,1.595669,0.4168,0.426818,0.4168,0.399479


[I 2025-01-02 06:52:22,171] Trial 35 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1858,1.082619,0.6106,0.618222,0.6106,0.604473


[I 2025-01-02 06:55:26,192] Trial 36 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6977,1.131158,0.5911,0.620521,0.5911,0.585297


[I 2025-01-02 06:58:31,145] Trial 37 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9225,1.490484,0.4483,0.485375,0.4483,0.436142


[I 2025-01-02 07:01:34,324] Trial 38 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.1764,1.956935,0.2728,0.266753,0.2728,0.246739


[I 2025-01-02 07:04:40,958] Trial 39 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.0596,1.8584,0.2999,0.294364,0.2999,0.273723


[I 2025-01-02 07:07:47,247] Trial 40 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9756,1.837129,0.3225,0.320707,0.3225,0.300023


[I 2025-01-02 07:11:02,995] Trial 41 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7965,1.727597,0.3727,0.393913,0.3727,0.345369


[I 2025-01-02 07:14:18,762] Trial 42 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.2003,2.015417,0.2576,0.248484,0.2576,0.230786


[I 2025-01-02 07:17:37,480] Trial 43 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1591,1.063461,0.6206,0.642919,0.6206,0.610323
2,0.8507,0.867755,0.6984,0.743373,0.6984,0.700888
3,0.6655,0.659845,0.7746,0.784872,0.7746,0.77392
4,0.5406,0.588791,0.8051,0.808632,0.8051,0.802679
5,0.4658,0.657078,0.7846,0.818442,0.7846,0.788793
6,0.3837,0.588223,0.8124,0.833557,0.8124,0.811818
7,0.3105,0.669038,0.7994,0.816719,0.7994,0.798017
8,0.2449,1.02375,0.764,0.784828,0.764,0.749109
9,0.197,0.643717,0.8264,0.83912,0.8264,0.82653
10,0.1515,0.644903,0.8376,0.845637,0.8376,0.836981


[I 2025-01-02 08:19:46,996] Trial 44 finished with value: 0.8657942234681372 and parameters: {'learning_rate': 0.0001691374604609239, 'per_device_train_batch_size': 32, 'weight_decay': 0.002}. Best is trial 44 with value: 0.8657942234681372.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1732,1.029646,0.6304,0.647527,0.6304,0.629093


[I 2025-01-02 08:22:53,583] Trial 45 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1629,1.015507,0.6372,0.651568,0.6372,0.633206


[I 2025-01-02 08:25:58,063] Trial 46 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2286,1.071035,0.6144,0.635883,0.6144,0.60701


[I 2025-01-02 08:29:02,011] Trial 47 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.8234,1.325684,0.5139,0.548655,0.5139,0.508766


[I 2025-01-02 08:32:04,108] Trial 48 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.4913,1.36059,0.505,0.529291,0.505,0.490918


[I 2025-01-02 08:35:10,181] Trial 49 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6293,1.06381,0.6231,0.627216,0.6231,0.61875


[I 2025-01-02 08:38:06,034] Trial 50 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.1365,1.693558,0.3729,0.386415,0.3729,0.352934


[I 2025-01-02 08:41:08,479] Trial 51 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.3852,2.249013,0.1748,0.193869,0.1748,0.128068


[I 2025-01-02 08:44:24,733] Trial 52 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1496,0.97122,0.6537,0.665939,0.6537,0.647727


[I 2025-01-02 08:47:42,806] Trial 53 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2754,1.163602,0.5886,0.601674,0.5886,0.5725


[I 2025-01-02 08:50:48,666] Trial 54 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.8908,1.778811,0.3482,0.349138,0.3482,0.324418


[I 2025-01-02 08:54:05,329] Trial 55 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2845,1.1584,0.5862,0.611901,0.5862,0.578416


[I 2025-01-02 08:57:10,666] Trial 56 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.435,2.272146,0.1477,0.110085,0.1477,0.096591


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
[I 2025-01-02 09:00:16,171] Trial 57 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7562,1.186995,0.5713,0.588868,0.5713,0.5615
2,0.9466,0.873025,0.6976,0.723813,0.6976,0.700866
3,0.744,0.703449,0.7542,0.772017,0.7542,0.753986
4,0.5461,0.641315,0.7912,0.794859,0.7912,0.787486
5,0.4526,0.721466,0.7663,0.805936,0.7663,0.773413
6,0.3281,0.627953,0.8009,0.819969,0.8009,0.798314
7,0.2759,0.903049,0.756,0.796154,0.756,0.755129
8,0.1749,1.175558,0.7288,0.757436,0.7288,0.716713
9,0.1345,0.849059,0.79,0.816473,0.79,0.789817


[I 2025-01-02 09:27:36,252] Trial 58 finished with value: 0.7898174905543811 and parameters: {'learning_rate': 0.00013938335495096915, 'per_device_train_batch_size': 64, 'weight_decay': 0.003}. Best is trial 44 with value: 0.8657942234681372.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7949,1.307226,0.5281,0.558497,0.5281,0.514258


[I 2025-01-02 09:30:31,061] Trial 59 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7682,1.201812,0.5713,0.594202,0.5713,0.560891


[I 2025-01-02 09:33:33,309] Trial 60 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6282,1.08137,0.6143,0.628155,0.6143,0.606234


[I 2025-01-02 09:36:35,829] Trial 61 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.5992,1.057122,0.6265,0.63465,0.6265,0.617968


[I 2025-01-02 09:39:37,900] Trial 62 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9219,1.466374,0.465,0.492591,0.465,0.447282


[I 2025-01-02 09:42:32,688] Trial 63 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1517,1.025239,0.6444,0.656923,0.6444,0.643589


[I 2025-01-02 09:45:37,800] Trial 64 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.8503,1.716463,0.3784,0.381629,0.3784,0.353788


[I 2025-01-02 09:48:43,123] Trial 65 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.8953,1.697678,0.3633,0.356916,0.3633,0.346378


[I 2025-01-02 09:51:59,834] Trial 66 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1882,1.067581,0.6201,0.634565,0.6201,0.613788


[I 2025-01-02 09:55:06,496] Trial 67 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.8913,1.376721,0.5003,0.5054,0.5003,0.482058


[I 2025-01-02 09:58:08,805] Trial 68 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9722,1.504391,0.4422,0.459842,0.4422,0.422996


[I 2025-01-02 10:01:11,313] Trial 69 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7365,1.190998,0.5789,0.600939,0.5789,0.570389


[I 2025-01-02 10:04:13,516] Trial 70 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.4098,2.241018,0.1764,0.158579,0.1764,0.126092


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
[I 2025-01-02 10:07:17,857] Trial 71 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.1917,1.968132,0.2675,0.258725,0.2675,0.243569


[I 2025-01-02 10:10:22,321] Trial 72 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.1651,1.730219,0.3704,0.374917,0.3704,0.348944


[I 2025-01-02 10:13:24,602] Trial 73 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.4612,2.278863,0.134,0.111617,0.134,0.080842


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
[I 2025-01-02 10:16:26,450] Trial 74 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.3845,2.205562,0.1901,0.205238,0.1901,0.147229


[I 2025-01-02 10:19:30,701] Trial 75 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7131,1.561944,0.4227,0.432257,0.4227,0.400689


[I 2025-01-02 10:22:46,830] Trial 76 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7971,1.191304,0.578,0.590367,0.578,0.571984


[I 2025-01-02 10:25:49,006] Trial 77 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.4462,2.23876,0.1775,0.15646,0.1775,0.125903


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
[I 2025-01-02 10:28:52,219] Trial 78 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.3982,1.267022,0.5393,0.552113,0.5393,0.525163


[I 2025-01-02 10:31:58,127] Trial 79 pruned. 


## Definice destilačního tréninku

Třída, která upravuje hugging face trenéra pro destilaci znalostí. Nově pracuje s logity uloženými v datasetu.

In [17]:
class ImageDistilTrainer(Trainer):
    def __init__(self, model_init, *args, **kwargs):
        self.model_init = model_init
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
       
        super().__init__(model_init=model_init, *args, **kwargs)
        
        self.student = self.model_init()
        self.temperature = self.args.temperature
        self.lambda_param = self.args.lambda_param

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.student.to(device)


    def compute_loss(self, model, inputs, return_outputs=False):
        logits = inputs.pop("logits")

        student_output = model(**inputs)
        self.lambda_param = self.args.lambda_param
        self.temperature = self.args.temperature
        
        soft_teacher = F.softmax(logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

        student_target_loss = student_output.loss

        loss = ((1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss)
        return (loss, student_output) if return_outputs else loss

### Trénink náhodně inicializovaného modelu s pomocí destilace znalostí

In [18]:
reset_seed(42)

In [19]:
training_args = get_training_args("./results/cifar10-random-KD", './logs/cifar10-random-KD', False)

In [20]:
def hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 5e-4, log=True),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64]),
        "weight_decay": trial.suggest_float("weight_decay", 0, 1e-2, step=1e-3),
        "lambda_param": trial.suggest_float("lambda_param",0,1,step=.1),
        "temperature": trial.suggest_float("temperature", 2,7, step=.5)
    }

In [21]:
trainer = ImageDistilTrainer(
    args=training_args,
    train_dataset=train,
    eval_dataset=test,
    compute_metrics=compute_metrics,
    model_init=get_random_init_mobilenet,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)

In [22]:
pruner = optuna.pruners.HyperbandPruner(min_resource=1, reduction_factor=4)
sampler = optuna.samplers.TPESampler(seed=42, multivariate=True)



In [23]:
best_trial = trainer.hyperparameter_search(
    direction="maximize",
    backend="optuna",
    hp_space=hp_space,
    compute_objective=lambda metrics: metrics["eval_f1"],
    pruner=pruner,
    sampler=sampler,
    n_trials=80
)

[I 2025-01-02 10:31:58,319] A new study created in memory with name: no-name-c949043e-bbe5-4784-8e94-3711ffa24f5e


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.8261,1.794448,0.3562,0.363023,0.3562,0.332477
2,1.6639,1.638281,0.429,0.451496,0.429,0.430032
3,1.506,1.536666,0.4901,0.496502,0.4901,0.486227
4,1.3961,1.514397,0.5165,0.521185,0.5165,0.506952
5,1.3117,1.508902,0.5278,0.575985,0.5278,0.532914
6,1.2622,1.367924,0.5864,0.578568,0.5864,0.576121
7,1.1986,1.423449,0.5768,0.601245,0.5768,0.572559
8,1.1257,1.765431,0.5099,0.563286,0.5099,0.494695
9,1.0585,1.310685,0.6264,0.646086,0.6264,0.620627
10,1.0291,1.226045,0.662,0.67427,0.662,0.660745


[I 2025-01-02 11:14:36,692] Trial 0 finished with value: 0.6493830756403727 and parameters: {'learning_rate': 1.0253509690168497e-05, 'per_device_train_batch_size': 16, 'weight_decay': 0.001, 'lambda_param': 0.1, 'temperature': 2.0}. Best is trial 0 with value: 0.6493830756403727.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6508,1.645658,0.6416,0.651808,0.6416,0.630255
2,0.4678,1.807167,0.7252,0.751262,0.7252,0.725268
3,0.372,1.774964,0.796,0.79762,0.796,0.793375
4,0.3116,1.782879,0.8207,0.821734,0.8207,0.81995
5,0.2664,1.785201,0.8013,0.82813,0.8013,0.804786
6,0.2327,1.71434,0.8237,0.833181,0.8237,0.820774
7,0.2029,1.73078,0.816,0.830533,0.816,0.8146
8,0.1778,1.671794,0.7674,0.791711,0.7674,0.761982
9,0.1544,1.754976,0.85,0.855467,0.85,0.849485
10,0.136,1.758484,0.8534,0.860455,0.8534,0.8544


[I 2025-01-02 12:06:48,211] Trial 1 finished with value: 0.8775070142695549 and parameters: {'learning_rate': 0.00021766241123453658, 'per_device_train_batch_size': 32, 'weight_decay': 0.01, 'lambda_param': 0.9, 'temperature': 3.0}. Best is trial 1 with value: 0.8775070142695549.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.0266,1.83874,0.1926,0.190376,0.1926,0.143647


[I 2025-01-02 12:09:50,409] Trial 2 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.1866,2.000362,0.1742,0.134926,0.1742,0.122639
2,2.0727,1.905863,0.2219,0.216746,0.2219,0.173206
3,1.9723,1.817852,0.272,0.263115,0.272,0.25065
4,1.8452,1.777897,0.2949,0.285927,0.2949,0.278934
5,1.8184,1.783804,0.3045,0.331433,0.3045,0.292158
6,1.773,1.753153,0.3229,0.330373,0.3229,0.307589
7,1.7537,1.781673,0.3105,0.309499,0.3105,0.294976
8,1.7208,2.030673,0.261,0.288998,0.261,0.224095
9,1.6927,1.684682,0.3663,0.370945,0.3663,0.352001
10,1.6895,1.673737,0.3747,0.377942,0.3747,0.369316


[I 2025-01-02 13:10:45,151] Trial 3 finished with value: 0.39022170376656556 and parameters: {'learning_rate': 2.379522116387725e-06, 'per_device_train_batch_size': 64, 'weight_decay': 0.008, 'lambda_param': 0.2, 'temperature': 4.5}. Best is trial 1 with value: 0.8775070142695549.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6648,1.092367,0.4227,0.459306,0.4227,0.390251


[I 2025-01-02 13:13:50,371] Trial 4 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6859,1.289853,0.5999,0.617377,0.5999,0.590284


[I 2025-01-02 13:16:53,827] Trial 5 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7116,1.602961,0.1753,0.142933,0.1753,0.130146


[I 2025-01-02 13:20:11,131] Trial 6 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.4267,1.416727,0.2528,0.251181,0.2528,0.229749


[I 2025-01-02 13:23:28,723] Trial 7 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.1714,1.999753,0.1432,0.110177,0.1432,0.092197


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
[I 2025-01-02 13:26:31,907] Trial 8 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9743,1.83692,0.3201,0.3136,0.3201,0.298563


[I 2025-01-02 13:29:36,746] Trial 9 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6415,1.609443,0.6555,0.670329,0.6555,0.649023
2,0.4677,1.665924,0.7134,0.749679,0.7134,0.71601
3,0.3789,1.665458,0.7837,0.793046,0.7837,0.782119


[I 2025-01-02 13:38:45,576] Trial 10 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.5397,1.72343,0.4542,0.473658,0.4542,0.431718


[I 2025-01-02 13:42:01,117] Trial 11 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.0416,1.937631,0.2213,0.216447,0.2213,0.1908


[I 2025-01-02 13:45:16,357] Trial 12 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.8073,1.686657,0.6353,0.649883,0.6353,0.627057
2,0.5961,1.635921,0.7302,0.764407,0.7302,0.733458
3,0.4773,1.564477,0.7897,0.796995,0.7897,0.786955
4,0.4012,1.509111,0.8153,0.816522,0.8153,0.814655
5,0.3592,1.624529,0.806,0.831573,0.806,0.810094
6,0.3157,1.57329,0.8287,0.841215,0.8287,0.828446
7,0.2735,1.636031,0.829,0.839301,0.829,0.827579
8,0.2444,1.514421,0.8112,0.82412,0.8112,0.80896
9,0.2129,1.593595,0.8513,0.856352,0.8513,0.850958
10,0.1863,1.598731,0.8642,0.867108,0.8642,0.864882


[I 2025-01-02 14:51:05,124] Trial 13 finished with value: 0.8775142231826768 and parameters: {'learning_rate': 0.00010562422487867538, 'per_device_train_batch_size': 16, 'weight_decay': 0.006, 'lambda_param': 0.7000000000000001, 'temperature': 3.0}. Best is trial 13 with value: 0.8775142231826768.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.8636,1.897928,0.5118,0.516317,0.5118,0.485337
2,0.641,2.023914,0.6657,0.698177,0.6657,0.66941
3,0.4992,1.995993,0.7334,0.740914,0.7334,0.727995
4,0.4063,2.107063,0.781,0.782289,0.781,0.777834
5,0.3424,2.191655,0.7478,0.795487,0.7478,0.754354
6,0.2992,2.152781,0.8003,0.812695,0.8003,0.797772
7,0.2569,2.111333,0.7903,0.80694,0.7903,0.788578
8,0.2212,2.132159,0.7699,0.788331,0.7699,0.764204
9,0.1868,2.192748,0.8197,0.834518,0.8197,0.818955
10,0.1566,2.109633,0.8454,0.847213,0.8454,0.845268


[I 2025-01-02 15:36:57,665] Trial 14 finished with value: 0.8393768114963797 and parameters: {'learning_rate': 5.070319408723441e-05, 'per_device_train_batch_size': 16, 'weight_decay': 0.006, 'lambda_param': 1.0, 'temperature': 2.5}. Best is trial 13 with value: 0.8775142231826768.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.9173,1.454623,0.6557,0.663914,0.6557,0.648298


[I 2025-01-02 15:40:02,363] Trial 15 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.0642,1.465229,0.6686,0.681861,0.6686,0.668477
2,0.8165,1.360149,0.7421,0.761589,0.7421,0.743705


[I 2025-01-02 15:46:34,346] Trial 16 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1126,1.311957,0.4418,0.479723,0.4418,0.40881


[I 2025-01-02 15:49:36,026] Trial 17 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.5256,1.792656,0.4155,0.425195,0.4155,0.39269


[I 2025-01-02 15:52:38,130] Trial 18 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.9531,1.179795,0.3425,0.347448,0.3425,0.322543


[I 2025-01-02 15:55:41,693] Trial 19 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.9927,1.475629,0.5558,0.570293,0.5558,0.53655


[I 2025-01-02 15:58:57,928] Trial 20 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.9215,1.833634,0.5451,0.56258,0.5451,0.526838


[I 2025-01-02 16:02:14,403] Trial 21 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2212,1.650682,0.3403,0.346941,0.3403,0.318851


[I 2025-01-02 16:05:31,073] Trial 22 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.734,2.036732,0.6393,0.663486,0.6393,0.633929


[I 2025-01-02 16:08:47,252] Trial 23 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.594,1.468202,0.6346,0.659192,0.6346,0.629873


[I 2025-01-02 16:11:51,706] Trial 24 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6798,1.906213,0.6414,0.660885,0.6414,0.632942


[I 2025-01-02 16:14:55,823] Trial 25 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.7194,1.694801,0.6473,0.662804,0.6473,0.640999


[I 2025-01-02 16:18:12,209] Trial 26 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.9313,1.908035,0.6298,0.645091,0.6298,0.622301


[I 2025-01-02 16:21:16,050] Trial 27 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.9379,1.426069,0.563,0.575587,0.563,0.540515


[I 2025-01-02 16:24:32,325] Trial 28 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.9227,1.432718,0.6136,0.619687,0.6136,0.595488


[I 2025-01-02 16:27:48,218] Trial 29 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1975,1.556503,0.4274,0.450959,0.4274,0.409366


[I 2025-01-02 16:30:54,160] Trial 30 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.3001,1.957914,0.2562,0.25312,0.2562,0.227429


[I 2025-01-02 16:33:57,126] Trial 31 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.2843,2.159734,0.178,0.170701,0.178,0.130561


[I 2025-01-02 16:37:13,903] Trial 32 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9151,1.848692,0.3279,0.325475,0.3279,0.306867


[I 2025-01-02 16:40:18,042] Trial 33 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.7446,1.533749,0.5156,0.537144,0.5156,0.492623


[I 2025-01-02 16:43:35,296] Trial 34 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6718,1.563451,0.4275,0.465244,0.4275,0.411582


[I 2025-01-02 16:46:52,288] Trial 35 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.795,1.661769,0.388,0.401225,0.388,0.364767


[I 2025-01-02 16:50:09,319] Trial 36 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.105,1.519232,0.5907,0.600787,0.5907,0.576703


[I 2025-01-02 16:53:12,069] Trial 37 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7513,1.651152,0.2473,0.243097,0.2473,0.223189


[I 2025-01-02 16:56:15,048] Trial 38 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.9926,1.823636,0.5743,0.597146,0.5743,0.563294


[I 2025-01-02 16:59:32,107] Trial 39 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.3369,1.677657,0.36,0.373771,0.36,0.341022


[I 2025-01-02 17:02:35,414] Trial 40 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.329,2.152216,0.1382,0.116866,0.1382,0.084491


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
[I 2025-01-02 17:05:38,226] Trial 41 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9139,1.728208,0.2069,0.199192,0.2069,0.164699


[I 2025-01-02 17:08:40,898] Trial 42 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.2475,1.8881,0.2716,0.270982,0.2716,0.244194


[I 2025-01-02 17:11:43,044] Trial 43 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.4339,1.587345,0.4032,0.426729,0.4032,0.373247


[I 2025-01-02 17:15:00,289] Trial 44 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.9294,1.715594,0.6177,0.636102,0.6177,0.604075


[I 2025-01-02 17:18:04,823] Trial 45 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9583,1.835287,0.3054,0.295889,0.3054,0.279486


[I 2025-01-02 17:21:08,155] Trial 46 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.7165,1.402947,0.6234,0.63648,0.6234,0.611129


[I 2025-01-02 17:24:12,097] Trial 47 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.92,1.839371,0.5049,0.525058,0.5049,0.484643


[I 2025-01-02 17:27:28,696] Trial 48 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.6656,1.776729,0.3457,0.360688,0.3457,0.317261


[I 2025-01-02 17:30:45,151] Trial 49 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6235,1.582114,0.6191,0.617891,0.6191,0.597189


[I 2025-01-02 17:33:49,449] Trial 50 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.0977,1.372472,0.6512,0.667586,0.6512,0.646812


[I 2025-01-02 17:37:06,060] Trial 51 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7649,1.179881,0.5718,0.583125,0.5718,0.560764


[I 2025-01-02 17:40:09,002] Trial 52 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.0512,1.703946,0.5877,0.610459,0.5877,0.572156


[I 2025-01-02 17:43:27,231] Trial 53 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.0096,0.895696,0.2205,0.215084,0.2205,0.186628


[I 2025-01-02 17:46:30,657] Trial 54 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1238,1.408782,0.6416,0.65533,0.6416,0.638465


[I 2025-01-02 17:49:47,553] Trial 55 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.0247,1.767879,0.4539,0.473591,0.4539,0.430663


[I 2025-01-02 17:52:51,986] Trial 56 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7054,1.701337,0.3864,0.410349,0.3864,0.363824


[I 2025-01-02 17:56:09,362] Trial 57 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9377,1.742919,0.2577,0.248751,0.2577,0.226796


[I 2025-01-02 17:59:11,668] Trial 58 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2921,1.788777,0.4656,0.496162,0.4656,0.440067


[I 2025-01-02 18:02:28,588] Trial 59 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2658,1.626632,0.3021,0.321277,0.3021,0.279703


[I 2025-01-02 18:05:46,758] Trial 60 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.714,1.62575,0.6495,0.666292,0.6495,0.641938


[I 2025-01-02 18:08:52,208] Trial 61 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1942,1.513356,0.6262,0.653919,0.6262,0.623126


[I 2025-01-02 18:11:54,478] Trial 62 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1225,1.419581,0.6398,0.658429,0.6398,0.638277


[I 2025-01-02 18:14:58,139] Trial 63 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.8484,1.764161,0.2486,0.24747,0.2486,0.227665


[I 2025-01-02 18:18:14,721] Trial 64 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9052,1.743548,0.1456,0.129052,0.1456,0.096428


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
[I 2025-01-02 18:21:16,968] Trial 65 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2166,1.072678,0.614,0.631255,0.614,0.603854
2,0.9373,0.988015,0.657,0.698383,0.657,0.660612
3,0.7622,0.743426,0.7493,0.761935,0.7493,0.747534
4,0.6453,0.711066,0.7707,0.777858,0.7707,0.768524
5,0.5827,0.633615,0.7793,0.807959,0.7793,0.783624
6,0.5,0.722539,0.7801,0.801005,0.7801,0.773408
7,0.4413,0.743682,0.784,0.804393,0.784,0.783129
8,0.3678,1.221466,0.7243,0.751061,0.7243,0.710855


[I 2025-01-02 18:45:53,795] Trial 66 finished with value: 0.7108547303462907 and parameters: {'learning_rate': 0.0004192037404572582, 'per_device_train_batch_size': 32, 'weight_decay': 0.003, 'lambda_param': 0.0, 'temperature': 7.0}. Best is trial 13 with value: 0.8775142231826768.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1901,1.213143,0.6133,0.631634,0.6133,0.599124


[I 2025-01-02 18:48:57,411] Trial 67 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.241,1.092313,0.6043,0.62028,0.6043,0.594696


[I 2025-01-02 18:52:14,084] Trial 68 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.318,1.321736,0.5601,0.578742,0.5601,0.534668


[I 2025-01-02 18:55:31,834] Trial 69 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1984,1.200755,0.6198,0.640503,0.6198,0.608948


[I 2025-01-02 18:58:35,890] Trial 70 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.1979,1.085242,0.6149,0.634268,0.6149,0.604737


[I 2025-01-02 19:01:39,791] Trial 71 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,2.3289,2.153829,0.1377,0.113611,0.1377,0.085904


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
[I 2025-01-02 19:04:42,067] Trial 72 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.7528,1.753912,0.3764,0.382832,0.3764,0.353604


[I 2025-01-02 19:07:59,703] Trial 73 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2316,1.519022,0.3475,0.357001,0.3475,0.322437


[I 2025-01-02 19:11:05,193] Trial 74 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6653,1.486519,0.5762,0.590446,0.5762,0.561158


[I 2025-01-02 19:14:08,355] Trial 75 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6702,1.912372,0.6303,0.640771,0.6303,0.620805


[I 2025-01-02 19:17:10,884] Trial 76 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2416,1.490284,0.5533,0.57234,0.5533,0.541789


[I 2025-01-02 19:20:27,459] Trial 77 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.5943,1.438909,0.1707,0.148205,0.1707,0.118925


[I 2025-01-02 19:23:30,635] Trial 78 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6093,1.698447,0.6154,0.631447,0.6154,0.606751


[I 2025-01-02 19:26:47,257] Trial 79 pruned. 
