In [49]:
import scipy.special
import torch
import hidet
from hidet.graph import ops

# print graph
from torch._dynamo import optimize
from torch._inductor import graph
class CustomGraphLowering(graph.GraphLowering):
    def __init__(
        self,
        gm: torch.fx.GraphModule,
        *args,
        **kwargs,
    ):
        super().__init__(gm, *args, **kwargs)
        print("*"*30 + "Print Fx Graph" + "*"*30)
        gm.graph.print_tabular()
graph.GraphLowering = CustomGraphLowering

In [50]:
def test_take(input, dim, index):
    a = hidet.from_torch(input)
    b = hidet.from_torch(index)
    c = ops.transform.take(a, b, axis=dim)
    return c


In [51]:
def test_index_select():
    input = torch.Tensor([[1, 2], [3, 4], [5, 6]])
    index = torch.as_tensor([0], dtype=torch.int32)
    dim = 1
    b = torch.index_select(input, dim, index)
    print(b)

    print(test_take(input, dim, index))


In [52]:
def test_concat():
    x = torch.rand([3, 4, 5])
    y = torch.rand([0])
    print(torch.cat([x, y]))

    x = hidet.randn([3, 4, 5])
    y = hidet.randn([0])
    print(ops.concat([x, y], axis = 0))

In [53]:
def test_norm():
    dim = 0
    x = torch.rand([3, 4, 5])
    print(x)
    y = torch.norm(x, p = 1, dim=dim)
    print("torch: ", y.shape)
    print(y)
    print("*"*99)

    x_hidet = hidet.from_torch(x)
    y_hidet = ops.normalize.lp_norm(x_hidet, p = 1, dim=dim)
    print("hidet: ", y_hidet.shape)
    print(y_hidet)


In [54]:
def test_broadcast_tensor():
    x = torch.arange(3).view(1, 1, 3)
    y = torch.arange(2).view(2, 1)
    a, b = torch.broadcast_tensors(x, y)
    print("*"*20, "torch", "*"*20)
    print("x: ", x)
    print("y: ", y)
    print("a: ", a)
    print("b: ", b)

    x_hidet = hidet.from_torch(x)
    y_hidet = hidet.from_torch(y)
    a_hidet = ops.broadcast(x_hidet, [2, 3])
    b_hidet = ops.broadcast(y_hidet, [2, 3])
    print("*"*20, "hidet", "*"*20)
    print("a: ", a_hidet)
    print("b: ", b_hidet)

