In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

In [2]:
import torch.nn as nn
import torch
import time


In [9]:
class TRL(nn.Module):
    def __init__(self, input_size, output, rank, ignore_modes = (0,), bias = True, device = 'cpu'):
        super(TRL, self).__init__()
        
        alphabet = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQERSUVWXYZ'
        self.device = device
        self.bias = bias
        
        if isinstance(input_size, int):
            self.input_size = (input_size, )
        else:
            self.input_size = tuple(input_size)
            
        if isinstance(output, int):
            self.output = (output, )
        else:
            self.output = tuple(output)
        
        if isinstance(rank, int):
            self.rank = (rank, )
        else:
            self.rank = tuple(rank)
        
        self.ignore_modes = ignore_modes
        
        # remove ignored modes from the input size
        new_size = []
        for i in range(len(self.input_size)):
            if i in self.ignore_modes:
                continue
            else:
                new_size.append(self.input_size[i])
        
        self.w_size = tuple(new_size) + self.output
        if self.bias:
            self.b = nn.Parameter(torch.randn(self.output), requires_grad=True)
        else:
            self.b = None
            
        # Tucker Decomposition method for TRL
        
        self.core = nn.Parameter(torch.randn(self.rank), requires_grad=True)
                           
        # List of all factors
        parameter_list = []
        for i,r in enumerate(self.rank):
            parameter_list.append(nn.Parameter(torch.randn(r, self.w_size[i]), requires_grad=True))
        self.factors = nn.ParameterList().extend(parameter_list)
        
        # Generate formula for w :
        
        index = 0
        formula = ''
        core_str = ''
        w_str = ''
        for i in range(len(self.core.shape)):
            formula+=alphabet[index]
            index+=1
            if i== len(self.core.shape) - 1:
                formula+=','
        core_str = formula[:len(formula)-1]
                
        for l,_ in enumerate(self.factors):
            formula+=core_str[l]
            formula+=alphabet[index]
            w_str+=alphabet[index]
            index+=1
            if l < len(self.factors) - 1:
                formula+=','
            elif l == len(self.factors) - 1:
                    formula+='->'
        
        formula+=w_str
        # print(formula)
        
        self.w_formula = formula        
        self.w = torch.einsum(self.w_formula, (tuple([self.core] + [f for f in self.factors]))).to(self.device)
        
        
        # Generate formula for Generalized Inner Product of W and X:
        index = 0
        formula = ''
        mul = ''
        out_str = ''
        extend_str =''
        for i in range(len(self.input_size)):
            formula+=alphabet[index]
            if i not in self.ignore_modes:
                mul+= alphabet[index]
            else:
                extend_str+= alphabet[index]
            index+=1
            if i== len(self.input_size) - 1:
                formula+=','
        
        formula+=mul
        for i in range(len(mul),len(self.w_size)):
            formula+=alphabet[index]
            out_str+=alphabet[index]
            index+=1
            if i== len(self.w_size) - 1:
                formula+='->'
         
        formula+=extend_str+out_str       
        self.out_formula = formula
        # print(formula)
        
    def forward(self, x):
        out = torch.einsum(self.out_formula, (x, self.w)).to(self.device)
        out += self.b 
        return out

In [11]:
st = time.time()
trl = TRL(input_size=(5,14,14,16,16,3), output=(16,16,3), rank=(16,16,3,16,16,3), ignore_modes=(0,1,2), device = 'cuda').to('cuda')
t = torch.randn(5,14,14,16,16,3).to('cuda')
output = trl(t)
elapsed = time.time() - st
print(elapsed)
output.shape

0.0439763069152832


torch.Size([5, 14, 14, 16, 16, 3])