# Flatten/Unflatten Tensors

This notebook illustrates merging (or **flattening**) of two ranks (or levels) of a tensor into one and then **unflattening** that one rank back into two.

First, include some libaries

In [None]:
# Begin - startup boilerplate code

import os

def run_prelude(**kwargs):

  switches =  " ".join([ f"--{k}={v}" for (k,v) in kwargs.items()])

  for prelude_file in ['./prelude.py', '../prelude.py']:
    if os.path.exists(prelude_file):
      command=f"{prelude_file} {switches}"
      %run $command
      return
    
  print("Downloading prelude.py")
  ! curl -LJOs https://raw.githubusercontent.com/Fibertree-Project/fibertree-notebooks/colab/notebooks/prelude.py
  command=f"./prelude.py {switches}"
  %run $command

# End - startup boilerplate code

run_prelude(style="tree", animation='movie')


## Flatten/Unflatten tests

Following are a set of cells where a tensor is flattened and then unflattened back to is orignal form. 

In [None]:
t0 = Tensor.fromYAMLfile(datafileName("draw-a.yaml"))
#print(f"{t1.getRoot():n*}")
print(t0.getName())
print(t0.getRankIds())
print(t0.getShape())
displayTensor(t0)

t1 = Tensor.fromYAMLfile(datafileName("draw-b.yaml"))
#print(f"{t2.getRoot():n*}")
print(t1.getName())
print(t1.getRankIds())
print(t1.getShape())
displayTensor(t1)

## Flatten ranks of a tensor

Flatten the top two ranks of `t0` into a single rank. Note that now the coordintes of the resulting rank are tuples of the coordinates from the fibers in the original ranks.

In [None]:
f = t0.flattenRanks()
print(f"{f:n*}")
displayTensor(f)

## Flatten and then unflatten the top rank of a tensor

The result of this sequence is to re-generate the original tensor.

Note that we explicitly specify with keyword arguments the `depth` (relative to the top rank) of the rank that the specified number of number of `levels` wilt be **flattened** into. So `levels=1` will combine fibers from two ranks into one. Similarly those same keyword arguments apply to the **unflatten** operation.


In [None]:
f01 = t0.flattenRanks(depth=0, levels=1)
print(f"{f01:n*}")
displayTensor(f01)

u01 = f01.unflattenRanks(depth=0, levels=1)
displayTensor(u01)

print(u01 == t0)

## Ilustrate flattening and unflattening three ranks

Again the result of **flattening** and then **unflattening** is to restore the original tensor

Note that the coordinates after flattening with `level`=2 is a 3-element tuple (not very visible in the diagram).

In [None]:
f02 = t0.flattenRanks(depth=0, levels=2)
print(f"{f02:n*}")
displayTensor(f02)


u02a = f02.unflattenRanks(depth=0, levels=1)
u02b = u02a.unflattenRanks(depth=1, levels=1)
displayTensor(u02b)

print(u02b == t0)

u02 = f02.unflattenRanks(depth=0, levels=2)

print(u02 == t0)

## Flatten and unflatten at a lower rank in the tensor

Flatten one `level` at a `depth` of one below the top rank of the tensor.

In [None]:
f12 = t0.flattenRanks(depth=1, levels=1)
print(f"{f12:n*}")
displayTensor(f12)

u12 = f12.unflattenRanks(depth=1, levels=1)
print(u12 == t0)
displayTensor(u12)

print(u12 == t0)

## Illustrate flattening/unflattening for a rank-4 tensor

In [None]:
t2 = Tensor.fromFiber(["A", "B", "C", "D"], 
                      Fiber([1, 4], [t0.getRoot(), t1.getRoot()]),
                      name="t2")
displayTensor(t2)

## Flatten/unflatten 3 ranks of a rank-4 tensor at a depth of 1

In [None]:
f13 = t2.flattenRanks(depth=1, levels=2)
print(f"{f13:n*}")
displayTensor(f13)

u13 = f13.unflattenRanks(depth=1, levels=2)
displayTensor(u13)

print(u13 == t2)

## Flatten/unflatten all 4 ranks of a rank-4 tensor

In [None]:
f04 = t2.flattenRanks(depth=0, levels=3)
print(f"{f04:n*}")
displayTensor(f04)

u04 = f04.unflattenRanks(depth=0, levels=3)
displayTensor(u04)

print(u04 == t2)