In [1]:
from random_convolution import *

# Toy tests

In [2]:
size = (3, 8, 8)
n = ceil(log2(size[1]))
x = torch.rand(size)
x = torch.unsqueeze(x, dim=0)
cbs = Conv_Base_Slicer(size, n)

print(generate_base_kernel_size(size, n))
print(sample_unit_kernel((3, 5, 5)).sum())
print(cbs(x))

[(1, 3, 5, 5), (1, 1, 3, 3), (1, 1, 2, 2)]
tensor(1.0000)
tensor([[[[0.4878]]]])


In [3]:
x = torch.rand((10, *size))
print(cbs(x).size())

torch.Size([10, 1, 1, 1])


In [4]:
size = (3, 10, 10)
n = ceil(log2(size[1]))
x = torch.rand(size)
x = torch.unsqueeze(x, dim=0)
css = Conv_Stride_Slicer(size, n)
cds = Conv_Dilatation_Slicer(size, n)

print(generate_kernel_size(size, n))
print(css(x))
print(cds(x))

[(1, 3, 2, 2), (1, 1, 2, 2), (1, 1, 2, 2), (1, 1, 1, 1)]
tensor([[[[0.5285]]]])
tensor([[[[0.5332]]]])


In [5]:
size = (3, 8, 8)
n = ceil(log2(size[1]))
x = torch.rand(size)
x = torch.unsqueeze(x, dim=0)
csw = Conv_Sliced_Wasserstein(size, n, L=4)
csw.forward(x)

tensor([0.5104, 0.5057, 0.5001, 0.5211])

In [6]:
x = torch.rand((10, *size))
print(csw(x))

tensor([[0.5055, 0.5033, 0.5067, 0.5206],
        [0.4548, 0.4567, 0.4502, 0.4650],
        [0.5645, 0.5642, 0.5676, 0.5524],
        [0.4292, 0.4362, 0.4163, 0.4254],
        [0.4693, 0.4698, 0.4758, 0.4835],
        [0.5277, 0.5325, 0.5245, 0.5264],
        [0.4618, 0.4595, 0.4537, 0.4683],
        [0.4855, 0.4821, 0.4785, 0.4986],
        [0.5383, 0.5463, 0.5460, 0.5448],
        [0.4798, 0.4801, 0.4826, 0.4866]])


In [7]:
size = (3, 10, 10)
n = ceil(log2(size[1]))
csw = Conv_Sliced_Wasserstein(size, n, L=4, type="stride")
x = torch.rand((1, *size))
csw.forward(x)

tensor([0.5137, 0.4808, 0.4693, 0.4493])

In [8]:
x = torch.rand((10, *size))
print(csw(x))

tensor([[0.4764, 0.5444, 0.5031, 0.4895],
        [0.4852, 0.5058, 0.5348, 0.4999],
        [0.5132, 0.4939, 0.4829, 0.4936],
        [0.4518, 0.4291, 0.4278, 0.4767],
        [0.5362, 0.5673, 0.5372, 0.5756],
        [0.5044, 0.5486, 0.5162, 0.5209],
        [0.4804, 0.4732, 0.4976, 0.4959],
        [0.4685, 0.5248, 0.4831, 0.4839],
        [0.4818, 0.5406, 0.5078, 0.4993],
        [0.4271, 0.4935, 0.4961, 0.5275]])


In [9]:
size = (3, 10, 10)
n = ceil(log2(size[1]))
csw = Conv_Sliced_Wasserstein(size, n, L=4, type="dilatation")
x = torch.rand((1, *size))
csw.forward(x)

tensor([0.5155, 0.5588, 0.5134, 0.5568])

In [10]:
x = torch.rand((10, *size))
print(csw(x))

tensor([[0.4979, 0.5227, 0.4602, 0.4823],
        [0.4217, 0.4307, 0.3996, 0.4431],
        [0.4702, 0.4784, 0.5042, 0.4783],
        [0.5758, 0.5533, 0.5593, 0.5412],
        [0.4645, 0.4922, 0.4463, 0.4335],
        [0.4622, 0.4536, 0.5017, 0.4867],
        [0.4253, 0.4358, 0.4416, 0.5065],
        [0.5069, 0.5179, 0.5341, 0.5601],
        [0.5568, 0.5473, 0.5540, 0.5549],
        [0.6072, 0.5827, 0.5771, 0.5598]])


In [11]:
size = (3, 10, 10)
n = ceil(log2(size[1]))
csw = Conv_Sliced_Wasserstein(size, n, L=4, type="stride")
mu = torch.rand((2, *size))
nu = torch.rand((2, *size))
wasserstein_distance(csw.forward(mu), csw.forward(nu))

tensor([2.0789e-05, 3.6989e-04])

# CelabA tests

In [12]:
from torchvision import datasets, transforms

In [13]:
transform = transforms.Compose([transforms.Resize(255),
                                transforms.CenterCrop(224),
                                transforms.ToTensor()])

dataset = datasets.ImageFolder('data', transform=transform)

In [14]:
dataloader = torch.utils.data.DataLoader(dataset)
dat = iter(dataloader)
images1, _ = next(dat)
images2, _ = next(dat)
images1.shape

