In [None]:

import os, random, math
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report

# inline plots
%matplotlib inline

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

print('TensorFlow version:', tf.__version__)


In [None]:

# ----------------------------
# Config
# ----------------------------
IMG_H, IMG_W = 32, 32
NUM_CLASSES = 3
CHANNELS = 1
VARIANTS_PER_CLASS = 24          # >= 20 (change if you want more)
TEST_PER_CLASS = 4               # held-out test count per class
VAL_SPLIT = 0.2                  # from remaining after taking test
EPOCHS = 25
BATCH_SIZE = 32

# Optimizer / LR grid
OPTIMIZERS = ['adam', 'sgd']
LRS = [1e-3, 1e-2]


In [None]:

def load_gray(path):
    # Load grayscale PNG to (H,W,1) float32 in [0,1].
    img = tf.io.read_file(path)
    img = tf.io.decode_png(img, channels=1)
    img = tf.image.convert_image_dtype(img, tf.float32)  # [0,1]
    img = tf.image.resize(img, (IMG_H, IMG_W), method='nearest')
    return img

def show_seeds(imgs, labels):
    plt.figure(figsize=(6,2))
    for i, (im, lb) in enumerate(zip(imgs, labels)):
        plt.subplot(1,3,i+1)
        plt.imshow(tf.squeeze(im), cmap='gray')
        plt.title(f'label {lb}')
        plt.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:

# ---- lightweight TensorFlow equivalents for rotate/translate to avoid tensorflow-addons dependency ----
def tfa_image_rotate(img, radians):
    c = tf.math.cos(radians)
    s = tf.math.sin(radians)
    h, w = IMG_H, IMG_W
    cx, cy = (w - 1) / 2.0, (h - 1) / 2.0
    tx = cx - c * cx + s * cy
    ty = cy - s * cx - c * cy
    transform = tf.stack([c, -s, tx, s, c, ty, 0.0, 0.0])
    transform = tf.reshape(transform, (1, 8))
    img_b = tf.expand_dims(img, 0)
    out = tf.raw_ops.ImageProjectiveTransformV3(
        images=img_b,
        transforms=transform,
        fill_mode='REFLECT',
        interpolation='BILINEAR',
        output_shape=[h, w]
    )
    return tf.squeeze(out, 0)

def tfa_image_translate(img, deltas_xy):
    dx, dy = deltas_xy[0], deltas_xy[1]
    transform = tf.stack([1.0, 0.0, -dx, 0.0, 1.0, -dy, 0.0, 0.0])
    transform = tf.reshape(transform, (1, 8))
    img_b = tf.expand_dims(img, 0)
    out = tf.raw_ops.ImageProjectiveTransformV3(
        images=img_b,
        transforms=transform,
        fill_mode='REFLECT',
        interpolation='BILINEAR',
        output_shape=[IMG_H, IMG_W]
    )
    return tf.squeeze(out, 0)

def random_zoom_keep_size(img, zoom):
    h, w = IMG_H, IMG_W
    if zoom > 1.0:
        new_h = tf.cast(tf.round(h / zoom), tf.int32)
        new_w = tf.cast(tf.round(w / zoom), tf.int32)
        top = (h - new_h) // 2
        left = (w - new_w) // 2
        cropped = img[top:top+new_h, left:left+new_w, :]
        return tf.image.resize(cropped, (h, w), method='bilinear')
    else:
        new_h = tf.cast(tf.round(h * zoom), tf.int32)
        new_w = tf.cast(tf.round(w * zoom), tf.int32)
        resized = tf.image.resize(img, (new_h, new_w), method='bilinear')
        pad_top = (h - new_h) // 2
        pad_bottom = h - new_h - pad_top
        pad_left = (w - new_w) // 2
        pad_right = w - new_w - pad_left
        padded = tf.pad(resized, [[pad_top, pad_bottom],[pad_left, pad_right],[0,0]], mode='REFLECT')
        return tf.image.resize(padded, (h, w), method='bilinear')

@tf.function
def augment_once(x):
    # random rotation ±15°
    angle = tf.random.uniform([], minval=-15.0, maxval=15.0) * math.pi / 180.0
    x = tfa_image_rotate(x, angle)

    # small translations ±3px
    tx = tf.random.uniform([], -3.0, 3.0)
    ty = tf.random.uniform([], -3.0, 3.0)
    x = tfa_image_translate(x, [tx, ty])

    # random zoom/crop (0.9–1.1)
    zoom = tf.random.uniform([], 0.9, 1.1)
    x = random_zoom_keep_size(x, zoom)

    # optional flip (set to low prob for digits)
    if tf.random.uniform([]) < 0.3:
        x = tf.image.flip_left_right(x)

    # brightness/contrast jitter
    x = tf.image.random_brightness(x, max_delta=0.15)
    x = tf.image.random_contrast(x, lower=0.8, upper=1.2)

    # Gaussian noise
    noise = tf.random.normal(tf.shape(x), mean=0.0, stddev=0.03)
    x = tf.clip_by_value(x + noise, 0.0, 1.0)
    return x


In [None]:

def make_dataset(seed_paths):
    seeds, labels = [], []
    for p, lb in seed_paths:
        assert os.path.exists(p), f'Missing file: {p}'
        img = load_gray(p)
        seeds.append(img)
        labels.append(lb)

    # Show seeds
    show_seeds(seeds, labels)

    X, y = [], []
    for img, lb in zip(seeds, labels):
        # include the original too
        X.append(img.numpy())
        y.append(lb)
        for _ in range(VARIANTS_PER_CLASS):
            X.append(augment_once(img).numpy())
            y.append(lb)

    X = np.stack(X, axis=0).astype('float32')
    y = np.array(y, dtype=np.int32)
    return X, y

