In [1]:
import os, sys
sys.path.insert(0, "..")
from __future__ import annotations
import torch
import math
from torch import Tensor
from torch.nn import Parameter
from HTorch.MCTensor.MCOpBasics import _Renormalize, _Simple_renormalize_old, _Grow_ExpN, _AddMCN,  _ScalingN,\
    _DivMCN, _MultMCN, _exp, _pow, _square, _sinh, _cosh, _tanh, _log, _exp, _sqrt, \
    _softmax, _log_softmax, _cross_entropy, _mse_loss, _layer_norm, _atanh, _log1p_standard, \
    _clamp, _norm, _sum, _mean
from HTorch.MCTensor.MCOpMatrix import _Dot_MCN, _Dot_MCN_M, _MV_MC_T, _MV_T_MC, _MV_MC_M_M, _MM_MC_T, _MM_T_MC, _MM_MC_MC, \
    _BMM_MC_T, _BMM_T_MC, _BMM_MC_MC, _4DMM_T_MC, _4DMM_MC_T, _4DMM_MC_MC
from typing import Union, List
import functools

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
HANDLED_FUNCTIONS = {}

def implements(torch_function):
    """Register a torch function override for MCTensor"""
    @functools.wraps(torch_function)
    def decorator(func):
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator

In [3]:
class MCTensor(Tensor):
    @staticmethod 
    def __new__(cls, *args,  nc=1, **kwargs): 
        ret = super().__new__(cls, *args, **kwargs)
        ret._nc = nc
        ret.res = torch.zeros(ret.size() + (nc-1,), dtype=ret.dtype, device=ret.device)
        return ret

    def __init__(self, *args,  nc=1, **kwargs):
        self._nc = nc
        self.res = torch.zeros(self.size() + (nc-1,), dtype=self.dtype, device=self.device)
        
    @staticmethod
    def wrap_tensor_to_mctensor(tensor: Tensor) -> MCTensor:
        # involves copying
        ret = MCTensor(tensor[..., 0], nc=tensor.size(-1))
        ret.res.data.copy_(tensor[..., 1:].data)
        return ret

    @staticmethod
    def wrap_tensor_and_res_to_mctensor(val: Tensor, res: Tensor) -> MCTensor:
        # involves copying
        ret = MCTensor(val, nc=res.shape[-1] + 1)
        ret.res.data.copy_(res.data)
        return ret

    @property
    def tensor(self):
        return torch.cat([self.as_subclass(Tensor).view(*self.shape, 1), self.res], -1)

    @property
    def T(self):
        return torch.transpose(self, 0, 1)
    
    @property
    def nc(self):
        return self._nc
        
    def normalize_(self, simple=False):
        if simple:
            normalized_self = _Simple_renormalize_old(self.tensor, self.nc)
        else:
            normalized_self = _Renormalize(self.tensor, self.nc)
        self.data.copy_(normalized_self[..., 0].data)
        self.res.data.copy_(normalized_self[..., 1:].data)
    
    def __repr__(self, *args, **kwargs):
        return "{}, nc={}".format(super().__repr__(), self.nc)

    def __add__(self, other) -> MCTensor:
        ''' add self with other'''
        if self.nc == 1 and (not isinstance(other, MCTensor) or other.nc == 1):
            obj = super().__add__(other)
        else:
            obj = torch.add(self, other)
        return obj

    def add(self, other):
        return self + other

    def add_(self, other):
        return self.copy_(self + other)

    def __radd__(self, other) -> MCTensor:
        ''' add self with other'''
        if self.nc == 1 and (not isinstance(other, MCTensor) or other.nc == 1):
            obj = super().__radd__(other)
        else:
            obj = torch.add(self, other)
        return obj
    
    def __sub__(self, other) -> MCTensor:
        if self.nc == 1 and (not isinstance(other, MCTensor) or other.nc == 1):
            obj = super().__sub__(other)
        else:
            obj = torch.add(self, -other)
        return obj

    def sub(self, other):
        return self - other

    def sub_(self, other):
        return self.copy_(self - other)

    def __rsub__(self, other) -> MCTensor:
        if self.nc == 1 and (not isinstance(other, MCTensor) or other.nc == 1):
            obj = super().__rsub__(other)
        else:
            obj = torch.add(other, -self)
        return obj

    def __mul__(self, other) -> MCTensor:
        if self.nc == 1 and (not isinstance(other, MCTensor) or other.nc == 1):
            obj = super().__mul__(other)
        else:
            obj = torch.mul(self, other)
        return obj

    def mul(self, other):
        return self * other

    def mul_(self, other):
        return self.copy_(self * other)

    def __rmul__(self, other) -> MCTensor:
        if self.nc == 1 and (not isinstance(other, MCTensor) or other.nc == 1):
            obj = super().__rmul__(other)
        else:
            obj = torch.mul(other, self)
        return obj
    
    def __truediv__(self, other) -> MCTensor:
        if self.nc == 1 and (not isinstance(other, MCTensor) or other.nc == 1):
            obj = super().__truediv__(other)
        else:
            obj = torch.div(self, other)
        return obj

    def div(self, other):
        return self / other

    def div_(self, other):
        return self.copy_(self / other)

    def __rtruediv__(self, other) -> MCTensor:
        if self.nc == 1 and (not isinstance(other, MCTensor) or other.nc == 1):
            obj = super().__rtruediv__(other)
        else:
            obj = torch.div(other, self)
        return obj

    def mv(self, other):
        return torch.mv(self, other)

    def mm(self, other):
        return torch.mm(self, other)

    def bmm(self, other):
        return torch.bmm(self, other)

    def matmul(self, other):
        return torch.matmul(self, other)

    def __matmul__(self, other) -> MCTensor:
        if self.nc == 1 and (not isinstance(other, MCTensor) or other.nc == 1):
            obj = super().__matmul__(other)
        else:
            obj = torch.matmul(self, other)
        return obj

    def __rmatmul__(self, other) -> MCTensor:
        if self.nc == 1 and (not isinstance(other, MCTensor) or other.nc == 1):
            obj = super().__rmatmul__(other)
        else:
            obj = torch.matmul(other, self)
        return obj

    def __pow__(self, other) -> MCTensor:
        if self.nc == 1 and (not isinstance(other, MCTensor) or other.nc == 1):
            obj = super().__pow__(other)
        else:
            obj = torch.pow(self, other)
        return obj

    def __getitem__(self, key):
        if isinstance(key, tuple) and Ellipsis in key:
            key = key + (slice(None, None, None),) # case of [..., v]
        val = self.tensor[key]
        return MCTensor.wrap_tensor_to_mctensor(val)

    def __setitem__(self, key, value: Union[MCTensor, Tensor, int, float]):
        if isinstance(key, tuple) and Ellipsis in key:
            res_key = (*key, slice(None, None, None))
        else:
            res_key = key
        if isinstance(value, MCTensor):
            super().__setitem__(key, value.as_subclass(Tensor))
            self.res[res_key].data.copy_(value.res.data)
        else:
            super().__setitem__(key, value)
            self.res[res_key] = 0

    def dot(self, other) -> MCTensor:
        if self.nc == 1 and (not isinstance(other, MCTensor) or other.nc == 1):
            if isinstance(other, MCTensor):
                other = other.as_subclass(Tensor)
            obj = super().as_subclass(Tensor).dot(other)
        else:
            obj = torch.dot(self, other)
        return obj
    
    def abs(self) -> MCTensor:
        return torch.abs(self)
    
    def sum(self, dim=None, keepdim=False, **kw) -> MCTensor:
        return torch.sum(self, dim=dim, keepdim=keepdim)
    
    def mean(self, dim=None, keepdim=False, **kw) -> MCTensor:
        return torch.mean(self, dim=dim, keepdim=keepdim)

    def norm(self, dim=None, keepdim=False, p=2, **kw) -> MCTensor:
        return torch.norm(self, dim=dim, keepdim=keepdim, p=p)

    def exp(self) -> MCTensor:
        return torch.exp(self)

    def exp_(self) -> MCTensor:
        return self.copy_(torch.exp(self))

    def log(self) -> MCTensor:
        return torch.log(self)
    
    def log_(self) -> MCTensor:
        return self.copy_(torch.log(self))

    def square(self) -> MCTensor:
        return torch.square(self)

    def square_(self) -> MCTensor:
        return self.copy_(torch.square(self))

    def sqrt(self) -> MCTensor:
        return torch.sqrt(self)

    def sqrt_(self) -> MCTensor:
        return self.copy_(torch.sqrt(self))

    def sinh(self) -> MCTensor:
        return torch.sinh(self)

    def sinh_(self) -> MCTensor:
        return self.copy_(torch.sinh(self))

    def cosh(self) -> MCTensor:
        return torch.cosh(self)
    
    def cosh_(self) -> MCTensor:
        return self.copy_(torch.cosh(self))
    
    def tanh(self) -> MCTensor:
        return torch.tanh(self)

    def tanh_(self) -> MCTensor:
        return self.copy_(torch.tanh(self))

    def atanh(self) -> MCTensor:
        return torch.atanh(self)

    def atanh_(self) -> MCTensor:
        return self.copy_(torch.atanh(self))
    
    def clamp_min(self, min=None) -> MCTensor:
        return torch.clamp_min(self, min=min)

    def clamp_min_(self, min=None) -> MCTensor:
        return self.copy_(torch.clamp_min(self, min=min))

    def clamp_max(self, max=None) -> MCTensor:
        return torch.clamp_max(self, max=max)

    def clamp_max_(self, max=None) -> MCTensor:
        return self.copy_(torch.clamp_max(self, max=max))

    def clone(self) -> MCTensor:
        return torch.clone(self)

    def unsqueeze(self, *args, **kwargs) -> MCTensor:
        return torch.unsqueeze(self, *args, **kwargs)
    
    def squeeze(self, *args, **kwargs) -> MCTensor:
        return torch.squeeze(self, *args, **kwargs)

    def reshape(self, *shape) -> MCTensor:
        return torch.reshape(self, shape)

    def transpose(self, dim0, dim1) -> MCTensor:
        return torch.transpose(self, dim0, dim1)

    def narrow(self, dim, start, length) -> MCTensor:
        if dim < 0:
            dim = dim + self.dim()
        return super().narrow(dim, start, length)
    
    def index_select(self, dim, index) -> MCTensor:
        if dim < 0:
            dim = dim + self.dim()
        return super().index_select(dim, index)

    def copy_(self, other: Union[MCTensor, Tensor]):
        if type(other) == Tensor:
            super().copy_(other)
        elif isinstance(other, MCTensor):
            super().copy_(other.as_subclass(Tensor))
            self.res.data.copy_(other.res.data)
        else:
            raise NotImplemented

    @staticmethod
    def replace_args(args):
        new_args = []
        for arg in args:
            if isinstance(arg, list):
                # Recursively apply the function to each element of the list
                new_arg = MCTensor.replace_args(arg)
            elif isinstance(arg, MCTensor):
                # Replace MCTensor with its `res` attribute
                new_arg = arg.res
            else:
                # For other types of objects, just use the original object
                new_arg = arg
            new_args.append(new_arg)
        return tuple(new_args)


    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # inherence nc
        if kwargs is None:
            kwargs = {}
        if func in HANDLED_FUNCTIONS: # torch.Tensor with MCTensor contents
            ret = HANDLED_FUNCTIONS[func](*args, **kwargs)
            if isinstance(ret, Tensor) and not isinstance(ret, MCTensor):
                return cls.wrap_tensor_to_mctensor(ret)
        else: # pytorch functions handle main and res component separately
            ret = super().__torch_function__(func, types, args, kwargs)
            if isinstance(ret, MCTensor) and not hasattr(ret, '_nc'):
                new_args = cls.replace_args(args)
                ret.res = super().__torch_function__(func, types, new_args, kwargs).as_subclass(Tensor)
                ret._nc = ret.res.size(-1) + 1
        return ret

