In [149]:
from torch.nn import Sequential, Linear, SiLU, ModuleList
from torch_geometric.nn import MessagePassing, MLP, AttentionalAggregation, MaxAggregation
from torch_geometric.nn import PointNetConv, PositionalEncoding
import torch

import torch
from torch_cluster import knn_graph, fps


class PointNetEncoder(torch.nn.Module):
    def __init__(self, zdim):
        super().__init__()
        self.conv1 = PointNetConv(
            local_nn=MLP([3 + 3, 32], act=SiLU(), plain_last=True), 
            global_nn=SiLU(), 
            aggr=MaxAggregation()
        )
        self.conv2 = PointNetConv(
            local_nn=MLP([32 + 3, 32], act=SiLU(), plain_last=True), 
            global_nn=SiLU(),
            aggr=MaxAggregation()
        )
        self.aggr = MaxAggregation()
        self.net = Linear(32, zdim)

    def forward(self, pos: torch.Tensor, batch: torch.Tensor):
        h: torch.Tensor
        edge_index = knn_graph(pos, k=16, batch=batch, loop=True)
        h = self.conv1(x=pos, pos=pos, edge_index=edge_index)

        index = fps(pos, batch, ratio=0.5)
        h, pos, batch = h[index], pos[index], batch[index]
        edge_index = knn_graph(pos, k=16, batch=batch, loop=True)
        h = self.conv2(x=h, pos=pos, edge_index=edge_index)

        h = self.aggr(h, batch)  # [batch_size, hidden_channels]
        return self.net(h)
    

class ConcatSquashLinear(torch.nn.Module):
    def __init__(self, dim_in, dim_out, dim_ctx):
        super(ConcatSquashLinear, self).__init__()
        self._layer = Linear(dim_in, dim_out)
        self._hyper_bias = Linear(dim_ctx, dim_out, bias=False)
        self._hyper_gate = Linear(dim_ctx, dim_out)

    def forward(self, x: torch.Tensor, ctx: torch.Tensor, batch: torch.Tensor):
        gate: torch.Tensor = torch.sigmoid(self._hyper_gate(ctx))
        bias: torch.Tensor = self._hyper_bias(ctx)
        ret: torch.Tensor = self._layer(x) * gate[batch] + bias[batch]
        return ret


class PointwiseNet(torch.nn.Module):
    def __init__(self, dim_ctx):
        super().__init__()
        self.net = ModuleList([
            ConcatSquashLinear(3, 128, dim_ctx),
            ConcatSquashLinear(128, 256, dim_ctx),
            ConcatSquashLinear(256, 512, dim_ctx),
            ConcatSquashLinear(512, 256, dim_ctx),
            ConcatSquashLinear(256, 128, dim_ctx),
        ])
        self.out = ConcatSquashLinear(128, 3, dim_ctx)

    def forward(self, x: torch.Tensor, ctx: torch.Tensor, batch: torch.Tensor):
        out: torch.Tensor = x
        for layer in self.net:
            out = layer(out, ctx, batch)
            out = torch.nn.functional.silu(out)

        out = self.out(out, ctx, batch)
        return x + out



In [150]:
from torch_geometric.datasets import GeometricShapes
from torch_geometric.transforms import NormalizeScale, SamplePoints, Compose

transform = Compose([NormalizeScale(), SamplePoints(1024)])
dataset = GeometricShapes(root='data/GeometricShapes', transform=transform)
print(dataset)

GeometricShapes(40)


In [151]:
from torch_geometric.loader import DataLoader

train_dataset = GeometricShapes(root='data/GeometricShapes', train=True,
                                transform=SamplePoints(128))
test_dataset = GeometricShapes(root='data/GeometricShapes', train=False,
                               transform=SamplePoints(128))


train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10)

