In [1]:
import torch
from itertools import combinations

In [2]:
def totally_unimodular_matrix(A, tol: float = 1e-9):

    A = A.to(dtype=torch.float64)
    m, n = A.shape
    i = min(m, n)

    for i in range(1, i+1):
        for rows in combinations(range(m), i):
            R = torch.tensor(rows)
            Ar = A.index_select(0, R)
            for cols in combinations(range(n), i):
                C = torch.tensor(cols)
                M = Ar.index_select(1, C)
                d = torch.linalg.det(M).item()
                di = round(d)                # nearest integer
                # reject if not (â‰ˆ integer) or not in {-1,0,1}
                if abs(d - di) > tol or di not in (-1, 0, 1):
                    return False, {"rows": rows, "cols": cols, "det": d}
    return True, None

In [3]:
A1 = torch.tensor([
    [1,0,1,0,1],
    [0,1,1,1,0],
    [0,0,0,1,1],
    [1,1,0,0,0]
], dtype=torch.int)

A2 = torch.tensor([
    [-1,0,1,0,-1,0,0],
    [ 0,1,0,1, 1,0,1],
    [-1,1,0,0, 0,0,0],
    [ 0,0,1,0, 0,1,1],
    [ 0,0,0,1, 0,-1,0]
], dtype=torch.int)

In [4]:
print(totally_unimodular_matrix(A1))

(False, {'rows': (0, 1, 2), 'cols': (2, 3, 4), 'det': 2.0})


In [5]:
print(totally_unimodular_matrix(A2))

(True, None)
