# Image shifting

In [None]:
### Main header.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.autograd import Variable
from torchvision import datasets, transforms

In [None]:
### Image shifting defs.

def image_shift(image, x, y, pbc=False, fill=None):
    image=np.array(image)
    x, y = int(x), int(y)
    if fill is None:
        fill = np.min(image)

    if pbc:
        image = np.roll(image, x, axis=1)
        image = np.roll(image, y, axis=0)
    else:
        x = np.sign(x)*min(abs(x),image.shape[1])
        y = np.sign(y)*min(abs(y),image.shape[0])
    
        if x>0:
            image = np.concatenate((fill*np.ones((image.shape[0],x)), image[:,:image.shape[1]-x]), axis=1)
        else:
            image = np.concatenate((image[:,-x:image.shape[1]], fill*np.ones((image.shape[0],-x))), axis=1)
        if y>0:
            image = np.concatenate((fill*np.ones((y,image.shape[1])), image[:image.shape[0]-y,:]), axis=0)
        else:
            image = np.concatenate((image[-y:image.shape[0],:], fill*np.ones((-y,image.shape[1]))), axis=0)

    return image

def image_shift_polar(image, r, angle, pbc=False, fill=None):
    image=np.array(image)

    x = np.rint(np.cos(angle)*r)  # np.rint: Round integer.
    y = np.rint(np.sin(angle)*r)  #
    x, y = int(x), int(y)

    return image_shift(image, x, y, pbc, fill)

def image_shift_polar_random_angle(image, r, pbc=False, fill=None):
    angle = np.random.random()*2*np.pi
    return image_shift_polar(image, r, angle, pbc, fill)

def shift_dataset(dataset, f, f_args):
    dataset_shifted_numpy = []
    for data, target in dataset:
        data = 255*np.reshape(np.array(data), (28,28))
        dataset_shifted_numpy.append((f(data, *f_args), target))

    dataset_shifted = []
    for data, target in dataset_shifted_numpy:
        data = np.reshape(data, (28,28,1))
        f = transforms.ToTensor()
        data = f(data)
        dataset_shifted.append((data, target))
    return dataset_shifted

In [None]:
### Sample plots.

R   = 3
pbc = True

test_data = datasets.MNIST("../data/",
                           train=False,
                           download = True,
                           transform = transforms.Compose([transforms.ToTensor(),
                                                           transforms.Normalize((0.1307,),(0.3081,))]))

test_data_shifted = shift_dataset(test_data,
                                  image_shift_polar_random_angle,
                                  (R, pbc))

plt.figure(figsize=(15,7))

for i in range(40):
    plt.subplot(4,10,i+1)
    plt.axis("off")
    image = np.reshape(np.array(test_data_shifted[i][0]),(28,28))
    plt.imshow(image, cmap="gray")
plt.show()

In [None]:
### Testing and model defs.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels  = 1,
                               out_channels = 10,
                               kernel_size  = 5)
        
        self.conv2 = nn.Conv2d(in_channels  = 10,
                               out_channels = 20,
                               kernel_size  = 5)
        
        self.conv2_drop = nn.Dropout2d()
        
        self.fc1 = nn.Linear(in_features  = 320,
                             out_features = 50)
        
        self.fc2 = nn.Linear(in_features  = 50,
                             out_features = 10)
        
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def test(test_data, model):
    model.eval()
    correct = 0
    for data, target in test_data:
        data = data.unsqueeze(0)
        data = Variable(data)        
        result = model(data)
        correct += target == np.argsort(np.array(result.data[0]))[-1]
    return 100*correct/len(test_data)

In [None]:
### Compute!

model = Net()
model.load_state_dict(torch.load("./state.pkl"))

test_data = datasets.MNIST("../data/",
                           train=False,
                           download = True,
                           transform = transforms.Compose([transforms.ToTensor(),
                                                           transforms.Normalize((0.1307,),(0.3081,))]))

accuracies    = []
accuracies_pbc= []
shift_distances = range(100)

for r in shift_distances:
    data = shift_dataset(test_data,
                         image_shift_polar_random_angle,
                         (r, True))
    accuracies_pbc.append(test(data, model))

for r in shift_distances:
    data = shift_dataset(test_data,
                         image_shift_polar_random_angle,
                         (r, False))
    accuracies.append(test(data, model))

In [None]:
### Plots.

x_max = 100
plt.figure(figsize=(15,3))
plt.subplot(1,2,1)

plt.plot(shift_distances[:x_max], accuracies_pbc[:x_max])
plt.plot(shift_distances[:x_max], accuracies[:x_max])
plt.ylabel("Accuracy")
plt.xlabel("Shift (pixels)")
plt.title("Plot")
plt.legend(["pbc", "no pbc"])
plt.grid()

plt.subplot(1,2,2)

plt.plot(shift_distances[:15], accuracies_pbc[:15])
plt.plot(shift_distances[:15], accuracies[:15])

plt.ylabel("Accuracy")
plt.xlabel("Shift (pixels)")
plt.title("Plot")
plt.legend(["pbc", "no pbc"])
plt.grid()

plt.show()

Peaks:
$$ x_n = 28 \sqrt{1+n^2}$$

$$n = 1,2,3,\ldots$$