In [1]:
%load_ext autoreload
%autoreload 2

# Imports and Setup

In [22]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['PLATFORM'] = 'GCP' # Kaggle, Colab, Paperspace, Local

import gc
import json
import pprint
import numpy as np
import pandas as pd
from tqdm import tqdm
from functools import partial
from argparse import Namespace
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *

from model import SimpleSupervisedModel, ArcFaceSupervisedModel, get_feature_extractor
from config import get_train_config
from data import GetDataloader
from utils import ShowBatch, id_generator, get_stratified_k_fold, setup_device, count_data_items
from callbacks import GetCallbacks

pp = pprint.PrettyPrinter(indent=1)

In [6]:
# Setup Device for training
strategy = setup_device()
print("REPLICAS: ", strategy.num_replicas_in_sync)

#### GPU Available ####
REPLICAS:  1


In [7]:
# Setup W&B Login
try:
    import wandb
    wandb.login()
except:
    !pip install -qqq wandb
    import wandb
    wandb.login()
else:
    from wandb.keras import WandbCallback

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [16]:
# Setup Platform
platform = os.getenv('PLATFORM', None)
print('Where are we working? - ', platform)

Where are we working? -  GCP


In [17]:
# Setup configs
args = get_train_config()

# Setup experiment id for W&B
random_id = id_generator(size=8)
args.exp_id = random_id
print('Experiment ID: ', args.exp_id)

Experiment ID:  WDEREODW


# Dataset

In [20]:
GCS_PATH = 'gs://kds-d916c3252bf3bc5b3500b904f05f51ce57c8df85221d11b7711bcda9'
GCS_PATH = 'gs://kds-6f96f0bc6a675e5d7d316d604702c92b164fc68a7fe681084b7a873d'

In [24]:
train_files = np.sort(np.array(tf.io.gfile.glob(GCS_PATH + '/happywhale-2022-train*.tfrec')))
test_files = np.sort(np.array(tf.io.gfile.glob(GCS_PATH + '/happywhale-2022-test*.tfrec')))
print(GCS_PATH)
print(len(train_files), len(test_files), count_data_items(train_files), count_data_items(test_files))

gs://kds-6f96f0bc6a675e5d7d316d604702c92b164fc68a7fe681084b7a873d
10 10 51033 27956


# DataLoader

In [38]:
AUTO = tf.data.AUTOTUNE

def data_augment(image, label):
    return image, label

def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels = 3)
    image = tf.image.resize(image, [args.image_height, args.image_width])
    image = tf.cast(image, tf.float32) / 255.0

    return image


def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image_name": tf.io.FixedLenFeature([], tf.string), # Name of image
        "image": tf.io.FixedLenFeature([], tf.string), # Image tensors
        "target": tf.io.FixedLenFeature([], tf.int64), # Individual ID
    }

    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    
    image_name = example['image_name']
    image = decode_image(example['image'])
    label = tf.cast(example['target'], tf.int32)
    
    return image_name, image, label


def load_dataset(filenames, ordered = False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False 
        
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(read_labeled_tfrecord, num_parallel_calls = AUTO) 
    return dataset

In [46]:
def get_dataloader(filenames, 
                   type='train', # valid, eval, test
                   get_targets=True, 
                   get_names=False):
    
    if type=='train':
        order = False
    else:
        order = True
        
    dataloader = load_dataset(filenames, ordered=order)

    if type=='train':
        dataloader = dataloader.shuffle(args.batch_size*100)
        dataloader = dataloader.map(lambda image_name, image, label: (image, label))
        dataloader = dataloader.map(data_augment, num_parallel_calls=AUTO)
        dataloader = dataloader.repeat()
    elif type=='valid':
        dataloader = dataloader.map(lambda image_name, image, label: (image, label))
    elif type=='eval':
        dataloader = dataloader.map(lambda image_name, image, label: (image, label))
        if not get_targets:
            dataloader = dataloader.map(lambda image, label: image)
    elif type=='test':
        dataloader = dataloader.map(lambda image_name, image, label: (image_name, image))
        if not get_names:
            dataloader = dataloader.map(lambda image_name, image: image)

    dataloader = dataloader.batch(args.batch_size)
    dataloader = dataloader.prefetch(AUTO)

    return dataloader

In [58]:
trainloader = get_dataloader(train_files[0],
                             type='train', # valid, eval, test
                             get_targets=True, 
                             get_names=True)

In [None]:
sample_imgs, sample_labels = next(iter(trainloader))
show_batch = ShowBatch(args)
show_batch.show_batch(sample_imgs, sample_labels)