In [31]:
from efficientnet_pytorch import EfficientNet
from PIL import Image
import torch
from torchvision import transforms
import os
from tqdm import *
from sklearn.metrics import confusion_matrix
import numpy as np
import json


In [4]:
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')

Loaded pretrained weights for efficientnet-b0


# Inference on a single image

In [6]:
# Preprocess image
tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])
test_image = tfms(Image.open('./dataset/ILSVRC2012_val_00000003.JPEG')).unsqueeze(0)

# Load ImageNet class names
labels_map = json.load(open('categories.json'))
labels_map = [labels_map[str(i)] for i in range(1000)]

# Classify
model.eval()
with torch.no_grad():
    outputs = model(test_image)

# Print predictions
for idx in torch.topk(outputs, k=5).indices.squeeze(0).tolist():
    prob = torch.softmax(outputs, dim=1)[0, idx].item()
    print('{label:<75} ({p:.2f}%)'.format(label=labels_map[idx], p=prob*100))

Shetland sheepdog, Shetland sheep dog, Shetland                             (66.55%)
collie                                                                      (26.09%)
borzoi, Russian wolfhound                                                   (0.27%)
Pembroke, Pembroke Welsh corgi                                              (0.13%)
papillon                                                                    (0.12%)


# Naive implementation on the whole dataset

In [22]:
directory= './dataset/'
nb_samples=200 # test on a smaller part of the dataset
labels=json.load(open('labels.json'))

grayscale= []
actual=[]
predicted=[]

tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])

for filename in tqdm(os.listdir(directory)[:nb_samples]):
    f = os.path.join(directory, filename)
    y_true= labels[filename]    
    img=Image.open(f)    
    if img.mode !='RGB':
        grayscale.append(filename)
        img = tfms(img.convert('RGB')).unsqueeze(0)
    else:
        img = tfms(img).unsqueeze(0)
    model.eval()
    with torch.no_grad():
        outputs = model(img)
        y_pred = torch.argmax(outputs).item()
        actual.append(y_true)
        predicted.append(y_pred)

with open('results.json', 'w') as fp:
    results = { i : (actual[i], predicted[i]) for i in range(nb_samples)}
    json.dump(results, fp,  indent=4)

100%|██████████| 200/200 [00:43<00:00,  4.60it/s]


# Metrics

In [29]:
results=list(json.load(open('results.json')).values())
actual=[sample[0] for sample in results]
predicted=[sample[1] for sample in results]

print(actual)
def confusion_values (y_true,y_pred):
    confusion = confusion_matrix(y_true,y_pred)
    FP = confusion.sum(axis=0) - np.diag(confusion)  
    FN = confusion.sum(axis=1) - np.diag(confusion)
    TP = np.diag(confusion)
    TN = confusion.sum() - (FP + FN + TP)
    return FP,FN,TP,TN

def get_metrics(FP,FN,TP,TN,nb_smpl):
    # Accuracy
    ACC=sum(TP)/nb_smpl
    # Specificity or True Negative Rate
    TNR=TN/(TN+FP)
    return {'ACC':ACC,'TNR':TNR}

[511, 74, 32, 332, 928, 346, 10, 259, 56, 524, 841, 618, 31, 803, 566, 990, 964, 52, 7, 547, 52, 444, 741, 72, 113, 504, 817, 462, 205, 425, 215, 871, 203, 135, 837, 491, 703, 563, 793, 83, 898, 31, 7, 962, 354, 128, 424, 463, 438, 541, 359, 973, 985, 816, 346, 723, 148, 580, 449, 510, 486, 237, 524, 791, 979, 593, 939, 913, 451, 627, 763, 54, 401, 553, 982, 41, 135, 707, 170, 48, 180, 204, 482, 212, 674, 990, 164, 701, 519, 23, 543, 736, 264, 278, 438, 735, 649, 10, 410, 91, 92, 355, 318, 120, 596, 956, 864, 679, 42, 762, 806, 136, 690, 357, 371, 543, 414, 935, 581, 849, 30, 845, 513, 490, 350, 858, 787, 114, 14, 96, 620, 988, 48, 848, 626, 913, 209, 717, 121, 250, 896, 380, 297, 626, 395, 180, 628, 493, 551, 795, 158, 819, 356, 70, 639, 661, 817, 321, 78, 967, 939, 168, 285, 270, 448, 366, 775, 515, 53, 413, 3, 20, 928, 490, 100, 918, 597, 119, 754, 452, 886, 321, 253, 942, 417, 17, 824, 858, 800, 926, 388, 280, 666, 115, 22, 717, 607, 859, 270, 188]


