In [None]:
import pandas as pd
import pickle
import numpy as np
from pathlib import Path
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import itertools
import shutil
from torch.nn import Softmax

%matplotlib inline

In [None]:
root = Path('../output/runs')

CLASSES = ['HTC-1-M7', 'LG-Nexus-5x', 'Motorola-Droid-Maxx', 'Motorola-Nexus-6', 'Motorola-X',
           'Samsung-Galaxy-Note3', 'Samsung-Galaxy-S4', 'Sony-NEX-7', 'iPhone-4s', 'iPhone-6']

In [None]:
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return e_x / e_x.sum(axis=1, keepdims=True)

# TEST

In [None]:
def atest(run):
#     csv_path = str(root/run/Path('predict_test.csv'))
#     df = pd.read_csv(csv_path)
#     pred_counts = df['camera'].value_counts()
#     pred_counts /= 0.01 * df.shape[0]
#     print(pred_counts)

    preds, fnames = pickle.load(open(str(root/run/'predict_test_detailed.pkl'),'rb'))
    preds, fnames = np.vstack(preds), np.array(fnames)
    preds = softmax(preds)
    pred_classes = np.argmax(preds, axis=1)
    class_counts = np.unique(pred_classes, return_counts=True)
    class_counts = sorted(zip(*class_counts), key=lambda x: x[1])
    total_poor_pred = 0
    poor_thresh = 0.7
    for cls, cnt in class_counts[::-1]:
        poor_pred = np.sum(preds[:, cls] < poor_thresh)
        total_poor_pred += poor_pred
        print('{}\n\t{}\t{:.2f}%\tpoor:{:.2f}%'.\
              format(CLASSES[cls], cnt, 100*cnt/len(preds), poor_pred/cnt))
    print('Total poor predictions: {:.2f}%, threshold {}'.format(total_poor_pred/preds.shape[0], poor_thresh))
    
    phone = 'LG-Nexus-5x'
    class_id = CLASSES.index(phone)
    mask = pred_classes == class_id
    class_probs = preds[mask, :].max(axis=1)
    plt.hist(class_probs, 50)
    plt.title(phone)
    
    if False:
        dst_dir = Path('/tmp/poor_pred')
        if dst_dir.exists():
            shutil.rmtree(str(dst_dir))
        dst_dir.mkdir(exist_ok=True)
        for src in fnames[np.bitwise_and(mask, preds[:, class_id] < 4)]:
            src = Path(str(src))
            shutil.copy(str(Path('../')/src), str(dst_dir/(src.stem+'.png')))

In [None]:
atest('resnet50_random_crop')

In [None]:
atest('resnet50_random_crop_sometimes_0.3')

In [None]:
atest('resnet50_random_crop_sometimes_0.5')

In [None]:
atest('resnet50_adam_lr_1e-3_rand_crop_lr_ch_3')

In [None]:
atest('resnet50_adam_lr_1e-3_rand_crop')

In [None]:
atest('resnet50_class_aware')

In [None]:
atest('dense121_512')

In [None]:
atest('dense121_512_lr_2e-4')

In [None]:
atest('dense121_mul_fc_lr')

In [None]:
atest('dense121_mul_fc_lr_no_rot')

In [None]:
atest('densenet121_2stage_tform')

In [None]:
atest('densenet121_2stage_tform')

In [None]:
atest('densenet121_2stage_tform_full_flickr')

In [None]:
atest('densenet121_2stage_tform_full_flickr')

# VALID

In [None]:
def plot_confusion_matrix(cm, classes=CLASSES,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    #print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    #plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

def avalid(run):
    data = pickle.load(open(str(root / run / 'predict_valid.pkl'), 'rb'))
    preds, targets, manips = [np.array(d) for d in data]
    
    mask = manips != -1
    acc_manip = np.mean(np.argmax(preds[mask,:], axis=1) == targets[mask])
    acc_unalt = np.mean(np.argmax(preds[~mask,:], axis=1) == targets[~mask])
    print('Predictions\t{}'.format(preds.shape[0]))
    print('Acc unalt\t{}\nAcc manip\t{}\nAcc\t{}'.format(acc_unalt, acc_manip, 0.7*acc_unalt+0.3*acc_manip))
    
    y_pred = np.argmax(preds, axis=1)
    y_true = targets
    cnf_matrix = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 8))
    plot_confusion_matrix(cnf_matrix, classes=CLASSES, normalize=True)

In [None]:
avalid('resnet50_refine_nexus_5x_no_sea_repeat_4')

In [None]:
avalid('resnet50_random_crop')

In [None]:
avalid('dense121_512_lr_2e-4')