# Active Learning (Labeling Selection) using Augmentations

In [None]:
BASE_PATH = '/content/'
MODEL_FNAME = 'model_stage1'
ANNOTATIONS_FNAME = 'annotations.txt'
NUM_ANNOTATE = 52
FROM_STAGE1 = True

In [None]:
![ ! -f "pip_installed" ] && \
pip install -q tensorflow-datasets==4.4.0 tensorflow-addons && \
unzip -qq /content/drive/MyDrive/TeamSemiSuperCV/Wing/xray_reborn.zip -d /root/tensorflow_datasets && \
unzip -qq /content/drive/MyDrive/TeamSemiSuperCV/Wing/XRay_.zip -d /content && \
unzip -qq /content/drive/MyDrive/TeamSemiSuperCV/Active_Learn/$MODEL_FNAME\.zip -d /content/$MODEL_FNAME && \
cp /content/drive/MyDrive/TeamSemiSuperCV/Active_Learn/preprocess.py /content && \
cp /content/drive/MyDrive/TeamSemiSuperCV/Active_Learn/Xray_Reborn.py /content && \
cp /content/drive/MyDrive/TeamSemiSuperCV/Active_Learn/valid.txt /content && \
cp /content/drive/MyDrive/TeamSemiSuperCV/Active_Learn/test.txt /content && \
cp /content/drive/MyDrive/TeamSemiSuperCV/Active_Learn/$ANNOTATIONS_FNAME /content && \
git clone --depth 1 https://github.com/TeamSemiSuperCV/semi-super.git /content/semi-super && \
touch pip_installed
!ls -F

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import logging
logging.disable(logging.WARNING)

In [None]:
import json

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

from importlib import reload
from pathlib import Path
from scipy.stats import entropy
from scipy.special import softmax

from preprocess import dict2dict, IMG_SIZE
import Xray_Reborn

In [None]:
model_path = Path(BASE_PATH + MODEL_FNAME)
if (model_path / 'saved_model.pb').exists():
  model_path = str(model_path)
else:
  model_path = str([p.parent for p in model_path.glob('*/*/assets') if p.is_dir()][0])
model_path

## Labeling Selection

In [None]:
ds = tfds.load('xray_reborn')

ds_train = ds['train']
AUTO = tf.data.AUTOTUNE
ds_train = ds_train.map(dict2dict, num_parallel_calls=AUTO)
ds_train = ds_train.batch(64).prefetch(AUTO)

In [None]:
model = tf.saved_model.load(model_path)

In [None]:
logits = []
labels = []
fnames = []

for batch in ds_train:
    outputs = model(batch['image'], trainable=False)
    # ['block_group1', 'block_group3', 'block_group2', 'initial_conv', 'final_avg_pool', 'block_group4',
    #  'sup_head_input', 'proj_head_output', 'proj_head_input', 'initial_max_pool', 'logits_sup'])
    logits.append(outputs['logits_sup'].numpy())
    labels.append(batch['label'].numpy())
    fnames.append(batch['fname'].numpy())

logits = np.concatenate(logits)
initial_preds = np.argmax(logits, axis=1)
initial_probs = softmax(logits, axis=1)[:, 1]
labels = np.concatenate(labels)
fnames = np.concatenate(fnames)
fnames = [bs.decode('utf-8') for bs in fnames]

In [None]:
def get_preds_probs(ds):
  logits = []
  for batch in ds:
      outputs = model(batch['image'], trainable=False)
      logits.append(outputs['logits_sup'].numpy())
  logits = np.concatenate(logits)
  probs = softmax(logits, axis=1)[:, 1]
  preds = np.argmax(logits, axis=1)
  return preds, probs

In [None]:
NUM_AUGMENTS = 6

len_train = len(ds['train'])
board_score = np.zeros((len_train, NUM_AUGMENTS))
board_probs = np.zeros((len_train, NUM_AUGMENTS))

def update_scores(aug_num, new_preds):
  same = (initial_preds == new_preds)
  print((same).sum() / len_train)
  board_score[:, aug_num] = np.where(same, 1, 0)

def update_probs(aug_num, new_probs):
  mse = ((new_probs - initial_probs)**2).mean()
  print(mse, (new_probs.sum() / initial_probs.sum()))
  board_probs[:, aug_num] = new_probs

In [None]:
def do_augment(aug_num, aug_fn, ds=ds_train):
  ds_aug = ds.map(aug_fn, num_parallel_calls=AUTO)
  preds, probs = get_preds_probs(ds_aug)
  update_scores(aug_num, preds)
  update_probs(aug_num, probs)

In [None]:
def aug_nocrop(d):
  image = d['image']
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  image = tf.image.resize(image, IMG_SIZE[:2])
  d['image'] = image
  return d

ds_nocrop = ds['train'].map(aug_nocrop, num_parallel_calls=AUTO).batch(64).prefetch(AUTO)
do_augment(0, lambda x: x, ds=ds_nocrop)

