In [1]:
# autoreload
%load_ext autoreload
%autoreload 2

from model.stratified_transformer import Stratified
from util import config



In [2]:
args = config.load_cfg_from_cfg_file('config/s3dis/s3dis_stratified_transformer.yaml')

args.patch_size = args.grid_size * args.patch_size
args.window_size = [args.patch_size * args.window_size * (2**i) for i in range(args.num_layers)]
args.grid_sizes = [args.patch_size * (2**i) for i in range(args.num_layers)]
args.quant_sizes = [args.quant_size * (2**i) for i in range(args.num_layers)]

FEATURE_DIM = 64

model = Stratified(args.downsample_scale, args.depths, args.channels, args.num_heads, args.window_size, \
    args.up_k, args.grid_sizes, args.quant_sizes, rel_query=args.rel_query, \
    rel_key=args.rel_key, rel_value=args.rel_value, drop_path_rate=args.drop_path_rate, concat_xyz=args.concat_xyz, num_classes=args.classes, \
    ratio=args.ratio, k=args.k, prev_grid_size=args.grid_size, sigma=1.0, num_layers=args.num_layers, stem_transformer=args.stem_transformer, features_in_dim=FEATURE_DIM)
model = model.cuda()

In [3]:
print('#Model parameters: {}'.format(sum([x.nelement() for x in model.parameters()])))
print("The model is on device: {}".format(next(model.parameters()).device))

#Model parameters: 8064730
The model is on device: cuda:0


In [4]:
import torch
import torch_points_kernels as tp

# get fake data
N = 1024
C = FEATURE_DIM
B = 8
pos = torch.rand(B, N, 3) * 2
features = torch.rand(B, N, C) 

def batch_for_stratified_point_transformer(points, features):
    batch_size = points.shape[0]
    seq_len = points.shape[1]
    points = points.reshape(-1, 3)
    features = features.reshape(-1, features.shape[-1])
    offsets = torch.arange(seq_len, (1 + batch_size) * seq_len, seq_len, dtype=torch.int32)
    
    offset_ = offsets.clone()
    offset_[1:] = offset_[1:] - offset_[:-1]
    batch = torch.cat([torch.tensor([ii]*o) for ii,o in enumerate(offset_)], 0).long()
    
    sigma = 1.0
    radius = 2.5 * args.grid_size * sigma
    neighbor_idx = tp.ball_query(radius, args.max_num_neighbors, points, points, mode="partial_dense", batch_x=batch, batch_y=batch)[0]
    
    # get the neighbour_idx
    
    return points, features, offsets, batch, neighbor_idx

points, features, offsets, batch, neighbour_idx = batch_for_stratified_point_transformer(pos, features)

In [5]:
print(points.shape, features.shape, offsets.shape, batch.shape, neighbour_idx.shape)
print(points.device, features.device, offsets.device, batch.device, neighbour_idx.device)
print(points.dtype, features.dtype, offsets.dtype, batch.dtype, neighbour_idx.dtype)
# bring all the data to the GPU
points, features, offsets, batch, neighbour_idx = points.cuda(), features.cuda(), offsets.cuda(), batch.cuda(), neighbour_idx.cuda()
print(points.device, features.device, offsets.device, batch.device, neighbour_idx.device)

print(offsets)
print(batch)

torch.Size([8192, 3]) torch.Size([8192, 64]) torch.Size([8]) torch.Size([8192]) torch.Size([8192, 34])
cpu cpu cpu cpu cpu
torch.float32 torch.float32 torch.int32 torch.int64 torch.int64
cuda:0 cuda:0 cuda:0 cuda:0 cuda:0
tensor([1024, 2048, 3072, 4096, 5120, 6144, 7168, 8192], device='cuda:0',
       dtype=torch.int32)
tensor([0, 0, 0,  ..., 7, 7, 7], device='cuda:0')


In [6]:
output = model(features, points, offsets, batch, neighbour_idx)


In [7]:
print(output.shape)

torch.Size([8192, 13])
