In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from fastai.vision import *
from fastai.metrics import accuracy
from fastai.basic_data import *
from skimage.util import montage
import pandas as pd
from torch import optim
import re

from utils import *

I take a curriculum approach to training here. I first expose the model to as many different images of whales as quickly as possible (no oversampling) and train on images resized to 224x224.

I would like the conv layers to start picking up on features useful for identifying whales. For that, I want to show the model as rich of a dataset as possible.

I then train on images resized to 448x448.

Finally, I train on oversampled data. Here, the model will see some images more often than others but I am hoping that this will help alleviate the class imbalance in the training data.

In [2]:
import fastai
from fastprogress import force_console_behavior
import fastprogress
fastprogress.fastprogress.NO_BAR = True
master_bar, progress_bar = force_console_behavior()
fastai.basic_train.master_bar, fastai.basic_train.progress_bar = master_bar, progress_bar

In [3]:
df = pd.read_csv('../data/train.csv')
val_fns = {'69823499d.jpg'}

In [4]:
fn2label = {row[1].Image: row[1].Id for row in df.iterrows()}
path2fn = lambda path: re.search('\w*\.jpg$', path).group(0)

In [5]:
name = f'res50-full-train'

In [6]:
SZ = 224
BS = 64
NUM_WORKERS = 12
SEED=0

In [7]:
data = (
    ImageItemList
        .from_df(df[df.Id != 'new_whale'], '../data/train', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder('../data/test'))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='../data')
        .normalize(imagenet_stats)
)

In [8]:
%%time

learn = create_cnn(data, models.resnet50, lin_ftrs=[2048])
learn.clip_grad();

CPU times: user 2.51 s, sys: 844 ms, total: 3.35 s
Wall time: 10.2 s


In [9]:


learn.fit_one_cycle(14, 1e-2)
learn.save(f'{name}-stage-1')

learn.unfreeze()

max_lr = 1e-3
lrs = [max_lr/100, max_lr/10, max_lr]

learn.fit_one_cycle(24, lrs)
learn.save(f'{name}-stage-2')

epoch     train_loss  valid_loss
1         7.488658    0.212913    
2         6.698018    0.009669    
3         6.110320    1.649534    
4         5.145953    0.010905    
5         4.344628    0.622688    
6         3.427283    0.001436    
7         2.693983    0.006930    
8         1.889978    0.000161    
9         1.200347    0.000006    
10        0.743178    0.000083    
11        0.427395    0.000000    
12        0.231724    0.000001    
13        0.147966    0.000000    
14        0.114749    0.000000    
Total time: 18:46
epoch     train_loss  valid_loss
1         0.125161    0.000000    
2         0.125646    0.000000    
3         0.168645    0.000001    
4         0.224942    0.000000    
5         0.262029    0.000547    
6         0.288204    0.000000    
7         0.328155    0.000162    
8         0.306801    0.000023    
9         0.264029    0.000000    
10        0.239775    0.000000    
11        0.217391    0.000004    
12        0.192243    0.000008    
13    

In [10]:
SZ = 224 * 2
BS = 64 // 4
NUM_WORKERS = 12
SEED=0

In [11]:
data = (
    ImageItemList
        .from_df(df[df.Id != 'new_whale'], '../data/train', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder('../data/test'))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='../data')
        .normalize(imagenet_stats)
)

In [12]:
%%time
learn = create_cnn(data, models.resnet50, lin_ftrs=[2048])
learn.clip_grad();
learn.load(f'{name}-stage-2')
# learn.load(f'{name}-stage-3')
learn.freeze_to(-1)

learn.fit_one_cycle(24, 1e-2 / 4)
# learn.fit_one_cycle(12, 1e-2 / 4)
learn.save(f'{name}-stage-3')

learn.unfreeze()

max_lr = 1e-3 / 4
lrs = [max_lr/100, max_lr/10, max_lr]

learn.fit_one_cycle(44, lrs)
# learn.fit_one_cycle(22, lrs)
learn.save(f'{name}-stage-4')

epoch     train_loss  valid_loss
1         1.488783    0.000001    
2         0.781213    0.000000    
3         0.712935    0.000000    
4         0.963552    0.000004    
5         1.219762    0.000000    
6         1.474415    0.000000    
7         1.792757    0.000019    
8         2.021285    0.000000    
9         2.130956    0.000000    
10        2.228688    0.000001    
11        2.331408    0.000000    
12        2.517759    0.000000    
13        2.376398    0.000000    
14        2.373355    0.000000    
15        2.386347    0.000000    
16        2.220315    0.000000    
17        2.051613    0.000001    
18        2.181615    0.000008    
19        2.061087    0.000000    
20        2.035984    0.000008    
21        1.833481    0.000001    
22        1.702043    0.000001    
23        1.536841    0.000002    
24        1.463844    0.000001    
Total time: 1:48:34
epoch     train_loss  valid_loss
1         1.598646    0.000001    
2         1.517286    0.000001    
3   

In [13]:
# with oversampling
df = pd.read_csv('../data/oversampled_train_and_val.csv')

In [14]:
data = (
    ImageItemList
        .from_df(df, '../data/train', cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder('../data/test'))
        .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='../data')
        .normalize(imagenet_stats)
)

