In [1]:
# Load necessary libraries
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

import sys
sys.path.append('..')

import warnings
warnings.filterwarnings('ignore')

# Load project paths
dataset_path = '../Datasets/CIFAR10'
img_path = dataset_path + '/images'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Set a fixed random seed
from core.ai.utils import seed_everything
seed_everything()

In [2]:
# Set data owner id of interest
data_owner_id = 'A'
data_owner_model_name = './saved_models/model_A_1.pt'

## Train Dataset Evaluation

In [3]:
from core.ai.utils import label_enc

# Load data owner dataset
data_owner_dataset = pd.read_excel(dataset_path + '/CIFAR10dataOwnerInfo.xlsx', sheet_name=data_owner_id)
data_owner_dataset.image = [f'{img_path}/{image}' for image in data_owner_dataset.image]
num_classes = data_owner_dataset.label_name.nunique()
images, labels, label2id, id2label = label_enc(data_owner_dataset)

# Create data owner's model 
from core.ai.model import get_vit_model
vit_feature_extractor, vit_model = get_vit_model(device)

In [4]:
from core.ai.dataset import get_loader
import albumentations as A
eval_transform = A.Compose([
    A.Resize(224, 224)
])
data_loader = get_loader(images, labels, vit_feature_extractor, eval_transform,
                         pre_trained_model=vit_model, device=device, shuffle=False)

In [5]:
import torch
model = torch.load(data_owner_model_name)
print(model)

Sequential(
  (0): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Linear(in_features=768, out_features=128, bias=True)
  (2): GELU()
  (3): Linear(in_features=128, out_features=6, bias=True)
)


In [6]:
# Get predictions
def predict(data_loader, model, device='cpu'):
    model.eval()
    predictions = []
    for batch_images, _ in tqdm(data_loader):
        logits = model(batch_images.to(device))
        predictions += logits.argmax(1).cpu().tolist()
    
    return predictions

predictions = predict(data_loader, model, device)

100%|██████████| 422/422 [10:50<00:00,  1.54s/it]


In [7]:
# Get and print evaluation metrics
from core.ai.utils import get_metrics
accuracy_score, f1_score, weighted_accuracy_score, weighted_f1_score = get_metrics(
    labels, predictions, label2id)

print(f'In sample accuracy score: {accuracy_score:.4f}')
print(f'In sample f1 score: {f1_score:.4f}')
print(f'In sample weighted accuracy score: {weighted_accuracy_score:.4f}')
print(f'In sample weighted f1 score: {weighted_f1_score:.4f}')


In sample accuracy score: 0.9802
In sample f1 score: 0.9802
In sample weighted accuracy score: 0.9780
In sample weighted f1 score: 0.9802


## Test Dataset Evaluation

In [8]:
# Get test dataset from a random seed
SEED = 20
np.random.seed(SEED)
data_df = pd.read_csv(dataset_path + '/data.csv')
test_dataset = data_df.groupby('label') .sample(frac=.2) \
    .query('image != "51101.png"') # 51101.png has a loading problem

other_index = test_dataset.query('label_name not in @label2id.keys()').index
test_dataset.loc[other_index, 'label_name'] = 'other'

# Get test dataset images and labels
images = [f'{img_path}/{image}' for image in test_dataset.image]
labels = test_dataset.label_name.apply(lambda x: label2id[x]).tolist()


In [9]:
# Get test dataset data loader
eval_data_loader = get_loader(images, labels, vit_feature_extractor, eval_transform,
                         pre_trained_model=vit_model, device=device, shuffle=False)
predictions = predict(eval_data_loader, model, device)

100%|██████████| 375/375 [09:44<00:00,  1.56s/it]


In [11]:
# Get and print evaluation metrics
from core.ai.utils import get_metrics
accuracy_score, f1_score, weighted_accuracy_score, weighted_f1_score = get_metrics(
    labels, predictions, label2id)

print(f'Out sample accuracy score: {accuracy_score:.4f}')
print(f'Out sample f1 score: {f1_score:.4f}')
print(f'Out sample weighted accuracy score: {weighted_accuracy_score:.4f}')
print(f'Out sample weighted f1 score: {weighted_f1_score:.4f}')


Out sample accuracy score: 0.8002
Out sample f1 score: 0.8002
Out sample weighted accuracy score: 0.6888
Out sample weighted f1 score: 0.6888
