In [26]:
import os
import math
import time
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.models import alexnet

In [89]:
model = alexnet()
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        n = m.in_channels
        for k in m.kernel_size:
            n*=k
        stdv = 1. / math.sqrt(n)
        m.weight.data.uniform_(-stdv,stdv)
        if m.bias is not None:
            m.bias.data.uniform_(-stdv,stdv)
            
convs = []
# list of a conv layers within alexnn
for i,k in enumerate(model.modules()):
    if isinstance(k, nn.Conv2d):
        convs.append(k)

In [90]:
l1 = convs[0].weight

In [92]:
for k in range(len(convs)):
    l1 = convs[k].weight
    dist = 0.0
    for j in range(l1.shape[1]): # iterate over channels
        X = l1[0][j].view(1,l1[0][j].shape[1]**2)  # gets first elem
        for i,w in enumerate(l1): # iterates over filters 0-> 64
            if i == 0:
                continue
            else:
                y = w[0].view(1,w[j].shape[1]**2)
                X = torch.cat((X,y))

        VI = torch.inverse(cov(X)) #inverse of covariance matrix
        for _input in X:
            dist += _batch_mahalanobis(VI,_input)
            #print(_batch_mahalanobis(VI,_input))
    print(dist)

tensor(0.0003, grad_fn=<ThAddBackward>)
tensor(4.0939e-06, grad_fn=<ThAddBackward>)
tensor(5.1202e-06, grad_fn=<ThAddBackward>)
tensor(8.6133e-07, grad_fn=<ThAddBackward>)
tensor(1.9252e-06, grad_fn=<ThAddBackward>)


In [4]:
# Returns the covariance matrix of m
def cov(m, rowvar=False):
    if m.dim() > 2:
        raise ValueError('m has more than 2 dimensions')
    if m.dim() < 2:
        m = m.view(1, -1)
    if not rowvar and m.size(0) != 1:
        m = m.t()
    # m = m.type(torch.double)  # uncomment this line if desired
    fact = 1.0 / (m.size(1) - 1)
    m -= torch.mean(m, dim=1, keepdim=True)
    mt = m.t()  # if complex: mt = m.t().conj()
    return fact * m.matmul(mt).squeeze()

In [5]:
def _batch_mahalanobis(L, x):
    r"""
    Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
    for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.

    Accepts batches for both L and x.
    """
    # TODO: use `torch.potrs` or similar once a backwards pass is implemented.
    flat_L = L.unsqueeze(0).reshape((-1,) + L.shape[-2:])
    L_inv = torch.stack([torch.inverse(Li.t()) for Li in flat_L]).view(L.shape)
    return (x.unsqueeze(-1) * L_inv).sum(-2).pow(2.0).sum(-1)

In [7]:
VI = torch.inverse(cov(X)) #inverse of covariance matrix
for _input in X:
    print(_batch_mahalanobis(VI,_input))

tensor(6.6474e-08, grad_fn=<SumBackward1>)
tensor(1.9875e-07, grad_fn=<SumBackward1>)
tensor(1.6561e-07, grad_fn=<SumBackward1>)
tensor(6.8049e-07, grad_fn=<SumBackward1>)
tensor(1.6184e-07, grad_fn=<SumBackward1>)
tensor(1.1137e-07, grad_fn=<SumBackward1>)
tensor(4.6328e-07, grad_fn=<SumBackward1>)
tensor(1.9133e-07, grad_fn=<SumBackward1>)
tensor(1.4527e-07, grad_fn=<SumBackward1>)
tensor(4.9883e-08, grad_fn=<SumBackward1>)
tensor(2.1414e-07, grad_fn=<SumBackward1>)
tensor(5.9594e-07, grad_fn=<SumBackward1>)
tensor(1.4073e-07, grad_fn=<SumBackward1>)
tensor(3.4788e-07, grad_fn=<SumBackward1>)
tensor(3.4324e-07, grad_fn=<SumBackward1>)
tensor(1.4160e-06, grad_fn=<SumBackward1>)
tensor(4.8469e-08, grad_fn=<SumBackward1>)
tensor(3.9130e-07, grad_fn=<SumBackward1>)
tensor(1.0504e-07, grad_fn=<SumBackward1>)
tensor(2.3756e-07, grad_fn=<SumBackward1>)
tensor(2.5613e-07, grad_fn=<SumBackward1>)
tensor(6.8629e-07, grad_fn=<SumBackward1>)
tensor(1.3376e-06, grad_fn=<SumBackward1>)
tensor(2.07

In [28]:
tmp = X
tmp

tensor([[-0.0474,  0.0017,  0.0472,  ...,  0.0012, -0.0300,  0.0392],
        [-0.0177,  0.0123,  0.0127,  ...,  0.0099, -0.0233,  0.0153],
        [ 0.0387, -0.0177, -0.0252,  ...,  0.0216, -0.0470, -0.0273],
        ...,
        [-0.0083, -0.0160,  0.0233,  ..., -0.0245,  0.0362,  0.0295],
        [ 0.0460,  0.0058, -0.0124,  ..., -0.0407,  0.0351,  0.0305],
        [ 0.0217, -0.0162, -0.0111,  ...,  0.0115,  0.0021,  0.0515]],
       grad_fn=<CatBackward>)

In [31]:
tmp2 = X
tmp2

tensor([[-0.0031,  0.0523, -0.0470,  ..., -0.0072,  0.0420,  0.0186],
        [-0.0177,  0.0123,  0.0127,  ...,  0.0099, -0.0233,  0.0153],
        [ 0.0387, -0.0177, -0.0252,  ...,  0.0216, -0.0470, -0.0273],
        ...,
        [-0.0083, -0.0160,  0.0233,  ..., -0.0245,  0.0362,  0.0295],
        [ 0.0460,  0.0058, -0.0124,  ..., -0.0407,  0.0351,  0.0305],
        [ 0.0217, -0.0162, -0.0111,  ...,  0.0115,  0.0021,  0.0515]],
       grad_fn=<CatBackward>)

In [33]:
X

tensor([[ 0.0184, -0.0160, -0.0388,  ...,  0.0068, -0.0235, -0.0444],
        [-0.0177,  0.0123,  0.0127,  ...,  0.0099, -0.0233,  0.0153],
        [ 0.0387, -0.0177, -0.0252,  ...,  0.0216, -0.0470, -0.0273],
        ...,
        [-0.0083, -0.0160,  0.0233,  ..., -0.0245,  0.0362,  0.0295],
        [ 0.0460,  0.0058, -0.0124,  ..., -0.0407,  0.0351,  0.0305],
        [ 0.0217, -0.0162, -0.0111,  ...,  0.0115,  0.0021,  0.0515]],
       grad_fn=<CatBackward>)