In [1]:
import torch

In [20]:
def _combine(xs_b, ys_b):
        """Interleaves the x's and the y's into a single sequence."""
        bsize, points, dim = xs_b.shape
        ys_b_wide = torch.cat(
            (
                ys_b.view(bsize, points, 1),
                torch.zeros(bsize, points, dim - 1, device=ys_b.device),
            ),
            axis=2,
        )
        zs = torch.stack((xs_b, ys_b_wide), dim=2)
        zs = zs.view(bsize, 2 * points, dim)
        return zs

In [21]:
x = torch.tensor([[[1,2,3], [10,11,12]]])
y = torch.tensor([[7,8]])
print(_combine(x,y))

tensor([[[ 1.,  2.,  3.],
         [ 7.,  0.,  0.],
         [10., 11., 12.],
         [ 8.,  0.,  0.]]])


In [2]:
b_size = 2
n_dims = 10
n_points = 20
scale = 1
xs_b = torch.randn(b_size, n_points, n_dims)
w_b = torch.randn(b_size, n_dims, 1)
ys_b = scale * (xs_b @ w_b)[:, :, 0]
# print(xs_b)
# print(w_b)
ys_b

tensor([[-3.2249,  1.0662,  2.5475,  4.5052,  4.1052, -7.2817, -1.6772,  0.2272,
          0.1621, -2.7903, -2.2233,  1.4837,  9.6558,  4.4323, -1.7924,  2.7732,
          4.8481, -0.7044,  2.9636, -7.9988],
        [ 4.8723, -1.1972, -2.5540,  1.9891,  1.2408, -1.1358,  2.8412, -0.8591,
          0.0524, -1.5446,  1.0500, -0.0445,  0.6956,  0.4533, -2.1465, -0.6740,
          2.2280,  0.8928,  0.0525, -0.8362]])

In [3]:
ys_b.shape

torch.Size([2, 20])

In [9]:
z_b = _combine(xs_b, ys_b)
z_b.shape

torch.Size([2, 40, 10])

In [None]:
xs_b

In [13]:
z_b[0][1]

tensor([4.2928, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000])

In [24]:
def _combine(xs_b, ys_b, cat_b):  # Added argument for category
    """Interleaves the x's, y's and category into a single sequence."""
    bsize, points, dim = xs_b.shape
    ys_b_wide = torch.cat(
        (
            ys_b.view(bsize, points, 1),
            torch.zeros(bsize, points, dim , device=ys_b.device),
        ),
        axis=2,
    )
    # Add 'cat' to the 'xs_b' tensor as the first dimension
    xs_b_cat = torch.cat([cat_b, xs_b], dim=-1)  # cat_b and xs_b are now of shape (batch_size, num_points, num_features+1)
    # ys_b stays zeros at the end because it does not have a 'cat'
    zs = torch.stack((xs_b_cat, ys_b_wide), dim=2)
    zs = zs.view(bsize, 2 * points, dim + 1)  # dim increased by 1
    return zs

In [9]:
b_size = 2
n_dims = 10
n_points = 20
scale = 1
xs_b = torch.randn(b_size, n_points, n_dims)
w_b = torch.randn(b_size, n_dims, 1)
ys_b = scale * (xs_b @ w_b)[:, :, 0]
# print(xs_b)
# print(w_b)
ys_b

tensor([[ 2.3142,  1.8115,  1.7231,  6.1020, -0.2173,  0.7389,  3.2170, -3.8504,
         -3.7699,  0.5092,  4.1580, -0.0885, -1.7097, -1.1708,  2.5282,  0.4906,
          7.7920,  4.2260, -1.2555,  2.5978],
        [-3.5541, -1.7211,  2.9887, -2.6853,  0.1160,  0.0864,  0.5158,  0.4945,
          3.3078,  1.8155, -0.8824, -2.5839,  2.9885, -1.8227, -0.7254,  1.2981,
          2.2626, -0.1825,  1.8890, -1.4638]])

In [25]:
cat_b = torch.full(xs_b.shape[:-1] + (1,), float(3))
z_b = _combine(xs_b, ys_b, cat_b)
z_b.shape

torch.Size([2, 40, 11])

In [26]:
z_b[0][0]

tensor([ 3.0000e+00, -1.5289e+00, -2.2036e-01, -2.0727e-04,  1.1929e+00,
         9.5127e-01,  1.0182e+00,  1.0993e+00,  1.1125e+00,  1.6158e+00,
        -2.1588e-01])

In [27]:
z_b[0][1]

tensor([2.3142, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000])

In [8]:
import torch
one_hot_tensor = torch.nn.functional.one_hot(torch.tensor(0), num_classes=3).float()
one_hot_tensor

tensor([1., 0., 0.])

In [12]:
one_hot_tensor = torch.nn.functional.one_hot(torch.tensor(1), num_classes=3).float().view(1, 1, 3)
one_hot_tensor

tensor([[[0., 1., 0.]]])

In [15]:
y = torch.randn(b_size, n_points)
y1 = torch.randn(b_size, n_points)
(y - y1).square().mean()

tensor(2.1891)

In [5]:
print(5**42 - 75**16)

-774885900318622589111328125000


In [9]:
print(9**4)

6561