In [4]:
@implements(torch.add)
def add(input: Union[MCTensor, int, float], other: Union[MCTensor, int, float], alpha=1) -> MCTensor:
    if type(input) == int or type(input) == float:
        input = torch.tensor(input, dtype=other.dtype,
                             device=other.device)
    if type(other) == int or type(other) == float:
        other = torch.tensor(other, device=input.device,
                             dtype=input.dtype)
    if alpha != 1:
        other = alpha * other  # should check for MCTensor multiplication
    if isinstance(input, MCTensor) and isinstance(other, MCTensor):
        return _AddMCN(input.tensor, other.tensor, simple=False)
    elif isinstance(input, MCTensor):
        x = input  # the MCTensor
        value = other
    else:
        x = other  # the MCTensor
        value = input
    x_tensor = x.tensor
    if x_tensor.dim() == 1:
        return _Grow_ExpN(x_tensor.unsqueeze(0), value)[0]
    else:
        return _Grow_ExpN(x_tensor, value)

@implements(torch.mul)
def mul(input: Union[MCTensor, int, float], other: Union[MCTensor, int, float]) -> MCTensor:
    if type(input) == int or type(input) == float:
        input = torch.tensor(input, dtype=other.dtype,
                             device=other.device)
    if type(other) == int or type(other) == float:
        other = torch.tensor(other, device=input.device,
                             dtype=input.dtype)
    normalize_case = 0 #case 0, not renormalize, explore carefully
    if isinstance(input, MCTensor) and isinstance(other, MCTensor):
        return _MultMCN(input.tensor, other.tensor, case=normalize_case)
    elif isinstance(input, MCTensor):
        x = input  # the MCTensor
        value = other
    else:
        x = other  # the MCTensor
        value = input
    return _ScalingN(x.tensor, value)

