In [29]:
# pointtransformer_pure_torch.py
import torch
import torch.nn as nn
import torch.nn.functional as F

In [30]:
# -----------------------------
# Helpers for offset handling
# -----------------------------
def _offset_to_ranges(offset: torch.Tensor):
    """
    offset: (B,) cumulative counts, e.g., [n1, n1+n2, n1+n2+n3, ...]
    yields (start, end) per batch.
    """
    ranges = []
    start = 0
    for i in range(offset.shape[0]):
        end = int(offset[i].item())
        ranges.append((start, end))
        start = end
    return ranges

# -----------------------------
# Pure PyTorch FPS / kNN / Interp (batched by offset)
# -----------------------------
def fps(points: torch.Tensor, npoint: int):
    """
    Farthest Point Sampling (single set, no offsets)
    points: (N, 3)
    npoint: int
    returns: (npoint,) indices into points
    """
    N = points.shape[0]
    device = points.device
    idx = torch.zeros(npoint, dtype=torch.long, device=device)
    distances = torch.full((N,), 1e10, device=device)
    farthest = torch.randint(0, N, (1,), device=device, dtype=torch.long)
    for i in range(npoint):
        idx[i] = farthest
        centroid = points[farthest]  # (1, 3)
        dist = torch.sum((points - centroid) ** 2, dim=-1)
        distances = torch.minimum(distances, dist)
        farthest = torch.argmax(distances)
    return idx

def furthest_sampling_offset(xyz: torch.Tensor, offset: torch.Tensor, new_offset: torch.Tensor):
    """
    Batched FPS using offsets.
    xyz: (N, 3)
    offset: (B,)
    new_offset: (B,) cumulative number of samples per batch
    return: (M,) indices (global) into xyz
    """
    ranges = _offset_to_ranges(offset)
    new_ranges = _offset_to_ranges(new_offset)
    assert len(ranges) == len(new_ranges)

    out_indices = []
    for (s, e), (ns, ne) in zip(ranges, new_ranges):
        N_b = e - s
        M_b = ne - ns
        if M_b <= 0:
            continue
        pts = xyz[s:e, :]  # (N_b, 3)
        if M_b >= N_b:
            idx_local = torch.arange(N_b, device=xyz.device, dtype=torch.long)
        else:
            idx_local = fps(pts, M_b)  # (M_b,)
        out_indices.append(idx_local + s)
    if len(out_indices) == 0:
        return torch.empty((0,), dtype=torch.long, device=xyz.device)
    return torch.cat(out_indices, dim=0)

def queryandgroup_offset(nsample: int,
                         xyz: torch.Tensor,   # (N,3)
                         new_xyz: torch.Tensor,  # (M,3)
                         feat: torch.Tensor,  # (N,C)
                         offset: torch.Tensor,  # (B,)
                         new_offset: torch.Tensor,  # (B,)
                         use_xyz: bool = True):
    """
    Batched kNN grouping using offsets.
    Returns:
        (M, nsample, 3+C) if use_xyz else (M, nsample, C)
    """
    ranges = _offset_to_ranges(offset)
    new_ranges = _offset_to_ranges(new_offset)
    assert len(ranges) == len(new_ranges)

    grouped_list = []
    for (s, e), (ns, ne) in zip(ranges, new_ranges):
        if ne - ns == 0:
            continue
        xyz_b = xyz[s:e, :]            # (Nb,3)
        feat_b = feat[s:e, :]          # (Nb,C)
        new_xyz_b = new_xyz[ns:ne, :]  # (Mb,3)

        # (Mb, Nb) pairwise distances
        dist = torch.cdist(new_xyz_b, xyz_b)
        idx = dist.topk(min(nsample, xyz_b.shape[0]), largest=False)[1]  # (Mb, k)

        grouped_feat = feat_b[idx]  # (Mb, k, C)
        if use_xyz:
            grouped_xyz = xyz_b[idx]  # (Mb, k, 3)
            rel = grouped_xyz - new_xyz_b.unsqueeze(1)  # (Mb,k,3)
            grouped = torch.cat([rel, grouped_feat], dim=-1)  # (Mb, k, 3+C)
        else:
            grouped = grouped_feat  # (Mb, k, C)
        # If k < nsample (when Nb < nsample), we can pad (optional).
        if grouped.shape[1] < nsample:
            pad_k = nsample - grouped.shape[1]
            pad_shape = list(grouped.shape)
            pad_shape[1] = pad_k
            grouped = torch.cat([grouped, grouped.new_zeros(pad_shape)], dim=1)
        grouped_list.append(grouped)

    if len(grouped_list) == 0:
        Cg = (3 + feat.shape[1]) if use_xyz else feat.shape[1]
        return torch.zeros((0, nsample, Cg), device=xyz.device, dtype=xyz.dtype)
    return torch.cat(grouped_list, dim=0)

