<a href="https://colab.research.google.com/github/SimonGiebenhain/ma_proj/blob/master/Kopie_von_GON.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# requirements
import torch
import torch.nn as nn
import torchvision
import numpy as np

# colab requirements
from IPython.display import clear_output
import matplotlib.pyplot as plt
from time import sleep

In [None]:
# image data
dataset_name = 'mnist' # ['mnist', 'fashion']
img_size = 28
n_channels = 1
img_coords = 2

# training info
lr = 1e-4
batch_size = 64
num_latent = 10
hidden_features = 128
num_layers = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# create the GON network (a SIREN as in https://vsitzmann.github.io/siren/)
class SirenLayer(nn.Module):
    def __init__(self, in_f, out_f, w0=30, is_first=False, is_last=False):
        super().__init__()
        self.in_f = in_f
        self.w0 = w0
        self.linear = nn.Linear(in_f, out_f)
        self.is_first = is_first
        self.is_last = is_last
        self.init_weights()
    
    def init_weights(self):
        b = 1 / self.in_f if self.is_first else np.sqrt(6 / self.in_f) / self.w0
        with torch.no_grad():
            self.linear.weight.uniform_(-b, b)

    def forward(self, x):
        x = self.linear(x)
        return x if self.is_last else torch.sin(self.w0 * x)

class MLPLayer(nn.Module):
    def __init__(self, in_f, out_f, is_last):
        super().__init__()
        self.in_f = in_f
        self.is_last = is_last
        self.linear = nn.Linear(in_f, out_f)

    def forward(self, x):
        x = self.linear(x)
        return x if self.is_last else torch.nn.functional.relu(x)


def gon_model(dimensions):
    first_layer = SirenLayer(dimensions[0], dimensions[1], is_first=True)
    other_layers = []
    for dim0, dim1 in zip(dimensions[1:-2], dimensions[2:-1]):
        other_layers.append(SirenLayer(dim0, dim1))
    final_layer = SirenLayer(dimensions[-2], dimensions[-1], is_last=True)
    return nn.Sequential(first_layer, *other_layers, final_layer)

def simple_model(dimensions):
    other_layers = []
    for dim0, dim1 in zip(dimensions[0:-2], dimensions[1:-1]):
        other_layers.append(MLPLayer(dim0, dim1, is_last=False))
    final_layer = MLPLayer(dimensions[-2], dimensions[-1], is_last=True)
    return nn.Sequential(*other_layers, final_layer)

def ging_model(dimensions):
    first_layer = SirenLayer(dimensions[0], dimensions[1], is_first=True)
    other_layers_trainee = []
    for dim0, dim1 in zip(dimensions[1:-2], dimensions[2:-1]):
        other_layers_trainee.append(SirenLayer(dim0, dim1))
    other_layers_trainer = []
    for dim0, dim1 in zip(dimensions[1:-2], dimensions[2:-1]):
        other_layers_trainer.append(SirenLayer(dim0, dim1))
    final_layer_trainee = SirenLayer(dimensions[-2], dimensions[-1], is_last=True)
    final_layer_trainer = SirenLayer(dimensions[-2], 6, is_last=True)
    return nn.Sequential(first_layer, *other_layers_trainee, final_layer_trainee), nn.Sequential(first_layer, *other_layers_trainer, final_layer_trainer)

In [None]:
###### helper functions #####
def get_mgrid(sidelen, dim=2):
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

def slerp(a, b, t):
    omega = torch.acos((a/torch.norm(a, dim=1, keepdim=True)*b/torch.norm(b, dim=1, keepdim=True)).sum(1)).unsqueeze(1)
    res = (torch.sin((1.0-t)*omega)/torch.sin(omega))*a + (torch.sin(t*omega)/torch.sin(omega)) * b
    return res

