In [1]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
from transformers import ViTFeatureExtractor

model_ckpt = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_ckpt)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
images_root = 'Affectnet/Manually_Annotated/Manually_Annotated_Images'

In [4]:
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import os


def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

class AffectNetDataset(Dataset):
    def __init__(self,
                 csvfile,
                 root,
                 mode='classification',
                 crop=False,
                 transform=None,
                 invalid_files=None):
        self.df = pd.read_csv(csvfile)
        self.root = root
        self.mode = mode
        self.crop = crop
        self.transform = transform
        self.invalid_files = invalid_files
        
        if self.invalid_files:
            self.df = self.df[~self.df['subDirectory_filePath'].isin(invalid_files)]
            self.df = self.df.reset_index(drop=True)
    
    def __getitem__(self, idx):
        img = pil_loader(os.path.join(self.root, self.df['subDirectory_filePath'][idx]))
        if self.crop:
            img = img.crop((self.df['face_x'][idx],
                            self.df['face_y'][idx],
                            self.df['face_x'][idx]+self.df['face_width'][idx],
                            self.df['face_y'][idx]+self.df['face_height'][idx],))
        if self.transform:
            img = self.transform(img)
        if self.mode == 'classification':
            target = torch.tensor(self.df['expression'][idx])
        elif self.mode == 'valence':
            target = torch.tensor([self.df['valence'][idx]])
        elif self.mode == 'arousal':
            target = torch.tensor([self.df['arousal'][idx]])
        else:
            target = torch.tensor([self.df['valence'][idx],
                                   self.df['arousal'][idx]])
        return img.float(), target.float()
    
    def __len__(self):
        return len(self.df)

In [5]:
from tqdm import tqdm

import pandas as pd

train_df = pd.read_csv('Affectnet/training.csv')
val_df = pd.read_csv('Affectnet/validation.csv')

def check_files(df):
    invalid_files = []
    for filename in tqdm(df['subDirectory_filePath']):
        try:
            pil_loader(os.path.join(images_root, filename))
        except:
            invalid_files.append(filename)
    print(invalid_files)
    return invalid_files

# train_invalid_files = check_files(train_df)
# val_invalid_files = check_files(val_df)

In [6]:
train_invalid_files = ['103/29a31ebf1567693f4644c8ba3476ca9a72ee07fe67a5860d98707a0a.jpg']
val_invalid_files = []

In [7]:
mode = 'arousal'
val_size = 1000
seed = 0

In [8]:
from torchvision.transforms import (Compose,
                                    Normalize,
                                    Resize,
                                    ToTensor)
from torch.utils.data import random_split


normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

transform = Compose([Resize(tuple(feature_extractor.size.values())),
                     ToTensor()])

train_dataset = AffectNetDataset('Affectnet/training.csv',
                                 images_root,
                                 mode,
                                 crop=True,
                                 transform=transform,
                                 invalid_files=train_invalid_files)

val_dataset = AffectNetDataset('Affectnet/validation.csv',
                               images_root,
                               mode,
                               crop=True,
                               transform=transform,
                               invalid_files=val_invalid_files)

print('train:', len(train_dataset))
print('validation:', len(val_dataset))


train: 414798
validation: 5500


In [9]:
def collate_fn(examples):
    imgs, targets = zip(*examples)
    pixel_values = torch.stack(imgs)
    targets = torch.stack(targets)
    return {'pixel_values': pixel_values, 'labels': targets}

In [10]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4)
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

pixel_values torch.Size([4, 3, 224, 224])
labels torch.Size([4, 1])


In [11]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  num_labels=1,
                                                  problem_type='regression')


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Balanced MSE
- paper: https://openaccess.thecvf.com/content/CVPR2022/papers/Ren_Balanced_MSE_for_Imbalanced_Visual_Regression_CVPR_2022_paper.pdf
- github: https://github.com/jiawei-ren/BalancedMSE/tree/main

Batch-based Monte-Carlo (BMC)を使う

In [12]:
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from torch.distributions import MultivariateNormal as MVN

def bmc_loss_md(pred, target, noise_var, device):
    I = torch.eye(pred.shape[-1]).to(device)
    logits = MVN(pred.unsqueeze(1), noise_var*I).log_prob(target.unsqueeze(0))
    loss = F.cross_entropy(logits, torch.arange(pred.shape[0]).to(device))
    loss = loss * (2 * noise_var).detach()
    
    return loss

class BMCLoss(_Loss):
    def __init__(self, init_noise_sigma=1., device=None, root=False):
        super(BMCLoss, self).__init__()
        self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma))
        self.device = device
        self.root = root
    
    def forward(self, pred, target):
        noise_var = self.noise_sigma ** 2
        loss = bmc_loss_md(pred, target, noise_var, self.device)
        return torch.sqrt(loss) if self.root else loss

In [13]:
from transformers import Trainer