def interpolation_offset(src_xyz: torch.Tensor,  # (N2,3)
                         tgt_xyz: torch.Tensor,  # (N1,3)
                         src_feat: torch.Tensor, # (N2,C)
                         src_offset: torch.Tensor,  # (B,)
                         tgt_offset: torch.Tensor,  # (B,)
                         k: int = 3):
    """
    Batched inverse-distance interpolation from src to tgt using offsets.
    Returns:
        (N1, C)
    """
    ranges_src = _offset_to_ranges(src_offset)
    ranges_tgt = _offset_to_ranges(tgt_offset)
    assert len(ranges_src) == len(ranges_tgt)

    outs = []
    for (sS, eS), (sT, eT) in zip(ranges_src, ranges_tgt):
        if eT - sT == 0:
            continue
        xyz2 = src_xyz[sS:eS, :]        # (N2b,3)
        feat2 = src_feat[sS:eS, :]      # (N2b,C)
        xyz1 = tgt_xyz[sT:eT, :]        # (N1b,3)

        if xyz2.shape[0] == 0:
            outs.append(torch.zeros((xyz1.shape[0], feat2.shape[1]), device=src_feat.device, dtype=src_feat.dtype))
            continue

        dist = torch.cdist(xyz1, xyz2)  # (N1b, N2b)
        k_eff = min(k, xyz2.shape[0])
        idx = dist.topk(k_eff, largest=False)[1]  # (N1b,k_eff)
        d = torch.gather(dist, 1, idx) + 1e-8     # (N1b,k_eff)

        w = 1.0 / d
        w = w / w.sum(dim=1, keepdim=True)  # (N1b,k_eff)
        interp = torch.sum(feat2[idx] * w.unsqueeze(-1), dim=1)  # (N1b,C)
        outs.append(interp)

    if len(outs) == 0:
        return torch.zeros((0, src_feat.shape[1]), device=src_feat.device, dtype=src_feat.dtype)
    return torch.cat(outs, dim=0)


