In [1]:
import torch
import numpy as np
import torch.nn as nn
import t3nsor as t3

In [2]:
class TensorRing(object):
    def __init__(self, tr_cores, shape=None, tr_ranks=None, convert_to_tensors=True):
        #tr_cores = list(tr_cores)
        if convert_to_tensors:
            for i in range(len(tr_cores)):
                tr_cores[i] = torch.Tensor(tr_cores[i])

        self._tr_cores = tr_cores

        if len(self._tr_cores[0].shape) == 4:
            self._is_tr_matrix = True
        else:
            self._is_tr_matrix = False

        if self._is_tr_matrix:
            self._raw_shape = [[tr_core.shape[1] for tr_core in self._tr_cores],
                               [tr_core.shape[2] for tr_core in self._tr_cores]]
            self._shape = [int(np.prod(self._raw_shape[0])), int(np.prod(self._raw_shape[1]))]
            self._ndims = len(self._raw_shape[0])

        else:
            self._raw_shape = [tr_core.shape[1] for tr_core in self._tr_cores]
            self._shape = [tr_core.shape[1] for tr_core in self._tr_cores]
            self._ndims = len(self._raw_shape)

        self._ranks = [tr_core.shape[0] for tr_core in self._tr_cores] + [1, ]
        self._is_parameter = False
        self._parameter = None
        self._dof = np.sum([np.prod(list(tr_core.shape)) for tr_core in self._tr_cores])
        self._total = np.prod(self._shape)
        
    @property    
    def tr_cores(self):
        """A list of TR-cores.
        Returns:
          A list of 4d or 5d tensors.
        """
        return self._tr_cores   
    
    @property
    def is_tr_matrix(self):
        return self._is_tr_matrix
        
    def full(self):
        num_dims = self._ndims
        ranks = self._ranks
        shape = self._shape
        raw_shape = self._raw_shape
        res = self.tr_cores[0]

        for core_idx in range(1, num_dims):
            curr_core = self.tr_cores[core_idx]
#             print('loop', core_idx, curr_core.shape)
            res = torch.tensordot(res, curr_core, dims=[[-1], [0]])
    
        res = torch.einsum('i...i->...', res) # trace  
        print(res.shape)
        
        if self.is_tr_matrix:
            transpose = []
            for i in range(0, 2 * num_dims, 2):
                transpose.append(i)
            for i in range(1, 2 * num_dims, 2):
                transpose.append(i)
            res = res.permute(*transpose)
        print(transpose)
        if self.is_tr_matrix:
            res = res.contiguous().view(*shape)
        else:
            res = res.view(*shape)
        return res
    
    def to_parameter(self):
        new_cores = []
        for core in self.tr_cores:
            core = nn.Parameter(core)
            core.is_tr = True
            new_cores.append(core)

        tr_p = TensorRing(new_cores, convert_to_tensors=False)
        tr_p._parameter = nn.ParameterList(tr_p.tr_cores)        
        tr_p._is_parameter = True
        return tr_p
    

In [3]:
r = 1

cores = [np.random.normal(size=(r, 4, 3, r)), 
        np.random.normal(size=(r, 5, 4, r)),
        np.random.normal(size=(r, 7, 3, r))
        ]
x = TensorRing(cores)


x2 = t3.TensorTrain(cores)

In [4]:
res = x.full()

torch.Size([4, 3, 5, 4, 7, 3])
[0, 2, 4, 1, 3, 5]


In [5]:
res2 = x2.full()

In [6]:
res - res2

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [7]:
x.is_tr_matrix

True

In [8]:
class TREmbedding(nn.Module):
    def __init__(self,
                 init=None,
                 shape=None,
                 voc_size=None,
                 emb_size=None,
                 auto_shapes=None,
                 auto_shape_mode='ascending',
                 auto_shape_criterion='entropy',
                 d=3,
                 tr_rank=8,
                 batch_dim_last=None,
                 padding_idx=None):

        super(TREmbedding, self).__init__()

        if auto_shapes:
            voc_quantization = t3.utils.suggest_shape(
                voc_size, d=d, criterion=auto_shape_criterion, mode=auto_shape_mode)
            emb_quantization = t3.utils.auto_shape(
                emb_size, d=d, criterion=auto_shape_criterion, mode=auto_shape_mode)

            shape = [voc_quantization, emb_quantization]
            self.shape = shape
            
        else:
            self.shape = shape

        if init is None:
            if shape is None:
                raise ValueError('if init is not provided,'
                                 ' please specify shape')
        else:
            self.shape = init.raw_shape
        

        if init is None:
            init = t3.glorot_initializer_tr(self.shape, tr_rank=tr_rank)

        self.tr_matrix = init.to_parameter()
        self.parameters = self.tr_matrix.parameter

        # for p in self.parameters():
        #    p.name = 'tt_core'

        self.batch_dim_last = batch_dim_last
        self.voc_size = int(np.prod(self.shape[0]))
        self.emb_size = int(np.prod(self.shape[1]))

        self.voc_quant = self.shape[0]
        self.emb_quant = self.shape[1]

        self.padding_idx = padding_idx

    def forward(self, x):

        xshape = list(x.shape)
        xshape_new = xshape + [self.emb_size, ]
        x = x.view(-1)

        # x_ind = t3.ind2sub(self.voc_quant, x)
        # rows = t3.gather_rows(self.tt_matrix, x_ind)

        # rows = rows.view(x.shape[0], -1)

        full = self.tr_matrix.full()
        rows = full[x]

        if self.padding_idx is not None:
            rows = torch.where(x.view(-1, 1) != self.padding_idx, rows, torch.zeros_like(rows))

        rows = rows.view(*xshape_new)

        return rows.to(x.device)

In [9]:
emb = TREmbedding(shape=[10000, 64], voc_size=10000,
                 emb_size=64, auto_shapes=True)

In [10]:
inp = np.random.randint(0, 10000, 200)
inp = torch.LongTensor(inp)

In [12]:
out = emb.forward(inp)

torch.Size([20, 4, 20, 4, 25, 4])
[0, 2, 4, 1, 3, 5]


In [13]:
out.shape

torch.Size([200, 64])