class BMCLossTrainer(Trainer):
    def __init__(self,
                 model = None,
                 args = None,
                 data_collator = None,
                 train_dataset = None,
                 eval_dataset = None,
                 tokenizer = None,
                 model_init = None,
                 compute_metrics = None,
                 callbacks = None,
                 optimizers = (None, None),
                 preprocess_logits_for_metrics = None):
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
        self.loss_fct = BMCLoss(device=self.args.device).to(self.args.device)
    
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get('labels')
        outputs = model(**inputs)
        logits = outputs.get('logits')
        loss = self.loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

In [14]:
from transformers import TrainingArguments
import wandb

wandb.init(project='AffectNet-vit', name='arousal')

args = TrainingArguments(
    f"affectnet-balancedMSE-aro",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=1e-6,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=16,
    num_train_epochs=30,
    weight_decay=1e-3,
    load_best_model_at_end=True,
    logging_dir='logs',
    remove_unused_columns=False,
    report_to='wandb'
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrkn[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [15]:
from sklearn.metrics import mean_squared_error

def compute_metrics(eval_pred):
    preds, targets = eval_pred
    rmse = mean_squared_error(targets, preds, squared=False) / 2
    return {'rmse': rmse}

class ComputeMetrics(object):
    def __init__(self):
        self.metrics = BMCLoss(device=device).to(device)
    
    def __call__(self, eval_pred):
        preds, targets = eval_pred
        preds, targets = torch.tensor(preds).to(device), torch.tensor(targets).to(device)
        bmse = self.metrics(preds, targets)
        rmse = compute_metrics(eval_pred)
        return {'bmse': bmse, 'rmse': rmse}

compute_bmse_metrics = ComputeMetrics()


In [16]:
from transformers import EarlyStoppingCallback

trainer = BMCLossTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    # compute_metrics=compute_bmse_metrics,
    tokenizer=feature_extractor,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

In [17]:
trainer.train()

  0%|          | 500/388890 [06:57<100:09:55,  1.08it/s]

{'loss': 7.0826, 'learning_rate': 9.987142893877447e-07, 'epoch': 0.04}


  0%|          | 1000/388890 [14:32<109:42:22,  1.02s/it]

{'loss': 6.8168, 'learning_rate': 9.97428578775489e-07, 'epoch': 0.08}


  0%|          | 1500/388890 [22:22<106:33:36,  1.01it/s]

{'loss': 6.7, 'learning_rate': 9.961428681632338e-07, 'epoch': 0.12}


  1%|          | 2000/388890 [30:14<108:47:48,  1.01s/it]

{'loss': 6.6521, 'learning_rate': 9.948571575509785e-07, 'epoch': 0.15}


  1%|          | 2500/388890 [38:09<98:09:01,  1.09it/s] 

{'loss': 6.6248, 'learning_rate': 9.93571446938723e-07, 'epoch': 0.19}


  1%|          | 3000/388890 [46:06<100:42:31,  1.06it/s]

{'loss': 6.6148, 'learning_rate': 9.922857363264676e-07, 'epoch': 0.23}


  1%|          | 3500/388890 [54:03<106:48:19,  1.00it/s]

{'loss': 6.5777, 'learning_rate': 9.910000257142122e-07, 'epoch': 0.27}


  1%|          | 4000/388890 [1:02:15<101:46:09,  1.05it/s]

{'loss': 6.555, 'learning_rate': 9.89714315101957e-07, 'epoch': 0.31}


  1%|          | 4500/388890 [1:10:15<100:41:25,  1.06it/s]

{'loss': 6.538, 'learning_rate': 9.884286044897014e-07, 'epoch': 0.35}


  1%|▏         | 5000/388890 [1:18:13<104:31:15,  1.02it/s]

{'loss': 6.5529, 'learning_rate': 9.87142893877446e-07, 'epoch': 0.39}


  1%|▏         | 5500/388890 [1:26:07<101:33:28,  1.05it/s]

{'loss': 6.5462, 'learning_rate': 9.858571832651907e-07, 'epoch': 0.42}


  2%|▏         | 6000/388890 [1:34:00<96:28:18,  1.10it/s] 

{'loss': 6.5471, 'learning_rate': 9.845714726529353e-07, 'epoch': 0.46}


  2%|▏         | 6500/388890 [1:41:58<96:14:52,  1.10it/s] 

{'loss': 6.5356, 'learning_rate': 9.832857620406798e-07, 'epoch': 0.5}


  2%|▏         | 7000/388890 [1:49:53<98:42:06,  1.07it/s] 

{'loss': 6.5169, 'learning_rate': 9.820000514284246e-07, 'epoch': 0.54}


  2%|▏         | 7500/388890 [1:57:49<101:30:57,  1.04it/s]

{'loss': 6.519, 'learning_rate': 9.80714340816169e-07, 'epoch': 0.58}


  2%|▏         | 8000/388890 [2:05:42<97:04:53,  1.09it/s] 

{'loss': 6.5261, 'learning_rate': 9.794286302039136e-07, 'epoch': 0.62}


  2%|▏         | 8500/388890 [2:13:35<100:45:30,  1.05it/s]

{'loss': 6.4937, 'learning_rate': 9.781429195916584e-07, 'epoch': 0.66}


  2%|▏         | 9000/388890 [2:21:34<96:45:17,  1.09it/s] 

{'loss': 6.5018, 'learning_rate': 9.76857208979403e-07, 'epoch': 0.69}


  2%|▏         | 9500/388890 [2:29:31<96:54:36,  1.09it/s] 

{'loss': 6.5083, 'learning_rate': 9.755714983671475e-07, 'epoch': 0.73}


  3%|▎         | 10000/388890 [2:37:29<96:43:06,  1.09it/s]

{'loss': 6.509, 'learning_rate': 9.742857877548922e-07, 'epoch': 0.77}


  3%|▎         | 10500/388890 [2:45:24<96:11:09,  1.09it/s] 

{'loss': 6.5175, 'learning_rate': 9.730000771426368e-07, 'epoch': 0.81}


  3%|▎         | 11000/388890 [2:53:17<99:22:50,  1.06it/s] 

{'loss': 6.488, 'learning_rate': 9.717143665303813e-07, 'epoch': 0.85}


  3%|▎         | 11500/388890 [3:01:14<101:23:52,  1.03it/s]

{'loss': 6.5001, 'learning_rate': 9.70428655918126e-07, 'epoch': 0.89}


  3%|▎         | 12000/388890 [3:09:07<105:52:37,  1.01s/it]

{'loss': 6.5057, 'learning_rate': 9.691429453058706e-07, 'epoch': 0.93}


  3%|▎         | 12500/388890 [3:17:01<101:45:59,  1.03it/s]

{'loss': 6.49, 'learning_rate': 9.678572346936151e-07, 'epoch': 0.96}


                                                            
  3%|▎         | 12963/388890 [3:26:19<86:46:03,  1.20it/s]

{'eval_loss': 5.432745933532715, 'eval_rmse': 0.49152782559394836, 'eval_runtime': 120.2544, 'eval_samples_per_second': 45.736, 'eval_steps_per_second': 2.861, 'epoch': 1.0}


  3%|▎         | 13000/388890 [3:27:04<94:20:08,  1.11it/s]  

{'loss': 6.5005, 'learning_rate': 9.665715240813599e-07, 'epoch': 1.0}


  3%|▎         | 13500/388890 [3:34:43<90:40:08,  1.15it/s] 

{'loss': 6.498, 'learning_rate': 9.652858134691044e-07, 'epoch': 1.04}


  4%|▎         | 14000/388890 [3:42:32<92:04:26,  1.13it/s] 

{'loss': 6.4771, 'learning_rate': 9.64000102856849e-07, 'epoch': 1.08}


  4%|▎         | 14500/388890 [3:50:26<94:09:10,  1.10it/s] 

{'loss': 6.4641, 'learning_rate': 9.627143922445935e-07, 'epoch': 1.12}


  4%|▍         | 15000/388890 [3:58:57<96:46:19,  1.07it/s] 

{'loss': 6.4719, 'learning_rate': 9.614286816323382e-07, 'epoch': 1.16}


  4%|▍         | 15500/388890 [4:06:56<96:41:28,  1.07it/s] 

{'loss': 6.4833, 'learning_rate': 9.601429710200828e-07, 'epoch': 1.2}


  4%|▍         | 16000/388890 [4:14:52<94:30:37,  1.10it/s] 

{'loss': 6.4837, 'learning_rate': 9.588572604078273e-07, 'epoch': 1.23}


  4%|▍         | 16500/388890 [4:23:07<95:05:21,  1.09it/s] 

{'loss': 6.4788, 'learning_rate': 9.57571549795572e-07, 'epoch': 1.27}


  4%|▍         | 17000/388890 [4:31:02<98:48:47,  1.05it/s] 

{'loss': 6.4658, 'learning_rate': 9.562858391833166e-07, 'epoch': 1.31}


  4%|▍         | 17500/388890 [4:38:58<96:45:33,  1.07it/s] 

{'loss': 6.4574, 'learning_rate': 9.550001285710611e-07, 'epoch': 1.35}


  5%|▍         | 18000/388890 [4:46:57<97:31:59,  1.06it/s] 

{'loss': 6.4672, 'learning_rate': 9.53714417958806e-07, 'epoch': 1.39}


  5%|▍         | 18500/388890 [4:54:50<91:06:18,  1.13it/s] 

{'loss': 6.483, 'learning_rate': 9.524287073465503e-07, 'epoch': 1.43}


  5%|▍         | 19000/388890 [5:02:45<97:33:45,  1.05it/s] 

{'loss': 6.4617, 'learning_rate': 9.51142996734295e-07, 'epoch': 1.47}


  5%|▌         | 19500/388890 [5:10:45<100:10:55,  1.02it/s]

{'loss': 6.468, 'learning_rate': 9.498572861220396e-07, 'epoch': 1.5}


  5%|▌         | 20000/388890 [5:18:43<103:12:10,  1.01s/it]

{'loss': 6.4581, 'learning_rate': 9.485715755097842e-07, 'epoch': 1.54}


  5%|▌         | 20500/388890 [5:26:41<101:07:25,  1.01it/s]

{'loss': 6.4844, 'learning_rate': 9.472858648975288e-07, 'epoch': 1.58}


  5%|▌         | 21000/388890 [5:34:37<96:25:29,  1.06it/s] 

{'loss': 6.4628, 'learning_rate': 9.460001542852734e-07, 'epoch': 1.62}


  6%|▌         | 21500/388890 [5:42:34<94:49:51,  1.08it/s] 

{'loss': 6.4429, 'learning_rate': 9.44714443673018e-07, 'epoch': 1.66}


  6%|▌         | 22000/388890 [5:50:28<99:36:42,  1.02it/s] 

{'loss': 6.4575, 'learning_rate': 9.434287330607626e-07, 'epoch': 1.7}


  6%|▌         | 22500/388890 [5:58:24<96:56:29,  1.05it/s] 

{'loss': 6.4673, 'learning_rate': 9.421430224485073e-07, 'epoch': 1.74}


  6%|▌         | 23000/388890 [6:06:22<98:54:42,  1.03it/s] 

{'loss': 6.4486, 'learning_rate': 9.408573118362518e-07, 'epoch': 1.77}


  6%|▌         | 23500/388890 [6:14:19<100:40:58,  1.01it/s]

{'loss': 6.4349, 'learning_rate': 9.395716012239965e-07, 'epoch': 1.81}


  6%|▌         | 24000/388890 [6:22:18<97:28:05,  1.04it/s] 

{'loss': 6.4411, 'learning_rate': 9.38285890611741e-07, 'epoch': 1.85}


  6%|▋         | 24500/388890 [6:30:10<96:08:16,  1.05it/s] 

{'loss': 6.4499, 'learning_rate': 9.370001799994856e-07, 'epoch': 1.89}


  6%|▋         | 25000/388890 [6:38:02<91:20:15,  1.11it/s] 

{'loss': 6.4334, 'learning_rate': 9.357144693872303e-07, 'epoch': 1.93}


  7%|▋         | 25500/388890 [6:45:53<91:53:34,  1.10it/s] 

{'loss': 6.449, 'learning_rate': 9.344287587749748e-07, 'epoch': 1.97}


                                                            
  7%|▋         | 25926/388890 [6:54:38<89:44:35,  1.12it/s]

{'eval_loss': 5.396358489990234, 'eval_rmse': 0.4805545210838318, 'eval_runtime': 120.7844, 'eval_samples_per_second': 45.536, 'eval_steps_per_second': 2.848, 'epoch': 2.0}


  7%|▋         | 26000/388890 [6:55:58<89:32:20,  1.13it/s]  

{'loss': 6.423, 'learning_rate': 9.331430481627195e-07, 'epoch': 2.01}


  7%|▋         | 26500/388890 [7:03:45<91:45:57,  1.10it/s] 

{'loss': 6.4315, 'learning_rate': 9.318573375504641e-07, 'epoch': 2.04}


  7%|▋         | 27000/388890 [7:11:32<90:00:28,  1.12it/s] 

{'loss': 6.4341, 'learning_rate': 9.305716269382087e-07, 'epoch': 2.08}


  7%|▋         | 27500/388890 [7:19:33<96:02:27,  1.05it/s] 

{'loss': 6.405, 'learning_rate': 9.292859163259533e-07, 'epoch': 2.12}


  7%|▋         | 28000/388890 [7:27:31<96:25:15,  1.04it/s] 

{'loss': 6.4305, 'learning_rate': 9.280002057136979e-07, 'epoch': 2.16}


  7%|▋         | 28500/388890 [7:35:45<101:19:15,  1.01s/it]

{'loss': 6.4334, 'learning_rate': 9.267144951014425e-07, 'epoch': 2.2}


  7%|▋         | 29000/388890 [7:43:49<88:48:45,  1.13it/s] 

{'loss': 6.4306, 'learning_rate': 9.254287844891871e-07, 'epoch': 2.24}


  8%|▊         | 29500/388890 [7:51:48<89:39:42,  1.11it/s] 

{'loss': 6.4287, 'learning_rate': 9.241430738769317e-07, 'epoch': 2.28}


  8%|▊         | 30000/388890 [8:00:09<93:03:42,  1.07it/s] 

{'loss': 6.4301, 'learning_rate': 9.228573632646763e-07, 'epoch': 2.31}


  8%|▊         | 30500/388890 [8:08:10<102:20:15,  1.03s/it]

{'loss': 6.4378, 'learning_rate': 9.21571652652421e-07, 'epoch': 2.35}


  8%|▊         | 31000/388890 [8:16:09<94:30:57,  1.05it/s] 

{'loss': 6.4189, 'learning_rate': 9.202859420401655e-07, 'epoch': 2.39}


  8%|▊         | 31500/388890 [8:24:07<90:31:57,  1.10it/s] 

{'loss': 6.4227, 'learning_rate': 9.190002314279101e-07, 'epoch': 2.43}


  8%|▊         | 32000/388890 [8:32:03<92:14:11,  1.07it/s] 

{'loss': 6.4317, 'learning_rate': 9.177145208156548e-07, 'epoch': 2.47}


  8%|▊         | 32500/388890 [8:40:01<97:57:54,  1.01it/s] 

{'loss': 6.4377, 'learning_rate': 9.164288102033993e-07, 'epoch': 2.51}


  8%|▊         | 33000/388890 [8:48:00<91:27:41,  1.08it/s] 

{'loss': 6.4203, 'learning_rate': 9.15143099591144e-07, 'epoch': 2.55}


  9%|▊         | 33500/388890 [8:56:00<106:28:32,  1.08s/it]

{'loss': 6.4309, 'learning_rate': 9.138573889788886e-07, 'epoch': 2.58}


  9%|▊         | 34000/388890 [9:03:59<93:50:57,  1.05it/s] 

{'loss': 6.42, 'learning_rate': 9.125716783666332e-07, 'epoch': 2.62}


  9%|▉         | 34500/388890 [9:12:00<92:26:01,  1.06it/s] 

{'loss': 6.4112, 'learning_rate': 9.112859677543778e-07, 'epoch': 2.66}


  9%|▉         | 35000/388890 [9:19:59<90:59:36,  1.08it/s] 

{'loss': 6.4411, 'learning_rate': 9.100002571421223e-07, 'epoch': 2.7}


  9%|▉         | 35500/388890 [9:27:54<97:40:45,  1.00it/s] 

{'loss': 6.4184, 'learning_rate': 9.08714546529867e-07, 'epoch': 2.74}


  9%|▉         | 36000/388890 [9:35:56<93:09:16,  1.05it/s] 

{'loss': 6.4179, 'learning_rate': 9.074288359176116e-07, 'epoch': 2.78}


  9%|▉         | 36500/388890 [9:43:55<92:55:12,  1.05it/s] 

{'loss': 6.429, 'learning_rate': 9.061431253053562e-07, 'epoch': 2.82}


 10%|▉         | 37000/388890 [9:51:55<96:12:15,  1.02it/s] 

{'loss': 6.4178, 'learning_rate': 9.048574146931008e-07, 'epoch': 2.85}


 10%|▉         | 37500/388890 [9:59:56<87:10:00,  1.12it/s] 

{'loss': 6.4101, 'learning_rate': 9.035717040808455e-07, 'epoch': 2.89}


 10%|▉         | 38000/388890 [10:07:54<90:52:56,  1.07it/s] 

{'loss': 6.4018, 'learning_rate': 9.0228599346859e-07, 'epoch': 2.93}


 10%|▉         | 38500/388890 [10:15:55<88:17:46,  1.10it/s] 

{'loss': 6.4205, 'learning_rate': 9.010002828563346e-07, 'epoch': 2.97}


                                                             
 10%|█         | 38889/388890 [10:24:09<84:12:11,  1.15it/s]

{'eval_loss': 5.388031482696533, 'eval_rmse': 0.5169902443885803, 'eval_runtime': 121.0701, 'eval_samples_per_second': 45.428, 'eval_steps_per_second': 2.841, 'epoch': 3.0}


 10%|█         | 39000/388890 [10:26:00<81:18:03,  1.20it/s]  

{'loss': 6.4143, 'learning_rate': 8.997145722440793e-07, 'epoch': 3.01}


 10%|█         | 39500/388890 [10:33:33<84:13:00,  1.15it/s]

{'loss': 6.3934, 'learning_rate': 8.984288616318238e-07, 'epoch': 3.05}


 10%|█         | 40000/388890 [10:41:17<86:03:10,  1.13it/s] 

{'loss': 6.4021, 'learning_rate': 8.971431510195685e-07, 'epoch': 3.09}


 10%|█         | 40500/388890 [10:49:13<101:56:23,  1.05s/it]

{'loss': 6.3652, 'learning_rate': 8.95857440407313e-07, 'epoch': 3.12}


 11%|█         | 41000/388890 [10:57:32<116:28:06,  1.21s/it]

{'loss': 6.4092, 'learning_rate': 8.945717297950577e-07, 'epoch': 3.16}


 11%|█         | 41500/388890 [11:05:37<101:17:46,  1.05s/it]

{'loss': 6.4015, 'learning_rate': 8.932860191828023e-07, 'epoch': 3.2}


 11%|█         | 42000/388890 [11:13:41<89:39:56,  1.07it/s] 

{'loss': 6.3749, 'learning_rate': 8.920003085705468e-07, 'epoch': 3.24}


 11%|█         | 42500/388890 [11:21:49<93:52:31,  1.02it/s] 

{'loss': 6.4005, 'learning_rate': 8.907145979582915e-07, 'epoch': 3.28}


 11%|█         | 43000/388890 [11:29:50<95:21:44,  1.01it/s] 

{'loss': 6.4118, 'learning_rate': 8.894288873460361e-07, 'epoch': 3.32}


 11%|█         | 43500/388890 [11:37:50<90:43:38,  1.06it/s] 

{'loss': 6.3941, 'learning_rate': 8.881431767337807e-07, 'epoch': 3.36}


 11%|█▏        | 44000/388890 [11:45:53<95:51:40,  1.00s/it] 

{'loss': 6.4048, 'learning_rate': 8.868574661215253e-07, 'epoch': 3.39}


 11%|█▏        | 44500/388890 [11:53:52<91:07:44,  1.05it/s] 

{'loss': 6.4144, 'learning_rate': 8.855717555092698e-07, 'epoch': 3.43}


 12%|█▏        | 45000/388890 [12:01:51<89:50:14,  1.06it/s] 

{'loss': 6.4015, 'learning_rate': 8.842860448970145e-07, 'epoch': 3.47}


 12%|█▏        | 45500/388890 [12:09:50<91:13:44,  1.05it/s] 

{'loss': 6.3783, 'learning_rate': 8.830003342847591e-07, 'epoch': 3.51}


 12%|█▏        | 46000/388890 [12:17:49<91:50:24,  1.04it/s] 

{'loss': 6.399, 'learning_rate': 8.817146236725037e-07, 'epoch': 3.55}


 12%|█▏        | 46500/388890 [12:25:46<87:12:21,  1.09it/s] 

{'loss': 6.388, 'learning_rate': 8.804289130602483e-07, 'epoch': 3.59}


 12%|█▏        | 47000/388890 [12:33:45<89:21:13,  1.06it/s] 

{'loss': 6.4146, 'learning_rate': 8.79143202447993e-07, 'epoch': 3.63}


 12%|█▏        | 47500/388890 [12:41:45<90:03:19,  1.05it/s] 

{'loss': 6.4003, 'learning_rate': 8.778574918357375e-07, 'epoch': 3.66}


 12%|█▏        | 48000/388890 [12:49:46<85:56:03,  1.10it/s] 

{'loss': 6.3783, 'learning_rate': 8.765717812234821e-07, 'epoch': 3.7}


 12%|█▏        | 48500/388890 [12:57:53<85:15:06,  1.11it/s] 

{'loss': 6.3902, 'learning_rate': 8.752860706112268e-07, 'epoch': 3.74}


 13%|█▎        | 49000/388890 [13:05:51<86:45:45,  1.09it/s] 

{'loss': 6.3764, 'learning_rate': 8.740003599989713e-07, 'epoch': 3.78}


 13%|█▎        | 49500/388890 [13:13:50<89:00:10,  1.06it/s] 

{'loss': 6.4013, 'learning_rate': 8.72714649386716e-07, 'epoch': 3.82}


 13%|█▎        | 50000/388890 [13:21:51<92:20:22,  1.02it/s] 

{'loss': 6.4027, 'learning_rate': 8.714289387744605e-07, 'epoch': 3.86}


 13%|█▎        | 50500/388890 [13:29:51<87:34:43,  1.07it/s] 

{'loss': 6.382, 'learning_rate': 8.701432281622052e-07, 'epoch': 3.9}


 13%|█▎        | 51000/388890 [13:37:52<92:34:17,  1.01it/s] 

{'loss': 6.3824, 'learning_rate': 8.688575175499498e-07, 'epoch': 3.93}


 13%|█▎        | 51500/388890 [13:45:50<86:32:30,  1.08it/s] 

{'loss': 6.4158, 'learning_rate': 8.675718069376943e-07, 'epoch': 3.97}


                                                             
 13%|█▎        | 51852/388890 [13:53:27<75:04:46,  1.25it/s]

{'eval_loss': 5.3756866455078125, 'eval_rmse': 0.5580943822860718, 'eval_runtime': 122.1294, 'eval_samples_per_second': 45.034, 'eval_steps_per_second': 2.817, 'epoch': 4.0}


 13%|█▎        | 52000/388890 [13:55:53<78:35:42,  1.19it/s]  

{'loss': 6.3545, 'learning_rate': 8.66286096325439e-07, 'epoch': 4.01}


 13%|█▎        | 52500/388890 [14:03:27<82:10:55,  1.14it/s] 

{'loss': 6.4035, 'learning_rate': 8.650003857131836e-07, 'epoch': 4.05}


 14%|█▎        | 53000/388890 [14:11:08<82:43:34,  1.13it/s]

{'loss': 6.3795, 'learning_rate': 8.637146751009282e-07, 'epoch': 4.09}


 14%|█▍        | 53500/388890 [14:19:04<86:55:36,  1.07it/s] 

{'loss': 6.3849, 'learning_rate': 8.624289644886728e-07, 'epoch': 4.13}


 14%|█▍        | 54000/388890 [14:27:00<88:16:44,  1.05it/s]

{'loss': 6.3624, 'learning_rate': 8.611432538764175e-07, 'epoch': 4.17}


 14%|█▍        | 54500/388890 [14:35:05<92:54:58,  1.00s/it] 

{'loss': 6.3629, 'learning_rate': 8.59857543264162e-07, 'epoch': 4.2}


 14%|█▍        | 55000/388890 [14:43:35<91:54:37,  1.01it/s] 

{'loss': 6.3701, 'learning_rate': 8.585718326519066e-07, 'epoch': 4.24}


 14%|█▍        | 55500/388890 [14:51:39<86:32:18,  1.07it/s] 

{'loss': 6.377, 'learning_rate': 8.572861220396512e-07, 'epoch': 4.28}


 14%|█▍        | 56000/388890 [14:59:40<89:03:16,  1.04it/s] 

{'loss': 6.3598, 'learning_rate': 8.560004114273958e-07, 'epoch': 4.32}


 15%|█▍        | 56500/388890 [15:07:41<87:46:39,  1.05it/s] 

{'loss': 6.3743, 'learning_rate': 8.547147008151405e-07, 'epoch': 4.36}


 15%|█▍        | 57000/388890 [15:15:45<89:54:50,  1.03it/s] 

{'loss': 6.3623, 'learning_rate': 8.53428990202885e-07, 'epoch': 4.4}


 15%|█▍        | 57500/388890 [15:23:43<84:04:22,  1.09it/s]

{'loss': 6.396, 'learning_rate': 8.521432795906297e-07, 'epoch': 4.44}


 15%|█▍        | 58000/388890 [15:31:47<89:34:28,  1.03it/s]

{'loss': 6.377, 'learning_rate': 8.508575689783743e-07, 'epoch': 4.47}


 15%|█▌        | 58500/388890 [15:39:47<87:29:50,  1.05it/s] 

{'loss': 6.3555, 'learning_rate': 8.495718583661188e-07, 'epoch': 4.51}


 15%|█▌        | 59000/388890 [15:47:51<95:39:08,  1.04s/it] 

{'loss': 6.3659, 'learning_rate': 8.482861477538635e-07, 'epoch': 4.55}


 15%|█▌        | 59500/388890 [15:56:00<88:56:47,  1.03it/s] 

{'loss': 6.3932, 'learning_rate': 8.470004371416081e-07, 'epoch': 4.59}


 15%|█▌        | 60000/388890 [16:04:27<94:01:39,  1.03s/it] 

{'loss': 6.3667, 'learning_rate': 8.457147265293527e-07, 'epoch': 4.63}


 16%|█▌        | 60500/388890 [16:13:03<83:18:35,  1.09it/s] 

{'loss': 6.3532, 'learning_rate': 8.444290159170973e-07, 'epoch': 4.67}


 16%|█▌        | 61000/388890 [16:21:45<92:25:48,  1.01s/it] 

{'loss': 6.3665, 'learning_rate': 8.431433053048419e-07, 'epoch': 4.71}


 16%|█▌        | 61500/388890 [16:30:28<104:40:36,  1.15s/it]

{'loss': 6.3603, 'learning_rate': 8.418575946925865e-07, 'epoch': 4.74}


 16%|█▌        | 62000/388890 [16:39:05<89:02:48,  1.02it/s] 

{'loss': 6.366, 'learning_rate': 8.405718840803311e-07, 'epoch': 4.78}


 16%|█▌        | 62500/388890 [16:48:13<111:50:19,  1.23s/it]

{'loss': 6.3587, 'learning_rate': 8.392861734680757e-07, 'epoch': 4.82}


 16%|█▌        | 63000/388890 [16:57:19<100:00:01,  1.10s/it]

{'loss': 6.345, 'learning_rate': 8.380004628558203e-07, 'epoch': 4.86}


 16%|█▋        | 63500/388890 [17:06:38<86:58:07,  1.04it/s] 

{'loss': 6.3712, 'learning_rate': 8.367147522435651e-07, 'epoch': 4.9}


 16%|█▋        | 64000/388890 [17:15:44<130:46:43,  1.45s/it]

{'loss': 6.367, 'learning_rate': 8.354290416313095e-07, 'epoch': 4.94}


 17%|█▋        | 64500/388890 [17:25:07<89:54:04,  1.00it/s] 

{'loss': 6.3641, 'learning_rate': 8.341433310190543e-07, 'epoch': 4.98}


                                                             
 17%|█▋        | 64815/388890 [17:34:50<99:11:14,  1.10s/it]

{'eval_loss': 5.381923675537109, 'eval_rmse': 0.5891162157058716, 'eval_runtime': 122.86, 'eval_samples_per_second': 44.766, 'eval_steps_per_second': 2.8, 'epoch': 5.0}


 17%|█▋        | 65000/388890 [17:38:17<87:41:02,  1.03it/s]  

{'loss': 6.3474, 'learning_rate': 8.328576204067989e-07, 'epoch': 5.01}


 17%|█▋        | 65500/388890 [17:47:11<127:36:07,  1.42s/it]

{'loss': 6.3267, 'learning_rate': 8.315719097945433e-07, 'epoch': 5.05}


 17%|█▋        | 66000/388890 [17:56:44<87:32:27,  1.02it/s] 

{'loss': 6.3519, 'learning_rate': 8.302861991822881e-07, 'epoch': 5.09}


 17%|█▋        | 66500/388890 [18:06:06<91:16:24,  1.02s/it] 

{'loss': 6.3609, 'learning_rate': 8.290004885700325e-07, 'epoch': 5.13}


 17%|█▋        | 67000/388890 [18:15:47<98:05:09,  1.10s/it] 

{'loss': 6.3646, 'learning_rate': 8.277147779577773e-07, 'epoch': 5.17}


 17%|█▋        | 67500/388890 [18:31:38<177:14:58,  1.99s/it]

{'loss': 6.3301, 'learning_rate': 8.264290673455219e-07, 'epoch': 5.21}


 17%|█▋        | 68000/388890 [18:44:28<129:43:20,  1.46s/it]

{'loss': 6.3423, 'learning_rate': 8.251433567332665e-07, 'epoch': 5.25}


 18%|█▊        | 68500/388890 [18:55:39<165:08:10,  1.86s/it]

{'loss': 6.3535, 'learning_rate': 8.238576461210111e-07, 'epoch': 5.28}


 18%|█▊        | 69000/388890 [19:06:00<99:49:36,  1.12s/it] 

{'loss': 6.3403, 'learning_rate': 8.225719355087557e-07, 'epoch': 5.32}


 18%|█▊        | 69500/388890 [19:16:34<89:15:11,  1.01s/it] 

{'loss': 6.3402, 'learning_rate': 8.212862248965003e-07, 'epoch': 5.36}


 18%|█▊        | 70000/388890 [19:27:21<139:39:49,  1.58s/it]

{'loss': 6.3425, 'learning_rate': 8.200005142842449e-07, 'epoch': 5.4}


 18%|█▊        | 70500/388890 [19:37:51<95:39:15,  1.08s/it] 

{'loss': 6.3372, 'learning_rate': 8.187148036719895e-07, 'epoch': 5.44}


 18%|█▊        | 71000/388890 [19:48:22<100:54:10,  1.14s/it]

{'loss': 6.3534, 'learning_rate': 8.174290930597341e-07, 'epoch': 5.48}


 18%|█▊        | 71331/388890 [19:56:03<166:59:03,  1.89s/it]

In [None]:
trainer.save_state()
trainer.save_model()
wandb.finish()

In [None]:
from tqdm import tqdm

def CLE_tokens(model, tokenizer, dataset, device):
    tokens = []
    labels = []
    for img, label in tqdm(dataset):
        feature = tokenizer(img, return_tensors='pt').pixel_values.to(device)
        with torch.no_grad():
            token = model(feature, output_hidden_states=True).hidden_states[-1][0,0,:]
        tokens.append(token.cpu())
        labels.append(label)
    return torch.stack(tokens).squeeze(), torch.stack(labels)

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from umap import UMAP

import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cm
import random

def plot_tokens(tokens, targets, n_neighbors):
    # tsne = TSNE(n_components=2)
    # zs = tsne.fit_transform(tokens.numpy())
    umap = UMAP(n_neighbors=n_neighbors)
    zs = umap.fit_transform(tokens.numpy())
    ys = targets.numpy()
    print(zs.shape)
    print(ys.shape)
    fig = plt.figure()
    ax = fig.add_subplot()
    ax.set_xlabel('feature-1')
    ax.set_ylabel('feature-2')
    
    for x, y in zip(zs, ys):
        mp = ax.scatter(x[0], x[1],
                        alpha=1,
                        c=y,
                        cmap='Oranges',
                        s=3,)
    fig.colorbar(mp, ax=ax)
    plt.show()


In [None]:
from transformers import ViTForImageClassification

non_finetuned_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                                num_labels=2,
                                                                problem_type='regression')


In [None]:
tokens, targets = CLE_tokens(non_finetuned_model.to(device),
                             feature_extractor,
                             val_dataset,
                             device)
print(tokens.shape, targets.shape)

In [None]:
plot_tokens()