# Ch. 1: Selectively Sampled Training

## 0. General Imports

In [None]:
import os
import sys
!pip install wget
import wget

import numpy as np
import pandas as pd
import random

import tensorflow as tf
from tensorflow import keras


import matplotlib.pyplot as plt
import cv2


from tqdm.notebook import tqdm

## 1. Loading the Dataset

In [None]:
# '''
# Borrowed from this user:
# https://stackoverflow.com/questions/66288078/any-easy-way-to-get-imagenet-dataset-for-training-custom-model-in-tensorflow

# Images are stored in the following manner:
# tiny-imagenet-200/train/<class>/images/<image_ID>.jpeg
# '''

# from zipfile import ZipFile

# url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
# tiny_imgdataset = wget.download('http://cs231n.stanford.edu/tiny-imagenet-200.zip', out = os.getcwd())
# for file in tqdm(os.listdir(os.getcwd())):
#     if file.endswith(".zip"):
#         zip_ = ZipFile(file)
#         zip_.extractall()
#     else:
#         print("Not found.")

In [None]:
class TinyImageNetData(tf.keras.utils.Sequence):

    def __init__(self, 
                 directory, 
                 batch_size,
                 model = None,
                 train_split = 0.8,
                 img_size = (256, 256),
                 random_seed = 42):
        
        '''
        Assumes directory is structured like 
        <class>/images/<image_ID>
        '''
        
        self.image_files = []
        self.labels = []
        self.id2class = {}
        
        self.valid_loss = []
        self.valid_acc = []
        
        # save all files
        for ind, img_class in enumerate(os.listdir(directory)):
            self.id2class[ind] = img_class
            img_class_folder = os.path.join(directory, img_class, 'images')
            for img_name in os.listdir(img_class_folder):
                self.image_files.append(os.path.join(img_class_folder, img_name))
                self.labels.append(ind) # we can just use an ordinal encoding
        
        # shuffle order
        joined = list(zip(self.image_files, self.labels))
        random.seed(random_seed)
        random.shuffle(joined)
        self.image_files, self.labels = zip(*joined)
        
        # train-test split
        split_ind = int(train_split * len(self.image_files))
        self.X_train = self.image_files[:split_ind]
        self.y_train = self.labels[:split_ind]
        self.X_valid = self.image_files[:split_ind]
        self.y_valid = self.labels[:split_ind]
        
        # calculate and store stats
        self.batch_size = batch_size
        self.num_samples = len(self.image_files)
        self.train_size = len(self.X_train)
        self.train_batches = self.train_size // batch_size
        self.valid_size = len(self.X_valid)
        self.valid_batches = self.valid_size // batch_size
        self.img_size = img_size

    def __len__(self):
        return self.train_batches

    def __getitem__(self, idx):
        imgs, labels = [], []
        for ind in range(idx*self.batch_size, (idx+1)*self.batch_size):
            img = plt.imread(self.X_train[ind])
            img = cv2.resize(img, self.img_size)
            if len(img.shape) == 2: # grayscale, no color depth
                img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB) # cvt to color
            imgs.append(img)
            labels.append(self.y_train[ind])
        return np.stack(imgs) / 255, np.expand_dims(np.stack(labels), 1)
    
    def get_valid_item(self, idx):
        imgs, labels = [], []
        for ind in range(idx*self.batch_size, (idx+1)*self.batch_size):
            img = plt.imread(self.X_valid[ind])
            img = cv2.resize(img, self.img_size)
            if len(img.shape) == 2: # grayscale, no color depth
                img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB) # cvt to color
            imgs.append(img)
            labels.append(self.y_valid[ind])
        return np.stack(imgs) / 255, np.expand_dims(np.stack(labels), 1)
    
    def validate(self):
        scce = keras.losses.SparseCategoricalCrossentropy()
        acc = keras.metrics.Accuracy()
        scce_sum, acc_sum = 0, 0
        for idx in tqdm(range(self.valid_batches)):
            imgs, labels = self.get_valid_item(idx)
            pred = self.model.predict(imgs)
            scce_sum += scce(labels, pred) / self.valid_batches
            acc.update_state(labels, np.argmax(pred, axis=1))
            acc_sum += acc.result().numpy() / self.valid_batches
            acc.reset_state()
        return scce_sum, acc_sum
    
    def on_epoch_end(self):
        scce, acc = self.validate()
        self.valid_loss.append(scce)
        self.valid_acc.append(acc)
        print(f'\tSCCE: {np.round(scce,2)} | Acc: {np.round(acc, 2)}')

