Guidance on using this dataset: https://www.kaggle.com/code/acelevin/identifying-playing-cards

Download dataset from: https://www.kaggle.com/datasets/gunhcolab/object-detection-dataset-standard-52card-deck/data

In [1]:
import tensorflow as tf
import pickle
from PIL import Image
import os
import pandas as pd
import numpy as np
import time
import matplotlib.pyplot as plt

In [23]:
#hyper parameters:
BATCH_SIZE = 32
NUM_EPOCHS = 20

### Model Architecture

In [5]:
class CardPredictor(tf.keras.Model):
    def __init__(self):
        super(CardPredictor, self).__init__()
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
        
        data_augmentation = tf.keras.models.Sequential([
                                tf.keras.layers.RandomRotation(0.1),
                                tf.keras.layers.RandomZoom(0.1),
                            ])
        
        self.architecture = [        
                tf.keras.layers.InputLayer((300, 300, 3)),
                data_augmentation,
                
                tf.keras.layers.Conv2D(32, 3, padding='same', use_bias=False),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.ReLU(),
                tf.keras.layers.Conv2D(32, 3, padding='same', use_bias=False),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.ReLU(),
                
                tf.keras.layers.MaxPooling2D(2, 2),
                tf.keras.layers.Dropout(0.2),
                
                tf.keras.layers.Conv2D(64, 3, padding='same', use_bias=False),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.ReLU(),
                tf.keras.layers.Conv2D(64, 3, padding='same', use_bias=False),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.ReLU(),
                
                tf.keras.layers.MaxPooling2D(2, 2),
                tf.keras.layers.Dropout(0.3),
                
                tf.keras.layers.Conv2D(128, 3, padding='same', use_bias=False),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.ReLU(),
                tf.keras.layers.Conv2D(128, 3, padding='same', use_bias=False),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.ReLU(),
                
                tf.keras.layers.MaxPooling2D(2, 2),
                tf.keras.layers.Dropout(0.4),
                
                tf.keras.layers.GlobalAveragePooling2D(),
                
                tf.keras.layers.Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.001)),
                tf.keras.layers.Dropout(0.5),
                tf.keras.layers.Dense(52, activation='softmax')        
                ]
        
        self.sequential = tf.keras.Sequential(self.architecture, name="card_predictor")
        
    def call(self, x):
        """ Passes input image through the network. """
        return self.sequential(x)

    @staticmethod
    def loss_fn(labels, predictions): 
           """ Loss function for the model. """
           return tf.keras.losses.sparse_categorical_crossentropy(labels, predictions)

### Dataset Loading

In [6]:
with open('train.pkl', 'rb') as file:
    data = pickle.load(file)
    
new_data = {}
for key, inner_dict in data.items():
    img_path = inner_dict['img_path']
    value = inner_dict['class_label']
    new_data[img_path] = value

In [24]:
dataset = tf.data.Dataset.from_tensor_slices((new_data.keys(), new_data.values()))

def load_train_image(image_path, label):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [300, 300])
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

dataset = dataset.map(load_train_image, num_parallel_calls=tf.data.AUTOTUNE).shuffle(buffer_size=10000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
train_size = int(0.8 * len(dataset))
train_dataset = dataset.take(train_size)
val_dataset = dataset.skip(train_size)

In [35]:
# this cell it optional
# this checks that the dataset is balanced
import collections
v_counts = collections.Counter()
t_counts = collections.Counter()

for _, label in val_dataset:
    label = label.numpy().tolist() 
    v_counts.update(label)

for _, label in train_dataset:
    label = label.numpy().tolist() 
    t_counts.update(label)
    
print("Label counts in validation dataset:")
for i in range(52):
    print(f"Label {i}: {v_counts[i]}, {t_counts[i]}")

2025-04-27 15:09:06.453901: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


KeyboardInterrupt: 

### Model Loading and Training

In [25]:
model = CardPredictor()
#model = tf.keras.models.load_model("MODEL_PATH")

model.compile(optimizer=model.optimizer, loss=model.loss_fn, metrics=['accuracy'])

In [33]:
model.fit(train_dataset, 
          validation_data=val_dataset, 
          epochs=NUM_EPOCHS, 
          verbose=1)

Epoch 1/20
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 209ms/step - accuracy: 0.8509 - loss: 10751.3350 - val_accuracy: 0.8620 - val_loss: 9784.5518
Epoch 2/20
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 214ms/step - accuracy: 0.8442 - loss: 13309.9189 - val_accuracy: 0.8815 - val_loss: 8711.1084
Epoch 3/20
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 216ms/step - accuracy: 0.8650 - loss: 11565.9199 - val_accuracy: 0.8721 - val_loss: 11484.7246
Epoch 4/20
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 214ms/step - accuracy: 0.8758 - loss: 10543.7910 - val_accuracy: 0.8699 - val_loss: 11041.3184
Epoch 5/20
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 212ms/step - accuracy: 0.8615 - loss: 12944.3428 - val_accuracy: 0.8577 - val_loss: 16974.2871
Epoch 6/20
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 211ms/step - accuracy: 0.8621 - loss: 15146.7959 - val_accu

<keras.src.callbacks.history.History at 0x1335817f0>

In [34]:
model.save_weights('model_weights.weights.h5')