# MNIST with linear optical neural network

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms

In [3]:
#MNIST dataset
MNIST_train = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
MNIST_test = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

#Dataloaders
train_loader_MNIST = torch.utils.data.DataLoader(dataset=MNIST_train, batch_size=60000, shuffle=True)
test_loader_MNIST = torch.utils.data.DataLoader(dataset=MNIST_test, batch_size=10000, shuffle=False)


In [4]:
X_train_MNIST, Y_train_MNIST = next(iter(train_loader_MNIST))
X_train_MNIST = np.squeeze(X_train_MNIST.numpy())
Y_train_MNIST = Y_train_MNIST.numpy()

X_test_MNIST, Y_test_MNIST = next(iter(test_loader_MNIST))
X_test_MNIST = np.squeeze(X_test_MNIST.numpy())
Y_test_MNIST = Y_test_MNIST.numpy()

#### Here we try different FFFTs

- On the numpy arrays we just created before.
- Numpy, scipy, cupy, torch

In [5]:
import time
import scipy
import pyfftw

In [6]:
# let's test numpy's FFT 

start_time = time.time()

X_train_FFT_1 = np.fft.fft2(X_train_MNIST, axes=(-2,-1))

end_time = time.time()

npumpy_FFT_time = end_time - start_time

print(f'Time for numpy FFT: {end_time - start_time}')

Time for numpy FFT: 0.873779296875


In [7]:
# let's test scipy's FFT

start_time = time.time()

X_train_FFT_2 = scipy.fft.fft2(X_train_MNIST, axes=(-2,-1))

end_time = time.time()

scipy_FFT_time = end_time - start_time

print(f'Time for scipy FFT: {end_time - start_time}')


Time for scipy FFT: 0.7127516269683838


In [8]:
# let's test fftw's FFT, it is natively faster and supports easy multithreading

start_time = time.time()
pyfftw.config.NUM_THREADS = 8
X_train_FFT_3 = pyfftw.interfaces.numpy_fft.fft2(X_train_MNIST, axes=(-2,-1))

end_time = time.time()

scipy_FFT_time = end_time - start_time

print(f'Time for scipy FFT: {end_time - start_time}')

Time for scipy FFT: 0.3736300468444824