In [31]:
# -----------------------------
class PointTransformerLayer(nn.Module):
    def __init__(self, in_planes, out_planes, share_planes=8, nsample=16):
        super().__init__()
        self.mid_planes = mid_planes = out_planes // 1
        self.out_planes = out_planes
        self.share_planes = share_planes
        self.nsample = nsample

        self.linear_q = nn.Linear(in_planes, mid_planes)
        self.linear_k = nn.Linear(in_planes, mid_planes)
        self.linear_v = nn.Linear(in_planes, out_planes)

        self.linear_p = nn.Sequential(
            nn.Linear(3, 3),
            nn.BatchNorm1d(3),
            nn.ReLU(inplace=True),
            nn.Linear(3, out_planes),
        )
        self.linear_w = nn.Sequential(
            nn.BatchNorm1d(mid_planes),
            nn.ReLU(inplace=True),
            nn.Linear(mid_planes, mid_planes // share_planes),
            nn.BatchNorm1d(mid_planes // share_planes),
            nn.ReLU(inplace=True),
            nn.Linear(out_planes // share_planes, out_planes // share_planes),
        )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, pxo) -> torch.Tensor:
        p, x, o = pxo  # p: (N,3), x: (N,C_in), o: (B,)
        x_q, x_k, x_v = self.linear_q(x), self.linear_k(x), self.linear_v(x)  # (N, *)
        # (N, nsample, 3+C_mid) and (N, nsample, C_out)
        x_k_group = queryandgroup_offset(self.nsample, p, p, x_k, o, o, use_xyz=True)
        x_v_group = queryandgroup_offset(self.nsample, p, p, x_v, o, o, use_xyz=False)
        p_r, x_kn = x_k_group[:, :, 0:3], x_k_group[:, :, 3:]  # (N,ns,3), (N,ns,C_mid)

        # position encoding MLP across neighbor dimension:
        # handle BN1d by transposing to (N*ns,3) or (N*ns, out_planes)
        Nns = p_r.shape[0] * p_r.shape[1]
        pr = p_r.reshape(Nns, 3)
        # replicate the original trick around BN1d position in Sequential
        out_pr = pr
        for i, layer in enumerate(self.linear_p):
            if i == 1:  # BatchNorm1d
                out_pr = layer(out_pr)
            else:
                out_pr = layer(out_pr)
        p_r_enc = out_pr.view(p_r.shape[0], p_r.shape[1], self.out_planes)  # (N,ns,out_planes)

        # attention weights
        w = x_kn - x_q.unsqueeze(1) + p_r_enc.view(
            p_r_enc.shape[0], p_r_enc.shape[1],
            self.out_planes // self.mid_planes, self.mid_planes
        ).sum(2)  # (N,ns,C_mid)

        # apply linear_w with BN along feature dim:
        # linear_w expects (B,C,L) or (N*,C) with BN1d on C, so we reshape
        W = w.transpose(1, 2).contiguous()  # (N, C_mid, ns)
        # We need to mimic the original step-by-step application with transposes around BN layers in linear_w
        # But an easier and safe way: apply per last-dim MLP with BN1d by flattening (N*ns, C_mid)
        w_flat = w.reshape(-1, w.shape[-1])
        # Re-build linear_w manually with the same layers to preserve behavior with BN1d
        # (We already created self.linear_w; we apply it by managing shapes)
        # Pass through the sequence while toggling shapes for BN1d positions (they expect (N,C))
        tmp = w_flat
        for j, layer in enumerate(self.linear_w):
            if isinstance(layer, nn.BatchNorm1d):
                tmp = layer(tmp)
            else:
                tmp = layer(tmp)
        w = tmp.view(w.shape[0], w.shape[1], -1)  # (N,ns,C_mid/ share) at the end of MLP
        w = self.softmax(w)  # softmax along neighbor dim (dim=1)

        n, nsample, c = x_v_group.shape
        s = self.share_planes
        # (x_v + p_r_enc) has C_out; we share along s
        x_out = ((x_v_group + p_r_enc).view(n, nsample, s, c // s) * w.unsqueeze(2)).sum(1).view(n, c)
        return x_out

In [121]:
class TransitionDown(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1, nsample=16):
        super().__init__()
        self.stride, self.nsample = stride, nsample
        if stride != 1:
            self.linear = nn.Linear(3 + in_planes, out_planes, bias=False)
            self.pool = nn.MaxPool1d(nsample)
        else:
            self.linear = nn.Linear(in_planes, out_planes, bias=False)
        self.bn = nn.BatchNorm1d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, pxo):
        p, x, o = pxo  # (N,3), (N,C), (B,) 
        print("Enc p=", p.shape, "x=", x.shape, "o=", o)
        if self.stride != 1:
            # new_offset: downsample each batch by integer factor
            n_o_list = []
            count = int(o[0].item()) // self.stride
            n_o_list.append(count)
            for i in range(1, o.shape[0]):
                count += int((o[i].item() - o[i-1].item()) // self.stride)
                n_o_list.append(count)
            n_o = torch.tensor(n_o_list, device=o.device, dtype=o.dtype)

            idx = furthest_sampling_offset(p, o, n_o)  # (M,)
            n_p = p[idx.long(), :]  # (M,3)
            # (M, ns, 3+C_in)
            grouped = queryandgroup_offset(self.nsample, p, n_p, x, o, n_o, use_xyz=True)
            # linear over last dim -> (M, ns, out_planes)
            xg = self.linear(grouped)
            # to (M, out_planes, ns) for 1D pool
            xg = xg.transpose(1, 2).contiguous()
            xg = self.relu(self.bn(xg))  # BN over channel
            xg = self.pool(xg).squeeze(-1)  # (M, out_planes)
            p, x, o = n_p, xg, n_o
        else:
            x = self.relu(self.bn(self.linear(x)))  # (N,out_planes)
        return [p, x, o]


In [122]:
class TransitionUp(nn.Module):
    def __init__(self, in_planes, out_planes=None):
        super().__init__()
        if out_planes is None:
            self.linear1 = nn.Sequential(
                nn.Linear(2 * in_planes, in_planes),
                nn.BatchNorm1d(in_planes),
                nn.ReLU(inplace=True)
            )
            self.linear2 = nn.Sequential(
                nn.Linear(in_planes, in_planes),
                nn.ReLU(inplace=True)
            )
        else:
            self.linear1 = nn.Sequential(
                nn.Linear(out_planes, out_planes),
                nn.BatchNorm1d(out_planes),
                nn.ReLU(inplace=True)
            )
            self.linear2 = nn.Sequential(
                nn.Linear(in_planes, out_planes),
                nn.BatchNorm1d(out_planes),
                nn.ReLU(inplace=True)
            )

    def forward(self, pxo1, pxo2=None):
        if pxo2 is None:
            _, x, o = pxo1  # (N,3), (N,C), (B,)
            x_tmp = []
            for i in range(o.shape[0]):
                if i == 0:
                    s_i, e_i, cnt = 0, int(o[0].item()), int(o[0].item())
                else:
                    s_i, e_i, cnt = int(o[i-1].item()), int(o[i].item()), int(o[i].item() - o[i-1].item())
                x_b = x[s_i:e_i, :]
                x_b = torch.cat((x_b, self.linear2(x_b.sum(0, keepdim=True) / cnt).repeat(cnt, 1)), 1)
                x_tmp.append(x_b)
            x = torch.cat(x_tmp, 0)
            x = self.linear1(x)
        else:
            p1, x1, o1 = pxo1
            p2, x2, o2 = pxo2
            # interpolate features from p2->p1
            interp = interpolation_offset(p2, p1, self.linear2(x2), o2, o1, k=3)
            x = self.linear1(x1) + interp
        return x


In [123]:
class PointTransformerBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, share_planes=8, nsample=16):
        super(PointTransformerBlock, self).__init__()
        self.linear1 = nn.Linear(in_planes, planes, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.transformer2 = PointTransformerLayer(planes, planes, share_planes, nsample)
        self.bn2 = nn.BatchNorm1d(planes)
        self.linear3 = nn.Linear(planes, planes * self.expansion, bias=False)
        self.bn3 = nn.BatchNorm1d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, pxo):
        p, x, o = pxo
        identity = x
        x = self.relu(self.bn1(self.linear1(x)))
        x = self.relu(self.bn2(self.transformer2([p, x, o])))
        x = self.bn3(self.linear3(x))
        x = x + identity
        x = self.relu(x)
        return [p, x, o]


In [141]:
class PointTransformerSeg(nn.Module):
    def __init__(self, block, blocks, c=6, k=13):
        super().__init__() 
        #print("block=", block)
        #print("blocks=", blocks) 
        
        self.c = c
        self.in_planes, planes = c, [32, 64, 128, 256, 512]
        share_planes = 8
        stride, nsample = [1, 4, 4, 4, 4], [8, 16, 16, 16, 16]

        self.enc1 = self._make_enc(block, planes[0], blocks[0], share_planes, stride=stride[0], nsample=nsample[0])  # N/1
        self.enc2 = self._make_enc(block, planes[1], blocks[1], share_planes, stride=stride[1], nsample=nsample[1])  # N/4
        self.enc3 = self._make_enc(block, planes[2], blocks[2], share_planes, stride=stride[2], nsample=nsample[2])  # N/16
        self.enc4 = self._make_enc(block, planes[3], blocks[3], share_planes, stride=stride[3], nsample=nsample[3])  # N/64
        self.enc5 = self._make_enc(block, planes[4], blocks[4], share_planes, stride=stride[4], nsample=nsample[4])  # N/256

        self.dec5 = self._make_dec(block, planes[4], 2, share_planes, nsample=nsample[4], is_head=True)  # transform p5
        self.dec4 = self._make_dec(block, planes[3], 2, share_planes, nsample=nsample[3])  # fusion p5 and p4
        self.dec3 = self._make_dec(block, planes[2], 2, share_planes, nsample=nsample[2])  # fusion p4 and p3
        self.dec2 = self._make_dec(block, planes[1], 2, share_planes, nsample=nsample[1])  # fusion p3 and p2
        self.dec1 = self._make_dec(block, planes[0], 2, share_planes, nsample=nsample[0])  # fusion p2 and p1

        self.cls = nn.Sequential(
            nn.Linear(planes[0], planes[0]),
            nn.BatchNorm1d(planes[0]),
            nn.ReLU(inplace=True),
            nn.Linear(planes[0], k)
        )

    def _make_enc(self, block, planes, blocks, share_planes=8, stride=1, nsample=16):
        layers = []
        layers.append(TransitionDown(self.in_planes, planes * block.expansion, stride, nsample))
        #print("layers = ", layers, "expension=", block.expansion)
        self.in_planes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_planes, self.in_planes, share_planes, nsample=nsample)) #share plane = 8
        #print(layers)
        return nn.Sequential(*layers)

    def _make_dec(self, block, planes, blocks, share_planes=8, nsample=16, is_head=False):
        layers = []
        layers.append(TransitionUp(self.in_planes, None if is_head else planes * block.expansion))
        self.in_planes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_planes, self.in_planes, share_planes, nsample=nsample)) 
        #print(nn.Sequential(*layers))
        return nn.Sequential(*layers)

    def forward(self, pxo):
        p0, x0, o0 = pxo  # p0:(N,3), x0:(N,C_in), o0:(B,)
        x0 = p0 if self.c == 3 else torch.cat((p0, x0), 1)

        print("p0=", p0.shape, "x0=", x0.shape, "o0=", o0)

        p1, x1, o1 = self.enc1([p0, x0, o0])
        print("p1=", p1.shape, "x1=", x1.shape, "o1=", o1)
        p2, x2, o2 = self.enc2([p1, x1, o1])
        print("p2=", p2.shape, "x2=", x2.shape, "o2=", o2)
        p3, x3, o3 = self.enc3([p2, x2, o2])
        print("p3=", p3.shape, "x3=", x3.shape, "o3=", o3)
        p4, x4, o4 = self.enc4([p3, x3, o3]) 
        print("p4=", p4.shape, "x4=", x4.shape, "o4=", o4)
        p5, x5, o5 = self.enc5([p4, x4, o4])
        print("p5=", p5.shape, "x5=", x5.shape, "o5=", o5) 

        print("****************** Decoder *****************")

        x5 = self.dec5[1:]([p5, self.dec5[0]([p5, x5, o5]), o5])[1] 
        print("x5=", x5.shape)
        x4 = self.dec4[1:]([p4, self.dec4[0]([p4, x4, o4], [p5, x5, o5]), o4])[1] 
        print("x4=", x4.shape)
        x3 = self.dec3[1:]([p3, self.dec3[0]([p3, x3, o3], [p4, x4, o4]), o3])[1] 
        print("x3=", x3.shape)
        x2 = self.dec2[1:]([p2, self.dec2[0]([p2, x2, o2], [p3, x3, o3]), o2])[1]
        print("x2=", x2.shape)
        x1 = self.dec1[1:]([p1, self.dec1[0]([p1, x1, o1], [p2, x2, o2]), o1])[1] 
        print("x1=", x1.shape)

        x = self.cls(x1) 
        print("x=", x.shape)
        return x


