In [1]:
import jax
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os,sys
import optax
from flax.training.train_state import TrainState
from flax import linen as nn
from jax import numpy as jnp

In [2]:
# Reading in the Data
train_df = pd.read_csv("data/sign_mnist_train.csv")
valid_df = pd.read_csv("data/sign_mnist_test.csv")

In [3]:
sample_df = train_df.head().copy()  # Grab the top 5 rows
sample_df.pop('label')
sample_x = sample_df.values
sample_x

array([[107, 118, 127, ..., 204, 203, 202],
       [155, 157, 156, ..., 103, 135, 149],
       [187, 188, 188, ..., 195, 194, 195],
       [211, 211, 212, ..., 222, 229, 163],
       [164, 167, 170, ..., 163, 164, 179]], shape=(5, 784))

In [4]:
sample_x.shape

(5, 784)

In [5]:
IMG_HEIGHT = 28
IMG_WIDTH = 28
IMG_CHS = 1

sample_x = sample_x.reshape(-1, IMG_CHS, IMG_HEIGHT, IMG_WIDTH)
sample_x.shape

(5, 1, 28, 28)

In [6]:
# SOLUTION
class MyDataset:
    def __init__(self, base_df):
        x_df = base_df.copy()  # Some operations below are in-place
        y_df = x_df.pop('label')
        x_df = x_df.values / 255  # Normalize values from 0 to 1
        x_df = x_df.reshape(-1, IMG_CHS, IMG_WIDTH, IMG_HEIGHT)
        self.xs = jnp.array(x_df)
        self.ys = jnp.array(y_df)

    def __getitem__(self, idx):
        x = self.xs[idx]
        y = self.ys[idx]
        return x, y

    def __len__(self):
        return len(self.xs)

In [7]:
def create_batches(dataset, batch_size):
    num_batches = len(dataset) // batch_size
    for i in range(num_batches):
        yield dataset[i * batch_size:(i + 1) * batch_size]

In [8]:
BATCH_SIZE = 32

train_data = MyDataset(train_df)
# Creating batches
train_batches = list(create_batches(train_data, BATCH_SIZE))

# Putting the data on the device
train_loader = [jax.device_put(batch) for batch in train_batches]


In [9]:
valid_data = MyDataset(valid_df)
valid_batches = list(create_batches(valid_data, BATCH_SIZE))
valid_loader = [jax.device_put(batch) for batch in valid_batches]

In [10]:
train_N = len(train_loader)
valid_N = len(valid_loader)

In [11]:
batch = next(iter(train_loader))
batch

(Array([[[[0.41960785, 0.4627451 , 0.49803922, ..., 0.6666667 ,
           0.6666667 , 0.6627451 ],
          [0.43529412, 0.4745098 , 0.5058824 , ..., 0.67058825,
           0.67058825, 0.6666667 ],
          [0.44313726, 0.48235294, 0.5137255 , ..., 0.67058825,
           0.67058825, 0.67058825],
          ...,
          [0.5568628 , 0.5882353 , 0.62352943, ..., 0.7921569 ,
           0.7882353 , 0.78431374],
          [0.5568628 , 0.5921569 , 0.627451  , ..., 0.8       ,
           0.79607844, 0.7921569 ],
          [0.5568628 , 0.5921569 , 0.627451  , ..., 0.8       ,
           0.79607844, 0.7921569 ]]],
 
 
        [[[0.60784316, 0.6156863 , 0.6117647 , ..., 0.5411765 ,
           0.36078432, 0.42352942],
          [0.61960787, 0.62352943, 0.62352943, ..., 0.5568628 ,
           0.45490196, 0.56078434],
          [0.6313726 , 0.6313726 , 0.6313726 , ..., 0.5764706 ,
           0.49019608, 0.54901963],
          ...,
          [0.63529414, 0.62352943, 0.5686275 , ..., 0.35686275,


In [12]:
batch[0].shape

(32, 1, 28, 28)

In [13]:
batch[1].shape

(32,)

In [38]:
# model
n_classes = 24
kernel_size = 3
flattened_img_size = 75 * 3 * 3

class CNN(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=28, strides=(1, 1), kernel_size=(kernel_size,kernel_size), padding="SAME")(x)

        return x

In [39]:
# Initializing the Model
rng = jax.random.PRNGKey(0)
model = CNN()

In [40]:
# initialize the parameters
params = model.init(jax.random.PRNGKey(0), batch[0])
# forward pass
y_pred = model.apply(params, batch[0])

In [41]:
y_pred.shape

(32, 1, 28, 28)