# PyTorch Vs Numba Comparison
- Implementation of a simple neural network using PyTorch and Numba jit
- Speed calculation of both implementations
- Surface plot of the loss function for both implementations

In [None]:
import torch
import functions as fn
from math import sqrt
import random
import timeit
import matplotlib.pyplot as plt



## Pytorch AutoGrad

In [None]:
data_x, data_y = fn.gen_pts_(1000)
xp, yp = 5.0, 5.0

# fn.loss(xp, yp, data_x, data_y)
#* define loss function in python
def Lossfn(xp, yp, data_x, data_y):
    loss = sum([((xi - xp)**2  + (yi - yp)**2)**0.5 for xi, yi in zip(data_x, data_y)]) / len(data_x)
    return loss
#* assurance that the function is correct
print(fn.loss(xp, yp, data_x, data_y) == Lossfn(xp, yp, data_x, data_y))

#* create tensors from data
data = torch.tensor([data_x, data_y]).T
pnt = torch.tensor([xp, yp], requires_grad=True)

#* compute gradient numerically
def GradLimit(xp, yp, data_x, data_y, H=0.001):
    dl_dx = (Lossfn(xp + H, yp, data_x, data_y) - Lossfn(xp, yp, data_x, data_y)) / H
    dl_dy = (Lossfn(xp, yp + H, data_x, data_y) - Lossfn(xp, yp, data_x, data_y)) / H

    return dl_dx, dl_dy

#* compute gradient analytically - close form
def GradCloseform(xp, yp, data_x, data_y):
    sum_x, sum_y = 0, 0
    for xi, yi in zip(data_x, data_y):
        inv_sqrt = ((xi - xp) ** 2 + (yi - yp) ** 2) ** -0.5
        sum_x += inv_sqrt * (xi - xp)
        sum_y += inv_sqrt * (yi - yp)

    dl_dx = -sum_x / len(data_x)
    dl_dy = -sum_y / len(data_y)
    return dl_dx, dl_dy

#* compute gradient auto-grad
def GradTroch(data, pnt):
    loss = torch.mean(torch.sqrt(torch.sum((data - pnt) ** 2, dim=1)))
    loss.backward()
    return pnt.grad


print(GradLimit(xp, yp, data_x, data_y))
print(GradCloseform(xp, yp, data_x, data_y))
print(GradTroch(data, pnt))

In [None]:
limit_timer = timeit.timeit(lambda: GradLimit(xp, yp, data_x, data_y), number=100)
closeform_timer = timeit.timeit(lambda: GradCloseform(xp, yp, data_x, data_y), number=100)
torch_timer = timeit.timeit(lambda: GradTroch(data, pnt), number=100)

print(f"close form is faster than limit by {(limit_timer / closeform_timer):.3f} times")
print(f"torch is faster than limit by {(limit_timer / torch_timer):.3f} times")
print(f"torch is faster than close form by {(closeform_timer / torch_timer):.3f} times")

## Comments 
- it's faster to use closed form than numerical limits
- it's faster to use torch autograd than closed form

## Numba compilation 
- numba compiles python code to machine code with LLVM
- numba is a just-in-time compiler
- code 
  - first compiled to IR
  - then executed

- the first time of execution is slow as compilation is done
- the second time of execution is fast as the compiled code is executed

## Loss mesh 
- Compile loss mesh using normal python types, list, dict, etc
- speed up with numba

In [None]:
try:
    from numba import jit
    import warnings
    import time

    warnings.filterwarnings("ignore")
except:
    print(
        """
          Numba is not installed.
          Comment the JIT decorator and try again.
            """
    )
    

@jit(nopython=True)
def Lossfn(xp, yp, data_x, data_y):
    loss = sum(
        [((xi - xp) ** 2 + (yi - yp) ** 2) ** 0.5 for xi, yi in zip(data_x, data_y)]
    ) / len(data_x)
    return loss


