In [1]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
from transformers import ViTFeatureExtractor

model_ckpt = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_ckpt)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
images_root = '../Affectnet/Manually_Annotated/Manually_Annotated_Images'

In [4]:
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import os

def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


class AffectNetDataset(Dataset):
    def __init__(self,
                 csv_file,
                 root,
                 mode='classification',
                 crop=False,
                 transform=None,
                 invalid_files=None):
        assert mode in ('valence', 'arousal', 'valence-arousal', 'classification')
        self.df = pd.read_csv(csv_file)
        self.root = root
        self.mode = mode
        self.crop = crop
        self.transform = transform
        self.invalid_files = invalid_files
        
        if self.invalid_files:
            self.df = self.df[~self.df['subDirectory_filePath'].isin(invalid_files)]
        
        self.df = self.df[~((self.df['expression'] == 9) | (self.df['expression'] == 10))].reset_index(drop=True)
    
    def __getitem__(self, idx):
        try:
            img = pil_loader(os.path.join(self.root, self.df['subDirectory_filePath'][idx]))
        except KeyError:
            raise IndexError
        if self.crop:
            img = img.crop((self.df['face_x'][idx],
                            self.df['face_y'][idx],
                            self.df['face_x'][idx]+self.df['face_width'][idx],
                            self.df['face_y'][idx]+self.df['face_height'][idx],))
        if self.transform:
            img = self.transform(img)
        if self.mode == 'classification':
            target = torch.tensor(self.df['expression'][idx])
        elif self.mode == 'valence':
            target = torch.tensor([self.df['valence'][idx]])
        elif self.mode == 'arousal':
            target = torch.tensor([self.df['arousal'][idx]])
        else:
            target = torch.tensor([self.df['valence'][idx],
                                   self.df['arousal'][idx]])
        return img.float(), target.float()

    def __len__(self):
        return len(self.df)

In [5]:
def collate_fn(examples):
    imgs, targets = zip(*examples)
    pixel_values = torch.stack(imgs)
    targets = torch.stack(targets)
    return {'pixel_values': pixel_values, 'labels': targets}

In [6]:
train_invalid_files = ['103/29a31ebf1567693f4644c8ba3476ca9a72ee07fe67a5860d98707a0a.jpg']
val_invalid_files = []

In [7]:
mode = 'valence'

In [8]:
from torchvision.transforms import (Compose,
                                    Normalize,
                                    Resize,
                                    ToTensor)

normalize = Normalize(mean=feature_extractor.image_mean,
                      std=feature_extractor.image_std)

transform = Compose([Resize(tuple(feature_extractor.size.values())),
                     ToTensor()])

train_dataset = AffectNetDataset('../Affectnet/training.csv',
                                 images_root,
                                 mode,
                                 transform=transform,
                                 invalid_files=train_invalid_files)
val_dataset = AffectNetDataset('../Affectnet/validation.csv',
                               images_root,
                               mode,
                               transform=transform,
                               invalid_files=val_invalid_files)

print('train:', len(train_dataset))
print('validation:', len(val_dataset))

train: 320739
validation: 4500


In [9]:
from transformers import Trainer
from KDEweightedMSE.losses import KDEWeightedMSESc

class CustomTrainer(Trainer):
    def __init__(self,
                 band_width = None,
                 model = None,
                 args = None,
                 data_collator = None,
                 train_dataset = None,
                 eval_dataset = None,
                 tokenizer = None,
                 model_init = None,
                 compute_metrics = None,
                 callbacks = None,
                 optimizers = (None, None),
                 preprocess_logits_for_metrics = None):
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
        
        data = train_dataset.df['valence']
        self.loss_fct = KDEWeightedMSESc(data=data, band_width=band_width, device=self.args.device, mode='divide', standardize=False)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get('labels')
        outputs = model(**inputs)
        logits = outputs.get('logits')
        loss = self.loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

In [10]:
from transformers import TrainingArguments
import wandb

wandb.init(project='KDE-weighted-MSE', name='nonstd-divide-valence')

args = TrainingArguments(
    f"nonstd-divide",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=1e-6,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=16,
    num_train_epochs=30,
    weight_decay=1e-3,
    load_best_model_at_end=True,
    logging_dir='logs',
    logging_strategy='steps',
    logging_steps=1000,
    remove_unused_columns=False,
    report_to='wandb'
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrkn[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
from sklearn.metrics import mean_squared_error
import numpy as np

def compute_metrics(eval_pred):
    preds, targets = eval_pred
    mse = mean_squared_error(targets, preds, squared=True)
    rmse = mean_squared_error(targets, preds, squared=False)
    return {'mse': mse, 'rmse': rmse}

In [12]:
import optuna
import wandb
from transformers import ViTForImageClassification, EarlyStoppingCallback, TrainingArguments

def objective(trial: optuna.Trial):
    band_width = trial.suggest_float('band_width', low=0.01, high=0.5)
    print('-'*20)
    print('bw=', band_width)
    print('-'*20)
    
    model = ViTForImageClassification.from_pretrained(
        model_ckpt,
        num_labels=1,
        problem_type='regression'
    )
    args = TrainingArguments(
        f"nonstd-divide-valence-bw={band_width}",
        save_strategy="epoch",
        evaluation_strategy="epoch",
        learning_rate=1e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=30,
        weight_decay=1e-3,
        load_best_model_at_end=True,
        metric_for_best_model='rmse',
        greater_is_better=False,
        logging_dir='logs',
        logging_strategy='steps',
        logging_steps=1000,
        remove_unused_columns=False,
    )
    trainer = CustomTrainer(
        band_width=band_width,
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        tokenizer=feature_extractor,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.0005)],
    )
    
    trainer.train()
    trainer.save_state()
    trainer.save_model()
    val_result = trainer.predict(val_dataset)
    print(val_result.metrics)
    return val_result.metrics['test_rmse']

In [13]:
n_trials = 10

In [14]:
study = optuna.create_study(study_name='nonstd-divide-valence', direction='minimize')
study.optimize(func=objective, n_trials=n_trials)