In [142]:
def pointtransformer_seg_repro(**kwargs):
    model = PointTransformerSeg(PointTransformerBlock, [2, 3, 4, 6, 3], **kwargs)
    return model

In [143]:
# Example shapes
N1, N2 = 1024, 2048  # total points across batches
B = 2                # batches
C_in = 3             # input feature channels (excluding xyz)

In [144]:
# Dummy data
p = torch.randn(N1, 3)                 # (N,3)
x = torch.randn(N1, C_in)              # (N,C)
o = torch.tensor([N1//2, N1], dtype=torch.long)  # cumulative offsets for B=2 [512, 1024]

In [145]:
model = pointtransformer_seg_repro(c=C_in+3, k=13)  # if you concat xyz to feats like original code
logits = model([p, x, o])  # (N, k)
print(logits.shape)

[TransitionDown(
  (linear): Linear(in_features=6, out_features=32, bias=False)
  (bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
), PointTransformerBlock(
  (linear1): Linear(in_features=32, out_features=32, bias=False)
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (transformer2): PointTransformerLayer(
    (linear_q): Linear(in_features=32, out_features=32, bias=True)
    (linear_k): Linear(in_features=32, out_features=32, bias=True)
    (linear_v): Linear(in_features=32, out_features=32, bias=True)
    (linear_p): Sequential(
      (0): Linear(in_features=3, out_features=3, bias=True)
      (1): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Linear(in_features=3, out_features=32, bias=True)
    )
    (linear_w): Sequential(
      (0): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running

In [147]:
#model