In [None]:
import os
import shutil

base_dir = "/content/tiny-imagenet-200"
val_dir = os.path.join(base_dir, "val")
images_dir = os.path.join(val_dir, "images")
ann_file = os.path.join(val_dir, "val_annotations.txt")

# Read annotations
with open(ann_file) as f:
    annotations = [line.strip().split('\t') for line in f]

# Create class folders and move images
for img, cls, *_ in annotations:
    cls_dir = os.path.join(val_dir, cls)
    os.makedirs(cls_dir, exist_ok=True)
    shutil.move(
        os.path.join(images_dir, img),
        os.path.join(cls_dir, img)
    )

os.rmdir(images_dir)

In [None]:
!pip install -q tf-models-official

In [None]:
import tensorflow as tf
import tensorflow.keras.layers as tfla
import tensorflow.keras.models as tfm
import tensorflow.keras.optimizers as tfo
import tensorflow.keras.losses as tflo
import matplotlib.pyplot as plt
from official.vision.ops import augment
import numpy as np

In [None]:
with open("tiny-imagenet-200/wnids.txt") as f:
    wnids = [line.strip() for line in f]

train_ds = tf.keras.utils.image_dataset_from_directory(
    "tiny-imagenet-200/train",
    labels="inferred",
    label_mode="categorical",
    class_names=wnids,
    image_size=(256, 256),
    batch_size=None,
)

test_ds = tf.keras.utils.image_dataset_from_directory(
    "tiny-imagenet-200/val",
    labels="inferred",
    label_mode="categorical",
    class_names=wnids,
    image_size=(224, 224),
    batch_size=128,
)

Found 100000 files belonging to 200 classes.
Found 10000 files belonging to 200 classes.


In [None]:
def crop_image(image, label):
  # image shape: [h, w, c]
  # label shape: [num_class,]
  image = tf.image.random_crop(image, (224, 224, 3))
  image = tf.image.random_flip_left_right(image)
  return image, label