model = PointNetEncoder(zdim=train_dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.

def train(model, optimizer, loader):
    model.train() 
    
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()  # Clear gradients.
        logits = model(data.pos, data.batch)  # Forward pass.
        loss = criterion(logits, data.y)  # Loss computation.
        loss.backward()  # Backward pass.
        optimizer.step()  # Update model parameters.
        total_loss += loss.item() * data.num_graphs

    return total_loss / len(train_loader.dataset) # type: ignore


@torch.no_grad()
def test(model, loader):
    model.eval()

    total_correct = 0
    for data in loader:
        logits = model(data.pos, data.batch)
        pred = logits.argmax(dim=-1)
        total_correct += int((pred == data.y).sum())

    return total_correct / len(loader.dataset)

for epoch in range(1, 51):
    loss = train(model, optimizer, train_loader)
    test_acc = test(model, test_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')

Epoch: 01, Loss: 3.8056, Test Accuracy: 0.0250
Epoch: 02, Loss: 3.6662, Test Accuracy: 0.0500
Epoch: 03, Loss: 3.6004, Test Accuracy: 0.0250
Epoch: 04, Loss: 3.5283, Test Accuracy: 0.0500
Epoch: 05, Loss: 3.4957, Test Accuracy: 0.0500
Epoch: 06, Loss: 3.4506, Test Accuracy: 0.0500
Epoch: 07, Loss: 3.4021, Test Accuracy: 0.0750
Epoch: 08, Loss: 3.3420, Test Accuracy: 0.1000
Epoch: 09, Loss: 3.2696, Test Accuracy: 0.1750
Epoch: 10, Loss: 3.1986, Test Accuracy: 0.1750
Epoch: 11, Loss: 3.0949, Test Accuracy: 0.1750
Epoch: 12, Loss: 2.9843, Test Accuracy: 0.2000
Epoch: 13, Loss: 2.8761, Test Accuracy: 0.2250
Epoch: 14, Loss: 2.7336, Test Accuracy: 0.2500
Epoch: 15, Loss: 2.6267, Test Accuracy: 0.3000
Epoch: 16, Loss: 2.4353, Test Accuracy: 0.3750
Epoch: 17, Loss: 2.3108, Test Accuracy: 0.4250
Epoch: 18, Loss: 2.1443, Test Accuracy: 0.5250
Epoch: 19, Loss: 1.9847, Test Accuracy: 0.6750
Epoch: 20, Loss: 1.8719, Test Accuracy: 0.7500
Epoch: 21, Loss: 1.7290, Test Accuracy: 0.6750
Epoch: 22, Lo

In [152]:
import numpy as np

class VarianceSchedule(torch.nn.Module):

    def __init__(self, num_steps, beta_1, beta_T):
        super().__init__()
        self.num_steps = num_steps
        self.beta_1 = beta_1
        self.beta_T = beta_T

        betas = torch.linspace(beta_1, beta_T, steps=num_steps)
        betas = torch.cat([torch.zeros([1]), betas], dim=0)     # Padding

        alphas = 1 - betas
        log_alphas = torch.log(alphas)
        for i in range(1, log_alphas.size(0)):  # 1 to T
            log_alphas[i] += log_alphas[i - 1]
        alpha_bars = log_alphas.exp()

        sigmas_flex = torch.sqrt(betas)
        sigmas_inflex = torch.zeros_like(sigmas_flex)
        for i in range(1, sigmas_flex.size(0)):
            sigmas_inflex[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i]
        sigmas_inflex = torch.sqrt(sigmas_inflex)

        self.betas: torch.Tensor
        self.alphas: torch.Tensor
        self.alpha_bars: torch.Tensor
        self.sigmas_flex: torch.Tensor
        self.sigmas_inflex: torch.Tensor
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alpha_bars', alpha_bars)
        self.register_buffer('sigmas_flex', sigmas_flex)
        self.register_buffer('sigmas_inflex', sigmas_inflex)

    def uniform_sample_t(self, batch_size):
        ts = np.random.choice(np.arange(1, self.num_steps+1), batch_size)
        return ts.tolist()

    def get_sigmas(self, t, flexibility):
        assert 0 <= flexibility and flexibility <= 1
        sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility)
        return sigmas

In [153]:
class Model(torch.nn.Module):
    def __init__(self, zdim, num_steps, beta_1, beta_T):
        super().__init__()
        self.encoder = PointNetEncoder(zdim)
        self.decoder = PointwiseNet(zdim)
        self.schedule = VarianceSchedule(num_steps, beta_1, beta_T)

    def forward(self, pos: torch.Tensor, batch: torch.Tensor):
        z: torch.Tensor = self.encoder(pos, batch)
        batch_size = z.size(0)
 
        t = 43 * torch.ones(batch_size, dtype=torch.long)
        alpha_bar = self.schedule.alpha_bars[t]
        # beta = self.schedule.betas[t]

        c0 = torch.sqrt(alpha_bar)       # (B, 1, 1)
        c1 = torch.sqrt(1 - alpha_bar)   # (B, 1, 1)
        c0, c1 = c0[batch].view(-1, 1), c1[batch].view(-1, 1)

        e_rand = torch.randn_like(pos)
        e_theta = self.decoder(c0 * pos + c1 * e_rand, ctx=z, batch=batch)

        return e_theta, e_rand

In [154]:
def get_data_iterator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, data in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()

In [155]:
model = Model(zdim=train_dataset.num_classes, num_steps=1000, beta_1=1e-4, beta_T=0.02)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()  # Define loss criterion.

def train():
    
    for i, data in enumerate(get_data_iterator(train_loader), 1):
        optimizer.zero_grad()  # Clear gradients.
        model.train() 
        
        e_theta, e_rand = model(data.pos, data.batch)  # Forward pass.
        loss = criterion(e_theta, e_rand)  # Loss computation.
        
        loss.backward()  # Backward pass.
        optimizer.step()  # Update model parameters.

        print("Iteration: {}, Loss: {}".format(i, loss.item()))

train()


Iteration: 1, Loss: 1.6197115182876587
Iteration: 2, Loss: 0.8131834864616394
Iteration: 3, Loss: 0.6480311155319214
Iteration: 4, Loss: 0.6618724465370178
Iteration: 5, Loss: 0.6184800863265991
Iteration: 6, Loss: 1.0993030071258545
Iteration: 7, Loss: 0.6762861609458923
Iteration: 8, Loss: 0.9229393005371094
Iteration: 9, Loss: 0.6136332750320435
Iteration: 10, Loss: 0.7686660289764404
Iteration: 11, Loss: 4.202695846557617
Iteration: 12, Loss: 0.7083160877227783
Iteration: 13, Loss: 0.7865323424339294
Iteration: 14, Loss: 1.1750801801681519
Iteration: 15, Loss: 0.8006055951118469
Iteration: 16, Loss: 0.624251127243042
Iteration: 17, Loss: 0.6602266430854797
Iteration: 18, Loss: 0.9424404501914978
Iteration: 19, Loss: 1.1071655750274658
Iteration: 20, Loss: 0.717014491558075
Iteration: 21, Loss: 0.6872360110282898
Iteration: 22, Loss: 0.6284244656562805
Iteration: 23, Loss: 0.7244883179664612
Iteration: 24, Loss: 1.4483073949813843
Iteration: 25, Loss: 1.0629996061325073
Iteration: 2

KeyboardInterrupt: 

In [77]:
schedule = VarianceSchedule(100, 1e-4, 0.02)

In [116]:
steps = schedule.uniform_sample_t(10)
print(steps)
schedule.get_sigmas(torch.tensor(steps), 0.0)

[43, 12, 24, 72, 92, 81, 93, 19, 88, 26]


tensor([0.0905, 0.0441, 0.0659, 0.1186, 0.1347, 0.1261, 0.1354, 0.0578, 0.1316,
        0.0689])