@implements(torch.div)
def div(x: Union[MCTensor, int, float], y: Union[MCTensor, int, float]) -> MCTensor:
    if type(x) == int or type(x) == float:
        x = torch.tensor(x, device=y.tensor.device, dtype=y.tensor.dtype)
    if type(y) == int or type(y) == float:
        y = torch.tensor(y, device=x.tensor.device, dtype=x.tensor.dtype)
    normalize_case = 2 #case 2, renormalize, explore carefully
    if isinstance(x, MCTensor) and type(y) == Tensor:
        nc = x.nc
        y_tensor = torch.zeros(y.size() + (nc,), device=y.device, dtype=y.dtype)
        y_tensor[..., 0] = y
        result = _DivMCN(x.tensor, y_tensor, case=normalize_case)
    elif type(x) == Tensor and isinstance(y, MCTensor):
        nc = y.nc
        x_tensor = torch.zeros(x.size() + (nc,), device=x.device, dtype=x.dtype)
        x_tensor[..., 0] = x
        result = _DivMCN(x_tensor, y.tensor, case=normalize_case)
    elif isinstance(x, MCTensor) and isinstance(y, MCTensor):
        result = _DivMCN(x.tensor, y.tensor, case=normalize_case)
    else:
        raise NotImplemented
    return result

@implements(torch.rand_like)
def rand_like(input: MCTensor, requires_grad=False, device=None, dtype=None) -> MCTensor:
    val = torch.randn_like(input.as_subclass(Tensor), requires_grad=requires_grad, device=device, dtype=dtype)
    return MCTensor(val, nc=input.nc)

@implements(torch.zeros_like)
def zeros_like(input: MCTensor, requires_grad=False, device=None, dtype=None) -> MCTensor:
    val = torch.zeros_like(input.as_subclass(Tensor), requires_grad=requires_grad, device=device, dtype=dtype)
    return MCTensor(val, nc=input.nc)

@implements(torch.ones_like)
def ones_like(input: MCTensor, requires_grad=False, device=None, dtype=None) -> MCTensor:
    val = torch.ones_like(input.as_subclass(Tensor), requires_grad=requires_grad, device=device, dtype=dtype)
    return MCTensor(val, nc=input.nc)

@implements(torch.clamp)
def clamp(input: MCTensor, min=None, max=None) -> MCTensor:
    return _clamp(input.tensor, min=min, max=max)

@implements(torch.clamp_min)
def clamp_min(input: MCTensor, min=None) -> MCTensor:
    return _clamp(input.tensor, min=min)

@implements(torch.clamp_max)
def clamp_max(input: MCTensor, max=None) -> MCTensor:
    return _clamp(input.tensor, max=max)

@implements(torch.norm)
def norm(input: MCTensor, dim=None, keepdim=False, p=2, **kw) -> MCTensor:
    return _norm(input.tensor, dim=dim, keepdim=keepdim, p=p)