In [None]:
def aug_crop_left(d):
  image = d['image']
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  d['image'] = tf.image.crop_and_resize(image,
                                        boxes=tf.constant([[0.0, 0.0, 1.0, 0.9]]),
                                        box_indices=tf.constant([0]),
                                        crop_size=IMG_SIZE[:2])
  return d

ds_crop = ds['train'].batch(1).map(aug_crop_left, num_parallel_calls=AUTO)
ds_crop = ds_crop.unbatch().batch(64).prefetch(AUTO)
do_augment(1, lambda x: x, ds=ds_crop)

In [None]:
def aug_crop_right(d):
  image = d['image']
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  d['image'] = tf.image.crop_and_resize(image,
                                        boxes=tf.constant([[0.0, 0.1, 1.0, 1.0]]),
                                        box_indices=tf.constant([0]),
                                        crop_size=IMG_SIZE[:2])
  return d

ds_crop = ds['train'].batch(1).map(aug_crop_right, num_parallel_calls=AUTO)
ds_crop = ds_crop.unbatch().batch(64).prefetch(AUTO)
do_augment(2, lambda x: x, ds=ds_crop)

In [None]:
def aug_crop_upper(d):
  image = d['image']
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  d['image'] = tf.image.crop_and_resize(image,
                                        boxes=tf.constant([[0.0, 0.0, 0.9, 1.0]]),
                                        box_indices=tf.constant([0]),
                                        crop_size=IMG_SIZE[:2])
  return d

ds_crop = ds['train'].batch(1).map(aug_crop_upper, num_parallel_calls=AUTO)
ds_crop = ds_crop.unbatch().batch(64).prefetch(AUTO)
do_augment(3, lambda x: x, ds=ds_crop)

In [None]:
def aug_crop_lower(d):
  image = d['image']
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  d['image'] = tf.image.crop_and_resize(image,
                                        boxes=tf.constant([[0.1, 0.0, 1.0, 1.0]]),
                                        box_indices=tf.constant([0]),
                                        crop_size=IMG_SIZE[:2])
  return d

ds_crop = ds['train'].batch(1).map(aug_crop_lower, num_parallel_calls=AUTO)
ds_crop = ds_crop.unbatch().batch(64).prefetch(AUTO)
do_augment(4, lambda x: x, ds=ds_crop)

In [None]:
def aug_noise(d):
  d['image'] = tf.add(d['image'], tf.random.normal(tf.shape(IMG_SIZE), 0, 0.015))
  return d

do_augment(5, aug_noise)

In [None]:
# def aug_horz_flip(d):
#   d['image'] = tf.image.flip_left_right(d['image'])
#   return d

# do_augment(6, aug_horz_flip)

In [None]:
# def aug_vert_flip(d):
#   d['image'] = tf.image.flip_up_down(d['image'])
#   return d

# do_augment(7, aug_vert_flip)

In [None]:
# def aug_brightness_incr(d):
#   d['image'] = tf.image.adjust_brightness(d['image'], +0.15)
#   return d

# do_augment(8, aug_brightness_incr)

In [None]:
# def aug_brightness_decr(d):
#   d['image'] = tf.image.adjust_brightness(d['image'], -0.15)
#   return d

# do_augment(9, aug_brightness_decr)

In [None]:
# def aug_contrast_incr(d):
#   d['image'] = tf.image.adjust_contrast(d['image'], 1.0 + 0.2)
#   return d

# do_augment(10, aug_contrast_incr)

In [None]:
# def aug_contrast_decr(d):
#   d['image'] = tf.image.adjust_contrast(d['image'], 1.0 - 0.2)
#   return d

# do_augment(11, aug_contrast_decr)

In [None]:
# def aug_gamma_incr(d):
#   d['image'] = tf.image.adjust_gamma(d['image'], 1.1)
#   return d

# do_augment(12, aug_gamma_incr)

In [None]:
# def aug_gamma_decr(d):
#   d['image'] = tf.image.adjust_gamma(d['image'], 0.9)
#   return d

# do_augment(13, aug_gamma_decr)

In [None]:
final_score = board_score.sum(axis=1)
final_probs = np.var(board_probs, axis=1)
print(final_score.max(), final_score.min(), final_score.mean())
print(final_probs.max(), final_probs.min(), final_probs.mean())

In [None]:
for fs in range(0, NUM_AUGMENTS):
  print(f'{fs}:{(final_score == np.full_like(final_score, fs)).sum()}', end=' ')

In [None]:
ent_lbl_fname = list(zip(final_probs, labels, fnames))
ent_lbl_fname.sort(key=lambda x: x[0], reverse=True)
# ent_lbl_fname = [x for i, x in enumerate(ent_lbl_fname) if (i % 6) == 0]  # skip adjacents
num_ones = sum(l for _, l, _ in ent_lbl_fname[:NUM_ANNOTATE])
print(f'{num_ones}')
selected = [f.strip() for _, _, f in ent_lbl_fname]

