### 0. Pre-setup

In [1]:
%%capture
# upgrade jax, jaxlib, flax is essential, otherwise it may fail to load pretrained model
import os
import requests
import torch


!pip install kaggle

# TPU mode
if "TPU_NAME" in os.environ:
    print("TPU & Colab mode: Set up TPU devices")
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()

!pip install flax
!pip install jax-resnet

In [2]:
import os
import joblib
from typing import *
from functools import partial

from tqdm import tqdm
import pandas as pd
import numpy as np
from PIL import Image
from sklearn.preprocessing import LabelEncoder

# for data loading
import torch
from torch.utils.data import Dataset, DataLoader
# for augmentation
from albumentations import *

import optax
import flax
from flax import linen as nn
from flax.core import frozen_dict
from flax.core import FrozenDict
from flax.jax_utils import replicate, unreplicate
from flax.training import train_state, checkpoints
from flax.training.common_utils import shard, shard_prng_key

import jaxlib
import jax
import jax.numpy as jnp
from jax import lax
from jax_resnet import pretrained_resnet, slice_variables, Sequential

print(f"Flax version: {flax.__version__}")
print(f"Optax version: {optax.__version__}")
print(f"JAX version: {jax.__version__}")
print(f"JAX lib version: {jaxlib.__version__}")
print(f"JAX device: {jax.devices()}")

Flax version: 0.4.0
Optax version: 0.1.1
JAX version: 0.3.1
JAX lib version: 0.3.0
JAX device: [GpuDevice(id=0, process_index=0)]


### **1. Configuration**

In [3]:
EPOCH_N = 20
FIRST_N_ITER = None  # early stop on iterations for testing
LEARNING_RATE = 0.05
# GPU: (128, 128) -> 64, (256, 256) -> 16
# TPU: (128, 128) -> 64*8*2
BATCH_SIZE = 64*8*2
NUM_WORKERS = 2
PRE_RESIZE = True # if so, no need resizing during dataloading
FREEZE_BACKBONE = True
MOCK_DATASET = False  # control dataloading for speed test
ROOT_DIR = '/kaggle/input/happy-whale-and-dolphin'
TRAIN_CSV = os.path.join(ROOT_DIR, 'train.csv')
TEST_CSV = os.path.join(ROOT_DIR, 'sample_submission.csv')
TRAIN_DIR = '/kaggle/input/resize-images-of-happy-whales-and-dolphins/train_images_128'
TEST_DIR = '/kaggle/input/resize-images-of-happy-whales-and-dolphins/test_images_128'

IMAGE_SIZE: Optional[Tuple[int, int]] = (128, 128)
# ref: https://pytorch.org/vision/stable/models.html
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

OUTPUT_N = 15587
MODEL_ARCH = 'resnet18'
MODEL_SAVE_DIR = 'resnet_checkpoints'

# validator
assert all([IMAGE_SIZE[0] == s for s in IMAGE_SIZE]), "must be equal"

### **2. Data Preproc and Loading**

In [4]:
class HappyWhaleDataset(Dataset):
    """
    borrow from:
    - https://www.kaggle.com/snoop2head/pytorch-0-374-effb5-kfold-with-focalloss
    """
    def __init__(self, df, transforms=None, mock_dataset=False, inference=False):
        if mock_dataset:
            print(f"Mocking the dataset for testing...")
        
        self.mock_dataset = mock_dataset
        self.transforms = transforms
        self.df = df
        self.file_names = df.file_path.values
        if inference:
            self.species = None
            self.labels = None
        else:
            self.species = df.species.values
            self.labels = df.individual_id.values

    def __getitem__(self, index):
        getter = self.__get_mock_item if self.mock_dataset else self.__get_item
        return getter(index)
    
    def __get_mock_item(self, index):
        mock_img = np.ones((*IMAGE_SIZE, 3))
        if self.labels is None and self.species is None:
            return mock_img
        mock_label = np.random.randint(0, OUTPUT_N)
        return mock_img, mock_label
    
    def __get_item(self, index):
        image_path = self.file_names[index]
        image = Image.open(image_path)
        image = np.array(image)
        
        # take care of grayscale image by adding channel
        if len(image.shape) == 2: 
            image = np.dstack((image,)*3)        
        if self.transforms:
            image= self.transforms(image=image)['image']
        
        if self.labels is None and self.species is None:
            return image
        
        label = self.labels[index]
        return image, label

    def __len__(self):
        return len(self.df)

    def set_transform(self, transform):
        self.transform = transform


def get_train_file_path(filename, image_dir):
    return os.path.join(image_dir, filename)

In [5]:
df = pd.read_csv(TRAIN_CSV)
df['file_path'] = df['image'].apply(
    partial(get_train_file_path, image_dir=TRAIN_DIR)
)
print(f"No. of unique labels: {len(df.individual_id.unique())}")

