In [1]:
import os
import time
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.models import alexnet

In [2]:
model = alexnet()
convs = []
# list of a conv layers within alexnn
for i,k in enumerate(model.modules()):
    if isinstance(k, nn.Conv2d):
        convs.append(k)

In [3]:
l1 = convs[0].weight

In [9]:
# Returns the covariance matrix of m
def cov(m, rowvar=False):
    if m.dim() > 2:
        raise ValueError('m has more than 2 dimensions')
    if m.dim() < 2:
        m = m.view(1, -1)
    if not rowvar and m.size(0) != 1:
        m = m.t()
    # m = m.type(torch.double)  # uncomment this line if desired
    fact = 1.0 / (m.size(1) - 1)
    m -= torch.mean(m, dim=1, keepdim=True)
    mt = m.t()  # if complex: mt = m.t().conj()
    return fact * m.matmul(mt).squeeze()

In [10]:
def _mahalanobis(X):
    VI = torch.inverse(cov(X)) #covariance matrix
    total_dist = 0
    for i,v in enumerate(X):
        dist = 0
        for j,u in enumerate(X):
            if i == j:
                continue
            x = (v-u).unsqueeze(0).t()
            y = (v - u).unsqueeze(0)
            
            dist = (torch.mm(torch.mm(y,VI),x)) #sqrt of dist returns NaN (?)
            
            total_dist +=dist
            print(dist)
    return total_dist

In [11]:
for j in range(l1.shape[1]): # iterate over channels
    X = torch.reshape(l1[0][j], [1,l1[0][j].shape[1]**2])  # gets first elem
    for i,w in enumerate(l1): # iterates over filters
        if i == 0:
            continue
        else:
            y = torch.reshape(w[0], [1,w[j].shape[1]**2])
            X = torch.cat((X,y))
    dist = _mahalanobis(X)

tensor([[749.8835]], grad_fn=<MmBackward>)
tensor([[-534.8154]], grad_fn=<MmBackward>)
tensor([[-592.5972]], grad_fn=<MmBackward>)
tensor([[-582.0947]], grad_fn=<MmBackward>)
tensor([[-130.7014]], grad_fn=<MmBackward>)
tensor([[-85.6831]], grad_fn=<MmBackward>)
tensor([[-93.1754]], grad_fn=<MmBackward>)
tensor([[103.4254]], grad_fn=<MmBackward>)
tensor([[-122.3977]], grad_fn=<MmBackward>)
tensor([[-312.2463]], grad_fn=<MmBackward>)
tensor([[707.6191]], grad_fn=<MmBackward>)
tensor([[365.0967]], grad_fn=<MmBackward>)
tensor([[-300.7344]], grad_fn=<MmBackward>)
tensor([[-266.0461]], grad_fn=<MmBackward>)
tensor([[-33.8097]], grad_fn=<MmBackward>)
tensor([[135.6397]], grad_fn=<MmBackward>)
tensor([[97.8306]], grad_fn=<MmBackward>)
tensor([[275.6489]], grad_fn=<MmBackward>)
tensor([[-373.2925]], grad_fn=<MmBackward>)
tensor([[-62.2648]], grad_fn=<MmBackward>)
tensor([[594.2887]], grad_fn=<MmBackward>)
tensor([[303.8413]], grad_fn=<MmBackward>)
tensor([[65.8462]], grad_fn=<MmBackward>)
tens

