In [1]:
! pip install -q kaggle
from google.colab import files
files.upload()

! mkdir ~/.kaggle/
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

Saving kaggle.json to kaggle.json


In [2]:
!kaggle competitions download -c dogs-vs-cats

Downloading dogs-vs-cats.zip to /content
 98% 799M/812M [00:02<00:00, 242MB/s]
100% 812M/812M [00:02<00:00, 293MB/s]


In [3]:
from zipfile import ZipFile
file_name = "/content/dogs-vs-cats.zip"
with ZipFile(file_name, 'r') as zip:
    print('Extracting all the files now...')
    zip.extractall()
    print('Done!')

Extracting all the files now...
Done!


In [4]:
!pip install jax
!pip install -q flax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
[K     |████████████████████████████████| 185 kB 4.6 MB/s 
[K     |████████████████████████████████| 237 kB 42.3 MB/s 
[K     |████████████████████████████████| 145 kB 71.8 MB/s 
[K     |████████████████████████████████| 51 kB 8.5 MB/s 
[K     |████████████████████████████████| 85 kB 4.8 MB/s 
[?25h

In [14]:
import numpy as np
import pandas as pd 
import cv2
import matplotlib.pyplot as plt
import jax
from jax import lax, random, numpy as jnp
from jax import grad, jit, vmap, pmap

import flax
from flax.core import freeze, unfreeze
from flax import linen as nn  
from flax.training import train_state  
from flax.training.train_state import TrainState

import optax
import os

import seaborn as sns
import cv2
import time

import torch
from torch.utils.data import Dataset, random_split, DataLoader
import torch.nn.functional as F
import torch.nn as tnn
import tensorflow as tf
from torchvision import transforms
from PIL import Image

In [6]:
import zipfile

with zipfile.ZipFile("/content/train.zip","r") as z:
    z.extractall(".")
    
with zipfile.ZipFile("/content/test1.zip","r") as z:
    z.extractall(".")

In [7]:
data_df = pd.DataFrame({"file": os.listdir("/content/train")})
data_df["label"] = data_df["file"].apply(lambda x: x.split(".")[0])

data_df.head()

Unnamed: 0,file,label
0,cat.11080.jpg,cat
1,cat.3402.jpg,cat
2,cat.4622.jpg,cat
3,cat.6101.jpg,cat
4,cat.7125.jpg,cat


In [8]:
for i in range(len(data_df)):
    data_df.iloc[i,0]=f'/content/train/{data_df.iloc[i,0]}'
    
data_df

Unnamed: 0,file,label
0,/content/train/cat.11080.jpg,cat
1,/content/train/cat.3402.jpg,cat
2,/content/train/cat.4622.jpg,cat
3,/content/train/cat.6101.jpg,cat
4,/content/train/cat.7125.jpg,cat
...,...,...
24995,/content/train/cat.9318.jpg,cat
24996,/content/train/cat.10604.jpg,cat
24997,/content/train/cat.488.jpg,cat
24998,/content/train/dog.11808.jpg,dog


In [9]:
data_df.replace({'cat':0,'dog':1},inplace=True)
data_df.head()

Unnamed: 0,file,label
0,/content/train/cat.11080.jpg,0
1,/content/train/cat.3402.jpg,0
2,/content/train/cat.4622.jpg,0
3,/content/train/cat.6101.jpg,0
4,/content/train/cat.7125.jpg,0


In [15]:
class DatasetCreator(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)    
    
    def __getitem__(self, idx):
        row = self.df.loc[idx]
        img_id, img_label = row['file'], row['label']
        img_fname =  "/" + str(img_id)
        img = Image.open(img_fname)
        if self.transform:
            img = self.transform(img)
        return img, img_label


IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128

training_transform = transforms.Compose([
    transforms.RandomAffine(degrees=(-30, 30),
                            translate=(0.0, 0.2)),
    transforms.RandomHorizontalFlip(),
    transforms.Resize((IMAGE_HEIGHT,
                        IMAGE_WIDTH)),
    np.array])

testing_transform = transforms.Compose([
    transforms.Resize((IMAGE_HEIGHT,
                       IMAGE_WIDTH)),
    np.array])


train_df = data_df[:24000].reset_index()
val_df = data_df[24000:].reset_index()

train_ds = DatasetCreator(train_df, transform=training_transform)
val_ds = DatasetCreator(val_df, transform=testing_transform)
len(train_ds), len(val_ds)

(24000, 1000)

In [16]:
BATCH_SIZE = 16
train_dataloader = DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True, drop_last=True,num_workers=0)
test_dataloader = DataLoader(val_ds,batch_size=BATCH_SIZE,shuffle=True, drop_last=True,num_workers=0)

In [17]:
(image_batch, label_batch) = next(iter(train_dataloader))
print(image_batch.shape)
print(label_batch.shape)

torch.Size([16, 128, 128, 3])
torch.Size([16])


In [18]:
import warnings
warnings.filterwarnings("ignore")

In [21]:
class Model(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(32, (3, 3), padding="SAME")(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (2, 2), (2, 2))
        x = nn.Conv(64, (3, 3), padding="SAME")(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (2, 2), (2, 2))
        x = x.reshape(x.shape[0], -1)
        x = nn.Dense(features=1)(x)
        x = nn.sigmoid(x)
        return x

@jax.jit
def train_step(state, imgs, gt_labels):
    def loss_fn(params):
        logits = Model().apply({'params': params}, imgs)
        one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=2)
        loss = -jnp.mean(jnp.sum(one_hot_gt_labels * logits, axis=-1))
        return loss, logits
  
    (_, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads) 
    metrics = compute_metrics(logits=logits, gt_labels=gt_labels)  
    return state, metrics

@jax.jit
def eval_step(state, imgs, gt_labels):
    logits = Model().apply({'params': state.params}, imgs)
    return compute_metrics(logits=logits, gt_labels=gt_labels)

def train_one_epoch(state, dataloader, epoch):
    batch_metrics = []
    for cnt, (imgs, labels) in enumerate(dataloader):
        state, metrics = train_step(state, jnp.array(imgs), jnp.array(labels))
        batch_metrics.append(metrics)

    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }

    return state, epoch_metrics_np

def evaluate_model(state, test_imgs, test_lbls):
    metrics = eval_step(state, test_imgs, test_lbls)
    metrics = jax.device_get(metrics)  
    metrics = jax.tree_map(lambda x: x.item(), metrics)  
    return metrics

def create_train_state(key, learning_rate):
    model = Model()
    params = model.init(key, jnp.ones([1, *(128,128,3)]))['params']
    sgd_opt = optax.adam(learning_rate)
    return TrainState.create(apply_fn=model.apply, params=params, tx=sgd_opt)

def compute_metrics(*, logits, gt_labels):
    one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=2)
    loss = -jnp.mean(jnp.sum(one_hot_gt_labels * logits, axis=-1))
    accuracy = jnp.mean(jnp.argmax(logits, -1) == gt_labels)

    metrics = {
        'loss': loss,
        'accuracy': accuracy,
    }
    return metrics

seed = 0  
learning_rate = 0.01
num_epochs = 10

train_state = create_train_state(jax.random.PRNGKey(seed), learning_rate)

for epoch in range(1, num_epochs + 1):
    train_state, train_metrics = train_one_epoch(train_state, train_dataloader, epoch)
    print(f"Train epoch: {epoch}, loss: {train_metrics['loss']}, accuracy: {train_metrics['accuracy'] * 100}")

Train epoch: 1, loss: -0.999875009059906, accuracy: 50.0083327293396
Train epoch: 2, loss: -1.0, accuracy: 50.0083327293396
Train epoch: 3, loss: -1.0, accuracy: 50.0083327293396


KeyboardInterrupt: ignored