# Random Gabor experiment

> In this quick experiment we will be training an MNIST classifier using `RandomGabor` layers.

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

## Library importing

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange, repeat

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

from flayers.layers import RandomGabor

## Data loading

> We will be using MNIST for a simple and quick test.

In [None]:
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

X_train = repeat(X_train, "b h w ->  b h w c", c=1)/255.0
X_test = repeat(X_test, "b h w ->  b h w c", c=1)/255.0

X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

((60000, 28, 28, 1), (60000,), (10000, 28, 28, 1), (10000,))

## Definition of simple model

In [None]:
model = tf.keras.Sequential([
    RandomGabor(n_gabors=4, size=20, input_shape=(28,28,1)),
    # layers.Conv2D(32, 3, input_shape=(28,28,1)),
    layers.MaxPool2D(2),
    layers.Flatten(),
    layers.Dense(10, activation="softmax")
])
model.compile(optimizer="adam",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 26, 26, 32)        320       
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 13, 13, 32)       0         
 2D)                                                             
                                                                 
 flatten_1 (Flatten)         (None, 5408)              0         
                                                                 
 dense_1 (Dense)             (None, 10)                54090     
                                                                 
Total params: 54,410
Trainable params: 54,410
Non-trainable params: 0
_________________________________________________________________


In [None]:
with tf.device('/GPU:0'):
    history = model.fit(X_train, Y_train, batch_size=128, epochs=5)

Epoch 1/5
  6/469 [..............................] - ETA: 4s - loss: 2.1443 - accuracy: 0.3646  

2022-09-06 10:52:08.309019: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.




KeyboardInterrupt: 