In [None]:
import os
import sys

import numpy as np
import mlx.core as mx
import mlx.nn as nn
import matplotlib.pyplot as plt
from PIL import Image

from softgrad import Network
from softgrad.Checkpoint import Checkpoint
from softgrad.layer.reshape import Flatten
from softgrad.layer.shim import MLX
from softgrad.optim import SGD
from softgrad.function.activation import leaky_relu, softmax
from softgrad.function.loss import CrossEntropyLoss, cross_entropy_loss
from softgrad.layer.core import Linear, Activation

sys.path.append(os.path.abspath('..'))
from util.dataset import get_cifar10

In [None]:
# load model from checkpoint
network = Network(input_shape=(32, 32, 3))

# conv block 1
network.add_layer(MLX(nn.Conv2d(in_channels=3, out_channels=96, kernel_size=7)))
network.add_layer(Activation(leaky_relu))
network.add_layer(MLX(nn.MaxPool2d(2)))
# conv block 2
network.add_layer(MLX(nn.Conv2d(in_channels=96, out_channels=256, kernel_size=3)))
network.add_layer(Activation(leaky_relu))
network.add_layer(MLX(nn.MaxPool2d(2)))
# feed forward
network.add_layer(Flatten())
network.add_layer(Linear(1024))
network.add_layer(Activation(leaky_relu))
network.add_layer(Linear(10))

network.load(Checkpoint.read("checkpoints/simple_conv.pb"))

In [None]:
def normalize(x):
    min_x = mx.min(x)
    x -= min_x

    max_x = mx.max(x)
    x /= max_x

    return x

def viz_input(x):
    x = 255 * np.array(x[0])
    img_x = Image.fromarray(x.astype('uint8'))
    plt.imshow(img_x)

In [None]:
# create some noise
noise = mx.random.uniform(0, 1, (1, 32, 32, 3))
viz_input(noise)

In [None]:
# prepare to optimize the input image (noise)
network.freeze()
optimizer = SGD(eta=0.05, momentum=0.9, weight_decay=0.0005)
optimizer.bind_network(network)
optimizer.bind_loss_fn(cross_entropy_loss)

In [None]:
# Cifar-10 labels
#   0 -> Airplane
#   1 -> Automobile
#   2 -> Bird
#   3 -> Cat
#   4 -> Deer
#   5 -> Dog
#   6 -> Frog
#   7 -> Horse
#   8 -> Ship
#   9 -> Truck

In [None]:
# Generate from noise
x = mx.random.uniform(0, 0.0001, (1, 32, 32, 3))

y = np.zeros((1, 10))  # one-hot encoded label
y[0][2] = 1

for i in range(10000):
    grad = optimizer.step(x, y)
    x -= 0.001 * grad
    x = normalize(x)

viz_input(x)
print(np.array(softmax(network.forward(x))))

In [None]:
# load data
train, test = get_cifar10()
train, test = list(train), list(test)

In [None]:
# Enhance existing image
x = test[4]['image']
x = mx.array(x.reshape((1, *x.shape)))
y = test[4]['label']
y = mx.array(y.reshape((1, *y.shape)))

for i in range(1000):
    grad = optimizer.step(x, y)
    x -= 0.1 * grad
    x = normalize(x)

viz_input(x)
print(np.array(softmax(network.forward(x))))

In [None]:
# Better visualization from noise
# x = mx.random.uniform(0, 1, (1, 32, 32, 3))
x = mx.random.normal((1, 32, 32, 3))

y = np.zeros((1, 10))  # one-hot encoded label
y[0][9] = 1

for i in range(500):
    grad = optimizer.step(x, y)
    x -= 2.5 * grad + 0.2 * x

# x = 2.5 * x + 0.5
x = normalize(x)
viz_input(x)
print(np.array(softmax(network.forward(x))))

In [None]:
noise = mx.random.normal((1, 32, 32, 3))
y = network.forward(noise)
np.array(softmax(y))