# encode label from str to integer
label_encoder = LabelEncoder()
label_encoder.fit(df.individual_id)
df.individual_id = label_encoder.transform(df.individual_id)
with open("label_encoder.pkl", "wb") as p:
    joblib.dump(label_encoder, p)

assert len(df.individual_id.unique()) == OUTPUT_N
    
df.head(3)

No. of unique labels: 15587


Unnamed: 0,image,species,individual_id,file_path
0,00021adfb725ed.jpg,melon_headed_whale,12348,/kaggle/input/resize-images-of-happy-whales-an...
1,000562241d384d.jpg,humpback_whale,1636,/kaggle/input/resize-images-of-happy-whales-an...
2,0007c33415ce37.jpg,false_killer_whale,5842,/kaggle/input/resize-images-of-happy-whales-an...


In [6]:
# define transformations applied on dataset
if not PRE_RESIZE:
    assert IMAGE_SIZE is not None
transforms_ls = [Resize(*IMAGE_SIZE, p=1.0)] if not PRE_RESIZE else []
transforms_ls += [Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD, max_pixel_value=255.0, p=1.0)]
transforms = Compose(transforms_ls)

# set dataset and dataloader
train_dataset = HappyWhaleDataset(
    df, transforms=transforms,
    mock_dataset=MOCK_DATASET,
    inference=False
)
# default collate_func converts numpy to tensor in dataloading
# reference: https://discuss.pytorch.org/t/torch-dataloader-gives-torch-tensor-as-ouput/31084/6
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS, shuffle=True
)

# sanity check
x, _ = train_dataset[0]
batch, labels = next(iter(train_loader))

assert x.shape[:2] == IMAGE_SIZE and batch.shape[0] == BATCH_SIZE
print(f"Shape for one batch: {batch.shape}")

Shape for one batch: torch.Size([1024, 128, 128, 3])


### **3. Model Construction**

In [7]:
# @TODO fill in later
class MarginLayer(nn.Module):
    """
    reference:
    - https://www.kaggle.com/ragnar123/unsupervised-baseline-arcface
    - https://arxiv.org/pdf/1801.07698.pdf
    """
    @nn.compact
    def __call__(self, inputs):
        raise NotImplementedError


class Head(nn.Module):
    """
    references:
    - fastai
    - https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
    """
    batch_norm_cls: partial = partial(nn.BatchNorm, momentum=0.9)
        
    @nn.compact
    def __call__(self, inputs, train: bool):
        """ param train determines dropout and batchnorm """
        output_n = inputs.shape[-1]
        x = self.batch_norm_cls(use_running_average=not train)(inputs)
        x = nn.Dropout(rate=0.25)(x, deterministic=not train)
        x = nn.Dense(features=output_n)(x)
        x = nn.relu(x)
        x = self.batch_norm_cls(use_running_average=not train)(x)
        x = nn.Dropout(rate=0.5)(x, deterministic=not train)
        x = nn.Dense(features=OUTPUT_N)(x)
        return x

    
class HappyWhaleModel(nn.Module):
    backbone: Sequential
    head: Head
        
    def __call__(self, inputs, train: bool):
        x = self.backbone(inputs)
        # average pool layer
        x = jnp.mean(x, axis=(1, 2))
        x = self.head(x, train)
        return x

    
def _get_backbone_and_params(model_arch: str) -> Tuple[nn.Module, FrozenDict]:
    # get model & param structure for pretrained model
    if model_arch == 'resnet18':
        resnet_tmpl, params = pretrained_resnet(18)
        model = resnet_tmpl()
    else:
        raise NotImplementedError
        
    # get model & param structure for backbone
    start, end = 0, len(model.layers) - 2
    backbone = Sequential(model.layers[start:end])
    backbone_params = slice_variables(params, start, end)
    return backbone, backbone_params


def get_model_and_variables(model_arch: str, head_init_key: int) -> Tuple[nn.Module, FrozenDict]:
    """
    variables is a composition of params and batch state
    """
    inputs = jnp.ones((1, *IMAGE_SIZE, 3), jnp.float32)
    key = jax.random.PRNGKey(head_init_key)
    # get backbone
    backbone, backbone_params = _get_backbone_and_params(model_arch)
    
    # determine input size for head model
    backbone_output = backbone.apply(backbone_params, inputs, mutable=False)
    head_inputs = jnp.ones((1, backbone_output.shape[-1]), jnp.float32)
    # get head
    head = Head()
    head_params = head.init(key, head_inputs, train=False)
    
    # get final model
    model = HappyWhaleModel(backbone, head)
    # combine params from backbone and head
    variables = FrozenDict({
        'params': {
            'backbone': backbone_params['params'],
            'head': head_params['params']
        },
        'batch_stats': {
            'backbone': backbone_params['batch_stats'],
            'head': head_params['batch_stats']
        }
    })
    return model, variables

