In [1]:
import math
import random
import os
import yaml
from typing import List, Tuple
from easydict import EasyDict
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import GPT2LMHeadModel, GPT2TokenizerFast

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
from models.pointbert.point_encoder import PointTransformer

def cfg_from_yaml_file(cfg_file):
    config = EasyDict()
    with open(cfg_file, 'r') as f:
        new_config = yaml.load(f, Loader=yaml.FullLoader)
    merge_new_config(config=config, new_config=new_config)
    return config

def merge_new_config(config, new_config):
    for key, val in new_config.items():
        if not isinstance(val, dict):
            if key == '_base_':
                with open(new_config['_base_'], 'r') as f:
                    try:
                        val = yaml.load(f, Loader=yaml.FullLoader)
                    except:
                        val = yaml.load(f)
                config[key] = EasyDict()
                merge_new_config(config[key], val)
            else:
                config[key] = val
                continue
        if key not in config:
            config[key] = EasyDict()
        merge_new_config(config[key], val)
    return config

point_bert_config_name = "PointTransformer_8192point_2layer"
point_bert_config_addr = os.path.join("models/pointbert", f"{point_bert_config_name}.yaml")
print(f"Loading PointBERT config from {point_bert_config_addr}.")
point_bert_config = cfg_from_yaml_file(point_bert_config_addr)

point_bert_config.model.point_dims = 6  # Use 6D points (XYZ + RGB)
use_max_pool = False
point_encoder = PointTransformer(point_bert_config.model, use_max_pool=use_max_pool).to(device)
print(f"Using {point_encoder.point_dims} dim of points.")

point_encoder.load_checkpoint("models/pointbert/point_bert_v1.2.pt")

backbone_output_dim = point_bert_config.model.trans_dim
print(f"Using {backbone_output_dim} output dim of points from PointBERT.")

# freeze PointBERT parameters
for param in point_encoder.parameters():
    param.requires_grad = False
point_encoder.eval()

Loading PointBERT config from models/pointbert\PointTransformer_8192point_2layer.yaml.


2025-11-22 00:58:45,599 - Transformer - INFO - PointBERT's weights are successfully loaded from models/pointbert/point_bert_v1.2.pt


Using 6 dim of points.
Using 384 output dim of points from PointBERT.


