In [1]:
import numpy as np
from numpy import linalg as la
import torch
import torch.nn as nn
import torchvision
from torchvision.models import alexnet
from nearestPD import *

model = alexnet()
convs = []
# list of a conv layers within alexnn
for i,k in enumerate(model.modules()):
    if isinstance(k, nn.Conv2d):
        convs.append(k)
l1 = convs[0].weight

for j in range(l1.shape[1]): # iterate over channels
    X = torch.reshape(l1[0][j], [1,l1[0][j].shape[1]**2])  # gets first elem
    for i,w in enumerate(l1): # iterates over filters
        if i == 0:
            continue
        else:
            y = torch.reshape(w[0], [1,w[j].shape[1]**2])
            X = torch.cat((X,y))
# X is a matrix with each line being a filter

In [2]:
def _mahalanobis(X):
    V = torch.inverse(cov(X))
    if not _isPD(V):
        VI = _nearestPD(V) #nearest Positive Definite of covariance matrix
    else:
        VI = V
    total_dist = 0
    for i,v in enumerate(X):
        dist = 0
        for j,u in enumerate(X):
            if i == j:
                continue
            x = (v-u).unsqueeze(0).t()
            y = (v - u).unsqueeze(0)
            
            dist = (torch.mm(torch.mm(y,VI),x)) #sqrt of dist returns NaN (?)
            
            total_dist +=dist
            print(dist)
    return total_dist

In [3]:
def _nearestPD(A):
    B = (A + A.t()) / 2
    _, s, V = torch.svd(B)
    
    H = torch.mm(V.t(), torch.mm(torch.diag(s),V))
    
    A2 = (B+H) / 2
    A3 = (A2 + A2.t())/ 2
    
    if _isPD(A3):
        return A3
    spacing = np.spacing(la.norm(A.detach().numpy()))
    
    I = torch.eye(A.shape[0])
    k = 1
    while not _isPD(A3):
        mineig = np.min(np.real(la.eigvals(A3.detach().numpy())))
        A3 += I * (-mineig * k**2 + spacing)
        k += 1
    return A3


def _isPD(B):
    try:
        M = B.detach().numpy()
        _ = la.cholesky(M)
        return True
    except la.LinAlgError:
        return False

In [4]:
for j in range(l1.shape[1]): # iterate over channels
    X = torch.reshape(l1[0][j], [1,l1[0][j].shape[1]**2])  # gets first elem
    for i,w in enumerate(l1): # iterates over filters
        if i == 0:
            continue
        else:
            y = torch.reshape(w[0], [1,w[j].shape[1]**2])
            X = torch.cat((X,y))
    dist = _mahalanobis(X)