torch.Size([1, 3, 224, 224])

In [15]:
size = (3, 224, 224)
n = ceil(log2(size[1]))
csw = Conv_Sliced_Wasserstein(size, n, L=800, type="stride")
print(csw.forward(images1))
print(csw.forward(images2))

tensor([0.5451, 0.5508, 0.5938, 0.5469, 0.6189, 0.5832, 0.5885, 0.6948, 0.7179,
        0.6511, 0.7055, 0.6716, 0.5576, 0.7057, 0.7084, 0.5791, 0.5595, 0.6839,
        0.6313, 0.5667, 0.6000, 0.4963, 0.6307, 0.6575, 0.5424, 0.6341, 0.7896,
        0.6285, 0.5313, 0.6868, 0.6522, 0.6465, 0.6395, 0.6369, 0.7399, 0.6671,
        0.5754, 0.6450, 0.6376, 0.6162, 0.7175, 0.6727, 0.6326, 0.5426, 0.7159,
        0.5639, 0.5797, 0.6759, 0.7883, 0.6054, 0.5951, 0.6188, 0.6805, 0.5972,
        0.6853, 0.6662, 0.5144, 0.6305, 0.5696, 0.6442, 0.6621, 0.5649, 0.7944,
        0.6385, 0.6090, 0.6555, 0.7119, 0.7031, 0.5886, 0.6918, 0.5456, 0.5532,
        0.7675, 0.5968, 0.6774, 0.5409, 0.5577, 0.5760, 0.6810, 0.7208, 0.6024,
        0.6029, 0.7351, 0.5961, 0.6985, 0.5011, 0.6602, 0.7254, 0.5856, 0.6045,
        0.7265, 0.6952, 0.6847, 0.5458, 0.5974, 0.6330, 0.6952, 0.5382, 0.5714,
        0.5431, 0.6137, 0.5939, 0.5866, 0.5888, 0.6331, 0.5992, 0.5923, 0.6617,
        0.6375, 0.6019, 0.6178, 0.5720, 

In [16]:
def mono_wasserstein_distance(mu:torch.Tensor, nu:torch.Tensor, p=2):
    """
    Sliced Wasserstein distance between encoded samples and distribution samples

    Args:
        mu (torch.Tensor): tensor of samples from measure mu
        nu (torch.Tensor): tensor of samples from measure nu
        p (int): power of distance metric

    Return:
        torch.Tensor: Tensor of wasserstein distances of size (num_projections, 1)
    """
    wasserstein_distance = (torch.sort(mu).values -
                            torch.sort(nu).values)

    wasserstein_distance = torch.pow(wasserstein_distance, p)
    return wasserstein_distance.mean()

In [17]:
mono_wasserstein_distance(csw.forward(images1), csw.forward(images2))

tensor(0.0424)

In [18]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True)
dat = iter(dataloader)
images1, _ = next(dat)
images2, _ = next(dat)

In [24]:
%time
wasserstein_distance(csw.forward(images1), csw.forward(images2))

CPU times: user 5 µs, sys: 0 ns, total: 5 µs
Wall time: 10.7 µs


tensor([0.0379, 0.0042, 0.0093, 0.0009, 0.0603])

In [22]:
from losses import *

vec_images1 = images1.view(5, -1)
vec_images2 = images2.view(5, -1)

In [23]:
%time
sliced_wasserstein_distance(vec_images1, vec_images2, 1000)

CPU times: user 8 µs, sys: 0 ns, total: 8 µs
Wall time: 16.7 µs


tensor([0.0401, 0.0242, 0.0540, 0.0631, 0.1428, 0.0058, 0.1297, 0.0864, 0.1324,
        0.0318, 0.0281, 0.0857, 0.0170, 0.0179, 0.0243, 0.0628, 0.0509, 0.1052,
        0.0384, 0.0059, 0.0173, 0.0226, 0.0164, 0.0999, 0.0056, 0.0978, 0.0589,
        0.0102, 0.0233, 0.0309, 0.0078, 0.0582, 0.0227, 0.0406, 0.0707, 0.0166,
        0.0095, 0.0132, 0.0243, 0.0255, 0.0091, 0.0213, 0.0395, 0.0173, 0.0882,
        0.1329, 0.0161, 0.0359, 0.0089, 0.0147, 0.0293, 0.1167, 0.0335, 0.0547,
        0.0147, 0.0194, 0.0092, 0.0094, 0.0301, 0.0757, 0.0596, 0.0186, 0.1053,
        0.1157, 0.2106, 0.0066, 0.0528, 0.0507, 0.0474, 0.0566, 0.0458, 0.1675,
        0.0852, 0.1355, 0.0378, 0.0338, 0.0020, 0.0405, 0.0503, 0.0239, 0.0476,
        0.0220, 0.0239, 0.0137, 0.0204, 0.0217, 0.0194, 0.0287, 0.0145, 0.0140,
        0.0338, 0.0861, 0.0028, 0.0485, 0.0379, 0.0303, 0.0192, 0.0197, 0.0284,
        0.0133, 0.0167, 0.0153, 0.0576, 0.0993, 0.0691, 0.0800, 0.0299, 0.0223,
        0.0499, 0.0766, 0.0524, 0.0466, 