Visualize data. For now, we won't provide a model.

In [None]:
data = TinyImageNetData(directory = 'tiny-imagenet-200/train', 
                        batch_size = 32,
                        img_size = (256, 256))

In [None]:
a, b = data[8]

In [None]:
a.shape

In [None]:
plt.figure(figsize=(10, 10), dpi=400)
for row in range(8):
    for col in range(8):
        plt.subplot(8, 8, row*8 + col + 1)
        plt.imshow(a[row*8 + col])
        plt.axis('off')
plt.show()

## 2. Training a Baseline Model

In [None]:
base_model = keras.applications.InceptionV3(
    include_top = True,
    weights = None,
    input_shape = (256, 256, 3),
    classes = 200,
    classifier_activation = "softmax"
)

In [None]:
# add base_model to dataset for validation
data.model = base_model

In [None]:
base_model.compile(optimizer='adam', 
                   loss='sparse_categorical_crossentropy',
                   metrics=['accuracy'])

In [None]:
base_history = keras.callbacks.History()
base_model.fit(data, epochs=5, callbacks=[base_history])

In [None]:
plt.figure(figsize=(10, 5), dpi=400)
plt.plot(base_history.history['loss'], label='Training')
plt.plot(data.valid_loss, linestyle='--', label='Validation')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot()

In [None]:
plt.figure(figsize=(10, 5), dpi=400)
plt.plot(base_history.history['accuracy'], label='Training')
plt.plot(data.valid_acc, linestyle='--', label='Validation')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.plot()

## 3. Calculating Difficulty Scores

In [None]:
scores = []
scce = keras.losses.SparseCategoricalCrossentropy(reduction='none')
for batch_idx in tqdm(range(len(data))):
    imgs, labels = data[batch_idx]
    pred = base_model.predict(imgs)
    losses = scce(labels, pred).numpy()
    scores.extend(losses.astype(np.float16).tolist())
scores = np.array(scores)

In [None]:
del base_model # no longer needed

In [None]:
plt.figure(figsize=(8, 6), dpi=400)
plt.hist(scores, bins=100, alpha=0.5)
plt.ylabel('Count')
plt.xlabel('Difficulty Score (Loss)')
plt.show()

## 4. Probabilistic Sampling

In [None]:
class ProbSampleImageNetData(TinyImageNetData):
    
    curr_sample = 0
    scores = None
    operative_train_size = 2500
    
    def __len__(self):
        return self.operative_train_size
    
    def __getitem__(self, idx):
        '''
        Requires user to attach self.scores
        '''
        
        imgs, labels = [], []
        counter = 0
        while counter < self.batch_size:
            if np.random.uniform() < self.th_probs[self.curr_sample]:
                img = plt.imread(self.X_train[self.curr_sample])
                img = cv2.resize(img, self.img_size)
                if len(img.shape) == 2: # grayscale, no color depth
                    img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB) # cvt to color
                imgs.append(img)
                labels.append(self.y_train[self.curr_sample])
                counter += 1
            self.curr_sample = (self.curr_sample + 1) % self.train_size
        
        return np.stack(imgs) / 255, np.expand_dims(np.stack(labels), 1)

In [None]:
probs_data = ProbSampleImageNetData(directory = 'tiny-imagenet-200/train', 
                                    batch_size = 32,
                                    img_size = (256, 256))

In [None]:
def s2tp(s): return np.exp(-s) # return 1 - np.exp(-s)
probs_data.th_probs = s2tp(scores)

In [None]:
plt.figure(figsize=(8, 6), dpi=400)
plt.hist(s2tp(scores), bins=100, alpha=0.5)
plt.ylabel('Count')
plt.xlabel('Theoretical Probability')
plt.show()