In [8]:
%%capture
# sanity check
model, variables = get_model_and_variables('resnet18', 0)
inputs = jnp.ones((1, *IMAGE_SIZE, 3), jnp.float32)
key = jax.random.PRNGKey(0)
o = model.apply(variables, inputs, train=False, mutable=False)

assert o.shape[0] == inputs.shape[0] and o.shape[-1] == OUTPUT_N


You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-p

### **4. Loss Function and Metrics**

In [9]:
def topk_accuracy(logits, labels, top_k = (1,)) -> List[jnp.DeviceArray]:
    """
    reference:
    - https://www.kaggle.com/snoop2head/pytorch-0-374-effb5-kfold-with-focalloss
    
    logits: (bs, class_n)
    labels: (bs, )
    """
    max_k = max(top_k)
    batch_size = labels.shape[0]
    
    _, preds = jax.lax.top_k(logits, k = max_k)
    # (max_k, bs)
    preds = preds.transpose()
    # convert labels to have shape: (max_k, bs)
    correct = (preds == jnp.repeat(labels[None, :], repeats=preds.shape[0], axis = 0) * 1.)
    
    accuracies = []
    for k in top_k:
        _accuracy = correct[:k].sum() * 100. / batch_size
        accuracies.append(_accuracy)
    return accuracies


# inputs: (logits, labels)
loss_fn = optax.softmax_cross_entropy
# inputs: (logits, labels)
eval_fn = partial(topk_accuracy, top_k = (1, 3))

### **5. Optimizer and TrainState**

In [10]:
class TrainState(train_state.TrainState):
    batch_stats: FrozenDict
    loss_fn: Callable = flax.struct.field(pytree_node=False)
    eval_fn: Callable = flax.struct.field(pytree_node=False)


def create_mask(params, label_fn):
    """
    reference:
    - https://colab.research.google.com/drive/16wcmLt0AIKzMmLPrliuBMfmBvM8VEc4p#scrollTo=TqDvTL_tIQCH
    """
    def _map(params, mask, label_fn):
        for k in params:
            if label_fn(k):
                mask[k] = 'zero'
            else:
                if isinstance(params[k], FrozenDict):
                    mask[k] = {}
                    _map(params[k], mask[k], label_fn)
                else:
                    mask[k] = 'adam'
    mask = {}
    _map(params, mask, label_fn)
    return frozen_dict.freeze(mask)


def zero_grads():
    """
    reference:
    - https://github.com/deepmind/optax/issues/159#issuecomment-896459491
    """
    def init_fn(_): 
        return ()
    def update_fn(updates, state, params=None):
        return jax.tree_map(jnp.zeros_like, updates), ()
    return optax.GradientTransformation(init_fn, update_fn)

In [11]:
adamw = optax.adamw(
    learning_rate=LEARNING_RATE,
    b1=0.9, b2=0.999, 
    eps=1e-6, weight_decay=1e-2
)

# freeze backbone by masking target params
if FREEZE_BACKBONE:
    # label as zero if params belong to "backbone" key
    optimizer = optax.multi_transform(
        {'adam': adamw, 'zero': zero_grads()},
        create_mask(variables['params'], lambda s: s.startswith('backbone'))
    )
# update all params
else:
    optimizer = adamw

state = TrainState.create(
    apply_fn = model.apply,
    params = variables['params'],
    tx = optimizer,
    batch_stats = variables['batch_stats'],
    loss_fn = loss_fn,
    eval_fn = eval_fn
)


You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


### **6. Training Loop**

In [12]:
def train_step(state: TrainState, batch, labels, dropout_rng) -> Tuple[TrainState, dict, jnp.DeviceArray]:
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
    
    # params as input because we differentiate wrt it 
    def loss_function(params):
        # if you set state.params, then params can't be backpropagated through!
        variables = {'params': params, 'batch_stats': state.batch_stats}
        # return mutated states if mutable is specified
        logits, new_batch_stats = state.apply_fn(
            variables, batch, train=True, 
            mutable=['batch_stats'],
            rngs={'dropout': dropout_rng}
        )
        # logits: (BS, OUTPUT_N), one_hot: (BS, OUTPUT_N)
        one_hot = jax.nn.one_hot(labels, OUTPUT_N)
        loss = state.loss_fn(logits, one_hot).mean()
        return loss, (logits, new_batch_stats)
    
    # if you wanna vary lr per step
    #lr = learning_rate_fn(state.step)
    
    # backpropagation and update params & batch_stats 
    grad_fn = jax.value_and_grad(loss_function, has_aux=True)
    (loss, aux), grads = grad_fn(state.params)
    logits, new_batch_stats = aux
    grads = lax.pmean(grads, axis_name='batch')
    new_state = state.apply_gradients(
        grads=grads, batch_stats=new_batch_stats['batch_stats']
    )
    
    # evaluation metrics
    accuracy = state.eval_fn(logits, labels)
    
    # store metadata
    metadata = jax.lax.pmean(
        {'loss': loss, 'accuracy': accuracy},
        axis_name='batch'
    )
    return new_state, metadata, new_dropout_rng


