In [1]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

In [2]:
from transformers import ViTFeatureExtractor

model_ckpt = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_ckpt)



In [3]:
images_root = 'Affectnet/Manually_Annotated/Manually_Annotated_Images'

In [4]:
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import os


def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

class AffectNetDataset(Dataset):
    def __init__(self,
                 csvfile,
                 root,
                 mode='classification',
                 transform=None,
                 invalid_files=None):
        self.df = pd.read_csv(csvfile)
        self.root = root
        self.mode = mode
        self.transform = transform
        self.invalid_files = invalid_files
        
        if self.invalid_files:
            self.df = self.df[~self.df['subDirectory_filePath'].isin(invalid_files)]
            self.df = self.df.reset_index(drop=True)
    
    def __getitem__(self, idx):
        img = pil_loader(os.path.join(self.root, self.df['subDirectory_filePath'][idx]))
        if self.transform:
            img = self.transform(img)
        if self.mode == 'classification':
            target = self.df['expression'][idx]
        elif self.mode == 'valence':
            target = self.df['valence'][idx]
        elif self.mode == 'arousal':
            target = self.df['arousal'][idx]
        else:
            target = torch.tensor([self.df['valence'][idx],
                                   self.df['arousal'][idx]])
        return img.float(), target.float()
    
    def __len__(self):
        return len(self.df)

In [None]:
from tqdm import tqdm

import pandas as pd

train_df = pd.read_csv('Affectnet/training.csv')
val_df = pd.read_csv('Affectnet/validation.csv')

def check_files(df):
    invalid_files = []
    for filename in tqdm(df['subDirectory_filePath']):
        try:
            pil_loader(os.path.join(images_root, filename))
        except:
            invalid_files.append(filename)
    print(invalid_files)
    return invalid_files

train_invalid_files = check_files(train_df)
val_invalid_files = check_files(val_df)

In [None]:
mode = 'valence-arousal'
val_size = 1000
seed = 0

In [None]:
from torchvision.transforms import (Compose,
                                    Normalize,
                                    Resize,
                                    ToTensor)
from torch.utils.data import random_split


normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

transform = Compose([Resize(tuple(feature_extractor.size.values())),
                     ToTensor()])

train_dataset = AffectNetDataset('Affectnet/training.csv',
                                 images_root,
                                 mode,
                                 transform,
                                 train_invalid_files)

valtest_dataset = AffectNetDataset('Affectnet/validation.csv',
                                   images_root,
                                   mode,
                                   transform,
                                   val_invalid_files)
val_dataset, test_dataset = random_split(valtest_dataset,
                                         [val_size, len(valtest_dataset) - val_size],
                                         generator=torch.Generator().manual_seed(seed))

print('train:', len(train_dataset))
print('validation:', len(val_dataset))
print('test:', len(test_dataset))

In [None]:
def collate_fn(examples):
    imgs, targets = zip(*examples)
    pixel_values = torch.stack(imgs)
    targets = torch.stack(targets)
    return {'pixel_values': pixel_values, 'labels': targets}

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4)
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

In [None]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  num_labels=2,
                                                  problem_type='regression')


## Balanced MSE
- paper: https://openaccess.thecvf.com/content/CVPR2022/papers/Ren_Balanced_MSE_for_Imbalanced_Visual_Regression_CVPR_2022_paper.pdf
- github: https://github.com/jiawei-ren/BalancedMSE/tree/main

Batch-based Monte-Carlo (BMC)を使う

In [None]:
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from torch.distributions import MultivariateNormal as MVN

def bmc_loss_md(pred, target, noise_var, device):
    I = torch.eye(pred.shape[-1]).to(device)
    logits = MVN(pred.unsqueeze(1), noise_var*I).log_prob(target.unsqueeze(0))
    loss = F.cross_entropy(logits, torch.arange(pred.shape[0]).to(device))
    loss = loss * (2 * noise_var).detach()
    
    return loss

class BMCLoss(_Loss):
    def __init__(self, init_noise_sigma=1., device=None):
        super(BMCLoss, self).__init__()
        self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma))
        self.device = device
    
    def forward(self, pred, target):
        noise_var = self.noise_sigma ** 2
        return bmc_loss_md(pred, target, noise_var, self.device)

In [None]:
from transformers import Trainer

class BMCLossTrainer(Trainer):
    def __init__(self,
                 model = None,
                 args = None,
                 data_collator = None,
                 train_dataset = None,
                 eval_dataset = None,
                 tokenizer = None,
                 model_init = None,
                 compute_metrics = None,
                 callbacks = None,
                 optimizers = (None, None),
                 preprocess_logits_for_metrics = None):
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
        self.loss_fct = BMCLoss(device=self.args.device).to(self.args.device)
    
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get('labels')
        outputs = model(**inputs)
        logits = outputs.get('logits')
        loss = self.loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    f"affectnet-balancedMSE-vit",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    logging_dir='logs',
    remove_unused_columns=False,
)

In [None]:
from datasets import load_metric
from sklearn.metrics import mean_squared_error

metric = load_metric('mse')

def compute_metrics(eval_pred):
    preds, targets = eval_pred
    # return metric.compute(predictions=preds, references=targets, squared=False)
    mse = mean_squared_error(targets, preds)
    return {'mse': mse}

In [None]:
trainer = BMCLossTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
)

In [None]:
trainer.predict(val_dataset)

In [None]:
trainer.train()

In [None]:
trainer.save_state()
trainer.save_model()