In [19]:
import torch
from torch import Tensor
from typing import Union
def get_row_col_indicator_V1(subgraphs: list[Union[Tensor,list]]) -> tuple[Tensor,Tensor]:
    r"""
    NOTE: this assumes 'domain_indicator' is ordered the same as 'subgraphs'
        and increments domains by one starting at zero.
        This is how the code typically does it, but it's still good to 
        keep in mind. (the above commented out code is more general)
    """
    rows = []
    cols = []
    count = 0
    for subg in subgraphs:
        size = len(subg)
        ar = torch.arange(count,count + size)
        row = ar.unsqueeze(-1).broadcast_to(-1,size).flatten()
        col = ar.unsqueeze(0).broadcast_to(size,-1).flatten()
        rows.append(row)
        cols.append(col)
        count += size
    return torch.cat(rows), torch.cat(cols)

def test1():
    subgraphs = [[0]*3]
    erows = torch.tensor([0,0,0,1,1,1,2,2,2])
    ecols = torch.tensor([0,1,2,0,1,2,0,1,2])
    rows, cols = get_row_col_indicator_V1(subgraphs)
    assert (erows == rows).all(), rows
    assert (cols == ecols).all(), cols

def test2():
    subgraphs = [[0]*3,[0],[0]*2]
    erows = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 4, 4, 5, 5])
    ecols = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, 4, 5])
    rows, cols = get_row_col_indicator_V1(subgraphs)
    assert (erows == rows).all(), rows
    assert (cols == ecols).all(), cols

# def test3_generic(a, b):
#     domain_ind = torch.randint(30,(200,))
#     row1, col1 = a(domain_ind)
#     row2, col2 = b(domain_ind)
#     assert (row1 == row2).all(), f"{row1} !=\n {row2}"
#     assert (col1 == col2).all(), f"{col1} !=\n {col2}"

test1()
test2()

In [36]:
def get_transpose(subgraphs: list[Tensor]) -> Tensor:
    tensor_list = []
    count = 0
    for subg in subgraphs:
        size = len(subg)**2
        if size > 0:
            transpose = torch.arange(count,count + size).view(-1,len(subg)).transpose(1,0).flatten()
            count += size
            tensor_list.append(transpose)
    return torch.cat(tensor_list)

def test3():
    subgraphs = [[0]*3]
    etranspose = [
        0,3,6,
        1,4,7,
        2,5,8
    ]
    etranspose = torch.tensor(etranspose)
    ptranspose = get_transpose(subgraphs)
    assert (ptranspose == etranspose).all(), ptranspose

def test4(n,m):
    subgraphs = [[0]*v for v in torch.randint(n,(m,))]
    ptranspose = get_transpose(subgraphs)
    assert (ptranspose == ptranspose[ptranspose[ptranspose]]).all(), f"not involutive\n{ptranspose}"
    row, col = get_row_col_indicator_V1(subgraphs)
    assert (row[ptranspose] == col).all(), \
        f"transpose of row incorrect\nrow\n{row}\ncol\n{col}\nrow[ptranspose]\n{row[ptranspose]}\nsubgraphs\n{subgraphs}\nptranspose\n{ptranspose}"
test3()
test4(3,4)