In [1]:
import os
import pickle

from pathlib import Path

import pandas as pd
import numpy as np
import torch
import torch.nn as nn

from PIL import Image
from sklearn.metrics import f1_score
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models, transforms

pd.options.display.max_columns = 999
pd.options.display.max_rows = 100

In [2]:
ROOT_DIR = '../'
MODEL_PATH = ROOT_DIR + 'models/Patryk-ResNeXt-long-train.pkt'
VALIDATION_LABELS_PATH = ROOT_DIR + 'data/validation_labels.csv'
VALIDATION_DATA_PATH = ROOT_DIR + 'data/validation_images'
OPTIMAL_THRESHOLD_PATH = ROOT_DIR + 'models/optimal_thresholds-Patryk-ResNeXt-long-train.npy'

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(DEVICE)

cpu


  return torch._C._cuda_getDeviceCount() > 0


In [3]:
def OFO(y_pred, y_true):
    length = y_pred.shape[1]
    a = np.ones(length)
    b = 2 * np.ones(length)
    tau = a / b
    
    for i in range(y_pred.shape[0]):
        row_true = y_true[i]
        row_pred = y_pred[i]

        y_pred_threshold = (row_pred > tau).astype(int)
        a += np.logical_and(y_pred_threshold, row_true).astype(int)
        b += y_pred_threshold + row_true
        tau = a / b
    return tau


def skyhacks_f1_score(preds, y):
    return f1_score(y, preds, average = 'macro')

In [4]:
class MultiClassDataset(Dataset):

    def __init__(self , csv_file , img_dir , transform=None):
        self.df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
    
    def __getitem__(self, idx):
        d = self.df.iloc[idx]
        image = Image.open(f'{self.img_dir}/{d.Name}').convert("RGB")
        label = torch.tensor(d[1:].tolist() , dtype=torch.float32)
    
        if self.transform is not None:
            image = self.transform(image)
        return image, label
  
    def __len__(self):
        return len(self.df)

In [5]:
model = torch.load(MODEL_PATH, map_location=DEVICE)

In [6]:
batch_size = 16
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

validation_set = MultiClassDataset(VALIDATION_LABELS_PATH, VALIDATION_DATA_PATH, transform)
validation_loader = DataLoader(validation_set, shuffle=False, batch_size=batch_size)

In [7]:
results = []

for batch in validation_loader:
    x, y = batch
    res = model(x.to(DEVICE))
    res = torch.sigmoid(res).to(torch.float32).cpu().detach().numpy()
    # res = (torch.sigmoid(res) > 0.5).to(torch.float32).cpu().numpy()
    results.append(res)

In [8]:
df_val_true = pd.read_csv(VALIDATION_LABELS_PATH)
df_val_pred = pd.DataFrame(np.vstack(results), columns = df_val_true.columns[1:])
df_val_pred['Name'] = df_val_true['Name'].copy()
df_val_pred = df_val_pred[df_val_true.columns]

In [9]:
y_pred = df_val_pred.iloc[:, 1:].values
y_true = df_val_true.iloc[:, 1:].values

In [10]:
optimal_thresholds = OFO(y_pred, y_true)

In [11]:
optimal_thresholds

array([0.21428571, 0.4       , 0.25      , 0.407173  , 0.30769231,
       0.4       , 0.41818182, 0.28787879, 0.34782609, 0.15116279,
       0.11111111, 0.29120879, 0.34782609, 0.31034483, 0.33333333,
       0.27272727, 0.33333333, 0.32857143, 0.25      , 0.32      ,
       0.27692308, 0.20689655, 0.18181818, 0.05555556, 0.10638298,
       0.39784946, 0.3003413 , 0.35714286, 0.23076923, 0.3037037 ,
       0.35      , 0.46153846, 0.25925926, 0.17391304, 0.13513514,
       0.42417062, 0.5       , 0.32278481])

In [12]:
with open(OPTIMAL_THRESHOLD_PATH, 'wb') as f:
    np.save(f, optimal_thresholds)

In [13]:
skyhacks_f1_score((y_pred > 0.5).astype(int), y_true)

0.5061888429024908

In [14]:
skyhacks_f1_score((y_pred > optimal_thresholds).astype(int), y_true)

0.5595717084499242

# Error analysis

In [15]:
y_hat = (y_pred > optimal_thresholds).astype(int)

In [17]:
df = pd.read_csv('../data/training_labels.csv')

In [18]:
scoring = pd.DataFrame(np.array([f1_score(y_hat[:, i], y_true[:, i]) for i in range(y_hat.shape[1])])).T
scoring.columns = df.columns[1:]

In [19]:
scoring

Unnamed: 0,Amusement park,Animals,Bench,Building,Castle,Cave,Church,City,Cross,Cultural institution,Food,Footpath,Forest,Furniture,Grass,Graveyard,Lake,Landscape,Mine,Monument,Motor vehicle,Mountains,Museum,Open-air museum,Park,Person,Plants,Reservoir,River,Road,Rocks,Snow,Sport,Sports facility,Stairs,Trees,Watercraft,Windows
0,0.166667,0.75,0.4,0.809224,0.666667,0.75,0.844037,0.617647,0.6875,0.273684,0.153846,0.583784,0.680851,0.611765,0.671096,0.454545,0.695652,0.638889,0.428571,0.538462,0.508475,0.357143,0.28169,0.153846,0.175439,0.793478,0.608392,0.7,0.56,0.59375,0.65,0.909091,0.48,0.272727,0.296296,0.850356,1.0,0.650155


In [20]:
df.describe()

Unnamed: 0,Amusement park,Animals,Bench,Building,Castle,Cave,Church,City,Cross,Cultural institution,Food,Footpath,Forest,Furniture,Grass,Graveyard,Lake,Landscape,Mine,Monument,Motor vehicle,Mountains,Museum,Open-air museum,Park,Person,Plants,Reservoir,River,Road,Rocks,Snow,Sport,Sports facility,Stairs,Trees,Watercraft,Windows
count,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0,3745.0
mean,0.013351,0.022163,0.053672,0.541522,0.018158,0.014419,0.173298,0.071562,0.143124,0.055274,0.009613,0.189319,0.10761,0.074499,0.338852,0.033111,0.022163,0.073698,0.014686,0.050734,0.077971,0.030975,0.077437,0.020828,0.036582,0.20721,0.281175,0.048064,0.018959,0.161549,0.054473,0.020027,0.035514,0.0251,0.061682,0.526836,0.00988,0.299065
std,0.114789,0.147233,0.225399,0.498339,0.133539,0.119227,0.378555,0.257796,0.350246,0.228544,0.097586,0.391814,0.309929,0.262617,0.473383,0.17895,0.147233,0.261314,0.12031,0.219484,0.268161,0.173272,0.267319,0.142827,0.187759,0.405361,0.449633,0.21393,0.136397,0.368085,0.226978,0.14011,0.1851,0.15645,0.24061,0.499346,0.098918,0.45791
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
75%,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0
max,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [21]:
scoring.T.sort_values(0)

Unnamed: 0,0
Open-air museum,0.153846
Food,0.153846
Amusement park,0.166667
Park,0.175439
Sports facility,0.272727
Cultural institution,0.273684
Museum,0.28169
Stairs,0.296296
Mountains,0.357143
Bench,0.4
