In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import warnings
warnings.filterwarnings("ignore")

learning_rate = 0.01
momentum = 0.5
log_interval = 10000

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)



<torch._C.Generator at 0x7efd66fb2ab0>

In [2]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/tmp', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Resize((14,14)),
                                torchvision.transforms.Lambda(lambda x: torch.flatten(x)),
                             ])), shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/tmp', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Resize((14,14)),
                                torchvision.transforms.Lambda(lambda x: torch.flatten(x)),
                             ])), shuffle=True)

In [3]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(f"example_data.shape: {example_data.shape}")
print(f"example_targets.shape: {example_targets.shape}")

example_data.shape: torch.Size([1, 196])
example_targets.shape: torch.Size([1])


In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(196, 40)
        self.fc2 = nn.Linear(40, 20)
        self.fc3 = nn.Linear(20, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

In [5]:
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)

In [6]:
train_losses = []
train_counter = []

def train(epoch):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

In [7]:
train(3)



In [8]:
network.eval()

with torch.no_grad():
    pred = network(example_data)
print(f"Prediction: {pred.argmax()}")
print(f"Real Value: {example_targets.item()}")

Prediction: 3
Real Value: 3


In [9]:
torch.onnx.export(network, example_data, "mnist_3.onnx")

In [22]:
def to_fixed_point(val, bits):
    return round(val * (2**bits))

def mnist_image_to_fixed_point(data):
    return [to_fixed_point(val.item(), 16) for val in data]


def generate_input_cairo(data):
    values = mnist_image_to_fixed_point(data)
    values = [f"FixedTrait::<FP16x16>::new({val}, {'true' if val < 0 else 'false'})" for val in values]
    return ",\n ".join(values)

input_cairo = generate_input_cairo(example_data[0])

with open("mnist_cairo/src/input.cairo", "w") as f:
    f.write("""
use array::{SpanTrait, ArrayTrait};
use orion::operators::tensor::{TensorTrait, FP16x16Tensor, Tensor};
use orion::numbers::{FixedTrait, FP16x16};
fn input() -> Tensor<FP16x16> {
    TensorTrait::<FP16x16>::new(
        array![196].span(),
        array![
    """)
    f.write(input_cairo)
    f.write("""
        ].span()
    )
}
    """)