# dummy

# domain-specific with innerproduct decoder

In [None]:
class InnerProductDecoder(nn.Module):
    def __init__(self, act=torch.sigmoid, dropout=0.):
        super(InnerProductDecoder, self).__init__()
        self.act = act
        self.dropout = dropout
    
    def forward(self, inp):
        inp = F.dropout(inp, self.dropout, training=self.training)
        x = torch.transpose(inp, dim0=0, dim1=1)
        x = torch.mm(inp, x)
        return self.act(x)


class GCRN(nn.Module):
    def __init__(self, input_dim, feature_dim, hidden_dim, num_layers=2):
        super(GCRN, self).__init__()
        self.encoder = GCNEncoder(input_dim, feature_dim)
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # Create a vanilla GRU layer with the specified number of layers
        self.gru = nn.GRU(feature_dim, hidden_dim, num_layers)
        
        # Use InnerProductDecoder for decoding
        self.decoder = InnerProductDecoder()

    def forward(self, x, adj):
        num_nodes, feature_dim = x.size()
        
        # Encode the current time point
        z, mu, logvar = self.encoder(x, adj)
        z = z.unsqueeze(0).unsqueeze(0)  # (1, 1, feature_dim)
        
        # Initialize hidden states for GRU
        h = torch.zeros(self.num_layers, 1, self.hidden_dim, device=x.device)  # (num_layers, batch_size, hidden_dim)
        
        # Update latent representation using the GRU
        z, h = self.gru(z, h)
        z = z.squeeze(0).squeeze(0) # (hidden_dim,)
        
        # Expand z to create node embeddings
        z_expanded = z.unsqueeze(0).repeat(num_nodes, 1)  # (num_nodes, hidden_dim)
        
        # Decode the expanded latent representation to predict the adjacency matrix
        adj_pred = self.decoder(z_expanded)
        
        return adj_pred, mu, logvar
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# domain-specific with gcn decoder

In [None]:
class GCNDecoder(nn.Module):
    def __init__(self, latent_dim, num_nodes):
        super(GCNDecoder, self).__init__()
        self.latent_dim = latent_dim
        self.num_nodes = num_nodes
        
        # Fully connected layer to expand z to node features
        self.fc = nn.Linear(latent_dim, num_nodes * latent_dim)
        
        # GCN layers
        self.conv1 = GCN(latent_dim, latent_dim * 2)
        self.bn1 = nn.BatchNorm1d(latent_dim * 2)
        self.conv2 = GCN(latent_dim * 2, latent_dim)
        self.bn2 = nn.BatchNorm1d(latent_dim)
        
        # Fully connected layer to generate adjacency matrix
        self.fc_adj = nn.Linear(latent_dim, num_nodes * num_nodes)

    def forward(self, z):
        # Expand z to initial node features
        x = self.fc(z).view(self.num_nodes, self.latent_dim)
        
        # Initial adjacency matrix (identity matrix)
        adj = torch.eye(self.num_nodes, device=z.device)
        
        # Pass through GCN layers
        x = F.relu(self.bn1(self.conv1(x, adj)))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.bn2(self.conv2(x, adj)))
        x = F.dropout(x, training=self.training)
        
        # Generate adjacency matrix
        adj_pred = torch.sigmoid(self.fc_adj(x).view(self.num_nodes, self.num_nodes))

        # Ensure symmetry of the adjacency matrix
        adj_pred = (adj_pred + adj_pred.T) / 2
        
        return adj_pred

class GCRN(nn.Module):
    def __init__(self, input_dim, feature_dim, hidden_dim, num_layers=2):
        super(GCRN, self).__init__()
        self.encoder = GCNEncoder(input_dim, feature_dim)
        self.decoder = GCNDecoder(hidden_dim, input_dim)  # Assuming input_dim is the number of nodes
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # Create a vanilla GRU layer with the specified number of layers
        self.gru = nn.GRU(feature_dim, hidden_dim, num_layers)

    def forward(self, x, adj):
        num_nodes, feature_dim = x.size()
        
        # Encode the current time point
        z, mu, logvar = self.encoder(x, adj)
        z = z.unsqueeze(0).unsqueeze(0)  # (1, 1, feature_dim)
        
        # Initialize hidden states for GRU
        h = torch.zeros(self.num_layers, 1, self.hidden_dim, device=x.device)  # (num_layers, batch_size, hidden_dim)
        
        # Update latent representation using the GRU
        z, h = self.gru(z, h)
        z = z.squeeze(0).squeeze(0) # (hidden_dim,)
        
        # Decode the updated latent representation to predict the adjacency matrix for the next time point
        adj_pred = self.decoder(z)
        
        return adj_pred, mu, logvar
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# domain-specific with fully-connected layer decoder

In [None]:
class GCNDecoder(nn.Module):
    def __init__(self, latent_dim, num_nodes):
        super(GCNDecoder, self).__init__()
        self.latent_dim = latent_dim
        self.num_nodes = num_nodes
        
        # Fully connected layers to generate the adjacency matrix
        self.fc1 = nn.Linear(latent_dim, latent_dim * 2)
        self.fc2 = nn.Linear(latent_dim * 2, num_nodes * num_nodes)

    def forward(self, z):
        # Pass through fully connected layers
        x = F.relu(self.fc1(z))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        
        # Reshape to the adjacency matrix
        adj_pred = torch.sigmoid(x.view(self.num_nodes, self.num_nodes))

        # Ensure symmetry of the adjacency matrix
        adj_pred = (adj_pred + adj_pred.T) / 2
        
        return adj_pred

class GCRN(nn.Module):
    def __init__(self, input_dim, feature_dim, hidden_dim, num_layers=2):
        super(GCRN, self).__init__()
        self.encoder = GCNEncoder(input_dim, feature_dim)
        self.decoder = GCNDecoder(hidden_dim, input_dim)  # Assuming input_dim is the number of nodes
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # Create a vanilla GRU layer with the specified number of layers
        self.gru = nn.GRU(feature_dim, hidden_dim, num_layers)

    def forward(self, x, adj):
        num_nodes, feature_dim = x.size()
        
        # Encode the current time point
        z, mu, logvar = self.encoder(x, adj)
        z = z.unsqueeze(0).unsqueeze(0)  # (1, 1, feature_dim)
        
        # Initialize hidden states for GRU
        h = torch.zeros(self.num_layers, 1, self.hidden_dim, device=x.device)  # (num_layers, batch_size, hidden_dim)
        
        # Update latent representation using the GRU
        z, h = self.gru(z, h)
        z = z.squeeze(0).squeeze(0) # (hidden_dim,)
        
        # Decode the updated latent representation to predict the adjacency matrix for the next time point
        adj_pred = self.decoder(z)
        
        return adj_pred, mu, logvar
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
