In [None]:
import os
import shutil

base_dir = "/content/tiny-imagenet-200"
val_dir = os.path.join(base_dir, "val")
images_dir = os.path.join(val_dir, "images")
ann_file = os.path.join(val_dir, "val_annotations.txt")

# Read annotations
with open(ann_file) as f:
    annotations = [line.strip().split('\t') for line in f]

# Create class folders and move images
for img, cls, *_ in annotations:
    cls_dir = os.path.join(val_dir, cls)
    os.makedirs(cls_dir, exist_ok=True)
    shutil.move(
        os.path.join(images_dir, img),
        os.path.join(cls_dir, img)
    )

os.rmdir(images_dir)

In [None]:
import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"

!pip -q install tf-keras

In [None]:
import tensorflow as tf
import tensorflow.keras.layers as layers

In [None]:
# load training set
train_ds = tf.keras.utils.image_dataset_from_directory(
    "/content/tiny-imagenet-200/train",
    image_size=(256, 256),
    batch_size=None,
    label_mode="int"
)

# load test set
test_ds = tf.keras.utils.image_dataset_from_directory(
    "/content/tiny-imagenet-200/val",
    image_size=(224, 224),
    batch_size=64,
    label_mode="int"
)

Found 100000 files belonging to 200 classes.
Found 10000 files belonging to 200 classes.


In [None]:
def random_crop(image):
  image = tf.image.random_crop(image, (224, 224, 3))
  image = tf.image.random_flip_left_right(image)
  return image

In [None]:
train_ds = train_ds.map(lambda x, y: (random_crop(x), y))
train_ds = train_ds.batch(64)

In [None]:
def normalisation(images):
  mean = tf.constant([0.5, 0.5, 0.5])
  std = tf.constant([0.5, 0.5, 0.5])
  return (images / 255.0 - mean) / std

In [None]:
train_ds = train_ds.map(lambda x, y: (normalisation(x), y))
test_ds = test_ds.map(lambda x, y: (normalisation(x), y))

In [None]:
def tonchw(image, label):
  return tf.transpose(image, [0, 3, 1, 2]), label

In [None]:
train_ds = train_ds.map(tonchw)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.map(tonchw)
test_ds = test_ds.prefetch(tf.data.AUTOTUNE)

In [None]:
from transformers import TFViTModel

vit = TFViTModel.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    from_pt=True,
    use_safetensors=False
)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.
All PyTorch model weights were used when initializing TFViTModel.

All the weights of TFViTModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFViTModel for predictions without further training.


In [None]:
inputs = layers.Input(shape=(3, 224, 224))

outputs = vit(inputs, training=True)
cls_token = outputs.last_hidden_state[:, 0, :]
x = layers.Dropout(0.1)(cls_token)
x = layers.Dense(200)(x)

model = tf.keras.Model(inputs, x)

In [None]:
vit.trainable = False

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

model.fit(train_ds, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tf_keras.src.callbacks.History at 0x7b27ebd2cd10>

In [None]:
vit.trainable = True

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

model.fit(train_ds, epochs=20)

Epoch 1/20




Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tf_keras.src.callbacks.History at 0x7b27e3f0b680>

In [None]:
model.evaluate(test_ds)



[0.5162237882614136, 0.8986999988555908]