@implements(torch.sum)
def sum(input: MCTensor, dim=None, keepdim=False, **kw) -> MCTensor:
    return _sum(input.tensor, dim=dim, keepdim=keepdim)

@implements(torch.abs)
def abs(input: MCTensor) -> MCTensor:
    result_tensor = _Renormalize(input.tensor, input.nc)
    neg_fc_pos = result_tensor[..., 0] < 0
    result_tensor[neg_fc_pos] = -result_tensor[neg_fc_pos]
    return result_tensor

@implements(torch.dot)
def dot(input: Union[MCTensor, int, float], other: Union[MCTensor, int, float]) -> MCTensor:
    if isinstance(input, MCTensor) and type(other) == Tensor:
        x = input
        y = other
    elif type(input) == Tensor and isinstance(other, MCTensor):
        x = other
        y = input
    elif isinstance(input, MCTensor) and isinstance(other, MCTensor):
        return _Dot_MCN_M(input.tensor, other.tensor)
    else:
        raise NotImplemented
    return _Dot_MCN(x.tensor, y)

@implements(torch.mv)
def mv(input: Union[Tensor, MCTensor], other: Union[Tensor, MCTensor]) -> MCTensor:
    if input.dim() == 2 and other.dim() == 1:
        x = input  # matrix
        y = other  # vector
    elif input.dim() == 1 and other.dim() == 2:
        x = other  # matrix
        y = input  # vector
    else:
        raise NotImplemented
    if isinstance(x, MCTensor) and type(y) == Tensor:
        result = _MV_MC_T(x.tensor, y)
    elif type(x) == Tensor and isinstance(y, MCTensor):
        result = _MV_T_MC(x, y.tensor)
    elif isinstance(x, MCTensor) and isinstance(y, MCTensor):
        return _MV_MC_M_M(x.tensor, y.tensor)
    else:
        raise NotImplemented
    return result

@implements(torch.mm)
def mm(input: Union[Tensor, MCTensor], other: Union[Tensor, MCTensor]) -> MCTensor:
    if isinstance(input, MCTensor) and type(other) == Tensor:
        result = _MM_MC_T(input.tensor, other)
    elif type(input) == Tensor and isinstance(other, MCTensor):
        result = _MM_T_MC(input, other.tensor)
    elif isinstance(input, MCTensor) and isinstance(other, MCTensor):
        result = _MM_MC_MC(input.tensor, other.tensor)
    else:
        ## implement mm between mctensors
        raise NotImplemented
    return result

@implements(torch.bmm)
def bmm(input: Union[Tensor, MCTensor], other: Union[Tensor, MCTensor]) -> MCTensor:
    if isinstance(input, MCTensor) and type(other) == Tensor:
        result, size, nc = _BMM_MC_T(input.tensor, other)
    elif type(input) == Tensor and isinstance(other, MCTensor):
        result, size, nc = _BMM_T_MC(input, other.tensor)
    elif isinstance(input, MCTensor) and isinstance(other, MCTensor):
        result, size, nc = _BMM_MC_MC(input.tensor, other.tensor)
    else:
        raise NotImplemented
    return result

@implements(torch.matmul)
def matmul(input: Union[MCTensor, Tensor], other: Union[MCTensor, Tensor]) -> MCTensor:
    x_dim, y_dim = input.dim(), other.dim()
    if x_dim == 1 and y_dim == 1:
        return dot(input, other)
    elif x_dim == 2 and y_dim == 2:
        return mm(input, other)
    elif (x_dim == 2 and y_dim == 1) or (x_dim == 1 and y_dim == 2):
        return mv(input, other)
    elif (x_dim > 2 and y_dim == 1) or (x_dim == 1 and y_dim > 2):
        return mul(input, other)
    elif x_dim == y_dim and x_dim == 3:
        if isinstance(input, MCTensor) and type(other) == Tensor:
            result, size, nc = _BMM_MC_T(input.tensor, other)
        elif type(input) == Tensor and isinstance(other, MCTensor):
            result, size, nc = _BMM_T_MC(input, other.tensor)
        elif isinstance(input, MCTensor) and isinstance(other, MCTensor):
            result, size, nc = _BMM_MC_MC(input.tensor, other.tensor)
        else:
            raise NotImplemented
    elif x_dim == y_dim and x_dim == 4:
        if isinstance(input, MCTensor) and type(other) == Tensor:
            result, size, nc = _4DMM_MC_T(input.tensor, other)
        elif type(input) == Tensor and isinstance(other, MCTensor):
            result, size, nc = _4DMM_T_MC(input, other.tensor)
        elif isinstance(input, MCTensor) and isinstance(other, MCTensor):
            result, size, nc = _4DMM_MC_MC(input.tensor, other.tensor)
        else:
            raise NotImplemented
    elif x_dim > y_dim:
        y = other[(None,) * (x_dim - y_dim)]  # unsqueeze
        if x_dim == 3:
            if isinstance(input, MCTensor) and type(other) == Tensor:
                result, size, nc = _BMM_MC_T(input.tensor, y)
            elif type(input) == Tensor and isinstance(other, MCTensor):
                result, size, nc = _BMM_T_MC(input, y.tensor)
            elif isinstance(input, MCTensor) and isinstance(other, MCTensor):
                result, size, nc = _BMM_MC_MC(input.tensor, y.tensor)
            else:
                raise NotImplemented
        elif x_dim == 4:
            if isinstance(input, MCTensor) and type(other) == Tensor:
                result, size, nc = _4DMM_MC_T(input.tensor, y)
            elif type(input) == Tensor and isinstance(other, MCTensor):
                result, size, nc = _4DMM_T_MC(input, y.tensor)
            elif isinstance(input, MCTensor) and isinstance(other, MCTensor):
                result, size, nc = _4DMM_MC_MC(input, y.tensor)
            else:
                raise NotImplemented
    elif x_dim < y_dim:
        x = input[(None,) * (y_dim - x_dim)]  # unsqueeze
        if y_dim == 3:
            if isinstance(input, MCTensor) and type(other) == Tensor:
                result, size, nc = _BMM_MC_T(x.tensor, other)
            elif type(input) == Tensor and isinstance(other, MCTensor):
                result, size, nc = _BMM_T_MC(x, other.tensor)
            elif isinstance(input, MCTensor) and isinstance(other, MCTensor):
                result, size, nc = _BMM_MC_MC(x.tensor, other.tensor)
            else:
                raise NotImplemented
        elif y_dim == 4:
            if isinstance(input, MCTensor) and type(other) == Tensor:
                result, size, nc = _4DMM_MC_T(x.tensor, other)
            elif type(input) == Tensor and isinstance(other, MCTensor):
                result, size, nc = _4DMM_T_MC(x, other.tensor)
            elif isinstance(input, MCTensor) and isinstance(other, MCTensor):
                result, size, nc = _4DMM_MC_MC(x, other.tensor)
            else:
                raise NotImplemented
    else:
        ## implement mm between mctensors
        raise NotImplemented
    return result

