In [6]:

from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple, Union, List
from enum import Enum, auto
from functools import cached_property
from functools import reduce
from operator import mul
from layout_module import NestedTuple, Atom, Profile 
from layout_module import nested_tuple_algebra as na
from layout_module import flat_algebra as fa 



In [None]:
from dataclasses import dataclass

@dataclass(frozen=True)
class Layout:
  shape: NestedTuple
  stride: NestedTuple

  def __post_init__(self):
    assert self.shape.prof == self.stride.prof
    
    
  def get_mode(self, i:int)->"Layout": 
    assert 0 <= i < self.rank 
    return Layout(self.shape.get_mode(i), self.stride.get_mode(i))
  
  def get_entry(self, i:int)->"Layout": 
    assert 0 <= i < self.length 
    return Layout(self.shape.get_entry(i), self.stride.get_entry(i))
  
  def is_flat(self)->bool: 
    return self.shape.is_flat() and self.stride.is_flat()
    
  def flatten(self)->"Layout": 
    return Layout(self.shape.flatten(), self.stride.flatten())
  
  
 
  def coalesce(self) -> "Layout":
    Lf = self.flatten()

    shape = Lf.shape.int_tuple
    stride = Lf.stride.int_tuple
    m = len(shape)

    # Empty layout: do nothing
    if m == 0:
      return Lf

    new_shape = []
    new_stride = []

    curr_shape = shape[0]
    curr_stride = stride[0]

    for i in range(1, m):
      if stride[i] == curr_shape * curr_stride:
        curr_shape *= shape[i]
      else:
        new_shape.append(curr_shape)
        new_stride.append(curr_stride)
        curr_shape = shape[i]
        curr_stride = stride[i]

    new_shape.append(curr_shape)
    new_stride.append(curr_stride)

    # Build flat profile of resulting rank
    flat_prof = Profile(tuple(Profile(Atom.STAR) for _ in new_shape))

    return Layout(
      NestedTuple(tuple(new_shape), flat_prof),
      NestedTuple(tuple(new_stride), flat_prof),
    )

  

  @property
  def rank(self) -> int:
    return self.shape.rank

  @property
  def length(self) -> int:
    return self.shape.length

  @property
  def depth(self) -> int:
    return self.shape.depth

  @property
  def size(self) -> int:
    return self.shape.size

  @property
  def cosize(self) -> int:
    shape = self.shape.int_tuple
    stride = self.stride.int_tuple
    assert len(shape) == len(stride)

    if len(shape) == 0:
      return 0

    return 1 + sum((s - 1) * st for s, st in zip(shape, stride))


    
    

In [None]:
def concatenate(ls:List[Layout])-> Layout: 
  shapes = [l.shape for l in ls]
  strides = [l.stride for l in ls]
  cat_shape = na.concatenate(shapes)
  cat_stride = na.concatenate(strides)
  return Layout(cat_shape, cat_stride)

def substitute_modes(L:Layout, P:Profile)-> Layout: 
  assert L.rank == P.length 
  shape_modes = [L.get_mode(i).shape for i in range(L.rank)]
  stride_modes = [L.get_mode(i).stride for i in range(L.rank)]
  new_shape = na.substitute(P, shape_modes)
  new_stride = na.substitute(P, stride_modes)
  return Layout(new_shape, new_stride)


  