In [1]:
import os
import pathlib

import tensorflow as tf
import tensorflow_io as tfio

from tensorflow import keras
import keras_cv
from tensorflow.keras import layers
from keras_cv import utils
from keras_cv.layers import BaseImageAugmentationLayer

import tensorflow_addons as tfa
from keras_flops import get_flops

from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np

import math
!set XLA_FLAGS=--xla_gpu_cuda_data_dir="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.8"
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



Num GPUs Available:  1


In [2]:
# DATA
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
#INPUT_SHAPE = (32, 32, 3)
#INPUT_SHAPE = (124, 129, 1)
INPUT_SHAPE = (624, 129, 1)
#NUM_CLASSES = 8
NUM_CLASSES = 50
#NUM_CLASSES = 5

# OPTIMIZER
LEARNING_RATE = 1e-6
WEIGHT_DECAY = 1e-4

# TRAINING
EPOCHS = 20

# AUGMENTATION
#IMAGE_SIZE = 32  # We will resize input images to this size.
IMAGE_SIZE = 64
PATCH_SIZE = 8  # Size of the patches to be extracted from the input images.
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

# ViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 8
MLP_UNITS = [
    PROJECTION_DIM * 2,
    PROJECTION_DIM,
]

# TOKENLEARNER
NUM_TOKENS = 8

# Data Processing

In [3]:
DATASET_PATH = 'datasets/ECS-50-master'

data_dir = pathlib.Path(DATASET_PATH)

In [4]:
DATASET_PATH = 'datasets/ESC-50-master/audio'
data_dir = pathlib.Path(DATASET_PATH)

In [5]:
import pandas as pd
esc50_csv = './datasets/ESC-50-master/meta/esc50.csv'
base_data_path = './datasets/ESC-50-master/audio/'

pd_data = pd.read_csv(esc50_csv)
pd_data.set_index('category', inplace=True)
pd_data.head()

Unnamed: 0_level_0,filename,fold,target,esc10,src_file,take
category,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
dog,1-100032-A-0.wav,1,0,True,100032,A
chirping_birds,1-100038-A-14.wav,1,14,False,100038,A
vacuum_cleaner,1-100210-A-36.wav,1,36,False,100210,A
vacuum_cleaner,1-100210-B-36.wav,1,36,False,100210,B
thunderstorm,1-101296-A-19.wav,1,19,False,101296,A


In [6]:
print(pd_data)
pd_short = pd_data.rename(index={'dog' : 'animal', 'chirping_birds' : 'natural', 'vacuum_cleaner' : 'interior', 'thunderstorm' : 'natural',
       'door_wood_knock' : 'interior', 'can_opening' : 'interior', 'crow' : 'animal', 'clapping' : 'human', 'fireworks' : 'exterior',
       'chainsaw' : 'exterior', 'airplane' : 'exterior', 'mouse_click' : 'interior', 'pouring_water' : 'natural', 'train' : 'exterior',
       'sheep' : 'animal', 'water_drops' : 'natural', 'church_bells' : 'exterior', 'clock_alarm' : 'interior',
       'keyboard_typing' : 'interior', 'wind' : 'natural', 'footsteps' : 'human', 'frog' : 'animal', 'cow' : 'animal', 'brushing_teeth' : 'human',
       'car_horn' : 'exterior', 'crackling_fire' : 'natural', 'helicopter' : 'exterior', 'drinking_sipping' : 'human', 'rain' : 'natural',
       'insects' : 'animal', 'laughing' : 'human', 'hen' : 'animal', 'engine' : 'exterior', 'breathing' : 'human', 'crying_baby' : 'human',
       'hand_saw' : 'exterior', 'coughing' : 'human', 'glass_breaking' : 'interior', 'snoring' : 'human', 'toilet_flush' : 'natural',
       'pig' : 'animal', 'washing_machine' : 'interior', 'clock_tick' : 'interior', 'sneezing' : 'human', 'rooster' : 'animal',
       'sea_waves' : 'natural', 'siren' : 'exterior', 'cat' : 'animal', 'door_wood_creaks' : 'interior', 'crickets' : 'natural'})
