In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from fastai.vision import *
from fastai.metrics import accuracy
from fastai.basic_data import *
from skimage.util import montage
import pandas as pd
from torch import optim
import re

from utils import *

I take a curriculum approach to training here. I first expose the model to as many different images of whales as quickly as possible (no oversampling) and train on images resized to 224x224.

I would like the conv layers to start picking up on features useful for identifying whales. For that, I want to show the model as rich of a dataset as possible.

I then train on images resized to 448x448.

Finally, I train on oversampled data. Here, the model will see some images more often than others but I am hoping that this will help alleviate the class imbalance in the training data.

In [2]:
import fastai
from fastprogress import force_console_behavior
import fastprogress
fastprogress.fastprogress.NO_BAR = True
master_bar, progress_bar = force_console_behavior()
fastai.basic_train.master_bar, fastai.basic_train.progress_bar = master_bar, progress_bar

In [3]:
df = pd.read_csv('../data/train.csv')
val_fns = {'69823499d.jpg'}

In [4]:
fn2label = {row[1].Image: row[1].Id for row in df.iterrows()}
path2fn = lambda path: re.search('\w*\.jpg$', path).group(0)

In [5]:
name = f'res50-full-train'

In [6]:
SZ = 224
SZ = 384
BS = 64
NUM_WORKERS = 12
SEED=0

In [7]:
data = (
    ImageItemList
        .from_df(df[df.Id != 'new_whale'], '../data/train', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder('../data/test'))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='../data')
        .normalize(imagenet_stats)
)

In [8]:
%%time

learn = create_cnn(data, models.resnet50, lin_ftrs=[2048])
learn.clip_grad();

CPU times: user 2.41 s, sys: 798 ms, total: 3.21 s
Wall time: 3.07 s


In [8]:
learn.fit_one_cycle(14, 1e-2)
learn.save(f'{name}-stage-1')

learn.unfreeze()

max_lr = 1e-3
lrs = [max_lr/100, max_lr/10, max_lr]

learn.fit_one_cycle(24, lrs)
learn.save(f'{name}-stage-2')

epoch     train_loss  valid_loss
1         7.501500    0.273467    
2         6.671245    0.180389    
3         6.094555    0.012937    
4         5.126923    0.341942    
5         4.270926    0.039900    
6         3.461939    1.203884    
7         2.553479    0.000077    
8         1.863998    0.000075    
9         1.227336    0.000357    
10        0.695952    0.000714    
11        0.383916    0.000061    
12        0.207382    0.000008    
13        0.143048    0.000001    
14        0.107567    0.000003    
Total time: 17:26
epoch     train_loss  valid_loss
1         0.106527    0.000008    
2         0.106264    0.000010    
3         0.164136    0.000028    
4         0.223849    0.002716    
5         0.270537    0.000025    
6         0.314849    0.003295    
7         0.333403    0.000002    
8         0.285706    0.000000    
9         0.296146    0.000027    
10        0.234786    0.000124    
11        0.197633    0.000048    
12        0.189064    0.000027    
13    

In [9]:
SZ = 224 * 2
BS = 64 // 4
NUM_WORKERS = 12
SEED=0

In [10]:
data = (
    ImageItemList
        .from_df(df[df.Id != 'new_whale'], '../data/train', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder('../data/test'))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='../data')
        .normalize(imagenet_stats)
)

In [11]:
%%time
learn = create_cnn(data, models.resnet50, lin_ftrs=[2048])
learn.clip_grad();
# learn.load(f'{name}-stage-2')
learn.load(f'{name}-stage-3')
# learn.freeze_to(-1)

# learn.fit_one_cycle(12, 1e-2 / 4)
# learn.save(f'{name}-stage-3')

learn.unfreeze()

max_lr = 1e-3 / 4
lrs = [max_lr/100, max_lr/10, max_lr]

learn.fit_one_cycle(22, lrs)
learn.save(f'{name}-stage-4')

epoch     train_loss  valid_loss
1         0.518923    0.000000    
2         0.518838    0.000000    
3         0.577872    0.000002    
4         0.606787    0.000000    
5         0.684259    0.000000    
6         0.748076    0.000002    
7         0.717744    0.000000    
8         0.718881    0.000000    
9         0.735095    0.000002    
10        0.699363    0.000004    
11        0.664675    0.000001    
12        0.690950    0.000017    
13        0.563404    0.000004    
14        0.574512    0.000004    
15        0.510566    0.000038    
16        0.554237    0.000002    
17        0.498543    0.000002    
18        0.449977    0.000000    
19        0.420661    0.000002    
20        0.369140    0.000004    
21        0.397967    0.000002    
22        0.372123    0.000004    
Total time: 2:14:41
CPU times: user 1h 36min 10s, sys: 37min 51s, total: 2h 14min 2s
Wall time: 2h 14min 45s


In [12]:
# with oversampling
df = pd.read_csv('../data/oversampled_train_and_val.csv')

In [13]:
data = (
    ImageItemList
        .from_df(df, '../data/train', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder('../data/test'))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='../data')
        .normalize(imagenet_stats)
)

In [14]:
%%time
learn = create_cnn(data, models.resnet50, lin_ftrs=[2048])
learn.clip_grad();
learn.load(f'{name}-stage-4')
learn.freeze_to(-1)

