In [None]:
import torch
import torch.distributions as distributions
import torch.nn as nn
import torch.nn.functional as f
import torchvision
from torch.optim import Adam
from torch.utils.data import SubsetRandomSampler, DataLoader
from scipy import interpolate
import numpy as np

data = torchvision.datasets.MNIST(".", download=True, train=True)
new_data = torchvision.datasets.MNIST(".", download=True, train=False, transform=None)

use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

In [None]:
from matplotlib.pyplot import imshow
images = []
for i, (image, label) in enumerate(data):
    if i > 5000:
        break
    image_mod = np.array(image) 
    image_x = 28
    image_y = 28
    hole_size_x = 10
    hole_size_y = 10
    x = np.zeros((image_x*image_y,))
    y = np.zeros((image_x*image_y,))
    z = np.zeros((image_x*image_y,))
    counter = 0
    for i, row in enumerate(image_mod):
        for j, a in enumerate(row):
            x[counter] = j
            y[counter] = i
            z[counter] = a
            counter += 1    
    imshow(image_mod)
    img_center_x = image_x/2
    img_center_y = image_y/2
    hole_beg_x = img_center_x - hole_size_x/2
    hole_end_x = img_center_x + hole_size_x/2
    hole_beg_y = img_center_y - hole_size_y/2
    hole_end_y = img_center_y + hole_size_y/2
    mask = []
    #remove center rectangle
    for a in range(image_x*image_y):
        if not hole_beg_x < x[a] < hole_end_x or not hole_beg_y < y[a] < hole_end_x:
                mask.append(a)
    x = x[mask]
    y = y[mask]
    z = z[mask]
    # move points to fill hole
    x_old = np.copy(x)
    y_old = np.copy(y)
    for a in range(len(x)):
        if hole_beg_x <= x[a] <= hole_end_x or hole_beg_y <= y[a] <= hole_end_x:
            if x[a] > y[a] and (image_x - x[a]) > y[a]:
                y[a] *= (image_x - hole_size_x) / (image_x - abs(image_x - 2*x[a]))
                if np.isinf(y[a]):
                    print(a)
            elif x[a] < y[a] and (image_x - x[a]) < y[a]:
                x[a] *= (image_y - hole_size_y) / (image_y - abs(image_y - 2*y[a]))
            elif x[a] > y[a] and (image_x - x[a]) < y[a]:
                y[a] = (y[a] - image_y) * ((image_x - hole_size_x) / (image_x - abs(image_x - 2*x[a]))) + image_y
                if np.isinf(y[a]):
                    print(a)
            elif x[a] < y[a] and (image_x - x[a]) > y[a]:
                x[a] = (x[a] - image_x) * (image_y - hole_size_y) / (image_y - abs(image_y - 2*y[a])) + image_x
    x_new = np.arange(0,28,1)
    y_new = np.arange(0,28,1)
    x_new, y_new = np.meshgrid(x_new, y_new)
    z_new = interpolate.griddata((x, y), z, (x_new, y_new), method='linear')
    x_2 = interpolate.griddata((x, y), x_old, (x_new, y_new), method='linear')
    y_2 = interpolate.griddata((x, y), y_old, (x_new, y_new), method='linear')
    z_new[z_new<0]=0
    images.append(torch.from_numpy(np.stack([z_new, x_2, y_2])))

In [None]:
batch_size: int = 20
epoch: int = 10
lr: float = 0.01
test_split: float = 0.2

random_seed: int = 10
shuffle_dataset: bool = True

dataset_size: int = len(images)
indices: list = list(range(dataset_size))
split: int = int(np.floor(test_split * dataset_size))
if shuffle_dataset:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, test_indices = indices[split:], indices[:split]

train_sampler: SubsetRandomSampler = SubsetRandomSampler(train_indices)
test_sampler: SubsetRandomSampler = SubsetRandomSampler(test_indices)

train_loader: DataLoader = DataLoader(images, batch_size=batch_size, sampler=train_sampler)
test_loader: DataLoader = DataLoader(images, batch_size=batch_size, sampler=test_sampler)

In [None]:
k = 3
l = 1
n = 10*10

class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(3, 30, kernel_size=5)
        self.conv2 = nn.Conv2d(30, 60, kernel_size=5)
        # self.conv3 = nn.Conv2d(60, 90, kernel_size=5)
        self.linear1 = nn.Linear(320*3, 1256)
        self.linear2 = nn.Linear(1256, k + k * n + n * l * k + k * n)

    def forward(self, x: torch.Tensor):
        x = f.relu(f.max_pool2d(self.conv1(x), kernel_size=2))
        x = f.relu(f.max_pool2d(self.conv2(x), kernel_size=2))
        # x = f.relu(f.max_pool2d(self.conv3(x), kernel_size=2))
        x = x.view(-1, 320*3)
        x = f.relu(self.linear1(x))
        return self.linear2(x)


model: Network = Network().double()
optimizer = Adam(model.parameters(), lr)

In [None]:
epoch = 30
model = model.to('cpu')
def loss_function(x: torch.Tensor, input):
    x = x.view((len(input), -1))
    sum = torch.tensor(0).double().to(device)
    sum.requires_grad = True
    for i in range(len(input)):
        p:torch.Tensor = x[i][:k]
        p_tuple = torch.split(p, 1, 0)
        m:torch.Tensor = x[i][k:k+k*n].view(-1, k)
        m_tuple = torch.split(m, 1, 1)
        A:torch.Tensor = x[i][k+k*n:k+k*n+n*l*k].view(-1, l, k)
        A_tuple = torch.split(A, 1, 2)
        d:torch.Tensor = x[i][k+k*n+n*l*k:].view(-1,k)
        d_tuple = torch.split(d, 1, 1)
        dists = []
        for p, m, A, d in zip(p_tuple, m_tuple, A_tuple, d_tuple):
            dists.append((p, distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(m.view(-1), A.view(n*l, l), torch.abs(d).view(-1))))
        for p, dist in dists:
            sum.add(-p.log().add(-dist.log_prob(input[i][0][9:19,9:19].contiguous().view(-1))))
    return sum

for e in range(epoch):
    for i, x in enumerate(train_loader):
        optimizer.zero_grad()
        result = model(x.double())
        loss = loss_function(result, x)
        loss.backward()
    print("Epoch {} completed".format(e))