print(pd_short)
labels = pd_data.target.values.tolist()
print(labels)
#pd_data=pd_short
categories = pd_data.index.unique()
print(categories)

                         filename  fold  target  esc10  src_file take
category                                                             
dog              1-100032-A-0.wav     1       0   True    100032    A
chirping_birds  1-100038-A-14.wav     1      14  False    100038    A
vacuum_cleaner  1-100210-A-36.wav     1      36  False    100210    A
vacuum_cleaner  1-100210-B-36.wav     1      36  False    100210    B
thunderstorm    1-101296-A-19.wav     1      19  False    101296    A
...                           ...   ...     ...    ...       ...  ...
hen              5-263831-B-6.wav     5       6  False    263831    B
vacuum_cleaner  5-263902-A-36.wav     5      36  False    263902    A
footsteps        5-51149-A-25.wav     5      25  False     51149    A
sheep             5-61635-A-8.wav     5       8  False     61635    A
dog                5-9032-A-0.wav     5       0   True      9032    A

[2000 rows x 6 columns]
                   filename  fold  target  esc10  src_file take
c

In [7]:
labels = []
files_ = []
for root, dirs, files in os.walk(data_dir):
   files_ = files
print(files_)
for file in files_:
  labels.append(pd_data.loc[pd_data.filename == file, 'target'].values[0])