def val_step(state: TrainState, batch, labels):
    variables = {'params': state.params, 'batch_stats': state.batch_stats}
    logits = state.apply_fn(variables, batch, train=False)
    return state.eval_fn(logits, labels)

In [13]:
parallel_train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))
parallel_val_step = jax.pmap(val_step, axis_name='batch', donate_argnums=(0,))

# required for parallelism
state = replicate(state)

# control randomness on dropout and update inside train_step
rng = jax.random.PRNGKey(0)
dropout_rng = jax.random.split(rng, jax.local_device_count())  # for parallelism


You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


In [14]:
for epoch_i in tqdm(range(EPOCH_N), desc=f"{EPOCH_N} epochs", position=0, leave=True):
    # training set
    train_loss, train_accuracy = [], []
    iter_n = len(train_dataset)//BATCH_SIZE
    
    # @TODO: print out from tqdm can be hard to read, find an alternative
    with tqdm(total=iter_n, desc=f"{iter_n} iterations", leave=False) as progress_bar:
        for iter_i, (batch, labels) in enumerate(train_loader):
            
            # unavoid reminder cant be evenly distributed across tpu devices
            if iter_i + 1 > iter_n:
                break
            if (epoch_i == 0) and (FIRST_N_ITER is not None) and (iter_i >= FIRST_N_ITER):
                print(f"Early stop at iteration: {iter_i+1}")
                break
                
            # shard to enable parallelism
            batch, labels = shard(batch), shard(labels)
            batch = jnp.array(batch, dtype=jnp.float32)
            labels = jnp.array(labels, dtype=jnp.float32)
            
            # backprop and update param & batch stats
            state, train_metadata, dropout_rng = parallel_train_step(state, batch, labels, dropout_rng)
            train_metadata = unreplicate(train_metadata)
            
            # update train statistics
            _train_loss, _train_top1_acc, _train_top3_acc = map(float, [train_metadata['loss'], *train_metadata['accuracy']])
            train_loss.append(_train_loss)
            train_accuracy.append(_train_top1_acc)
            progress_bar.update(1)
    
    avg_train_loss = sum(train_loss)/len(train_loss)
    avg_train_acc = sum(train_accuracy)/len(train_accuracy)
    print(f"[{epoch_i+1}/{EPOCH_N}] Train Loss: {avg_train_loss:.03} | Train Accuracy: {avg_train_acc:.03}")
    
    # validation set
    pass

20 epochs:   0%|          | 0/20 [00:00<?, ?it/s]

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

