In [1]:
# Put these at the top of every notebook, to get automatic reloading and inline plotting
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.vision import *
from fastai.metrics import accuracy
from fastai.basic_data import *
import pandas as pd

In [3]:
import datetime

In [4]:
#DATA_PATH 目录下应该有train文件夹和valid文件夹 fastai 根据这两个目录的来进行train和valid
DATA_PATH = './data/'
TRAIN_DATA_PATH = DATA_PATH + 'train'
TEST_IMG_PATH = DATA_PATH + 'test'
SUBMISSION_PATH = DATA_PATH + 'submission/'
BBOX_PATH = DATA_PATH + 'bounding_boxes.csv'

VALIDATION_RATIO = 0.16

In [5]:
train_df = pd.read_csv(DATA_PATH + 'train.csv')
train_df.head()

Unnamed: 0,Image,Id
0,0000e88ab.jpg,w_f48451c
1,0001f9222.jpg,w_c3d896a
2,00029d126.jpg,w_20df2c5
3,00050a15a.jpg,new_whale
4,0005c1ef8.jpg,new_whale


In [6]:
train_df.Id.nunique()

5005

In [7]:
new_whale_df = train_df[train_df['Id'] == 'new_whale']
new_whale_df.shape

(9664, 2)

In [8]:
without_new_whale_train_df = train_df[~(train_df['Id'] == 'new_whale')]

验证集拆分
1.每个id按找ratio进行随机获取20%的样本作为验证集合，如果随机出来的样本数小于1个，则不放入验证集
2.对于new_whale 则不存在验证集中，
3.为new_whale计算一个阈值，在最终的结果中每个验证集合有5个候选值和对应的匹配概率，new_whale 按阈值插入对应的候选匹配概率中

In [10]:
validation_filenames = []
count_df = train_df.groupby('Id').count().reset_index()
count_df.columns = ['Id', 'Count']


def append_validation_data(id):
    
    df = without_new_whale_train_df[without_new_whale_train_df['Id'] == id].reset_index()['Image']
    count = df.count()
    validation_size = count * VALIDATION_RATIO
    if validation_size >= 1:
#         print('count = {0}, validation_size = {1}'.format(count, int(validation_size)))
        for i in range(0, int(validation_size)):
            validation_filenames.append(df[i])
    elif validation_size > 0 and validation_size < 1 and count > 1:
#         print('count = {0}, validation_size = {1}'.format(count, 1))
        validation_filenames.append(df[0])
#     else:
#         print('count = {0}, validation_size = 0'.format(count))
# todo 改成for循环        
count_df.apply(lambda item: append_validation_data(item.Id), axis=1)

0       None
1       None
2       None
3       None
4       None
5       None
6       None
7       None
8       None
9       None
10      None
11      None
12      None
13      None
14      None
15      None
16      None
17      None
18      None
19      None
20      None
21      None
22      None
23      None
24      None
25      None
26      None
27      None
28      None
29      None
        ... 
4975    None
4976    None
4977    None
4978    None
4979    None
4980    None
4981    None
4982    None
4983    None
4984    None
4985    None
4986    None
4987    None
4988    None
4989    None
4990    None
4991    None
4992    None
4993    None
4994    None
4995    None
4996    None
4997    None
4998    None
4999    None
5000    None
5001    None
5002    None
5003    None
5004    None
Length: 5005, dtype: object

In [11]:
len(validation_filenames)

3273

In [12]:
count_df = without_new_whale_train_df.groupby('Id').count().reset_index()

In [13]:
count_df.columns = ['Id', 'Count']

In [14]:
count_df.head()

Unnamed: 0,Id,Count
0,w_0003639,1
1,w_0003c59,1
2,w_0027efa,10
3,w_00289b1,2
4,w_002c810,1


In [15]:
over_sample_num = 15
over_sample_train_df = None
for i in count_df.index:
# for i in range(3):
#     rows = without_new_whale_train_df[without_new_whale_train_df['Id']== id]
#     print(rows)
    row = count_df.loc[i]
    if row['Count'] < over_sample_num:
#         print(row['Id'])
        group = without_new_whale_train_df[without_new_whale_train_df['Id'] == row['Id']]
#         print(group)
        new_group = group.sample(over_sample_num, replace=True)
        
        if over_sample_train_df is None:
            over_sample_train_df = new_group
        else:
            over_sample_train_df = pd.concat((over_sample_train_df, new_group))
#         print(over_sample_train_df)
        

In [16]:
count_df.loc[1]['Id']

'w_0003c59'

In [20]:
over_sample_train_df['is_valid'] = False

In [17]:
over_sample_train_df = over_sample_train_df.reset_index(drop=True)

In [21]:
over_sample_train_df

Unnamed: 0,Image,Id,is_valid
0,833675975.jpg,w_0003639,False
1,833675975.jpg,w_0003639,False
2,833675975.jpg,w_0003639,False
3,833675975.jpg,w_0003639,False
4,833675975.jpg,w_0003639,False
5,833675975.jpg,w_0003639,False
6,833675975.jpg,w_0003639,False
7,833675975.jpg,w_0003639,False
8,833675975.jpg,w_0003639,False
9,833675975.jpg,w_0003639,False


In [22]:
validation_filenames

['2f31725c6.jpg',
 '4e6290672.jpg',
 '204c7a64b.jpg',
 '108f230d8.jpg',
 '1eccb4eba.jpg',
 '20e7c6af4.jpg',
 '0d5777fc2.jpg',
 '2c0892b4d.jpg',
 '1cc6523fc.jpg',
 '4bb9c9727.jpg',
 '166a9e05d.jpg',
 '3ff8e6bd6.jpg',
 '4f495a9c8.jpg',
 '04eec03fa.jpg',
 '11116ae07.jpg',
 '58bc661b5.jpg',
 '7abceec23.jpg',
 '028bbb7af.jpg',
 '65d9f620f.jpg',
 '056472b40.jpg',
 '08db58a5f.jpg',
 '12093d2aa.jpg',
 '5fe4bd878.jpg',
 '5233e6139.jpg',
 '14bf04de1.jpg',
 '0301ca72f.jpg',
 '14c5eb662.jpg',
 '1cb57f0b9.jpg',
 '1ddec56e7.jpg',
 '39a3458a3.jpg',
 '2110619d0.jpg',
 '1887a9835.jpg',
 '04ef6e1d6.jpg',
 'f694ac18b.jpg',
 'ca2d288dc.jpg',
 'ae7993c1d.jpg',
 '4c9966db6.jpg',
 '9b9716164.jpg',
 '49f92915d.jpg',
 '0124d2989.jpg',
 '3ee226849.jpg',
 '276f83d14.jpg',
 '3332b245e.jpg',
 '34e29a50e.jpg',
 '00e9e5122.jpg',
 '06ae83bbc.jpg',
 '092ef02b8.jpg',
 '12670516b.jpg',
 '12f5ad5a6.jpg',
 '1c11e4f25.jpg',
 '20d4abba8.jpg',
 '2bbad75d4.jpg',
 '2bbfd538c.jpg',
 '618221292.jpg',
 '511f49e5d.jpg',
 '4b24dd76

In [22]:
# for valid_filename in ['2f31725c6.jpg']:
for valid_filename in validation_filenames:
    valid_df = over_sample_train_df[over_sample_train_df['Image'] == valid_filename]
#     print(valid_df)
    over_sample_train_df.loc[valid_df.index,'is_valid'] = True

In [23]:
over_sample_train_df.loc[37]

Image       cc6c1a235.jpg
Id              w_0027efa
is_valid            False
Name: 37, dtype: object

In [25]:
over_sample_train_df[over_sample_train_df['is_valid'] == True].count()

Image       14856
Id          14856
is_valid    14856
dtype: int64

In [26]:
over_sample_train_df.count()

Image       73080
Id          73080
is_valid    73080
dtype: int64

In [27]:
over_sample_train_df.to_csv('data/over_sample_train.csv', index=False)