In [1]:
import operator
import string
from functools import reduce
from typing import Union

import funsor
import torch
from pyro import set_rng_seed as pyro_set_rng_seed
from torch import Tensor

funsor.set_backend("torch")
torch.set_default_dtype(torch.float32)
pyro_set_rng_seed(0)

from pyroapi import pyro

pyro.clear_param_store()

In [2]:
all_lower_alphas = torch.tensor([ord(letter) for letter in string.ascii_lowercase])
print(all_lower_alphas, all_lower_alphas.size())

tensor([ 97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
        111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122]) torch.Size([26])


In [3]:
all_digits = torch.tensor([ord(d) for d in string.digits])
print(all_digits, all_digits.size())

tensor([48, 49, 50, 51, 52, 53, 54, 55, 56, 57]) torch.Size([10])


In [4]:
def enumerate_sequences(*args, shift: int = 8) -> Tensor: # IDEA: make shift broadcastable (i.e. shift according to space requirement)
    if not args:
        raise ValueError("At least one tensor required")
    worklist = [arg for arg in args]
    with torch.no_grad():
        t = worklist.pop()
        if isinstance(t, tuple):
            replicate, tensor = t
            for i in range(replicate - 1):
                worklist.append(torch.clone(tensor))
            t = tensor
        result = torch.clone(t)
        while worklist:
            t = worklist.pop()
            if isinstance(t, tuple):
                replicate, tensor = t
                for i in range(replicate):
                    worklist.append(torch.clone(tensor))
                continue
            for _d in result.size():
                t = torch.unsqueeze(t, -1)
            result = (result << shift) + t

    return result

In [5]:
def count_sequences(*args) -> int:
    if not args:
        return 0
    result = 1
    index = 0
    while index < len(args):
        t = args[index]
        if isinstance(t, tuple):
            r, tensor = t
            t = tensor
        else:
            r = 1
        result *= count_elems(t) ** r
        index += 1
    return result

In [6]:
def count_elems(t: Union[Tensor, torch.Size]) -> int:
    if isinstance(t, Tensor):
        t = t.size()
    return reduce(operator.mul, t, 1)


In [7]:
%%time
p = enumerate_sequences((4, all_lower_alphas), (4, all_digits))
count_elems(p)

CPU times: user 8.73 s, sys: 12.7 s, total: 21.4 s
Wall time: 2.7 s


4569760000

In [8]:
%%time
count_sequences((4, all_lower_alphas), (4, all_digits))

CPU times: user 40 µs, sys: 55 µs, total: 95 µs
Wall time: 13.1 µs


4569760000

In [9]:
%%time
p = enumerate_sequences(torch.tensor(ord('T')), torch.tensor([ord('A'), ord('H'), ord('M')]), (2, all_lower_alphas), (3, all_digits), torch.tensor(9))
count_elems(p)

CPU times: user 1.22 s, sys: 687 ms, total: 1.9 s
Wall time: 641 ms


2028000

In [10]:
%%time
count_sequences(torch.tensor(ord('T')), torch.tensor([ord('A'), ord('H'), ord('M')]), (2, all_lower_alphas), (3, all_digits), torch.tensor(9))

CPU times: user 30 µs, sys: 38 µs, total: 68 µs
Wall time: 52.9 µs


2028000