In [None]:
train_ds = train_ds.map(crop_image, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
def apply_randaugment(image, label):
  # image shape: [h, w, c]
  # label shape: [num_class,]
  augmenter = augment.RandAugment(num_layers=2, magnitude=9)
  return augmenter.distort(image), label

In [None]:
train_ds = train_ds.map(apply_randaugment, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
def normalise_image(image, label):
  # image shape: [h, w, c]
  # label shape: [num_class,]
  mean = tf.constant([0.485, 0.456, 0.406])
  std = tf.constant([0.229, 0.224, 0.225])

  image = (image / 255.0 - mean) / std

  return image, label

In [None]:
train_ds = train_ds.map(normalise_image, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.map(normalise_image, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
train_ds = train_ds.batch(128)
combined_ds = train_ds

In [None]:
def mixup(images, labels):
  # images shape: [batchsize, h, w, c]
  # labels shape: [batchsize, num_class]
  batch_size = tf.shape(images)[0]

  gamma_1 = tf.random.gamma(shape=[batch_size, 1], alpha=0.2)
  gamma_2 = tf.random.gamma(shape=[batch_size, 1], alpha=0.2)
  lam = gamma_1 / (gamma_1 + gamma_2)

  indices = tf.random.shuffle(tf.range(batch_size))
  shuffled_images = tf.gather(images, indices)
  shuffled_labels = tf.gather(labels, indices)

  images_lam = tf.reshape(lam, [-1, 1, 1, 1])
  labels_lam = lam


  images = images_lam * images + (1 - images_lam) * shuffled_images
  labels = labels_lam * labels + (1 - labels_lam) * shuffled_labels

  return images, labels

In [None]:
mixup_ds = train_ds.map(mixup, num_parallel_calls=tf.data.AUTOTUNE)
combined_ds = combined_ds.concatenate(mixup_ds)

In [None]:
def cutmix(images, labels):
  # images shape: [batchsize, h, w, c]
  # labels shape: [batchsize, num_class]
  batch_size = tf.shape(images)[0]
  img_height = tf.shape(images)[1]
  img_width = tf.shape(images)[2]

  gamma_1 = tf.random.gamma(shape=[batch_size, 1], alpha=0.2)
  gamma_2 = tf.random.gamma(shape=[batch_size, 1], alpha=0.2)
  # lam is the cut percentage with shape: [batch_size, 1]
  lam = gamma_1 / (gamma_1 + gamma_2)

  # we find the cut image height and width all with shape [batch_size, 1]
  cut_height = tf.cast(tf.cast(img_height, tf.float32) * tf.sqrt(lam), tf.int32)
  cut_width = tf.cast(tf.cast(img_width, tf.float32) * tf.sqrt(lam), tf.int32)

  min_height = tf.cast(cut_height // 2, tf.int32)
  max_height = tf.cast(img_height - 1 - cut_height // 2, tf.int32)
  cut_centre_x = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=1,
                                   dtype=tf.float32)
  cut_centre_x = cut_centre_x * tf.cast(max_height - min_height, tf.float32) + tf.cast(min_height, tf.float32)

  min_width = tf.cast(cut_width // 2, tf.int32)
  max_width = tf.cast(img_width - 1 - cut_width // 2, tf.int32)
  cut_centre_y = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=1,
                                   dtype=tf.float32)
  cut_centre_y = cut_centre_y * tf.cast(max_width - min_width, tf.float32) + tf.cast(min_width, tf.float32)

  # find four conors for rectangles all with shape: [batch_size, 1]
  x1 = tf.cast(tf.cast(cut_centre_x, tf.int32) - cut_height // 2, tf.int32)
  x2 = tf.cast(tf.cast(cut_centre_x, tf.int32) + cut_height // 2, tf.int32)
  y1 = tf.cast(tf.cast(cut_centre_y, tf.int32) - cut_width // 2, tf.int32)
  y2 = tf.cast(tf.cast(cut_centre_y, tf.int32) + cut_width // 2, tf.int32)

  x_indices = tf.range(img_height)
  y_indices = tf.range(img_width)
  y_grid, x_grid = tf.meshgrid(y_indices, x_indices)
  x_grid = tf.reshape(x_grid, [1, img_height, img_width])
  y_grid = tf.reshape(y_grid, [1, img_height, img_width])

  x1 = tf.reshape(x1, [batch_size, 1, 1])
  x2 = tf.reshape(x2, [batch_size, 1, 1])
  y1 = tf.reshape(y1, [batch_size, 1, 1])
  y2 = tf.reshape(y2, [batch_size, 1, 1])

  # mask matrix with shape : [batch_size, h, w]
  mask = tf.logical_and(tf.logical_and(x1 <= x_grid, x_grid <= x2), tf.logical_and(
      y1 <= y_grid, y_grid <= y2))
  mask = tf.cast(mask, dtype=tf.int32)

  indices = tf.random.shuffle(tf.range(batch_size))
  shuffled_images = tf.gather(images, indices)
  shuffled_labels = tf.gather(labels, indices)

  mask = tf.reshape(mask, [batch_size, img_height, img_width, 1])

  images = images * tf.cast(1 - mask, dtype=tf.float32) + shuffled_images * tf.cast(mask, dtype=tf.float32)
  labels = labels * (1.0 - lam) + shuffled_labels * lam

  return images, labels

In [None]:
cutmix_ds = train_ds.map(cutmix, num_parallel_calls=tf.data.AUTOTUNE)
combined_ds = combined_ds.concatenate(cutmix_ds)

In [None]:
def erase(images, labels):
  # images shape: [batch_size, h, w, c]
  # labels shape: [batch_size, num_class]
  batch_size = tf.shape(images)[0]
  img_height = tf.shape(images)[1]
  img_width = tf.shape(images)[2]

  # lam is the cut percentage with shape: [batch_size, 1]
  lam = tf.random.uniform(shape=(batch_size, 1), minval=0.2, maxval=0.5, dtype=tf.float32)

  # we find the cut image height and width all with shape [batch_size, 1]
  erase_height = tf.cast(tf.cast(img_height, tf.float32) * tf.sqrt(lam), tf.int32)
  erase_width = tf.cast(tf.cast(img_width, tf.float32) * tf.sqrt(lam), tf.int32)

  min_height = tf.cast(erase_height // 2, tf.int32)
  max_height = tf.cast(img_height - 1 - erase_height // 2, tf.int32)
  erase_centre_x = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=1,
                                   dtype=tf.float32)
  erase_centre_x = erase_centre_x * tf.cast(max_height - min_height, tf.float32) + tf.cast(min_height, tf.float32)

  min_width = tf.cast(erase_width // 2, tf.int32)
  max_width = tf.cast(img_width - 1 - erase_width // 2, tf.int32)
  erase_centre_y = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=1,
                                   dtype=tf.float32)
  erase_centre_y = erase_centre_y * tf.cast(max_width - min_width, tf.float32) + tf.cast(min_width, tf.float32)

  # find four conors for rectangles all with shape: [batch_size, 1]
  x1 = tf.cast(tf.cast(erase_centre_x, tf.int32) - erase_height // 2, tf.int32)
  x2 = tf.cast(tf.cast(erase_centre_x, tf.int32) + erase_height // 2, tf.int32)
  y1 = tf.cast(tf.cast(erase_centre_y, tf.int32) - erase_width // 2, tf.int32)
  y2 = tf.cast(tf.cast(erase_centre_y, tf.int32) + erase_width // 2, tf.int32)

  x_indices = tf.range(img_height)
  y_indices = tf.range(img_width)
  y_grid, x_grid = tf.meshgrid(y_indices, x_indices)
  x_grid = tf.reshape(x_grid, [1, img_height, img_width])
  y_grid = tf.reshape(y_grid, [1, img_height, img_width])

  x1 = tf.reshape(x1, [batch_size, 1, 1])
  x2 = tf.reshape(x2, [batch_size, 1, 1])
  y1 = tf.reshape(y1, [batch_size, 1, 1])
  y2 = tf.reshape(y2, [batch_size, 1, 1])

  # mask matrix with shape : [batch_size, h, w]
  mask = tf.logical_and(tf.logical_and(x1 <= x_grid, x_grid <= x2), tf.logical_and(
      y1 <= y_grid, y_grid <= y2))
  mask = tf.cast(mask, dtype=tf.int32)

  mask = tf.reshape(mask, [batch_size, img_height, img_width, 1])

  images = images * tf.cast(1 - mask, dtype=tf.float32)

  return images, labels

In [None]:
erase_ds = train_ds.map(erase, num_parallel_calls=tf.data.AUTOTUNE)
combined_ds = combined_ds.concatenate(erase_ds)

In [None]:
def label_smoothing(labels, epsilon=0.1):
  num_class = tf.cast(tf.shape(labels)[1], tf.float32)
  return labels * (1.0 - epsilon) + epsilon / num_class

In [None]:
combined_ds = combined_ds.map(lambda images, labels: (images, label_smoothing(labels)), num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
combined_ds = combined_ds.shuffle(buffer_size=100)

In [None]:
# TEST CODE
def convert_back(image):
  mean = tf.constant([0.485, 0.456, 0.406])
  std = tf.constant([0.229, 0.224, 0.225])

  image = (image * std + mean) * 255
  image = tf.clip_by_value(image, 0, 255)
  image = tf.cast(image, tf.int32)

  return image

In [None]:
# TEST CODE
def load_tinyimagenet_label_maps(tiny_imagenet_root):
    # Load wnids (index -> wnid)
    with open(f"{tiny_imagenet_root}/wnids.txt") as f:
        wnids = [line.strip() for line in f]

    # Load wnid -> words
    wnid_to_words = {}
    with open(f"{tiny_imagenet_root}/words.txt") as f:
        for line in f:
            wnid, words = line.strip().split("\t")
            wnid_to_words[wnid] = words

    return wnids, wnid_to_words

In [None]:
# TEST CODE
wnids, wnid_to_words = load_tinyimagenet_label_maps("tiny-imagenet-200")

In [None]:
# TEST CODE
def one_hot_to_tinyimagenet_word(one_hot_label, wnids, wnid_to_words):
    """
    one_hot_label: tf.Tensor or np.array, shape (200,)
    """
    # Convert one-hot to index
    if isinstance(one_hot_label, tf.Tensor):
        label_idx = int(tf.argmax(one_hot_label).numpy())
    else:
        label_idx = int(one_hot_label.argmax())

    wnid = wnids[label_idx]
    return wnid_to_words[wnid]

In [None]:
# TEST CODE
"""%matplotlib inline

for images, labels in combined_ds.take(1):
  imagess, labelss = erase(images, labels)
  print(one_hot_to_tinyimagenet_word(labelss[0], wnids, wnid_to_words))
  plt.imshow(convert_back(imagess[0]).numpy())
  plt.axis("off")
  plt.show()"""

'%matplotlib inline\n\nfor images, labels in combined_ds.take(1):\n  imagess, labelss = erase(images, labels)\n  print(one_hot_to_tinyimagenet_word(labelss[0], wnids, wnid_to_words))\n  plt.imshow(convert_back(imagess[0]).numpy())\n  plt.axis("off")\n  plt.show()'

In [None]:
class patch_merging(tfla.Layer):
  def __init__(self, dim):
    # dim is the C in [B, N, C]
    super().__init__()
    self.dim = dim
    self.norm = tfla.LayerNormalization()
    self.proj = tfla.Dense(dim * 2, use_bias=False)

  def call(self, x, h, w):
    batch_size = tf.shape(x)[0]

    x = tf.reshape(x, [batch_size, h, w, self.dim])

    x0 = x[:, 0::2, 0::2, :]
    x1 = x[:, 1::2, 0::2, :]
    x2 = x[:, 0::2, 1::2, :]
    x3 = x[:, 1::2, 1::2, :]

    x = tf.concat([x0, x1, x2, x3], axis=-1)
    x = tf.reshape(x, [batch_size, (h // 2) * (w // 2), 4 * self.dim])

    x = self.norm(x)
    x = self.proj(x)

    return x

In [None]:
def window_partition(x, window_size):
  # x shape:[batch_size, h, w, c]
  batch_size = tf.shape(x)[0]
  h = tf.shape(x)[1]
  w = tf.shape(x)[2]
  c = tf.shape(x)[3]

  # x shape:[batch_size, row_num, row in window, column_num, column in window, c]
  x = tf.reshape(x, [batch_size, h // window_size, window_size, w // window_size,
                     window_size, c])
  # x shape:[batch_size, row_num, column_num, row in window, column in window, c]
  x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
  # windows shape:[B_, window_size, window_size, c]
  windows = tf.reshape(x, [-1, window_size, window_size, c])

  return windows

In [None]:
def window_reverse(windows, window_size, h, w):
  # windows shape:[batch_size * row_num * column_num, row in window, column in window, c]
  batch_size = tf.shape(windows)[0] // (h // window_size) // (w // window_size)
  c = tf.shape(windows)[3]
  # windows shape:[batch_size, row_num, column_num, row in window, column in window, c]
  windows = tf.reshape(windows, [batch_size, h // window_size, w // window_size,
                                 window_size, window_size, c])
  # windows shape:[batch_size, row_num, row in window, column_num, column in window, c]
  windows = tf.transpose(windows, [0, 1, 3, 2, 4, 5])

  windows = tf.reshape(windows, [batch_size, h, w, c])

  return windows

In [None]:
class MLP(tfla.Layer):
  def __init__(self, dim, mlp_ratio=4):
    super().__init__()
    self.dim = dim
    self.mlp_ratio = mlp_ratio
    self.fc1 = tfla.Dense(dim * mlp_ratio, use_bias=False, activation="gelu")
    self.fc2 = tfla.Dense(dim, use_bias=False)

  def call(self, x):
    x = self.fc1(x)
    x = self.fc2(x)

    return(x)

In [None]:
class window_attention(tfla.Layer):
  def __init__(self, dim, window_size, num_heads):
    super().__init__()
    self.dim = dim
    self.num_heads = num_heads
    self.window_size = window_size
    self.head_dim = dim // num_heads
    self.scale = tf.cast(self.head_dim, tf.float32) ** -0.5
    self.qkv = tfla.Dense(dim * 3, use_bias=True)
    self.dense = tfla.Dense(dim, use_bias=True)

    self.num_rel_pos = (2 * window_size - 1) * (2 * window_size - 1)
    self.rel_pos_emb = self.add_weight(
        shape=(self.num_rel_pos, self.num_heads),
        initializer=tf.random_normal_initializer(stddev=0.02),
        trainable=True
    )

    self.coords_h = tf.range(self.window_size)
    self.coords_w = tf.range(self.window_size)
    # coords shape:[2, window_size, window_size]
    self.coords = tf.stack(tf.meshgrid(self.coords_h, self.coords_w, indexing="ij"))
    # coords shape:[2, N]
    self.coords = tf.reshape(self.coords, [2, -1])
    # rel_pos shape:[2, N, N]
    self.rel_pos = self.coords[:, :, None] - self.coords[:, None, :]
    # rel_pos shape:[N, N, 2]
    self.rel_pos = tf.transpose(self.rel_pos, [1, 2, 0])

    self.col_pos = self.rel_pos[:,:,0] + self.window_size - 1
    self.row_pos = self.rel_pos[:,:,1] + self.window_size - 1

    self.rel_pos_index = self.col_pos * (2 * self.window_size - 1) + self.row_pos

    # rel_pos_index shape:[N, N]
    self.rel_pos_index = tf.cast(self.rel_pos_index, tf.int32)

  def call(self, x, mask=None):
    # x shape:[B_, N, c], where N is window_size * window_size
    # B_ total number of windows inside a batch
    # mask shape:[nW, N, N], where nW is the number of windows per image
    B_ = tf.shape(x)[0]
    N = self.window_size * self.window_size

    # qkv shape:[B_, N, 3 * c]
    qkv = self.qkv(x)
    # qkv shape:[B_, N, 3, num_heads, head_dim]
    qkv = tf.reshape(qkv, [B_, N, 3, self.num_heads, self.head_dim])
    # qkv shape:[3, B_, num_heads, N, head_dim]
    qkv = tf.transpose(qkv, [2, 0, 3, 1, 4])

    # q, k, v shape:[B_, num_heads, N, head_dim]
    q, k, v = qkv[0], qkv[1], qkv[2]

    q = q * self.scale
    # attn shape:[B_, num_heads, N, N]
    attn = tf.matmul(q, k, transpose_b=True)

    rel_position_index = tf.reshape(self.rel_pos_index, [N*N])
    # bias shape:[N*N, num_heads]
    bias = tf.gather(self.rel_pos_emb, rel_position_index)
    bias = tf.reshape(bias, [N, N, self.num_heads])
    bias = tf.transpose(bias, [2, 0, 1])
    bias = tf.reshape(bias, [1, self.num_heads, N, N])
    attn = attn + bias

    # apply mask
    if(mask is not None):
      num_windows_per_image = tf.shape(mask)[0]

      attn = tf.reshape(attn, [-1, num_windows_per_image, self.num_heads, N, N])
      mask = tf.reshape(mask, [1, num_windows_per_image, 1, N, N])
      mask = tf.cast(mask, tf.float32)
      attn = attn + mask
      attn = tf.reshape(attn, [-1, self.num_heads, N, N])


    attn = tf.nn.softmax(attn, axis=-1)
    # out shape:[B_, num_heads, N, head_dim]
    out = tf.matmul(attn, v)
    # out shape:[B_, N, num_heads, head_dim]
    out = tf.transpose(out, [0, 2, 1, 3])
    out = tf.reshape(out, [B_, N, self.dim])

    # out shape:[B_, N, C]
    return out

In [None]:
def create_mask(H, W, window_size, shifted_size):
  # H: image height
  # W: image weight
  # window_size: window height and width
  # shifted size: shift length

  mask = np.zeros([H, W])

  for i in range(0, H, window_size):
    for j in range(0, W, window_size):
      mask[i:i + window_size, j:j + window_size] = (i // window_size) * (W // window_size) + j // window_size

  mask = np.roll(mask, -shifted_size, 0)
  mask = np.roll(mask, -shifted_size, 1)

  mask = np.reshape(mask, [1, H, W, 1])
  mask = tf.convert_to_tensor(mask, dtype=tf.float32)
  # mask shape: [num_windows_per_image, window_size, window_size, 1]
  mask = window_partition(mask, window_size)
  mask = tf.reshape(mask, [-1, window_size * window_size])

  # mask_row shape: [num_windows_per_image, 1, N]
  mask_row = tf.expand_dims(mask, axis=1)
  # mask_col shape: [num_windows_per_image, N, 1]
  mask_col = tf.expand_dims(mask, axis=2)

  mask = tf.logical_not(tf.math.equal(mask_row, mask_col))
  mask = mask * tf.cast(-1e9, tf.float32)

  # mask shape:[num_windows_per_image, N, N]
  return mask

In [None]:
class swin_window_attention_forward(tfla.Layer):
  def __init__(self, H, W, C, window_size, num_heads, shift_size):
    super().__init__()
    self.H = H
    self.W = W
    self.C = C
    self.num_heads = num_heads
    self.window_size = window_size
    self.shift_size = shift_size
    self.window_attention = window_attention(self.C, self.window_size, self.num_heads)

  def call(self, x):
    # x shape:[B, H, W, C]
    if(x.shape.rank == 3):
      x = tf.reshape(x, [-1, self.H, self.W, self.C])

    if(self.shift_size > 0):
      x = tf.roll(x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2])
      mask_x = create_mask(H=self.H, W=self.W, window_size=self.window_size, shifted_size=self.shift_size)
    else:
      mask_x = None

    # window_x shape:[B_, window_size, window_size, C]
    window_x = window_partition(x, self.window_size)
    # window_x shape:[B_, N, C]
    window_x = tf.reshape(window_x, [-1, self.window_size * self.window_size, self.C])
    # atte_x shape:[B_, N, C]
    atte_x = self.window_attention(window_x, mask=mask_x)

    # atte_x shape:[B_, window_size, window_size, C]
    atte_x = tf.reshape(atte_x, [-1, self.window_size, self.window_size, self.C])

    # x shape:[B, H, W, C]
    x = window_reverse(atte_x, self.window_size, self.H, self.W)

    if(self.shift_size > 0):
      x = tf.roll(x, shift=[self.shift_size, self.shift_size], axis=[1, 2])

    return x

In [None]:
class swin_block(tfla.Layer):
  def __init__(self, H, W, C, num_heads, window_size, shift_size):
    super().__init__()
    self.H = H
    self.W = W
    self.C = C
    self.num_heads = num_heads
    self.window_size = window_size
    self.shift_size = shift_size

    self.norm1 = tfla.LayerNormalization()
    self.norm2 = tfla.LayerNormalization()
    self.attn = swin_window_attention_forward(H, W, C, window_size, num_heads,shift_size)
    self.MLP = MLP(C)

  def call(self, x):
    shortcut = x
    x = self.norm1(x)
    x = self.attn(x)
    if(shortcut.shape.rank == 3):
      shortcut = tf.reshape(shortcut, [-1, self.H, self.W, self.C])
    x = x + shortcut

    shortcut = x
    x = self.norm2(x)
    x = self.MLP(x)
    x = x + shortcut

    # x shape:[B, N, C]
    return(x)

In [None]:
inputs = tfla.Input(shape=(224, 224, 3))

# x shape:[B, H, W, C]
x = tfla.Conv2D(
    96,
    4,
    strides=4,
    padding="valid"
)(inputs)
x = tfla.LayerNormalization()(x)

H = W = 56
C = 96
num_heads = [3, 6, 12, 24]
depths = [2, 2, 6, 2]
window_size = 7
shift_size = window_size // 2
for i in range(depths[0]):
  if(i % 2 == 0):
    x = swin_block(H, W, C, num_heads[0], window_size, 0)(x)
  else:
    x = swin_block(H, W, C, num_heads[0], window_size, shift_size)(x)
x = patch_merging(C)(x, h=H, w=W)

H = W = H // 2
C = C * 2
for i in range(depths[1]):
  if(i % 2 == 0):
    x = swin_block(H, W, C, num_heads[1], window_size, 0)(x)
  else:
    x = swin_block(H, W, C, num_heads[1], window_size, shift_size)(x)
x = patch_merging(C)(x, h=H, w=W)

H = W = H // 2
C = C * 2
for i in range(depths[2]):
  if(i % 2 == 0):
    x = swin_block(H, W, C, num_heads[2], window_size, 0)(x)
  else:
    x = swin_block(H, W, C, num_heads[2], window_size, shift_size)(x)
x = patch_merging(C)(x, h=H, w=W)

H = W = H // 2
C = C * 2
for i in range(depths[3]):
    x = swin_block(H, W, C, num_heads[3], window_size, 0)(x)

x = tfla.LayerNormalization()(x)
x = tfla.GlobalAveragePooling2D()(x)
outputs = tfla.Dense(200, activation="softmax")(x)

model = tfm.Model(inputs, outputs)



In [None]:
model.summary()

In [None]:
steps_per_epoch = 3128
epochs = 40
total_steps = steps_per_epoch * epochs

lr = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=5e-4,
    decay_steps=total_steps
)

opt = tf.keras.optimizers.AdamW(learning_rate=lr, weight_decay=5e-2)

In [None]:
model.compile(
    optimizer=opt,
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

history = model.fit(
    combined_ds,
    epochs=epochs
)

Epoch 1/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1152s[0m 335ms/step - accuracy: 0.0124 - loss: 5.2870
Epoch 2/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1011s[0m 318ms/step - accuracy: 0.0167 - loss: 5.1936
Epoch 3/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1012s[0m 319ms/step - accuracy: 0.0347 - loss: 5.0189
Epoch 4/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1011s[0m 319ms/step - accuracy: 0.0602 - loss: 4.8289
Epoch 5/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1012s[0m 319ms/step - accuracy: 0.0943 - loss: 4.6114
Epoch 6/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1012s[0m 319ms/step - accuracy: 0.1210 - loss: 4.4479
Epoch 7/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1012s[0m 319ms/step - accuracy: 0.1495 - loss: 4.2917
Epoch 8/40
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1012s[0m 319ms/step - accuracy: 0.1775

In [None]:
model.evaluate(test_ds)

In [None]:
plt.plot(history.history['accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train'], loc='lower right')
plt.show()