PointTransformer(
  (group_divider): Group()
  (encoder): Encoder(
    (first_conv): Sequential(
      (0): Conv1d(6, 128, kernel_size=(1,), stride=(1,))
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
    )
    (second_conv): Sequential(
      (0): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv1d(512, 256, kernel_size=(1,), stride=(1,))
    )
  )
  (reduce_dim): Linear(in_features=256, out_features=384, bias=True)
  (pos_embed): Sequential(
    (0): Linear(in_features=3, out_features=128, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=128, out_features=384, bias=True)
  )
  (blocks): TransformerEncoder(
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((384,), eps=1e-05, 

In [5]:
model_name = "gpt2"
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
gpt2 = GPT2LMHeadModel.from_pretrained(model_name)

# GPT-2 doesn't have a pad token by default; set pad = eos
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

gpt2.resize_token_embeddings(len(tokenizer))
gpt2.to(device)

print("GPT-2 vocab size:", len(tokenizer))
print("GPT-2 hidden size:", gpt2.config.n_embd)

GPT-2 vocab size: 50257
GPT-2 hidden size: 768


In [4]:
class PointBertClipCap3D(nn.Module):
    """
    CLIPCap-style model:
        point cloud -> point encoder (Point-BERT) -> MLP mapper -> GPT-2 prefix
    """
     
    def __init__(self, point_encoder: nn.Module, gpt2: GPT2LMHeadModel, prefix_len: int = 10):
        super().__init__()
        self.point_encoder = point_encoder
        self.gpt2 = gpt2
        self.prefix_len = prefix_len

        point_emb_dim = backbone_output_dim
        gpt_emb_dim = gpt2.config.n_embd

        # Map global point embedding -> (prefix_len * gpt_emb_dim)
        self.mapper = nn.Sequential(
            nn.Linear(point_emb_dim, gpt_emb_dim * prefix_len),
            nn.Tanh(),
        )

    def encode_prefix(self, pts: torch.Tensor) -> torch.Tensor:
        """
        pts: (B, N, C)
        Returns:
            prefix: (B, prefix_len, gpt_emb_dim)
        """
        B = pts.size(0)
        feats = self.point_encoder(pts)  # (B, D)
        # print(global_feat.shape)
        global_feat = feats.mean(dim=1)
        mapped = self.mapper(global_feat)      # (B, prefix_len * H)
        prefix = mapped.view(B, self.prefix_len, self.gpt2.config.n_embd)
        return prefix
    
    def forward(
        self,
        pts: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor = None,
        labels: torch.Tensor = None,
    ):
        """
        pts: (B, N, C)
        input_ids: (B, T)
        attention_mask: (B, T), optional
        labels: (B, T), optional
        """

        B = pts.size(0)
        prefix = self.encode_prefix(pts)  # (B, prefix_len, gpt_emb_dim)

        # Token embeddings from GPT-2
        token_embeds = self.gpt2.transformer.wte(input_ids)  # (B, T, H)

        # Concatenate prefix + tokens along sequence dimension
        inputs_embeds = torch.cat([prefix, token_embeds], dim=1)  # (B, prefix_len + T, H)

        # Build attention mask, padding ones for the prefix
        if attention_mask is not None:
            prefix_mask = torch.ones(
                (B, self.prefix_len), dtype=attention_mask.dtype, device=attention_mask.device
            )
            attention_mask_full = torch.cat([prefix_mask, attention_mask], dim=1)
        else:
            attention_mask_full = None

        # Build labels, ignoring prefix positions with -100
        if labels is not None:
            prefix_labels = torch.full(
                (B, self.prefix_len), -100, dtype=labels.dtype, device=labels.device
            )
            labels_full = torch.cat([prefix_labels, labels], dim=1)
        else:
            labels_full = None

        outputs = self.gpt2(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask_full,
            labels=labels_full,
        )
        return outputs
    
prefix_len = 10  # you can tune this
model = PointBertClipCap3D(point_encoder, gpt2, prefix_len=prefix_len).to(device)

print("Model ready.")

Model ready.


In [5]:
from dataset.dataset import Cap3DShapeNetPreprocessed

def collate_fn(batch: List[Tuple[torch.Tensor, str]]):
    """
    Collate function to:
      - stack point clouds
      - tokenize captions
    """
    pts_list, captions = zip(*batch)

    # Stack point clouds -> (B, N, 6)
    pts_batch = torch.stack(pts_list, dim=0).float()

    # Tokenize captions
    enc = tokenizer(
        list(captions),
        padding=True,
        truncation=True,
        max_length=32,
        return_tensors="pt",
    )

    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)

    # Use input_ids as labels (standard LM training)
    labels = input_ids.clone()

    return pts_batch, input_ids, attention_mask, labels, captions

dataset = Cap3DShapeNetPreprocessed(
    points_path="dataset/data/shapenet/processed_points.pt",
    ids_path="dataset/data/shapenet/point_ids.json",
    csv_path="dataset/data/shapenet/Cap3D_automated_ShapeNet.csv",
    device =device,
)

train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn
)

len(dataset), next(iter(train_loader))[0].shape

                                          id  \
0  03001627_5ceabffee1c333293002761e7a3ba3bd   
1   04090263_4ce26b6d23caecb3cc34b900bb2492e   
2  03001627_7178731312819be3ecb14096838a20c5   
3  03691459_91b781b40d32b74dc491effd0ae881ea   
4  04401088_c3b9cb70c6a80ed686a04ec9e4169973   

                                             caption  
0  A modern office chair featuring a plush green ...  
1  A modern tactical rifle featuring an ergonomic...  
2  A modern chair featuring a sleek design with a...  
3  A rectangular storage unit featuring a promine...  
4  A sleek, rectangular smartphone featuring a gl...  


(52472, torch.Size([32, 16384, 6]))

In [None]:

freeze_gpt2 = True  # set False if you want to fine-tune GPT-2

if freeze_gpt2:
    for param in model.gpt2.parameters():
        param.requires_grad = False
    print("GPT-2 frozen (only mapper will train).")
else:
    print("GPT-2 will be fine-tuned.")

GPT-2 frozen (only mapper + point encoder will train).


In [None]:
num_epochs = 3
lr = 1e-4

# Only train the parts that require grad
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable_params, lr=lr)