['1-100032-A-0.wav', '1-100038-A-14.wav', '1-100210-A-36.wav', '1-100210-B-36.wav', '1-101296-A-19.wav', '1-101296-B-19.wav', '1-101336-A-30.wav', '1-101404-A-34.wav', '1-103298-A-9.wav', '1-103995-A-30.wav', '1-103999-A-30.wav', '1-104089-A-22.wav', '1-104089-B-22.wav', '1-105224-A-22.wav', '1-110389-A-0.wav', '1-110537-A-22.wav', '1-115521-A-19.wav', '1-115545-A-48.wav', '1-115545-B-48.wav', '1-115545-C-48.wav', '1-115546-A-48.wav', '1-115920-A-22.wav', '1-115920-B-22.wav', '1-115921-A-22.wav', '1-116765-A-41.wav', '1-11687-A-47.wav', '1-118206-A-31.wav', '1-118559-A-17.wav', '1-119125-A-45.wav', '1-121951-A-8.wav', '1-12653-A-15.wav', '1-12654-A-15.wav', '1-12654-B-15.wav', '1-13571-A-46.wav', '1-13572-A-46.wav', '1-13613-A-37.wav', '1-137-A-32.wav', '1-137296-A-16.wav', '1-14262-A-37.wav', '1-155858-A-25.wav', '1-155858-B-25.wav', '1-155858-C-25.wav', '1-155858-D-25.wav', '1-155858-E-25.wav', '1-155858-F-25.wav', '1-15689-A-4.wav', '1-15689-B-4.wav', '1-160563-A-48.wav', '1-160563-

In [8]:
print(labels)

[0, 14, 36, 36, 19, 19, 30, 34, 9, 30, 30, 22, 22, 22, 0, 22, 19, 48, 48, 48, 48, 22, 22, 22, 41, 47, 31, 17, 45, 8, 15, 15, 15, 46, 46, 37, 32, 16, 37, 25, 25, 25, 25, 25, 25, 4, 4, 48, 48, 3, 15, 27, 27, 43, 12, 40, 40, 40, 40, 40, 40, 29, 10, 12, 7, 12, 12, 12, 26, 4, 6, 6, 40, 40, 44, 44, 23, 31, 20, 4, 4, 4, 49, 43, 24, 24, 7, 8, 8, 36, 36, 36, 41, 41, 41, 39, 3, 28, 18, 2, 2, 2, 2, 2, 20, 20, 20, 10, 46, 35, 38, 38, 25, 20, 20, 46, 44, 15, 15, 19, 19, 49, 35, 35, 43, 43, 19, 19, 19, 47, 43, 48, 48, 2, 2, 2, 21, 43, 43, 30, 10, 1, 35, 35, 28, 28, 1, 18, 11, 11, 43, 16, 10, 21, 26, 26, 18, 0, 0, 23, 23, 23, 24, 6, 6, 42, 42, 21, 4, 4, 0, 35, 35, 29, 26, 5, 5, 1, 1, 14, 14, 37, 38, 26, 26, 23, 23, 23, 23, 47, 29, 14, 14, 9, 9, 11, 11, 1, 28, 46, 28, 1, 28, 34, 12, 38, 1, 11, 34, 47, 47, 47, 47, 1, 27, 31, 14, 12, 18, 49, 36, 7, 7, 41, 41, 21, 21, 16, 16, 5, 5, 5, 28, 46, 38, 35, 8, 8, 10, 44, 44, 15, 17, 44, 17, 16, 16, 16, 25, 18, 17, 17, 33, 33, 33, 33, 33, 33, 33, 33, 24, 30, 24,

In [9]:
import shutil
sorted_dir = './datasets/ESC-50-master/sorted/'
if os.path.exists(pathlib.Path(sorted_dir)) is True:
  shutil.rmtree(pathlib.Path(sorted_dir))
if os.path.exists(pathlib.Path(sorted_dir)) is False:
  os.mkdir(sorted_dir)

for category in categories:
  curr_dir = sorted_dir + category
  if os.path.exists(pathlib.Path(curr_dir)) is False:
    os.mkdir(curr_dir)
  list_files = pd_data.loc[category].filename.values.tolist()
  for cat_file in list_files:
    shutil.copy(base_data_path + cat_file, curr_dir+'/'+cat_file)

In [10]:
train_ds, val_ds = tf.keras.utils.audio_dataset_from_directory(
    directory=pathlib.Path(sorted_dir),
    batch_size=BATCH_SIZE,
    validation_split=0.2,
    seed=0,
    output_sequence_length=80000,
    subset='both')

label_names = np.array(train_ds.class_names)
print()
print("label names:", label_names)

Found 2000 files belonging to 50 classes.
Using 1600 files for training.
Using 400 files for validation.

label names: ['airplane' 'breathing' 'brushing_teeth' 'can_opening' 'car_horn' 'cat'
 'chainsaw' 'chirping_birds' 'church_bells' 'clapping' 'clock_alarm'
 'clock_tick' 'coughing' 'cow' 'crackling_fire' 'crickets' 'crow'
 'crying_baby' 'dog' 'door_wood_creaks' 'door_wood_knock'
 'drinking_sipping' 'engine' 'fireworks' 'footsteps' 'frog'
 'glass_breaking' 'hand_saw' 'helicopter' 'hen' 'insects'
 'keyboard_typing' 'laughing' 'mouse_click' 'pig' 'pouring_water' 'rain'
 'rooster' 'sea_waves' 'sheep' 'siren' 'sneezing' 'snoring' 'thunderstorm'
 'toilet_flush' 'train' 'vacuum_cleaner' 'washing_machine' 'water_drops'
 'wind']


In [11]:
def squeeze(audio, labels):
  audio = tf.squeeze(audio, axis=-1)
  return audio, labels

train_ds = train_ds.map(squeeze, tf.data.AUTOTUNE)
val_ds = val_ds.map(squeeze, tf.data.AUTOTUNE)

In [12]:
test_ds = val_ds.shard(num_shards=2, index=0)
val_ds = val_ds.shard(num_shards=2, index=1)

In [13]:
for example_audio, example_labels in train_ds.take(1):  
  print(example_audio.shape)
  print(example_labels.shape)
print(train_ds.element_spec)

(32, 80000)
(32,)
(TensorSpec(shape=(None, 80000), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int32, name=None))


In [14]:
def get_spectrogram(waveform):
  # Convert the waveform to a spectrogram via a STFT.
  spectrogram = tf.signal.stft(
      waveform, frame_length=255, frame_step=128)
  # Obtain the magnitude of the STFT.
  spectrogram = tf.abs(spectrogram)
  # Add a `channels` dimension, so that the spectrogram can be used
  # as image-like input data with convolution layers (which expect
  # shape (`batch_size`, `height`, `width`, `channels`).
  spectrogram = spectrogram[..., tf.newaxis]
  return spectrogram

def make_spec_ds(ds):
  return ds.map(
      map_func=lambda audio,label: (get_spectrogram(audio), label),
      num_parallel_calls=tf.data.AUTOTUNE)

def get_spectrogram_ts(waveform):
  # Convert the waveform to a spectrogram via a STFT.
  spectrogram = tf.signal.stft(
      waveform, frame_length=255, frame_step=128)
  spectrogram = tfio.audio.freq_mask(spectrogram, param=10)
  spectrogram = tfio.audio.time_mask(spectrogram, param=10)
  # Obtain the magnitude of the STFT.
  spectrogram = tf.abs(spectrogram)
  # Add a `channels` dimension, so that the spectrogram can be used
  # as image-like input data with convolution layers (which expect
  # shape (`batch_size`, `height`, `width`, `channels`).
  spectrogram = spectrogram[..., tf.newaxis]
  return spectrogram

def make_spec_ts_ds(ds):
  return ds.map(
      map_func=lambda audio,label: (get_spectrogram_ts(audio), label),
      num_parallel_calls=tf.data.AUTOTUNE)

In [15]:
from IPython import display

for i in range(3):
  label = label_names[example_labels[i]]
  waveform = example_audio[i]
  spectrogram = get_spectrogram(waveform)

  print('Label:', label)
  print('Waveform shape:', waveform.shape)
  print('Spectrogram shape:', spectrogram.shape)
  print('Audio playback')
  display.display(display.Audio(waveform, rate=16000))

Label: siren
Waveform shape: (80000,)
Spectrogram shape: (624, 129, 1)
Audio playback


Label: crackling_fire
Waveform shape: (80000,)
Spectrogram shape: (624, 129, 1)
Audio playback


Label: breathing
Waveform shape: (80000,)
Spectrogram shape: (624, 129, 1)
Audio playback


In [16]:
train_spectrogram_ds = make_spec_ds(train_ds)
val_spectrogram_ds = make_spec_ds(val_ds)
test_spectrogram_ds = make_spec_ds(test_ds)

In [17]:
train_spectrogram_ds = train_spectrogram_ds.cache().shuffle(10000).prefetch(tf.data.AUTOTUNE)
val_spectrogram_ds = val_spectrogram_ds.cache().prefetch(tf.data.AUTOTUNE)
test_spectrogram_ds = test_spectrogram_ds.cache().prefetch(tf.data.AUTOTUNE)

In [18]:
class TimeMask(keras_cv.layers.BaseImageAugmentationLayer):
    def augment_image(self, image, transformation=None):
        return tfio.audio.time_mask(image, param=10)

In [19]:
norm_layer = layers.Normalization()
norm_layer.adapt(data=train_spectrogram_ds.map(map_func=lambda spec, label: spec))



data_augmentation_wresize = keras.Sequential(
    [
        layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
        norm_layer,
    ],
    name="data_augmentation",
)

data_augmentation_nresize = keras.Sequential(
    [
        norm_layer,
    ],
    name="data_augmentation",
)



# Token Learner

In [20]:
def position_embedding(
    projected_patches, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM
):
    # Build the positions.
    positions = tf.range(start=0, limit=num_patches, delta=1)

    # Encode the positions with an Embedding layer.
    encoded_positions = layers.Embedding(
        input_dim=num_patches, output_dim=projection_dim
    )(positions)

    # Add encoded positions to the projected patches.
    return projected_patches + encoded_positions

In [21]:
def mlp(x, dropout_rate, hidden_units):
    # Iterate over the hidden units and
    # add Dense => Dropout.
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

In [22]:
def token_learner(inputs, number_of_tokens=NUM_TOKENS):
    # Layer normalize the inputs.
    x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(inputs)  # (B, H, W, C)

    # Applying Conv2D => Reshape => Permute
    # The reshape and permute is done to help with the next steps of
    # multiplication and Global Average Pooling.
    attention_maps = keras.Sequential(
        [
            # 3 layers of conv with gelu activation as suggested
            # in the paper.
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation=tf.nn.gelu,
                padding="same",
                use_bias=False,
            ),
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation=tf.nn.gelu,
                padding="same",
                use_bias=False,
            ),
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation=tf.nn.gelu,
                padding="same",
                use_bias=False,
            ),
            # This conv layer will generate the attention maps
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation="sigmoid",  # Note sigmoid for [0, 1] output
                padding="same",
                use_bias=False,
            ),
            # Reshape and Permute
            layers.Reshape((-1, number_of_tokens)),  # (B, H*W, num_of_tokens)
            layers.Permute((2, 1)),
        ]
    )(
        x
    )  # (B, num_of_tokens, H*W)

    # Reshape the input to align it with the output of the conv block.
    num_filters = inputs.shape[-1]
    inputs = layers.Reshape((1, -1, num_filters))(inputs)  # inputs == (B, 1, H*W, C)

    # Element-Wise multiplication of the attention maps and the inputs
    attended_inputs = (
        attention_maps[..., tf.newaxis] * inputs
    )  # (B, num_tokens, H*W, C)

    # Global average pooling the element wise multiplication result.
    outputs = tf.reduce_mean(attended_inputs, axis=2)  # (B, num_tokens, C)
    return outputs

