In [1]:
# Set up
import sys
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import ViTFeatureExtractor, ViTForImageClassification

import warnings
warnings.filterwarnings('ignore')

sys.path.append('..')

dataset_path = '../Datasets/CIFAR10'
img_path = dataset_path + '/images'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
# Set data owner id of interest
data_owner_id = 'A'

In [3]:
from core.ai.utils import seed_everything
seed_everything()

In [4]:
from core.ai.utils import label_enc

# Load data owner dataset
data_owner_dataset = pd.read_excel(dataset_path + '/CIFAR10dataOwnerInfo.xlsx', sheet_name=data_owner_id)
data_owner_dataset.image = [f'{img_path}/{image}' for image in data_owner_dataset.image]
num_classes = data_owner_dataset.label_name.nunique()
images, labels, label2id, id2label = label_enc(data_owner_dataset)

# Create data owner's model 
from core.ai.model import get_vit_model
vit_feature_extractor, vit_model = get_vit_model(device)

In [8]:
from core.ai.dataset import get_loader
import albumentations as A
train_transform = A.Compose([
    A.Resize(224, 224)
])
data_loader = get_loader(images, labels, vit_feature_extractor, train_transform,
                         pre_trained_model=vit_model, device=device, shuffle=False)

In [11]:
import torch
model = torch.load('./saved_models/model_A_1.pt')
print(model)

Sequential(
  (0): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Linear(in_features=768, out_features=128, bias=True)
  (2): GELU()
  (3): Linear(in_features=128, out_features=6, bias=True)
)


In [14]:
# Get predictions
model.eval()
predictions = []
for batch_images, batch_labels in tqdm(data_loader):
    logits = model(batch_images.to(device))
    predictions += logits.argmax(1).cpu().tolist()


In [40]:
from sklearn import metrics
# Get accuracy
accuracy_score = metrics.accuracy_score(labels, predictions)
print(f'In sample accuracy score: {accuracy_score:.4f}')

# Get f1 score
f1_score = metrics.f1_score(labels, predictions, average='micro')
print(f'In sample f1 score: {f1_score:.4f}')


In sample accuracy score: 0.980
In sample f1 score: 0.980


In [41]:
# Get weighted accuracy
sample_weight = [5 if label==label2id['other'] else 1 for label in labels]
weighted_accuracy_score = metrics.accuracy_score(labels, predictions, sample_weight=sample_weight)
print(f'In sample weighted accuracy score: {weighted_accuracy_score:.4f}')

# Get weighted f1 score
f1_score = metrics.f1_score(labels, predictions, average='micro', sample_weight=sample_weight)
print(f'In sample weighted f1 score: {f1_score:.4f}')

In sample weighted accuracy score: 0.978
In sample weighted f1 score: 0.978
