In [3]:
%matplotlib inline
import matplotlib.pyplot as plt
from fastai.vision import *
from fastai.metrics import accuracy
from fastai.basic_data import *
from skimage.util import montage
import pandas as pd
from torch import optim
import re
import os

from utils import *

## Prepare data

In [4]:
data_dir = '/home/ys1/dataset/Humpback_Whale'
df = pd.read_csv(os.path.join(data_dir,'train.csv'))
df.head()

Unnamed: 0,Image,Id
0,0000e88ab.jpg,w_f48451c
1,0001f9222.jpg,w_c3d896a
2,00029d126.jpg,w_20df2c5
3,00050a15a.jpg,new_whale
4,0005c1ef8.jpg,new_whale


In [5]:
im_count = df[df.Id != 'new_whale'].Id.value_counts()
im_count.name = 'sighting_count'
df = df.join(im_count, on='Id')
val_fns = set(df.sample(frac=1)[(df.Id != 'new_whale') & (df.sighting_count > 1)].groupby('Id').first().Image)

  after removing the cwd from sys.path.


In [7]:
pd.to_pickle(val_fns, os.path.join(data_dir,'val_fns'))
# val_fns = pd.read_pickle(os.path.join(data_dir,'val_fns'))

In [8]:
fn2label = {row[1].Image: row[1].Id for row in df.iterrows()}

In [9]:
SZ = 224
BS = 64
NUM_WORKERS = 12
SEED=0

In [10]:
path2fn = lambda path: re.search('\w*\.jpg$', path).group(0)

In [11]:
df = df[df.Id != 'new_whale']

In [12]:
df.shape

(15697, 3)

In [13]:
df.sighting_count.max()

73.0

In [14]:
df_val = df[df.Image.isin(val_fns)]
df_train = df[~df.Image.isin(val_fns)]
df_train_with_val = df

In [15]:
df_val.shape, df_train.shape, df_train_with_val.shape

((2931, 3), (12766, 3), (15697, 3))

In [16]:
%%time

res = None
sample_to = 15

for grp in df_train.groupby('Id'):
    n = grp[1].shape[0]
    additional_rows = grp[1].sample(0 if sample_to < n  else sample_to - n, replace=True)
    rows = pd.concat((grp[1], additional_rows))
    
    if res is None: res = rows
    else: res = pd.concat((res, rows))

CPU times: user 6.8 s, sys: 39 ms, total: 6.83 s
Wall time: 6.83 s


In [17]:
%%time

res_with_val = None
sample_to = 15

for grp in df_train_with_val.groupby('Id'):
    n = grp[1].shape[0]
    additional_rows = grp[1].sample(0 if sample_to < n  else sample_to - n, replace=True)
    rows = pd.concat((grp[1], additional_rows))
    
    if res_with_val is None: res_with_val = rows
    else: res_with_val = pd.concat((res_with_val, rows))

CPU times: user 11.2 s, sys: 7.96 ms, total: 11.3 s
Wall time: 11.3 s


In [18]:
res.shape, res_with_val.shape

((76174, 3), (76287, 3))

Our training set increased 6-fold, but that is still an amount of data that is okay. I don't think it makes sense to worry about breaking up the data into smaller epochs.

In [19]:
pd.concat((res, df_val))[['Image', 'Id']].to_csv(os.path.join(data_dir,'oversampled_train.csv'), index=False)
res_with_val[['Image', 'Id']].to_csv(os.path.join(data_dir,'oversampled_train_and_val.csv'), index=False)

The naming here is not very fortunate, but the idea is that `oversampled_train` has single entries for images in `val_fns` and `oversampled_train_and_val` is both `val` and `train` combined. Meaning, `oversampled_train_and_val` is one we might want to use when retraining on the entire train set.

In [20]:
df = pd.read_csv(os.path.join(data_dir,'oversampled_train.csv'))

In [21]:
data = (
    ImageItemList
        .from_df(df[df.Id != 'new_whale'], os.path.join(data_dir,'train'), cols=['Image'])
        .split_by_valid_func(lambda path: path2fn(path) in val_fns)
        .label_from_func(lambda path: fn2label[path2fn(path)])
        .add_test(ImageItemList.from_folder(os.path.join(data_dir,'test')))
        .transform(get_transforms(do_flip=False, max_zoom=1, max_warp=0, max_rotate=2), size=SZ, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=BS, num_workers=NUM_WORKERS, path='data')
        .normalize(imagenet_stats)
)

In [22]:
data

ImageDataBunch;

Train: LabelList
y: CategoryList (76174 items)
[Category w_0003639, Category w_0003639, Category w_0003639, Category w_0003639, Category w_0003639]...
Path: /home/ys1/dataset/Humpback_Whale/train
x: ImageItemList (76174 items)
[Image (3, 700, 1050), Image (3, 700, 1050), Image (3, 700, 1050), Image (3, 700, 1050), Image (3, 700, 1050)]...
Path: /home/ys1/dataset/Humpback_Whale/train;

Valid: LabelList
y: CategoryList (2931 items)
[Category w_c3d896a, Category w_a6f9d33, Category w_cb622a2, Category w_3881f28, Category w_8a235b6]...
Path: /home/ys1/dataset/Humpback_Whale/train
x: ImageItemList (2931 items)
[Image (3, 325, 758), Image (3, 667, 1000), Image (3, 285, 1050), Image (3, 409, 1050), Image (3, 307, 1050)]...
Path: /home/ys1/dataset/Humpback_Whale/train;

Test: LabelList
y: CategoryList (7960 items)
[Category w_0003639, Category w_0003639, Category w_0003639, Category w_0003639, Category w_0003639]...
Path: /home/ys1/dataset/Humpback_Whale/train
x: ImageItemList