<a href="https://colab.research.google.com/github/95-sanya-95/Summer_ML_internship/blob/main/YOLO_trial_(Basic_Haiku_code).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

First let us brush up of **basic structure of a Haiku Model**

In [2]:
pip install git+https://github.com/deepmind/dm-haiku

Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-9tdd8kpb
  Running command git clone --filter=blob:none --quiet https://github.com/deepmind/dm-haiku /tmp/pip-req-build-9tdd8kpb
  Resolved https://github.com/deepmind/dm-haiku to commit a7b7e73dae840153ecd828e97a64b6a875b168f7
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [3]:
# Important libraries
import jax
import jax.numpy as jnp
import haiku as hk
import optax

In [None]:

# Simple Haiky module
class SimpleModel(hk.Module):
    def __init__(self, output_dim):
        super().__init__()
        self.output_dim = output_dim

    # make the architecture of neural network in _call_ function
    def __call__(self, x):
        # x = hk.Linear(128)(x)
        # x = jax.nn.relu(x)
        # x = hk.Linear(self.output_dim)(x)
        return x

# used to initialise the model with the given input
def forward_fn(x):
    model = SimpleModel(output_dim=10)
    return model(x)

# makes two functions that are to be used further 1). forward.apply  2). forward.init
forward = hk.transform(forward_fn)

rng = jax.random.PRNGKey(42)
example_input = jnp.ones([1, 32])  # Example input with shape [batch_size, input_dim] batch_size: the dataset used to train the model
params = forward.init(rng, example_input)

def loss_fn(params, x, y):
    predictions = forward.apply(params, x) # gives the predictions when x is applied on it
    loss = jnp.mean((predictions - y) ** 2)  # Mean Squared Error loss
    return loss

# optimiser used for training step
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

@jax.jit
def train_step(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# **YOLO Model**

In [6]:
class YOLO(hk.Module):
    def __init__(self, num_classes, grid_size, num_bboxes):
        super().__init__()
        self.num_classes = num_classes
        self.grid_size = grid_size
        self.num_bboxes = num_bboxes

    def __call__(self, x):
        # Convolutional layers
        x = hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME')(x)
        x = jax.nn.relu(x)
        x = hk.MaxPool(window_shape=2, strides=2, padding='VALID')(x)

        x = hk.Conv2D(64, kernel_shape=3, stride=1, padding='SAME')(x)
        x = jax.nn.relu(x)
        x = hk.MaxPool(window_shape=2, strides=2, padding='VALID')(x)

        x = hk.Conv2D(128, kernel_shape=3, stride=1, padding='SAME')(x)
        x = jax.nn.relu(x)
        x = hk.MaxPool(window_shape=2, strides=2, padding='VALID')(x)

        x = hk.Conv2D(256, kernel_shape=3, stride=1, padding='SAME')(x)
        x = jax.nn.relu(x)
        x = hk.MaxPool(window_shape=2, strides=2, padding='VALID')(x)

        x = hk.Conv2D(512, kernel_shape=3, stride=1, padding='SAME')(x)
        x = jax.nn.relu(x)
        x = hk.MaxPool(window_shape=2, strides=2, padding='VALID')(x)

        x = hk.Conv2D(1024, kernel_shape=3, stride=1, padding='SAME')(x)
        x = jax.nn.relu(x)

        # Fully connected layers
        x = hk.Flatten()(x)
        x = hk.Linear(4096)(x)
        x = jax.nn.relu(x)
        x = hk.Linear(self.grid_size * self.grid_size * (self.num_bboxes * 5 + self.num_classes))(x)

        # Reshape output to match the grid size and number of bounding boxes
        x = jnp.reshape(x, (-1, self.grid_size, self.grid_size, self.num_bboxes * 5 + self.num_classes))
        return x

## **YOLO Forward function**

In [7]:

def yolo_forward(x, num_classes, grid_size, num_bboxes):
    model = YOLO(num_classes, grid_size, num_bboxes)
    return model(x)

## YOLO Loss function
