In [3]:
import torch

In [20]:
def _combine(xs_b, ys_b):
        """Interleaves the x's and the y's into a single sequence."""
        bsize, points, dim = xs_b.shape
        ys_b_wide = torch.cat(
            (
                ys_b.view(bsize, points, 1),
                torch.zeros(bsize, points, dim - 1, device=ys_b.device),
            ),
            axis=2,
        )
        zs = torch.stack((xs_b, ys_b_wide), dim=2)
        zs = zs.view(bsize, 2 * points, dim)
        return zs

In [21]:
x = torch.tensor([[[1,2,3], [10,11,12]]])
y = torch.tensor([[7,8]])
print(_combine(x,y))

tensor([[[ 1.,  2.,  3.],
         [ 7.,  0.,  0.],
         [10., 11., 12.],
         [ 8.,  0.,  0.]]])


In [6]:
b_size = 2
n_dims = 10
n_points = 20
scale = 1
xs_b = torch.randn(b_size, n_points, n_dims)
w_b = torch.randn(b_size, n_dims, 1)
ys_b = scale * (xs_b @ w_b)[:, :, 0]
# print(xs_b)
# print(w_b)
ys_b

tensor([[ 4.2928,  0.7752, -2.0067, -2.7497, -2.2065,  1.6115, -2.3812, -1.0388,
          2.3826,  0.9312, -1.0446,  0.2875, -2.2057, -0.9914, -2.4912,  2.3056,
          1.5789,  0.7910,  1.9145, -0.0324],
        [ 0.2469,  1.3184, -2.2832,  2.8630, -0.3621, -0.6629,  0.7160,  2.6438,
          1.0135,  3.1089,  0.7629,  0.4742,  2.7834,  0.2663, -0.3894,  1.6009,
          2.6995,  2.0222, -0.8605, -3.5497]])

In [9]:
z_b = _combine(xs_b, ys_b)
z_b.shape

torch.Size([2, 40, 10])

In [None]:
xs_b

In [13]:
z_b[0][1]

tensor([4.2928, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000])

In [24]:
def _combine(xs_b, ys_b, cat_b):  # Added argument for category
    """Interleaves the x's, y's and category into a single sequence."""
    bsize, points, dim = xs_b.shape
    ys_b_wide = torch.cat(
        (
            ys_b.view(bsize, points, 1),
            torch.zeros(bsize, points, dim , device=ys_b.device),
        ),
        axis=2,
    )
    # Add 'cat' to the 'xs_b' tensor as the first dimension
    xs_b_cat = torch.cat([cat_b, xs_b], dim=-1)  # cat_b and xs_b are now of shape (batch_size, num_points, num_features+1)
    # ys_b stays zeros at the end because it does not have a 'cat'
    zs = torch.stack((xs_b_cat, ys_b_wide), dim=2)
    zs = zs.view(bsize, 2 * points, dim + 1)  # dim increased by 1
    return zs

In [9]:
b_size = 2
n_dims = 10
n_points = 20
scale = 1
xs_b = torch.randn(b_size, n_points, n_dims)
w_b = torch.randn(b_size, n_dims, 1)
ys_b = scale * (xs_b @ w_b)[:, :, 0]
# print(xs_b)
# print(w_b)
ys_b

tensor([[ 2.3142,  1.8115,  1.7231,  6.1020, -0.2173,  0.7389,  3.2170, -3.8504,
         -3.7699,  0.5092,  4.1580, -0.0885, -1.7097, -1.1708,  2.5282,  0.4906,
          7.7920,  4.2260, -1.2555,  2.5978],
        [-3.5541, -1.7211,  2.9887, -2.6853,  0.1160,  0.0864,  0.5158,  0.4945,
          3.3078,  1.8155, -0.8824, -2.5839,  2.9885, -1.8227, -0.7254,  1.2981,
          2.2626, -0.1825,  1.8890, -1.4638]])

In [25]:
cat_b = torch.full(xs_b.shape[:-1] + (1,), float(3))
z_b = _combine(xs_b, ys_b, cat_b)
z_b.shape

torch.Size([2, 40, 11])

In [26]:
z_b[0][0]

tensor([ 3.0000e+00, -1.5289e+00, -2.2036e-01, -2.0727e-04,  1.1929e+00,
         9.5127e-01,  1.0182e+00,  1.0993e+00,  1.1125e+00,  1.6158e+00,
        -2.1588e-01])

In [27]:
z_b[0][1]

tensor([2.3142, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000])