In [None]:
from PIL import Image
import glob
import math
import random as rd
rd.seed(123)

import objax
from objax.zoo.vgg import VGG19

import jax.numpy as jnp

import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from tqdm.notebook import tqdm

In [None]:
class CustomJAXDataset(Dataset):
    def __init__(self, path_to_data, labels, transform = None):
        
        self.img_paths = path_to_data
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        
        im = Image.open(self.img_paths[idx])
        if self.transform is not None:
            im = self.transform(im)
            
        
        label = self.labels[idx]
        return im, label

In [None]:
means =  [0.485, 0.456, 0.406]
stds  =  [0.229, 0.224, 0.225]

transform = transforms.Compose([
        transforms.Resize((224, 224), Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(means, stds)])     

paths = glob.glob('./images/train/*/*.jpg')
image_paths = rd.sample(paths, len(paths))

train_image_paths, test_image_paths = image_paths[0:110], image_paths[110:141]
train_labels, test_labels =  [int(i.split('/')[3]) for i in train_image_paths], [int(i.split('/')[3]) for i in test_image_paths]


train_ds = CustomJAXDataset(train_image_paths, train_labels, transform = transform)
test_ds = CustomJAXDataset(test_image_paths, test_labels, transform = transform)

train_dl = DataLoader(train_ds, batch_size=10)
test_dl = DataLoader(test_ds, batch_size=5)

In [None]:
model = VGG19(pretrained=True)

model_vars = model.vars()

new_layer = objax.nn.Linear(4096, 102)

new_model = model[:-1]
new_model = objax.nn.Sequential(new_model + [new_layer])

In [None]:
lr = 0.01

In [None]:
def train_model(model):
    
    
    
    #opt = objax.optimizer.Momentum(.vars())
    def loss(x, labels):
        prediction = model(x, training=True)
        
        return objax.functional.loss.cross_entropy_logits_sparse(prediction, labels).mean()
    
    vars_train = objax.VarCollection((k, v) for k, v in new_model.vars().items() if '(Sequential)[42](Linear)' in k)
    opt = objax.optimizer.Adam(vars_train)
    gv = objax.GradValues(loss, vars_train)
    
    def train_op(x, y, lr):
        
        g, v = gv(x, y)
        opt(lr = lr, grads = g)
        
        return v
    
    train_op = objax.Jit(train_op, gv.vars() + opt.vars())
    
    eval_op = objax.Jit(lambda x: objax.functional.softmax(model(x, training=False)), model.vars())

    
    
    for epoch in range(20):
        for img, label in tqdm(train_dl):
            loss = train_op(x = img.numpy(), y = label.numpy(), lr = lr)[0]
            print('Epoch %04d  Loss %.2f' % (epoch + 1, loss))
        accuracy = 0
        correct_preds = 0
        for img, label in tqdm(test_dl):
            correct_preds += (jnp.argmax(eval_op(img.numpy()), axis=1) == label.numpy()).sum()
            accuracy = correct_preds / len(test_dl)

        
        print('Epoch %04d  Loss %.2f  Accuracy %.2f' % (epoch + 1, loss, 100 * accuracy))
        #print('Epoch %04d  Loss %.2f' % (epoch + 1, loss))

In [None]:
train_model(new_model)