tensor([[8.2937]], grad_fn=<MmBackward>)
tensor([[10.0557]], grad_fn=<MmBackward>)
tensor([[-72.8325]], grad_fn=<MmBackward>)
tensor([[259.3216]], grad_fn=<MmBackward>)
tensor([[28.7845]], grad_fn=<MmBackward>)
tensor([[-438.4448]], grad_fn=<MmBackward>)
tensor([[-103.3669]], grad_fn=<MmBackward>)
tensor([[-2.5659]], grad_fn=<MmBackward>)
tensor([[-423.2612]], grad_fn=<MmBackward>)
tensor([[45.7019]], grad_fn=<MmBackward>)
tensor([[253.4697]], grad_fn=<MmBackward>)
tensor([[-309.7466]], grad_fn=<MmBackward>)
tensor([[367.6222]], grad_fn=<MmBackward>)
tensor([[-181.0813]], grad_fn=<MmBackward>)
tensor([[483.4954]], grad_fn=<MmBackward>)
tensor([[-251.3340]], grad_fn=<MmBackward>)
tensor([[237.9578]], grad_fn=<MmBackward>)
tensor([[393.3701]], grad_fn=<MmBackward>)
tensor([[-83.3506]], grad_fn=<MmBackward>)
tensor([[35.1517]], grad_fn=<MmBackward>)
tensor([[140.3040]], grad_fn=<MmBackward>)
tensor([[273.4810]], grad_fn=<MmBackward>)
tensor([[112.8409]], grad_fn=<MmBackward>)
tensor([[-80

tensor([[-108.7422]], grad_fn=<MmBackward>)
tensor([[-647.2025]], grad_fn=<MmBackward>)
tensor([[-260.0858]], grad_fn=<MmBackward>)
tensor([[-525.1993]], grad_fn=<MmBackward>)
tensor([[138.9768]], grad_fn=<MmBackward>)
tensor([[-33.8043]], grad_fn=<MmBackward>)
tensor([[-415.6133]], grad_fn=<MmBackward>)
tensor([[-215.7888]], grad_fn=<MmBackward>)
tensor([[-204.6566]], grad_fn=<MmBackward>)
tensor([[207.5641]], grad_fn=<MmBackward>)
tensor([[332.9723]], grad_fn=<MmBackward>)
tensor([[-98.6996]], grad_fn=<MmBackward>)
tensor([[79.9187]], grad_fn=<MmBackward>)
tensor([[-238.4611]], grad_fn=<MmBackward>)
tensor([[-116.9240]], grad_fn=<MmBackward>)
tensor([[-221.5742]], grad_fn=<MmBackward>)
tensor([[-63.6620]], grad_fn=<MmBackward>)
tensor([[-307.1343]], grad_fn=<MmBackward>)
tensor([[-573.4358]], grad_fn=<MmBackward>)
tensor([[165.3218]], grad_fn=<MmBackward>)
tensor([[302.5840]], grad_fn=<MmBackward>)
tensor([[-69.2216]], grad_fn=<MmBackward>)
tensor([[-512.9858]], grad_fn=<MmBackward>)

tensor([[-41.7153]], grad_fn=<MmBackward>)
tensor([[274.2154]], grad_fn=<MmBackward>)
tensor([[221.9221]], grad_fn=<MmBackward>)
tensor([[-736.9059]], grad_fn=<MmBackward>)
tensor([[-269.9543]], grad_fn=<MmBackward>)
tensor([[75.3652]], grad_fn=<MmBackward>)
tensor([[120.0381]], grad_fn=<MmBackward>)
tensor([[-185.4738]], grad_fn=<MmBackward>)
tensor([[436.8979]], grad_fn=<MmBackward>)
tensor([[-395.9124]], grad_fn=<MmBackward>)
tensor([[-423.2612]], grad_fn=<MmBackward>)
tensor([[534.6362]], grad_fn=<MmBackward>)
tensor([[279.6924]], grad_fn=<MmBackward>)
tensor([[-90.5048]], grad_fn=<MmBackward>)
tensor([[-368.6672]], grad_fn=<MmBackward>)
tensor([[-360.6439]], grad_fn=<MmBackward>)
tensor([[526.1634]], grad_fn=<MmBackward>)
tensor([[581.6711]], grad_fn=<MmBackward>)
tensor([[155.6825]], grad_fn=<MmBackward>)
tensor([[79.0070]], grad_fn=<MmBackward>)
tensor([[422.1490]], grad_fn=<MmBackward>)
tensor([[165.5236]], grad_fn=<MmBackward>)
tensor([[179.5701]], grad_fn=<MmBackward>)
tensor

tensor([[-355.2850]], grad_fn=<MmBackward>)
tensor([[-279.1022]], grad_fn=<MmBackward>)
tensor([[217.5464]], grad_fn=<MmBackward>)
tensor([[43.6486]], grad_fn=<MmBackward>)
tensor([[-404.7889]], grad_fn=<MmBackward>)
tensor([[-98.5031]], grad_fn=<MmBackward>)
tensor([[140.3040]], grad_fn=<MmBackward>)
tensor([[95.9202]], grad_fn=<MmBackward>)
tensor([[124.0964]], grad_fn=<MmBackward>)
tensor([[36.1555]], grad_fn=<MmBackward>)
tensor([[-258.1501]], grad_fn=<MmBackward>)
tensor([[-269.2947]], grad_fn=<MmBackward>)
tensor([[-101.8331]], grad_fn=<MmBackward>)
tensor([[-47.8862]], grad_fn=<MmBackward>)
tensor([[362.9493]], grad_fn=<MmBackward>)
tensor([[326.3384]], grad_fn=<MmBackward>)
tensor([[278.1084]], grad_fn=<MmBackward>)
tensor([[170.3495]], grad_fn=<MmBackward>)
tensor([[-16.3545]], grad_fn=<MmBackward>)
tensor([[61.2056]], grad_fn=<MmBackward>)
tensor([[-285.7826]], grad_fn=<MmBackward>)
tensor([[-341.6729]], grad_fn=<MmBackward>)
tensor([[-98.4189]], grad_fn=<MmBackward>)
tensor(

tensor([[290.3252]], grad_fn=<MmBackward>)
tensor([[515.2598]], grad_fn=<MmBackward>)
tensor([[91.7067]], grad_fn=<MmBackward>)
tensor([[389.1761]], grad_fn=<MmBackward>)
tensor([[-179.4839]], grad_fn=<MmBackward>)
tensor([[447.2991]], grad_fn=<MmBackward>)
tensor([[479.1746]], grad_fn=<MmBackward>)
tensor([[-105.6992]], grad_fn=<MmBackward>)
tensor([[241.0988]], grad_fn=<MmBackward>)
tensor([[-390.0127]], grad_fn=<MmBackward>)
tensor([[98.4911]], grad_fn=<MmBackward>)
tensor([[739.7866]], grad_fn=<MmBackward>)
tensor([[356.1343]], grad_fn=<MmBackward>)
tensor([[253.4495]], grad_fn=<MmBackward>)
tensor([[-223.3508]], grad_fn=<MmBackward>)
tensor([[-113.9581]], grad_fn=<MmBackward>)
tensor([[262.9363]], grad_fn=<MmBackward>)
tensor([[46.1683]], grad_fn=<MmBackward>)
tensor([[532.2051]], grad_fn=<MmBackward>)
tensor([[304.7719]], grad_fn=<MmBackward>)
tensor([[352.0063]], grad_fn=<MmBackward>)
tensor([[150.0676]], grad_fn=<MmBackward>)
tensor([[497.9631]], grad_fn=<MmBackward>)
tensor([[

tensor([[831.4781]], grad_fn=<MmBackward>)
tensor([[1012.3713]], grad_fn=<MmBackward>)
tensor([[436.4626]], grad_fn=<MmBackward>)
tensor([[242.3545]], grad_fn=<MmBackward>)
tensor([[341.1550]], grad_fn=<MmBackward>)
tensor([[896.8667]], grad_fn=<MmBackward>)
tensor([[106.8124]], grad_fn=<MmBackward>)
tensor([[22.8551]], grad_fn=<MmBackward>)
tensor([[407.9058]], grad_fn=<MmBackward>)
tensor([[731.0262]], grad_fn=<MmBackward>)
tensor([[533.9814]], grad_fn=<MmBackward>)
tensor([[339.2502]], grad_fn=<MmBackward>)
tensor([[-245.9580]], grad_fn=<MmBackward>)
tensor([[674.9114]], grad_fn=<MmBackward>)
tensor([[588.9087]], grad_fn=<MmBackward>)
tensor([[399.8811]], grad_fn=<MmBackward>)
tensor([[-169.0452]], grad_fn=<MmBackward>)
tensor([[-216.9084]], grad_fn=<MmBackward>)
tensor([[61.7717]], grad_fn=<MmBackward>)
tensor([[103.3193]], grad_fn=<MmBackward>)
tensor([[242.7627]], grad_fn=<MmBackward>)
tensor([[826.0184]], grad_fn=<MmBackward>)
tensor([[-20.8643]], grad_fn=<MmBackward>)
tensor([[

tensor([[471.8921]], grad_fn=<MmBackward>)
tensor([[-375.4700]], grad_fn=<MmBackward>)
tensor([[1094.8954]], grad_fn=<MmBackward>)
tensor([[328.2067]], grad_fn=<MmBackward>)
tensor([[-409.7949]], grad_fn=<MmBackward>)
tensor([[277.3687]], grad_fn=<MmBackward>)
tensor([[-526.1770]], grad_fn=<MmBackward>)
tensor([[91.0345]], grad_fn=<MmBackward>)
tensor([[-411.8632]], grad_fn=<MmBackward>)
tensor([[236.5439]], grad_fn=<MmBackward>)
tensor([[191.5377]], grad_fn=<MmBackward>)
tensor([[-0.5713]], grad_fn=<MmBackward>)
tensor([[-193.5916]], grad_fn=<MmBackward>)
tensor([[-1096.9517]], grad_fn=<MmBackward>)
tensor([[239.7316]], grad_fn=<MmBackward>)
tensor([[-26.4056]], grad_fn=<MmBackward>)
tensor([[261.3760]], grad_fn=<MmBackward>)
tensor([[-595.9834]], grad_fn=<MmBackward>)
tensor([[67.3742]], grad_fn=<MmBackward>)
tensor([[-688.4055]], grad_fn=<MmBackward>)
tensor([[-637.5889]], grad_fn=<MmBackward>)
tensor([[242.7627]], grad_fn=<MmBackward>)
tensor([[-265.7280]], grad_fn=<MmBackward>)
te

tensor([[406.9629]], grad_fn=<MmBackward>)
tensor([[-1070.4061]], grad_fn=<MmBackward>)
tensor([[-391.5581]], grad_fn=<MmBackward>)
tensor([[-448.4030]], grad_fn=<MmBackward>)
tensor([[213.8337]], grad_fn=<MmBackward>)
tensor([[-216.9972]], grad_fn=<MmBackward>)
tensor([[-1255.1301]], grad_fn=<MmBackward>)
tensor([[-549.3735]], grad_fn=<MmBackward>)
tensor([[-934.0615]], grad_fn=<MmBackward>)
tensor([[-1276.3579]], grad_fn=<MmBackward>)
tensor([[-184.9670]], grad_fn=<MmBackward>)
tensor([[-14.5925]], grad_fn=<MmBackward>)
tensor([[900.5566]], grad_fn=<MmBackward>)
tensor([[559.7996]], grad_fn=<MmBackward>)
tensor([[-316.6885]], grad_fn=<MmBackward>)
tensor([[-364.4199]], grad_fn=<MmBackward>)
tensor([[-151.5361]], grad_fn=<MmBackward>)
tensor([[606.7737]], grad_fn=<MmBackward>)
tensor([[-777.6082]], grad_fn=<MmBackward>)
tensor([[547.5378]], grad_fn=<MmBackward>)
tensor([[-296.3629]], grad_fn=<MmBackward>)
tensor([[-66.3805]], grad_fn=<MmBackward>)
tensor([[-264.9292]], grad_fn=<MmBack

tensor([[770.5427]], grad_fn=<MmBackward>)
tensor([[-9.9824]], grad_fn=<MmBackward>)
tensor([[-276.2842]], grad_fn=<MmBackward>)
tensor([[132.1133]], grad_fn=<MmBackward>)
tensor([[548.3164]], grad_fn=<MmBackward>)
tensor([[-440.9365]], grad_fn=<MmBackward>)
tensor([[-211.8877]], grad_fn=<MmBackward>)
tensor([[-340.6841]], grad_fn=<MmBackward>)
tensor([[-727.0939]], grad_fn=<MmBackward>)
tensor([[411.8291]], grad_fn=<MmBackward>)
tensor([[207.0312]], grad_fn=<MmBackward>)
tensor([[1040.8486]], grad_fn=<MmBackward>)
tensor([[-295.3212]], grad_fn=<MmBackward>)
tensor([[-24.1962]], grad_fn=<MmBackward>)
tensor([[181.9121]], grad_fn=<MmBackward>)
tensor([[-45.2627]], grad_fn=<MmBackward>)
tensor([[48.0681]], grad_fn=<MmBackward>)
tensor([[-691.1227]], grad_fn=<MmBackward>)
tensor([[213.2470]], grad_fn=<MmBackward>)
tensor([[-118.6947]], grad_fn=<MmBackward>)
tensor([[-109.4834]], grad_fn=<MmBackward>)
tensor([[237.0806]], grad_fn=<MmBackward>)
tensor([[-57.1379]], grad_fn=<MmBackward>)
ten

tensor([[-100.2192]], grad_fn=<MmBackward>)
tensor([[72.4883]], grad_fn=<MmBackward>)
tensor([[-617.1241]], grad_fn=<MmBackward>)
tensor([[556.6722]], grad_fn=<MmBackward>)
tensor([[342.2986]], grad_fn=<MmBackward>)
tensor([[253.1741]], grad_fn=<MmBackward>)
tensor([[1144.9718]], grad_fn=<MmBackward>)
tensor([[-581.4565]], grad_fn=<MmBackward>)
tensor([[625.0166]], grad_fn=<MmBackward>)
tensor([[306.3662]], grad_fn=<MmBackward>)
tensor([[136.5100]], grad_fn=<MmBackward>)
tensor([[9.1450]], grad_fn=<MmBackward>)
tensor([[1110.4365]], grad_fn=<MmBackward>)
tensor([[-545.8406]], grad_fn=<MmBackward>)
tensor([[201.4707]], grad_fn=<MmBackward>)
tensor([[-1311.2302]], grad_fn=<MmBackward>)
tensor([[690.3354]], grad_fn=<MmBackward>)
tensor([[587.5513]], grad_fn=<MmBackward>)
tensor([[448.1807]], grad_fn=<MmBackward>)
tensor([[421.9098]], grad_fn=<MmBackward>)
tensor([[525.8169]], grad_fn=<MmBackward>)
tensor([[32.4147]], grad_fn=<MmBackward>)
tensor([[-332.8782]], grad_fn=<MmBackward>)
tensor

tensor([[149.3470]], grad_fn=<MmBackward>)
tensor([[-27.1023]], grad_fn=<MmBackward>)
tensor([[-8.1085]], grad_fn=<MmBackward>)
tensor([[227.5879]], grad_fn=<MmBackward>)
tensor([[-43.4352]], grad_fn=<MmBackward>)
tensor([[-38.6246]], grad_fn=<MmBackward>)
tensor([[54.4117]], grad_fn=<MmBackward>)
tensor([[160.4006]], grad_fn=<MmBackward>)
tensor([[82.5381]], grad_fn=<MmBackward>)
tensor([[373.8635]], grad_fn=<MmBackward>)
tensor([[115.2407]], grad_fn=<MmBackward>)
tensor([[158.8371]], grad_fn=<MmBackward>)
tensor([[132.3941]], grad_fn=<MmBackward>)
tensor([[423.2581]], grad_fn=<MmBackward>)
tensor([[259.8892]], grad_fn=<MmBackward>)
tensor([[428.2164]], grad_fn=<MmBackward>)
tensor([[162.4255]], grad_fn=<MmBackward>)
tensor([[174.2199]], grad_fn=<MmBackward>)
tensor([[401.9761]], grad_fn=<MmBackward>)
tensor([[71.1528]], grad_fn=<MmBackward>)
tensor([[57.7042]], grad_fn=<MmBackward>)
tensor([[20.3089]], grad_fn=<MmBackward>)
tensor([[365.4258]], grad_fn=<MmBackward>)
tensor([[221.7783

tensor([[208.0748]], grad_fn=<MmBackward>)
tensor([[316.1853]], grad_fn=<MmBackward>)
tensor([[394.5621]], grad_fn=<MmBackward>)
tensor([[339.9115]], grad_fn=<MmBackward>)
tensor([[136.7828]], grad_fn=<MmBackward>)
tensor([[351.2020]], grad_fn=<MmBackward>)
tensor([[168.9940]], grad_fn=<MmBackward>)
tensor([[370.2239]], grad_fn=<MmBackward>)
tensor([[187.9184]], grad_fn=<MmBackward>)
tensor([[210.6504]], grad_fn=<MmBackward>)
tensor([[139.7827]], grad_fn=<MmBackward>)
tensor([[261.8187]], grad_fn=<MmBackward>)
tensor([[251.2572]], grad_fn=<MmBackward>)
tensor([[113.6602]], grad_fn=<MmBackward>)
tensor([[206.2564]], grad_fn=<MmBackward>)
tensor([[195.4327]], grad_fn=<MmBackward>)
tensor([[178.7332]], grad_fn=<MmBackward>)
tensor([[77.4753]], grad_fn=<MmBackward>)
tensor([[109.4015]], grad_fn=<MmBackward>)
tensor([[64.7884]], grad_fn=<MmBackward>)
tensor([[155.6767]], grad_fn=<MmBackward>)
tensor([[140.3190]], grad_fn=<MmBackward>)
tensor([[219.5689]], grad_fn=<MmBackward>)
tensor([[61.9

tensor([[-91.2293]], grad_fn=<MmBackward>)
tensor([[121.0939]], grad_fn=<MmBackward>)
tensor([[-151.2271]], grad_fn=<MmBackward>)
tensor([[49.8111]], grad_fn=<MmBackward>)
tensor([[-90.6724]], grad_fn=<MmBackward>)
tensor([[20.1142]], grad_fn=<MmBackward>)
tensor([[46.4977]], grad_fn=<MmBackward>)
tensor([[72.9224]], grad_fn=<MmBackward>)
tensor([[20.6614]], grad_fn=<MmBackward>)
tensor([[88.2596]], grad_fn=<MmBackward>)
tensor([[-35.2654]], grad_fn=<MmBackward>)
tensor([[54.0845]], grad_fn=<MmBackward>)
tensor([[-93.0115]], grad_fn=<MmBackward>)
tensor([[-118.9243]], grad_fn=<MmBackward>)
tensor([[309.7517]], grad_fn=<MmBackward>)
tensor([[88.1548]], grad_fn=<MmBackward>)
tensor([[206.6460]], grad_fn=<MmBackward>)
tensor([[343.8764]], grad_fn=<MmBackward>)
tensor([[75.7586]], grad_fn=<MmBackward>)
tensor([[316.2842]], grad_fn=<MmBackward>)
tensor([[302.4252]], grad_fn=<MmBackward>)
tensor([[31.3940]], grad_fn=<MmBackward>)
tensor([[163.1002]], grad_fn=<MmBackward>)
tensor([[233.5654]]

tensor([[190.9014]], grad_fn=<MmBackward>)
tensor([[-29.6510]], grad_fn=<MmBackward>)
tensor([[197.5696]], grad_fn=<MmBackward>)
tensor([[68.3154]], grad_fn=<MmBackward>)
tensor([[251.2923]], grad_fn=<MmBackward>)
tensor([[195.4327]], grad_fn=<MmBackward>)
tensor([[87.1496]], grad_fn=<MmBackward>)
tensor([[-142.1763]], grad_fn=<MmBackward>)
tensor([[10.8404]], grad_fn=<MmBackward>)
tensor([[348.2704]], grad_fn=<MmBackward>)
tensor([[51.3752]], grad_fn=<MmBackward>)
tensor([[47.7799]], grad_fn=<MmBackward>)
tensor([[118.9293]], grad_fn=<MmBackward>)
tensor([[-22.9425]], grad_fn=<MmBackward>)
tensor([[210.5892]], grad_fn=<MmBackward>)
tensor([[174.5713]], grad_fn=<MmBackward>)
tensor([[282.3329]], grad_fn=<MmBackward>)
tensor([[63.0584]], grad_fn=<MmBackward>)
tensor([[94.9333]], grad_fn=<MmBackward>)
tensor([[-59.2373]], grad_fn=<MmBackward>)
tensor([[179.9707]], grad_fn=<MmBackward>)
tensor([[154.1233]], grad_fn=<MmBackward>)
tensor([[100.7070]], grad_fn=<MmBackward>)
tensor([[-214.416

tensor([[-53.1506]], grad_fn=<MmBackward>)
tensor([[-198.7183]], grad_fn=<MmBackward>)
tensor([[41.1946]], grad_fn=<MmBackward>)
tensor([[-134.9808]], grad_fn=<MmBackward>)
tensor([[-114.4982]], grad_fn=<MmBackward>)
tensor([[136.7871]], grad_fn=<MmBackward>)
tensor([[148.6234]], grad_fn=<MmBackward>)
tensor([[81.5210]], grad_fn=<MmBackward>)
tensor([[85.0544]], grad_fn=<MmBackward>)
tensor([[280.7804]], grad_fn=<MmBackward>)
tensor([[17.4111]], grad_fn=<MmBackward>)
tensor([[223.5494]], grad_fn=<MmBackward>)
tensor([[-25.8014]], grad_fn=<MmBackward>)
tensor([[200.0257]], grad_fn=<MmBackward>)
tensor([[21.8481]], grad_fn=<MmBackward>)
tensor([[-58.5450]], grad_fn=<MmBackward>)
tensor([[44.0520]], grad_fn=<MmBackward>)
tensor([[120.7646]], grad_fn=<MmBackward>)
tensor([[159.9402]], grad_fn=<MmBackward>)
tensor([[-30.5381]], grad_fn=<MmBackward>)
tensor([[114.5359]], grad_fn=<MmBackward>)
tensor([[-135.7844]], grad_fn=<MmBackward>)
tensor([[24.5793]], grad_fn=<MmBackward>)
tensor([[63.97

tensor([[-136.6023]], grad_fn=<MmBackward>)
tensor([[323.6923]], grad_fn=<MmBackward>)
tensor([[-12.8186]], grad_fn=<MmBackward>)
tensor([[83.6335]], grad_fn=<MmBackward>)
tensor([[2.7881]], grad_fn=<MmBackward>)
tensor([[172.7063]], grad_fn=<MmBackward>)
tensor([[174.0801]], grad_fn=<MmBackward>)
tensor([[-35.2654]], grad_fn=<MmBackward>)
tensor([[168.5516]], grad_fn=<MmBackward>)
tensor([[20.2822]], grad_fn=<MmBackward>)
tensor([[21.2712]], grad_fn=<MmBackward>)
tensor([[26.7152]], grad_fn=<MmBackward>)
tensor([[210.5642]], grad_fn=<MmBackward>)
tensor([[-129.4084]], grad_fn=<MmBackward>)
tensor([[242.5989]], grad_fn=<MmBackward>)
tensor([[195.1788]], grad_fn=<MmBackward>)
tensor([[43.9515]], grad_fn=<MmBackward>)
tensor([[-41.6573]], grad_fn=<MmBackward>)
tensor([[2.4714]], grad_fn=<MmBackward>)
tensor([[-14.6715]], grad_fn=<MmBackward>)
tensor([[261.1427]], grad_fn=<MmBackward>)
tensor([[285.7524]], grad_fn=<MmBackward>)
tensor([[-189.3043]], grad_fn=<MmBackward>)
tensor([[62.4422]

In [27]:
x = X[3] #vector1
#y = X[2] #vector 2
### mahalanobis params
VI = torch.inverse(cov(X)) #inverse of covariance matrix
#p2 = (x - y).unsqueeze(0)}
#p1 = (x-y).unsqueeze(0).t()

In [28]:
#_mahalanobis(X)

In [29]:
def _batch_mahalanobis(L, x):
    r"""
    Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
    for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.

    Accepts batches for both L and x.
    """
    # TODO: use `torch.potrs` or similar once a backwards pass is implemented.
    flat_L = L.unsqueeze(0).reshape((-1,) + L.shape[-2:])
    L_inv = torch.stack([torch.inverse(Li.t()) for Li in flat_L]).view(L.shape)
    return (x.unsqueeze(-1) * L_inv).sum(-2).pow(2.0).sum(-1)

#x = np.random.rand(200)
#x = x.reshape(100,2)
#cov = np.cov(x)

#x = torch.Tensor(x)
#cov = torch.Tensor(cov)

#_batch_mahalanobis(cov,x)

In [33]:
for _input in X:
    print(_batch_mahalanobis(VI,_input))

tensor(5.0696e-07, grad_fn=<SumBackward1>)
tensor(1.4493e-07, grad_fn=<SumBackward1>)
tensor(1.9977e-07, grad_fn=<SumBackward1>)
tensor(3.3181e-07, grad_fn=<SumBackward1>)
tensor(7.6387e-08, grad_fn=<SumBackward1>)
tensor(1.2702e-07, grad_fn=<SumBackward1>)
tensor(8.1833e-08, grad_fn=<SumBackward1>)
tensor(3.4277e-08, grad_fn=<SumBackward1>)
tensor(3.2667e-08, grad_fn=<SumBackward1>)
tensor(5.7837e-08, grad_fn=<SumBackward1>)
tensor(4.1629e-07, grad_fn=<SumBackward1>)
tensor(3.0943e-07, grad_fn=<SumBackward1>)
tensor(3.3732e-07, grad_fn=<SumBackward1>)
tensor(5.0971e-08, grad_fn=<SumBackward1>)
tensor(5.5605e-08, grad_fn=<SumBackward1>)
tensor(9.4276e-08, grad_fn=<SumBackward1>)
tensor(1.4449e-07, grad_fn=<SumBackward1>)
tensor(3.7462e-07, grad_fn=<SumBackward1>)
tensor(6.2903e-08, grad_fn=<SumBackward1>)
tensor(7.0058e-07, grad_fn=<SumBackward1>)
tensor(7.5599e-08, grad_fn=<SumBackward1>)
tensor(1.4389e-07, grad_fn=<SumBackward1>)
tensor(3.8602e-08, grad_fn=<SumBackward1>)
tensor(1.23

In [41]:
_batch_mahalanobis(cov(l1[0][0]),)

tensor([[ 1.3569e-03, -5.2147e-04,  3.7936e-06, -1.4454e-04,  3.6596e-04,
         -2.8298e-04,  6.9286e-04, -8.2219e-04, -2.5268e-05,  9.5080e-05,
         -2.8236e-04],
        [-5.2147e-04,  7.7735e-04, -3.1055e-04, -4.7665e-05, -1.9722e-04,
          1.7001e-04, -1.4300e-04,  3.8126e-04, -7.6989e-05, -2.7402e-04,
          2.1344e-04],
        [ 3.7936e-06, -3.1055e-04,  8.1334e-04, -4.2161e-04, -5.2512e-05,
         -2.9197e-04,  1.3822e-04, -3.1444e-04,  2.8925e-04,  2.0379e-04,
         -1.3866e-04],
        [-1.4454e-04, -4.7665e-05, -4.2161e-04,  1.2845e-03,  1.3761e-04,
          1.5035e-04, -7.0705e-04,  6.5575e-04,  1.0620e-04,  3.0898e-04,
         -1.6821e-04],
        [ 3.6596e-04, -1.9722e-04, -5.2512e-05,  1.3761e-04,  5.9166e-04,
         -2.8644e-04, -4.8070e-05, -4.6564e-04, -2.9627e-04,  1.4054e-04,
         -2.1196e-04],
        [-2.8298e-04,  1.7001e-04, -2.9197e-04,  1.5035e-04, -2.8644e-04,
          1.0938e-03, -3.8682e-04,  2.7550e-04,  2.1929e-04, -7.6928e-0

In [32]:
_batch_mahalanobis(VI,x)

tensor(3.3181e-07, grad_fn=<SumBackward1>)

In [17]:
flat_L = VI.unsqueeze(0).reshape((-1,) + VI.shape[-2:])
flat_L.shape

torch.Size([1, 121, 121])

In [18]:
L_inv = torch.stack([torch.inverse(Li.t()) for Li in flat_L]).view(VI.shape)
print(L_inv.shape)
print(x.shape)

torch.Size([121, 121])
torch.Size([121])


In [26]:
(x.unsqueeze(-1) * L_inv).sum(-2).pow(2.0).sum(-1)

tensor(0.0000, grad_fn=<SumBackward1>)

In [7]:
from numpy import linalg as la

def nearestPD(A):
    """Find the nearest positive-definite matrix to input

    A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1], which
    credits [2].

    [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd

    [2] N.J. Higham, "Computing a nearest symmetric positive semidefinite
    matrix" (1988): https://doi.org/10.1016/0024-3795(88)90223-6
    """

    B = (A + A.T) / 2
    _, s, V = la.svd(B)

    H = np.dot(V.T, np.dot(np.diag(s), V))

    A2 = (B + H) / 2

    A3 = (A2 + A2.T) / 2

    if isPD(A3):
        return A3

    spacing = np.spacing(la.norm(A))
    # The above is different from [1]. It appears that MATLAB's `chol` Cholesky
    # decomposition will accept matrixes with exactly 0-eigenvalue, whereas
    # Numpy's will not. So where [1] uses `eps(mineig)` (where `eps` is Matlab
    # for `np.spacing`), we use the above definition. CAVEAT: our `spacing`
    # will be much larger than [1]'s `eps(mineig)`, since `mineig` is usually on
    # the order of 1e-16, and `eps(1e-16)` is on the order of 1e-34, whereas
    # `spacing` will, for Gaussian random matrixes of small dimension, be on
    # othe order of 1e-16. In practice, both ways converge, as the unit test
    # below suggests.
    I = np.eye(A.shape[0])
    k = 1
    while not isPD(A3):
        mineig = np.min(np.real(la.eigvals(A3)))
        A3 += I * (-mineig * k**2 + spacing)
        k += 1

    return A3

def isPD(B):
    """Returns true when input is positive-definite, via Cholesky"""
    try:
        _ = la.cholesky(B)
        return True
    except la.LinAlgError:
        return False

if __name__ == '__main__':
    import numpy as np
    for i in range(10):
        for j in range(2, 100):
            A = np.random.randn(j, j)
            B = nearestPD(A)
            assert(isPD(B))
    print('unit test passed!')

unit test passed!


In [34]:
from scipy.spatial import distance
import numpy as np
V = np.cov(M.T)
VI = np.linalg.inv(V)
if not isPD(VI):
    VI = nearestPD(VI)
for i in range(M.shape[0]):
    u = M[i]
    for j in range(M.shape[0]):
        if i == j:
            continue
        v = M[j]
        print(distance.mahalanobis(u,v,VI))
#mahalanobis test in numpy PASSED

NameError: name 'M' is not defined