Reference: [Sequence Of Boxes](https://docs.google.com/document/d/14vPryt1JFE6Si89VamwMOi-RASeXmMXpo0_BjT_hn-U/edit?usp=sharing)

Features:
- Incorporates LR decay from the [Segmenter](https://arxiv.org/abs/2105.05633) paper.
- Attention and MLP do not use Bias.
- Uses mIoU.
- Input Augmentations
  - Mean Substraction
  - Random Horizontal Flip. Both images and masks should be identically augmented.
  - Random Resize.
  - Rotation and Scaling Augmentations.
- Reloads model checkpoints for continued training.
- Uses Stochastic Depth for regularization. Reference: https://keras.io/examples/vision/cct/
- Ignores background pixels.


# Running Instructions
Typically, we develop notebooks on a Apple M1 local machine. When importing this to platforms such as Colab or Kaggle, following adaptations are required:
- Enable package installations in the [Import Modules](#import-modules) section.
- [Initialize WANDB](#initialize-wandb).
- Adjust [dataset splits](#download).

# Instructions to Reload Last Run's Weights
- [] Adjust EPOCHS and EPOCHS_DONE
- [] Ensure that the weights are loaded from the previous run.

# Import Modules

Note: This section requires changes to adapt to the target environments. Please refer to the [instructions](#running-instructions).

In [None]:
# ! pip install -q git+https://github.com/EfficientDL/codelab_utils.git

import pickle
import math
import operator
import os
import wandb
import torch

import codelab_utils.mpl_styles as mpl_styles
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
import albumentations as A
import matplotlib.pyplot as plt


from datasets import load_dataset, load_dataset_builder
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint
from kaggle_secrets import UserSecretsClient # Kaggle
from matplotlib import patches as patches
from itertools import accumulate
from functools import reduce
from collections import Counter
from pathlib import Path

mpl_styles.set_default_styles()
plt.rcParams['font.family'] = 'Poppins'

os.environ['WANDB_NOTEBOOK_NAME'] = 'PT.Segmenter'

DEBUG = False

# Initialize WANDB

Note: This section requires changes to adapt to the target environments. Please refer to the [instructions](#running-instructions).

In [None]:
# user_secrets = UserSecretsClient()
# wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
# os.environ['WANDB_API_KEY'] = wandb_api_key

# Device Selection

In [None]:
DEVICE = 'cpu'

if torch.cuda.is_available():
    DEVICE = 'cuda'
elif torch.backends.mps.is_available():
    DEVICE = 'mps'

print(f'Device: {DEVICE}')

# Data Preparation

* Download scene_parse150, an semantic segmentation dataset.
* Dataset class labels are available at [CSAILVision](https://github.com/CSAILVision/sceneparsing/blob/master/objectInfo150.csv)

## Hyperparameters

In [None]:
# BATCH_SIZE = 2
BATCH_SIZE = 8 # Kaggle

## Procurement

Note: This section requires changes to adapt to the target environments. Please refer to the [instructions](#running-instructions).

In [None]:
DS_NAME = "scene_parse_150"
DS_SPLITS = dict(
    train='train[:5]',
    validation='validation[:5]',
    test='test[:5]',
)
# DS_SPLITS = dict(
#     train='train',
#     validation='validation',
#     test='test',
# )# Kaggle

def load_split(ds_name, split, with_annotation=True):
    ds = load_dataset(ds_name, split=split).with_format("torch", device=device)
#     ds = load_dataset(
#         ds_name,
#         split=split,
#         cache_dir="/kaggle/input/scene-parse-150-hf/huggingface"
#     ) # Kaggle
    print(f"Split: {split} Items: {len(ds)} Features: {ds.features.keys()}")
    
    return ds, len(ds)

ds_builder = load_dataset_builder(DS_NAME)
splits = ds_builder.info.splits
split_infos = list(map(lambda k: (k, splits[k].num_examples), splits.keys()))

print(f"Available Splits: {split_infos}")

train_ds, train_count = load_split(DS_NAME, DS_SPLITS['train'])
val_ds, val_count = load_split(DS_NAME, DS_SPLITS['validation'])
test_ds, test_count = load_split(DS_NAME, DS_SPLITS['test'], with_annotation=False)

def ds_shape(ds, size, with_annotation=True, name='Training'):
    print(f'\n{name} Set')
    print('------------------')
    print(f'Size: {size}')
    
    if with_annotation:
        image, mask = next(iter(ds))
        print(f'Image Shape: {image.shape} Mask Shape: {mask.shape}')
    else:
        image = next(iter(ds))
        print(f'Image Shape: {image.shape}')

ds_shape(train_ds, train_count)
ds_shape(val_ds, val_count)
ds_shape(test_ds, test_count, with_annotation=False)

## Exploration

## Visualization

### Visualization Functions

In [None]:
def visualize_item(ax, item):
    ax.imshow(item)
    ax.set_axis_off()

def show_related_images(*relatives, batch_to_rows=True, title='', size=1.5):
    num_relatives, relative_dims = len(relatives), len(relatives[0].shape)

    if relative_dims == 3:
        relatives = list(map(lambda x: tf.expand_dims(x, axis=0), relatives))
    
    batch_size = relatives[0].shape[0]
    items = tf.range(batch_size*num_relatives)

    if batch_to_rows:
        fig, axes = plt.subplots(batch_size, num_relatives, figsize=(num_relatives*size, batch_size*size))
        rows, cols = tf.unravel_index(indices=items, dims=[batch_size, num_relatives])

        fig.supylabel('Batch')
        fig.supxlabel('Relatives')
    else:
        fig, axes = plt.subplots(num_relatives, batch_size, figsize=(batch_size*size, num_relatives*size))
        rows, cols = tf.unravel_index(indices=items, dims=[num_relatives, batch_size])

        fig.supxlabel('Batch')
        fig.supylabel('Relatives')

    axes = axes.ravel()

    for item_id in range(batch_size*num_relatives):
        row, col = rows[item_id], cols[item_id]
        ax = axes[item_id]

        item = relatives[col][row] if batch_to_rows else relatives[row][col]
        visualize_item(ax, item)
    
    fig.suptitle(title)
    fig.tight_layout()
    

batch_size, img_size, classes = 3, 8, 3
image = tf.random.normal((batch_size, img_size, img_size, 3))
mask = tf.random.uniform((batch_size, img_size, img_size, 1), maxval=classes, dtype=tf.int32)

show_related_images(image, mask, batch_to_rows=False)

### Single Input

### Multiple Inputs