In [3]:
!pip install torch-cluster
!pip install timm


Collecting torch-cluster
  Downloading torch_cluster-1.6.3.tar.gz (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[0m [31m740.5 kB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torch-cluster
  Building wheel for torch-cluster (setup.py) ... [?25l[?25hdone
  Created wheel for torch-cluster: filename=torch_cluster-1.6.3-cp310-cp310-linux_x86_64.whl size=722828 sha256=477e0a39150923a289ac2618b79620b87a59e2039d521b939b02729450dace61
  Stored in directory: /root/.cache/pip/wheels/51/78/c3/536637b3cdcc3313aa5e8851a6c72b97f6a01877e68c7595e3
Successfully built torch-cluster
Installing collected packages: torch-cluster
Successfully installed torch-cluster-1.6.3
Collecting timm
  Downloading timm-1.0.3-py3-none-any.whl (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12=

In [4]:
import torch
from torch import nn
import torch.nn.functional as F
from torch_cluster import fps
from timm.models.layers import DropPath

class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim=None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if context_dim is not None else None

    def forward(self, x, **kwargs):
        x = self.norm(x)

        if self.norm_context is not None:
            context = kwargs['context']
            normed_context = self.norm_context(context)
            kwargs.update(context=normed_context)

        return self.fn(x, **kwargs)

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim=-1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4, drop_path_rate=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim)
        )
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

    def forward(self, x):
        return self.drop_path(self.net(x))

class Attention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, drop_path_rate=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = context_dim if context_dim is not None else query_dim
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, query_dim)

        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = context if context is not None else x
        k, v = self.to_kv(context).chunk(2, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if mask is not None:
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.drop_path(self.to_out(out))

class PointEmbed(nn.Module):
    def __init__(self, hidden_dim=48, dim=128):
        super().__init__()
        assert hidden_dim % 6 == 0
        self.embedding_dim = hidden_dim
        e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
        e = torch.stack([
            torch.cat([e, torch.zeros(self.embedding_dim // 6), torch.zeros(self.embedding_dim // 6)]),
            torch.cat([torch.zeros(self.embedding_dim // 6), e, torch.zeros(self.embedding_dim // 6)]),
            torch.cat([torch.zeros(self.embedding_dim // 6), torch.zeros(self.embedding_dim // 6), e]),
        ])
        self.register_buffer('basis', e)  # 3 x 16
        self.mlp = nn.Linear(self.embedding_dim + 3, dim)

    @staticmethod
    def embed(input, basis):
        projections = torch.einsum('bnd,de->bne', input, basis)
        embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
        return embeddings

    def forward(self, input):
        embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2))
        return embed



In [5]:
class AutoEncoder(nn.Module):
    def __init__(self, depth=24, dim=512, queries_dim=512, output_dim=1, num_inputs=2048, num_latents=512, heads=8, dim_head=64, weight_tie_layers=False, decoder_ff=False):
        super().__init__()
        self.depth = depth
        self.num_inputs = num_inputs
        self.num_latents = num_latents

        # Initialize cross-attention blocks with a list of two modules: attention and feedforward
        self.cross_attend_blocks = nn.ModuleList([
            PreNorm(dim, Attention(dim, dim, heads=1, dim_head=dim), context_dim=dim),
            PreNorm(dim, FeedForward(dim))
        ])

        # Initialize point embedding
        self.point_embed = PointEmbed(dim=dim)

        # Define functions for creating latent attention and feedforward layers with cached function calls
        get_latent_attn = lambda: PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, drop_path_rate=0.1))
        get_latent_ff = lambda: PreNorm(dim, FeedForward(dim, drop_path_rate=0.1))
        get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))

        # Initialize a list of layers for the autoencoder
        self.layers = nn.ModuleList([])
        cache_args = {'_cache': weight_tie_layers}

        # Create the specified number of layers in the autoencoder
        for i in range(depth):
            self.layers.append(nn.ModuleList([
                get_latent_attn(**cache_args),
                get_latent_ff(**cache_args)
            ]))

        # Initialize decoder cross-attention and feedforward layers
        self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, dim, heads=1, dim_head=dim), context_dim=dim)
        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None

        # Output layer for the autoencoder
        self.to_outputs = nn.Linear(queries_dim, output_dim) if output_dim is not None else nn.Identity()

    def encode(self, pc):
        B, N, D = pc.shape
        assert N == self.num_inputs

        # Flatten the input point cloud
        flattened = pc.view(B * N, D)
        batch = torch.arange(B).to(pc.device)
        batch = torch.repeat_interleave(batch, N)
        pos = flattened

        # Farthest point sampling (fps) to select a subset of input points
        ratio = 1.0 * self.num_latents / self.num_inputs
        idx = fps(pos, batch, ratio=ratio)
        sampled_pc = pos[idx].view(B, -1, 3)

        # Embed the sampled point cloud and the original point cloud
        sampled_pc_embeddings = self.point_embed(sampled_pc)
        pc_embeddings = self.point_embed(pc)

        # Apply cross-attention and feedforward layers to the embeddings
        cross_attn, cross_ff = self.cross_attend_blocks
        x = cross_attn(sampled_pc_embeddings, context=pc_embeddings, mask=None) + sampled_pc_embeddings
        x = cross_ff(x) + x

        return x

    def decode(self, x, queries):
        # Loop through each layer in the autoencoder
        for self_attn, self_ff in self.layers:
            # Apply latent attention and feedforward layers
            x = self_attn(x) + x
            x = self_ff(x) + x

        # Embed the decoder queries
        queries_embeddings = self.point_embed(queries)

        # Apply decoder cross-attention to the embeddings and the encoded input
        latents = self.decoder_cross_attn(queries_embeddings, context=x)

        # Optionally apply decoder feedforward layer
        if self.decoder_ff is not None:
            latents = latents + self.decoder_ff(latents)

        # Return the output logits
        return self.to_outputs(latents)

    def forward(self, pc, queries):
        # Encode the input point cloud
        x = self.encode(pc)
        # Decode the latent representation with decoder queries
        o = self.decode(x, queries).squeeze(-1)
        # Return the output logits
        return {'logits': o}


In [7]:
import numpy as np
autoencoder = AutoEncoder()

# Generate some example input data
input_data = torch.randn(1, 2048, 3)  # Batch size 1, 2048 points, 3 dimensions

# Generate some example query data for the decoder
queries = torch.randn(1, 512, 3)  # Batch size 1, 512 queries, 3 dimensions

# Perform a forward pass through the autoencoder
output = autoencoder(input_data, queries)

# Access the output logits
logits = output['logits']