In [1]:
import math
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dist
import torch_scatter
import networkx as nx

In [2]:
class AutoregressiveModel(nn.Module, dist.Distribution):
    """ Represent a generative model that can generate samples and evaluate log probabilities.
        
        Args:
        lattice: lattice system
        features: a list of feature dimensions for all layers
        nonlinearity: activation function to use 
        bias: whether to learn the bias
    """
    
    def __init__(self, lattice: Lattice, features, nonlinearity: str = 'Tanh', bias: bool = True):
        super(AutoregressiveModel, self).__init__()
        self.lattice = lattice
        self.nodes = lattice.sites
        self.max_depth = lattice.max_depth
        self.features = features
        dist.Distribution.__init__(self, event_shape=torch.Size([self.nodes, self.features[0]]))
        self.has_rsample = True
        #self.graph = self.lattice.graph
        self.layers = nn.ModuleList()
        for l in range(1, len(self.features)):
            if l == 1: # the first layer should not have self loops
                self.layers.append(GraphConv(self.lattice, self.features[0], self.features[1], bias, self_loop = False))
            else: # remaining layers are normal
                self.layers.append(nn.LayerNorm([self.features[l - 1]]))
                self.layers.append(getattr(nn, nonlinearity)()) # activatioin layer
                self.layers.append(GraphConv(self.lattice, self.features[l - 1], self.features[l], bias))

    def update_graph(self, graph):
        # update graph for all GraphConv layers
        self.graph = graph
        for layer in self.layers:
            if isinstance(layer, GraphConv):
                layer.update_graph(graph)
        return self

    def forward(self, input):
        output = input
        for layer in self.layers: # apply layers
            output = layer(output)
        return output # logits
    
    def log_prob(self, sample):
        logits = self(sample) # forward pass to get logits
        return torch.sum(sample * F.log_softmax(logits, dim=-1), (-2,-1))

    def sampler(self, logits, dim=-1): # simplified from F.gumbel_softmax
        gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
        gumbels += logits.detach()
        index = gumbels.max(dim, keepdim=True)[1]
        return torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)

    def _sample(self, sample_size: int, sampler = None):
        if sampler is None: # if no sampler specified, use default
            sampler = self.sampler
        # create a list of tensors to cache layer-wise outputs
        cache = [torch.zeros(self.max_depth * self.features[0], sample_size)]
        depth_assignments = [self.layers[0].depth_assignment]
        for layer in self.layers:
            if isinstance(layer, GraphConv): # for graph convolution layers
                features = layer.out_features # features get updated
                depths = layer.depth_assignment
            cache.append(torch.zeros(self.nodes * features, sample_size))
            depth_assignments.append(depths)
        # cache established. start by sampling node 0.
        # assuming global symmetry, node 0 is always sampled uniformly
        cache[0][..., 0, :] = sampler(cache[0][..., 0, :])
        # start autoregressive sampling
        for j in range(1, self.max_depth + 1): # iterate through nodes 1:all
            for l, layer in enumerate(self.layers):
                if isinstance(layer, GraphConv): # for graph convolution layers
                    if l==0: # first layer should forward from previous node
                        cache[l + 1] += layer(cache[l], j - 1)#possibly change to not be in place operation
                    else: # remaining layers forward from this node
                        cache[l + 1] += layer(cache[l], j)
                else: # for other layers, only update node j (other nodes not ready yet)
                    src = layer(cache[l][..., [j], :])
                    index = torch.tensor(j).view([1]*src.dim()).expand(src.size())
                    cache[l + 1] = cache[l + 1].scatter(-2, index, src)#scatter incorrect
            # the last cache hosts the logit, sample from it 
            cache[0][..., j, :] = sampler(cache[-1][..., j, :])
        return cache # cache[0] hosts the sample
    
    def sample(self, sample_size=1):
        with torch.no_grad():
            cache = self._sample(sample_size)
        return cache[0]
    
    def rsample(self, sample_size=1, tau=None, hard=False):
        # reparametrized Gumbel sampling
        if tau is None: # if temperature not given
            tau = 1/(self.features[-1]-1) # set by the out feature dimension
        cache = self._sample(sample_size, lambda x: F.gumbel_softmax(x, tau, hard))
        return cache[0]

    def sample_with_log_prob(self, sample_size=1):
        cache = self._sample(sample_size)
        sample = cache[0]
        logits = cache[-1]
        log_prob = torch.sum(sample * F.log_softmax(logits, dim=-1), (-2,-1))
        return sample, log_prob

NameError: name 'Lattice' is not defined