49 iterations:   2%|▏         | 1/49 [00:24<19:30, 24.38s/it][A
49 iterations:   4%|▍         | 2/49 [00:24<07:58, 10.18s/it][A
49 iterations:   6%|▌         | 3/49 [00:24<04:19,  5.65s/it][A
49 iterations:   8%|▊         | 4/49 [00:25<02:38,  3.52s/it][A
49 iterations:  10%|█         | 5/49 [00:25<01:43,  2.35s/it][A
49 iterations:  12%|█▏        | 6/49 [00:31<02:38,  3.68s/it][A
49 iterations:  14%|█▍        | 7/49 [00:32<01:51,  2.65s/it][A
49 iterations:  16%|█▋        | 8/49 [00:39<02:45,  4.05s/it][A
49 iterations:  18%|█▊        | 9/49 [00:39<01:59,  2.98s/it][A
49 iterations:  20%|██        | 10/49 [00:46<02:44,  4.21s/it][A
49 iterations:  22%|██▏       | 11/49 [00:47<01:56,  3.07s/it][A
49 iterations:  24%|██▍       | 12/49 [00:

[1/20] Train Loss: 9.5 | Train Accuracy: 0.68



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<02:03,  2.57s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<00:57,  1.23s/it][A
49 iterations:   6%|▌         | 3/49 [00:04<01:14,  1.63s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:49,  1.10s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:07,  1.53s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:47,  1.11s/it][A
49 iterations:  14%|█▍        | 7/49 [00:10<01:02,  1.49s/it][A
49 iterations:  16%|█▋        | 8/49 [00:10<00:45,  1.10s/it][A
49 iterations:  18%|█▊        | 9/49 [00:12<00:55,  1.39s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:40,  1.05s/it][A
49 iterations:  22%|██▏       | 11/49 [00:15<00:57,  1.51s/it][A
49 iterations:  24%|██▍       | 12/49 [00:15<00:42,  1.14s/it][A
49 iterations:  27%|██▋       | 13/49 [00:17<00:51,  1.44s/it][A
49 iterations:  29%|██▊       | 14/49 [00:18<00:40,  1.17s/it][A
49 iterations:  31%|███    

[2/20] Train Loss: 7.42 | Train Accuracy: 2.37



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<02:10,  2.72s/it][A
49 iterations:   4%|▍         | 2/49 [00:03<01:00,  1.29s/it][A
49 iterations:   6%|▌         | 3/49 [00:05<01:17,  1.68s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:50,  1.13s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:04,  1.47s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:45,  1.06s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:58,  1.39s/it][A
49 iterations:  16%|█▋        | 8/49 [00:10<00:42,  1.04s/it][A
49 iterations:  18%|█▊        | 9/49 [00:12<00:52,  1.32s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:41,  1.06s/it][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:51,  1.35s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:37,  1.02s/it][A
49 iterations:  27%|██▋       | 13/49 [00:16<00:48,  1.35s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:36,  1.04s/it][A
49 iterations:  31%|███    

[3/20] Train Loss: 6.38 | Train Accuracy: 5.23



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<01:55,  2.40s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<00:54,  1.15s/it][A
49 iterations:   6%|▌         | 3/49 [00:04<01:12,  1.57s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:47,  1.06s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:01,  1.39s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:43,  1.01s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:59,  1.41s/it][A
49 iterations:  16%|█▋        | 8/49 [00:09<00:43,  1.05s/it][A
49 iterations:  18%|█▊        | 9/49 [00:11<00:52,  1.31s/it][A
49 iterations:  20%|██        | 10/49 [00:11<00:38,  1.01it/s][A
49 iterations:  22%|██▏       | 11/49 [00:13<00:48,  1.27s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:36,  1.03it/s][A
49 iterations:  27%|██▋       | 13/49 [00:16<00:45,  1.25s/it][A
49 iterations:  29%|██▊       | 14/49 [00:16<00:33,  1.04it/s][A
49 iterations:  31%|███    

[4/20] Train Loss: 5.4 | Train Accuracy: 11.5



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<02:02,  2.55s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<00:57,  1.22s/it][A
49 iterations:   6%|▌         | 3/49 [00:05<01:16,  1.67s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:50,  1.12s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:01,  1.40s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:43,  1.02s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:55,  1.33s/it][A
49 iterations:  16%|█▋        | 8/49 [00:09<00:42,  1.03s/it][A
49 iterations:  18%|█▊        | 9/49 [00:11<00:51,  1.29s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:39,  1.02s/it][A
49 iterations:  22%|██▏       | 11/49 [00:13<00:47,  1.26s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:37,  1.01s/it][A
49 iterations:  27%|██▋       | 13/49 [00:16<00:48,  1.35s/it][A
49 iterations:  29%|██▊       | 14/49 [00:16<00:37,  1.08s/it][A
49 iterations:  31%|███    

[5/20] Train Loss: 4.57 | Train Accuracy: 20.8



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<02:16,  2.84s/it][A
49 iterations:   4%|▍         | 2/49 [00:03<01:02,  1.34s/it][A
49 iterations:   6%|▌         | 3/49 [00:05<01:16,  1.65s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:49,  1.11s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:10,  1.61s/it][A
49 iterations:  12%|█▏        | 6/49 [00:08<00:49,  1.16s/it][A
49 iterations:  14%|█▍        | 7/49 [00:10<01:01,  1.46s/it][A
49 iterations:  16%|█▋        | 8/49 [00:10<00:44,  1.09s/it][A
49 iterations:  18%|█▊        | 9/49 [00:12<00:57,  1.43s/it][A
49 iterations:  20%|██        | 10/49 [00:13<00:42,  1.08s/it][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:50,  1.33s/it][A
49 iterations:  24%|██▍       | 12/49 [00:15<00:37,  1.00s/it][A
49 iterations:  27%|██▋       | 13/49 [00:17<00:46,  1.29s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:34,  1.02it/s][A
49 iterations:  31%|███    

[6/20] Train Loss: 3.98 | Train Accuracy: 28.1



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<01:58,  2.47s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<00:59,  1.27s/it][A
49 iterations:   6%|▌         | 3/49 [00:04<01:13,  1.61s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:48,  1.08s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:01,  1.41s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:45,  1.05s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:59,  1.41s/it][A
49 iterations:  16%|█▋        | 8/49 [00:09<00:42,  1.05s/it][A
49 iterations:  18%|█▊        | 9/49 [00:11<00:53,  1.34s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:39,  1.02s/it][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:52,  1.38s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:41,  1.11s/it][A
49 iterations:  27%|██▋       | 13/49 [00:16<00:49,  1.37s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:36,  1.05s/it][A
49 iterations:  31%|███    

[7/20] Train Loss: 3.59 | Train Accuracy: 33.6



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<02:01,  2.53s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<00:56,  1.21s/it][A
49 iterations:   6%|▌         | 3/49 [00:05<01:17,  1.69s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:52,  1.17s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:07,  1.54s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:47,  1.11s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:59,  1.41s/it][A
49 iterations:  16%|█▋        | 8/49 [00:10<00:43,  1.05s/it][A
49 iterations:  18%|█▊        | 9/49 [00:12<00:54,  1.36s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:39,  1.02s/it][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:52,  1.38s/it][A
49 iterations:  24%|██▍       | 12/49 [00:15<00:39,  1.07s/it][A
49 iterations:  27%|██▋       | 13/49 [00:17<00:47,  1.32s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:35,  1.01s/it][A
49 iterations:  31%|███    

[8/20] Train Loss: 3.35 | Train Accuracy: 36.8



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<01:59,  2.50s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<00:56,  1.19s/it][A
49 iterations:   6%|▌         | 3/49 [00:04<01:10,  1.53s/it][A
49 iterations:   8%|▊         | 4/49 [00:04<00:46,  1.03s/it][A
49 iterations:  10%|█         | 5/49 [00:06<00:59,  1.36s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:42,  1.00it/s][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:55,  1.31s/it][A
49 iterations:  16%|█▋        | 8/49 [00:09<00:42,  1.04s/it][A
49 iterations:  18%|█▊        | 9/49 [00:11<00:56,  1.42s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:41,  1.07s/it][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:50,  1.32s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:37,  1.00s/it][A
49 iterations:  27%|██▋       | 13/49 [00:16<00:45,  1.28s/it][A
49 iterations:  29%|██▊       | 14/49 [00:16<00:34,  1.02it/s][A
49 iterations:  31%|███    

[9/20] Train Loss: 3.12 | Train Accuracy: 40.1



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<02:20,  2.92s/it][A
49 iterations:   4%|▍         | 2/49 [00:03<01:06,  1.41s/it][A
49 iterations:   6%|▌         | 3/49 [00:05<01:16,  1.65s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:50,  1.12s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:05,  1.50s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:46,  1.08s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:57,  1.36s/it][A
49 iterations:  16%|█▋        | 8/49 [00:10<00:41,  1.01s/it][A
49 iterations:  18%|█▊        | 9/49 [00:12<00:51,  1.30s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:38,  1.01it/s][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:49,  1.31s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:36,  1.01it/s][A
49 iterations:  27%|██▋       | 13/49 [00:17<00:53,  1.48s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:39,  1.13s/it][A
49 iterations:  31%|███    

[10/20] Train Loss: 2.98 | Train Accuracy: 42.3



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<02:06,  2.64s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<00:58,  1.25s/it][A
49 iterations:   6%|▌         | 3/49 [00:04<01:14,  1.61s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:48,  1.09s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:04,  1.46s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:47,  1.09s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:59,  1.42s/it][A
49 iterations:  16%|█▋        | 8/49 [00:10<00:43,  1.06s/it][A
49 iterations:  18%|█▊        | 9/49 [00:12<00:55,  1.39s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:40,  1.04s/it][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:52,  1.38s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:39,  1.05s/it][A
49 iterations:  27%|██▋       | 13/49 [00:16<00:47,  1.33s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:35,  1.01s/it][A
49 iterations:  31%|███    

[11/20] Train Loss: 2.87 | Train Accuracy: 44.1



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<01:53,  2.37s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<00:53,  1.14s/it][A
49 iterations:   6%|▌         | 3/49 [00:04<01:11,  1.56s/it][A
49 iterations:   8%|▊         | 4/49 [00:04<00:47,  1.06s/it][A
49 iterations:  10%|█         | 5/49 [00:06<01:01,  1.39s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:43,  1.02s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:58,  1.40s/it][A
49 iterations:  16%|█▋        | 8/49 [00:09<00:42,  1.04s/it][A
49 iterations:  18%|█▊        | 9/49 [00:11<00:53,  1.34s/it][A
49 iterations:  20%|██        | 10/49 [00:11<00:39,  1.01s/it][A
49 iterations:  22%|██▏       | 11/49 [00:13<00:49,  1.31s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:36,  1.00it/s][A
49 iterations:  27%|██▋       | 13/49 [00:16<00:52,  1.45s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:38,  1.10s/it][A
49 iterations:  31%|███    

[12/20] Train Loss: 2.79 | Train Accuracy: 45.5



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<02:13,  2.79s/it][A
49 iterations:   4%|▍         | 2/49 [00:03<01:02,  1.34s/it][A
49 iterations:   6%|▌         | 3/49 [00:05<01:19,  1.74s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:52,  1.17s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:03,  1.45s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:45,  1.05s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:56,  1.34s/it][A
49 iterations:  16%|█▋        | 8/49 [00:10<00:41,  1.00s/it][A
49 iterations:  18%|█▊        | 9/49 [00:12<00:52,  1.32s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:39,  1.00s/it][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:48,  1.29s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:36,  1.02it/s][A
49 iterations:  27%|██▋       | 13/49 [00:16<00:48,  1.35s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:35,  1.03s/it][A
49 iterations:  31%|███    

[13/20] Train Loss: 2.7 | Train Accuracy: 46.5



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<02:03,  2.57s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<00:57,  1.23s/it][A
49 iterations:   6%|▌         | 3/49 [00:04<01:11,  1.55s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:49,  1.10s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:07,  1.54s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:47,  1.11s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:57,  1.37s/it][A
49 iterations:  16%|█▋        | 8/49 [00:09<00:42,  1.02s/it][A
49 iterations:  18%|█▊        | 9/49 [00:12<00:55,  1.38s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:40,  1.04s/it][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:50,  1.33s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:37,  1.01s/it][A
49 iterations:  27%|██▋       | 13/49 [00:16<00:46,  1.28s/it][A
49 iterations:  29%|██▊       | 14/49 [00:16<00:34,  1.02it/s][A
49 iterations:  31%|███    

[14/20] Train Loss: 2.66 | Train Accuracy: 47.4



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<01:55,  2.41s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<01:01,  1.30s/it][A
49 iterations:   6%|▌         | 3/49 [00:05<01:16,  1.67s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:50,  1.12s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:04,  1.47s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:48,  1.12s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:57,  1.37s/it][A
49 iterations:  16%|█▋        | 8/49 [00:10<00:41,  1.02s/it][A
49 iterations:  18%|█▊        | 9/49 [00:11<00:51,  1.29s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:38,  1.01it/s][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:54,  1.43s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:40,  1.08s/it][A
49 iterations:  27%|██▋       | 13/49 [00:16<00:47,  1.33s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:35,  1.01s/it][A
49 iterations:  31%|███    

[15/20] Train Loss: 2.61 | Train Accuracy: 48.1



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<01:57,  2.45s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<00:55,  1.18s/it][A
49 iterations:   6%|▌         | 3/49 [00:04<01:15,  1.65s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:49,  1.11s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:01,  1.40s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:43,  1.02s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:56,  1.34s/it][A
49 iterations:  16%|█▋        | 8/49 [00:09<00:43,  1.07s/it][A
49 iterations:  18%|█▊        | 9/49 [00:12<00:56,  1.41s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:41,  1.06s/it][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:50,  1.33s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:38,  1.04s/it][A
49 iterations:  27%|██▋       | 13/49 [00:17<00:52,  1.45s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:38,  1.10s/it][A
49 iterations:  31%|███    

[16/20] Train Loss: 2.55 | Train Accuracy: 48.8



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<01:54,  2.39s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<00:54,  1.16s/it][A
49 iterations:   6%|▌         | 3/49 [00:04<01:09,  1.51s/it][A
49 iterations:   8%|▊         | 4/49 [00:04<00:46,  1.02s/it][A
49 iterations:  10%|█         | 5/49 [00:06<00:58,  1.34s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:42,  1.02it/s][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:54,  1.30s/it][A
49 iterations:  16%|█▋        | 8/49 [00:09<00:41,  1.00s/it][A
49 iterations:  18%|█▊        | 9/49 [00:11<00:54,  1.36s/it][A
49 iterations:  20%|██        | 10/49 [00:11<00:40,  1.03s/it][A
49 iterations:  22%|██▏       | 11/49 [00:13<00:49,  1.29s/it][A
49 iterations:  24%|██▍       | 12/49 [00:13<00:36,  1.02it/s][A
49 iterations:  27%|██▋       | 13/49 [00:15<00:45,  1.27s/it][A
49 iterations:  29%|██▊       | 14/49 [00:16<00:34,  1.01it/s][A
49 iterations:  31%|███    

[17/20] Train Loss: 2.49 | Train Accuracy: 49.8



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<02:23,  2.99s/it][A
49 iterations:   4%|▍         | 2/49 [00:03<01:05,  1.40s/it][A
49 iterations:   6%|▌         | 3/49 [00:05<01:17,  1.68s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:51,  1.13s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:06,  1.52s/it][A
49 iterations:  12%|█▏        | 6/49 [00:08<00:47,  1.10s/it][A
49 iterations:  14%|█▍        | 7/49 [00:10<00:59,  1.42s/it][A
49 iterations:  16%|█▋        | 8/49 [00:10<00:43,  1.07s/it][A
49 iterations:  18%|█▊        | 9/49 [00:12<00:52,  1.32s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:39,  1.00s/it][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:48,  1.28s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:36,  1.02it/s][A
49 iterations:  27%|██▋       | 13/49 [00:16<00:45,  1.26s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:35,  1.01s/it][A
49 iterations:  31%|███    

[18/20] Train Loss: 2.47 | Train Accuracy: 50.3



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<02:13,  2.77s/it][A
49 iterations:   4%|▍         | 2/49 [00:03<01:01,  1.31s/it][A
49 iterations:   6%|▌         | 3/49 [00:04<01:12,  1.58s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:48,  1.09s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:10,  1.61s/it][A
49 iterations:  12%|█▏        | 6/49 [00:08<00:49,  1.16s/it][A
49 iterations:  14%|█▍        | 7/49 [00:09<00:58,  1.39s/it][A
49 iterations:  16%|█▋        | 8/49 [00:10<00:42,  1.04s/it][A
49 iterations:  18%|█▊        | 9/49 [00:12<00:52,  1.31s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:40,  1.03s/it][A
49 iterations:  22%|██▏       | 11/49 [00:14<00:51,  1.35s/it][A
49 iterations:  24%|██▍       | 12/49 [00:14<00:37,  1.03s/it][A
49 iterations:  27%|██▋       | 13/49 [00:16<00:46,  1.29s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:34,  1.01it/s][A
49 iterations:  31%|███    

[19/20] Train Loss: 2.44 | Train Accuracy: 50.7



49 iterations:   0%|          | 0/49 [00:00<?, ?it/s][A
49 iterations:   2%|▏         | 1/49 [00:02<02:01,  2.54s/it][A
49 iterations:   4%|▍         | 2/49 [00:02<00:57,  1.22s/it][A
49 iterations:   6%|▌         | 3/49 [00:05<01:22,  1.80s/it][A
49 iterations:   8%|▊         | 4/49 [00:05<00:53,  1.20s/it][A
49 iterations:  10%|█         | 5/49 [00:07<01:03,  1.45s/it][A
49 iterations:  12%|█▏        | 6/49 [00:07<00:45,  1.06s/it][A
49 iterations:  14%|█▍        | 7/49 [00:10<01:00,  1.44s/it][A
49 iterations:  16%|█▋        | 8/49 [00:10<00:43,  1.07s/it][A
49 iterations:  18%|█▊        | 9/49 [00:12<00:52,  1.32s/it][A
49 iterations:  20%|██        | 10/49 [00:12<00:39,  1.02s/it][A
49 iterations:  22%|██▏       | 11/49 [00:15<00:56,  1.50s/it][A
49 iterations:  24%|██▍       | 12/49 [00:15<00:41,  1.13s/it][A
49 iterations:  27%|██▋       | 13/49 [00:17<00:48,  1.36s/it][A
49 iterations:  29%|██▊       | 14/49 [00:17<00:36,  1.03s/it][A
49 iterations:  31%|███    

[20/20] Train Loss: 2.41 | Train Accuracy: 51.2





In [15]:
state = unreplicate(state)

# checkpoint the model
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
checkpoints.save_checkpoint(ckpt_dir=MODEL_SAVE_DIR, target=state,
                            step=int(state.step), overwrite=True)
#restored_state = checkpoints.restore_checkpoint(ckpt_dir=MODEL_SAVE_DIR, target=state)
print(f"Model checkpoint saved: {MODEL_SAVE_DIR}")

Model checkpoint saved: resnet_checkpoints


### 5. Submission

In [16]:
def infer_step(state: TrainState, batch):
    """
    return top5 predicted index for each sample
    """
    variables = {'params': state.params, 'batch_stats': state.batch_stats}
    logits = state.apply_fn(variables, batch, train=False)
    _, top_preds = jax.lax.top_k(logits, k = 5)
    return top_preds


# load csv for submission
test_df = pd.read_csv(TEST_CSV)
test_df['file_path'] = test_df['image'].apply(
    partial(get_train_file_path, image_dir=TEST_DIR)
)
print(f"No. of Submission Entries: {len(test_df)}")

# setup dataloader for submission
test_dataset = HappyWhaleDataset(
    test_df, transforms=transforms, 
    mock_dataset=MOCK_DATASET,
    inference=True
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS, shuffle=False
)

# ready for running batch
parallel_infer_step = jax.pmap(infer_step, axis_name='batch', donate_argnums=(0,))
state = replicate(state)

No. of Submission Entries: 27956


In [17]:
# batch processing samples for submission
top5_idxs = []
iter_n = len(test_dataset)//BATCH_SIZE
for iter_i, batch in tqdm(enumerate(test_loader)):
    if iter_i + 1 > iter_n:
        state = unreplicate(state)
        _batch_top5_idxs = infer_step(state, batch)
        # make this cell block re-runnable
        state = replicate(state)
    else:
        batch = shard(batch)
        batch = jnp.array(batch, dtype=jnp.float32)
        _batch_top5_idxs = parallel_infer_step(state, batch)
        _batch_top5_idxs = _batch_top5_idxs.reshape(-1, 5)
    top5_idxs.append(_batch_top5_idxs)

top5_idxs = jnp.concatenate(top5_idxs)
assert top5_idxs.shape[0] == test_df.shape[0]

See an explanation at https://jax.readthedocs.io/en/latest/notebooks/faq.html#buffer-donation.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-p

In [18]:
# label decoding
prediction_parser = lambda idxs: ' '.join(label_encoder.inverse_transform(idxs)).strip()
top5_labels = list(map(prediction_parser, top5_idxs))
test_df['predictions'] = top5_labels

# write the same csv
test_df.drop(columns=['file_path'], inplace=True)
test_df.to_csv('submission.csv', index=False)
print(f"Submission csv written!!")


You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


Submission csv written!!
