In [None]:
import os
import shutil

base_dir = "/content/tiny-imagenet-200"
val_dir = os.path.join(base_dir, "val")
images_dir = os.path.join(val_dir, "images")
ann_file = os.path.join(val_dir, "val_annotations.txt")

# Read annotations
with open(ann_file) as f:
    annotations = [line.strip().split('\t') for line in f]

# Create class folders and move images
for img, cls, *_ in annotations:
    cls_dir = os.path.join(val_dir, cls)
    os.makedirs(cls_dir, exist_ok=True)
    shutil.move(
        os.path.join(images_dir, img),
        os.path.join(cls_dir, img)
    )

os.rmdir(images_dir)

In [None]:
!pip install -q tf-models-official

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m87.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.3/16.3 MB[0m [31m103.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m620.7/620.7 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m25.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.0/18.0 MB[0m [31m120.1 MB/s[0m eta [36m0

In [None]:
import tensorflow as tf
import tensorflow.keras.layers as tfla
import tensorflow.keras.models as tfm
import tensorflow.keras.optimizers as tfo
import tensorflow.keras.losses as tflo
import matplotlib.pyplot as plt
from official.vision.ops import augment
import numpy as np

In [None]:
with open("tiny-imagenet-200/wnids.txt") as f:
    wnids = [line.strip() for line in f]

train_ds = tf.keras.utils.image_dataset_from_directory(
    "tiny-imagenet-200/train",
    labels="inferred",
    label_mode="categorical",
    class_names=wnids,
    image_size=(256, 256),
    batch_size=None,
)

test_ds = tf.keras.utils.image_dataset_from_directory(
    "tiny-imagenet-200/val",
    labels="inferred",
    label_mode="categorical",
    class_names=wnids,
    image_size=(224, 224),
    batch_size=128,
)

Found 100000 files belonging to 200 classes.
Found 10000 files belonging to 200 classes.


In [None]:
def crop_image(image, label):
  # image shape: [h, w, c]
  # label shape: [num_class,]
  image = tf.image.random_crop(image, (224, 224, 3))
  image = tf.image.random_flip_left_right(image)
  return image, label

In [None]:
train_ds = train_ds.map(crop_image, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
def normalise_image(image, label):
  # image shape: [h, w, c]
  # label shape: [num_class,]
  mean = tf.constant([0.485, 0.456, 0.406])
  std = tf.constant([0.229, 0.224, 0.225])

  image = (image / 255.0 - mean) / std

  return image, label

In [None]:
train_ds = train_ds.map(normalise_image, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.map(normalise_image, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
train_ds = train_ds.batch(128)
combined_ds = train_ds

In [None]:
def mixup(images, labels):
  # images shape: [batchsize, h, w, c]
  # labels shape: [batchsize, num_class]
  batch_size = tf.shape(images)[0]

  gamma_1 = tf.random.gamma(shape=[batch_size, 1], alpha=0.2)
  gamma_2 = tf.random.gamma(shape=[batch_size, 1], alpha=0.2)
  lam = gamma_1 / (gamma_1 + gamma_2)

  indices = tf.random.shuffle(tf.range(batch_size))
  shuffled_images = tf.gather(images, indices)
  shuffled_labels = tf.gather(labels, indices)

  images_lam = tf.reshape(lam, [-1, 1, 1, 1])
  labels_lam = lam


  images = images_lam * images + (1 - images_lam) * shuffled_images
  labels = labels_lam * labels + (1 - labels_lam) * shuffled_labels

  return images, labels

In [None]:
mixup_ds = train_ds.map(mixup, num_parallel_calls=tf.data.AUTOTUNE)
combined_ds = combined_ds.concatenate(mixup_ds)

In [None]:
def cutmix(images, labels):
  # images shape: [batchsize, h, w, c]
  # labels shape: [batchsize, num_class]
  batch_size = tf.shape(images)[0]
  img_height = tf.shape(images)[1]
  img_width = tf.shape(images)[2]

  gamma_1 = tf.random.gamma(shape=[batch_size, 1], alpha=0.2)
  gamma_2 = tf.random.gamma(shape=[batch_size, 1], alpha=0.2)
  # lam is the cut percentage with shape: [batch_size, 1]
  lam = gamma_1 / (gamma_1 + gamma_2)

  # we find the cut image height and width all with shape [batch_size, 1]
  cut_height = tf.cast(tf.cast(img_height, tf.float32) * tf.sqrt(lam), tf.int32)
  cut_width = tf.cast(tf.cast(img_width, tf.float32) * tf.sqrt(lam), tf.int32)

  min_height = tf.cast(cut_height // 2, tf.int32)
  max_height = tf.cast(img_height - 1 - cut_height // 2, tf.int32)
  cut_centre_x = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=1,
                                   dtype=tf.float32)
  cut_centre_x = cut_centre_x * tf.cast(max_height - min_height, tf.float32) + tf.cast(min_height, tf.float32)

  min_width = tf.cast(cut_width // 2, tf.int32)
  max_width = tf.cast(img_width - 1 - cut_width // 2, tf.int32)
  cut_centre_y = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=1,
                                   dtype=tf.float32)
  cut_centre_y = cut_centre_y * tf.cast(max_width - min_width, tf.float32) + tf.cast(min_width, tf.float32)

  # find four conors for rectangles all with shape: [batch_size, 1]
  x1 = tf.cast(tf.cast(cut_centre_x, tf.int32) - cut_height // 2, tf.int32)
  x2 = tf.cast(tf.cast(cut_centre_x, tf.int32) + cut_height // 2, tf.int32)
  y1 = tf.cast(tf.cast(cut_centre_y, tf.int32) - cut_width // 2, tf.int32)
  y2 = tf.cast(tf.cast(cut_centre_y, tf.int32) + cut_width // 2, tf.int32)

  x_indices = tf.range(img_height)
  y_indices = tf.range(img_width)
  y_grid, x_grid = tf.meshgrid(y_indices, x_indices)
  x_grid = tf.reshape(x_grid, [1, img_height, img_width])
  y_grid = tf.reshape(y_grid, [1, img_height, img_width])

  x1 = tf.reshape(x1, [batch_size, 1, 1])
  x2 = tf.reshape(x2, [batch_size, 1, 1])
  y1 = tf.reshape(y1, [batch_size, 1, 1])
  y2 = tf.reshape(y2, [batch_size, 1, 1])

  # mask matrix with shape : [batch_size, h, w]
  mask = tf.logical_and(tf.logical_and(x1 <= x_grid, x_grid <= x2), tf.logical_and(
      y1 <= y_grid, y_grid <= y2))
  mask = tf.cast(mask, dtype=tf.int32)

  indices = tf.random.shuffle(tf.range(batch_size))
  shuffled_images = tf.gather(images, indices)
  shuffled_labels = tf.gather(labels, indices)

  mask = tf.reshape(mask, [batch_size, img_height, img_width, 1])

  images = images * tf.cast(1 - mask, dtype=tf.float32) + shuffled_images * tf.cast(mask, dtype=tf.float32)
  labels = labels * (1.0 - lam) + shuffled_labels * lam

  return images, labels

In [None]:
cutmix_ds = train_ds.map(cutmix, num_parallel_calls=tf.data.AUTOTUNE)
combined_ds = combined_ds.concatenate(cutmix_ds)

In [None]:
def erase(images, labels):
  # images shape: [batch_size, h, w, c]
  # labels shape: [batch_size, num_class]
  batch_size = tf.shape(images)[0]
  img_height = tf.shape(images)[1]
  img_width = tf.shape(images)[2]

  # lam is the cut percentage with shape: [batch_size, 1]
  lam = tf.random.uniform(shape=(batch_size, 1), minval=0.2, maxval=0.5, dtype=tf.float32)

  # we find the cut image height and width all with shape [batch_size, 1]
  erase_height = tf.cast(tf.cast(img_height, tf.float32) * tf.sqrt(lam), tf.int32)
  erase_width = tf.cast(tf.cast(img_width, tf.float32) * tf.sqrt(lam), tf.int32)

  min_height = tf.cast(erase_height // 2, tf.int32)
  max_height = tf.cast(img_height - 1 - erase_height // 2, tf.int32)
  erase_centre_x = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=1,
                                   dtype=tf.float32)
  erase_centre_x = erase_centre_x * tf.cast(max_height - min_height, tf.float32) + tf.cast(min_height, tf.float32)

  min_width = tf.cast(erase_width // 2, tf.int32)
  max_width = tf.cast(img_width - 1 - erase_width // 2, tf.int32)
  erase_centre_y = tf.random.uniform(shape=(batch_size, 1), minval=0, maxval=1,
                                   dtype=tf.float32)
  erase_centre_y = erase_centre_y * tf.cast(max_width - min_width, tf.float32) + tf.cast(min_width, tf.float32)

  # find four conors for rectangles all with shape: [batch_size, 1]
  x1 = tf.cast(tf.cast(erase_centre_x, tf.int32) - erase_height // 2, tf.int32)
  x2 = tf.cast(tf.cast(erase_centre_x, tf.int32) + erase_height // 2, tf.int32)
  y1 = tf.cast(tf.cast(erase_centre_y, tf.int32) - erase_width // 2, tf.int32)
  y2 = tf.cast(tf.cast(erase_centre_y, tf.int32) + erase_width // 2, tf.int32)

  x_indices = tf.range(img_height)
  y_indices = tf.range(img_width)
  y_grid, x_grid = tf.meshgrid(y_indices, x_indices)
  x_grid = tf.reshape(x_grid, [1, img_height, img_width])
  y_grid = tf.reshape(y_grid, [1, img_height, img_width])

  x1 = tf.reshape(x1, [batch_size, 1, 1])
  x2 = tf.reshape(x2, [batch_size, 1, 1])
  y1 = tf.reshape(y1, [batch_size, 1, 1])
  y2 = tf.reshape(y2, [batch_size, 1, 1])

  # mask matrix with shape : [batch_size, h, w]
  mask = tf.logical_and(tf.logical_and(x1 <= x_grid, x_grid <= x2), tf.logical_and(
      y1 <= y_grid, y_grid <= y2))
  mask = tf.cast(mask, dtype=tf.int32)

  mask = tf.reshape(mask, [batch_size, img_height, img_width, 1])

  images = images * tf.cast(1 - mask, dtype=tf.float32)

  return images, labels

In [None]:
erase_ds = train_ds.map(erase, num_parallel_calls=tf.data.AUTOTUNE)
combined_ds = combined_ds.concatenate(erase_ds)

In [None]:
def label_smoothing(labels, epsilon=0.1):
  num_class = tf.cast(tf.shape(labels)[1], tf.float32)
  return labels * (1.0 - epsilon) + epsilon / num_class

In [None]:
combined_ds = combined_ds.map(lambda images, labels: (images, label_smoothing(labels)), num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
combined_ds = combined_ds.shuffle(buffer_size=100)

In [None]:
class MBConv(tfla.Layer):
  def __init__(self, dim, out_channels):
    super().__init__()
    self.dim = dim
    self.conv1 = tfla.Conv2D(dim * 4, 1, strides=1, padding="same", use_bias=False)
    self.bn1 = tfla.BatchNormalization()
    self.gelu1 = tfla.Activation("gelu")
    self.depthwise = tfla.DepthwiseConv2D(3, strides=1, padding="same", use_bias=False)
    self.globalaverage = tfla.GlobalAveragePooling2D()
    self.dense1 = tfla.Dense(dim, activation="gelu")
    self.dense2 = tfla.Dense(dim * 4, activation="sigmoid")
    self.conv2 = tfla.Conv2D(out_channels, 1, strides=1, padding="same", use_bias=False)
    self.bn2 = tfla.BatchNormalization()
    self.shortcut = tfla.Conv2D(out_channels, 1, strides=1, padding="same", use_bias=False)

  def call(self, x):
    # x shape:[B, H, W, C]
    shortcut = self.shortcut(x)

    x = self.conv1(x)
    x = self.bn1(x)
    x = self.gelu1(x)
    x = self.depthwise(x)

    sne = self.globalaverage(x)
    sne = self.dense1(sne)
    sne = self.dense2(sne)
    sne = tf.expand_dims(sne, axis=1)
    sne = tf.expand_dims(sne, axis=1)

    x = x * sne
    x = self.conv2(x)
    x = self.bn2(x)

    # return shape:[B, H, W, out_channels]
    return x + shortcut

In [None]:
class self_attention(tfla.Layer):
  def __init__(self, dim, window_size, num_heads):
    super().__init__()
    self.dim = dim
    self.window_size = window_size
    self.num_heads = num_heads
    self.head_dims = dim // num_heads
    self.N = window_size * window_size
    self.scale = tf.cast(self.head_dims, tf.float32) ** -0.5
    self.toqkv = tfla.Dense(dim * 3, use_bias=False)
    self.dense = tfla.Dense(dim)

    self.num_rel_pos = (2 * window_size - 1) * (2 * window_size - 1)
    self.rel_pos = self.add_weight(
        shape=(self.num_rel_pos, num_heads),
        initializer=tf.random_normal_initializer(stddev=0.02),
        trainable=True
    )

    coords_h = tf.range(window_size)
    coords_w = tf.range(window_size)
    # coords shape:[2, window_size, window_size]
    coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
    # coords shape:[2, N]
    coords = tf.reshape(coords, shape=(2, -1))
    # rel_pos shape:[2, N, N]
    rel_pos = coords[:, :, None] - coords[:, None, :]

    # rel_pos_h and rel_pos_w shape:[N, N]
    rel_pos_h = rel_pos[0] + window_size - 1
    rel_pos_w = rel_pos[1] + window_size - 1

    # rel_index shape:[N, N]
    self.rel_index = rel_pos_h * (2 * window_size - 1) + rel_pos_w

  def call(self, x):
    # x shape:[B, N, C]
    # qkv shape:[B, N, 3 * C]
    x = self.toqkv(x)

    # x shape:[B, N, 3, num_heads, head_dims]
    x = tf.reshape(x, shape=(-1, self.N, 3, self.num_heads, self.head_dims))

    # x shape:[3, B, num_heads, N, head_dims]
    x = tf.transpose(x, perm=(2, 0, 3, 1, 4))
    # q, k, v shape:[B, num_heads, N, head_dims]
    q, k, v = x[0], x[1], x[2]

    # attn shape:[B, num_heads, N, N]
    attn = tf.matmul(q, k, transpose_b=True)
    attn = attn * self.scale

    # rel_pos shape:[N, N, num_heads]
    rel_pos = tf.gather(self.rel_pos, self.rel_index)
    rel_pos = tf.transpose(rel_pos, perm=(2, 0, 1))
    # rel_pos shape:[1, num_heads, N, N]
    rel_pos = tf.reshape(rel_pos, shape=(1, self.num_heads, self.N, self.N))

    attn = attn + rel_pos
    attn = tf.nn.softmax(attn, axis=-1)
    # attn shape:[B, num_heads, N, head_dims]
    attn = tf.matmul(attn, v)

    # attn shape:[B, N, num_heads, head_dims]
    attn = tf.transpose(attn, perm=(0, 2, 1, 3))
    attn = tf.reshape(attn, shape=(-1, self.N, self.dim))

    # return shape:[B, N, C]
    return attn

In [None]:
inputs = tfla.Input(shape=(224, 224, 3))

# x shape:[B, 112, 112, 64]
x = tfla.Conv2D(64, 3, strides=2, padding="same", use_bias=False)(inputs)
x = tfla.BatchNormalization()(x)
x = tfla.Activation("gelu")(x)

x = MBConv(dim=64, out_channels=96)(x)
# x shape:[B, 56, 56, 96]
x = tfla.Conv2D(96, 3, strides=2, padding="same", use_bias=False)(x)
x = MBConv(dim=96, out_channels=192)(x)
#x shape:[B, 28, 28, 192]
x = tfla.Conv2D(192, 3, strides=2, padding="same", use_bias=False)(x)

x = tfla.Reshape((784, 192))(x)
# x shape:[B, 784, 192]
x = self_attention(dim=192, window_size=28, num_heads=12)(x)
x = tfla.Reshape((28, 28, 192))(x)
# x shape:[B, 14, 14, 192]
x = tfla.Conv2D(384, 3, strides=2, padding="same", use_bias=False)(x)
x = tfla.Reshape((14 * 14, 384))(x)
x = self_attention(dim=384, window_size=14, num_heads=12)(x)
x = tfla.Reshape((14, 14, 384))(x)
# x shape:[B, 7, 7, 384]
x = tfla.Conv2D(384, 3, strides=2, padding="same", use_bias=False)(x)

x = tfla.GlobalAveragePooling2D()(x)
outputs = tfla.Dense(200, activation="softmax")(x)

model = tfm.Model(inputs=inputs, outputs=outputs)



In [None]:
model.summary()

In [None]:
steps_per_epoch = 3128
epochs = 80
total_steps = steps_per_epoch * epochs

lr = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=5e-4,
    decay_steps=total_steps,
    alpha=1e-2
)

opt = tf.keras.optimizers.AdamW(learning_rate=lr, weight_decay=5e-2)

In [None]:
model.compile(
    optimizer=opt,
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

history = model.fit(
    combined_ds,
    epochs=epochs
)

Epoch 1/80
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m698s[0m 180ms/step - accuracy: 0.0982 - loss: 4.5823
Epoch 2/80
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m508s[0m 161ms/step - accuracy: 0.2442 - loss: 3.7919
Epoch 3/80
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m507s[0m 161ms/step - accuracy: 0.3115 - loss: 3.5013
Epoch 4/80
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m507s[0m 161ms/step - accuracy: 0.3585 - loss: 3.3166
Epoch 5/80
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m507s[0m 161ms/step - accuracy: 0.3908 - loss: 3.1849
Epoch 6/80
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m508s[0m 161ms/step - accuracy: 0.4199 - loss: 3.0748
Epoch 7/80
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m508s[0m 161ms/step - accuracy: 0.4451 - loss: 2.9805
Epoch 8/80
[1m3128/3128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m509s[0m 161ms/step - accuracy: 0.4663 - loss:

In [None]:
model.evaluate(test_ds)

[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 217ms/step - accuracy: 0.4480 - loss: 2.6958


[2.679743528366089, 0.44670000672340393]

In [None]:
plt.plot(history.history['accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train'], loc='lower right')
plt.show()