# Global metrics

In [33]:
FP,FN,TP,TN = confusion_values(actual,predicted)
metrics = get_metrics(FP,FN,TP,TN,nb_samples)
print(metrics)

{'ACC': 0.745, 'TNR': array([1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       0.99497487, 0.995     , 0.995     , 1.        , 1.        ,
       0.995     , 0.99494949, 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 0.995     , 1.        ,
       1.        , 1.        , 0.995     , 1.        , 1.        ,
       1.        , 1.        , 0.995     , 1.        , 1.        ,
       1.        , 0.995     , 0.99497487, 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 0.995     ,
       0.995     , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 0.995     , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 0.995     ,
       1.        , 0.995     , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 0.99497487, 0.995     ,
       1.        , 0.995     , 1.       

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# Class imbalance
We count the frequence of appearance of each class in order to compute metrics on the different parts of the dataset

## Frequence
We compute the metrics on the most present (true) classes 

In [78]:
actual_dict = {x:actual.count(x) for x in actual}
predicted_dict = {x:predicted.count(x) for x in predicted}
results_dict = {actual[i]: predicted[i] for i in range(len(actual))}

In [89]:
actual_filtered = [ key for (key,value) in actual_dict.items() if value >= 5 ]
results_filtered = {key:value for (key,value) in results_dict.items() if key in actual_filtered }
fFP,fFN,fTP,fTN = confusion_values(list(results_filtered.keys()),list(results_filtered.values()))
fmetrics = get_metrics(fFP,fFN,fTP,fTN,len(actual_filtered))
print(fmetrics.ACC)

{'ACC': 0.7658450704225352, 'TNR': array([1.        , 1.        , 1.        , 1.        , 0.99823944,
       1.        , 0.99823944, 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 0.99823944,
       1.        , 0.99823944, 1.        , 1.        , 1.        ,
       0.99823944, 1.        , 1.        , 1.        , 1.        ,
       0.99823944, 0.99823633, 1.        , 1.        , 1.        ,
       1.        , 1.        , 0.99823944, 1.        , 1.        ,
       0.99470899, 0.99823633, 1.        , 1.        , 1.        ,
       1.        , 0.99823633, 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.      

# Grayscale
We compute the metrics on the grayscale images to see if the color have an effect on the inference

In [90]:
grayscale_labels = [labels[filename] for filename in grayscale  ]
results_gray = {key:value for (key,value) in results_dict.items() if key in grayscale_labels }
gFP,gFN,gTP,gTN = confusion_values(list(results_gray.keys()),list(results_gray.values()))
gmetrics = get_metrics(gFP,gFN,gTP,gTN,len(grayscale_labels))
print(gmetrics.ACC)

{'ACC': 0.63, 'TNR': array([1.        , 1.        , 0.98947368, 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 0.9893617 ,
       1.        , 1.        , 0.98947368, 0.98947368, 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 0.98947368, 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 0.98947368,
       1.        , 1.        , 0.98947368, 1.        , 0.98947368,
       1.        , 0.9893617 , 1.        , 1.        , 0.98947368,
       1.        , 1.        , 0.98947368, 1.        , 1.        ,
       1.        , 1.        , 1.        , 0.98947368, 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 0.98947368, 1.        , 0.9893617 , 1.        ,
       0.97894737, 1.        , 1.        