@jit(nopython=True)
def CreatLossMesh(step, mesh_size, data_x, data_y):
    STEP = step
    MESH_SIZE = mesh_size
    x_mesh = [-1 + i * STEP for i in range(MESH_SIZE)]
    y_mesh = [-1 + i * STEP for i in range(MESH_SIZE)]
    loss_mesh = [[Lossfn(x, y, data_x, data_y) for x in x_mesh] for y in y_mesh]
    return loss_mesh


start = time.time()
loss_mesh = CreatLossMesh(0.01, 200, data_x, data_y)
end = time.time()
sped_mesh_jit_1st = end - start

start = time.time()
loss_mesh = CreatLossMesh(0.01, 200, data_x, data_y)
end = time.time()
sped_mesh_jit_2nd = end - start

print(f"first run: {sped_mesh_jit_1st:.3f} seconds")
print(f"second run: {sped_mesh_jit_2nd:.3f} seconds")

## Torch Loss Mesh
- split loss mesh into two parts 
  - create mesh, define loss 
  - compute loss
- comput all loss mesh in one go 

In [None]:
torch_loss_fn = lambda pnt, data : torch.mean(torch.sqrt(torch.sum((data - pnt) ** 2, dim=1)))
torch_data = torch.tensor([data_x, data_y]).T
torch_pnt = torch.tensor([xp, yp], requires_grad=True)
x_mesh = torch.linspace(-1, 2, 300)
y_mesh = torch.linspace(-1, 2, 300)
print("mesh size: ", len(x_mesh) * len(y_mesh))

start = time.time()
loss_mesh_tensor = [
    [torch_loss_fn(torch.tensor([x,y]), torch_data) for x in x_mesh] for y in y_mesh
]
end = time.time()
sped_mesh_torch_1st = end - start
print(f"first run: {sped_mesh_torch_1st:.3f} seconds")
#* Numba is faster xd? ofc we are still using using python interpreter for the computation
#* anyway this is much faster than using the previous method without numba


## Contour plot

In [None]:
fig = plt.figure(figsize=(10, 10))
# mp.style.use(plt.style.available[6])
ax = fig.add_subplot(1, 1, 1, projection = '3d')
#ax.stem() #TODO steam plot
ax.contour(x_mesh, y_mesh, loss_mesh_tensor, levels =200)
ax.view_init(elev=0,)

## 3d plot of single surface

In [None]:
import numpy as np
plt.figure(figsize=(10, 10))
ax = plt.axes(projection='3d')
ax.plot_surface(np.array(x_mesh), np.array(y_mesh), np.array(loss_mesh_tensor), cmap='viridis', edgecolor='none')
ax.view_init(elev=0,)
plt.show()

## Recompute the loss mesh with torch 

In [None]:
torch_data = torch.tensor([data_x, data_y]).t()
x_space = torch.linspace(-1, 2, 300).repeat(1, 300).t() #* (90000, 1)
y_space = torch.linspace(-1, 2, 300).repeat(300, 1).t().reshape(-1, 1) #* (90000, 1)
print(x_space.shape, y_space.shape)
grid_points = torch.hstack((x_space, y_space)) #* (90000, 2)
print(grid_points.shape)

def FasterGridLoss(pnts, data):
    
    data = data.repeat(pnts.shape[0], 1, 1) #* (90000, 1000, 2)
    data = data.view(-1, pnts.shape[0], 2) #* (1000, 90000, 2)
    #* apply the simple loss approach
    
    res = (data - pnts) ** 2
    res = torch.sum(res, dim=2) ** 0.5 
    loss = torch.mean(res, dim=0)
    # loss = loss.view(300, 300)
    
    return loss

grid_faster = FasterGridLoss(grid_points, torch_data).tolist()
adjusted_grid = [grid_faster[i::300] for i in range(300)]

In [None]:
STEP = 0.01
MESH_SIZE = 300

#* define mesh [-1, 2] for x and y
x_mesh = [-1 +i * STEP for i in range(MESH_SIZE)]
y_mesh = [-1 +i * STEP for i in range(MESH_SIZE)]


