In [121]:
import torch
import numpy as np
import t3f
import tensorflow as tf

In [112]:
class TensorTrainMatrix():
    def __init__(self, tt_cores, shape=None, tt_ranks=None, convert_to_tensors=True):  
        tt_cores = list(tt_cores)
        if convert_to_tensors:
            for i in range(len(tt_cores)):
                tt_cores[i] = torch.Tensor(tt_cores[i])
        
        self._tt_cores = tuple(tt_cores)
        
        self._raw_shape = [[tt_core.shape[1] for tt_core in self._tt_cores],
                           [tt_core.shape[2] for tt_core in self._tt_cores]]
        
        self._shape = [int(np.prod(self._raw_shape[0])), int(np.prod(self._raw_shape[1]))]
        
          
        self._ranks = [tt_core.shape[0] for tt_core in self._tt_cores] + [1,]
            
        self._ndims = len(self._raw_shape[0])
        
        
    @property
    def tt_cores(self):
        """A tuple of TT-cores.
        Returns:
          A tuple of 4d tensors of shape
            `[r_k-1, n_k, m_k, r_k]`
        """
        return self._tt_cores
    
    @property
    def raw_shape(self):
        return self._raw_shape
    
    
    @property
    def shape(self):
        return self._shape
    
    @property
    def ranks(self):
        return self._ranks
    
    
    @property
    def ndims(self):
        return self._ndims
    
    def to(device):
        for core in self.cores:
            core.to(device)
            
            
    def full(self):
        num_dims = self.ndims
        ranks = self.ranks
        shape = self.shape
        raw_shape = self.raw_shape
        
        
        res = self.tt_cores[0]
        
        for i in range(1, num_dims):
            res = res.view(-1, ranks[i])
            curr_core = self.tt_cores[i].view(ranks[i], -1)
            res = torch.matmul(res, curr_core)
            
            
        intermediate_shape = []
        for i in range(num_dims):
            intermediate_shape.append(raw_shape[0][i])
            intermediate_shape.append(raw_shape[1][i])
    
        res = res.view(*intermediate_shape)
        transpose = []
        for i in range(0, 2 * num_dims, 2):
            transpose.append(i)
        for i in range(1, 2 * num_dims, 2):
            transpose.append(i)
        res = res.permute(*transpose)
        res = res.contiguous().view(*shape)
        return res
    
    

In [113]:
tt_cores = [np.random.rand(1, 2, 2, 3), np.random.rand(3, 2, 2, 1)]

In [114]:
ttm = TensorTrainMatrix(tt_cores)

In [115]:
ttm.shape

[4, 4]

In [116]:
res = ttm.full()ee

In [117]:
res

tensor([[ 0.5880,  0.5528,  0.9631,  1.0056],
        [ 0.8675,  0.6784,  1.3647,  0.9304],
        [ 1.3756,  1.3364,  0.9143,  1.0003],
        [ 1.8526,  1.1428,  1.3255,  0.9342]])

In [118]:
mat = t3f.TensorTrain(tt_cores)

In [119]:
t3f.full(mat)

<tf.Tensor 'Reshape_3:0' shape=(4, 4) dtype=float64>

In [122]:
sess = tf.Session()

In [123]:
sess.run(t3f.full(mat))

array([[0.58804406, 0.55276174, 0.96309014, 1.00555499],
       [0.86747826, 0.67840304, 1.36465704, 0.93035856],
       [1.37556302, 1.33642596, 0.91427059, 1.00026065],
       [1.85256753, 1.14279743, 1.32545634, 0.93420146]])