In [None]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
from transformers import ViTFeatureExtractor

model_ckpt = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_ckpt)

In [None]:
images_root = 'Affectnet/Manually_Annotated/Manually_Annotated_Images'

In [None]:
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import os


def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

class AffectNetDataset(Dataset):
    def __init__(self,
                 csvfile,
                 root,
                 mode='classification',
                 crop=False,
                 transform=None,
                 invalid_files=None):
        self.df = pd.read_csv(csvfile)
        self.root = root
        self.mode = mode
        self.crop = crop
        self.transform = transform
        self.invalid_files = invalid_files
        
        if self.invalid_files:
            self.df = self.df[~self.df['subDirectory_filePath'].isin(invalid_files)]
            self.df = self.df
        
        self.df = self.df[~((self.df['expression'] == 9) | (self.df['expression'] == 10))].reset_index(drop=True)
    
    def __getitem__(self, idx):
        try:
            img = pil_loader(os.path.join(self.root, self.df['subDirectory_filePath'][idx]))
        except KeyError:
            raise IndexError
        if self.crop:
            img = img.crop((self.df['face_x'][idx],
                            self.df['face_y'][idx],
                            self.df['face_x'][idx]+self.df['face_width'][idx],
                            self.df['face_y'][idx]+self.df['face_height'][idx],))
        if self.transform:
            img = self.transform(img)
        if self.mode == 'classification':
            target = torch.tensor(self.df['expression'][idx])
        elif self.mode == 'valence':
            target = torch.tensor([self.df['valence'][idx]])
        elif self.mode == 'arousal':
            target = torch.tensor([self.df['arousal'][idx]])
        else:
            target = torch.tensor([self.df['valence'][idx],
                                   self.df['arousal'][idx]])
        return img.float(), target.float()
    
    def __len__(self):
        return len(self.df)

In [None]:
from tqdm import tqdm

import pandas as pd

train_df = pd.read_csv('Affectnet/training.csv')
val_df = pd.read_csv('Affectnet/validation.csv')

def check_files(df):
    invalid_files = []
    for filename in tqdm(df['subDirectory_filePath']):
        try:
            pil_loader(os.path.join(images_root, filename))
        except:
            invalid_files.append(filename)
    print(invalid_files)
    return invalid_files

# train_invalid_files = check_files(train_df)
# val_invalid_files = check_files(val_df)

In [None]:
train_invalid_files = ['103/29a31ebf1567693f4644c8ba3476ca9a72ee07fe67a5860d98707a0a.jpg']
val_invalid_files = []

In [None]:
mode = 'arousal'
val_size = 1000
seed = 0

In [None]:
from torchvision.transforms import (Compose,
                                    Normalize,
                                    Resize,
                                    ToTensor)
from torch.utils.data import random_split


normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

transform = Compose([Resize(tuple(feature_extractor.size.values())),
                     ToTensor()])

train_dataset = AffectNetDataset('Affectnet/training.csv',
                                 images_root,
                                 mode,
                                 crop=False,
                                 transform=transform,
                                 invalid_files=train_invalid_files)

val_dataset = AffectNetDataset('Affectnet/validation.csv',
                               images_root,
                               mode,
                               crop=False,
                               transform=transform,
                               invalid_files=val_invalid_files)

print('train:', len(train_dataset))
print('validation:', len(val_dataset))


In [None]:
def collate_fn(examples):
    imgs, targets = zip(*examples)
    pixel_values = torch.stack(imgs)
    targets = torch.stack(targets)
    return {'pixel_values': pixel_values, 'labels': targets}

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4)
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

In [None]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  num_labels=1,
                                                  problem_type='regression')


## Balanced MSE
- paper: https://openaccess.thecvf.com/content/CVPR2022/papers/Ren_Balanced_MSE_for_Imbalanced_Visual_Regression_CVPR_2022_paper.pdf
- github: https://github.com/jiawei-ren/BalancedMSE/tree/main

Batch-based Monte-Carlo (BMC)を使う

In [None]:
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from torch.distributions import MultivariateNormal as MVN

def bmc_loss_md(pred, target, noise_var, device):
    I = torch.eye(pred.shape[-1]).to(device)
    logits = MVN(pred.unsqueeze(1), noise_var*I).log_prob(target.unsqueeze(0))
    loss = F.cross_entropy(logits, torch.arange(pred.shape[0]).to(device))
    loss = loss * (2 * noise_var).detach()
    
    return loss

