In [1]:
from random_convolution import *

# Toy tests

In [2]:
size = (3, 8, 8)
n = ceil(log2(size[1]))
x = torch.rand(size)
x = torch.unsqueeze(x, dim=0)
cbs = Conv_Base_Slicer(size, n)

print(generate_base_kernel_size(size, n))
print(sample_unit_kernel((3, 5, 5)).sum())
print(cbs(x))

[(1, 3, 5, 5), (1, 1, 3, 3), (1, 1, 2, 2)]
tensor(1.0000)
tensor([[[[0.5244]]]])


In [3]:
x = torch.rand((10, *size))
print(cbs(x).size())

torch.Size([10, 1, 1, 1])


In [4]:
size = (3, 10, 10)
n = ceil(log2(size[1]))
x = torch.rand(size)
x = torch.unsqueeze(x, dim=0)
css = Conv_Stride_Slicer(size, n)
cds = Conv_Dilatation_Slicer(size, n)

print(generate_kernel_size(size, n))
print(css(x))
print(cds(x))

[(1, 3, 2, 2), (1, 1, 2, 2), (1, 1, 2, 2), (1, 1, 1, 1)]
tensor([[[[0.5415]]]])
tensor([[[[0.4762]]]])


In [5]:
size = (3, 8, 8)
n = ceil(log2(size[1]))
x = torch.rand(size)
x = torch.unsqueeze(x, dim=0)
csw = Conv_Sliced_Wasserstein(size, n, L=4)
csw.forward(x)

tensor([0.5177, 0.5244, 0.5086, 0.5121])

In [6]:
x = torch.rand((10, *size))
print(csw(x))

tensor([[0.4482, 0.4650, 0.4514, 0.4477],
        [0.5103, 0.4882, 0.5042, 0.5088],
        [0.4482, 0.4642, 0.4564, 0.4523],
        [0.4798, 0.4854, 0.4832, 0.4646],
        [0.4296, 0.4388, 0.4354, 0.4353],
        [0.5111, 0.5154, 0.5130, 0.5123],
        [0.5084, 0.5012, 0.5070, 0.5106],
        [0.4974, 0.4745, 0.4927, 0.5012],
        [0.5383, 0.5449, 0.5292, 0.5501],
        [0.4968, 0.5027, 0.5000, 0.4860]])


In [7]:
size = (3, 10, 10)
n = ceil(log2(size[1]))
csw = Conv_Sliced_Wasserstein(size, n, L=4, type="stride")
x = torch.rand((1, *size))
csw.forward(x)

tensor([0.5310, 0.5007, 0.4934, 0.4924])

In [8]:
x = torch.rand((10, *size))
print(csw(x))

tensor([[0.5723, 0.5036, 0.5627, 0.5571],
        [0.5183, 0.5242, 0.5138, 0.4950],
        [0.5537, 0.5131, 0.5409, 0.5509],
        [0.5051, 0.5005, 0.4739, 0.4964],
        [0.5253, 0.5775, 0.5064, 0.5437],
        [0.4952, 0.5361, 0.5025, 0.5083],
        [0.4640, 0.5281, 0.5445, 0.5055],
        [0.4953, 0.4972, 0.5081, 0.5106],
        [0.4700, 0.4937, 0.4505, 0.4561],
        [0.5309, 0.5441, 0.5476, 0.5377]])


In [9]:
size = (3, 10, 10)
n = ceil(log2(size[1]))
csw = Conv_Sliced_Wasserstein(size, n, L=4, type="dilatation")
x = torch.rand((1, *size))
csw.forward(x)

tensor([0.5010, 0.5122, 0.4790, 0.5224])

In [10]:
x = torch.rand((10, *size))
print(csw(x))

tensor([[0.4749, 0.4964, 0.4005, 0.4859],
        [0.4924, 0.5412, 0.4902, 0.5077],
        [0.5176, 0.5081, 0.5235, 0.5484],
        [0.4569, 0.4706, 0.4343, 0.4661],
        [0.5320, 0.5565, 0.5496, 0.4978],
        [0.4744, 0.4862, 0.4657, 0.4863],
        [0.5090, 0.4668, 0.4989, 0.4474],
        [0.4026, 0.3991, 0.4497, 0.4131],
        [0.4902, 0.4793, 0.5610, 0.4679],
        [0.5034, 0.4976, 0.4961, 0.5258]])


In [11]:
size = (3, 10, 10)
n = ceil(log2(size[1]))
csw = Conv_Sliced_Wasserstein(size, n, L=4, type="stride")
mu = torch.rand((2, *size))
nu = torch.rand((2, *size))
wasserstein_distance(csw.forward(mu), csw.forward(nu))

tensor([0.0017, 0.0038])

# CelabA tests

In [12]:
from torchvision import datasets, transforms

In [20]:
transform = transforms.Compose([transforms.Resize(255),
                                transforms.CenterCrop(224),
                                transforms.ToTensor()])

dataset = datasets.ImageFolder('data', transform=transform)

In [21]:
dataloader = torch.utils.data.DataLoader(dataset)
dat = iter(dataloader)
images1, _ = next(dat)
images2, _ = next(dat)
images1.shape

torch.Size([1, 3, 224, 224])

In [22]:
size = (3, 224, 224)
n = ceil(log2(size[1]))
csw = Conv_Sliced_Wasserstein(size, n, L=10, type="stride")
print(csw.forward(images1))
print(csw.forward(images2))

tensor([0.5488, 0.6603, 0.8200, 0.7236, 0.5125, 0.5549, 0.6736, 0.6272, 0.4873,
        0.5333])
tensor([0.4280, 0.4161, 0.2483, 0.3599, 0.4128, 0.3805, 0.3417, 0.4116, 0.3922,
        0.4099])


In [23]:
def mono_wasserstein_distance(mu:torch.Tensor, nu:torch.Tensor, p=2):
    """
    Sliced Wasserstein distance between encoded samples and distribution samples

    Args:
        mu (torch.Tensor): tensor of samples from measure mu
        nu (torch.Tensor): tensor of samples from measure nu
        p (int): power of distance metric

    Return:
        torch.Tensor: Tensor of wasserstein distances of size (num_projections, 1)
    """
    wasserstein_distance = (torch.sort(mu).values -
                            torch.sort(nu).values)

    wasserstein_distance = torch.pow(wasserstein_distance, p)
    return wasserstein_distance.mean()

In [24]:
mono_wasserstein_distance(csw.forward(images1), csw.forward(images2))

tensor(0.0596)

In [25]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True)
dat = iter(dataloader)
images1, _ = next(dat)
images2, _ = next(dat)
images1.shape

torch.Size([5, 3, 224, 224])

In [26]:
wasserstein_distance(csw.forward(images1), csw.forward(images2))

tensor([0.0693, 0.0038, 0.0314, 0.0047, 0.0711])