In [None]:
a, b = probs_data[0]

In [None]:
probs_model = keras.applications.InceptionV3(
    include_top = True,
    weights = None,
    input_shape = (256, 256, 3),
    classes = 200,
    classifier_activation = "softmax"
)
probs_data.model = probs_model

In [None]:
probs_model.compile(optimizer='adam', 
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy'])

In [None]:
probs_history = keras.callbacks.History()
probs_model.fit(probs_data, epochs=5, callbacks=[probs_history])

In [None]:
plt.figure(figsize=(10, 5), dpi=400)
plt.plot(probs_history.history['loss'], label='Training')
plt.plot(probs_data.valid_loss, linestyle='--', label='Validation')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot()

In [None]:
plt.figure(figsize=(10, 5), dpi=400)
plt.plot(probs_history.history['accuracy'], label='Training')
plt.plot(probs_data.valid_acc, linestyle='--', label='Validation')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.plot()

In [None]:
del probs_model

## 5. Deterministic Sampling

In [None]:
dets_data = TinyImageNetData(directory = 'tiny-imagenet-200/train', 
                             batch_size = 32,
                             img_size = (256, 256))

In [None]:
percentile = 80
thresh = np.percentile(scores, percentile)
ind_mask = np.squeeze(np.argwhere(scores < thresh))
dets_data.X_train = np.array(dets_data.X_train)[ind_mask]
dets_data.y_train = np.array(dets_data.y_train)[ind_mask]
dets_data.train_size = len(dets_data.X_train)
dets_data.train_batches = dets_data.train_batches // dets_data.batch_size

In [None]:
dets_model = keras.applications.InceptionV3(
    include_top = True,
    weights = None,
    input_shape = (256, 256, 3),
    classes = 200,
    classifier_activation = "softmax"
)
dets_data.model = dets_model

In [None]:
dets_model.compile(optimizer='adam', 
                   loss='sparse_categorical_crossentropy',
                   metrics=['accuracy'])

In [None]:
dets_history = keras.callbacks.History()
dets_model.fit(dets_data, epochs=5, callbacks=[dets_history])

In [None]:
plt.figure(figsize=(10, 5), dpi=400)
plt.plot(dets_history.history['loss'], label='Training')
plt.plot(dets_data.valid_loss, linestyle='--', label='Validation')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot()

In [None]:
plt.figure(figsize=(10, 5), dpi=400)
plt.plot(dets_history.history['accuracy'], label='Training')
plt.plot(dets_data.valid_acc, linestyle='--', label='Validation')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.plot()

## 6. Comparison

In [None]:
plt.figure(figsize=(10, 5), dpi=400)
plt.plot(base_history.history['loss'], label='Prob. Sampling')
plt.plot(probs_history.history['loss'], linestyle='--', label='Prob. Sampling')
plt.plot(dets_history.history['loss'], linestyle='dotted', label='Det. Sampling')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.plot()

In [None]:
plt.figure(figsize=(10, 5), dpi=400)
plt.plot(data.valid_perf, label='Prob. Sampling')
plt.plot(probs_data.valid_perf, linestyle='--', label='Prob. Sampling')
plt.plot(dets_data.valid_perf, linestyle='dotted', label='Det. Sampling')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Valid. Loss')
plt.plot()

In [None]:
plt.figure(figsize=(10, 5), dpi=400)
plt.plot(base_history.history['accuracy'], label='Prob. Sampling')
plt.plot(probs_history.history['accuracy'], linestyle='--', label='Prob. Sampling')
plt.plot(dets_history.history['accuracy'], linestyle='dotted', label='Det. Sampling')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Train Accuracy')
plt.plot()

In [None]:
plt.figure(figsize=(10, 5), dpi=400)
plt.plot(data.valid_acc, label='Prob. Sampling')
plt.plot(probs_data.valid_acc, linestyle='--', label='Prob. Sampling')
plt.plot(dets_data.valid_acc, linestyle='dotted', label='Det. Sampling')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Valid. Accuracy')
plt.plot()