In [None]:
if FROM_STAGE1:
  # ds_train_1pc = ds['train_1pc']
  # annotations = {d['fname'].numpy().decode('utf-8') for d in ds_train_1pc}
  annotations = set() 
elif os.path.isfile(ANNOTATIONS_FNAME):
  print(ANNOTATIONS_FNAME)
  with open(ANNOTATIONS_FNAME) as f:
    annotations = set()
    for line in f:
      fname = line.strip()
      if fname: annotations.add(fname)
else:
  print(f'{ANNOTATIONS_FNAME} not found!')
  
len(annotations)

In [None]:
annotate_new = set(selected[:NUM_ANNOTATE])
annotate_thresh = NUM_ANNOTATE
target_annotations = len(annotations) + NUM_ANNOTATE
annotate_new |= annotations

while len(annotate_new) < target_annotations:
  annotate_thresh += 1
  annotate_new.add(selected[annotate_thresh])
  print('!', end='')
print()
len(annotate_new)

In [None]:
with open('annotations.txt', 'w') as f:
    for fname in annotate_new:
        print(fname, file=f)
!wc -l annotations.txt

## Generate New Dataset

In [None]:
# reload(Xray_Reborn)
!rm -rf {BASE_PATH + 'xray_reborn'}
ds = tfds.load('XrayReborn', data_dir=BASE_PATH)  # will re-generate TFDS dataset
len(ds['train_act']), len(ds['validation']), len(ds['test'])

## Stage 2 Fine-Tuning

In [None]:
wc_annotations = !wc -l annotations.txt
len_annotations = int(wc_annotations[0].split()[0])
len_annotations

In [None]:
class simclrCommand():
  def __init__(self, params):
    self.params = params

  def compile_command(self):
    simclr_command = ['python3 /content/semi-super/run.py']
    for k,v in self.params.items():
      simclr_command.append(f'--{k}={v}')
    return (" ").join(simclr_command)

  def run_command(self):
    !{self.compile_command()}

In [None]:
params = {
    # Dataset
    'dataset': "xray_reborn",

    # Training Logistics
    'train_mode': "finetune", 
    "mode": "train_then_eval",
    'train_split': 'train_act',
    'eval_split': "validation", 
    'checkpoint_epochs': 20,
    "save_only_last_ckpt": True,
    "eval_per_loop": False,
    'zero_init_logits_layer': False,
    "use_tpu": False,

    # Training Hyperparams
    'warmup_epochs': 0,
    "train_epochs": 60,
    'fine_tune_after_block': 3,
    "train_batch_size": 14,
    "learning_rate": 0.0005, 
    "learning_rate_scaling": 'sqrt',
    'weight_decay': 0.001, 
    "temperature": 0.1,

    # Architecture
    "image_size": 224,   
    "resnet_depth": 50,
    "width_multiplier": 2,
    "sk_ratio":0.0625,  

    # Augmentations
    "color_jitter_strength": 0.5,
    "use_blur": False, 
    "area_range_min": 1.0,

    # Static
    "data_dir": '/content/',
    }

if FROM_STAGE1:
  params['zero_init_logits_layer'] = True
slimsk2 = simclrCommand(params)

In [None]:
# 1st Fine-Tuning /w Validation Split Results
def FT1(run, rerun=False):
  model_ft_name = f'model_{len_annotations}-{run}'
  global model_ft_path
  model_ft_path = BASE_PATH + model_ft_name
  if not rerun:
    !rm -rf $model_ft_path
    assert not os.path.isdir(model_ft_path)

  slimsk2.params['mode'] = 'train_then_eval'
  slimsk2.params['checkpoint'] = model_path
  slimsk2.params['model_dir'] = model_ft_path
  slimsk2.run_command()

In [None]:
# 2nd (Follow-on) Fine-Tuning /w Validation Split Results
def FT2(run, rerun=False):
  model_ft_name = f'model_{len_annotations}-{run}'
  model_ft_path = BASE_PATH + model_ft_name
  model_ft_path_sm = str([p.parent for p in Path(model_ft_path).glob('*/*/assets') if p.is_dir()][0])
  global model_ft2_path
  model_ft2_path = model_ft_path + '+'
  if not rerun:
    !rm -rf $model_ft2_path
    assert not os.path.isdir(model_ft2_path)

  slimsk2.params['mode'] = 'train_then_eval'
  slimsk2.params['checkpoint'] = model_ft_path_sm
  slimsk2.params['model_dir'] = model_ft2_path
  slimsk2.run_command()

In [None]:
FT1(1) # Run #1

In [None]:
FT1(2) # Run #2

In [None]:
FT1(3) # Run #3

In [None]:
FT1(4) # Run #4

In [None]:
# Average Validation Accuracy based on 4 Runs
avg_eval_accuracy = np.mean([0.956107, 0.956107, 0.948473, 0.948473])
avg_eval_accuracy