# Spiderman Vs. Deadpool V1

## Imports

In [None]:
# Put these at the top of every notebook, to get automatic reloading and inline plotting
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# This file contains all the main external libs we'll use
from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *

In [None]:
#torch.cuda.set_device

## Data Directory + Create Validation Set 

In [None]:
competition = 'superhero_detector'

In [None]:
PATH = "data/"+competition+'/'
sz=224
arch=resnet50
bs=28


In [None]:
from PIL import Image
for image in glob(PATH+"train/*/*.jpg"):
    try:
        im = Image.open(image)
    except IOError as w:
        print(image)
        os.remove(image)
    except OSError as w:
        print(image)
        os.remove(image)

In [None]:
from glob2 import glob
from shutil import move
import os 

df = pd.DataFrame(columns=["file", "superhero"])

for image in glob(PATH+"train/*/*.jpg"):
    dir_ = image.split('/')
    file_, superhero = dir_[-1], dir_[-2]

    df = df.append({
        "file": superhero+"_"+file_,
        "superhero": superhero
        }, ignore_index=True)
    move(PATH+"train/"+superhero+"/"+file_, PATH+"train/"+superhero+"_"+file_)

df.to_csv(PATH+'labels.csv', index=False)

superhero = df.pivot_table(index='superhero', aggfunc=len).sort_values('file', ascending=False)
for sp in superhero.index: 
    os.rmdir(PATH+'train/'+sp)

In [None]:
label_csv = f'{PATH}labels.csv'
n = len(list(open(label_csv)))-1
val_idxs = get_cv_idxs(n)

In [None]:
df = pd.read_csv(PATH+'labels.csv')
df.sample(5)

In [None]:
!rm -rf data/superhero_detector/tmp/

In [None]:
def get_data(sz,bs):
    tfms = tfms_from_model(arch, sz, aug_tfms=transforms_side_on, max_zoom=1.1)
    data = ImageClassifierData.from_csv(PATH,'train', f'{PATH}/labels.csv', val_idxs=val_idxs,
                                    tfms=tfms, bs=bs)
    return data if sz>300  else data.resize(340,'tmp')


In [None]:
data = get_data(sz,bs)

# Algorithm 

## Initial Model 

### Precompute

In [None]:
learn = ConvLearner.pretrained(arch, data, precompute=False)
learn.fit(1e-2,1)

In [None]:
learn.precompute=False

learn.fit(1e-2, 5, cycle_len=1)

In [None]:
learn.unfreeze()
lr=np.array([1e-4,1e-3,1e-2])

In [None]:
learn.fit(lr, 3, cycle_len=1, cycle_mult=2)

In [None]:
learn.save('224_all_50.spiderman.deadpool')

In [None]:
learn.load('224_all_50.spiderman.deadpool')

## Analyzing results

In [None]:
log_preds,y = learn.TTA()
preds = np.argmax(log_preds, axis=1)
probs = np.exp(log_preds[:,1])


In [None]:
def rand_by_mask(mask): return np.random.choice(np.where(mask)[0], 4, replace=False)
def rand_by_correct(is_correct): return rand_by_mask((preds == data.val_y)==is_correct)

def plot_val_with_title(idxs, title):
    imgs = np.stack([data.val_ds[x][0] for x in idxs])
#    title_probs = [probs[x] for x in idxs]
    title_probs = [preds[x] for x in idxs]
    print(title)
    return plots(data.val_ds.denorm(imgs), rows=1, titles=title_probs)

def plots(ims, figsize=(12,6), rows=1, titles=None):
    f = plt.figure(figsize=figsize)
    for i in range(len(ims)):
        sp = f.add_subplot(rows, len(ims)//rows, i+1)
        sp.axis('Off')
        if titles is not None: sp.set_title(titles[i], fontsize=16)
        plt.imshow(ims[i])

def load_img_id(ds, idx): return np.array(PIL.Image.open(PATH+ds.fnames[idx]))

def plot_val_with_title(idxs, title):
    imgs = [load_img_id(data.val_ds,x) for x in idxs]
    title_probs = [probs[x] for x in idxs]
    print(title)
    return plots(imgs, rows=1, titles=title_probs, figsize=(16,8))

def most_by_mask(mask, mult):
    idxs = np.where(mask)[0]
    return idxs[np.argsort(mult * probs[idxs])[:4]]

def most_by_incorrect(y, is_correct): 
    mult = -1 if (y==1)!=is_correct else 1
    return most_by_mask((preds == data.val_y)!=is_correct & (data.val_y != y), mult)

def most_by_incorrect(y, is_correct): 
    mult = -1 if (y==1)!=is_correct else 1
    return most_by_mask((preds == data.val_y)!=is_correct & (data.val_y != y), mult)

def most_by_correct(y, is_correct): 
    mult = -1 if (y==1)==is_correct else 1
    return most_by_mask((preds == data.val_y)==is_correct & (data.val_y == y), mult)

In [None]:
data.val_ds

In [None]:
plot_val_with_title(most_by_correct(1, True), "Most correct deadpools")

In [None]:
plot_val_with_title(most_by_incorrect(0, False), "Most correct Deadpool")

In [None]:
plot_val_with_title(most_by_correct(1, False), "Most incorrect images")