In [58]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np

In [104]:
class Conv2dMod(nn.Module):
    def __init__(self):
        super(Conv2dMod, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)

    def forward(self, x):
        y = self.conv2d(x)
        return y
    
conv2d = Conv2dMod()
#conv2d = conv2d.cuda()

In [105]:
class MaxPool2dMod(nn.Module):
    def __init__(self):
        super(MaxPool2dMod, self).__init__()
        self.maxpool2d = nn.MaxPool2d(kernel_size=2)
        
    def forward(self, x):
        y = self.maxpool2d(x)
        return y
    
maxpool2d = MaxPool2dMod()
#maxpool2d = maxpool2d.cuda()

In [216]:
class Threemm(nn.Module):
    def __init__(self):
        super(Threemm, self).__init__()

    def forward(self, X):
        n = 1024 # X.shape[0] // 4
        B = X[0:n, :] #X[0:n,:]
        C = X[n:2*n,:] #X[n:2*n,:]
        E = X[2*n:3*n,:] #X[2*n:3*n,:]
        F = X[3*n:4*n,:] #X[3*n:4*n,:]
        # A=BC, D=EF, G=AD
        A = torch.matmul(B, C)
        D = torch.matmul(E, F)
        G = torch.matmul(A, D)
        return G

threemm = Threemm()
#threemm = threemm.cuda()

In [217]:
class Alexnet(nn.Module):
    def __init__(self):
        super(Alexnet, self).__init__()
        self.alexnet = models.alexnet(pretrained=True)
        
    def forward(self, x):
        y = self.alexnet(x)
        return y
    
alexnet = Alexnet()
#alexnet = alexnet.cuda()

In [218]:
class Covariance(nn.Module):
    def __init__(self):
        super(Covariance, self).__init__()

    def forward(self, X):
        # Cov(X) = XX^T/(n-1)
        factor = 1.0 / (X.shape[0] - 1)
        M = X - torch.mean(X, axis=0)
        Mt = M.t()
        cov = factor * torch.matmul(M,Mt)
        return cov
    
covariance = Covariance()
# covariance = covariance.cuda()

In [219]:
class Correlation(nn.Module):
    def __init__(self):
        super(Correlation, self).__init__()

    def forward(self, X):
        factor = 1 / X.shape[0]
        means_col = torch.mean(X, axis=0)
        stds_col = torch.std(X, axis=0)
        M = (X - means_col)/ stds_col
        Mt = M.t()
        corr = factor * torch.matmul(M, Mt)
        return corr

correlation = Correlation()
# correlation = correlation.cuda()

In [220]:
# test conv2d, maxpool2d, alexnet
M = 227
img_random = torch.randn((1, 3, M, M))

In [221]:
output_conv2d = conv2d(img_random)
output_conv2d.shape

torch.Size([1, 3, 225, 225])

In [222]:
torch.onnx.export(conv2d, img_random, 'conv2d.onnx')

In [223]:
output_maxpool2d = maxpool2d(img_random)
output_maxpool2d.shape

torch.Size([1, 3, 113, 113])

In [224]:
torch.onnx.export(maxpool2d, img_random, 'maxpool2d.onnx')

In [225]:
output_alexnet = alexnet(img_random)
output_alexnet.shape

torch.Size([1, 1000])

In [226]:
torch.onnx.export(alexnet, img_random, 'alexnet.onnx')

In [227]:
# test 3mm
N = 1024
X = torch.randn((4*N, N))

In [228]:
output_threemm  = threemm(X)
output_threemm.shape

torch.Size([1024, 1024])

In [229]:
torch.onnx.export(threemm, X, 'threemm.onnx')

In [230]:
# test covariance and correlation
N = 1024
X = torch.randn((N,N))

In [231]:
output_covariance = covariance(X)
output_covariance.shape

torch.Size([1024, 1024])

In [232]:
torch.onnx.export(covariance, X, 'covariance.onnx')

In [233]:
output_correlation = correlation(X)
output_correlation.shape

torch.Size([1024, 1024])

In [234]:
torch.onnx.export(correlation, X, 'correlation.onnx')