tensor([[26378825728.]], grad_fn=<MmBackward>)
tensor([[29433227264.]], grad_fn=<MmBackward>)
tensor([[25699596288.]], grad_fn=<MmBackward>)
tensor([[28511621120.]], grad_fn=<MmBackward>)
tensor([[24464375808.]], grad_fn=<MmBackward>)
tensor([[27644585984.]], grad_fn=<MmBackward>)
tensor([[28142780416.]], grad_fn=<MmBackward>)
tensor([[20345993216.]], grad_fn=<MmBackward>)
tensor([[26889943040.]], grad_fn=<MmBackward>)
tensor([[25004042240.]], grad_fn=<MmBackward>)
tensor([[24303572992.]], grad_fn=<MmBackward>)
tensor([[24955303936.]], grad_fn=<MmBackward>)
tensor([[23217147904.]], grad_fn=<MmBackward>)
tensor([[26546788352.]], grad_fn=<MmBackward>)
tensor([[25650143232.]], grad_fn=<MmBackward>)
tensor([[29459294208.]], grad_fn=<MmBackward>)
tensor([[17585946624.]], grad_fn=<MmBackward>)
tensor([[22142582784.]], grad_fn=<MmBackward>)
tensor([[21952075776.]], grad_fn=<MmBackward>)
tensor([[25523359744.]], grad_fn=<MmBackward>)
tensor([[24857501696.]], grad_fn=<MmBackward>)
tensor([[2734

tensor([[23941416960.]], grad_fn=<MmBackward>)
tensor([[25894907904.]], grad_fn=<MmBackward>)
tensor([[23129964544.]], grad_fn=<MmBackward>)
tensor([[24279068672.]], grad_fn=<MmBackward>)
tensor([[21914488832.]], grad_fn=<MmBackward>)
tensor([[22713673728.]], grad_fn=<MmBackward>)
tensor([[21155110912.]], grad_fn=<MmBackward>)
tensor([[19417843712.]], grad_fn=<MmBackward>)
tensor([[25636501504.]], grad_fn=<MmBackward>)
tensor([[25526423552.]], grad_fn=<MmBackward>)
tensor([[20907700224.]], grad_fn=<MmBackward>)
tensor([[25978449920.]], grad_fn=<MmBackward>)
tensor([[25065392128.]], grad_fn=<MmBackward>)
tensor([[22709958656.]], grad_fn=<MmBackward>)
tensor([[27189803008.]], grad_fn=<MmBackward>)
tensor([[23980296192.]], grad_fn=<MmBackward>)
tensor([[26648041472.]], grad_fn=<MmBackward>)
tensor([[21345193984.]], grad_fn=<MmBackward>)
tensor([[22890823680.]], grad_fn=<MmBackward>)
tensor([[25317648384.]], grad_fn=<MmBackward>)
tensor([[24168460288.]], grad_fn=<MmBackward>)
tensor([[2034

tensor([[21015343104.]], grad_fn=<MmBackward>)
tensor([[29870735360.]], grad_fn=<MmBackward>)
tensor([[23235770368.]], grad_fn=<MmBackward>)
tensor([[22976200704.]], grad_fn=<MmBackward>)
tensor([[26565584896.]], grad_fn=<MmBackward>)
tensor([[26223929344.]], grad_fn=<MmBackward>)
tensor([[23832328192.]], grad_fn=<MmBackward>)
tensor([[27515248640.]], grad_fn=<MmBackward>)
tensor([[25807964160.]], grad_fn=<MmBackward>)
tensor([[26663516160.]], grad_fn=<MmBackward>)
tensor([[25305776128.]], grad_fn=<MmBackward>)
tensor([[22997671936.]], grad_fn=<MmBackward>)
tensor([[27011041280.]], grad_fn=<MmBackward>)
tensor([[24580929536.]], grad_fn=<MmBackward>)
tensor([[23770052608.]], grad_fn=<MmBackward>)
tensor([[20694464512.]], grad_fn=<MmBackward>)
tensor([[27837609984.]], grad_fn=<MmBackward>)
tensor([[29038213120.]], grad_fn=<MmBackward>)
tensor([[22695772160.]], grad_fn=<MmBackward>)
tensor([[29139046400.]], grad_fn=<MmBackward>)
tensor([[29459294208.]], grad_fn=<MmBackward>)
tensor([[2693

tensor([[26413006848.]], grad_fn=<MmBackward>)
tensor([[29145704448.]], grad_fn=<MmBackward>)
tensor([[29077968896.]], grad_fn=<MmBackward>)
tensor([[25803106304.]], grad_fn=<MmBackward>)
tensor([[28152412160.]], grad_fn=<MmBackward>)
tensor([[26418577408.]], grad_fn=<MmBackward>)
tensor([[27309314048.]], grad_fn=<MmBackward>)
tensor([[23949694976.]], grad_fn=<MmBackward>)
tensor([[28433944576.]], grad_fn=<MmBackward>)
tensor([[20233304064.]], grad_fn=<MmBackward>)
tensor([[25302048768.]], grad_fn=<MmBackward>)
tensor([[28202616832.]], grad_fn=<MmBackward>)
tensor([[27249328128.]], grad_fn=<MmBackward>)
tensor([[20621367296.]], grad_fn=<MmBackward>)
tensor([[29014992896.]], grad_fn=<MmBackward>)
tensor([[22269960192.]], grad_fn=<MmBackward>)
tensor([[25705271296.]], grad_fn=<MmBackward>)
tensor([[27352438784.]], grad_fn=<MmBackward>)
tensor([[21976401920.]], grad_fn=<MmBackward>)
tensor([[25494437888.]], grad_fn=<MmBackward>)
tensor([[21502115840.]], grad_fn=<MmBackward>)
tensor([[2121

tensor([[23007725568.]], grad_fn=<MmBackward>)
tensor([[30274066432.]], grad_fn=<MmBackward>)
tensor([[27285069824.]], grad_fn=<MmBackward>)
tensor([[23828144128.]], grad_fn=<MmBackward>)
tensor([[26344525824.]], grad_fn=<MmBackward>)
tensor([[20421498880.]], grad_fn=<MmBackward>)
tensor([[25240936448.]], grad_fn=<MmBackward>)
tensor([[28207687680.]], grad_fn=<MmBackward>)
tensor([[23634466816.]], grad_fn=<MmBackward>)
tensor([[28508264448.]], grad_fn=<MmBackward>)
tensor([[30675496960.]], grad_fn=<MmBackward>)
tensor([[30298306560.]], grad_fn=<MmBackward>)
tensor([[30276745216.]], grad_fn=<MmBackward>)
tensor([[26593566720.]], grad_fn=<MmBackward>)
tensor([[26332274688.]], grad_fn=<MmBackward>)
tensor([[24230115328.]], grad_fn=<MmBackward>)
tensor([[23656071168.]], grad_fn=<MmBackward>)
tensor([[25566244864.]], grad_fn=<MmBackward>)
tensor([[24657563648.]], grad_fn=<MmBackward>)
tensor([[33758765056.]], grad_fn=<MmBackward>)
tensor([[21664180224.]], grad_fn=<MmBackward>)
tensor([[2328

tensor([[28746774528.]], grad_fn=<MmBackward>)
tensor([[23516690432.]], grad_fn=<MmBackward>)
tensor([[20965928960.]], grad_fn=<MmBackward>)
tensor([[26285604864.]], grad_fn=<MmBackward>)
tensor([[26853244928.]], grad_fn=<MmBackward>)
tensor([[22923749376.]], grad_fn=<MmBackward>)
tensor([[25848688640.]], grad_fn=<MmBackward>)
tensor([[26803376128.]], grad_fn=<MmBackward>)
tensor([[27365828608.]], grad_fn=<MmBackward>)
tensor([[26650505216.]], grad_fn=<MmBackward>)
tensor([[25674287104.]], grad_fn=<MmBackward>)
tensor([[24235483136.]], grad_fn=<MmBackward>)
tensor([[22971400192.]], grad_fn=<MmBackward>)
tensor([[23446452224.]], grad_fn=<MmBackward>)
tensor([[23229859840.]], grad_fn=<MmBackward>)
tensor([[23738779648.]], grad_fn=<MmBackward>)
tensor([[32655581184.]], grad_fn=<MmBackward>)
tensor([[25487069184.]], grad_fn=<MmBackward>)
tensor([[26523856896.]], grad_fn=<MmBackward>)
tensor([[21700212736.]], grad_fn=<MmBackward>)
tensor([[26399252480.]], grad_fn=<MmBackward>)
tensor([[2215

tensor([[25677684736.]], grad_fn=<MmBackward>)
tensor([[22598031360.]], grad_fn=<MmBackward>)
tensor([[27613474816.]], grad_fn=<MmBackward>)
tensor([[27298121728.]], grad_fn=<MmBackward>)
tensor([[22549065728.]], grad_fn=<MmBackward>)
tensor([[27549593600.]], grad_fn=<MmBackward>)
tensor([[25664749568.]], grad_fn=<MmBackward>)
tensor([[29543667712.]], grad_fn=<MmBackward>)
tensor([[26201692160.]], grad_fn=<MmBackward>)
tensor([[24090404864.]], grad_fn=<MmBackward>)
tensor([[26446446592.]], grad_fn=<MmBackward>)
tensor([[29240186880.]], grad_fn=<MmBackward>)
tensor([[29170339840.]], grad_fn=<MmBackward>)
tensor([[23538241536.]], grad_fn=<MmBackward>)
tensor([[22328631296.]], grad_fn=<MmBackward>)
tensor([[26384173056.]], grad_fn=<MmBackward>)
tensor([[25987809280.]], grad_fn=<MmBackward>)
tensor([[26708905984.]], grad_fn=<MmBackward>)
tensor([[24086677504.]], grad_fn=<MmBackward>)
tensor([[30355990528.]], grad_fn=<MmBackward>)
tensor([[22713673728.]], grad_fn=<MmBackward>)
tensor([[2666

tensor([[22778345472.]], grad_fn=<MmBackward>)
tensor([[27298121728.]], grad_fn=<MmBackward>)
tensor([[26478376960.]], grad_fn=<MmBackward>)
tensor([[26792974336.]], grad_fn=<MmBackward>)
tensor([[21393113088.]], grad_fn=<MmBackward>)
tensor([[30422736896.]], grad_fn=<MmBackward>)
tensor([[25392920576.]], grad_fn=<MmBackward>)
tensor([[27499819008.]], grad_fn=<MmBackward>)
tensor([[22255814656.]], grad_fn=<MmBackward>)
tensor([[24741990400.]], grad_fn=<MmBackward>)
tensor([[25955194880.]], grad_fn=<MmBackward>)
tensor([[25818916864.]], grad_fn=<MmBackward>)
tensor([[24430569472.]], grad_fn=<MmBackward>)
tensor([[28235294720.]], grad_fn=<MmBackward>)
tensor([[22914996224.]], grad_fn=<MmBackward>)
tensor([[25485158400.]], grad_fn=<MmBackward>)
tensor([[28012120064.]], grad_fn=<MmBackward>)
tensor([[25096972288.]], grad_fn=<MmBackward>)
tensor([[26770735104.]], grad_fn=<MmBackward>)
tensor([[33650325504.]], grad_fn=<MmBackward>)
tensor([[29430472704.]], grad_fn=<MmBackward>)
tensor([[2847

tensor([[23295270912.]], grad_fn=<MmBackward>)
tensor([[24207726592.]], grad_fn=<MmBackward>)
tensor([[28411617280.]], grad_fn=<MmBackward>)
tensor([[24381024256.]], grad_fn=<MmBackward>)
tensor([[25240936448.]], grad_fn=<MmBackward>)
tensor([[26476939264.]], grad_fn=<MmBackward>)
tensor([[23127492608.]], grad_fn=<MmBackward>)
tensor([[21513402368.]], grad_fn=<MmBackward>)
tensor([[24637136896.]], grad_fn=<MmBackward>)
tensor([[22880346112.]], grad_fn=<MmBackward>)
tensor([[18610210816.]], grad_fn=<MmBackward>)
tensor([[27321464832.]], grad_fn=<MmBackward>)
tensor([[25848688640.]], grad_fn=<MmBackward>)
tensor([[25962473472.]], grad_fn=<MmBackward>)
tensor([[24352262144.]], grad_fn=<MmBackward>)
tensor([[24361971712.]], grad_fn=<MmBackward>)
tensor([[27447517184.]], grad_fn=<MmBackward>)
tensor([[24751474688.]], grad_fn=<MmBackward>)
tensor([[24547641344.]], grad_fn=<MmBackward>)
tensor([[28008304640.]], grad_fn=<MmBackward>)
tensor([[29240186880.]], grad_fn=<MmBackward>)
tensor([[2868

tensor([[30033293312.]], grad_fn=<MmBackward>)
tensor([[30329673728.]], grad_fn=<MmBackward>)
tensor([[27742208000.]], grad_fn=<MmBackward>)
tensor([[31652722688.]], grad_fn=<MmBackward>)
tensor([[28392308736.]], grad_fn=<MmBackward>)
tensor([[33242460160.]], grad_fn=<MmBackward>)
tensor([[26418409472.]], grad_fn=<MmBackward>)
tensor([[25710247936.]], grad_fn=<MmBackward>)
tensor([[27067514880.]], grad_fn=<MmBackward>)
tensor([[23475120128.]], grad_fn=<MmBackward>)
tensor([[30696939520.]], grad_fn=<MmBackward>)
tensor([[25540595712.]], grad_fn=<MmBackward>)
tensor([[24836315136.]], grad_fn=<MmBackward>)
tensor([[27675533312.]], grad_fn=<MmBackward>)
tensor([[20701362176.]], grad_fn=<MmBackward>)
tensor([[32134576128.]], grad_fn=<MmBackward>)
tensor([[25676414976.]], grad_fn=<MmBackward>)
tensor([[28600838144.]], grad_fn=<MmBackward>)
tensor([[26592299008.]], grad_fn=<MmBackward>)
tensor([[24025346048.]], grad_fn=<MmBackward>)
tensor([[25511337984.]], grad_fn=<MmBackward>)
tensor([[2427

tensor([[29178353664.]], grad_fn=<MmBackward>)
tensor([[31160969216.]], grad_fn=<MmBackward>)
tensor([[22954450944.]], grad_fn=<MmBackward>)
tensor([[26774431744.]], grad_fn=<MmBackward>)
tensor([[29124151296.]], grad_fn=<MmBackward>)
tensor([[27523241984.]], grad_fn=<MmBackward>)
tensor([[24219461632.]], grad_fn=<MmBackward>)
tensor([[21305638912.]], grad_fn=<MmBackward>)
tensor([[24076834816.]], grad_fn=<MmBackward>)
tensor([[32539369472.]], grad_fn=<MmBackward>)
tensor([[30714279936.]], grad_fn=<MmBackward>)
tensor([[24620500992.]], grad_fn=<MmBackward>)
tensor([[24958539776.]], grad_fn=<MmBackward>)
tensor([[28420194304.]], grad_fn=<MmBackward>)
tensor([[26604285952.]], grad_fn=<MmBackward>)
tensor([[26623633408.]], grad_fn=<MmBackward>)
tensor([[30893185024.]], grad_fn=<MmBackward>)
tensor([[26636355584.]], grad_fn=<MmBackward>)
tensor([[26629238784.]], grad_fn=<MmBackward>)
tensor([[29961488384.]], grad_fn=<MmBackward>)
tensor([[29584130048.]], grad_fn=<MmBackward>)
tensor([[2459

tensor([[31233986560.]], grad_fn=<MmBackward>)
tensor([[30309322752.]], grad_fn=<MmBackward>)
tensor([[29982937088.]], grad_fn=<MmBackward>)
tensor([[27566319616.]], grad_fn=<MmBackward>)
tensor([[26411698176.]], grad_fn=<MmBackward>)
tensor([[29766787072.]], grad_fn=<MmBackward>)
tensor([[34735943680.]], grad_fn=<MmBackward>)
tensor([[34352963584.]], grad_fn=<MmBackward>)
tensor([[29538398208.]], grad_fn=<MmBackward>)
tensor([[29239255040.]], grad_fn=<MmBackward>)
tensor([[35377864704.]], grad_fn=<MmBackward>)
tensor([[30015670272.]], grad_fn=<MmBackward>)
tensor([[20875411456.]], grad_fn=<MmBackward>)
tensor([[30060732416.]], grad_fn=<MmBackward>)
tensor([[30587318272.]], grad_fn=<MmBackward>)
tensor([[21734348800.]], grad_fn=<MmBackward>)
tensor([[30320621568.]], grad_fn=<MmBackward>)
tensor([[32483538944.]], grad_fn=<MmBackward>)
tensor([[24000210944.]], grad_fn=<MmBackward>)
tensor([[24901226496.]], grad_fn=<MmBackward>)
tensor([[32457410560.]], grad_fn=<MmBackward>)
tensor([[2876

tensor([[26410342400.]], grad_fn=<MmBackward>)
tensor([[27067365376.]], grad_fn=<MmBackward>)
tensor([[29557268480.]], grad_fn=<MmBackward>)
tensor([[24023463936.]], grad_fn=<MmBackward>)
tensor([[33035356160.]], grad_fn=<MmBackward>)
tensor([[29692940288.]], grad_fn=<MmBackward>)
tensor([[32028618752.]], grad_fn=<MmBackward>)
tensor([[27224668160.]], grad_fn=<MmBackward>)
tensor([[31469531136.]], grad_fn=<MmBackward>)
tensor([[31878887424.]], grad_fn=<MmBackward>)
tensor([[24255275008.]], grad_fn=<MmBackward>)
tensor([[29909377024.]], grad_fn=<MmBackward>)
tensor([[29623314432.]], grad_fn=<MmBackward>)
tensor([[30368397312.]], grad_fn=<MmBackward>)
tensor([[22954450944.]], grad_fn=<MmBackward>)
tensor([[21135167488.]], grad_fn=<MmBackward>)
tensor([[30755371008.]], grad_fn=<MmBackward>)
tensor([[25352534016.]], grad_fn=<MmBackward>)
tensor([[30278973440.]], grad_fn=<MmBackward>)
tensor([[28426942464.]], grad_fn=<MmBackward>)
tensor([[32627752960.]], grad_fn=<MmBackward>)
tensor([[2897

tensor([[35593543680.]], grad_fn=<MmBackward>)
tensor([[29494374400.]], grad_fn=<MmBackward>)
tensor([[33371744256.]], grad_fn=<MmBackward>)
tensor([[28639281152.]], grad_fn=<MmBackward>)
tensor([[28508999680.]], grad_fn=<MmBackward>)
tensor([[31760136192.]], grad_fn=<MmBackward>)
tensor([[26846570496.]], grad_fn=<MmBackward>)
tensor([[28735195136.]], grad_fn=<MmBackward>)
tensor([[30185459712.]], grad_fn=<MmBackward>)
tensor([[31638052864.]], grad_fn=<MmBackward>)
tensor([[32539369472.]], grad_fn=<MmBackward>)
tensor([[26333925376.]], grad_fn=<MmBackward>)
tensor([[31741992960.]], grad_fn=<MmBackward>)
tensor([[34596261888.]], grad_fn=<MmBackward>)
tensor([[30726809600.]], grad_fn=<MmBackward>)
tensor([[26168219648.]], grad_fn=<MmBackward>)
tensor([[25585129472.]], grad_fn=<MmBackward>)
tensor([[25602695168.]], grad_fn=<MmBackward>)
tensor([[32483538944.]], grad_fn=<MmBackward>)
tensor([[34852724736.]], grad_fn=<MmBackward>)
tensor([[26867417088.]], grad_fn=<MmBackward>)
tensor([[3770

tensor([[35566395392.]], grad_fn=<MmBackward>)
tensor([[33295228928.]], grad_fn=<MmBackward>)
tensor([[29299259392.]], grad_fn=<MmBackward>)
tensor([[35829198848.]], grad_fn=<MmBackward>)
tensor([[26162997248.]], grad_fn=<MmBackward>)
tensor([[29297530880.]], grad_fn=<MmBackward>)
tensor([[31421943808.]], grad_fn=<MmBackward>)
tensor([[30118967296.]], grad_fn=<MmBackward>)
tensor([[28429502464.]], grad_fn=<MmBackward>)
tensor([[35148722176.]], grad_fn=<MmBackward>)
tensor([[23748601856.]], grad_fn=<MmBackward>)
tensor([[26044309504.]], grad_fn=<MmBackward>)
tensor([[28390975488.]], grad_fn=<MmBackward>)
tensor([[30384269312.]], grad_fn=<MmBackward>)
tensor([[28672968704.]], grad_fn=<MmBackward>)
tensor([[23341086720.]], grad_fn=<MmBackward>)
tensor([[25389830144.]], grad_fn=<MmBackward>)
tensor([[28433614848.]], grad_fn=<MmBackward>)
tensor([[29045219328.]], grad_fn=<MmBackward>)
tensor([[25079285760.]], grad_fn=<MmBackward>)
tensor([[26805141504.]], grad_fn=<MmBackward>)
tensor([[2987

tensor([[30795038720.]], grad_fn=<MmBackward>)
tensor([[27897683968.]], grad_fn=<MmBackward>)
tensor([[25585723392.]], grad_fn=<MmBackward>)
tensor([[30257596416.]], grad_fn=<MmBackward>)
tensor([[32395741184.]], grad_fn=<MmBackward>)
tensor([[25098211328.]], grad_fn=<MmBackward>)
tensor([[27898040320.]], grad_fn=<MmBackward>)
tensor([[27108104192.]], grad_fn=<MmBackward>)
tensor([[26469210112.]], grad_fn=<MmBackward>)
tensor([[27181426688.]], grad_fn=<MmBackward>)
tensor([[27857031168.]], grad_fn=<MmBackward>)
tensor([[28548962304.]], grad_fn=<MmBackward>)
tensor([[29967933440.]], grad_fn=<MmBackward>)
tensor([[25987504128.]], grad_fn=<MmBackward>)
tensor([[25625507840.]], grad_fn=<MmBackward>)
tensor([[30840037376.]], grad_fn=<MmBackward>)
tensor([[31529816064.]], grad_fn=<MmBackward>)
tensor([[24580837376.]], grad_fn=<MmBackward>)
tensor([[26176215040.]], grad_fn=<MmBackward>)
tensor([[26916069376.]], grad_fn=<MmBackward>)
tensor([[26493722624.]], grad_fn=<MmBackward>)
tensor([[2282

tensor([[27477260288.]], grad_fn=<MmBackward>)
tensor([[27211436032.]], grad_fn=<MmBackward>)
tensor([[23679772672.]], grad_fn=<MmBackward>)
tensor([[27251169280.]], grad_fn=<MmBackward>)
tensor([[33594183680.]], grad_fn=<MmBackward>)
tensor([[31861155840.]], grad_fn=<MmBackward>)
tensor([[27740194816.]], grad_fn=<MmBackward>)
tensor([[28437026816.]], grad_fn=<MmBackward>)
tensor([[29774694400.]], grad_fn=<MmBackward>)
tensor([[28382875648.]], grad_fn=<MmBackward>)
tensor([[29421189120.]], grad_fn=<MmBackward>)
tensor([[33033979904.]], grad_fn=<MmBackward>)
tensor([[31170328576.]], grad_fn=<MmBackward>)
tensor([[31212513280.]], grad_fn=<MmBackward>)
tensor([[25104992256.]], grad_fn=<MmBackward>)
tensor([[29210464256.]], grad_fn=<MmBackward>)
tensor([[30398875648.]], grad_fn=<MmBackward>)
tensor([[27843534848.]], grad_fn=<MmBackward>)
tensor([[29128216576.]], grad_fn=<MmBackward>)
tensor([[28120301568.]], grad_fn=<MmBackward>)
tensor([[24930617344.]], grad_fn=<MmBackward>)
tensor([[3170

tensor([[19810625536.]], grad_fn=<MmBackward>)
tensor([[20552738816.]], grad_fn=<MmBackward>)
tensor([[19119308800.]], grad_fn=<MmBackward>)
tensor([[19139647488.]], grad_fn=<MmBackward>)
tensor([[21787881472.]], grad_fn=<MmBackward>)
tensor([[17657759744.]], grad_fn=<MmBackward>)
tensor([[18449360896.]], grad_fn=<MmBackward>)
tensor([[19671857152.]], grad_fn=<MmBackward>)
tensor([[16355533824.]], grad_fn=<MmBackward>)
tensor([[19807954944.]], grad_fn=<MmBackward>)
tensor([[17761851392.]], grad_fn=<MmBackward>)
tensor([[17767622656.]], grad_fn=<MmBackward>)
tensor([[21940766720.]], grad_fn=<MmBackward>)
tensor([[18046523392.]], grad_fn=<MmBackward>)
tensor([[22460499968.]], grad_fn=<MmBackward>)
tensor([[16351988736.]], grad_fn=<MmBackward>)
tensor([[19494901760.]], grad_fn=<MmBackward>)
tensor([[18387847168.]], grad_fn=<MmBackward>)
tensor([[21134329856.]], grad_fn=<MmBackward>)
tensor([[17801986048.]], grad_fn=<MmBackward>)
tensor([[21188462592.]], grad_fn=<MmBackward>)
tensor([[2197

tensor([[20252798976.]], grad_fn=<MmBackward>)
tensor([[19834562560.]], grad_fn=<MmBackward>)
tensor([[22706704384.]], grad_fn=<MmBackward>)
tensor([[23480084480.]], grad_fn=<MmBackward>)
tensor([[19734341632.]], grad_fn=<MmBackward>)
tensor([[19668365312.]], grad_fn=<MmBackward>)
tensor([[19407589376.]], grad_fn=<MmBackward>)
tensor([[18048569344.]], grad_fn=<MmBackward>)
tensor([[19912478720.]], grad_fn=<MmBackward>)
tensor([[21817235456.]], grad_fn=<MmBackward>)
tensor([[21684936704.]], grad_fn=<MmBackward>)
tensor([[23266310144.]], grad_fn=<MmBackward>)
tensor([[20779739136.]], grad_fn=<MmBackward>)
tensor([[21606578176.]], grad_fn=<MmBackward>)
tensor([[20430921728.]], grad_fn=<MmBackward>)
tensor([[16983042048.]], grad_fn=<MmBackward>)
tensor([[16904112128.]], grad_fn=<MmBackward>)
tensor([[20922441728.]], grad_fn=<MmBackward>)
tensor([[19200350208.]], grad_fn=<MmBackward>)
tensor([[17657759744.]], grad_fn=<MmBackward>)
tensor([[19368091648.]], grad_fn=<MmBackward>)
tensor([[2216

tensor([[16521389056.]], grad_fn=<MmBackward>)
tensor([[21870235648.]], grad_fn=<MmBackward>)
tensor([[19991629824.]], grad_fn=<MmBackward>)
tensor([[20993665024.]], grad_fn=<MmBackward>)
tensor([[19980068864.]], grad_fn=<MmBackward>)
tensor([[17928148992.]], grad_fn=<MmBackward>)
tensor([[19418046464.]], grad_fn=<MmBackward>)
tensor([[18645362688.]], grad_fn=<MmBackward>)
tensor([[17529100288.]], grad_fn=<MmBackward>)
tensor([[20892807168.]], grad_fn=<MmBackward>)
tensor([[18682679296.]], grad_fn=<MmBackward>)
tensor([[17099048960.]], grad_fn=<MmBackward>)
tensor([[19472695296.]], grad_fn=<MmBackward>)
tensor([[21810657280.]], grad_fn=<MmBackward>)
tensor([[21445908480.]], grad_fn=<MmBackward>)
tensor([[22982828032.]], grad_fn=<MmBackward>)
tensor([[18776856576.]], grad_fn=<MmBackward>)
tensor([[21623902208.]], grad_fn=<MmBackward>)
tensor([[16472231936.]], grad_fn=<MmBackward>)
tensor([[21621145600.]], grad_fn=<MmBackward>)
tensor([[20185872384.]], grad_fn=<MmBackward>)
tensor([[2033

tensor([[15509648384.]], grad_fn=<MmBackward>)
tensor([[21465161728.]], grad_fn=<MmBackward>)
tensor([[13659868160.]], grad_fn=<MmBackward>)
tensor([[16160998400.]], grad_fn=<MmBackward>)
tensor([[17432332288.]], grad_fn=<MmBackward>)
tensor([[17736452096.]], grad_fn=<MmBackward>)
tensor([[15650904064.]], grad_fn=<MmBackward>)
tensor([[18776856576.]], grad_fn=<MmBackward>)
tensor([[18367795200.]], grad_fn=<MmBackward>)
tensor([[18364069888.]], grad_fn=<MmBackward>)
tensor([[16841626624.]], grad_fn=<MmBackward>)
tensor([[17711788032.]], grad_fn=<MmBackward>)
tensor([[16280263680.]], grad_fn=<MmBackward>)
tensor([[17022770176.]], grad_fn=<MmBackward>)
tensor([[18429712384.]], grad_fn=<MmBackward>)
tensor([[15928893440.]], grad_fn=<MmBackward>)
tensor([[19622582272.]], grad_fn=<MmBackward>)
tensor([[17671680000.]], grad_fn=<MmBackward>)
tensor([[18442811392.]], grad_fn=<MmBackward>)
tensor([[16652156928.]], grad_fn=<MmBackward>)
tensor([[15430371328.]], grad_fn=<MmBackward>)
tensor([[1967

tensor([[17780520960.]], grad_fn=<MmBackward>)
tensor([[19515420672.]], grad_fn=<MmBackward>)
tensor([[18183583744.]], grad_fn=<MmBackward>)
tensor([[19343847424.]], grad_fn=<MmBackward>)
tensor([[17103271936.]], grad_fn=<MmBackward>)
tensor([[20907153408.]], grad_fn=<MmBackward>)
tensor([[18830972928.]], grad_fn=<MmBackward>)
tensor([[18579251200.]], grad_fn=<MmBackward>)
tensor([[13863864320.]], grad_fn=<MmBackward>)
tensor([[17667844096.]], grad_fn=<MmBackward>)
tensor([[16903699456.]], grad_fn=<MmBackward>)
tensor([[16742212608.]], grad_fn=<MmBackward>)
tensor([[16249539584.]], grad_fn=<MmBackward>)
tensor([[15119256576.]], grad_fn=<MmBackward>)
tensor([[16946307072.]], grad_fn=<MmBackward>)
tensor([[13613245440.]], grad_fn=<MmBackward>)
tensor([[16953266176.]], grad_fn=<MmBackward>)
tensor([[13730896896.]], grad_fn=<MmBackward>)
tensor([[16312983552.]], grad_fn=<MmBackward>)
tensor([[18553180160.]], grad_fn=<MmBackward>)
tensor([[22224484352.]], grad_fn=<MmBackward>)
tensor([[1977

tensor([[20909977600.]], grad_fn=<MmBackward>)
tensor([[16474177536.]], grad_fn=<MmBackward>)
tensor([[19316316160.]], grad_fn=<MmBackward>)
tensor([[18376130560.]], grad_fn=<MmBackward>)
tensor([[19723950080.]], grad_fn=<MmBackward>)
tensor([[19296856064.]], grad_fn=<MmBackward>)
tensor([[19036164096.]], grad_fn=<MmBackward>)
tensor([[13929764864.]], grad_fn=<MmBackward>)
tensor([[18804355072.]], grad_fn=<MmBackward>)
tensor([[21171324928.]], grad_fn=<MmBackward>)
tensor([[20219621376.]], grad_fn=<MmBackward>)
tensor([[24855822336.]], grad_fn=<MmBackward>)
tensor([[21638223872.]], grad_fn=<MmBackward>)
tensor([[17408270336.]], grad_fn=<MmBackward>)
tensor([[17940084736.]], grad_fn=<MmBackward>)
tensor([[20261832704.]], grad_fn=<MmBackward>)
tensor([[16299182080.]], grad_fn=<MmBackward>)
tensor([[21842917376.]], grad_fn=<MmBackward>)
tensor([[20990070784.]], grad_fn=<MmBackward>)
tensor([[18425081856.]], grad_fn=<MmBackward>)
tensor([[20139802624.]], grad_fn=<MmBackward>)
tensor([[1603

tensor([[20302886912.]], grad_fn=<MmBackward>)
tensor([[16781262848.]], grad_fn=<MmBackward>)
tensor([[19890124800.]], grad_fn=<MmBackward>)
tensor([[20242884608.]], grad_fn=<MmBackward>)
tensor([[19834562560.]], grad_fn=<MmBackward>)
tensor([[18339567616.]], grad_fn=<MmBackward>)
tensor([[18781906944.]], grad_fn=<MmBackward>)
tensor([[16662912000.]], grad_fn=<MmBackward>)
tensor([[19859257344.]], grad_fn=<MmBackward>)
tensor([[15471842304.]], grad_fn=<MmBackward>)
tensor([[19553900544.]], grad_fn=<MmBackward>)
tensor([[14971152384.]], grad_fn=<MmBackward>)
tensor([[19743541248.]], grad_fn=<MmBackward>)
tensor([[17833146368.]], grad_fn=<MmBackward>)
tensor([[20527618048.]], grad_fn=<MmBackward>)
tensor([[16211366912.]], grad_fn=<MmBackward>)
tensor([[19097475072.]], grad_fn=<MmBackward>)
tensor([[22705143808.]], grad_fn=<MmBackward>)
tensor([[19510982656.]], grad_fn=<MmBackward>)
tensor([[17598777344.]], grad_fn=<MmBackward>)
tensor([[20001140736.]], grad_fn=<MmBackward>)
tensor([[2094

tensor([[19552221184.]], grad_fn=<MmBackward>)
tensor([[19064981504.]], grad_fn=<MmBackward>)
tensor([[21817235456.]], grad_fn=<MmBackward>)
tensor([[20081455104.]], grad_fn=<MmBackward>)
tensor([[20030584832.]], grad_fn=<MmBackward>)
tensor([[20477052928.]], grad_fn=<MmBackward>)
tensor([[18531584000.]], grad_fn=<MmBackward>)
tensor([[21908668416.]], grad_fn=<MmBackward>)
tensor([[18828034048.]], grad_fn=<MmBackward>)
tensor([[18066104320.]], grad_fn=<MmBackward>)
tensor([[21038276608.]], grad_fn=<MmBackward>)
tensor([[20302100480.]], grad_fn=<MmBackward>)
tensor([[22968057856.]], grad_fn=<MmBackward>)
tensor([[20307458048.]], grad_fn=<MmBackward>)
tensor([[18155470848.]], grad_fn=<MmBackward>)
tensor([[20008024064.]], grad_fn=<MmBackward>)
tensor([[18894223360.]], grad_fn=<MmBackward>)
tensor([[16232227840.]], grad_fn=<MmBackward>)
tensor([[20206950400.]], grad_fn=<MmBackward>)
tensor([[20226355200.]], grad_fn=<MmBackward>)
tensor([[18740516864.]], grad_fn=<MmBackward>)
tensor([[2183

tensor([[16300719104.]], grad_fn=<MmBackward>)
tensor([[19552335872.]], grad_fn=<MmBackward>)
tensor([[19193600000.]], grad_fn=<MmBackward>)
tensor([[17705992192.]], grad_fn=<MmBackward>)
tensor([[17709219840.]], grad_fn=<MmBackward>)
tensor([[16741783552.]], grad_fn=<MmBackward>)
tensor([[16904112128.]], grad_fn=<MmBackward>)
tensor([[17051830272.]], grad_fn=<MmBackward>)
tensor([[17091088384.]], grad_fn=<MmBackward>)
tensor([[17176495104.]], grad_fn=<MmBackward>)
tensor([[21617612800.]], grad_fn=<MmBackward>)
tensor([[19360948224.]], grad_fn=<MmBackward>)
tensor([[18546397184.]], grad_fn=<MmBackward>)
tensor([[18014451712.]], grad_fn=<MmBackward>)
tensor([[18857572352.]], grad_fn=<MmBackward>)
tensor([[21301043200.]], grad_fn=<MmBackward>)
tensor([[17042391040.]], grad_fn=<MmBackward>)
tensor([[16898421760.]], grad_fn=<MmBackward>)
tensor([[16456390656.]], grad_fn=<MmBackward>)
tensor([[18994778112.]], grad_fn=<MmBackward>)
tensor([[20992229376.]], grad_fn=<MmBackward>)
tensor([[1720