In [None]:
import os
import shutil

base_dir = "/content/tiny-imagenet-200"
val_dir = os.path.join(base_dir, "val")
images_dir = os.path.join(val_dir, "images")
ann_file = os.path.join(val_dir, "val_annotations.txt")

# Read annotations
with open(ann_file) as f:
    annotations = [line.strip().split('\t') for line in f]

# Create class folders and move images
for img, cls, *_ in annotations:
    cls_dir = os.path.join(val_dir, cls)
    os.makedirs(cls_dir, exist_ok=True)
    shutil.move(
        os.path.join(images_dir, img),
        os.path.join(cls_dir, img)
    )

os.rmdir(images_dir)

In [None]:
!pip install -q tf-models-official

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m71.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.3/16.3 MB[0m [31m87.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m620.7/620.7 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m24.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.0/18.0 MB[0m [31m111.2 MB/s[0m eta [36m0:

In [None]:
import tensorflow as tf
import tensorflow.keras.layers as tfla
import tensorflow.keras.models as tfm
import tensorflow.keras.optimizers as tfo
import tensorflow.keras.losses as tflo
import matplotlib.pyplot as plt
from official.vision.ops import augment
import numpy as np

In [None]:
with open("tiny-imagenet-200/wnids.txt") as f:
    wnids = [line.strip() for line in f]

train_ds = tf.keras.utils.image_dataset_from_directory(
    "tiny-imagenet-200/train",
    labels="inferred",
    label_mode="categorical",
    class_names=wnids,
    image_size=(256, 256),
    batch_size=None,
)

test_ds = tf.keras.utils.image_dataset_from_directory(
    "tiny-imagenet-200/val",
    labels="inferred",
    label_mode="categorical",
    class_names=wnids,
    image_size=(224, 224),
    batch_size=128,
)

Found 100000 files belonging to 200 classes.
Found 10000 files belonging to 200 classes.


In [None]:
def crop_image(image, label):
  # image shape: [h, w, c]
  # label shape: [num_class,]
  image = tf.image.random_crop(image, (224, 224, 3))
  image = tf.image.random_flip_left_right(image)
  return image, label

In [None]:
train_ds = train_ds.map(crop_image, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
def apply_randaugment(image, label):
  # image shape: [h, w, c]
  # label shape: [num_class,]
  augmenter = augment.RandAugment(num_layers=2, magnitude=9)
  return augmenter.distort(image), label

In [None]:
train_ds = train_ds.map(apply_randaugment, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
def normalise_image(image, label):
  # image shape: [h, w, c]
  # label shape: [num_class,]
  mean = tf.constant([0.485, 0.456, 0.406])
  std = tf.constant([0.229, 0.224, 0.225])

  image = (image / 255.0 - mean) / std

  return image, label

In [None]:
train_ds = train_ds.map(normalise_image, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.map(normalise_image, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
train_ds = train_ds.batch(128)
combined_ds = train_ds

In [None]:
def mixup(images, labels):
  # images shape: [batchsize, h, w, c]
  # labels shape: [batchsize, num_class]
  batch_size = tf.shape(images)[0]

  gamma_1 = tf.random.gamma(shape=[batch_size, 1], alpha=0.2)
  gamma_2 = tf.random.gamma(shape=[batch_size, 1], alpha=0.2)
  lam = gamma_1 / (gamma_1 + gamma_2)

  indices = tf.random.shuffle(tf.range(batch_size))
  shuffled_images = tf.gather(images, indices)
  shuffled_labels = tf.gather(labels, indices)

  images_lam = tf.reshape(lam, [-1, 1, 1, 1])
  labels_lam = lam


  images = images_lam * images + (1 - images_lam) * shuffled_images
  labels = labels_lam * labels + (1 - labels_lam) * shuffled_labels

  return images, labels

In [None]:
mixup_ds = train_ds.map(mixup, num_parallel_calls=tf.data.AUTOTUNE)
combined_ds = combined_ds.concatenate(mixup_ds)

In [None]:
def cutmix(images, labels):
  # images shape: [batchsize, h, w, c]
  # labels shape: [batchsize, num_class]
  batch_size = tf.shape(images)[0]
  img_height = tf.shape(images)[1]
  img_width = tf.shape(images)[2]

  gamma_1 = tf.random.gamma(shape=[batch_size, 1], alpha=0.2)
  gamma_2 = tf.random.gamma(shape=[batch_size, 1], alpha=0.2)
  # lam is the cut percentage with shape: [batch_size, 1]
  lam = gamma_1 / (gamma_1 + gamma_2)

  # we find the cut image height and width all with shape [batch_size, 1]
  cut_height = tf.cast(tf.cast(img_height, tf.float32) * tf.sqrt(lam), tf.int32)
  cut_width = tf.cast(tf.cast(img_width, tf.float32) * tf.sqrt(lam), tf.int32)

  min_height = tf.cast(cut_height // 2, tf.int32)
  max_height = tf.cast(img_height - 1 - cut_height // 2, tf.int32)
  cut_centre_x = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=1,
                                   dtype=tf.float32)
  cut_centre_x = cut_centre_x * tf.cast(max_height - min_height, tf.float32) + tf.cast(min_height, tf.float32)

  min_width = tf.cast(cut_width // 2, tf.int32)
  max_width = tf.cast(img_width - 1 - cut_width // 2, tf.int32)
  cut_centre_y = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=1,
                                   dtype=tf.float32)
  cut_centre_y = cut_centre_y * tf.cast(max_width - min_width, tf.float32) + tf.cast(min_width, tf.float32)

  # find four conors for rectangles all with shape: [batch_size, 1]
  x1 = tf.cast(tf.cast(cut_centre_x, tf.int32) - cut_height // 2, tf.int32)
  x2 = tf.cast(tf.cast(cut_centre_x, tf.int32) + cut_height // 2, tf.int32)
  y1 = tf.cast(tf.cast(cut_centre_y, tf.int32) - cut_width // 2, tf.int32)
  y2 = tf.cast(tf.cast(cut_centre_y, tf.int32) + cut_width // 2, tf.int32)

  x_indices = tf.range(img_height)
  y_indices = tf.range(img_width)
  y_grid, x_grid = tf.meshgrid(y_indices, x_indices)
  x_grid = tf.reshape(x_grid, [1, img_height, img_width])
  y_grid = tf.reshape(y_grid, [1, img_height, img_width])

  x1 = tf.reshape(x1, [batch_size, 1, 1])
  x2 = tf.reshape(x2, [batch_size, 1, 1])
  y1 = tf.reshape(y1, [batch_size, 1, 1])
  y2 = tf.reshape(y2, [batch_size, 1, 1])

  # mask matrix with shape : [batch_size, h, w]
  mask = tf.logical_and(tf.logical_and(x1 <= x_grid, x_grid <= x2), tf.logical_and(
      y1 <= y_grid, y_grid <= y2))
  mask = tf.cast(mask, dtype=tf.int32)

  indices = tf.random.shuffle(tf.range(batch_size))
  shuffled_images = tf.gather(images, indices)
  shuffled_labels = tf.gather(labels, indices)

  mask = tf.reshape(mask, [batch_size, img_height, img_width, 1])

  images = images * tf.cast(1 - mask, dtype=tf.float32) + shuffled_images * tf.cast(mask, dtype=tf.float32)
  labels = labels * (1.0 - lam) + shuffled_labels * lam

  return images, labels

In [None]:
cutmix_ds = train_ds.map(cutmix, num_parallel_calls=tf.data.AUTOTUNE)
combined_ds = combined_ds.concatenate(cutmix_ds)

In [None]:
def erase(images, labels):
  # images shape: [batch_size, h, w, c]
  # labels shape: [batch_size, num_class]
  batch_size = tf.shape(images)[0]
  img_height = tf.shape(images)[1]
  img_width = tf.shape(images)[2]

  # lam is the cut percentage with shape: [batch_size, 1]
  lam = tf.random.uniform(shape=(batch_size, 1), minval=0.2, maxval=0.5, dtype=tf.float32)

  # we find the cut image height and width all with shape [batch_size, 1]
  erase_height = tf.cast(tf.cast(img_height, tf.float32) * tf.sqrt(lam), tf.int32)
  erase_width = tf.cast(tf.cast(img_width, tf.float32) * tf.sqrt(lam), tf.int32)

  min_height = tf.cast(erase_height // 2, tf.int32)
  max_height = tf.cast(img_height - 1 - erase_height // 2, tf.int32)
  erase_centre_x = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=1,
                                   dtype=tf.float32)
  erase_centre_x = erase_centre_x * tf.cast(max_height - min_height, tf.float32) + tf.cast(min_height, tf.float32)

  min_width = tf.cast(erase_width // 2, tf.int32)
  max_width = tf.cast(img_width - 1 - erase_width // 2, tf.int32)
  erase_centre_y = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=1,
                                   dtype=tf.float32)
  erase_centre_y = erase_centre_y * tf.cast(max_width - min_width, tf.float32) + tf.cast(min_width, tf.float32)

  # find four conors for rectangles all with shape: [batch_size, 1]
  x1 = tf.cast(tf.cast(erase_centre_x, tf.int32) - erase_height // 2, tf.int32)
  x2 = tf.cast(tf.cast(erase_centre_x, tf.int32) + erase_height // 2, tf.int32)
  y1 = tf.cast(tf.cast(erase_centre_y, tf.int32) - erase_width // 2, tf.int32)
  y2 = tf.cast(tf.cast(erase_centre_y, tf.int32) + erase_width // 2, tf.int32)

  x_indices = tf.range(img_height)
  y_indices = tf.range(img_width)
  y_grid, x_grid = tf.meshgrid(y_indices, x_indices)
  x_grid = tf.reshape(x_grid, [1, img_height, img_width])
  y_grid = tf.reshape(y_grid, [1, img_height, img_width])

  x1 = tf.reshape(x1, [batch_size, 1, 1])
  x2 = tf.reshape(x2, [batch_size, 1, 1])
  y1 = tf.reshape(y1, [batch_size, 1, 1])
  y2 = tf.reshape(y2, [batch_size, 1, 1])

  # mask matrix with shape : [batch_size, h, w]
  mask = tf.logical_and(tf.logical_and(x1 <= x_grid, x_grid <= x2), tf.logical_and(
      y1 <= y_grid, y_grid <= y2))
  mask = tf.cast(mask, dtype=tf.int32)

  mask = tf.reshape(mask, [batch_size, img_height, img_width, 1])

  images = images * tf.cast(1 - mask, dtype=tf.float32)

  return images, labels

In [None]:
erase_ds = train_ds.map(erase, num_parallel_calls=tf.data.AUTOTUNE)
combined_ds = combined_ds.concatenate(erase_ds)

In [None]:
def label_smoothing(labels, epsilon=0.1):
  num_class = tf.cast(tf.shape(labels)[1], tf.float32)
  return labels * (1.0 - epsilon) + epsilon / num_class

In [None]:
combined_ds = combined_ds.map(lambda images, labels: (images, label_smoothing(labels)), num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
combined_ds = combined_ds.shuffle(buffer_size=100)

In [None]:
def window_partition(x, window_size):
  # x shape:[B, H, W, C]
  B = tf.shape(x)[0]
  H = tf.shape(x)[1]
  W = tf.shape(x)[2]
  C = tf.shape(x)[3]

  # x shape:[B, row_num, row_in_window, col_num, col_in_window, C]
  x = tf.reshape(x, shape=[B, H // window_size, window_size, W // window_size, window_size, C])

  # x shape:[B, row_num, col_num, row_in_window, col_in_window, C]
  x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])

  # x shape:[B_, row_in_window, col_in_window, C]
  x = tf.reshape(x, shape=[-1, window_size, window_size, C])

  # return shape:[B_, window_size, window_size, C] B_ is the total number of windows within a batch
  return x

In [None]:
def window_reverse(x, window_size, H, W):
  # x shape:[B_, window_size, window_size, C] B_ is the total number of windows within a batch
  row_num = H // window_size
  col_num = W // window_size
  C = tf.shape(x)[3]

  # x shape:[B, row_num, col_num, row_in_window, col_in_window, C]
  x = tf.reshape(x, shape=[-1, row_num, col_num, window_size, window_size, C])

  # x shape:[B, row_num, row_in_window, col_num, col_in_window, C]
  x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])

  # x shape:[B, H, W, C]
  x = tf.reshape(x, shape=[-1, H, W, C])

  # return shape:[B, H, W, C]
  return x

In [None]:
class window_attention(tfla.Layer):
  def __init__(self, window_size, C, num_heads):
    super().__init__()
    self.C = C
    self.num_heads = num_heads
    self.window_size = window_size
    self.head_dims = C // num_heads
    self.scale = tf.cast(self.head_dims, tf.float32) ** -0.5

    self.qkv = tfla.Dense(3 * C)
    self.dense = tfla.Dense(C)
    self.rel_pos_num = (window_size * 2 - 1) * (window_size * 2 - 1)
    # rel_pos_bias shape:[rel_pos_num, num_heads]
    self.rel_pos_bias = self.add_weight(
        shape=[self.rel_pos_num, num_heads],
        initializer=tf.random_normal_initializer(stddev=0.02),
        trainable=True
    )

    self.coord_x = tf.range(window_size)
    self.coord_y = tf.range(window_size)

    # coord shape:[2, window_size, window_size]
    self.coord = tf.stack(tf.meshgrid(self.coord_x, self.coord_y, indexing="ij"))
    # coord shape:[2, N]
    self.coord = tf.reshape(self.coord, shape=[2, -1])
    # coord shape:[2, N, N]
    self.coord = self.coord[:,None,:] - self.coord[:, :, None]

    # rel_pos_h, rel_pos_w shape:[N, N]
    self.rel_pos_h = self.coord[0, :, :] + window_size - 1
    self.rel_pos_w = self.coord[1, :, :] + window_size - 1

    # rel_pos shape:[N, N]
    self.rel_pos = self.rel_pos_h * (2 * window_size - 1) + self.rel_pos_w

  def call(self, x):
    # x shape:[B_, window_size, window_size, C]
    B_ = tf.shape(x)[0]

    # x shape:[B_, N, C]
    x = tf.reshape(x, shape=[B_, self.window_size * self.window_size, self.C])

    N = self.window_size * self.window_size
    # qkv shape:[B_, N, 3 * C]
    qkv = self.qkv(x)
    # qkv shape:[B_, N, 3, num_heads, head_dims]
    qkv = tf.reshape(qkv, shape=[B_, N, 3, self.num_heads, self.head_dims])
    # qkv shape:[3, B_, num_heads, N, head_dims]
    qkv = tf.transpose(qkv, perm=[2, 0, 3, 1, 4])

    # q, k, v shape:[B_, num_heads, N, head_dims]
    q, k, v = qkv[0], qkv[1], qkv[2]

    # attn shape:[B_, num_heads, N, N]
    attn = tf.matmul(q, k, transpose_b=True) * self.scale

    # rel_emb shape:[N, N, num_heads]
    rel_emb = tf.gather(self.rel_pos_bias, self.rel_pos)
    # rel_emb shape:[num_heads, N, N]
    rel_emb = tf.transpose(rel_emb, perm=[2, 0, 1])
    # rel_emb shape:[1, num_heads, N, N]
    rel_emb = tf.reshape(rel_emb, shape=[1, self.num_heads, N, N])

    attn = tf.nn.softmax(attn + rel_emb, axis=-1)

    # attn shape:[B_, num_heads, N, head_dims]
    attn = attn @ v

    attn = tf.transpose(attn, perm=[0, 2, 1, 3])
    attn = tf.reshape(attn, shape=[B_, N, self.C])

    attn = self.dense(attn)

    # attn shape:[B_, N, C]
    return attn

In [None]:
class block_attention(tfla.Layer):
  def __init__(self, H, W, window_size, C, num_heads):
    super().__init__()
    self.H = H
    self.W = W
    self.window_size = window_size
    self.C = C
    self.num_heads = num_heads

    self.norm1 = tfla.LayerNormalization(epsilon=1e-6)
    self.norm2 = tfla.LayerNormalization(epsilon=1e-6)
    self.fc1 = tfla.Dense(4 * C, activation="gelu")
    self.fc2 = tfla.Dense(C)
    self.window_attention = window_attention(window_size, C, num_heads)

  def call(self, x):
    # x shape:[B, H, W, C]

    shortcut = x
    x = self.norm1(x)

    # window_partition_x shape:[B_, window_size, window_size, C]
    window_partition_x = window_partition(x, self.window_size)
    # window_attention_x shape:[B_, N, C]
    window_attention_x = self.window_attention(window_partition_x)
    # window_attention_x shape:[B_, window_size, window_size, C]
    window_attention_x = tf.reshape(window_attention_x, shape=[-1, self.window_size, self.window_size, self.C])
    # window_attention_x shape:[B, H, W, C]
    window_attention_x = window_reverse(window_attention_x, self.window_size, self.H, self.W)

    x = shortcut + window_attention_x

    shortcut = x
    x = self.norm2(x)
    x = self.fc1(x)
    x = self.fc2(x)

    x = x + shortcut

    # return shape:[B, H, W, C]
    return x

In [None]:
def grid_partition(x, H, W, window_size):
  # x shape:[B, H, W, C]
  actual_window_size_x = H // window_size
  actual_window_size_y = W // window_size
  num_windows_x = H // actual_window_size_x
  num_windows_y = W // actual_window_size_y
  C = tf.shape(x)[3]

  # x shape:[B, num_windows_x, row_in_window, num_windows_y, col_in_window, C]
  x = tf.reshape(x, shape=[-1, num_windows_x, actual_window_size_x, num_windows_y, actual_window_size_y, C])

  # x shape:[B, row_in_window, col_in_window, num_windows_x, num_windows_y, C]
  x = tf.transpose(x, perm=[0, 2, 4, 1, 3, 5])

  x = tf.reshape(x, shape=[-1, window_size, window_size, C])

  # return shape[B_, window_size, window_size, C]
  return x

In [None]:
def reverse_grid_partition(x, H, W, window_size):
  # x shape:[B_, N, C]
  C = tf.shape(x)[2]
  actual_window_size_x = H // window_size
  actual_window_size_y = W // window_size

  # x shape:[B, window_size, window_size, C]
  x = tf.reshape(x, shape=[-1, window_size, window_size, C])
  x = tf.reshape(x, shape=[-1, actual_window_size_x, actual_window_size_y, window_size, window_size, C])
  x = tf.transpose(x, perm=[0, 3, 1, 4, 2, 5])
  x = tf.reshape(x, shape=[-1, H, W, C])

  # return shape:[B, H, W, C]
  return x

In [None]:
class grid_attention(tfla.Layer):
  def __init__(self, H, W, window_size, C, num_heads):
    super().__init__()
    self.H = H
    self.W = W
    self.window_size = window_size
    self.C = C
    self.num_heads = num_heads

    self.norm1 = tfla.LayerNormalization(epsilon=1e-6)
    self.norm2 = tfla.LayerNormalization(epsilon=1e-6)

    self.window_attention = window_attention(window_size, C, num_heads)

    self.fc1 = tfla.Dense(4 * C, activation="gelu")
    self.fc2 = tfla.Dense(C)

  def call(self, x):
    # x shape:[B, H, W, C]

    shortcut = x
    x = self.norm1(x)

    grid_partition_x = grid_partition(x, self.H, self.W, self.window_size)
    grid_attention_x = self.window_attention(grid_partition_x)
    # grid_attention_x shape:[B, H, W, C]
    grid_attention_x = reverse_grid_partition(grid_attention_x, self.H, self.W, self.window_size)

    x = shortcut + grid_attention_x

    shortcut = x
    x = self.norm2(x)
    x = self.fc1(x)
    x = self.fc2(x)
    x = shortcut + x

    # return shape:[B, H, W, C]
    return x

In [None]:
class MBConv(tfla.Layer):
  def __init__(self, C, out_channels, downsample):
    super().__init__()
    self.C = C
    self.out_channels = out_channels
    self.downsample = downsample

    self.shortcut = tfla.Conv2D(out_channels, kernel_size=1, strides=1, use_bias=False)
    self.expansion = tfla.Conv2D(4 * C, kernel_size=1, strides=1, padding="same", use_bias=False)
    self.bn1 = tfla.BatchNormalization()
    self.gelu1 = tfla.Activation("gelu")
    self.depthwiseds = tfla.DepthwiseConv2D(3, strides=2, padding="same", use_bias=False)
    self.depthwise = tfla.DepthwiseConv2D(3, strides=1, padding="same", use_bias=False)
    self.globalaverage = tfla.GlobalAveragePooling2D()
    self.dense1 = tfla.Dense(C, activation="gelu")
    self.dense2 = tfla.Dense(4 * C, activation="sigmoid")
    self.conv1 = tfla.Conv2D(out_channels, kernel_size=1, strides=1, padding="same", use_bias=False)
    self.bn2 = tfla.BatchNormalization()
    self.bn3 = tfla.BatchNormalization()
    self.gelu2 = tfla.Activation("gelu")
    self.shortcutglo = tfla.AveragePooling2D(pool_size=2, strides=2)

  def build(self, input_shape):
    super().build(input_shape)

  def call(self, x):
  # x shape:[B, H, W, C]

    if(self.downsample):
      shortcut = self.shortcutglo(x)
      shortcut = self.shortcut(shortcut)
    else:
      shortcut = self.shortcut(x)

    x = self.expansion(x)
    x = self.bn1(x)
    x = self.gelu1(x)
    if(self.downsample):
      x = self.depthwiseds(x)
    else:
      x = self.depthwise(x)
    x = self.bn2(x)
    x = self.gelu2(x)

    sne = self.globalaverage(x)
    sne = self.dense1(sne)
    sne = self.dense2(sne)
    sne = tf.expand_dims(sne, axis=1)
    sne = tf.expand_dims(sne, axis=1)

    x = x * sne
    x = self.conv1(x)
    x = self.bn3(x)

    return x + shortcut

In [None]:
inputs = tfla.Input(shape=(224, 224, 3))

x = tfla.Conv2D(64, kernel_size=3, strides=2, padding="same", use_bias=False)(inputs)
x = tfla.Conv2D(64, kernel_size=3, strides=2, padding="same", use_bias=False)(x)

# block 1
# x shape:[B, 56, 56, 64]
x = MBConv(C=64, out_channels=64, downsample=False)(x)
x = block_attention(H=56, W=56, window_size=7, C=64, num_heads=2)(x)
x = grid_attention(H=56, W=56, window_size=7, C=64, num_heads=2)(x)
x = MBConv(C=64, out_channels=64, downsample=False)(x)
x = block_attention(H=56, W=56, window_size=7, C=64, num_heads=2)(x)
x = grid_attention(H=56, W=56, window_size=7, C=64, num_heads=2)(x)

# block 2
# x shape:[B, 28, 28, 128]
x = MBConv(C=64, out_channels=128, downsample=True)(x)
x = block_attention(H=28, W=28, window_size=7, C=128, num_heads=4)(x)
x = grid_attention(H=28, W=28, window_size=7, C=128, num_heads=4)(x)
x = MBConv(C=128, out_channels=128, downsample=False)(x)
x = block_attention(H=28, W=28, window_size=7, C=128, num_heads=4)(x)
x = grid_attention(H=28, W=28, window_size=7, C=128, num_heads=4)(x)

# block 3
# x shape:[B, 14, 14, 256]
x = MBConv(C=128, out_channels=256, downsample=True)(x)
x = block_attention(H=14, W=14, window_size=7, C=256, num_heads=8)(x)
x = grid_attention(H=14, W=14, window_size=7, C=256, num_heads=8)(x)
for _ in range(4):
  x = MBConv(C=256, out_channels=256, downsample=False)(x)
  x = block_attention(H=14, W=14, window_size=7, C=256, num_heads=8)(x)
  x = grid_attention(H=14, W=14, window_size=7, C=256, num_heads=8)(x)

# block 4
# x shape:[B, 7, 7, 512]
x = MBConv(C=256, out_channels=512, downsample=True)(x)
x = block_attention(H=7, W=7, window_size=7, C=512, num_heads=16)(x)
x = grid_attention(H=7, W=7, window_size=7, C=512, num_heads=16)(x)
x = MBConv(C=512, out_channels=512, downsample=False)(x)
x = block_attention(H=7, W=7, window_size=7, C=512, num_heads=16)(x)
x = grid_attention(H=7, W=7, window_size=7, C=512, num_heads=16)(x)

x = tfla.GlobalAveragePooling2D()(x)
outputs = tfla.Dense(200, activation="softmax")(x)

model = tfm.Model(inputs=inputs, outputs=outputs)

In [None]:
model.summary()

In [None]:
steps_per_epoch = 3128
epochs = 40
total_steps = steps_per_epoch * epochs

lr = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=5e-4,
    decay_steps=total_steps
)

opt = tf.keras.optimizers.AdamW(learning_rate=lr, weight_decay=5e-2)

In [None]:
model.compile(
    optimizer=opt,
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

history = model.fit(
    combined_ds,
    epochs=epochs
)

Epoch 1/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1936s[0m 537ms/step - accuracy: 0.0690 - loss: 5.1043
Epoch 2/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1565s[0m 495ms/step - accuracy: 0.2140 - loss: 3.9878
Epoch 3/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1566s[0m 496ms/step - accuracy: 0.2904 - loss: 3.6471
Epoch 4/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1566s[0m 496ms/step - accuracy: 0.3552 - loss: 3.3693
Epoch 5/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1566s[0m 496ms/step - accuracy: 0.4125 - loss: 3.1247
Epoch 6/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1567s[0m 496ms/step - accuracy: 0.4633 - loss: 2.9242
Epoch 7/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1570s[0m 497ms/step - accuracy: 0.5086 - loss: 2.7546
Epoch 8/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1568s[0m 496ms/step - accuracy: 0.5483

In [None]:
model.evaluate(test_ds)

[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 581ms/step - accuracy: 0.6913 - loss: 1.6756


[1.6835989952087402, 0.6927000284194946]

In [None]:
plt.plot(history.history['accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train'], loc='lower right')
plt.show()