In [55]:
import scipy
import math
def test_torch_digamma():
    def digamma_1(x):
        if x == 0:
            return float('inf')
        elif x < 0:
            return digamma(1 - x) - math.pi / math.tan(math.pi * x)
        else:
            result = 0
            while x < 8:
                result -= 1 / x
                x += 1
            print("result = ", result)
            xx = 1 / (x * x)
            result += math.log(x) - 0.5 / x - xx * (1 / 12 - xx * (1 / 120 + xx / 252))
            return result
    
    # 0 < x < 8
    def compute_2(x):
        k = []
        for i in range(int(8 - x // 1)):
            k.append(1 / (x + i))
        
        sum_k = sum(k)
        # print("sum_k: ", sum_k)

        x_1 = 8 + x % 1
        print("x_1: ", x_1)
        xx = 1 / (x_1 * x_1)
        ret = -sum_k + math.log(x_1) - 0.5 / x_1 - xx * (1 / 12 - xx * (1 / 120 + xx / 252))
        return ret

    # x > 8
    def compute_3(x):
        xx = 1 / (x * x)
        ret = math.log(x) - 0.5 / x - xx * (1 / 12 - xx * (1 / 120 + xx / 252))
        return ret

    def compute_1(x):
        if x < 8:
            return compute_2(x)
        else:
            return compute_3(x)
            

    def digamma(x):
        # Handle special cases
        if x == 0:
            return float('-inf')
        elif x < 0:
            x_1 = 1 - x
            return compute_1(x_1) - math.pi / math.tan(math.pi * x)
        else:
            return compute_1(x)


    print(torch.digamma(torch.Tensor([[0.9988, 1.1198, 0.81]])))
    print(digamma(0.9988))
    print(digamma(1.1198))
    print(digamma(0.81))

In [56]:
def test_tan():
    # test_torch_digamma()
    x = torch.randn([2, 2]).cuda()
    
    x_hidet = hidet.from_torch(x)
    x_hidet = x_hidet.astype("f32")
    y = ops.tan(x_hidet)
    print(y)
    print(torch.tan(x))

In [57]:
def test_hidet_ir_util_broadcastshape():
    from hidet.ir.utils.broadcast_utils import broadcast_shapes
    s1 = [1, 2, 1]
    s2 = [1, 1, 4]
    s3 = [3, 1, 1]
    bs = broadcast_shapes([s1, s2, s3])
    print(bs)

In [58]:
def test_torch_Tensor_lgamma():
    def gamma(x, num_points=10000):
        # 定义积分函数
        def integrand(t):
            return t**(x-1) * math.exp(-t)

        # 梯形法则数值积分
        integral = 0
        dt = 50 / num_points  # 选择合适的积分步长
        for i in range(num_points):
            integral += integrand(i * dt + dt / 2) * dt

        return integral

    def lanczos_gamma(x):
        # Lanczos近似参数
        g = 7
        p = [
            0.99999999999980993,
            676.5203681218851,
            -1259.1392167224028,
            771.32342877765313,
            -176.61502916214059,
            12.507343278686905,
            -0.13857109526572012,
            9.9843695780195716e-6,
            1.5056327351493116e-7
        ]
        # if x < 0.5:
        #     return math.pi / (math.sin(math.pi * x) * lanczos_gamma(1 - x))
        # else:
        #     x -= 1
        #     a = p[0]
        #     for i in range(1, g + 2):
        #         a += p[i] / (x + i)
        #     t = x + g + 0.5
        #     return math.sqrt(2 * math.pi) * math.pow(t, (x + 0.5)) * math.exp(-t) * a

        # x < 0.5
        def compute_1(x):
            return math.pi / (math.sin(math.pi * x) * compute_2(1 - x))
        
        # x >= 0.5
        def compute_2(x):
            x = x - 1
            a = p[0]
            for i in range(1, g + 2):
                a += p[i] / (x + i)
            t = x + g + 0.5
            return math.sqrt(2 * math.pi) * math.pow(t, (x + 0.5)) * math.exp(-t) * a
        
        return compute_1(x) if x < 0.5 else compute_2(x)

    
    # x = torch.Tensor([-1.2])
    # print(x.lgamma())

    print(math.gamma(1.2))
    print(lanczos_gamma(1.2))


In [59]:
def test_reshape():
    def reshape(x, *shape):
        return ops.reshape(x, shape)

    x = torch.randn([2, 2, 2])
    y = x.reshape([2, -1])
    print(y)
    x_hidet = hidet.from_torch(x)
    y_hidet = ops.reshape(x_hidet, [2, -1])
    print(reshape(x_hidet, [2, -1]))


In [60]:
def test_clamp():
    def test(x):
        return torch.clamp(x, 1, 3)
    from hidet.graph.frontend.torch.register_functions import clamp
    hidet.torch.dynamo_config.dump_graph_ir("hidet_graph")
    x = torch.Tensor([1, 2, 3])
    test_opt = torch.compile(test, backend="hidet")
    y = test_opt(x)
    print(y)

In [61]:
def test_truediv():
    def truediv(x, y):
        return x / y
    opt_div = torch.compile(truediv, backend="hidet")
    x_tensor_int = torch.asarray([2, 3, 4], dtype=torch.int)
    x_tensor_float = torch.asarray([2.1, 2.2, 2.3], dtype=torch.float)
    x_int = int(2)
    x_float = float(3.1)
    x_l = [x_tensor_int, x_tensor_float, x_int, x_float]

    y_tensor_int = torch.asarray([1, 1, 1], dtype=torch.int)
    y_tensor_float = torch.asarray([3.1, 3.2, 3.3], dtype=torch.float)
    y_int = int(4)
    y_float = float(5.2)
    y_l = [y_tensor_int, y_tensor_float, y_int, y_float]

    for x in x_l:
        for y in y_l:
            print("*"*20, x.dtype if isinstance(x, torch.Tensor) else type(x), " div ", y.dtype if isinstance(y, torch.Tensor) else type(y), "*"*20)
            z = truediv(x, y)
            z_hidet = opt_div(x, y)
            print("eager: ", z)
            print("hidet: ", z_hidet)

In [62]:
def test_correct():
    def func(a):
        x = torch.empty(0)
        return torch.cat(x, a)

    hidet.torch.dynamo_config.correctness_report()
    opt_func = torch.compile(func, backend="hidet")
    x = torch.asarray([2, 3, 4])
    y = torch.asarray([4, 5, 6])
    z = opt_func(x)
    print(z)


In [63]:
def test_torch_empty():
    x = torch.empty(0)
    # x
    # y = torch.Tensor([])
    # print(x)

    x_hidet = hidet.asarray([], dtype=hidet.int32)
    y_hidet = hidet.asarray([2,3,4], dtype=hidet.int32)
    print(ops.concat([x_hidet, y_hidet], 0))

In [65]:
def test_getitem():
    x = torch.randn([3, 3, 3])
    print("x: ", x)
    
    print("================== x[list[int]] ==================")
    def func1(x):
        idx = [2, 0, 1]
        return x[idx]    
    opt_func1 = torch.compile(func1, backend="hidet")

    print(func1(x))
    print(opt_func1(x))
    
    print("================== x[list[int|slice]] ==================")
    def func2(x):
        return x[:, 1]
    
    opt_func2 = torch.compile(func2, backend="hidet")
    print(func2(x))
    print(opt_func2(x))


    # x_hidet = hidet.from_torch(x)
    # idx = [1, 1]
    # y = x[idx]
    # print("y: ", y)

    # y_hidet = ops.take(x_hidet, hidet.asarray(idx))
    # print("y_hidet: ", y_hidet)

test_getitem()

x:  tensor([[[-0.4590, -0.2242, -0.3001],
         [ 0.1112, -0.0772, -0.6744],
         [ 1.3309,  1.9713,  0.9461]],

        [[-0.2830, -1.7146,  1.0066],
         [ 0.0935,  0.6668, -1.2713],
         [ 0.9443, -2.2736, -0.2316]],

        [[-0.7779, -1.8326,  1.5674],
         [-0.6061, -0.1652,  0.1898],
         [-0.0588,  0.2220, -0.3239]]])
tensor([[[-0.7779, -1.8326,  1.5674],
         [-0.6061, -0.1652,  0.1898],
         [-0.0588,  0.2220, -0.3239]],

        [[-0.4590, -0.2242, -0.3001],
         [ 0.1112, -0.0772, -0.6744],
         [ 1.3309,  1.9713,  0.9461]],

        [[-0.2830, -1.7146,  1.0066],
         [ 0.0935,  0.6668, -1.2713],
         [ 0.9443, -2.2736, -0.2316]]])
tensor([[[-0.7779, -1.8326,  1.5674],
         [-0.6061, -0.1652,  0.1898],
         [-0.0588,  0.2220, -0.3239]],

        [[-0.4590, -0.2242, -0.3001],
         [ 0.1112, -0.0772, -0.6744],
         [ 1.3309,  1.9713,  0.9461]],

        [[-0.2830, -1.7146,  1.0066],
         [ 0.0935,  0.6668, -1

BackendCompilerFailed: backend='hidet' raised:
RuntimeError: name 'idx' is not defined, occurred when interpreting operator.getitem with
  getitem(tensor(...), (slice(None, None, None), 1))
getitem is defined at
  File "/root/miniconda3/envs/pytorch2_2_1/lib/python3.8/site-packages/hidet/graph/frontend/torch/register_functions.py", line 232

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


In [40]:
def test_slice():
    x = torch.randn([3, 3, 3])
    x_hidet = hidet.from_torch(x)
    print("x: ", x)
    print("torch: ", x[[slice(None, None), 0]])
    y_hidet = ops.reshape(ops.strided_slice(x_hidet, [None, 0], [None, 1]), [3, 3])
    print("hidet: ", y_hidet)

test_slice()

Compiling cpu task [92mreshape(x=float32(3, 1, 3), y=float32(3, 3), shape=[3, 3])[0m...


x:  tensor([[[ 0.6180,  0.2747,  0.0779],
         [-0.9967, -1.8447,  1.3099],
         [ 0.3335,  0.1612,  0.5073]],

        [[-0.5005, -1.6400, -1.3435],
         [-0.6357,  1.2300, -0.6855],
         [ 0.7223,  2.3741, -0.6154]],

        [[-0.3486,  0.1112, -0.2765],
         [-2.8281, -0.1209, -0.0776],
         [-0.1415, -0.4260, -0.5322]]])
torch:  tensor([[ 0.6180,  0.2747,  0.0779],
        [-0.5005, -1.6400, -1.3435],
        [-0.3486,  0.1112, -0.2765]])
hidet:  Tensor(shape=(3, 3), dtype='float32', device='cpu')
[[ 0.6179514   0.27466476  0.07788834]
 [-0.500519   -1.6399834  -1.3435045 ]
 [-0.34859186  0.11118111 -0.27647382]]


In [42]:
def test_getitem():
    x = torch.randn([3,3,3])
    x_hidet = hidet.from_torch(x)
    print("x: ", x)
    print(x[:, 1, 2])
    # print(x[:, 0])



test_getitem()

x:  tensor([[[ 0.5130, -0.8551, -1.0871],
         [-1.0447,  1.8052,  0.1656],
         [-1.2294,  0.5877,  0.5736]],

        [[-0.6809,  0.0297, -0.8670],
         [-0.1760, -0.2163, -0.3050],
         [-0.5036,  1.6644, -0.2396]],

        [[-0.1223, -0.9100, -0.1767],
         [-1.3702,  0.1823, -0.8587],
         [-0.0469,  0.2566,  0.9508]]])


IndexError: too many indices for tensor of dimension 3