@implements(torch.addmm)
def addmm(input: Union[MCTensor, Tensor], 
          mat1: Union[MCTensor, Tensor], 
          mat2: Union[MCTensor, Tensor], 
          beta=1.0, alpha=1.0) -> MCTensor:
    return beta * input + alpha * (mat1 @ mat2)

@implements(torch.transpose)
def transpose(input: MCTensor, dim0, dim1) -> MCTensor:
    d = input.dim()
    if dim0 < 0:
        dim0 += d
    if dim1 < 0:
        dim1 += d
    val = torch.transpose(input.as_subclass(Tensor), dim0, dim1)
    res =  torch.transpose(input.res, dim0, dim1)
    return MCTensor.wrap_tensor_and_res_to_mctensor(val, res)

@implements(torch.reshape)
def reshape(input: MCTensor, shape) -> MCTensor:
    data = torch.reshape(input.as_subclass(Tensor), shape)
    extra_nc = input.res.shape[-1]
    res = torch.reshape(input.res.view(input.res.numel() // extra_nc, extra_nc), shape + (extra_nc,))
    return MCTensor.wrap_tensor_and_res_to_mctensor(data, res)

@implements(torch.nn.functional.relu)
def relu(input: MCTensor, inplace=False) -> MCTensor:
    val = torch.nn.functional.relu(input.as_subclass(Tensor), inplace=inplace)
    if inplace:
        input.res[input.as_subclass(Tensor) == 0] = 0
        return input
    else:
        res = input.res.clone()
        res[val == 0] = 0
        return MCTensor.wrap_tensor_and_res_to_mctensor(val, res)
    
@implements(torch.sigmoid)
def sigmoid(input) -> MCTensor:
    return 1/(torch.exp(-input) + 1)

@implements(torch.nn.functional.softmax)
def softmax(x: MCTensor, dim, *args, **kwargs) -> MCTensor:
    return _softmax(x.tensor, dim=dim)

@implements(torch.erf)
def erf(input: MCTensor) -> MCTensor:
    ### this is an approximation
    ret = torch.erf(input.as_subclass(Tensor))
    return MCTensor(ret, nc=input.nc)

@implements(torch.nn.functional.dropout)
def dropout(input: MCTensor, p=0.5, training=True, inplace=False) -> MCTensor:
    if training:
        val = torch.nn.functional.dropout(input.as_subclass(Tensor), p=p, training=True, inplace=inplace)
        if inplace:
            input.res[input.as_subclass(Tensor) == 0] = 0
            return input
        else:
            res = input.res.clone()
            res[val == 0] = 0
            return MCTensor.wrap_tensor_and_res_to_mctensor(val, res)    
    else:
        return input
    
@implements(torch.square)
def square(input: MCTensor) -> MCTensor:
    return _square(input.tensor)

@implements(torch.atanh)
def atanh(input: MCTensor) -> MCTensor:
    return _atanh(input.tensor)

@implements(torch.log1p)
def log1p(input: MCTensor) -> MCTensor:
    return _log1p_standard(input.tensor)

@implements(torch.nn.functional.linear)
def linear(input: Union[MCTensor, Tensor], weight: Union[MCTensor, Tensor], bias=None) -> MCTensor:
    if isinstance(input, MCTensor) and isinstance(weight, MCTensor):
        ## attention, here make input as tensor, as mul between MCTensors not supported yet
        input = input.as_subclass(Tensor)
    ret = torch.matmul(input, weight.T)
    if bias is None:
        return ret
    else:
        return ret + bias
    
@implements(torch.diag)
def diag(x: MCTensor, diagonal=0) -> MCTensor:
    indices_selected = torch.arange(x.numel(), dtype=torch.int64, device=x.device).view(*x.shape)
    selected_indices = torch.diag(indices_selected, diagonal=diagonal).view(-1)
    val = x.as_subclass(Tensor).view(-1)[selected_indices]
    res = x.res.view(x.numel(), x.res.shape[-1])[selected_indices]
    return MCTensor.wrap_tensor_and_res_to_mctensor(val, res)

@implements(torch.mean)
def mean(input: MCTensor, dim=None, keepdim=False, **kw) -> MCTensor:
    return _mean(input.tensor, dim=dim, keepdim=keepdim)

@implements(torch.nn.functional.nll_loss)
def nll_loss(input: MCTensor, target: Tensor, **kw) -> MCTensor:
    return torch.mean(torch.diag(-input[:, target]))

@implements(torch.nn.functional.log_softmax)
def log_softmax(x: MCTensor, dim=None, **kw) -> MCTensor:
    return _log_softmax(x.tensor, dim=dim)

@implements(torch.nn.functional.cross_entropy)
def cross_entropy(x: MCTensor, target: Tensor, reduction='mean', label_smoothing=0.0, **kw) -> MCTensor:
    cross_entropy_x_tensor = _cross_entropy(x.tensor, target=target, reduction=reduction, label_smoothing=label_smoothing)
    return cross_entropy_x_tensor

@implements(torch.nn.functional.mse_loss)
def mse_loss(x: MCTensor, y: MCTensor, reduction='mean', **kw) -> MCTensor:
    return _mse_loss(x.tensor, y.tensor, reduction=reduction)

@implements(torch.sqrt)
def sqrt(input: MCTensor) -> MCTensor:
    return _sqrt(input.tensor)

@implements(torch.log)
def log(input: MCTensor) -> MCTensor:
    return _log(input.tensor)

@implements(torch.pow)
def pow(input: MCTensor, exponent: Union[Tensor, float, int]) -> MCTensor:
    return  _pow(input.tensor, exponent)

@implements(torch.exp)
def exp(input: MCTensor) -> MCTensor:
    return _exp(input.tensor)

@implements(torch.sinh)
def sinh(input: MCTensor) -> MCTensor:
    return _sinh(input.tensor)

@implements(torch.cosh)
def cosh(input: MCTensor) -> MCTensor:
    return _cosh(input.tensor)

@implements(torch.tanh)
def tanh(input: MCTensor) -> MCTensor:
    return _tanh(input.tensor)

@implements(torch.nn.functional.layer_norm)
def layer_norm(x: MCTensor, normalized_shape, weight=None, bias=None, eps=1e-05) -> MCTensor:
    nc = x.nc
    if isinstance(weight, torch.Tensor):
        mc_weight = torch.zeros(weight.size() + (nc,),
                                device=x.device, dtype=x.dtype)
        mc_weight[..., 0] = weight
    else:
        mc_weight = weight.tensor

    if isinstance(bias, torch.Tensor):
        mc_bias = torch.zeros(bias.size() + (nc,),
                              device=x.device, dtype=x.dtype)
        mc_bias[..., 0] = bias
    else:
        mc_bias = bias.tensor
    return _layer_norm(x.tensor, normalized_shape, mc_weight, mc_bias, eps=eps)

In [5]:
a = MCTensor([0.1, 0.1], nc=2)
a + MCTensor([[0.1, 0.1]], nc=2)

MCTensor([[0.2000, 0.2000]]), nc=2

In [6]:
# import HTorch.MCTensor.MCOpBasics
# def _Grow_ExpN(x_tensor, value, simple=True):
#     nc = x_tensor.size(-1)
#     Q = value
#     h = torch.zeros_like(x_tensor)
#     for i in range(1, nc+1):
#         Q, hval = HTorch.MCTensor.MCOpBasics.Two_Sum(x_tensor[..., -i], Q)
#         if i == 1:
#             last_tensor = hval.data
#         else:
#             print(h.shape)
#             print(hval.shape)
#             h[..., -(i-1)] = hval.data
#     h[..., 0] = Q
#     # change from .unsqueeze(-1) to (*.shape, 1)
#     if simple:
#         res = torch.cat([h.data, last_tensor.data.view(
#             (*last_tensor.shape, 1))], dim=-1)
#         # if res.dim() == 1:
#         #     h.data.copy_(_Simple_renormalize_old(res.view(1, len(res)), r_nc=nc)[0])
#         # else:
#         h.data.copy_(_Simple_renormalize_old(res, r_nc=nc))
#     else:
#         res = torch.cat([h.data, last_tensor.data.view(
#             (*last_tensor.shape, 1))], dim=-1)
#         # if res.dim() == 1:
#         #     h.data.copy_(_Renormalize(res.view(1, len(res)), r_nc=nc)[0])
#         # else:
#         h.data.copy_(_Renormalize(res, r_nc=nc))
#     return h


In [7]:
a = MCTensor([0.1], nc=2)
# print(a[0])
# print(a[0] + a[0]) # []
# print(a[0].tensor.shape)
print(a[0].shape)
print(a[0].tensor.shape)
print(a[0] + a[0].tensor)

# print(_Grow_ExpN(a[0].tensor.unsqueeze(0), a[0].tensor)[0])

# print(a[0] * a[0])
# print(a[0] * a[0].tensor)
# print(a[0] / a[0])
# print(a[0] / a[0].tensor)

torch.Size([])
torch.Size([2])
MCTensor(0.2000), nc=2


### ToDos and attentions:
1. check each (customized) function is really called and working instead of super torch class functions
2. fix torch.abs to be correct, i.e., make it a neg for those with negative first component
3. support torch.norm with p='fro' Frobenius norm
4. check renormalize, i.e., case number (0,1,2) and simple (True,False) in add, mul, div, etc., specifically when they use lower level functions such as _AddMCN, _ScalingN ...
5. implement MCTensor multiplication with MCTensor at high level, i.e., dot, mv, mm etc.

In [8]:
a = MCTensor([[0.1, 0.3, 0.2], [0.2, 0.4, 0.5]], nc=2)
a.res.add_(-0.001)
a.shape, a.res.shape

(torch.Size([2, 3]), torch.Size([2, 3, 1]))

In [9]:
a = MCTensor([[-0.1, 0.3, 0.2], [0.2, -0.4, 0.5]], nc=2)
a.res = (a.as_subclass(Tensor) * -1e-4).view(a.res.shape)
print(torch.abs(a))
print(torch.abs(a).res)


MCTensor([[0.1000, 0.3000, 0.2000],
          [0.2000, 0.4000, 0.4999]]), nc=2
tensor([[[-1.3206e-09],
         [ 1.0938e-08],
         [-2.6412e-09]],

        [[-2.6412e-09],
         [-5.2823e-09],
         [ 8.2982e-09]]])


In [10]:
a = torch.arange(30).reshape(5, 3, 2).float()
mc_a = MCTensor(a, nc=2)
torch.norm(mc_a, p='fro', keepdim=True, dim=-1)

MCTensor([[[ 1.0000],
           [ 3.6056],
           [ 6.4031]],

          [[ 9.2195],
           [12.0416],
           [14.8661]],

          [[17.6918],
           [20.5183],
           [23.3452]],

          [[26.1725],
           [29.0000],
           [31.8277]],

          [[34.6554],
           [37.4833],
           [40.3113]]]), nc=2

In [11]:
torch.norm(a, p=2, keepdim=True, dim=-1)

tensor([[[ 1.0000],
         [ 3.6056],
         [ 6.4031]],

        [[ 9.2195],
         [12.0416],
         [14.8661]],

        [[17.6918],
         [20.5183],
         [23.3452]],

        [[26.1725],
         [29.0000],
         [31.8277]],

        [[34.6554],
         [37.4833],
         [40.3113]]])

In [12]:
a = torch.arange(30).float()
mc_a = MCTensor(a, nc=2)

b = torch.arange(30).float()
mc_b = MCTensor(b, nc=2)

print(torch.dot(a, b))
print(torch.dot(mc_a, mc_b))
print(torch.matmul(mc_a, mc_b) - torch.matmul(a, b))


tensor(8555.)
MCTensor(8555.), nc=2
MCTensor(0.), nc=2


In [13]:
a = torch.arange(30).reshape(5, 6).float()
mc_a = MCTensor(a, nc=2)

b = torch.arange(6).float()
mc_b = MCTensor(b, nc=2)

print(torch.mv(mc_a, mc_b))
print(torch.mv(a, b))
print(torch.matmul(mc_a, mc_b) - torch.matmul(a, b))


MCTensor([ 55., 145., 235., 325., 415.]), nc=2
tensor([ 55., 145., 235., 325., 415.])
MCTensor([0., 0., 0., 0., 0.]), nc=2


In [14]:
a = torch.arange(30).reshape(5, 6).float()
mc_a = MCTensor(a, nc=2)

b = torch.arange(60).reshape(6, 10).float()
mc_b = MCTensor(b, nc=2)

# 5, 6
print(torch.mm(mc_a, mc_b))
print(torch.mm(a, b))
print(torch.matmul(mc_a, mc_b) - torch.matmul(a, b))

MCTensor([[ 550.,  565.,  580.,  595.,  610.,  625.,  640.,  655.,  670.,  685.],
          [1450., 1501., 1552., 1603., 1654., 1705., 1756., 1807., 1858., 1909.],
          [2350., 2437., 2524., 2611., 2698., 2785., 2872., 2959., 3046., 3133.],
          [3250., 3373., 3496., 3619., 3742., 3865., 3988., 4111., 4234., 4357.],
          [4150., 4309., 4468., 4627., 4786., 4945., 5104., 5263., 5422., 5581.]]), nc=2
tensor([[ 550.,  565.,  580.,  595.,  610.,  625.,  640.,  655.,  670.,  685.],
        [1450., 1501., 1552., 1603., 1654., 1705., 1756., 1807., 1858., 1909.],
        [2350., 2437., 2524., 2611., 2698., 2785., 2872., 2959., 3046., 3133.],
        [3250., 3373., 3496., 3619., 3742., 3865., 3988., 4111., 4234., 4357.],
        [4150., 4309., 4468., 4627., 4786., 4945., 5104., 5263., 5422., 5581.]])
MCTensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0.,

In [15]:
a = torch.arange(30).reshape(5, 2, 3).float()
mc_a = MCTensor(a, nc=2)

b = torch.arange(30).reshape(5, 3, 2).float()
mc_b = MCTensor(b, nc=2)

# 5, 6
print(torch.bmm(mc_a, mc_b))
print(torch.bmm(a, b))
print(torch.matmul(mc_a, mc_b) - torch.matmul(a, b))

MCTensor([[[  10.,   13.],
           [  28.,   40.]],

          [[ 172.,  193.],
           [ 244.,  274.]],

          [[ 550.,  589.],
           [ 676.,  724.]],

          [[1144., 1201.],
           [1324., 1390.]],

          [[1954., 2029.],
           [2188., 2272.]]]), nc=2
tensor([[[  10.,   13.],
         [  28.,   40.]],

        [[ 172.,  193.],
         [ 244.,  274.]],

        [[ 550.,  589.],
         [ 676.,  724.]],

        [[1144., 1201.],
         [1324., 1390.]],

        [[1954., 2029.],
         [2188., 2272.]]])
MCTensor([[[0., 0.],
           [0., 0.]],

          [[0., 0.],
           [0., 0.]],

          [[0., 0.],
           [0., 0.]],

          [[0., 0.],
           [0., 0.]],

          [[0., 0.],
           [0., 0.]]]), nc=2


In [16]:
def _4DMM_MC_MC(x_tensor, y_tensor):
    x1, x2, x3, _, nc = x_tensor.size()
    y1, y2, _, y3, nc = y_tensor.size()
    size = max(x1, y1), max(x2, y2), x3, y3
    scaled = _MultMCN(x_tensor.unsqueeze(-3), y_tensor.transpose(-2, -3).unsqueeze(2))
    tmp = scaled[..., 0, :]
    for i in range(1, scaled.size(-2)):
        tmp = _AddMCN(tmp, scaled[..., i, :])
    return tmp, size, nc

In [17]:
a = torch.arange(120).reshape(4, 5, 2, 3).float()
mc_a = MCTensor(a, nc=2)

b = torch.arange(120).reshape(4, 5, 3, 2).float()
mc_b = MCTensor(b, nc=2)

print(torch.matmul(mc_a, mc_b) - torch.matmul(a, b))

MCTensor([[[[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]]],


          [[[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]]],


          [[[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]]],


          [[[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 0.],
            [0., 0.]]]]), nc=2


In [18]:
a = torch.arange(120).reshape(4, 5, 2, 3).float()
mc_a = MCTensor(a, nc=2)

mc_a + mc_a
print()

torch.add(mc_a, mc_a)
print()

mc_a.add(mc_a)
print()

mc_a.add_(mc_a)
print()







In [19]:
a = torch.arange(120).reshape(4, 5, 2, 3).float()
mc_a = MCTensor(a, nc=2)

mc_a * mc_a
print()

torch.mul(mc_a, mc_a)
print()

mc_a.mul(mc_a)
print()

mc_a.mul_(mc_a)
print()







In [20]:
a = torch.arange(120).reshape(4, 5, 2, 3).float()
mc_a = MCTensor(a, nc=2)

mc_a / mc_a
print()

torch.div(mc_a, mc_a)
print()

mc_a.div(mc_a)
print()

mc_a.div_(mc_a)
print()







In [21]:
a = torch.arange(12).reshape(4, 3).float()
mc_a = MCTensor(a, nc=2)

mc_a.dot(mc_a)
print()

torch.dot(mc_a, mc_a)
print()





In [22]:
a = torch.arange(12).reshape(4, 3).float()
mc_a = MCTensor(a, nc=2)

mc_a.mm(mc_a.T)
print()

torch.mm(mc_a, mc_a.T)
print()





In [23]:
a = torch.arange(24).reshape(2, 4, 3).float()
mc_a = MCTensor(a, nc=2)

mc_a.bmm(mc_a.transpose(-1, -2))
print()

torch.bmm(mc_a, mc_a.transpose(-1, -2))
print()





In [24]:
a = torch.arange(24).reshape(2, 4, 3).float()
mc_a = MCTensor(a, nc=2)

mc_a.matmul(mc_a.transpose(-1, -2))
print()

torch.matmul(mc_a, mc_a.transpose(-1, -2))
print()





In [25]:
a = torch.arange(24).reshape(2, 4, 3).float()
mc_a = MCTensor(a, nc=2)

mc_a.matmul(mc_a.transpose(-1, -2))
print()

torch.matmul(mc_a, mc_a.transpose(-1, -2))
print()






In [26]:
a = torch.randn(3, 2)
mc_a = MCTensor(a, nc=2)

mc_a.exp()
print()

mc_a.exp_()
print()

torch.exp(mc_a)
print()







In [27]:
a = torch.randn(3, 2)
mc_a = MCTensor(a, nc=2)

mc_a.square()
print()

mc_a.square_()
print()

torch.square(mc_a)
print()








In [28]:
a = torch.rand(3, 2)
mc_a = MCTensor(a, nc=2)

mc_a.sinh()
print()

torch.sinh(mc_a)
print()

mc_a.sinh_()
print()

mc_a.cosh()
print()

torch.cosh(mc_a)
print()

mc_a.cosh_()
print()

mc_a.tanh()
print()

torch.tanh(mc_a)
print()

mc_a.tanh_()
print()

mc_a.atanh()
print()

torch.atanh(mc_a)
print()

mc_a.atanh_()
print()















In [29]:
a = torch.rand(3, 2)
mc_a = MCTensor(a, nc=2)

torch.reshape(mc_a, (3, 2))
print()

mc_a.reshape(3, 2)
print()





In [30]:
a = torch.rand(3, 2)
mc_a = MCTensor(a, nc=2)

torch.mean(mc_a, dim=0)
print()

mc_a.mean(dim=0)
print()





In [31]:
a = torch.rand(3, 2)
mc_a = MCTensor(a, nc=2)

torch.norm(mc_a, dim=0)
print()

mc_a.norm(dim=0)
print()





In [32]:
a = torch.arange(12).reshape(2, 3, 2)
mc_a = MCTensor(a, nc=5)
# print(mc_a[..., -1, ...])

In [33]:
mc_a.tensor[..., 0, :, :]

tensor([[[0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0]],

        [[6, 0, 0, 0, 0],
         [7, 0, 0, 0, 0]]])

In [34]:
a[..., 0, :]

tensor([[0, 1],
        [6, 7]])

In [35]:
mc_a.res[..., 0, :].shape

torch.Size([2, 3, 4])

In [36]:
a[..., 0, :].shape

torch.Size([2, 2])