In [23]:
def transformer(encoded_patches):
    # Layer normalization 1.
    x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)

    # Multi Head Self Attention layer 1.
    attention_output = layers.MultiHeadAttention(
        num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
    )(x1, x1)

    # Skip connection 1.
    x2 = layers.Add()([attention_output, encoded_patches])

    # Layer normalization 2.
    x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)

    # MLP layer 1.
    x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=0.1)

    # Skip connection 2.
    encoded_patches = layers.Add()([x4, x2])
    return encoded_patches

In [24]:
def bottleneck_block(x, expand=64, squeeze=16, dropout_rate = 1.0):

  m = keras.Sequential([
      layers.Conv2D(expand, (1,1), padding='same'),
      layers.BatchNormalization(),
      layers.Activation('relu6'),
      layers.DepthwiseConv2D((3,3), padding='same'),
      layers.BatchNormalization(),
      layers.Activation('relu6'),
      layers.Conv2D(squeeze, (1,1), padding='same'),
      layers.BatchNormalization()
  ])(x)

  m = layers.Dropout(dropout_rate)(m)
    
  return layers.Add()([m, x])

In [25]:
def create_vit_classifier(use_token_learner=True, token_learner_units=NUM_TOKENS):
    inputs = layers.Input(shape=INPUT_SHAPE)  # (B, H, W, C)
    
    # Augment data.
    if use_token_learner:
        augmented = data_augmentation_wresize(inputs)
    else:
        augmented = data_augmentation_wresize(inputs)

    #mbconv = bottleneck_block(augmented, dropout_rate = 0.1)

    # Create patches and project the pathces.
    projected_patches = layers.Conv2D(
        filters=PROJECTION_DIM,
        kernel_size=(PATCH_SIZE, PATCH_SIZE),
        strides=(PATCH_SIZE, PATCH_SIZE),
        padding="VALID",
    )(augmented)
    
    if use_token_learner:
            projected_patches = token_learner(
                projected_patches, token_learner_units
            )
            # _, hh, c = projected_patches.shape
            # h = int(math.sqrt(hh))
            # projected_patches = layers.Reshape((h, h, c))(
            #     projected_patches)
            # _, h, w, c = projected_patches.shape
            # print(h,w,c)
            # projected_patches = layers.Reshape((h * w, c))(
            #     projected_patches
            # )  # (B, number_patches, projection_dim)

            # Add positional embeddings to the projected patches.
            encoded_patches = position_embedding(
                projected_patches, num_patches = NUM_TOKENS
            )  # (B, number_patches, projection_dim)
            encoded_patches = layers.Dropout(0.1)(encoded_patches)
    else:
        _, h, w, c = projected_patches.shape
        print(h,w,c)
        projected_patches = layers.Reshape((h * w, c))(
            projected_patches
        )  # (B, number_patches, projection_dim)

        # Add positional embeddings to the projected patches.
        encoded_patches = position_embedding(
            projected_patches
        )  # (B, number_patches, projection_dim)
        encoded_patches = layers.Dropout(0.1)(encoded_patches)

    # Iterate over the number of layers and stack up blocks of
    # Transformer.
    for i in range(NUM_LAYERS):
        # Add a Transformer block.
        encoded_patches = transformer(encoded_patches)

    # Layer normalization and Global average pooling.
    representation = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
    representation = layers.GlobalAvgPool1D()(representation)

    # Classify outputs.
    outputs = layers.Dense(NUM_CLASSES, activation="softmax")(representation)

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