learn.fit_one_cycle(2, 1e-2 / 4)
learn.save(f'{name}-stage-5')

learn.unfreeze()

max_lr = 1e-3 / 4
lrs = [max_lr/100, max_lr/10, max_lr]

learn.fit_one_cycle(3, lrs)
learn.save(f'{name}-stage-6')

epoch     train_loss  valid_loss
1         1.765668    0.000024    
2         0.624019    0.000000    
Total time: 42:09
epoch     train_loss  valid_loss
1         0.701914    0.000000    
2         0.672839    0.000021    
3         0.579063    0.000016    
Total time: 1:27:37
CPU times: user 1h 33min 6s, sys: 36min 28s, total: 2h 9min 34s
Wall time: 2h 9min 48s


## Predict

In [15]:
preds, _ = learn.get_preds(DatasetType.Test)

In [16]:
preds = torch.cat((preds, torch.ones_like(preds[:, :1])), 1)

In [66]:
preds[:, 5004] = 0.04

In [59]:
classes = learn.data.classes + ['new_whale']

In [60]:
# targs = torch.tensor([classes.index(label.obj) if label else 5004 for label in learn.data.valid_ds.y])

In [51]:
# %reload_ext autoreload
# %autoreload 2

In [52]:
def map5fast(preds, targs, k=10):
    predicted_idxs = preds.sort(descending=True)[1]
    top_5 = predicted_idxs[:, :5]
    scores = torch.zeros(len(preds), k).float()
    for kk in range(k):
        scores[:,kk] = (top_5[:,kk] == targs).float() / float((kk+1))
    return scores.max(dim=1)[0].mean()

def map55(preds,targs):
    if type(preds) is list:
        return torch.cat([map5fast(p, targs, 5).view(1) for p in preds ]).mean()
    return map5fast(preds,targs, 5)


In [41]:
%%time
from tqdm import tqdm_notebook
res = []
ps = np.linspace(0, 1, 101)
for p in tqdm_notebook(ps):
    preds[:, 5004] = p
    res.append(map55(preds, targs).item())


CPU times: user 7min 44s, sys: 28.4 s, total: 8min 12s
Wall time: 7min 57s


In [44]:
res

[0.00759631535038352,
 0.006987018510699272,
 0.005956868175417185,
 0.005766331683844328,
 0.0056574540212750435,
 0.005573702044785023,
 0.005510888062417507,
 0.005462730303406715,
 0.0053999163210392,
 0.0053999163210392,
 0.005337102338671684,
 0.005211473908275366,
 0.005211473908275366,
 0.005211473908275366,
 0.005211473908275366,
 0.005211473908275366,
 0.005211473908275366,
 0.005211473908275366,
 0.005211473908275366,
 0.005211473908275366,
 0.005211473908275366,
 0.005211473908275366,
 0.005211473908275366,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.00514865992590785,
 0.0051486599

In [42]:
# without predicting new_whale
map55(preds, targs)

tensor(0.0045)

In [45]:
np.argmax(res)

0

In [43]:
best_p = ps[np.argmax(res)]; best_p

0.0

In [67]:
create_submission(preds, learn.data, name, classes)

In [68]:
pd.read_csv(f'subs/{name}.csv.gz').head()

Unnamed: 0,Image,Id
0,41d6736e1.jpg,w_7e56d66 new_whale w_2ac6611 w_700ebb4 w_bca4304
1,c68904c64.jpg,new_whale w_a92e68e w_4a8a4c9 w_abd456f w_b9edd7b
2,361293a53.jpg,new_whale w_171ca39 w_150a6f5 w_bf5403e w_d0e6e7d
3,0a9b3c0dc.jpg,w_0abdaf4 new_whale w_5babd8c w_40a6c9c w_b3ca4b7
4,0f41d9dee.jpg,new_whale w_5482351 w_ad88c85 w_0e0f074 w_3ff114c


In [69]:
pd.read_csv(f'subs/{name}.csv.gz').Id.str.split().apply(lambda x: x[0] == 'new_whale').mean()

0.4614321608040201

In [70]:
!kaggle competitions submit -c humpback-whale-identification -f subs/{name}.csv.gz -m "{name}"

100%|████████████████████████████████████████| 184k/184k [00:13<00:00, 13.6kB/s]
Successfully submitted to Humpback Whale Identification

In [71]:
!kaggle competitions submissions -c humpback-whale-identification

fileName                 date                 description       status    publicScore  privateScore  
-----------------------  -------------------  ----------------  --------  -----------  ------------  
res50-full-train.csv.gz  2019-02-20 23:10:32  res50-full-train  complete  0.757        None          
res50-full-train.csv.gz  2019-02-20 23:08:39  res50-full-train  complete  0.712        None          
res50-full-train.csv.gz  2019-02-20 23:07:12  res50-full-train  complete  0.649        None          
res50-full-train.csv.gz  2019-02-20 19:54:45  res50-full-train  complete  0.749        None          
sub7d.csv                2019-02-20 00:22:56  None              complete  0.866        None          
sub7c.csv                2019-02-16 21:40:09  None              complete  0.276        None          
sub7b.csv                2019-02-14 19:57:55  None              complete  0.868        None          
sub7a.csv                2019-02-14 03:10:42  None              complete 