In [2]:
import numpy as np

In [305]:
import numpy as np
from dataclasses import dataclass, field
from typing import Optional, Callable, List, Tuple, Union

def lra_matmul(lra1, lra2) -> np.ndarray:
    return np.matmul(lra1, lra2.to_numpy())


@dataclass
class lazyrepeatarray:
    data: Union[int, float]
    shape: Tuple
    transforms: Optional[List] = field(repr=False, default=None)
        
    def __post_init__(self):
        if self.transforms is None:
            self.transforms = []
        self._shape = self.shape
    
    def to_numpy(self, original=False) -> np.ndarray:
        if not original:
            return np.ones(self.shape) * self.data
        else:
            return np.ones(self._shape) * self.data
    
    def add_op(self, function: Callable, selection=slice(None), args={}):
        self.transforms.append((function, selection, args))
    
    def evaluate(self):
        result = self.to_numpy(original=True)
        for func, selection, args in self.transforms:
            if func == np.matmul or func == lra_matmul:
                result = func(result[selection], args)
            else:
                result[selection] = func(result[selection], args)
        self.transforms = []
        return result
    
#     def __repr__(self):
#         return f"LazyRepeatArray(data={self.data}, shape={self.shape}, op_count={len(self.transforms)})"
    
    def __add__(self, other):
        if isinstance(other, (int, np.integer, float, np.floating)):
            self.data += other
            return self
        elif isinstance(other, (np.ndarray, lazyrepeatarray)):
            try: 
                self.shape = np.broadcast_shapes(self.shape, other.shape)
                
                if isinstance(other, lazyrepeatarray):
                    self.add_op(function=np.add, args=other.data)
                elif isinstance(other, np.ndarray):
                    self.add_op(function=np.add, args=other)
                return self
            except ValueError:
                raise Exception(f"Cannot broadcast shapes {self.shape} and {other.shape}")
                
    def matmul(self, other):
        if isinstance(other, (int, np.integer, float, np.floating)):
            raise Exception
        elif isinstance(other, (np.ndarray, lazyrepeatarray)):
            if len(self.shape) != 2 or len(other.shape) != 2:
                raise Exception("Matmul only valid for 2D arrays")
            
            if self.shape[-1] != other.shape[0]:
                raise Exception(f"Matmul not possible b/w shapes {self.shape} & {other.shape}")
            else:
                self.shape = (self.shape[0], other.shape[-1])
                if isinstance(other, lazyrepeatarray):
                    self.add_op(function=lra_matmul, args=other)
                elif isinstance(other, np.ndarray):
                    self.add_op(function=np.matmul, args=other)
                return self
                
    
    def dot(self, other):
        if isinstance(other, (int, np.integer, float, np.floating)):
            self.data *= other
            return self
        elif isinstance(other, (np.ndarray, lazyrepeatarray)):
            if other.shape[0] != self.shape[0]:
                raise Exception
            else:
                pass
#                 try: 
#                     self.shape = np.broadcast_shapes(self.shape, other.shape)

#                     if isinstance(other, lazyrepeatarray):
#                         self.add_op(function=np.add, args=other.data)
#                     elif isinstance(other, np.ndarray):
#                         self.add_op(function=np.add, args=other)
#                     return self
#                 except ValueError:
#                     raise Exception(f"Cannot broadcast shapes {self.shape} and {other.shape}")
    
    
    

In [306]:
a = lazyrepeatarray(10, (5,6))
b = np.random.random((6, 5))

In [307]:
c = a.matmul(b)
print(c)
c.evaluate()

lazyrepeatarray(data=10, shape=(5, 5))


array([[32.21456948, 41.83049448, 18.22916586, 27.40842481, 22.54779985],
       [32.21456948, 41.83049448, 18.22916586, 27.40842481, 22.54779985],
       [32.21456948, 41.83049448, 18.22916586, 27.40842481, 22.54779985],
       [32.21456948, 41.83049448, 18.22916586, 27.40842481, 22.54779985],
       [32.21456948, 41.83049448, 18.22916586, 27.40842481, 22.54779985]])

In [297]:
a = lazyrepeatarray(10, (5,6))
d = lazyrepeatarray(10, (6,7))

c = a.matmul(d)
print(c)
c.evaluate()

LazyRepeatArray(data=10, shape=(5, 7), op_count=1)


array([[600., 600., 600., 600., 600., 600., 600.],
       [600., 600., 600., 600., 600., 600., 600.],
       [600., 600., 600., 600., 600., 600., 600.],
       [600., 600., 600., 600., 600., 600., 600.],
       [600., 600., 600., 600., 600., 600., 600.]])

In [298]:
(a + a + 6).evaluate()

array([[26., 26., 26., 26., 26., 26.],
       [26., 26., 26., 26., 26., 26.],
       [26., 26., 26., 26., 26., 26.],
       [26., 26., 26., 26., 26., 26.],
       [26., 26., 26., 26., 26., 26.]])

In [299]:
a = lazyrepeatarray(10, (5,6))

In [300]:
(a + 6).evaluate()

array([[16., 16., 16., 16., 16., 16.],
       [16., 16., 16., 16., 16., 16.],
       [16., 16., 16., 16., 16., 16.],
       [16., 16., 16., 16., 16., 16.],
       [16., 16., 16., 16., 16., 16.]])

In [301]:
a.data

16

In [302]:
a.transforms

[]

In [303]:
np.ones_like(a)

array(1, dtype=object)

In [304]:
np.add(np.ones(6), 6)

array([7., 7., 7., 7., 7., 7.])