class BMCLoss(_Loss):
    def __init__(self, init_noise_sigma=1., device=None, root=False):
        super(BMCLoss, self).__init__()
        self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma))
        self.device = device
        self.root = root
    
    def forward(self, pred, target):
        noise_var = self.noise_sigma ** 2
        loss = bmc_loss_md(pred, target, noise_var, self.device)
        return torch.sqrt(loss) if self.root else loss

In [None]:
from transformers import Trainer

class BMCLossTrainer(Trainer):
    def __init__(self,
                 model = None,
                 args = None,
                 data_collator = None,
                 train_dataset = None,
                 eval_dataset = None,
                 tokenizer = None,
                 model_init = None,
                 compute_metrics = None,
                 callbacks = None,
                 optimizers = (None, None),
                 preprocess_logits_for_metrics = None):
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
        self.loss_fct = BMCLoss(device=self.args.device).to(self.args.device)
    
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get('labels')
        outputs = model(**inputs)
        logits = outputs.get('logits')
        loss = self.loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

In [None]:
from transformers import TrainingArguments
import wandb

wandb.init(project='AffectNet-vit', name='arousal')

args = TrainingArguments(
    f"affectnet-balancedMSE-aro-no910",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=1e-6,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=16,
    num_train_epochs=30,
    weight_decay=1e-3,
    load_best_model_at_end=True,
    logging_dir='logs',
    remove_unused_columns=False,
    report_to='wandb'
)

In [None]:
from sklearn.metrics import mean_squared_error

def compute_metrics(eval_pred):
    preds, targets = eval_pred
    rmse = mean_squared_error(targets, preds, squared=False) / 2
    return {'rmse': rmse}

# class ComputeMetrics(object):
#     def __init__(self):
#         self.metrics = BMCLoss(device=device).to(device)
    
#     def __call__(self, eval_pred):
#         preds, targets = eval_pred
#         preds, targets = torch.tensor(preds).to(device), torch.tensor(targets).to(device)
#         bmse = self.metrics(preds, targets)
#         rmse = compute_metrics(eval_pred)
#         return {'bmse': bmse, 'rmse': rmse}

# compute_bmse_metrics = ComputeMetrics()


In [None]:
from transformers import EarlyStoppingCallback

trainer = BMCLossTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    # compute_metrics=compute_bmse_metrics,
    tokenizer=feature_extractor,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

In [None]:
trainer.train()

In [None]:
trainer.save_state()
trainer.save_model()
wandb.finish()

In [None]:
from tqdm import tqdm

def CLE_tokens(model, tokenizer, dataset, device):
    tokens = []
    labels = []
    for img, label in tqdm(dataset):
        feature = tokenizer(img, return_tensors='pt').pixel_values.to(device)
        with torch.no_grad():
            token = model(feature, output_hidden_states=True).hidden_states[-1][0,0,:]
        tokens.append(token.cpu())
        labels.append(label)
    return torch.stack(tokens).squeeze(), torch.stack(labels)

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from umap import UMAP

import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cm
import random

def plot_tokens(tokens, targets, n_neighbors):
    # tsne = TSNE(n_components=2)
    # zs = tsne.fit_transform(tokens.numpy())
    umap = UMAP(n_neighbors=n_neighbors)
    zs = umap.fit_transform(tokens.numpy())
    ys = targets.numpy()
    print(zs.shape)
    print(ys.shape)
    fig = plt.figure()
    ax = fig.add_subplot()
    ax.set_xlabel('feature-1')
    ax.set_ylabel('feature-2')
    
    for x, y in zip(zs, ys):
        mp = ax.scatter(x[0], x[1],
                        alpha=1,
                        c=y,
                        cmap='Oranges',
                        vmin=-1,
                        vmax=1,
                        s=3,)
    fig.colorbar(mp, ax=ax)
    plt.show()


In [None]:
tokens, targets = CLE_tokens(model,
                             feature_extractor,
                             val_dataset,
                             device)
plot_tokens(tokens, targets.squeeze(), 30)

In [None]:
from transformers import ViTForImageClassification

non_finetuned_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                                num_labels=1,
                                                                problem_type='regression')


In [None]:
tokens, targets = CLE_tokens(non_finetuned_model.to(device),
                             feature_extractor,
                             val_dataset,
                             device)
print(tokens.shape, targets.shape)

In [None]:
print(torch.min(targets), torch.max(targets))

In [None]:
plot_tokens(tokens, targets.squeeze(), 30)