@jit(nopython=False, parallel=True)
def plt_contour():
    fig = plt.figure(figsize=(10, 10))
    # mp.style.use(plt.style.available[6])
    ax = fig.add_subplot(1, 1, 1, projection = '3d')
    ax.contour(x_mesh, y_mesh, adjusted_grid, levels =100)
    ax.view_init(elev=0,)
    plt.show()
    
plt_contour()

## Faster implementation


In [None]:
x_space = torch.linspace(-1, 2, 300)
y_space = torch.linspace(-1, 2, 300)
x, y = torch.meshgrid(x_space, y_space, indexing='xy')
data_x, data_y = fn.gen_pts_(1000)
data = torch.tensor([data_x, data_y]).t()
pnts = torch.stack([x, y], dim=2).reshape(-1, 2)



def FasterGridLoss(pnts, data):
    print(pnts.shape, data.shape)
    data = data.repeat(pnts.shape[0], 1, 1) #* (90000, 1000, 2)
    data = data.view(-1, pnts.shape[0], 2) #* (1000, 90000, 2)
    #* apply the simple loss approach
    
    res = (data - pnts) ** 2
    res = torch.sum(res, dim=2) ** 0.5 
    loss = torch.mean(res, dim=0)
    # loss = loss.view(300, 300)
    
    return loss

grid_faster = FasterGridLoss(pnts, data).view(300, 300)
ax = plt.axes(projection='3d')
ax.plot_surface(x.numpy(), y.numpy(), grid_faster.numpy(), cmap='viridis', edgecolor='none')
ax.view_init(elev=0,)
plt.show()

# OPTIONAL TO UNDERSTAND

## repeat and hstack  
![repeat](misc/repeat.png)
```python
x.repeat(2,1) #* repeat 2 times along dim 0, 1 times along dim 1
```

![stack](misc/stack.png)


In [None]:
x_axies = torch.linspace(0, 4, 5).repeat(1, 5).t()
y_axies = torch.linspace(0, 4, 5).repeat(5, 1).t().reshape(-1, 1)
print(x_axies.shape, y_axies.shape)
pnts = torch.hstack((x_axies, y_axies))
print(pnts.shape)

plt.scatter(pnts[:, 0], pnts[:, 1])

## Notes 
- Goal is to create a Grid of points in 2d space
- create grid with has 90000 points each with 2 coordinates
- calculate loss for each point in the grid

In [None]:
x_space = torch.linspace(-1, 2, 300).repeat(1, 300).t()
y_space = torch.linspace(-1, 2, 300).repeat(300, 1).t().reshape(-1, 1)
pnts = torch.hstack((x_space, y_space))
# print(x_space.shape, y_space.shape), print(pnts.shape)
# plt.scatter(pnts[:, 0], pnts[:, 1])
#* TAKES SO MUCH MEMORY 

## Understanding the faster formulation


In [None]:
#* data.shape = torch.Size([1000, 2])
#* pnts.shape = torch.Size([90000, 2])
data_repeated = data.repeat((pnts.shape[0], 1, 1)) #* torch.Size([90000, 1000, 2])
#* repeat the the whole data 90000 times along dim 0
(data_repeated[0] == data).all() #* True

repeated_batches = data_repeated.view(-1, pnts.shape[0], 2) #* torch.Size([90000, 1000, 2])
repeated_data_permuted = data_repeated.permute(1, 0, 2)  #* torch.Size([90000, 1000, 2])
print("it's not premutation ", (repeated_data_permuted == repeated_batches).all()) #* True
print("repeated grid for each point in the data ", (repeated_batches[0] == repeated_batches[1]).all())
print(repeated_batches.shape)


## WHy does this work ?
- `.view(-1, 90000, 2)` will take the first 90000 elements and reshape them into a 90000x2 matrix
- since elements are repeated each 1000x2, we get 90x1000x2 matrix
- this ensure that the each 90000x2 matrix has the same elements
- changing the data size or the grid size will not break the code