In [None]:
import tensorflow as tf
from transformers import ViTFeatureExtractor, TFViTModel
from PIL import Image
import numpy as np
import os

train_dir = "/kaggle/input/stanford-car-dataset-by-classes-folder/car_data/car_data/train"
test_dir = "/kaggle/input/stanford-car-dataset-by-classes-folder/car_data/car_data/test"
img_size = (224, 224)
batch_size = 32
checkpoint_path = "./best_model_checkpoint.h5"

classes = sorted(os.listdir(train_dir))

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

def load_and_preprocess_image(img_path, label):
    if not tf.io.gfile.exists(img_path) or tf.io.gfile.isdir(img_path):
        raise ValueError(f"{img_path} is not a valid file.")

    img = tf.io.read_file(img_path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.resize(img, img_size)
    img = tf.cast(img, tf.float32) / 255.0  
    
    img_np = img.numpy()
    pil_image = Image.fromarray((img_np * 255).astype(np.uint8))

    pixel_values = feature_extractor(images=pil_image, return_tensors='np')['pixel_values']
    pixel_values = tf.convert_to_tensor(pixel_values, dtype=tf.float32)  
    pixel_values = tf.squeeze(pixel_values, axis=0)  

    return pixel_values, label  

def preprocess_dataset(file_paths, labels, batch_size, shuffle=False):
    def gen():
        for path, label in zip(file_paths, labels):
            try:
                yield load_and_preprocess_image(path, label)
            except Exception as e:
                print(f"Skipping file {path} due to error: {e}")
    
    dataset = tf.data.Dataset.from_generator(
        gen,
        output_signature=(
            tf.TensorSpec(shape=(3, 224, 224), dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int64)
        )
    )
    
    dataset_length = len(file_paths)
    print("Dataset length:", dataset_length)
    
    if shuffle and dataset_length > 0:
        dataset = dataset.shuffle(buffer_size=min(dataset_length, 10000))  
    dataset = dataset.batch(batch_size)

    for image_batch, label_batch in dataset.take(1):
        tf.print("Dataset image batch shape:", tf.shape(image_batch))
        tf.print("Dataset label batch shape:", tf.shape(label_batch))
    
    return dataset

def get_file_paths_and_labels(directory):
    file_paths = []
    labels = []
    for class_name in os.listdir(directory):
        class_dir = os.path.join(directory, class_name)
        if os.path.isdir(class_dir):  
            for fname in os.listdir(class_dir):
                full_path = os.path.join(class_dir, fname)
                if os.path.isfile(full_path):  
                    file_paths.append(full_path)
                    labels.append(classes.index(class_name))
    return file_paths, labels

train_file_paths, train_labels = get_file_paths_and_labels(train_dir)
test_file_paths, test_labels = get_file_paths_and_labels(test_dir)

train_ds = preprocess_dataset(train_file_paths, train_labels, batch_size, shuffle=True)

total_size = len(test_file_paths)
val_size = 4020
test_size = total_size - val_size

test_file_paths, test_labels = zip(*list(zip(test_file_paths, test_labels)))
test_ds = preprocess_dataset(test_file_paths, test_labels, batch_size, shuffle=False)
val_ds = test_ds.take(val_size)
test_ds = test_ds.skip(val_size)

base_model = TFViTModel.from_pretrained('google/vit-base-patch16-224')

def create_model():
    inputs = tf.keras.Input(shape=(3, 224, 224), dtype=tf.float32)
    
    features = base_model(inputs)[0]  
    
    pooled_features = tf.keras.layers.GlobalAveragePooling1D()(features)  
    
    num_classes = len(classes)
    logits = tf.keras.layers.Dense(num_classes, activation='softmax')(pooled_features)
    
    model = tf.keras.Model(inputs=inputs, outputs=logits)
    return model


model = create_model()

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    monitor='val_loss',
    save_best_only=True,
    save_weights_only=True,
    mode='min'
)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=20,
    callbacks=[checkpoint_callback]
)


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



Dataset length: 8144
Dataset image batch shape: [32 3 224 224]
Dataset label batch shape: [32]
Dataset length: 8041
Dataset image batch shape: [32 3 224 224]
Dataset label batch shape: [32]


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFViTModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing TFViTModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFViTModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFViTModel were not initialized from the PyTorch model and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/20


  output, from_logits = _get_logits(
I0000 00:00:1726071845.148116     101 service.cc:145] XLA service 0x7a82f066fb40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1726071845.148173     101 service.cc:153]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
I0000 00:00:1726071845.297695     101 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
  1/255 [..............................] - ETA: 5:58:33 - loss: 0.0044 - accuracy: 1.0000

In [9]:
model.evaluate(val_ds)



[1.2087591886520386, 0.7096132040023804]