In [15]:
%%time
learn = create_cnn(data, models.resnet50, lin_ftrs=[2048])
learn.clip_grad();
learn.load(f'{name}-stage-4')
learn.freeze_to(-1)

# learn.fit_one_cycle(2, 1e-2 / 4)
learn.fit_one_cycle(14, 1e-2 / 4)
learn.save(f'{name}-stage-5')

learn.unfreeze()

max_lr = 1e-3 / 4
lrs = [max_lr/100, max_lr/10, max_lr]

# learn.fit_one_cycle(3, lrs)
learn.fit_one_cycle(24, lrs)
learn.save(f'{name}-stage-6')

epoch     train_loss  valid_loss
1         2.704631    0.000149    
2         2.569817    557.161499  
3         2.684457    197.819656  
4         2.248218    68.272034   
5         1.829801    6.009163    
6         1.637777    248.953278  
7         1.571471    0.000134    
8         1.352860    0.000011    
9         1.480265    0.000137    
10        1.605890    0.002937    
11        1.982956    145.073669  
12        2.344408    0.099056    
13        2.887636    12.574604   
14        3.240505    2.706734    
Total time: 5:09:19
epoch     train_loss  valid_loss
1         3.637343    0.010223    
2         3.955417    0.076408    
3         3.521924    2.532550    
4         2.965347    0.206098    
5         2.585325    0.010826    
6         2.331424    1.396721    
7         2.243938    0.013039    
8         2.298476    51.148140   
9         2.164433    62.201359   
10        2.156010    0.027977    
11        2.097930    0.001144    
12        2.077401    0.024689    
13  

## Predict

In [16]:
preds, _ = learn.get_preds(DatasetType.Test)

In [17]:
preds = torch.cat((preds, torch.ones_like(preds[:, :1])), 1)

In [30]:
preds[:, 5004] = 0.04

In [31]:
np.save('../cache/preds_resnet50_2', preds)

In [32]:
classes = learn.data.classes + ['new_whale']

In [33]:
# targs = torch.tensor([classes.index(label.obj) if label else 5004 for label in learn.data.valid_ds.y])

In [34]:
# %reload_ext autoreload
# %autoreload 2

In [35]:
def map5fast(preds, targs, k=10):
    predicted_idxs = preds.sort(descending=True)[1]
    top_5 = predicted_idxs[:, :5]
    scores = torch.zeros(len(preds), k).float()
    for kk in range(k):
        scores[:,kk] = (top_5[:,kk] == targs).float() / float((kk+1))
    return scores.max(dim=1)[0].mean()

def map55(preds,targs):
    if type(preds) is list:
        return torch.cat([map5fast(p, targs, 5).view(1) for p in preds ]).mean()
    return map5fast(preds,targs, 5)


In [36]:
# %%time
# from tqdm import tqdm_notebook
# res = []
# ps = np.linspace(0, 1, 101)
# for p in tqdm_notebook(ps):
#     preds[:, 5004] = p
#     res.append(map55(preds, targs).item())

In [37]:
create_submission(preds, learn.data, name, classes)

In [38]:
pd.read_csv(f'subs/{name}.csv.gz').head()

Unnamed: 0,Image,Id
0,41d6736e1.jpg,new_whale w_7e56d66 w_bca4304 w_71df35a w_af8df07
1,c68904c64.jpg,w_778e474 w_08630fd new_whale w_3de9056 w_59052ad
2,361293a53.jpg,new_whale w_c7d8935 w_0027efa w_171ca39 w_e16924b
3,0a9b3c0dc.jpg,new_whale w_0abdaf4 w_d72771c w_3756834 w_8eae2c3
4,0f41d9dee.jpg,new_whale w_91cc02c w_efcbb06 w_60ce6fc w_70fc054


In [39]:
pd.read_csv(f'subs/{name}.csv.gz').Id.str.split().apply(lambda x: x[0] == 'new_whale').mean()

0.8006281407035176

In [40]:
!kaggle competitions submit -c humpback-whale-identification -f subs/{name}.csv.gz -m "{name}"

100%|████████████████████████████████████████| 176k/176k [00:07<00:00, 23.6kB/s]
Successfully submitted to Humpback Whale Identification

In [41]:
!kaggle competitions submissions -c humpback-whale-identification

fileName                  date                 description        status    publicScore  privateScore  
------------------------  -------------------  -----------------  --------  -----------  ------------  
res50-full-train.csv.gz   2019-02-25 10:34:52  res50-full-train   complete  0.575        None          
res50-full-train.csv.gz   2019-02-25 01:55:44  res50-full-train   complete  0.587        None          
sub7l.csv                 2019-02-23 14:59:37  None               complete  0.890        None          
sub7k.csv                 2019-02-23 14:58:45  None               complete  0.892        None          
sub7j.csv                 2019-02-23 14:58:00  None               complete  0.891        None          
sub7h.csv                 2019-02-22 06:16:45  None               complete  0.890        None          
sub7h.csv                 2019-02-22 06:13:13  None               complete  0.889        None          
sub7h.csv                 2019-02-22 06:07:58  None    