model.train()

for epoch in range(num_epochs):
    total_loss = 0.0
    for step, (pts, input_ids, attention_mask, labels, raw_caps) in enumerate(tqdm(train_loader)):
        pts = pts.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(
            pts=pts,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss = outputs.loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()

        # if (step + 1) % 20 == 0:
        #     avg_loss = total_loss / (step + 1)
        #     print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{len(train_loader)}], Loss: {avg_loss:.4f}")

    print(f"Epoch {epoch+1} finished. Avg loss: {total_loss / len(train_loader):.4f}")

torch.save(model.state_dict(), "checkpoints/test_model.pth")

  0%|          | 0/1640 [00:00<?, ?it/s]

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.
100%|██████████| 1640/1640 [16:25<00:00,  1.66it/s]


Epoch 1 finished. Avg loss: 3.0136


100%|██████████| 1640/1640 [16:15<00:00,  1.68it/s]


Epoch 2 finished. Avg loss: 2.7640


100%|██████████| 1640/1640 [16:13<00:00,  1.68it/s]


Epoch 3 finished. Avg loss: 2.7092


In [41]:
@torch.no_grad()
def generate_caption_from_points(
    model: PointBertClipCap3D,
    tokenizer: GPT2TokenizerFast,
    pts: torch.Tensor,
    max_new_tokens: int = 30,
    temperature: float = 1.0,
    top_k: int = 0,
) -> str:
    """
    pts: (N, C) tensor on CPU
    """
    model.eval()

    pts = pts.unsqueeze(0).to(device)  # (1, N, C)
    prefix = model.encode_prefix(pts)  # (1, prefix_len, H)

    # Start with BOS token
    bos_id = tokenizer.bos_token_id or tokenizer.eos_token_id
    generated = torch.tensor([[bos_id]], dtype=torch.long, device=device)

    for _ in range(max_new_tokens):
        # Token embeddings for current generated sequence
        token_embeds = model.gpt2.transformer.wte(generated)  # (1, t, H)

        # Concatenate prefix + tokens
        inputs_embeds = torch.cat([prefix, token_embeds], dim=1)  # (1, prefix_len + t, H)

        outputs = model.gpt2(inputs_embeds=inputs_embeds)
        next_token_logits = outputs.logits[:, -1, :]  # (1, vocab)

        # Optionally apply temperature & top-k
        if temperature != 1.0:
            next_token_logits = next_token_logits / temperature

        if top_k > 0:
            values, indices = torch.topk(next_token_logits, top_k)
            probs = torch.softmax(values, dim=-1)
            next_token = indices[0, torch.multinomial(probs[0], num_samples=1)]
            next_token = next_token.unsqueeze(0).unsqueeze(0)
        else:
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

        generated = torch.cat([generated, next_token], dim=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    # Drop BOS, decode
    caption_ids = generated[0, 1:]
    caption = tokenizer.decode(caption_ids, skip_special_tokens=True)
    return caption.strip()


# Test on a random sample from the dummy dataset
test_pts, test_caption = dataset[10000]
print("Ground truth caption:", test_caption)

gen_caption = generate_caption_from_points(model, tokenizer, test_pts, max_new_tokens=20)
print("Generated caption:", gen_caption)

Ground truth caption: A solid rectangular block with a rich mahogany finish, featuring a smooth surface that exhibits subtle wood grain patterns. Accompanying this are broader structures with multiple circular handles, also finished in mahogany, offering a contemporary design suitable for storage solutions. The arrangement showcases a cohesive aesthetic with a warm, inviting color palette.
Generated caption: The sleek, modern wooden cabinet features a rich mahogany finish with a rich mahogany finish