[32m[I 2023-05-09 13:46:28,191][0m A new study created in memory with name: nonstd-divide-valence[0m


--------------------
bw= 0.2234383873255575
--------------------


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 1000/300720 [12

{'loss': 9.5569, 'learning_rate': 9.966746475126364e-06, 'epoch': 0.1}


  1%|          | 2000/300720 [25:33<63:53:17,  1.30it/s]

{'loss': 7.716, 'learning_rate': 9.933492950252728e-06, 'epoch': 0.2}


  1%|          | 3000/300720 [38:17<64:38:12,  1.28it/s]

{'loss': 7.4426, 'learning_rate': 9.900239425379091e-06, 'epoch': 0.3}


  1%|▏         | 4000/300720 [51:03<62:22:07,  1.32it/s]

{'loss': 7.0663, 'learning_rate': 9.866985900505454e-06, 'epoch': 0.4}


  2%|▏         | 5000/300720 [1:03:49<62:04:15,  1.32it/s]

{'loss': 7.0389, 'learning_rate': 9.833732375631819e-06, 'epoch': 0.5}


  2%|▏         | 6000/300720 [1:16:35<61:03:38,  1.34it/s]

{'loss': 6.7667, 'learning_rate': 9.800478850758181e-06, 'epoch': 0.6}


  2%|▏         | 7000/300720 [1:29:24<63:18:12,  1.29it/s]

{'loss': 6.8033, 'learning_rate': 9.767225325884544e-06, 'epoch': 0.7}


  3%|▎         | 8000/300720 [1:42:16<62:58:08,  1.29it/s]

{'loss': 6.8254, 'learning_rate': 9.733971801010907e-06, 'epoch': 0.8}


  3%|▎         | 9000/300720 [1:55:10<60:39:52,  1.34it/s]

{'loss': 6.6262, 'learning_rate': 9.70071827613727e-06, 'epoch': 0.9}


  3%|▎         | 10000/300720 [2:08:02<62:33:06,  1.29it/s]

{'loss': 6.661, 'learning_rate': 9.667464751263635e-06, 'epoch': 1.0}


                                                           
  3%|▎         | 10024/300720 [2:09:31<52:11:37,  1.55it/s]

{'eval_loss': 13.64694595336914, 'eval_mse': 0.14866700768470764, 'eval_rmse': 0.3855735957622528, 'eval_runtime': 71.451, 'eval_samples_per_second': 62.98, 'eval_steps_per_second': 1.973, 'epoch': 1.0}


  4%|▎         | 11000/300720 [2:21:55<60:57:31,  1.32it/s]  

{'loss': 6.056, 'learning_rate': 9.634211226389998e-06, 'epoch': 1.1}


  4%|▍         | 12000/300720 [2:34:33<61:36:42,  1.30it/s]

{'loss': 6.0211, 'learning_rate': 9.600957701516362e-06, 'epoch': 1.2}


  4%|▍         | 13000/300720 [2:47:15<60:02:18,  1.33it/s]

{'loss': 5.8411, 'learning_rate': 9.567704176642725e-06, 'epoch': 1.3}


  5%|▍         | 14000/300720 [3:00:01<59:47:45,  1.33it/s]

{'loss': 5.7656, 'learning_rate': 9.534450651769088e-06, 'epoch': 1.4}


  5%|▍         | 15000/300720 [3:12:45<63:22:11,  1.25it/s]

{'loss': 5.9279, 'learning_rate': 9.501197126895452e-06, 'epoch': 1.5}


  5%|▌         | 16000/300720 [3:25:30<59:01:03,  1.34it/s]

{'loss': 5.877, 'learning_rate': 9.467943602021815e-06, 'epoch': 1.6}


  6%|▌         | 17000/300720 [3:38:17<58:59:15,  1.34it/s]

{'loss': 5.9921, 'learning_rate': 9.434690077148178e-06, 'epoch': 1.7}


  6%|▌         | 18000/300720 [3:51:08<60:56:59,  1.29it/s]

{'loss': 5.9372, 'learning_rate': 9.401436552274543e-06, 'epoch': 1.8}


  6%|▋         | 19000/300720 [4:04:02<61:43:02,  1.27it/s]

{'loss': 5.997, 'learning_rate': 9.368183027400906e-06, 'epoch': 1.9}


  7%|▋         | 20000/300720 [4:16:58<58:18:23,  1.34it/s]

{'loss': 5.7743, 'learning_rate': 9.334929502527269e-06, 'epoch': 2.0}


                                                           
  7%|▋         | 20048/300720 [4:18:47<50:47:08,  1.54it/s]

{'eval_loss': 13.282628059387207, 'eval_mse': 0.1489625722169876, 'eval_rmse': 0.38595670461654663, 'eval_runtime': 72.5957, 'eval_samples_per_second': 61.987, 'eval_steps_per_second': 1.942, 'epoch': 2.0}


  7%|▋         | 21000/300720 [4:30:53<57:57:15,  1.34it/s]  

{'loss': 5.0183, 'learning_rate': 9.301675977653633e-06, 'epoch': 2.09}


  7%|▋         | 22000/300720 [4:43:38<58:21:37,  1.33it/s]

{'loss': 5.0799, 'learning_rate': 9.268422452779996e-06, 'epoch': 2.19}


  8%|▊         | 23000/300720 [4:56:21<58:23:21,  1.32it/s]

{'loss': 5.0295, 'learning_rate': 9.235168927906359e-06, 'epoch': 2.29}


  8%|▊         | 24000/300720 [5:09:07<56:30:13,  1.36it/s]

{'loss': 5.1954, 'learning_rate': 9.201915403032722e-06, 'epoch': 2.39}


  8%|▊         | 25000/300720 [5:21:52<59:03:25,  1.30it/s]

{'loss': 5.1705, 'learning_rate': 9.168661878159085e-06, 'epoch': 2.49}


  9%|▊         | 26000/300720 [5:34:34<58:11:02,  1.31it/s]

{'loss': 5.0042, 'learning_rate': 9.13540835328545e-06, 'epoch': 2.59}


  9%|▉         | 27000/300720 [5:47:19<57:18:24,  1.33it/s]

{'loss': 5.2495, 'learning_rate': 9.102154828411812e-06, 'epoch': 2.69}


  9%|▉         | 28000/300720 [6:00:02<56:55:27,  1.33it/s]

{'loss': 5.1645, 'learning_rate': 9.068901303538175e-06, 'epoch': 2.79}


 10%|▉         | 29000/300720 [6:12:56<61:09:48,  1.23it/s]

{'loss': 5.1659, 'learning_rate': 9.03564777866454e-06, 'epoch': 2.89}


 10%|▉         | 30000/300720 [6:25:46<59:30:35,  1.26it/s]

{'loss': 5.0416, 'learning_rate': 9.002394253790902e-06, 'epoch': 2.99}


                                                           
 10%|█         | 30072/300720 [6:27:54<51:10:18,  1.47it/s]

{'eval_loss': 14.581066131591797, 'eval_mse': 0.15706320106983185, 'eval_rmse': 0.39631199836730957, 'eval_runtime': 72.5655, 'eval_samples_per_second': 62.013, 'eval_steps_per_second': 1.943, 'epoch': 3.0}


 10%|█         | 30072/300720 [6:27:56<58:11:27,  1.29it/s]


{'train_runtime': 23276.4169, 'train_samples_per_second': 413.387, 'train_steps_per_second': 12.92, 'train_loss': 6.092131996002562, 'epoch': 3.0}


100%|██████████| 141/141 [01:09<00:00,  2.04it/s]
[32m[I 2023-05-09 20:15:38,546][0m Trial 0 finished with value: 0.3855735957622528 and parameters: {'band_width': 0.2234383873255575}. Best is trial 0 with value: 0.3855735957622528.[0m


{'test_loss': 13.64694595336914, 'test_mse': 0.14866700768470764, 'test_rmse': 0.3855735957622528, 'test_runtime': 69.5084, 'test_samples_per_second': 64.74, 'test_steps_per_second': 2.029}
--------------------
bw= 0.3752841277577092
--------------------


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 1000/300720 [12

{'loss': 9.9216, 'learning_rate': 9.966746475126364e-06, 'epoch': 0.1}


  1%|          | 2000/300720 [25:18<63:41:26,  1.30it/s]

{'loss': 8.0054, 'learning_rate': 9.933492950252728e-06, 'epoch': 0.2}


  1%|          | 3000/300720 [37:56<63:44:36,  1.30it/s]

{'loss': 7.7766, 'learning_rate': 9.900239425379091e-06, 'epoch': 0.3}


  1%|▏         | 4000/300720 [50:35<61:52:50,  1.33it/s]

{'loss': 7.3837, 'learning_rate': 9.866985900505454e-06, 'epoch': 0.4}


  2%|▏         | 5000/300720 [1:03:16<61:51:14,  1.33it/s]

{'loss': 7.3033, 'learning_rate': 9.833732375631819e-06, 'epoch': 0.5}


  2%|▏         | 6000/300720 [1:16:00<60:54:59,  1.34it/s]

{'loss': 7.041, 'learning_rate': 9.800478850758181e-06, 'epoch': 0.6}


  2%|▏         | 7000/300720 [1:28:47<63:12:49,  1.29it/s]

{'loss': 7.0894, 'learning_rate': 9.767225325884544e-06, 'epoch': 0.7}


  3%|▎         | 8000/300720 [1:41:37<62:27:33,  1.30it/s]

{'loss': 7.1157, 'learning_rate': 9.733971801010907e-06, 'epoch': 0.8}


  3%|▎         | 9000/300720 [1:54:29<60:41:43,  1.34it/s]

{'loss': 6.9362, 'learning_rate': 9.70071827613727e-06, 'epoch': 0.9}


  3%|▎         | 10000/300720 [2:07:19<62:31:00,  1.29it/s]

{'loss': 6.9131, 'learning_rate': 9.667464751263635e-06, 'epoch': 1.0}


                                                           
  3%|▎         | 10024/300720 [2:08:46<53:08:39,  1.52it/s]

{'eval_loss': 14.173173904418945, 'eval_mse': 0.15163668990135193, 'eval_rmse': 0.3894055485725403, 'eval_runtime': 68.9456, 'eval_samples_per_second': 65.269, 'eval_steps_per_second': 2.045, 'epoch': 1.0}


  4%|▎         | 11000/300720 [2:21:06<60:25:17,  1.33it/s]  

{'loss': 6.2789, 'learning_rate': 9.634211226389998e-06, 'epoch': 1.1}


  4%|▍         | 12000/300720 [2:33:42<61:08:49,  1.31it/s]

{'loss': 6.2807, 'learning_rate': 9.600957701516362e-06, 'epoch': 1.2}


  4%|▍         | 13000/300720 [2:46:20<60:24:23,  1.32it/s]

{'loss': 6.1257, 'learning_rate': 9.567704176642725e-06, 'epoch': 1.3}


  5%|▍         | 14000/300720 [2:59:00<60:26:27,  1.32it/s]

{'loss': 6.0738, 'learning_rate': 9.534450651769088e-06, 'epoch': 1.4}


  5%|▍         | 15000/300720 [3:11:40<62:11:17,  1.28it/s]

{'loss': 6.172, 'learning_rate': 9.501197126895452e-06, 'epoch': 1.5}


  5%|▌         | 16000/300720 [3:24:23<59:22:51,  1.33it/s]

{'loss': 6.1056, 'learning_rate': 9.467943602021815e-06, 'epoch': 1.6}


  6%|▌         | 17000/300720 [3:37:07<59:18:00,  1.33it/s]

{'loss': 6.2465, 'learning_rate': 9.434690077148178e-06, 'epoch': 1.7}


  6%|▌         | 18000/300720 [3:49:52<60:55:08,  1.29it/s]

{'loss': 6.2177, 'learning_rate': 9.401436552274543e-06, 'epoch': 1.8}


  6%|▋         | 19000/300720 [4:02:43<61:12:20,  1.28it/s]

{'loss': 6.2342, 'learning_rate': 9.368183027400906e-06, 'epoch': 1.9}


  7%|▋         | 20000/300720 [4:15:36<58:29:48,  1.33it/s]

{'loss': 5.9981, 'learning_rate': 9.334929502527269e-06, 'epoch': 2.0}


                                                           
  7%|▋         | 20048/300720 [4:17:24<51:11:08,  1.52it/s]

{'eval_loss': 13.794588088989258, 'eval_mse': 0.15443077683448792, 'eval_rmse': 0.3929768204689026, 'eval_runtime': 70.4761, 'eval_samples_per_second': 63.851, 'eval_steps_per_second': 2.001, 'epoch': 2.0}


  7%|▋         | 21000/300720 [4:29:25<57:41:43,  1.35it/s]  

{'loss': 5.2422, 'learning_rate': 9.301675977653633e-06, 'epoch': 2.09}


  7%|▋         | 22000/300720 [4:42:05<57:44:21,  1.34it/s]

{'loss': 5.2919, 'learning_rate': 9.268422452779996e-06, 'epoch': 2.19}


  8%|▊         | 23000/300720 [4:54:45<58:14:29,  1.32it/s]

{'loss': 5.2786, 'learning_rate': 9.235168927906359e-06, 'epoch': 2.29}


  8%|▊         | 24000/300720 [5:07:26<56:19:01,  1.36it/s]

{'loss': 5.4159, 'learning_rate': 9.201915403032722e-06, 'epoch': 2.39}


  8%|▊         | 25000/300720 [5:20:08<59:15:37,  1.29it/s]

{'loss': 5.4269, 'learning_rate': 9.168661878159085e-06, 'epoch': 2.49}


  9%|▊         | 26000/300720 [5:32:48<58:02:46,  1.31it/s]

{'loss': 5.1781, 'learning_rate': 9.13540835328545e-06, 'epoch': 2.59}


  9%|▉         | 27000/300720 [5:45:31<56:53:09,  1.34it/s]

{'loss': 5.5014, 'learning_rate': 9.102154828411812e-06, 'epoch': 2.69}


  9%|▉         | 28000/300720 [5:58:15<57:32:47,  1.32it/s]

{'loss': 5.3683, 'learning_rate': 9.068901303538175e-06, 'epoch': 2.79}


 10%|▉         | 29000/300720 [6:11:05<60:32:54,  1.25it/s]

{'loss': 5.3108, 'learning_rate': 9.03564777866454e-06, 'epoch': 2.89}


 10%|▉         | 30000/300720 [6:23:55<57:34:58,  1.31it/s]

{'loss': 5.268, 'learning_rate': 9.002394253790902e-06, 'epoch': 2.99}


                                                           
 10%|█         | 30072/300720 [6:26:02<50:10:46,  1.50it/s]

{'eval_loss': 14.663272857666016, 'eval_mse': 0.15867534279823303, 'eval_rmse': 0.39834073185920715, 'eval_runtime': 71.4448, 'eval_samples_per_second': 62.986, 'eval_steps_per_second': 1.974, 'epoch': 3.0}


 10%|█         | 30072/300720 [6:26:03<57:54:32,  1.30it/s]


{'train_runtime': 23163.6785, 'train_samples_per_second': 415.399, 'train_steps_per_second': 12.982, 'train_loss': 6.348678668022917, 'epoch': 3.0}


100%|██████████| 141/141 [01:07<00:00,  2.07it/s]
[32m[I 2023-05-10 02:42:52,922][0m Trial 1 finished with value: 0.3894055485725403 and parameters: {'band_width': 0.3752841277577092}. Best is trial 0 with value: 0.3855735957622528.[0m


{'test_loss': 14.173173904418945, 'test_mse': 0.15163668990135193, 'test_rmse': 0.3894055485725403, 'test_runtime': 68.4639, 'test_samples_per_second': 65.728, 'test_steps_per_second': 2.059}
--------------------
bw= 0.13604422505592878
--------------------


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 1000/300720 [12

{'loss': 9.4514, 'learning_rate': 9.966746475126364e-06, 'epoch': 0.1}


  1%|          | 2000/300720 [25:21<63:35:04,  1.30it/s]

{'loss': 7.6146, 'learning_rate': 9.933492950252728e-06, 'epoch': 0.2}


  1%|          | 3000/300720 [38:04<63:42:03,  1.30it/s]

{'loss': 7.361, 'learning_rate': 9.900239425379091e-06, 'epoch': 0.3}


  1%|▏         | 4000/300720 [50:49<62:24:06,  1.32it/s]

{'loss': 6.9925, 'learning_rate': 9.866985900505454e-06, 'epoch': 0.4}


  2%|▏         | 5000/300720 [1:03:35<62:05:44,  1.32it/s]

{'loss': 7.0308, 'learning_rate': 9.833732375631819e-06, 'epoch': 0.5}


  2%|▏         | 6000/300720 [1:16:20<61:18:59,  1.34it/s]

{'loss': 6.7316, 'learning_rate': 9.800478850758181e-06, 'epoch': 0.6}


  2%|▏         | 7000/300720 [1:29:07<63:32:01,  1.28it/s]

{'loss': 6.743, 'learning_rate': 9.767225325884544e-06, 'epoch': 0.7}


  3%|▎         | 8000/300720 [1:41:57<62:24:26,  1.30it/s]

{'loss': 6.7798, 'learning_rate': 9.733971801010907e-06, 'epoch': 0.8}


  3%|▎         | 9000/300720 [1:54:48<60:29:38,  1.34it/s]

{'loss': 6.6029, 'learning_rate': 9.70071827613727e-06, 'epoch': 0.9}


  3%|▎         | 10000/300720 [2:07:40<63:55:31,  1.26it/s]

{'loss': 6.5753, 'learning_rate': 9.667464751263635e-06, 'epoch': 1.0}


  3%|▎         | 10024/300720 [2:07:57<53:05:19,  1.52it/s]
  3%|▎         | 10024/300720 [2:09:08<53:05:19,  1.52it/s]

{'eval_loss': 13.605753898620605, 'eval_mse': 0.14592194557189941, 'eval_rmse': 0.38199731707572937, 'eval_runtime': 70.3996, 'eval_samples_per_second': 63.921, 'eval_steps_per_second': 2.003, 'epoch': 1.0}


  4%|▎         | 11000/300720 [2:21:27<61:05:18,  1.32it/s]  

{'loss': 6.0001, 'learning_rate': 9.634211226389998e-06, 'epoch': 1.1}


  4%|▍         | 12000/300720 [2:34:03<61:40:06,  1.30it/s]

{'loss': 5.9774, 'learning_rate': 9.600957701516362e-06, 'epoch': 1.2}


  4%|▍         | 13000/300720 [2:46:42<60:26:30,  1.32it/s]

{'loss': 5.7745, 'learning_rate': 9.567704176642725e-06, 'epoch': 1.3}


  5%|▍         | 14000/300720 [2:59:23<59:57:01,  1.33it/s]

{'loss': 5.7414, 'learning_rate': 9.534450651769088e-06, 'epoch': 1.4}


  5%|▍         | 15000/300720 [3:12:02<62:02:35,  1.28it/s]

{'loss': 5.8557, 'learning_rate': 9.501197126895452e-06, 'epoch': 1.5}


  5%|▌         | 16000/300720 [3:24:44<59:37:21,  1.33it/s]

{'loss': 5.8156, 'learning_rate': 9.467943602021815e-06, 'epoch': 1.6}


  6%|▌         | 17000/300720 [3:37:30<59:02:55,  1.33it/s]

{'loss': 5.9551, 'learning_rate': 9.434690077148178e-06, 'epoch': 1.7}


  6%|▌         | 18000/300720 [3:50:18<60:20:56,  1.30it/s]

{'loss': 5.8408, 'learning_rate': 9.401436552274543e-06, 'epoch': 1.8}


  6%|▋         | 19000/300720 [4:03:08<61:23:13,  1.27it/s]

{'loss': 5.9034, 'learning_rate': 9.368183027400906e-06, 'epoch': 1.9}


  7%|▋         | 20000/300720 [4:16:03<58:52:17,  1.32it/s]

{'loss': 5.7082, 'learning_rate': 9.334929502527269e-06, 'epoch': 2.0}


  7%|▋         | 20048/300720 [4:16:40<50:47:18,  1.54it/s]
  7%|▋         | 20048/300720 [4:17:52<50:47:18,  1.54it/s]

{'eval_loss': 13.863944053649902, 'eval_mse': 0.1514253467321396, 'eval_rmse': 0.38913410902023315, 'eval_runtime': 72.2808, 'eval_samples_per_second': 62.257, 'eval_steps_per_second': 1.951, 'epoch': 2.0}


  7%|▋         | 21000/300720 [4:29:56<58:19:16,  1.33it/s]  

{'loss': 4.986, 'learning_rate': 9.301675977653633e-06, 'epoch': 2.09}


  7%|▋         | 22000/300720 [4:42:38<57:54:02,  1.34it/s]

{'loss': 5.03, 'learning_rate': 9.268422452779996e-06, 'epoch': 2.19}


  8%|▊         | 23000/300720 [4:55:20<58:25:37,  1.32it/s]

{'loss': 4.9662, 'learning_rate': 9.235168927906359e-06, 'epoch': 2.29}


  8%|▊         | 24000/300720 [5:08:03<56:11:32,  1.37it/s]

{'loss': 5.1345, 'learning_rate': 9.201915403032722e-06, 'epoch': 2.39}


  8%|▊         | 25000/300720 [5:20:47<59:25:08,  1.29it/s]

{'loss': 5.1277, 'learning_rate': 9.168661878159085e-06, 'epoch': 2.49}


  9%|▊         | 26000/300720 [5:33:30<58:13:39,  1.31it/s]

{'loss': 4.928, 'learning_rate': 9.13540835328545e-06, 'epoch': 2.59}


  9%|▉         | 27000/300720 [5:46:17<57:00:06,  1.33it/s]

{'loss': 5.1879, 'learning_rate': 9.102154828411812e-06, 'epoch': 2.69}


  9%|▉         | 28000/300720 [5:59:04<57:33:34,  1.32it/s]

{'loss': 5.0765, 'learning_rate': 9.068901303538175e-06, 'epoch': 2.79}


 10%|▉         | 29000/300720 [6:11:56<61:08:32,  1.23it/s]

{'loss': 5.0397, 'learning_rate': 9.03564777866454e-06, 'epoch': 2.89}


 10%|▉         | 30000/300720 [6:24:47<58:17:04,  1.29it/s]

{'loss': 4.9649, 'learning_rate': 9.002394253790902e-06, 'epoch': 2.99}


 10%|█         | 30072/300720 [6:25:43<50:32:24,  1.49it/s]
 10%|█         | 30072/300720 [6:26:54<50:32:24,  1.49it/s]

{'eval_loss': 14.746831893920898, 'eval_mse': 0.15573565661907196, 'eval_rmse': 0.39463356137275696, 'eval_runtime': 71.353, 'eval_samples_per_second': 63.067, 'eval_steps_per_second': 1.976, 'epoch': 3.0}


 10%|█         | 30072/300720 [6:26:56<58:02:28,  1.30it/s]


{'train_runtime': 23216.5659, 'train_samples_per_second': 414.453, 'train_steps_per_second': 12.953, 'train_loss': 6.028474497712651, 'epoch': 3.0}


100%|██████████| 141/141 [01:07<00:00,  2.08it/s]
[32m[I 2023-05-10 09:11:00,402][0m Trial 2 finished with value: 0.38199731707572937 and parameters: {'band_width': 0.13604422505592878}. Best is trial 2 with value: 0.38199731707572937.[0m


{'test_loss': 13.605753898620605, 'test_mse': 0.14592194557189941, 'test_rmse': 0.38199731707572937, 'test_runtime': 68.4037, 'test_samples_per_second': 65.786, 'test_steps_per_second': 2.061}
--------------------
bw= 0.22717420190217377
--------------------


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 1000/300720 [12

{'loss': 9.5229, 'learning_rate': 9.966746475126364e-06, 'epoch': 0.1}


  1%|          | 2000/300720 [25:18<63:16:29,  1.31it/s]

{'loss': 7.6885, 'learning_rate': 9.933492950252728e-06, 'epoch': 0.2}


  1%|          | 3000/300720 [37:59<64:44:50,  1.28it/s]

{'loss': 7.4698, 'learning_rate': 9.900239425379091e-06, 'epoch': 0.3}


  1%|▏         | 4000/300720 [50:42<62:34:35,  1.32it/s]

{'loss': 7.0761, 'learning_rate': 9.866985900505454e-06, 'epoch': 0.4}


  2%|▏         | 5000/300720 [1:03:25<61:30:31,  1.34it/s]

{'loss': 7.077, 'learning_rate': 9.833732375631819e-06, 'epoch': 0.5}


  2%|▏         | 6000/300720 [1:16:09<61:44:44,  1.33it/s]

{'loss': 6.7697, 'learning_rate': 9.800478850758181e-06, 'epoch': 0.6}


  2%|▏         | 7000/300720 [1:28:57<63:30:34,  1.28it/s]

{'loss': 6.8039, 'learning_rate': 9.767225325884544e-06, 'epoch': 0.7}


  3%|▎         | 8000/300720 [1:41:49<63:03:36,  1.29it/s]

{'loss': 6.8516, 'learning_rate': 9.733971801010907e-06, 'epoch': 0.8}


  3%|▎         | 9000/300720 [1:54:43<61:28:52,  1.32it/s]

{'loss': 6.6738, 'learning_rate': 9.70071827613727e-06, 'epoch': 0.9}


  3%|▎         | 10000/300720 [2:07:36<63:30:11,  1.27it/s]

{'loss': 6.6383, 'learning_rate': 9.667464751263635e-06, 'epoch': 1.0}


  3%|▎         | 10024/300720 [2:07:54<52:52:51,  1.53it/s]
  3%|▎         | 10024/300720 [2:09:05<52:52:51,  1.53it/s]

{'eval_loss': 13.761797904968262, 'eval_mse': 0.15040460228919983, 'eval_rmse': 0.3878203332424164, 'eval_runtime': 70.9424, 'eval_samples_per_second': 63.432, 'eval_steps_per_second': 1.988, 'epoch': 1.0}


  4%|▎         | 11000/300720 [2:21:27<61:03:52,  1.32it/s]  

{'loss': 6.0481, 'learning_rate': 9.634211226389998e-06, 'epoch': 1.1}


  4%|▍         | 12000/300720 [2:34:06<62:13:39,  1.29it/s]

{'loss': 6.0334, 'learning_rate': 9.600957701516362e-06, 'epoch': 1.2}


  4%|▍         | 13000/300720 [2:46:48<60:37:03,  1.32it/s]

{'loss': 5.8471, 'learning_rate': 9.567704176642725e-06, 'epoch': 1.3}


  5%|▍         | 14000/300720 [2:59:33<60:36:07,  1.31it/s]

{'loss': 5.8321, 'learning_rate': 9.534450651769088e-06, 'epoch': 1.4}


  5%|▍         | 15000/300720 [3:12:16<61:53:27,  1.28it/s]

{'loss': 5.9153, 'learning_rate': 9.501197126895452e-06, 'epoch': 1.5}


  5%|▌         | 16000/300720 [3:25:03<59:21:01,  1.33it/s]

{'loss': 5.891, 'learning_rate': 9.467943602021815e-06, 'epoch': 1.6}


  6%|▌         | 17000/300720 [3:37:51<59:38:46,  1.32it/s]

{'loss': 6.0167, 'learning_rate': 9.434690077148178e-06, 'epoch': 1.7}


  6%|▌         | 18000/300720 [3:50:43<61:29:52,  1.28it/s]

{'loss': 5.9303, 'learning_rate': 9.401436552274543e-06, 'epoch': 1.8}


  6%|▋         | 19000/300720 [4:03:36<61:07:18,  1.28it/s]

{'loss': 5.9961, 'learning_rate': 9.368183027400906e-06, 'epoch': 1.9}


  7%|▋         | 20000/300720 [4:16:34<58:29:16,  1.33it/s]

{'loss': 5.7871, 'learning_rate': 9.334929502527269e-06, 'epoch': 2.0}


  7%|▋         | 20048/300720 [4:17:11<51:05:50,  1.53it/s]
  7%|▋         | 20048/300720 [4:18:24<51:05:50,  1.53it/s]

{'eval_loss': 13.471317291259766, 'eval_mse': 0.15115264058113098, 'eval_rmse': 0.3887835443019867, 'eval_runtime': 72.5283, 'eval_samples_per_second': 62.045, 'eval_steps_per_second': 1.944, 'epoch': 2.0}


  7%|▋         | 21000/300720 [4:30:30<57:56:24,  1.34it/s]  

{'loss': 5.0407, 'learning_rate': 9.301675977653633e-06, 'epoch': 2.09}


  7%|▋         | 22000/300720 [4:43:14<57:44:41,  1.34it/s]

{'loss': 5.0943, 'learning_rate': 9.268422452779996e-06, 'epoch': 2.19}


  8%|▊         | 23000/300720 [4:55:59<58:27:51,  1.32it/s]

{'loss': 5.0409, 'learning_rate': 9.235168927906359e-06, 'epoch': 2.29}


  8%|▊         | 24000/300720 [5:08:45<56:05:43,  1.37it/s]

{'loss': 5.2076, 'learning_rate': 9.201915403032722e-06, 'epoch': 2.39}


  8%|▊         | 25000/300720 [5:21:31<58:44:51,  1.30it/s]

{'loss': 5.1954, 'learning_rate': 9.168661878159085e-06, 'epoch': 2.49}


  9%|▊         | 26000/300720 [5:34:16<57:34:21,  1.33it/s]

{'loss': 5.0162, 'learning_rate': 9.13540835328545e-06, 'epoch': 2.59}


  9%|▉         | 27000/300720 [5:47:06<58:23:51,  1.30it/s]

{'loss': 5.3007, 'learning_rate': 9.102154828411812e-06, 'epoch': 2.69}


  9%|▉         | 28000/300720 [5:59:57<57:43:52,  1.31it/s]

{'loss': 5.1511, 'learning_rate': 9.068901303538175e-06, 'epoch': 2.79}


 10%|▉         | 29000/300720 [6:12:52<60:58:24,  1.24it/s]

{'loss': 5.0887, 'learning_rate': 9.03564777866454e-06, 'epoch': 2.89}


 10%|▉         | 30000/300720 [6:25:46<58:46:50,  1.28it/s]

{'loss': 5.044, 'learning_rate': 9.002394253790902e-06, 'epoch': 2.99}


 10%|█         | 30072/300720 [6:26:42<49:44:50,  1.51it/s]
 10%|█         | 30072/300720 [6:27:54<49:44:50,  1.51it/s]

{'eval_loss': 14.307493209838867, 'eval_mse': 0.1552175134420395, 'eval_rmse': 0.3939765393733978, 'eval_runtime': 72.3171, 'eval_samples_per_second': 62.226, 'eval_steps_per_second': 1.95, 'epoch': 3.0}


 10%|█         | 30072/300720 [6:27:56<58:11:28,  1.29it/s]


{'train_runtime': 23276.538, 'train_samples_per_second': 413.385, 'train_steps_per_second': 12.919, 'train_loss': 6.1001987787090926, 'epoch': 3.0}


100%|██████████| 141/141 [01:08<00:00,  2.07it/s]
[32m[I 2023-05-10 15:40:07,944][0m Trial 3 finished with value: 0.3878203332424164 and parameters: {'band_width': 0.22717420190217377}. Best is trial 2 with value: 0.38199731707572937.[0m


{'test_loss': 13.761797904968262, 'test_mse': 0.15040460228919983, 'test_rmse': 0.3878203332424164, 'test_runtime': 68.5442, 'test_samples_per_second': 65.651, 'test_steps_per_second': 2.057}
--------------------
bw= 0.2892402282524722
--------------------


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 1000/300720 [12

{'loss': 9.7063, 'learning_rate': 9.966746475126364e-06, 'epoch': 0.1}


  1%|          | 2000/300720 [25:24<64:15:37,  1.29it/s]

{'loss': 7.8285, 'learning_rate': 9.933492950252728e-06, 'epoch': 0.2}


  1%|          | 3000/300720 [38:09<63:27:16,  1.30it/s]

{'loss': 7.6311, 'learning_rate': 9.900239425379091e-06, 'epoch': 0.3}


  1%|▏         | 4000/300720 [50:56<62:47:14,  1.31it/s]

{'loss': 7.2019, 'learning_rate': 9.866985900505454e-06, 'epoch': 0.4}


  2%|▏         | 5000/300720 [1:03:43<62:19:57,  1.32it/s]

{'loss': 7.1537, 'learning_rate': 9.833732375631819e-06, 'epoch': 0.5}


  2%|▏         | 6000/300720 [1:16:30<61:17:08,  1.34it/s]

{'loss': 6.9063, 'learning_rate': 9.800478850758181e-06, 'epoch': 0.6}


  2%|▏         | 7000/300720 [1:29:21<63:07:33,  1.29it/s]

{'loss': 6.9355, 'learning_rate': 9.767225325884544e-06, 'epoch': 0.7}


  3%|▎         | 8000/300720 [1:42:14<62:44:47,  1.30it/s]

{'loss': 6.948, 'learning_rate': 9.733971801010907e-06, 'epoch': 0.8}


  3%|▎         | 9000/300720 [1:55:11<61:36:51,  1.32it/s]

{'loss': 6.7975, 'learning_rate': 9.70071827613727e-06, 'epoch': 0.9}


  3%|▎         | 10000/300720 [2:08:06<63:23:28,  1.27it/s]

{'loss': 6.7738, 'learning_rate': 9.667464751263635e-06, 'epoch': 1.0}


                                                           
  3%|▎         | 10024/300720 [2:09:35<53:32:58,  1.51it/s]

{'eval_loss': 13.680593490600586, 'eval_mse': 0.14933007955551147, 'eval_rmse': 0.3864324986934662, 'eval_runtime': 71.188, 'eval_samples_per_second': 63.213, 'eval_steps_per_second': 1.981, 'epoch': 1.0}


  4%|▎         | 11000/300720 [2:21:59<60:28:13,  1.33it/s]  

{'loss': 6.1491, 'learning_rate': 9.634211226389998e-06, 'epoch': 1.1}


  4%|▍         | 12000/300720 [2:34:40<62:27:13,  1.28it/s]

{'loss': 6.1185, 'learning_rate': 9.600957701516362e-06, 'epoch': 1.2}


  4%|▍         | 13000/300720 [2:47:24<60:41:12,  1.32it/s]

{'loss': 5.9776, 'learning_rate': 9.567704176642725e-06, 'epoch': 1.3}


  5%|▍         | 14000/300720 [3:00:10<61:04:13,  1.30it/s]

{'loss': 5.9483, 'learning_rate': 9.534450651769088e-06, 'epoch': 1.4}


  5%|▍         | 15000/300720 [3:12:55<62:45:23,  1.26it/s]

{'loss': 6.0229, 'learning_rate': 9.501197126895452e-06, 'epoch': 1.5}


  5%|▌         | 16000/300720 [3:25:42<59:01:02,  1.34it/s]

{'loss': 5.979, 'learning_rate': 9.467943602021815e-06, 'epoch': 1.6}


  6%|▌         | 17000/300720 [3:38:32<59:27:04,  1.33it/s]

{'loss': 6.1106, 'learning_rate': 9.434690077148178e-06, 'epoch': 1.7}


  6%|▌         | 18000/300720 [3:51:26<61:08:36,  1.28it/s]

{'loss': 6.0444, 'learning_rate': 9.401436552274543e-06, 'epoch': 1.8}


  6%|▋         | 19000/300720 [4:04:21<62:19:09,  1.26it/s]

{'loss': 6.0994, 'learning_rate': 9.368183027400906e-06, 'epoch': 1.9}


  7%|▋         | 20000/300720 [4:17:20<58:47:29,  1.33it/s]

{'loss': 5.8877, 'learning_rate': 9.334929502527269e-06, 'epoch': 2.0}


                                                           
  7%|▋         | 20048/300720 [4:19:09<50:32:16,  1.54it/s]

{'eval_loss': 13.474503517150879, 'eval_mse': 0.15178950130939484, 'eval_rmse': 0.3896017074584961, 'eval_runtime': 72.2138, 'eval_samples_per_second': 62.315, 'eval_steps_per_second': 1.953, 'epoch': 2.0}


  7%|▋         | 21000/300720 [4:31:15<58:30:30,  1.33it/s]  

{'loss': 5.1327, 'learning_rate': 9.301675977653633e-06, 'epoch': 2.09}


  7%|▋         | 22000/300720 [4:43:59<58:30:44,  1.32it/s]

{'loss': 5.1833, 'learning_rate': 9.268422452779996e-06, 'epoch': 2.19}


  8%|▊         | 23000/300720 [4:56:43<58:26:58,  1.32it/s]

{'loss': 5.1312, 'learning_rate': 9.235168927906359e-06, 'epoch': 2.29}


  8%|▊         | 24000/300720 [5:09:30<56:39:23,  1.36it/s]

{'loss': 5.3027, 'learning_rate': 9.201915403032722e-06, 'epoch': 2.39}


  8%|▊         | 25000/300720 [5:22:18<59:21:44,  1.29it/s]

{'loss': 5.3178, 'learning_rate': 9.168661878159085e-06, 'epoch': 2.49}


  9%|▊         | 26000/300720 [5:35:03<58:26:12,  1.31it/s]

{'loss': 5.0991, 'learning_rate': 9.13540835328545e-06, 'epoch': 2.59}


  9%|▉         | 27000/300720 [5:47:53<57:44:27,  1.32it/s]

{'loss': 5.3694, 'learning_rate': 9.102154828411812e-06, 'epoch': 2.69}


  9%|▉         | 28000/300720 [6:00:42<57:45:37,  1.31it/s]

{'loss': 5.2611, 'learning_rate': 9.068901303538175e-06, 'epoch': 2.79}


 10%|▉         | 29000/300720 [6:13:36<60:42:12,  1.24it/s]

{'loss': 5.1807, 'learning_rate': 9.03564777866454e-06, 'epoch': 2.89}


 10%|▉         | 30000/300720 [6:26:31<58:37:29,  1.28it/s]

{'loss': 5.1378, 'learning_rate': 9.002394253790902e-06, 'epoch': 2.99}


                                                           
 10%|█         | 30072/300720 [6:28:38<50:16:46,  1.50it/s]

{'eval_loss': 14.630603790283203, 'eval_mse': 0.15877428650856018, 'eval_rmse': 0.3984649181365967, 'eval_runtime': 71.8124, 'eval_samples_per_second': 62.663, 'eval_steps_per_second': 1.963, 'epoch': 3.0}


 10%|█         | 30072/300720 [6:28:40<58:18:04,  1.29it/s]


{'train_runtime': 23320.5073, 'train_samples_per_second': 412.606, 'train_steps_per_second': 12.895, 'train_loss': 6.2096332611952665, 'epoch': 3.0}


100%|██████████| 141/141 [01:08<00:00,  2.07it/s]
[32m[I 2023-05-10 22:09:59,471][0m Trial 4 finished with value: 0.3864324986934662 and parameters: {'band_width': 0.2892402282524722}. Best is trial 2 with value: 0.38199731707572937.[0m


{'test_loss': 13.680593490600586, 'test_mse': 0.14933007955551147, 'test_rmse': 0.3864324986934662, 'test_runtime': 68.5375, 'test_samples_per_second': 65.658, 'test_steps_per_second': 2.057}
--------------------
bw= 0.15652739552192593
--------------------


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 1000/300720 [12

{'loss': 9.4188, 'learning_rate': 9.966746475126364e-06, 'epoch': 0.1}


  1%|          | 2000/300720 [25:24<63:41:28,  1.30it/s]

{'loss': 7.5925, 'learning_rate': 9.933492950252728e-06, 'epoch': 0.2}


  1%|          | 3000/300720 [38:07<64:00:10,  1.29it/s]

{'loss': 7.3738, 'learning_rate': 9.900239425379091e-06, 'epoch': 0.3}


  1%|▏         | 4000/300720 [50:51<62:40:00,  1.32it/s]

{'loss': 6.9836, 'learning_rate': 9.866985900505454e-06, 'epoch': 0.4}


  2%|▏         | 5000/300720 [1:03:38<61:55:39,  1.33it/s]

{'loss': 6.9899, 'learning_rate': 9.833732375631819e-06, 'epoch': 0.5}


  2%|▏         | 6000/300720 [1:16:24<61:28:47,  1.33it/s]

{'loss': 6.7078, 'learning_rate': 9.800478850758181e-06, 'epoch': 0.6}


  2%|▏         | 7000/300720 [1:29:14<63:06:13,  1.29it/s]

{'loss': 6.7481, 'learning_rate': 9.767225325884544e-06, 'epoch': 0.7}


  3%|▎         | 8000/300720 [1:42:07<63:32:25,  1.28it/s]

{'loss': 6.7833, 'learning_rate': 9.733971801010907e-06, 'epoch': 0.8}


  3%|▎         | 9000/300720 [1:55:03<61:20:03,  1.32it/s]

{'loss': 6.5864, 'learning_rate': 9.70071827613727e-06, 'epoch': 0.9}


  3%|▎         | 10000/300720 [2:07:58<63:23:43,  1.27it/s]

{'loss': 6.5734, 'learning_rate': 9.667464751263635e-06, 'epoch': 1.0}


  3%|▎         | 10024/300720 [2:08:16<53:11:36,  1.52it/s]
  3%|▎         | 10024/300720 [2:09:26<53:11:36,  1.52it/s]

{'eval_loss': 13.652792930603027, 'eval_mse': 0.14744333922863007, 'eval_rmse': 0.3839835226535797, 'eval_runtime': 69.6489, 'eval_samples_per_second': 64.61, 'eval_steps_per_second': 2.024, 'epoch': 1.0}


  4%|▎         | 11000/300720 [2:21:50<61:22:18,  1.31it/s]  

{'loss': 5.997, 'learning_rate': 9.634211226389998e-06, 'epoch': 1.1}


  4%|▍         | 12000/300720 [2:34:31<61:41:40,  1.30it/s]

{'loss': 5.9624, 'learning_rate': 9.600957701516362e-06, 'epoch': 1.2}


  4%|▍         | 13000/300720 [2:47:15<60:36:53,  1.32it/s]

{'loss': 5.7732, 'learning_rate': 9.567704176642725e-06, 'epoch': 1.3}


  5%|▍         | 14000/300720 [3:00:01<60:46:33,  1.31it/s]

{'loss': 5.7449, 'learning_rate': 9.534450651769088e-06, 'epoch': 1.4}


  5%|▍         | 15000/300720 [3:12:47<62:19:46,  1.27it/s]

{'loss': 5.8427, 'learning_rate': 9.501197126895452e-06, 'epoch': 1.5}


  5%|▌         | 16000/300720 [3:25:35<59:16:44,  1.33it/s]

{'loss': 5.8321, 'learning_rate': 9.467943602021815e-06, 'epoch': 1.6}


  6%|▌         | 17000/300720 [3:38:24<59:48:09,  1.32it/s]

{'loss': 5.9253, 'learning_rate': 9.434690077148178e-06, 'epoch': 1.7}


  6%|▌         | 18000/300720 [3:51:18<60:59:58,  1.29it/s]

{'loss': 5.8359, 'learning_rate': 9.401436552274543e-06, 'epoch': 1.8}


  6%|▋         | 19000/300720 [4:04:13<62:32:03,  1.25it/s]

{'loss': 5.9298, 'learning_rate': 9.368183027400906e-06, 'epoch': 1.9}


  7%|▋         | 20000/300720 [4:17:13<58:51:57,  1.32it/s]

{'loss': 5.7124, 'learning_rate': 9.334929502527269e-06, 'epoch': 2.0}


  7%|▋         | 20048/300720 [4:17:50<51:00:54,  1.53it/s]
  7%|▋         | 20048/300720 [4:19:02<51:00:54,  1.53it/s]

{'eval_loss': 13.803255081176758, 'eval_mse': 0.15215224027633667, 'eval_rmse': 0.3900669813156128, 'eval_runtime': 71.911, 'eval_samples_per_second': 62.577, 'eval_steps_per_second': 1.961, 'epoch': 2.0}


  7%|▋         | 21000/300720 [4:31:09<57:52:54,  1.34it/s]  

{'loss': 4.9749, 'learning_rate': 9.301675977653633e-06, 'epoch': 2.09}


  7%|▋         | 22000/300720 [4:43:53<58:24:56,  1.33it/s]

{'loss': 5.0241, 'learning_rate': 9.268422452779996e-06, 'epoch': 2.19}


  8%|▊         | 23000/300720 [4:56:39<58:45:15,  1.31it/s]

{'loss': 4.9734, 'learning_rate': 9.235168927906359e-06, 'epoch': 2.29}


  8%|▊         | 24000/300720 [5:09:26<56:27:14,  1.36it/s]

{'loss': 5.1167, 'learning_rate': 9.201915403032722e-06, 'epoch': 2.39}


  8%|▊         | 25000/300720 [5:22:14<59:28:09,  1.29it/s]

{'loss': 5.1201, 'learning_rate': 9.168661878159085e-06, 'epoch': 2.49}


  9%|▊         | 26000/300720 [5:35:01<58:04:04,  1.31it/s]

{'loss': 4.9134, 'learning_rate': 9.13540835328545e-06, 'epoch': 2.59}


  9%|▉         | 27000/300720 [5:47:52<58:02:21,  1.31it/s]

{'loss': 5.206, 'learning_rate': 9.102154828411812e-06, 'epoch': 2.69}


  9%|▉         | 28000/300720 [6:00:44<57:58:22,  1.31it/s]

{'loss': 5.0752, 'learning_rate': 9.068901303538175e-06, 'epoch': 2.79}


 10%|▉         | 29000/300720 [6:13:39<60:57:20,  1.24it/s]

{'loss': 5.0285, 'learning_rate': 9.03564777866454e-06, 'epoch': 2.89}


 10%|▉         | 30000/300720 [6:26:35<59:21:03,  1.27it/s]

{'loss': 4.9529, 'learning_rate': 9.002394253790902e-06, 'epoch': 2.99}


 10%|█         | 30072/300720 [6:27:31<51:24:11,  1.46it/s]
 10%|█         | 30072/300720 [6:28:43<51:24:11,  1.46it/s]

{'eval_loss': 14.382702827453613, 'eval_mse': 0.15476210415363312, 'eval_rmse': 0.39339813590049744, 'eval_runtime': 71.8983, 'eval_samples_per_second': 62.588, 'eval_steps_per_second': 1.961, 'epoch': 3.0}


 10%|█         | 30072/300720 [6:28:44<58:18:44,  1.29it/s]


{'train_runtime': 23325.0242, 'train_samples_per_second': 412.526, 'train_steps_per_second': 12.893, 'train_loss': 6.021998299661866, 'epoch': 3.0}


100%|██████████| 141/141 [01:08<00:00,  2.07it/s]
[32m[I 2023-05-11 04:39:54,946][0m Trial 5 finished with value: 0.3839835226535797 and parameters: {'band_width': 0.15652739552192593}. Best is trial 2 with value: 0.38199731707572937.[0m


{'test_loss': 13.652792930603027, 'test_mse': 0.14744333922863007, 'test_rmse': 0.3839835226535797, 'test_runtime': 68.523, 'test_samples_per_second': 65.671, 'test_steps_per_second': 2.058}
--------------------
bw= 0.43501629307970335
--------------------


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 1000/300720 [12

{'loss': 10.0654, 'learning_rate': 9.966746475126364e-06, 'epoch': 0.1}


  1%|          | 2000/300720 [25:25<64:43:41,  1.28it/s]

{'loss': 8.0975, 'learning_rate': 9.933492950252728e-06, 'epoch': 0.2}


  1%|          | 3000/300720 [38:10<64:28:59,  1.28it/s]

{'loss': 7.8678, 'learning_rate': 9.900239425379091e-06, 'epoch': 0.3}


  1%|▏         | 4000/300720 [50:57<62:31:26,  1.32it/s]

{'loss': 7.4917, 'learning_rate': 9.866985900505454e-06, 'epoch': 0.4}


  2%|▏         | 5000/300720 [1:03:45<62:04:21,  1.32it/s]

{'loss': 7.4081, 'learning_rate': 9.833732375631819e-06, 'epoch': 0.5}


  2%|▏         | 6000/300720 [1:16:32<60:54:20,  1.34it/s]

{'loss': 7.1431, 'learning_rate': 9.800478850758181e-06, 'epoch': 0.6}


  2%|▏         | 7000/300720 [1:29:24<63:01:34,  1.29it/s]

{'loss': 7.1799, 'learning_rate': 9.767225325884544e-06, 'epoch': 0.7}


  3%|▎         | 8000/300720 [1:42:17<62:50:38,  1.29it/s]

{'loss': 7.2123, 'learning_rate': 9.733971801010907e-06, 'epoch': 0.8}


  3%|▎         | 9000/300720 [1:55:12<60:43:52,  1.33it/s]

{'loss': 7.039, 'learning_rate': 9.70071827613727e-06, 'epoch': 0.9}


  3%|▎         | 10000/300720 [2:08:08<63:33:34,  1.27it/s]

{'loss': 7.0077, 'learning_rate': 9.667464751263635e-06, 'epoch': 1.0}


  3%|▎         | 10024/300720 [2:08:26<52:52:28,  1.53it/s]
  3%|▎         | 10024/300720 [2:09:36<52:52:28,  1.53it/s]

{'eval_loss': 14.567038536071777, 'eval_mse': 0.1538856327533722, 'eval_rmse': 0.3922826051712036, 'eval_runtime': 69.7387, 'eval_samples_per_second': 64.527, 'eval_steps_per_second': 2.022, 'epoch': 1.0}


  4%|▎         | 11000/300720 [2:22:00<61:06:37,  1.32it/s]  

{'loss': 6.3654, 'learning_rate': 9.634211226389998e-06, 'epoch': 1.1}


  4%|▍         | 12000/300720 [2:34:40<61:32:52,  1.30it/s]

{'loss': 6.3604, 'learning_rate': 9.600957701516362e-06, 'epoch': 1.2}


  4%|▍         | 13000/300720 [2:47:24<60:13:28,  1.33it/s]

{'loss': 6.2037, 'learning_rate': 9.567704176642725e-06, 'epoch': 1.3}


  5%|▍         | 14000/300720 [3:00:10<61:04:05,  1.30it/s]

{'loss': 6.1772, 'learning_rate': 9.534450651769088e-06, 'epoch': 1.4}


  5%|▍         | 15000/300720 [3:12:55<62:16:43,  1.27it/s]

{'loss': 6.2515, 'learning_rate': 9.501197126895452e-06, 'epoch': 1.5}


  5%|▌         | 16000/300720 [3:25:44<59:31:58,  1.33it/s]

{'loss': 6.1993, 'learning_rate': 9.467943602021815e-06, 'epoch': 1.6}


  6%|▌         | 17000/300720 [3:38:34<59:56:26,  1.31it/s]

{'loss': 6.3258, 'learning_rate': 9.434690077148178e-06, 'epoch': 1.7}


  6%|▌         | 18000/300720 [3:51:28<61:04:31,  1.29it/s]

{'loss': 6.3242, 'learning_rate': 9.401436552274543e-06, 'epoch': 1.8}


  6%|▋         | 19000/300720 [4:04:23<62:35:13,  1.25it/s]

{'loss': 6.3136, 'learning_rate': 9.368183027400906e-06, 'epoch': 1.9}


  7%|▋         | 20000/300720 [4:17:22<59:22:23,  1.31it/s]

{'loss': 6.0796, 'learning_rate': 9.334929502527269e-06, 'epoch': 2.0}


  7%|▋         | 20048/300720 [4:17:59<51:17:09,  1.52it/s]
  7%|▋         | 20048/300720 [4:19:11<51:17:09,  1.52it/s]

{'eval_loss': 13.99327278137207, 'eval_mse': 0.15443281829357147, 'eval_rmse': 0.3929794132709503, 'eval_runtime': 72.0204, 'eval_samples_per_second': 62.482, 'eval_steps_per_second': 1.958, 'epoch': 2.0}


  7%|▋         | 21000/300720 [4:31:18<57:54:01,  1.34it/s]  

{'loss': 5.321, 'learning_rate': 9.301675977653633e-06, 'epoch': 2.09}


  7%|▋         | 22000/300720 [4:44:02<58:30:50,  1.32it/s]

{'loss': 5.3837, 'learning_rate': 9.268422452779996e-06, 'epoch': 2.19}


  8%|▊         | 23000/300720 [4:56:47<59:14:35,  1.30it/s]

{'loss': 5.3437, 'learning_rate': 9.235168927906359e-06, 'epoch': 2.29}


  8%|▊         | 24000/300720 [5:09:34<56:45:42,  1.35it/s]

{'loss': 5.5082, 'learning_rate': 9.201915403032722e-06, 'epoch': 2.39}


  8%|▊         | 25000/300720 [5:22:21<58:59:36,  1.30it/s]

{'loss': 5.5396, 'learning_rate': 9.168661878159085e-06, 'epoch': 2.49}


  9%|▊         | 26000/300720 [5:35:07<58:09:47,  1.31it/s]

{'loss': 5.2844, 'learning_rate': 9.13540835328545e-06, 'epoch': 2.59}


  9%|▉         | 27000/300720 [5:47:56<58:03:47,  1.31it/s]

{'loss': 5.5733, 'learning_rate': 9.102154828411812e-06, 'epoch': 2.69}


  9%|▉         | 28000/300720 [6:00:46<58:09:30,  1.30it/s]

{'loss': 5.4635, 'learning_rate': 9.068901303538175e-06, 'epoch': 2.79}


 10%|▉         | 29000/300720 [6:13:42<61:15:05,  1.23it/s]

{'loss': 5.3642, 'learning_rate': 9.03564777866454e-06, 'epoch': 2.89}


 10%|▉         | 30000/300720 [6:26:37<58:28:20,  1.29it/s]

{'loss': 5.338, 'learning_rate': 9.002394253790902e-06, 'epoch': 2.99}


 10%|█         | 30072/300720 [6:27:33<51:09:52,  1.47it/s]
 10%|█         | 30072/300720 [6:28:45<51:09:52,  1.47it/s]

{'eval_loss': 14.912881851196289, 'eval_mse': 0.1595155894756317, 'eval_rmse': 0.39939403533935547, 'eval_runtime': 72.046, 'eval_samples_per_second': 62.46, 'eval_steps_per_second': 1.957, 'epoch': 3.0}


 10%|█         | 30072/300720 [6:28:47<58:19:04,  1.29it/s]


{'train_runtime': 23327.2273, 'train_samples_per_second': 412.487, 'train_steps_per_second': 12.891, 'train_loss': 6.439510985808539, 'epoch': 3.0}


100%|██████████| 141/141 [01:08<00:00,  2.07it/s]
[32m[I 2023-05-11 11:09:53,142][0m Trial 6 finished with value: 0.3922826051712036 and parameters: {'band_width': 0.43501629307970335}. Best is trial 2 with value: 0.38199731707572937.[0m


{'test_loss': 14.567038536071777, 'test_mse': 0.1538856327533722, 'test_rmse': 0.3922826051712036, 'test_runtime': 68.5711, 'test_samples_per_second': 65.625, 'test_steps_per_second': 2.056}
--------------------
bw= 0.15080808249757954
--------------------


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 1000/300720 [12

{'loss': 9.4492, 'learning_rate': 9.966746475126364e-06, 'epoch': 0.1}


  1%|          | 2000/300720 [25:26<64:07:42,  1.29it/s]

{'loss': 7.6249, 'learning_rate': 9.933492950252728e-06, 'epoch': 0.2}


  1%|          | 3000/300720 [38:11<63:37:25,  1.30it/s]

{'loss': 7.3542, 'learning_rate': 9.900239425379091e-06, 'epoch': 0.3}


  1%|▏         | 4000/300720 [50:58<62:17:21,  1.32it/s]

{'loss': 6.965, 'learning_rate': 9.866985900505454e-06, 'epoch': 0.4}


  2%|▏         | 5000/300720 [1:03:44<62:01:39,  1.32it/s]

{'loss': 6.9999, 'learning_rate': 9.833732375631819e-06, 'epoch': 0.5}


  2%|▏         | 6000/300720 [1:16:31<61:04:24,  1.34it/s]

{'loss': 6.7294, 'learning_rate': 9.800478850758181e-06, 'epoch': 0.6}


  2%|▏         | 7000/300720 [1:29:21<63:03:46,  1.29it/s]

{'loss': 6.741, 'learning_rate': 9.767225325884544e-06, 'epoch': 0.7}


  3%|▎         | 8000/300720 [1:42:14<63:15:27,  1.29it/s]

{'loss': 6.7838, 'learning_rate': 9.733971801010907e-06, 'epoch': 0.8}


  3%|▎         | 9000/300720 [1:55:08<61:32:50,  1.32it/s]

{'loss': 6.5941, 'learning_rate': 9.70071827613727e-06, 'epoch': 0.9}


  3%|▎         | 10000/300720 [2:08:02<62:49:39,  1.29it/s]

{'loss': 6.5796, 'learning_rate': 9.667464751263635e-06, 'epoch': 1.0}


  3%|▎         | 10024/300720 [2:08:20<52:27:04,  1.54it/s]
  3%|▎         | 10024/300720 [2:09:29<52:27:04,  1.54it/s]

{'eval_loss': 13.737706184387207, 'eval_mse': 0.1483323574066162, 'eval_rmse': 0.3851394057273865, 'eval_runtime': 68.8284, 'eval_samples_per_second': 65.38, 'eval_steps_per_second': 2.049, 'epoch': 1.0}


  4%|▎         | 11000/300720 [2:21:50<60:48:52,  1.32it/s]  

{'loss': 6.0113, 'learning_rate': 9.634211226389998e-06, 'epoch': 1.1}


  4%|▍         | 12000/300720 [2:34:29<61:13:53,  1.31it/s]

{'loss': 5.9415, 'learning_rate': 9.600957701516362e-06, 'epoch': 1.2}


  4%|▍         | 13000/300720 [2:47:12<60:18:28,  1.33it/s]

{'loss': 5.7674, 'learning_rate': 9.567704176642725e-06, 'epoch': 1.3}


  5%|▍         | 14000/300720 [2:59:56<60:19:24,  1.32it/s]

{'loss': 5.7314, 'learning_rate': 9.534450651769088e-06, 'epoch': 1.4}


  5%|▍         | 15000/300720 [3:12:39<62:11:54,  1.28it/s]

{'loss': 5.8461, 'learning_rate': 9.501197126895452e-06, 'epoch': 1.5}


  5%|▌         | 16000/300720 [3:25:25<59:02:32,  1.34it/s]

{'loss': 5.8233, 'learning_rate': 9.467943602021815e-06, 'epoch': 1.6}


  6%|▌         | 17000/300720 [3:38:12<59:27:33,  1.33it/s]

{'loss': 5.9417, 'learning_rate': 9.434690077148178e-06, 'epoch': 1.7}


  6%|▌         | 18000/300720 [3:51:05<61:13:30,  1.28it/s]

{'loss': 5.8308, 'learning_rate': 9.401436552274543e-06, 'epoch': 1.8}


  6%|▋         | 19000/300720 [4:03:58<61:32:33,  1.27it/s]

{'loss': 5.9313, 'learning_rate': 9.368183027400906e-06, 'epoch': 1.9}


  7%|▋         | 20000/300720 [4:16:55<59:09:28,  1.32it/s]

{'loss': 5.7146, 'learning_rate': 9.334929502527269e-06, 'epoch': 2.0}


  7%|▋         | 20048/300720 [4:17:32<50:39:06,  1.54it/s]
  7%|▋         | 20048/300720 [4:18:43<50:39:06,  1.54it/s]

{'eval_loss': 13.690694808959961, 'eval_mse': 0.15309475362300873, 'eval_rmse': 0.39127326011657715, 'eval_runtime': 70.7638, 'eval_samples_per_second': 63.592, 'eval_steps_per_second': 1.993, 'epoch': 2.0}


  7%|▋         | 21000/300720 [4:30:48<58:36:12,  1.33it/s]  

{'loss': 4.9683, 'learning_rate': 9.301675977653633e-06, 'epoch': 2.09}


  7%|▋         | 22000/300720 [4:43:31<58:10:28,  1.33it/s]

{'loss': 5.0188, 'learning_rate': 9.268422452779996e-06, 'epoch': 2.19}


  8%|▊         | 23000/300720 [4:56:14<58:27:50,  1.32it/s]

{'loss': 4.9721, 'learning_rate': 9.235168927906359e-06, 'epoch': 2.29}


  8%|▊         | 24000/300720 [5:09:00<56:12:10,  1.37it/s]

{'loss': 5.142, 'learning_rate': 9.201915403032722e-06, 'epoch': 2.39}


  8%|▊         | 25000/300720 [5:21:47<59:43:35,  1.28it/s]

{'loss': 5.1501, 'learning_rate': 9.168661878159085e-06, 'epoch': 2.49}


  9%|▊         | 26000/300720 [5:34:32<58:01:10,  1.32it/s]

{'loss': 4.9315, 'learning_rate': 9.13540835328545e-06, 'epoch': 2.59}


  9%|▉         | 27000/300720 [5:47:22<58:27:21,  1.30it/s]

{'loss': 5.1882, 'learning_rate': 9.102154828411812e-06, 'epoch': 2.69}


  9%|▉         | 28000/300720 [6:00:11<57:49:35,  1.31it/s]

{'loss': 5.0818, 'learning_rate': 9.068901303538175e-06, 'epoch': 2.79}


 10%|▉         | 29000/300720 [6:13:04<60:39:01,  1.24it/s]

{'loss': 5.0506, 'learning_rate': 9.03564777866454e-06, 'epoch': 2.89}


 10%|▉         | 30000/300720 [6:25:58<58:49:47,  1.28it/s]

{'loss': 4.9802, 'learning_rate': 9.002394253790902e-06, 'epoch': 2.99}


 10%|█         | 30072/300720 [6:26:54<51:12:20,  1.47it/s]
 10%|█         | 30072/300720 [6:28:06<51:12:20,  1.47it/s]

{'eval_loss': 14.843402862548828, 'eval_mse': 0.15754254162311554, 'eval_rmse': 0.39691630005836487, 'eval_runtime': 71.8438, 'eval_samples_per_second': 62.636, 'eval_steps_per_second': 1.963, 'epoch': 3.0}


 10%|█         | 30072/300720 [6:28:07<58:13:11,  1.29it/s]


{'train_runtime': 23287.9169, 'train_samples_per_second': 413.183, 'train_steps_per_second': 12.913, 'train_loss': 6.027010421164336, 'epoch': 3.0}


100%|██████████| 141/141 [01:07<00:00,  2.07it/s]
[32m[I 2023-05-11 17:39:11,374][0m Trial 7 finished with value: 0.3851394057273865 and parameters: {'band_width': 0.15080808249757954}. Best is trial 2 with value: 0.38199731707572937.[0m


{'test_loss': 13.737706184387207, 'test_mse': 0.1483323574066162, 'test_rmse': 0.3851394057273865, 'test_runtime': 68.4218, 'test_samples_per_second': 65.769, 'test_steps_per_second': 2.061}
--------------------
bw= 0.38356344280567367
--------------------


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 1000/300720 [12

{'loss': 9.9205, 'learning_rate': 9.966746475126364e-06, 'epoch': 0.1}


  1%|          | 2000/300720 [25:25<64:14:06,  1.29it/s]

{'loss': 7.9964, 'learning_rate': 9.933492950252728e-06, 'epoch': 0.2}


  1%|          | 3000/300720 [38:11<64:10:40,  1.29it/s]

{'loss': 7.7792, 'learning_rate': 9.900239425379091e-06, 'epoch': 0.3}


  1%|▏         | 4000/300720 [50:59<62:44:13,  1.31it/s]

{'loss': 7.3881, 'learning_rate': 9.866985900505454e-06, 'epoch': 0.4}


  2%|▏         | 5000/300720 [1:03:47<61:41:10,  1.33it/s]

{'loss': 7.3054, 'learning_rate': 9.833732375631819e-06, 'epoch': 0.5}


  2%|▏         | 6000/300720 [1:16:34<61:51:50,  1.32it/s]

{'loss': 7.0669, 'learning_rate': 9.800478850758181e-06, 'epoch': 0.6}


  2%|▏         | 7000/300720 [1:29:26<63:28:19,  1.29it/s]

{'loss': 7.0952, 'learning_rate': 9.767225325884544e-06, 'epoch': 0.7}


  3%|▎         | 8000/300720 [1:42:20<63:31:03,  1.28it/s]

{'loss': 7.1177, 'learning_rate': 9.733971801010907e-06, 'epoch': 0.8}


  3%|▎         | 9000/300720 [1:55:14<61:36:42,  1.32it/s]

{'loss': 6.9609, 'learning_rate': 9.70071827613727e-06, 'epoch': 0.9}


  3%|▎         | 10000/300720 [2:08:08<63:03:00,  1.28it/s]

{'loss': 6.9316, 'learning_rate': 9.667464751263635e-06, 'epoch': 1.0}


  3%|▎         | 10024/300720 [2:08:26<52:45:34,  1.53it/s]
  3%|▎         | 10024/300720 [2:09:36<52:45:34,  1.53it/s]

{'eval_loss': 14.590437889099121, 'eval_mse': 0.1537625640630722, 'eval_rmse': 0.3921256959438324, 'eval_runtime': 70.368, 'eval_samples_per_second': 63.949, 'eval_steps_per_second': 2.004, 'epoch': 1.0}


  4%|▎         | 11000/300720 [2:21:58<60:45:38,  1.32it/s]  

{'loss': 6.2909, 'learning_rate': 9.634211226389998e-06, 'epoch': 1.1}


  4%|▍         | 12000/300720 [2:34:37<61:29:26,  1.30it/s]

{'loss': 6.2793, 'learning_rate': 9.600957701516362e-06, 'epoch': 1.2}


  4%|▍         | 13000/300720 [2:47:18<59:40:37,  1.34it/s]

{'loss': 6.1375, 'learning_rate': 9.567704176642725e-06, 'epoch': 1.3}


  5%|▍         | 14000/300720 [3:00:03<61:01:00,  1.31it/s]

{'loss': 6.0826, 'learning_rate': 9.534450651769088e-06, 'epoch': 1.4}


  5%|▍         | 15000/300720 [3:12:45<62:04:44,  1.28it/s]

{'loss': 6.1571, 'learning_rate': 9.501197126895452e-06, 'epoch': 1.5}


  5%|▌         | 16000/300720 [3:25:31<59:02:20,  1.34it/s]

{'loss': 6.1323, 'learning_rate': 9.467943602021815e-06, 'epoch': 1.6}


  6%|▌         | 17000/300720 [3:38:18<59:10:24,  1.33it/s]

{'loss': 6.2586, 'learning_rate': 9.434690077148178e-06, 'epoch': 1.7}


  6%|▌         | 18000/300720 [3:51:10<61:22:40,  1.28it/s]

{'loss': 6.202, 'learning_rate': 9.401436552274543e-06, 'epoch': 1.8}


  6%|▋         | 19000/300720 [4:04:02<61:55:28,  1.26it/s]

{'loss': 6.2558, 'learning_rate': 9.368183027400906e-06, 'epoch': 1.9}


  7%|▋         | 20000/300720 [4:16:58<58:23:12,  1.34it/s]

{'loss': 5.994, 'learning_rate': 9.334929502527269e-06, 'epoch': 2.0}


  7%|▋         | 20048/300720 [4:17:35<50:12:01,  1.55it/s]
  7%|▋         | 20048/300720 [4:18:47<50:12:01,  1.55it/s]

{'eval_loss': 13.70345401763916, 'eval_mse': 0.1538609117269516, 'eval_rmse': 0.39225107431411743, 'eval_runtime': 71.8449, 'eval_samples_per_second': 62.635, 'eval_steps_per_second': 1.963, 'epoch': 2.0}


  7%|▋         | 21000/300720 [4:30:52<57:59:36,  1.34it/s]  

{'loss': 5.2351, 'learning_rate': 9.301675977653633e-06, 'epoch': 2.09}


  7%|▋         | 22000/300720 [4:43:33<57:52:53,  1.34it/s]

{'loss': 5.3214, 'learning_rate': 9.268422452779996e-06, 'epoch': 2.19}


  8%|▊         | 23000/300720 [4:56:14<58:15:12,  1.32it/s]

{'loss': 5.249, 'learning_rate': 9.235168927906359e-06, 'epoch': 2.29}


  8%|▊         | 24000/300720 [5:08:57<56:49:21,  1.35it/s]

{'loss': 5.4318, 'learning_rate': 9.201915403032722e-06, 'epoch': 2.39}


  8%|▊         | 25000/300720 [5:21:41<59:17:30,  1.29it/s]

{'loss': 5.4568, 'learning_rate': 9.168661878159085e-06, 'epoch': 2.49}


  9%|▊         | 26000/300720 [5:34:25<58:21:52,  1.31it/s]

{'loss': 5.2141, 'learning_rate': 9.13540835328545e-06, 'epoch': 2.59}


  9%|▉         | 27000/300720 [5:47:12<58:22:28,  1.30it/s]

{'loss': 5.487, 'learning_rate': 9.102154828411812e-06, 'epoch': 2.69}


  9%|▉         | 28000/300720 [6:00:01<58:05:50,  1.30it/s]

{'loss': 5.3775, 'learning_rate': 9.068901303538175e-06, 'epoch': 2.79}


 10%|▉         | 29000/300720 [6:12:54<60:39:46,  1.24it/s]

{'loss': 5.3122, 'learning_rate': 9.03564777866454e-06, 'epoch': 2.89}


 10%|▉         | 30000/300720 [6:25:48<58:02:59,  1.30it/s]

{'loss': 5.2552, 'learning_rate': 9.002394253790902e-06, 'epoch': 2.99}


 10%|█         | 30072/300720 [6:26:43<50:59:22,  1.47it/s]
 10%|█         | 30072/300720 [6:27:55<50:59:22,  1.47it/s]

{'eval_loss': 15.064465522766113, 'eval_mse': 0.16112564504146576, 'eval_rmse': 0.40140458941459656, 'eval_runtime': 71.7471, 'eval_samples_per_second': 62.72, 'eval_steps_per_second': 1.965, 'epoch': 3.0}


 10%|█         | 30072/300720 [6:27:57<58:11:33,  1.29it/s]


{'train_runtime': 23277.1255, 'train_samples_per_second': 413.374, 'train_steps_per_second': 12.919, 'train_loss': 6.354926354645985, 'epoch': 3.0}


100%|██████████| 141/141 [01:08<00:00,  2.07it/s]
[32m[I 2023-05-12 00:08:19,003][0m Trial 8 finished with value: 0.3921256959438324 and parameters: {'band_width': 0.38356344280567367}. Best is trial 2 with value: 0.38199731707572937.[0m


{'test_loss': 14.590437889099121, 'test_mse': 0.1537625640630722, 'test_rmse': 0.3921256959438324, 'test_runtime': 68.6016, 'test_samples_per_second': 65.596, 'test_steps_per_second': 2.055}
--------------------
bw= 0.17081934079556765
--------------------


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 1000/300720 [12

{'loss': 9.4121, 'learning_rate': 9.966746475126364e-06, 'epoch': 0.1}


  1%|          | 2000/300720 [25:23<63:40:03,  1.30it/s]

{'loss': 7.6057, 'learning_rate': 9.933492950252728e-06, 'epoch': 0.2}


  1%|          | 3000/300720 [38:06<64:02:02,  1.29it/s]

{'loss': 7.3571, 'learning_rate': 9.900239425379091e-06, 'epoch': 0.3}


  1%|▏         | 4000/300720 [50:51<62:20:04,  1.32it/s]

{'loss': 6.9745, 'learning_rate': 9.866985900505454e-06, 'epoch': 0.4}


  2%|▏         | 5000/300720 [1:03:37<61:39:19,  1.33it/s]

{'loss': 6.9955, 'learning_rate': 9.833732375631819e-06, 'epoch': 0.5}


  2%|▏         | 6000/300720 [1:16:23<61:11:23,  1.34it/s]

{'loss': 6.7282, 'learning_rate': 9.800478850758181e-06, 'epoch': 0.6}


  2%|▏         | 7000/300720 [1:29:13<62:51:07,  1.30it/s]

{'loss': 6.7365, 'learning_rate': 9.767225325884544e-06, 'epoch': 0.7}


  3%|▎         | 8000/300720 [1:42:05<62:36:27,  1.30it/s]

{'loss': 6.7734, 'learning_rate': 9.733971801010907e-06, 'epoch': 0.8}


  3%|▎         | 9000/300720 [1:55:00<61:24:51,  1.32it/s]

{'loss': 6.5905, 'learning_rate': 9.70071827613727e-06, 'epoch': 0.9}


  3%|▎         | 10000/300720 [2:07:53<63:14:07,  1.28it/s]

{'loss': 6.5636, 'learning_rate': 9.667464751263635e-06, 'epoch': 1.0}


  3%|▎         | 10024/300720 [2:08:11<53:10:30,  1.52it/s]
  3%|▎         | 10024/300720 [2:09:22<53:10:30,  1.52it/s]

{'eval_loss': 13.464204788208008, 'eval_mse': 0.14687635004520416, 'eval_rmse': 0.38324451446533203, 'eval_runtime': 71.009, 'eval_samples_per_second': 63.372, 'eval_steps_per_second': 1.986, 'epoch': 1.0}


  4%|▎         | 11000/300720 [2:21:45<61:24:46,  1.31it/s]  

{'loss': 6.0074, 'learning_rate': 9.634211226389998e-06, 'epoch': 1.1}


  4%|▍         | 12000/300720 [2:34:24<61:11:21,  1.31it/s]

{'loss': 5.952, 'learning_rate': 9.600957701516362e-06, 'epoch': 1.2}


  4%|▍         | 13000/300720 [2:47:06<60:20:58,  1.32it/s]

{'loss': 5.7705, 'learning_rate': 9.567704176642725e-06, 'epoch': 1.3}


  5%|▍         | 14000/300720 [2:59:50<60:46:14,  1.31it/s]

{'loss': 5.7513, 'learning_rate': 9.534450651769088e-06, 'epoch': 1.4}


  5%|▍         | 15000/300720 [3:12:33<62:10:04,  1.28it/s]

{'loss': 5.8578, 'learning_rate': 9.501197126895452e-06, 'epoch': 1.5}


  5%|▌         | 16000/300720 [3:25:20<59:22:50,  1.33it/s]

{'loss': 5.8182, 'learning_rate': 9.467943602021815e-06, 'epoch': 1.6}


  6%|▌         | 17000/300720 [3:38:08<60:05:42,  1.31it/s]

{'loss': 5.948, 'learning_rate': 9.434690077148178e-06, 'epoch': 1.7}


  6%|▌         | 18000/300720 [3:51:00<60:59:30,  1.29it/s]

{'loss': 5.849, 'learning_rate': 9.401436552274543e-06, 'epoch': 1.8}


  6%|▋         | 19000/300720 [4:03:53<61:21:48,  1.28it/s]

{'loss': 5.9265, 'learning_rate': 9.368183027400906e-06, 'epoch': 1.9}


  7%|▋         | 20000/300720 [4:16:50<58:34:13,  1.33it/s]

{'loss': 5.7143, 'learning_rate': 9.334929502527269e-06, 'epoch': 2.0}


  7%|▋         | 20048/300720 [4:17:27<50:14:45,  1.55it/s]
  7%|▋         | 20048/300720 [4:18:39<50:14:45,  1.55it/s]

{'eval_loss': 13.6866455078125, 'eval_mse': 0.1524030566215515, 'eval_rmse': 0.3903883397579193, 'eval_runtime': 72.3349, 'eval_samples_per_second': 62.211, 'eval_steps_per_second': 1.949, 'epoch': 2.0}


  7%|▋         | 21000/300720 [4:30:45<58:29:06,  1.33it/s]  

{'loss': 4.9689, 'learning_rate': 9.301675977653633e-06, 'epoch': 2.09}


  7%|▋         | 22000/300720 [4:43:28<58:10:20,  1.33it/s]

{'loss': 5.0352, 'learning_rate': 9.268422452779996e-06, 'epoch': 2.19}


  8%|▊         | 23000/300720 [4:56:11<58:07:52,  1.33it/s]

{'loss': 4.979, 'learning_rate': 9.235168927906359e-06, 'epoch': 2.29}


  8%|▊         | 24000/300720 [5:08:57<55:11:19,  1.39it/s]

{'loss': 5.1343, 'learning_rate': 9.201915403032722e-06, 'epoch': 2.39}


  8%|▊         | 25000/300720 [5:21:42<58:47:23,  1.30it/s]

{'loss': 5.1269, 'learning_rate': 9.168661878159085e-06, 'epoch': 2.49}


  9%|▊         | 26000/300720 [5:34:27<58:11:03,  1.31it/s]

{'loss': 4.9338, 'learning_rate': 9.13540835328545e-06, 'epoch': 2.59}


  9%|▉         | 27000/300720 [5:47:16<57:34:34,  1.32it/s]

{'loss': 5.2179, 'learning_rate': 9.102154828411812e-06, 'epoch': 2.69}


  9%|▉         | 28000/300720 [6:00:06<57:45:03,  1.31it/s]

{'loss': 5.0709, 'learning_rate': 9.068901303538175e-06, 'epoch': 2.79}


 10%|▉         | 29000/300720 [6:12:59<61:15:53,  1.23it/s]

{'loss': 5.0444, 'learning_rate': 9.03564777866454e-06, 'epoch': 2.89}


 10%|▉         | 30000/300720 [6:25:52<58:57:52,  1.28it/s]

{'loss': 4.9683, 'learning_rate': 9.002394253790902e-06, 'epoch': 2.99}


 10%|█         | 30072/300720 [6:26:48<51:04:49,  1.47it/s]
 10%|█         | 30072/300720 [6:28:00<51:04:49,  1.47it/s]

{'eval_loss': 14.695479393005371, 'eval_mse': 0.15725095570087433, 'eval_rmse': 0.3965488076210022, 'eval_runtime': 71.8213, 'eval_samples_per_second': 62.655, 'eval_steps_per_second': 1.963, 'epoch': 3.0}


 10%|█         | 30072/300720 [6:28:01<58:12:17,  1.29it/s]


{'train_runtime': 23282.0047, 'train_samples_per_second': 413.288, 'train_steps_per_second': 12.916, 'train_loss': 6.025541778954143, 'epoch': 3.0}


100%|██████████| 141/141 [01:07<00:00,  2.08it/s]
[32m[I 2023-05-12 06:37:31,724][0m Trial 9 finished with value: 0.38324451446533203 and parameters: {'band_width': 0.17081934079556765}. Best is trial 2 with value: 0.38199731707572937.[0m


{'test_loss': 13.464204788208008, 'test_mse': 0.14687635004520416, 'test_rmse': 0.38324451446533203, 'test_runtime': 68.328, 'test_samples_per_second': 65.859, 'test_steps_per_second': 2.064}


In [15]:
study.best_params

{'band_width': 0.13604422505592878}