In [26]:
def run_experiment(model):
    # Initialize the AdamW optimizer.
    optimizer = tfa.optimizers.AdamW(
        learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )

    # Compile the model with the optimizer, loss function
    # and the metrics.
    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    # Define callbacks
    checkpoint_filepath = "./tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    # Custom Scheduler Function
    lr_start   = 1e-6
    lr_max     = 0.000015 * BATCH_SIZE
    lr_min     = 1e-7
    lr_ramp_ep = 4
    lr_sus_ep  = 0
    lr_decay   = 0.7
      
    def lrfn(epoch):
        if epoch < lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
                
        elif epoch < lr_ramp_ep + lr_sus_ep:
            lr = lr_max
                
        else:
            lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min
                
        return lr


    # Using this Custom Function, create a Callback
    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=True)


    # Train the model.
    _ = model.fit(
        train_spectrogram_ds,
        epochs=EPOCHS,
        validation_data=val_spectrogram_ds,
        callbacks=[lr_callback, checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(test_spectrogram_ds)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")


In [27]:
vit_token_learner = create_vit_classifier()
print(get_flops(vit_token_learner, batch_size=1))
run_experiment(vit_token_learner)

Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`
11526533

Epoch 1: LearningRateScheduler setting learning rate to 1e-06.
Epoch 1/20

Epoch 2: LearningRateScheduler setting learning rate to 0.00012075.
Epoch 2/20

Epoch 3: LearningRateScheduler setting learning rate to 0.0002405.
Epoch 3/20

Epoch 4: LearningRateScheduler setting learning rate to 0.00036025.
Epoch 4/20

Epoch 5: LearningRateScheduler setting learning rate to 0.00048.
Epoch 5/20

Epoch 6: LearningRateScheduler setting learning rate to 0.00033602999999999997.
Epoch 6/20

Epoch 7: LearningRateScheduler setting learning rate to 0.000235251.
Epoch 7/20

Epoch 8: LearningRateScheduler setting learning rate to 0.00016470569999999996.
Epoch 8/20

Epoch 9: LearningRateScheduler setting learning rate to 0.00011532398999999998.
Epoch 9/20

Epoch 10: LearningRateScheduler setting learning rate to 8.075679299999997e-05.
Epoch 10/20

Epoch 11: LearningRateScheduler setting learning rate to 5.6

In [28]:
vit_token_learner = create_vit_classifier(use_token_learner=False)
print(get_flops(vit_token_learner, batch_size=1))
run_experiment(vit_token_learner)

8 8 128
72343533

Epoch 1: LearningRateScheduler setting learning rate to 1e-06.
Epoch 1/20

Epoch 2: LearningRateScheduler setting learning rate to 0.00012075.
Epoch 2/20

Epoch 3: LearningRateScheduler setting learning rate to 0.0002405.
Epoch 3/20

Epoch 4: LearningRateScheduler setting learning rate to 0.00036025.
Epoch 4/20

Epoch 5: LearningRateScheduler setting learning rate to 0.00048.
Epoch 5/20

Epoch 6: LearningRateScheduler setting learning rate to 0.00033602999999999997.
Epoch 6/20

Epoch 7: LearningRateScheduler setting learning rate to 0.000235251.
Epoch 7/20

Epoch 8: LearningRateScheduler setting learning rate to 0.00016470569999999996.
Epoch 8/20

Epoch 9: LearningRateScheduler setting learning rate to 0.00011532398999999998.
Epoch 9/20

Epoch 10: LearningRateScheduler setting learning rate to 8.075679299999997e-05.
Epoch 10/20

Epoch 11: LearningRateScheduler setting learning rate to 5.655975509999999e-05.
Epoch 11/20

Epoch 12: LearningRateScheduler setting learning