In [None]:
#!pip install --upgrade jax jaxlib
#!pip install git+https://github.com/deepmind/dm-haiku
#!pip install optax
!pip install --quiet --upgrade objax

In [None]:
"""paths = glob.glob('images/*/*.jpg')
paths = rd.sample(paths, len(paths))

image_paths = tf.convert_to_tensor(paths, dtype=tf.string)
labels = tf.convert_to_tensor([int(i.split('/')[2]) for i in paths])

train_image_paths, train_labels = image_paths[0:110], labels[0:110]
test_image_paths, test_labels = image_paths[110:141], labels[110:141]


train_dataset = tf.data.Dataset.from_tensor_slices((train_image_paths, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_image_paths, test_labels))

def load_fn(path, label):
    image = tf.image.decode_jpeg(tf.io.read_file(path))
    image = tf.image.convert_image_dtype(image, tf.float32)
    
    
    image = image * (2. / 255) - 1
    
    image = tf.image.resize(image, size=[224, 224])
    image = tf.transpose(image, (2, 1, 0))
    
    return image, label

train_ds = train_dataset.map(load_fn, num_parallel_calls = 2).batch(10)
test_ds = test_dataset.map(load_fn, num_parallel_calls = 2).batch(1)"""

"""Tests for Data Loading"""
# Ensuring images aren't repeated xD
#count = 0 
#prev_next_elem = tf.ones((32, 224, 224, 3))
#for next_element in ds:
#    count += 1
#    print(next_element[0].shape)
    #print(prev_next_elem.shape)
    #print(f"Pass {count}")
    #print(next_element[0] == prev_next_elem)
    #prev_next_elem = next_element[0]
    
"""for epoch in range(20):
    count = 0
    print(f"Epoch {epoch}")
    for next_element in train_ds:
        count = count+len(next_element[0])
        print(f"Processed: {count} Images")
        loss = train_op(next_element[0].numpy(), next_element[1].numpy())[0]
        print(f"Loss is :{loss}")
        print("===" * 10)
    accuracy = 0        
    if epoch % 4 == 0:
        for next_element in test_ds:
            p = eval_op(next_element[0].numpy())
            accuracy += (np.argmax(p, axis=1) == next_element[1].numpy()).sum()
        print("***" * 10)
        print(f"Accuracy: {accuracy / len(test_ds)}")
        print("***" * 10)  
"""


####### PYTORCH


"""class CustomJAXDataset(Dataset):
    def __init__(self, path_to_data, labels):
        self.img_paths = path_to_data
        self.labels = labels
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img= np.array(Image.open(self.img_paths[idx]).resize((224, 224), Image.BILINEAR))
        im = (img - np.min(img) / (np.max(img) - np.min(img)))
        im = np.transpose(im, (2, 1, 0))
        
        label = np.array(self.labels[idx])
        
        return {'im' : im, 'label' : label}"""

In [None]:
import math
import random as rd

import jax
import jax.numpy as jnp
from jax import grad, jit, random

import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf

import objax
from objax.zoo.resnet_v2 import ResNet18, ResNet34
import glob

from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image

In [None]:
class CustomJAXDataset(Dataset):
    def __init__(self, path_to_data, labels, transform = None):
        self.img_paths = path_to_data
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        
        if self.transform is not None:
            im = self.transform(Image.open(self.img_paths[idx]).resize((224, 224), Image.BILINEAR))
        
        label = np.array(self.labels[idx])
        return im, label

In [None]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])     


paths = glob.glob('./images/*/*.jpg')
image_paths = rd.sample(paths, len(paths))

train_image_paths, test_image_paths = image_paths[0:110], image_paths[110:141]
train_labels, test_labels =  [int(i.split('/')[2]) for i in train_image_paths], [int(i.split('/')[2]) for i in test_image_paths]


train_ds = CustomJAXDataset(train_image_paths, train_labels, transform = transform)
test_ds = CustomJAXDataset(test_image_paths, test_labels, transform = transform)

train_dl = DataLoader(train_ds, batch_size=10)
test_dl = DataLoader(test_ds, batch_size=5)

In [None]:
#model = ResNet18(in_channels = 3, num_classes = 3)

def conv_relu_pool(in_layers, out_layers, pool=True):
    ops = [objax.nn.Conv2D(in_layers, out_layers, 5),
            objax.functional.relu]
    if pool:
        ops.append(lambda x: objax.functional.average_pool_2d(x, size=2, strides=1))
    return ops

model = objax.nn.Sequential(conv_relu_pool(3, 32) + \
                            conv_relu_pool(32, 32) + \
                            conv_relu_pool(32, 64) + 
                            conv_relu_pool(64, 64) + \
                            
                            [objax.nn.Conv2D(64, 3, 3),
                             lambda x: x.mean((2,3))])

In [None]:
## MODEL PARAMS

lr = 0.01
epochs = 20

In [None]:
def train_model(model):
    
    #Optimizer
    
    opt = objax.optimizer.Adam(model.vars())
    
    def loss(x, labels):
        prediction = model(x, training=True)
        
        return objax.functional.loss.cross_entropy_logits_sparse(prediction, labels).mean()
    
    gv = objax.GradValues(loss, model.vars())
    
    def train_op(x, y, lr):
        
        g, v = gv(x, y)
        opt(lr = lr, grads = g)
        
        return v
    
    train_op = objax.Jit(train_op, gv.vars() + opt.vars())
    
    eval_op = objax.Jit(lambda x: objax.functional.softmax(model(x, training=False)), model.vars())

    
    
    for epoch in range(epochs):
        for img, label in train_dl:
            loss = train_op(x = img.numpy(), y = label.numpy(), lr=0.1)[0]
    
        accuracy = 0
        for img, label in test_dl:
            correct_preds = (np.argmax(eval_op(img.numpy()), axis=1) == label.numpy()).sum()
            accuracy = correct_preds / len(test_dl)

        print('Epoch %04d  Loss %.2f  Accuracy %.2f' % (epoch + 1, loss, 100 * accuracy))


In [None]:
train_model(model)