def slerp_batch(model, z, coords):
    lz = z.data.clone().squeeze(1)
    col_size = int(np.sqrt(z.size(0)))
    src_z = lz.data[:col_size].repeat(col_size,1)
    z1, z2 = lz.data.split(lz.shape[0]//2)
    tgt_z = torch.cat([z2, z1])
    tgt_z = tgt_z[:col_size].repeat(col_size,1)
    t = torch.linspace(0,1,col_size).unsqueeze(1).repeat(1,col_size).contiguous().view(batch_size,1).contiguous().to(device)
    z_slerp = slerp(src_z, tgt_z, t)
    z_slerp_rep = z_slerp.unsqueeze(1).repeat(1,coords.size(1),1) 
    g_slerp = model(torch.cat((coords, z_slerp_rep), dim=-1))
    return g_slerp

def gon_sample(model, recent_zs, coords):
    zs = torch.cat(recent_zs, dim=0).squeeze(1).cpu().numpy()
    mean = np.mean(zs, axis=0)
    cov = np.cov(zs.T)
    sample = np.random.multivariate_normal(mean, cov, size=batch_size)
    sample = torch.tensor(sample).unsqueeze(1).repeat(1,coords.size(1),1).to(device).float()
    model_input = torch.cat((coords, sample), dim=-1)
    return model(model_input)

# This function returns the indices of bounding boxes of radius 'radius' around 
# the sampled locations 'loc_samps'.
# @Arguments:
#   - loc_samps: tensor of dimensions [batch_size, sample_size, dim]
#   - radius: radius of window around sampled location
#   - img_size: size of complete image
# @Returns:
#   - windos: window indices of dimensions [batch_size, sample_size, (2*radius+1)^dim, dim], 
#   where the size of the second to last dimension might be smaller, if the window
#   would reach outside of the image otherwise.
#   - ds: distances of sampled location to every point in the window
def get_windows(loc_samps, radius, img_size, smoothness):
  batch_size = loc_samps.shape[0]
  sample_size = loc_samps.shape[1]
  dims = loc_samps.shape[2]

  #compute indices
  scaled_samps = (loc_samps+1)/2*(img_size-1)
  centers = torch.round(scaled_samps).long()
  grid_x, grid_y = torch.meshgrid(torch.arange(-radius, radius+1, device=device, dtype=torch.long), torch.arange(-radius, radius+1, device=device, dtype=torch.long))
  grid = torch.stack([grid_x, grid_y], dim=2)
  windows = grid.view(1, 1, -1, dims) + centers.view(batch_size, sample_size, 1, dims)
  valid_locs = torch.logical_and(windows <= img_size - 1, windows >= 0)
  
  #compute distances
  ds = torch.sqrt(torch.sum((scaled_samps.view(batch_size, sample_size, 1, dims) - windows)**2, dim=3))

  

  #compute weights
  w = torch.where(valid_locs.all(3), torch.softmax(-smoothness*ds, dim=2), torch.zeros(1, device=device))  
    # investigated softmin temperature, 5 seems to be good
    #topw, _ = torch.topk(ds, 5, dim=2)
    #topd, _ = torch.topk(ds, 5, dim=2, largest=False)
    #print('distances')
    #print(topd[:3, :3, :])
    #print('weights')
    #print(topw[:3, :3, :])
  #w = torch.zeros([batch_size, sample_size, windows.shape[2]]).to(device)
  #w[valid_locs.all(3)] = 
  ##invalid locations get wieght 0
  #w[(~valid_locs).any(3)] = 0

  #invalid locations get assigned index 0 in order to prevent indexing error
  windows[~valid_locs] = 0

  return windows, w

def linearize_idx(idx, img_size):
  dims = idx.shape[-1]
  lin_idx = torch.zeros(idx.shape[:-1], dtype=torch.long).to(device)
  for d in range(dims):
    lin_idx += idx[:, :, :, d] * img_size**(dims-1-d)
  return lin_idx

from torch.utils.data import Dataset, DataLoader
class IdxedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        
    def __getitem__(self, index):
        data, target = self.dataset[index]    
        return data, target, index

    def __len__(self):
        return len(self.dataset)

In [None]:
# TEST: linearize_idx()
#xx, yy, zz = torch.meshgrid(torch.arange(2), torch.arange(2), torch.arange(2))
#print(xx)
#idx = torch.stack([xx, yy, zz], dim=3)
#print(idx)
#print(idx.shape)
#I = torch.arange(8).view(2,2,2)
#print(I)
#print(I.view(-1))
#idx = idx.view(8,3)
#print(idx)
##print(idx[:, 0]*3**0)
##print(idx[:, 1]*3**1)
##print(idx[:, 0]*3**0 + idx[:, 1]*3**1)
#linearize_idx(idx.view(1, 1, 8, 3), 2)


In [None]:
##### load datasets #####
if dataset_name == 'mnist':
    dataset_train = torchvision.datasets.MNIST('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ]))
    dataset_test = torchvision.datasets.MNIST('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ]))
if dataset_name == 'fashion':
    dataset_train = torchvision.datasets.FashionMNIST('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ]))
    dataset_test = torchvision.datasets.FashionMNIST('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ]))
print(dataset_train)
print(dataset_test)
dataset_train = IdxedDataset(dataset_train)
dataset_test = IdxedDataset(dataset_test)
#loader = DataLoader(dataset,
#                    batch_size=1,
#                    shuffle=True,
#                    num_workers=1)

train_loader = torch.utils.data.DataLoader(dataset_train, sampler=None, shuffle=True, batch_size=batch_size, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset_test, sampler=None, shuffle=True, batch_size=batch_size, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

In [None]:
##### SETUP MODEL #####

# define GON architecture, for example gon_shape = [34, 256, 256, 256, 256, 1]
gon_shape = [img_coords+num_latent] + [hidden_features]*num_layers + [n_channels]
    #num_noise = 2
    #adv_shape = [num_latent] + [hidden_features]*3 + [img_coords]
F = gon_model(gon_shape).to(device)
    #adv = gon_model(adv_shape).to(device)

optim_main = torch.optim.Adam(lr=lr, params=F.parameters())
    #optim_main_fine = torch.optim.Adam(lr=lr, params=F.parameters())
    #optim_adv = torch.optim.Adam(lr=lr, params=adv.parameters())

c = torch.stack([get_mgrid(img_size, 2) for _ in range(batch_size)]).to(device) # coordinates

recent_zs = []
print(f'> Number of parameters {len(torch.nn.utils.parameters_to_vector(F.parameters()))}')

In [None]:
##### TRAINING #####

#z = torch.randn(dataset.__len__(), num_latent).to(device) #torch.zeros(dataset.__len__(), num_latent).to(device)
#z.requires_grad_(True)
#optim_latent = torch.optim.Adam(params=[z], lr=1e-3)
#for ep in range(50):
#  print('EPOCH: {}'.format(ep))
#  for step, (x, t, idx) in enumerate(train_loader):
for step in range(501):
    # sample a batch of data
    x, t, idx = next(train_iterator)
    x, t, idx = x.to(device), t.to(device), idx.to(device)
    x = x.permute(0, 2, 3, 1)
    x = x.reshape(batch_size, -1, n_channels)
    z = torch.zeros(batch_size, 1, num_latent).to(device).requires_grad_()
    z_rep_main = z.repeat(1,c.size(1),1)
    g = F(torch.cat((c, z_rep_main), dim=-1))
    L_inner = ((g - x)**2).sum(1).mean()
    z_main = -torch.autograd.grad(L_inner, [z], create_graph=True, retain_graph=True)[0]
    z_rep_main = z_main.repeat(1, c.size(1), 1)
    #z_rep = z[idx, :]
    #z_main = z_rep.unsqueeze(1)
    #z_rep_main = z_main.repeat(1, c.size(1), 1)
    
    
    
    #z_rep_adv = z_main.repeat(1, sample_size, 1).detach().clone()

    #g = F(torch.cat((c, z_rep_main), dim=-1))
    #L_outer = ((g - x)**2).sum(1).mean()
    #print('normal loss:{}'.format(L_outer.data))
    #optim_main.zero_grad()
    #L_outer.backward()
    #optim_main.step()


    #noise = torch.from_numpy(np.random.normal(0, 1, [batch_size, sample_size, num_noise])).float().to(device)
    #samps = torch.sigmoid(adv(torch.cat([noise, z_rep_adv], dim=-1)))
    #samps_discr = torch.round(samps * img_size).long()
    #samps_discr = samps_discr[:, :, 0]*img_size + samps_discr[:, :, 1]
    #if step % 10 == 0:
    #  print(samps_discr)

    # now with z as our new latent points, optimise the data fitting loss
    g = F(torch.cat((c, z_rep_main), dim=-1))
    L_outer = ((g - x)**2).sum(1).mean()
    #L_outer = ((g - x[torch.arange(batch_size).unsqueeze(1), samps_discr, :])**2).mean(1).mean()
    if step % 20 == 0:
      print('sampled loss:{}'.format(L_outer.data))
    #optim_adv.zero_grad()
    optim_main.zero_grad()
    #optim_latent.zero_grad()
    L_outer.backward()
    #for p in adv.parameters():
    #  if p.grad is not None: 
    #   p.grad.data.mul_(-1) 
    #optim_adv.step()
    optim_main.step()
    #optim_latent.step()




    # compute sampling statistics
    recent_zs.append(z_main.detach())
    recent_zs = recent_zs[-100:]

    if step % 100 == 0 and step > 0:
        print(f"Step: {step}   Loss: {L_outer.item():.3f}")
        z_rep = z_main.repeat(1, c.shape[1], 1)
        g = F(torch.cat((c, z_rep), dim=-1))
        # plot reconstructions, interpolations, and samples
        recons = torchvision.utils.make_grid(torch.clamp(g, 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size), nrow=8)
        slerps = torchvision.utils.make_grid(torch.clamp(slerp_batch(F, z_main.data, c), 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size), nrow=8)
        sample = torchvision.utils.make_grid(torch.clamp(gon_sample(F, recent_zs, c), 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size))

        plt.title('Reconstructions')
        plt.imshow(recons[0,:,:].reshape(242,242,1).repeat(1,1,3).cpu().data.numpy())
        plt.figure()
        plt.title('Spherical Interpolations')
        plt.imshow(slerps[0,:,:].reshape(242,242,1).repeat(1,1,3).cpu().data.numpy())
        plt.figure()
        plt.title('Samples')
        plt.imshow(sample[0,:,:].reshape(242,242,1).repeat(1,1,3).cpu().data.numpy())
        plt.show()
        sleep(1)

In [None]:
##### TESTING #####

recent_zs_test = []
cum_loss_test = 0
num_batches_test = 0

for step, (x, t, idx) in enumerate(test_loader):
    x, t, idx = x.to(device), t.to(device), idx.to(device)
    x = x.permute(0, 2, 3, 1)
    x = x.reshape(batch_size, -1, n_channels)
    z = torch.zeros(batch_size, 1, num_latent).to(device).requires_grad_()
    z_rep_main = z.repeat(1,c.size(1),1)
    g = F(torch.cat((c, z_rep_main), dim=-1))
    L_inner = ((g - x)**2).sum(1).mean()
    z_main = -torch.autograd.grad(L_inner, [z], create_graph=True, retain_graph=False)[0]
    z_rep_main = z_main.repeat(1, c.size(1), 1)

    # now with z as our new latent points, optimise the data fitting loss
    g = F(torch.cat((c, z_rep_main), dim=-1))
    cum_loss_test += ((g - x)**2).sum(1).mean().item()
    num_batches_test += 1
    # compute sampling statistics
    recent_zs_test.append(z_main.detach())
    recent_zs_test = recent_zs_test[-100:]


avg_loss_test = cum_loss_test / num_batches_test
print('Average test loss: {}'.format(avg_loss_test))

z_rep = z_main.repeat(1, c.shape[1], 1)
g = F(torch.cat((c, z_rep), dim=-1))
# plot reconstructions, interpolations, and samples
recons = torchvision.utils.make_grid(torch.clamp(g, 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size), nrow=8)
slerps = torchvision.utils.make_grid(torch.clamp(slerp_batch(F, z_main.data, c), 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size), nrow=8)
sample = torchvision.utils.make_grid(torch.clamp(gon_sample(F, recent_zs, c), 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size))

plt.title('Reconstructions')
plt.imshow(recons[0,:,:].reshape(242,242,1).repeat(1,1,3).cpu().data.numpy())
plt.figure()
plt.title('Spherical Interpolations')
plt.imshow(slerps[0,:,:].reshape(242,242,1).repeat(1,1,3).cpu().data.numpy())
plt.figure()
plt.title('Samples')
plt.imshow(sample[0,:,:].reshape(242,242,1).repeat(1,1,3).cpu().data.numpy())
plt.show()


In [None]:
num_latent_adv = 3
num_noise = 10

adv_shape = [num_latent + num_noise] + [128]*4 + [img_coords]
adv = simple_model(adv_shape).to(device)
optim_adv = torch.optim.Adam(lr=1e-3, params=adv.parameters())


In [None]:
sample_size = 300
#batch_size = 64
#c = torch.stack([get_mgrid(img_size, 2) for _ in range(batch_size)]).to(device) # coordinates
#train_loader = torch.utils.data.DataLoader(dataset_train, sampler=None, shuffle=True, batch_size=batch_size, drop_last=True)
#test_loader = torch.utils.data.DataLoader(dataset_test, sampler=None, shuffle=True, batch_size=batch_size, drop_last=True)



In [None]:
direct_idx = torch.zeros(batch_size, sample_size, 2).to(device)#torch.FloatTensor(batch_size, sample_size, 2).uniform_(-1, 1).to(device)#
direct_idx.requires_grad_(True)
optim_direct = torch.optim.Adam([direct_idx], lr=1e-3)
#ZADV = torch.zeros(batch_size, sample_size, num_latent_adv).to(device)
#ZADV.requires_grad_(True)
#optim_latent = torch.optim.Adam([ZADV], lr= 1e-4)



# TODO
- check if boundary cases work with entropy as well


In [None]:
from tqdm.notebook import trange
from tqdm.notebook import tqdm

OPT_STEPS = 0
N_EPOCHS = 10
locs = np.zeros([N_EPOCHS, batch_size, sample_size, 2])
F.eval()
recent_samples_x = np.zeros([250000])
recent_samples_y = np.zeros([250000])
heatmap_history = []
count = 0

names = ["Epoch", "SamplingLoss", "Entropy"]
layout = "{!s:15} " * len(names)

def print_stats(epoch, values, decimals=6):
    layout = "{!s:^15}" + " {!s:15}" * len(values)
    values = [epoch] + list(np.round(values, decimals))
    print(layout.format(*values))

for epoch in trange(N_EPOCHS, desc='Epoch'):
  sampling_loss_train = 0
  #MSGAN_loss_train = 0
  entropy_train = 0

  tq = tqdm(iter(train_loader), leave=False, total=len(train_loader), position=0)
  for i, (x, t, idx) in enumerate(tq):
    x, t = x.to(device), t.to(device)
    x = x.permute(0, 2, 3, 1)
    x = x.reshape(batch_size, -1, n_channels)

    z = torch.zeros(batch_size, 1, num_latent).to(device).requires_grad_()
    z_rep_main = z.repeat(1,c.size(1),1)
    g = F(torch.cat((c, z_rep_main), dim=-1))
    L_inner = ((g - x)**2).sum(1).mean()
    z_main = -torch.autograd.grad(L_inner, [z], create_graph=True, retain_graph=True)[0]
    z_rep_main = z_main.repeat(1, sample_size, 1)

    #z = torch.zeros(batch_size, 1, num_latent).to(device).requires_grad_()
    #z_adv = z.repeat(1, sample_size, 1)
    #samps_inner = torch.tanh(adv(z_adv))
    #samps_discr_inner = torch.floor((samps_inner+1)/2 * img_size).long().detach()
    #samps_discr_inner = samps_discr_inner[:, :, 0]*img_size + samps_discr_inner[:, :, 1]
    #g = F(torch.cat((samps_inner, z_rep_main), dim=-1))
    #L_adv_inner = ((g - x[torch.arange(batch_size).unsqueeze(1), samps_discr_inner, :])**2).sum(1).mean()
    #z_adv = -torch.autograd.grad(L_adv_inner, [z], create_graph=True, retain_graph=True)[0]
    #z_rep_adv = z_adv.repeat(1, sample_size, 1)
    #z_rep_adv = ZADV.repeat(1, sample_size, 1)

    #g = F(torch.cat((c, z_rep_main), dim=-1))
    #L_outer = ((g - x)**2).sum(1).mean()
    #print('normal loss:{}'.format(L_outer.data))
    #optim_main.zero_grad()
    #L_outer.backward()
    #optim_main.step()


    noise1 = torch.from_numpy(np.random.normal(0, 1, [batch_size, sample_size, num_noise])).float().to(device)
    #noise2 = torch.from_numpy(np.random.normal(0, 1, [batch_size, sample_size, num_noise])).float().to(device)

    samps1 = torch.tanh(adv(torch.cat([noise1, z_rep_main.detach()], dim=2)))
    #samps2 = torch.tanh(adv(torch.cat([noise2, z_rep_main.detach()], dim=2)))

    lin_samps = samps1.view(-1, 2)
    n_samps = lin_samps.shape[0] 
    if count + n_samps > recent_samples_x.shape[0]:
      overflow = count + n_samps - recent_samples_x.shape[0]
      recent_samples_x[count:] = lin_samps[:n_samps - overflow, 0].detach().cpu().data.numpy()
      recent_samples_y[count:] = lin_samps[:n_samps - overflow, 1].detach().cpu().data.numpy()
      recent_samples_x[: overflow] = lin_samps[-overflow:, 0].detach().cpu().data.numpy()
      recent_samples_y[: overflow] = lin_samps[-overflow:, 1].detach().cpu().data.numpy()
      count = overflow
      heatmap, _, _ = np.histogram2d(recent_samples_x, recent_samples_y, bins=(28, 28), range=[[-1, 1], [-1, 1]])
      heatmap_history.append(heatmap)
    else:
      recent_samples_x[count: count + n_samps] = lin_samps[:, 0].detach().cpu().data.numpy()
      recent_samples_y[count: count + n_samps] = lin_samps[:, 1].detach().cpu().data.numpy()
      count = count + n_samps

    #samps1 = torch.tanh(adv(z_rep_main))

    #samps1 = torch.FloatTensor(batch_size, sample_size,2).uniform_(-1, 1).to(device) #direct_idx
    #samps1 = direct_idx
    if i == 0:
      locs[epoch, :, :, :] = samps1.detach().cpu().data


    windows1, w1 = get_windows(samps1, radius=1, img_size=img_size, smoothness=5)
    #windows2, w2 = get_windows(samps2, radius=1, img_size=img_size, smoothness=5)


    # linearize window index, should have shape [batch_size x sample_size * window_size]
    win_lin1 = linearize_idx(windows1, img_size).view(batch_size, -1)
    x_cropped1 = x[torch.arange(batch_size).unsqueeze(1), win_lin1, :].view(batch_size, sample_size, -1) #window_size
    target1 = torch.matmul(w1.view(batch_size, sample_size, 1, -1), x_cropped1.view(batch_size, sample_size, -1, 1))

    #win_lin2 = linearize_idx(windows2, img_size).view(batch_size, -1)
    #x_cropped2 = x[torch.arange(batch_size).unsqueeze(1), win_lin2, :].view(batch_size, sample_size, -1)
    #target2 = torch.matmul(w2.view(batch_size, sample_size, 1, -1), x_cropped2.view(batch_size, sample_size, -1, 1))

    #ds = torch.sqrt(((c - samps1.repeat(1, 784, 1))**2).sum(2))
    #w = torch.softmax(-50*ds, dim=1)
    #print(torch.topk(w[:5, :5, :], 5, dim=2))
        

    #target = torch.bmm(w.view(batch_size, 1, 784), x.view(batch_size, 784, 1))



    mesh_size = 3
    #TODO get x and y range for the meshgrid: what to do for boundary cases
    # produce meshgrid, stack meshes, get distances, weights and targets as before
    
    #samps_ceil = torch.ceil(((samps1.clamp(-1,1)+1)/2)*(img_size-1)).detach()
    #samps_floor = torch.floor(((samps1.clamp(-1,1)+1)/2)*(img_size-1)).detach()
    #xl = samps_floor[:, :, 0]
    #xu = samps_ceil[:, :, 0]
    #yl = samps_floor[:, :, 1]
    #yu = samps_ceil[:, :, 1]

    #p1 = torch.stack([xl, yl], 2)
    #p2 = torch.stack([xu, yl], 2)
    #p3 = torch.stack([xl, yu], 2)
    #p4 = torch.stack([xu, yu], 2)
    #w1 = torch.sqrt(((samps1 - p1)**2).sum(2)).requires_grad_(True)
    #w2 = torch.sqrt(((samps1 - p2)**2).sum(2)).requires_grad_(True)
    #w3 = torch.sqrt(((samps1 - p3)**2).sum(2)).requires_grad_(True)
    #w4 = torch.sqrt(((samps1 - p4)**2).sum(2)).requires_grad_(True)
    #w = torch.stack([w1, w2, w3, w4], dim=2)
    #w = w / torch.sum(w, dim=2).unsqueeze(2)
       #w = torch.softmax(w, dim=2)
    #p1_idx = p1[:, :, 0] * img_size + p1[:, :, 1]
    #p2_idx = p2[:, :, 0] * img_size + p2[:, :, 1]
    #p3_idx = p3[:, :, 0] * img_size + p3[:, :, 1]
    #p4_idx = p4[:, :, 0] * img_size + p4[:, :, 1]
    #target = w[:, :, 0].unsqueeze(2) * x[torch.arange(batch_size).unsqueeze(1), p1_idx.long(), :] + \
    #         w[:, :, 1].unsqueeze(2) * x[torch.arange(batch_size).unsqueeze(1), p2_idx.long(), :] + \
    #         w[:, :, 2].unsqueeze(2) * x[torch.arange(batch_size).unsqueeze(1), p3_idx.long(), :] + \
    #         w[:, :, 3].unsqueeze(2) * x[torch.arange(batch_size).unsqueeze(1), p4_idx.long(), :]
            



    #noise2 = torch.from_numpy(np.random.normal(0, 1, [batch_size, sample_size, num_noise])).float().to(device)
    #samps2 = torch.tanh(adv(torch.cat([noise2, z_rep_adv], dim=-1)))
    #samps2 = torch.FloatTensor(1,1,2).uniform_(-1, 1).to(device)
    #samps2 = samps2.repeat(samps1.shape[0], samps1.shape[1], 1)


    #samps_discr1 = samps_discr1[:, :, 0]*img_size + samps_discr1[:, :, 1]
    #samps_discr2 = torch.floor((samps2+1)/2 * img_size).long().detach()
    #samps_discr2 = samps_discr2[:, :, 0] * img_size + samps_discr2[:, :, 1]
    
    #samps = torch.cat([samps1, samps2], 1)
    #samps_discr = torch.cat([samps_discr1, samps_discr2], 1)
    #if step % 20 == 0:
    #  print(samps_discr1)

    #print(noise1)
    #print(noise2)

    #### ENFORCING DIVERSITY ####
    #lz = ((torch.abs(samps1 - samps2)).sum(2) / (
    #        torch.abs(noise1 - noise2)).sum(2)).sum(1).mean()
    
    #lz = 1 / (lz/1000 + 1e-12)

    # TODO abstruct '2' to 'dims'
    counts_decompressed1 = torch.zeros([batch_size, sample_size, img_size**2]).to(device)
    counts_decompressed1[torch.arange(batch_size).unsqueeze(1).unsqueeze(2), torch.arange(sample_size).unsqueeze(0).unsqueeze(2), win_lin1.view(batch_size, sample_size, -1)] = w1
    counts1 = (counts_decompressed1.sum(1))/(sample_size)
    entropy1 = -torch.sum(torch.log1p(counts1), dim=1).mean()

    #counts_decompressed2 = torch.zeros([batch_size, sample_size, img_size**2]).to(device)
    #counts_decompressed2[torch.arange(batch_size).unsqueeze(1).unsqueeze(2), torch.arange(sample_size).unsqueeze(0).unsqueeze(2), win_lin2.view(batch_size, sample_size, -1)] = w2
    #counts2 = (counts_decompressed2.sum(1))/(sample_size)
    #entropy2 = -torch.sum(torch.log1p(counts2), dim=1).mean()





    #### OPTIMIZATION ####
    # now with z as our new latent points, optimise the data fitting loss
    g1 = F(torch.cat((samps1, z_rep_main.detach()), dim=-1))
    #g2 = F(torch.cat((samps2, z_rep_main.detach()), dim=-1))

    #L_outer = -((g - x[torch.arange(batch_size).unsqueeze(1), samps_discr1, :])**2).sum(1).mean() #- 3*lz
    L_outer1 = -((g1.squeeze() - target1.squeeze()).abs().mean(1).mean()) 
    #L_outer2 = -((g2.squeeze() - target2.squeeze()).abs().mean(1).mean())


    #if i % 20 == 0:
    #  print('sampled loss:{}'.format(L_outer1.item()))
    #  print('diversity loss (ModeSeekingGAN):{}'.format(lz.item()))
    L_outer = L_outer1#(L_outer1 + L_outer2) / 2
    entropy = entropy1#(entropy1+ entropy2) / 2
    L_total = L_outer + 7.5*entropy  #+0.005*lz
    sampling_loss_train += L_outer.item()
    #MSGAN_loss_train += lz.item()
    entropy_train += entropy.item()
    tq.set_postfix(loss=L_outer.item(), MSGAN_loss=lz.item(), entropy=entropy.item())

    #L_outer -= 0.4*( torch.var(samps1[:, :, 0]) + torch.var(samps1[:, :, 1]) )
    optim_adv.zero_grad()
    #optim_latent.zero_grad()
    #optim_main_fine.zero_grad()
    #optim_direct.zero_grad()
    L_total.backward()

    #for p in adv.parameters():
    #  if p.grad is not None: 
    #   p.grad.data.mul_(-1) 
    #torch.nn.utils.clip_grad_norm_(adv.parameters(), 0.01)
    #torch.nn.utils.clip_grad_norm_(direct_idx, 0.1)
    optim_adv.step()
    #optim_latent.step()
    #optim_main_fine.step()
    #optim_direct.step()
    #g = F(torch.cat((samps2, z_rep_main), dim=-1))
    #L_outer = ((g - x[torch.arange(batch_size).unsqueeze(1), samps_discr2, :])**2).sum(1).mean() #- 3*lz
    #if step % 20 == 0:
    #print('random loss:{}'.format(L_outer.item()))




    ## compute sampling statistics
    #recent_zs.append(z_main.detach())
    #recent_zs = recent_zs[-100:]

    #if step % 100 == 0 and step > 0:
    #    z_rep = z_main.repeat(1, c.shape[1], 1)
    #    g = F(torch.cat((c, z_rep), dim=-1))
    #    # plot reconstructions, interpolations, and samples
    #    recons = torchvision.utils.make_grid(torch.clamp(g, 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size), nrow=8)
    #    slerps = torchvision.utils.make_grid(torch.clamp(slerp_batch(F, z_main.data, c), 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size), nrow=8)
    #    sample = torchvision.utils.make_grid(torch.clamp(gon_sample(F, recent_zs, c), 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size))

    #    plt.title('Reconstructions')
    #    plt.imshow(recons[0,:,:].reshape(242,242,1).repeat(1,1,3).cpu().data.numpy())
    #    plt.figure()
    #    plt.title('Spherical Interpolations')
    #    plt.imshow(slerps[0,:,:].reshape(242,242,1).repeat(1,1,3).cpu().data.numpy())
    #    plt.figure()
    #    plt.title('Samples')
    #    plt.imshow(sample[0,:,:].reshape(242,242,1).repeat(1,1,3).cpu().data.numpy())
    #    plt.show()
    #    sleep(1)
  
  if epoch == 0:
        print(f"\n{layout.format(*names)}")
  print_stats(epoch, [sampling_loss_train / len(train_loader), entropy_train / len(train_loader)])
  plt.figure(figsize=(2, 2))
  plt.imshow(heatmap)
  plt.show()


heatmap, xedges, yedges = np.histogram2d(recent_samples_x, recent_samples_y, bins=(28, 28), range=[[-1, 1], [-1, 1]])
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

plt.figure()
plt.imshow(heatmap)
plt.show()

In [None]:
samps1.max()

In [None]:
for i in range(len(heatmap_history)):
  plt.figure()
  plt.imshow(heatmap_history[i].T, origin='lower')
  plt.show()

In [None]:
#noise1 = torch.from_numpy(np.random.normal(0, 1/2, [batch_size, sample_size, num_noise])).float().to(device)
#samps1 = torch.tanh(adv(torch.cat([noise1, z_rep_main], dim=2)))
#ds = torch.sqrt(((c - samps1.repeat(1, 784, 1))**2).sum(2))
#w = torch.softmax(ds, dim=1)
#target = torch.bmm(w.view(batch_size, 1, 784), x.view(batch_size, 784, 1))
#g = F(torch.cat((samps1, z_rep_main.detach()), dim=-1))
#print(g.shape)
#print(target.shape)
#L_samps = -((g - target).abs()).mean(1) #- 3*lz
#samps1 = direct_idx

error_map = np.zeros([28, 28])
for n in range(0, 10):
  print(((locs[0, n, :5, :]+1)/2)*(img_size-1))
  print(((locs[-1, n, :5, :]+1)/2)*(img_size-1))

  
  g = F(torch.cat((c, z_main.repeat(1, c.size(1), 1)), dim=-1))

  L = (g - x).abs()
  error_map += L.reshape(batch_size, img_size, img_size).sum(0).cpu().data.numpy()
  print(L.shape)
  I = L[n, :, 0].reshape(28, 28).cpu().data
  print(I.shape)
  plt.figure()
  plt.subplot(131)
  plt.imshow(I, vmin=0, vmax=1)
  print(I.max())
  #print(L_samps[n, :])
  plt.subplot(132)
  plt.imshow(x[n, :, 0].reshape(28,28,).cpu().data, vmin=0, vmax=1)
  plt.subplot(133)
  plt.imshow(g[n, :, 0].reshape(28,28,).cpu().data)
  plt.show()

plt.figure()
plt.imshow(error_map)
plt.show()

plt.figure()
plt.imshow(heatmap)
plt.show()

In [None]:
idx = 1
img = x.reshape(batch_size, 28, 28, 1)
print(torch.squeeze(img[idx, :, :, :]).shape)
plt.figure()
plt.imshow(torch.squeeze(img[idx, :, :, :]).cpu().data)

res = 28
c_hd = get_mgrid(res).reshape(1, res**2, 2).repeat(batch_size, 1, 1).to(device)
print(c_hd.shape)
g = F(torch.cat((c_hd, z_main.repeat(1, c_hd.size(1), 1)), dim=2))
print(g.shape)
#plt.figure()
#plt.imshow(g[idx, :, :].reshape(res,res).cpu().data)

res = 28
c_hd = get_mgrid(res).reshape(1, res**2, 2).repeat(batch_size, 1, 1).to(device)
print(c_hd.shape)
g = F(torch.cat((c_hd, z_main.repeat(1, c_hd.size(1), 1)), dim=2))
print(g.shape)
plt.figure()
plt.imshow(g[idx, :, :].reshape(res,res).cpu().data, vmin=0, vmax=1)

target = torch.zeros(batch_size, res**2, 1).to(device)
for i in range(res**2):
    samps1 = c_hd[:, i, :].unsqueeze(1)
    ds = torch.sqrt(((c - samps1.repeat(1, 784, 1))**2).sum(2))
    w = torch.softmax(-35*ds, dim=1)
    tmp = torch.bmm(w.view(batch_size, 1, 784), x.view(batch_size, 784, 1))
    target[:, i, :] = tmp.squeeze().unsqueeze(1)

L_samps = ((g - target).abs())

plt.figure()
plt.imshow(L_samps[idx, :].reshape(res,res).cpu().data)

plt.figure()
plt.imshow(target[idx, :, :].reshape(res, res).cpu().data, vmin=0, vmax=1)

In [None]:
mg = get_mgrid(5)
#print(mg)
print(img[0, 24, 13])
x_lin = x.reshape(batch_size, -1, n_channels)


In [None]:
# Set these to whatever you want for your gaussian filter
kernel_size = 15
sigma = 3
channels = 1

# Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
x_cord = torch.arange(kernel_size)
x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
y_grid = x_grid.t()
xy_grid = torch.stack([x_grid, y_grid], dim=-1)

mean = (kernel_size - 1)/2.
variance = sigma**2.

# Calculate the 2-dimensional gaussian kernel which is
# the product of two gaussian distributions for two different
# variables (in this case called x and y)
gaussian_kernel = (1./(2.*np.pi*variance)) *\
                  torch.exp(
                      -torch.sum((xy_grid - mean)**2., dim=-1) /\
                      (2*variance)
                  )
# Make sure sum of values in gaussian kernel equals 1.
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

# Reshape to 2d depthwise convolutional weight
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)

gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,
                            kernel_size=kernel_size, groups=channels, bias=False)

gaussian_filter.weight.data = gaussian_kernel
gaussian_filter.weight.requires_grad = False
plt.imshow(gaussian_filter)
plt.show()

In [None]:
torch.autograd.set_detect_anomaly(True)
for step in range(501):
    # sample a batch of data
    x, t = next(train_iterator)
    x, t = x.to(device), t.to(device)
    x = x.permute(0, 2, 3, 1)
    x = x.reshape(batch_size, -1, n_channels)
    #if step % 2 == 1:
    #  #TODO also sample here, or just use some of c
    #  # compute the gradients of the inner loss with respect to zeros (gradient origin)
    #  z = torch.zeros(batch_size, 1, num_latent).to(device).requires_grad_()
    #  z_rep = z.repeat(1,c.size(1),1)
    #  g = F(torch.cat((c, z_rep), dim=-1))
    #  L_inner = ((g - x)**2).sum(1).mean()
    #  z = -torch.autograd.grad(L_inner, [z], create_graph=True, retain_graph=True)[0]

    #  z_rep = z.repeat(1, c.size(1), 1)


      # now with z as our new latent points, optimise the data fitting loss
     # g = F(torch.cat((c, z_rep), dim=-1))
     # L_outer = ((g - x)**2).sum(1).mean()
     # print(L_outer)
     # optim_main.zero_grad()
     # L_outer.backward()
     # optim_main.step()
    #else:
      #TODO also sample here, or just use some of c
      # compute the gradients of the inner loss with respect to zeros (gradient origin)
    z = torch.zeros(batch_size, 1, num_latent).to(device).requires_grad_()
    z_rep_main = z.repeat(1,c.size(1),1)
    g = F(torch.cat((c, z_rep_main), dim=-1))
    L_inner = ((g - x)**2).sum(1).mean()
    z_main = -torch.autograd.grad(L_inner, [z], create_graph=True, retain_graph=True)[0]

    z_rep_main = z_main.repeat(1, sample_size, 1)

    z_rep_adv = z_rep_main.detach().clone()
      #g = F(torch.cat((c, z_rep_adv), dim=-1))
      #L_inner_adv = -((g - x)**2).sum(1).mean()
      #z_adv = -torch.autograd.grad(L_inner_adv, [z], create_graph=True, retain_graph=True)[0]

      #z_rep_adv = z_adv.repeat(1, sample_size, 1)

      # sample new points
    noise = torch.from_numpy(np.random.normal(0, 10, [batch_size, sample_size, num_noise])).float().to(device)
    samps = torch.sigmoid(adv(torch.cat([noise, z_rep_adv], dim=-1)))
    samps_discr = torch.round(samps * img_size).long()
    samps_discr = samps_discr[:, :, 0]*img_size + samps_discr[:, :, 1]
    if step % 10 == 0:
      print(samps_discr)

      # now with z as our new latent points, optimise the data fitting loss
    g = F(torch.cat((samps.detach(), z_rep_main), dim=-1))
    L_outer = ((g - x[torch.arange(batch_size).unsqueeze(1), samps_discr.detach(), :])**2).sum(1).mean()
    print('sampled loss')
    print(L_outer)
    optim_main.zero_grad()
    L_outer.backward()
    optim_main.step()

    optim_adv.zero_grad()
    g = F(torch.cat((samps, z_rep_main), dim=-1))
    L_outer = ((g - x[torch.arange(batch_size).unsqueeze(1), samps_discr, :])**2).sum(1).mean()
    L_outer.backward()
    optim_adv.step()


    # compute sampling statistics
    recent_zs.append(z.detach())
    recent_zs = recent_zs[-100:]

    if step % 100 == 0 and step > 0:
        print(f"Step: {step}   Loss: {L_outer.item():.3f}")
        z_rep = z.repeat(1, c.shape[1], 1)
        g = F(torch.cat((c, z_rep), dim=-1))
        # plot reconstructions, interpolations, and samples
        recons = torchvision.utils.make_grid(torch.clamp(g, 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size), nrow=8)
        slerps = torchvision.utils.make_grid(torch.clamp(slerp_batch(F, z.data, c), 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size), nrow=8)
        sample = torchvision.utils.make_grid(torch.clamp(gon_sample(F, recent_zs, c), 0, 1).permute(0,2,1).reshape(-1, n_channels, img_size, img_size))

        clear_output()
        plt.title('Reconstructions')
        plt.imshow(recons[0,:,:].reshape(242,242,1).repeat(1,1,3).cpu().data.numpy())
        plt.figure()
        plt.title('Spherical Interpolations')
        plt.imshow(slerps[0,:,:].reshape(242,242,1).repeat(1,1,3).cpu().data.numpy())
        plt.figure()
        plt.title('Samples')
        plt.imshow(sample[0,:,:].reshape(242,242,1).repeat(1,1,3).cpu().data.numpy())
        plt.show()
        sleep(1)

**Comments:**

The gradient origin network loss is:

$$G_{\mathbf{x}} = \int \mathcal{L} \Big( \Phi_{\mathbf{x}}(\mathbf{c}), F\Big(\mathbf{c} \oplus -\nabla_{\mathbf{z}_0} \int \mathcal{L} \big( \Phi_{\mathbf{x}}(\mathbf{c}), F(\mathbf{c} \oplus \mathbf{z}_0) \big) \mathrm{d}\mathbf{c} \Big) \Big) \mathrm{d}\mathbf{c},$$

where we first compute the gradients of the inner loss with respect to the zero vector $\mathbf{z}_0$:
$$\mathbf{z}=-\nabla_{\mathbf{z}_0} \int \mathcal{L} \big( \Phi_{\mathbf{x}}(\mathbf{c}), F(\mathbf{c} \oplus \mathbf{z}_0) \big) \mathrm{d}\mathbf{c}.$$

```
z = torch.zeros(batch_size, 1, num_latent).to(device).requires_grad_()
z_rep = z.repeat(1,c.size(1),1)
g = F(torch.cat((c, z_rep), dim=-1))
L_inner = ((g - x)**2).sum(1).mean()
z = -torch.autograd.grad(L_inner, [z], create_graph=True, retain_graph=True)[0]
```

These gradients act as the latent space that we will call $\mathbf{z}$. They are then concatenated $\oplus$ to the coordinates $\mathbf{c}$ and now we can optimise the outer loss:

$$G_{\mathbf{x}} = \int \mathcal{L} \Big( \Phi_{\mathbf{x}}(\mathbf{c}), F(\mathbf{c} \oplus \mathbf{z} ) \Big) \mathrm{d}\mathbf{c}$$

```
z_rep = z.repeat(1, c.size(1), 1)
g = F(torch.cat((c, z_rep), dim=-1))
L_outer = ((g - x)**2).sum(1).mean()
optim.zero_grad()
L_outer.backward()
optim.step()
```

When trained, we can simply sample $\mathbf{z}\sim p_z$ and query the model $F(\mathbf{c} \oplus \mathbf{z})$:

```
z_rep = z.repeat(1, c.size(1), 1)
g_sampled = F(torch.cat((c, z_rep), dim=-1))
```