In [1]:
%load_ext autoreload
%autoreload 2

# Import and Setup

In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import gc
import string
import random
import numpy as np
import pandas as pd
from functools import partial
from argparse import Namespace
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *


import wandb
from wandb.keras import WandbCallback

from model import GetModel, get_feature_extractor
from config import get_config
from data import GetDataloader
from utils import ShowBatch
from callbacks import GetCallbacks

In [3]:
args = get_config()
DEBUG = False

print(args)

Namespace(batch_size=256, early_patience=6, embedding_save_path='../embeddings', epochs=1, image_height=128, image_width=128, labels={'melon_headed_whale': 0, 'humpback_whale': 1, 'false_killer_whale': 2, 'bottlenose_dolphin': 3, 'beluga': 4, 'minke_whale': 5, 'fin_whale': 6, 'blue_whale': 7, 'gray_whale': 8, 'southern_right_whale': 9, 'common_dolphin': 10, 'killer_whale': 11, 'short_finned_pilot_whale': 12, 'dusky_dolphin': 13, 'long_finned_pilot_whale': 14, 'sei_whale': 15, 'spinner_dolphin': 16, 'cuviers_beaked_whale': 17, 'spotted_dolphin': 18, 'brydes_whale': 19, 'commersons_dolphin': 20, 'white_sided_dolphin': 21, 'rough_toothed_dolphin': 22, 'pantropic_spotted_dolphin': 23, 'pygmy_killer_whale': 24, 'frasiers_dolphin': 25}, model_save_path='../models', num_folds=5, num_labels=26, resize=False, rlrp_factor=0.2, rlrp_patience=3, train_img_path='../128x128/train_images-128-128/train_images-128-128')


In [5]:
def id_generator(size=6, chars=string.ascii_uppercase + string.digits):
    return ''.join(random.choice(chars) for _ in range(size))

random_id = id_generator(size=8)
print(random_id)
args.exp_id = random_id

AGBSEVK8


# Prepare Dataset

In [6]:
df = pd.read_csv('../cleaned_5_fold_train.csv')
print('Num Labels:', args.num_labels)
df.head()

Num Labels: 26


Unnamed: 0,image,species,individual_id,img_path,target,fold
0,00021adfb725ed.jpg,melon_headed_whale,cadddb1636b9,../128x128/train_images-128-128/train_images-1...,0,2.0
1,000562241d384d.jpg,humpback_whale,1a71fbb72250,../128x128/train_images-128-128/train_images-1...,1,3.0
2,0007c33415ce37.jpg,false_killer_whale,60008f293a2b,../128x128/train_images-128-128/train_images-1...,2,2.0
3,0007d9bca26a99.jpg,bottlenose_dolphin,4b00fe572063,../128x128/train_images-128-128/train_images-1...,3,2.0
4,00087baf5cef7a.jpg,humpback_whale,8e5253662392,../128x128/train_images-128-128/train_images-1...,1,4.0


# Dataloader

In [7]:
# Sanity check
if DEBUG:
    # Get Split
    train_df = df[df.fold != 0]
    valid_df = df[df.fold == 0]

    # Get train and validation loaders
    dataset = GetDataloader(args)
    trainloader = dataset.dataloader(train_df)
    validloader = dataset.dataloader(valid_df)

    # Display a batch
    sample_imgs, sample_labels = next(iter(trainloader))

    show_batch = ShowBatch(args)
    show_batch.show_batch(sample_imgs, sample_labels)

# Model

In [8]:
if DEBUG:
    tf.keras.backend.clear_session()
    get_model = GetModel(args)
    model = get_model.get_efficientnet()
    model.summary()

# Callbacks

In [9]:
callbacks = GetCallbacks(args)

# Train

In [10]:
for fold in range(args.num_folds):
    # Get dataloaders
    train_df = df[df.fold != fold]
    valid_df = df[df.fold == fold]

    dataset = GetDataloader(args)
    trainloader = dataset.dataloader(train_df)
    validloader = dataset.dataloader(valid_df)
    
    # Initialize model
    tf.keras.backend.clear_session()
    model = GetModel(args)
    model = model.get_efficientnet()

    # Compile model
    model.compile('adam',
                  loss='categorical_crossentropy',
                  metrics=['acc',
                           tf.keras.metrics.TopKCategoricalAccuracy(1, name='top@1_acc'),
                           tf.keras.metrics.TopKCategoricalAccuracy(5, name='top@5_acc')])

    # Initialize W&B run
    run = wandb.init(project='happywhale', config=vars(args), group='effnetb0-new', job_type='train')

    # Train
    model.fit(trainloader, 
              epochs=args.epochs,
              validation_data=validloader,
              callbacks=[WandbCallback(save_model=False),
                         callbacks.get_reduce_lr_on_plateau()])
    
    # Save the model
    os.makedirs(f'{args.model_save_path}/{args.exp_id}', exist_ok=True)
    model.save(f'{args.model_save_path}/{args.exp_id}/model_{fold}.h5')
    
    # Load the model
    model = tf.keras.models.load_model(f'{args.model_save_path}/{args.exp_id}/model_{fold}.h5')
    
    # Evaluate
    preds = model.predict(validloader)
    df.loc[list(df[df.fold == fold].index), 'preds'] = np.argmax(preds, axis=1)
    
    # Get Embedding
    feature_extractor = get_feature_extractor(model)
    embedding = feature_extractor.predict(validloader)

    os.makedirs(f'{args.embedding_save_path}/{args.exp_id}', exist_ok=True)
    np.savez(f'{args.embedding_save_path}/{args.exp_id}/embedding_{fold}.npz', 
             embedding=embedding, 
             individual_id=valid_df.individual_id.values)
    
    del trainloader, validloader, model, feature_extractor, embedding
    _ = gc.collect()

    # Close W&B run
    run.finish()
    
df[['image', 'individual_id', 'target', 'preds']].to_csv('../oof.csv', index=False)

# Save Model as Artifacts

In [2]:
# Initialize W&B run
run = wandb.init(project='happywhale', group='effnetb0', job_type='log_model')
model_artifact = wandb.Artifact('EfficientNetB0', type='supervised')
model_artifact.add_dir('models')
run.log_artifact(model_artifact)
wandb.finish()

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: wandb version 0.12.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[34m[1mwandb[0m: Adding directory to artifact (./models)... Done. 3.7s


VBox(children=(Label(value=' 236.57MB of 236.57MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=…