<a href="https://colab.research.google.com/github/artinmajdi/chest_xray_private_main/blob/master/aims/aim1_1_taxonomy/SAM/my_SAM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook how to minimally implement **sharpness-aware minimization** in TensorFlow with the CIFAR10 dataset. Sharpness-aware minimization (SAM) was proposed in the paper - [Sharpness-Aware Minimization for Efficiently Improving Generalization](https://openreview.net/pdf?id=6Tm1mposlrM)<sup>*</sup>. Some notable differences in this implementation: 
* ResNet20 (attributed to [this repository](https://github.com/GoogleCloudPlatform/keras-idiomatic-programmer/blob/master/zoo/resnet/resnet_cifar10.py)) is used as opposed to PyramidNet and WideResNet. 
* ShakeDrop regularization has not been used.
* Two simple augmentation transformations (random crop and random brightness) have been used as opposed to Cutout, AutoAugment. 
* Adam has been used as the optimizer with the default arguments as provided by TensorFlow with a `ReduceLROnPlateau`. Table 1 of the original paper suggests to use SGD with different configurations. 
* Instead of training for full number of epochs I used early stopping with a patience of 10.

I referred to the following resources for this study - 
* [Original Paper](https://openreview.net/pdf?id=6Tm1mposlrM) 
* [davda54](https://github.com/davda54)'s [PyTorch implementation](https://github.com/davda54/sam)

*<sub>arXiv version of the paper can be found [here](https://arxiv.org/abs/2010.01412).</sub>

## Initial Setup

In [13]:
from google.colab import drive
drive.mount('/content/drive')

# !git clone https://github.com/sayakpaul/Sharpness-Aware-Minimization-TensorFlow

Mounted at /content/drive


In [2]:
!pip install -r /content/drive/MyDrive/RESEARCH/PhD/code/my_main_code/my_main_code/requirements.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mlflow~=1.12.1
  Downloading mlflow-1.12.1-py3-none-any.whl (13.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.9/13.9 MB[0m [31m64.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pysftp==0.2.9
  Downloading pysftp-0.2.9.tar.gz (25 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting psycopg2==2.8.5
  Downloading psycopg2-2.8.5.tar.gz (380 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m380.9/380.9 kB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting crowd-kit
  Downloading crowd_kit-1.1.0-py3-none-any.whl (74 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.0/75.0 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting wget~=3.2
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting

In [3]:
# !pip install -e '/content/drive/MyDrive/RESEARCH/PhD/code/my_main_code/'

import my_main_code

In [4]:
# !pip install condacolab
# import condacolab
# condacolab.install()
# !conda install git gitpython

In [12]:
# !pip install colabcode
# from colabcode import ColabCode
# ColabCode(port=1000)

In [18]:
%reload_ext autoreload
%autoreload 2

import sys
import os

# dir = '/content/drive/MyDrive/RESEARCH/PhD/code/my_main_code/'
# sys.path.append(dir)

import time
import tensorflow as tf
tf.random.set_seed(42)
print('tf version:', tf.__version__)


from my_main_code.utils import funcs
from my_main_code.aims.aim1_1_taxonomy.SAM import utils 
from my_main_code.aims.aim1_1_taxonomy.SAM import resnet_cifar10
# from aims.aim1_1_taxonomy.SAM import utils, resnet_cifar10


%reload_ext my_main_code.aims.aim1_1_taxonomy.SAM.resnet_cifar10
%reload_ext my_main_code.aims.aim1_1_taxonomy.SAM.utils
%reload_ext my_main_code.utils.funcs

tf version: 2.9.2


In [6]:
artin_stuff = utils.ArtinStuff()

strategy = artin_stuff.tpu_gpu_initialization()

INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.


INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.


INFO:tensorflow:Initializing the TPU system: grpc://10.10.154.186:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.10.154.186:8470


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


## Load Dataset and Prepare Data Loaders

### <span style='color:red'> a. Original dataset loading code from Google/SAM </span>

In [24]:
X , Y = artin_stuff.load_cifar100_raw_data()

mode='merged' # 'fine' 'coarse' 'merged'
n_classes = 120 # Y[mode]['train'].shape[1]
X[mode]['test'].shape

(10000, 32, 32, 3)

In [25]:
train_ds, test_ds = artin_stuff.load_cifar100_dataset( X=X, Y=Y, strategy=strategy,  mode=mode,  batch_size=128 )

Batch size: 1024


## Encapsulate SAM Logic 

SAM is implemented as follows - 

<center>
<img src="https://i.ibb.co/qRSfNX7/image.png"></img><br>
<small>Source: Original Paper</small>
</center>

## Define Callbacks

## Initialize Model with SAM and Train It

In [None]:
# history = artin_stuff.fit_SAM(strategy=strategy ,  train_ds=train_ds ,  test_ds=test_ds, n_classes=n_classes, activation='softmax', loss='sparse_categorical_crossentropy')

# utils.plot_history(history)

## Train a Regular ResNet20 Model

In [28]:
# num_classes = 120 if mode=='merged' else 100
model = funcs.architecture(architecture_name='EfficientNetB3', input_shape=[224,224,3], num_classes=n_classes, activation='softmax', first_index_trainable=-5)

model.compile(  optimizer="adam", loss=tf.keras.losses.binary_crossentropy, metrics=["accuracy"] ) 

In [29]:
start = time.time()

history = model.fit(train_ds,  validation_data=test_ds,  callbacks=artin_stuff.train_callbacks(),  epochs=200)
print(f"Total training time: {(time.time() - start)/60.} minutes")

utils.plot_history(history)

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200

KeyboardInterrupt: ignored