# Define your seed paths here (update if needed)
seed_paths = [
    ('zero_0.png', 0),
    ('one_0.png', 1),
    ('two_0.png', 2),
]

X, y = make_dataset(seed_paths)
print('Dataset shape:', X.shape, y.shape)


In [None]:

def split_dataset(X, y, num_classes=NUM_CLASSES, test_per_class=TEST_PER_CLASS, val_split=VAL_SPLIT):
    Xc, yc = [], []
    testX, testy = [], []
    for c in range(num_classes):
        idx = np.where(y == c)[0]
        np.random.shuffle(idx)
        test_take = idx[:test_per_class]
        keep = idx[test_per_class:]
        testX.append(X[test_take])
        testy.append(y[test_take])
        Xc.append(X[keep])
        yc.append(y[keep])
    X_keep = np.concatenate(Xc, axis=0)
    y_keep = np.concatenate(yc, axis=0)
    X_test = np.concatenate(testX, axis=0)
    y_test = np.concatenate(testy, axis=0)

    perm = np.random.permutation(len(X_keep))
    X_keep, y_keep = X_keep[perm], y_keep[perm]
    n_val = int(len(X_keep) * val_split)
    X_val, y_val = X_keep[:n_val], y_keep[:n_val]
    X_train, y_train = X_keep[n_val:], y_keep[n_val:]

    print(f'Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}')
    return X_train, y_train, X_val, y_val, X_test, y_test

X_train, y_train, X_val, y_val, X_test, y_test = split_dataset(X, y)


In [None]:

def build_model():
    model = keras.Sequential([
        layers.Input(shape=(IMG_H, IMG_W, CHANNELS)),
        layers.Conv2D(16, 3, activation='relu', padding='same'),
        layers.MaxPool2D(),
        layers.Conv2D(32, 3, activation='relu', padding='same'),
        layers.MaxPool2D(),
        layers.Flatten(),
        layers.Dense(32, activation='relu'),
        layers.Dense(NUM_CLASSES, activation='softmax'),
    ])
    return model


In [None]:

def train_grid(X_train, y_train, X_val, y_val, optimizers=OPTIMIZERS, lrs=LRS):
    best = {'val_acc': -1.0, 'history': None, 'model': None, 'opt_name': None, 'lr': None}
    y_train_cat = keras.utils.to_categorical(y_train, NUM_CLASSES)
    y_val_cat = keras.utils.to_categorical(y_val, NUM_CLASSES)

    for opt_name in optimizers:
        for lr in lrs:
            model = build_model()
            if opt_name == 'adam':
                opt = keras.optimizers.Adam(learning_rate=lr)
            elif opt_name == 'sgd':
                opt = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
            else:
                raise ValueError('Unknown optimizer')

            model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
            print(f'=== Training with {opt_name.upper()} lr={lr} ===')
            hist = model.fit(
                X_train, y_train_cat,
                validation_data=(X_val, y_val_cat),
                epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=0
            )
            val_acc = max(hist.history['val_accuracy'])
            print(f'Max val_acc: {val_acc:.4f}')
            if val_acc > best['val_acc']:
                best.update({'val_acc': val_acc, 'history': hist, 'model': model, 'opt_name': opt_name, 'lr': lr})
    return best

best = train_grid(X_train, y_train, X_val, y_val)
print(f"Best setting -> optimizer: {best['opt_name']}, lr: {best['lr']}, best_val_acc: {best['val_acc']:.4f}")


In [None]:

def plot_accuracy(history):
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    epochs = np.arange(1, len(acc)+1)
    plt.figure(figsize=(5,3))
    plt.plot(epochs, acc, label='train acc')
    plt.plot(epochs, val_acc, label='val acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Train vs. Val Accuracy (best run)')
    plt.legend()
    plt.tight_layout()
    plt.show()

plot_accuracy(best['history'])


In [None]:

def plot_confmat(cm, class_names=('0','1','2')):
    plt.figure(figsize=(3.8,3.2))
    plt.imshow(cm, interpolation='nearest')
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names)
    plt.yticks(tick_marks, class_names)
    thresh = cm.max() / 2.0 if cm.max() > 0 else 0.5
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                     horizontalalignment='center',
                     color='white' if cm[i, j] > thresh else 'black')
    plt.ylabel('True')
    plt.xlabel('Predicted')
    plt.tight_layout()
    plt.show()

model = best['model']
y_pred_prob = model.predict(X_test, batch_size=64, verbose=0)
y_pred = np.argmax(y_pred_prob, axis=1)

acc = accuracy_score(y_test, y_pred)
precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='macro', zero_division=0)

print('=== Test Metrics ===')
print(f'Accuracy:  {acc:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall:    {recall:.4f}')
print(f'F1 (macro):{f1:.4f}')

cm = confusion_matrix(y_test, y_pred, labels=[0,1,2])
print('\nConfusion Matrix:\n', cm)
plot_confmat(cm, class_names=('0','1','2'))

print('\nClassification report:')
print(classification_report(y_test, y_pred, digits=4))



### Short Summary (fill with your actual results)
- **Best setting:** *(e.g., Adam @ 1e-3)* achieved the highest validation accuracy with stable learning curves.
- **Generalization:** Test **F1 (macro)** ≈ *…*; most confusion between *1 ↔ 2* (expected with tiny data).
- **Fit signs:** Slight overfitting after epoch ~… (train↑, val↔/↓). Consider stronger augmentation or early stopping.
