# dummy

In [None]:
class GAT(nn.Module):
    def __init__(self, in_features, out_features, activation=F.relu):
        super(GAT, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = nn.Parameter(torch.zeros(out_features))
        self.phi = nn.Parameter(torch.FloatTensor(2 * out_features, 1))
        self.activation = activation
        self.reset_parameters()
        self.drop = nn.Dropout(p=0.5)
 
    def reset_parameters(self):
        uniform(weight)
        uniform(phi)
 
    def forward(self, input, adj):
        input = self.drop(input)
        h = torch.mm(input, self.weight) + self.bias 
 
        N = input.size(0) 
        h_expand = h.unsqueeze(1).expand(N, N, -1)
        h_t_expand = h.unsqueeze(0).expand(N, N, -1)
        
        concat_features = torch.cat([h_expand, h_t_expand], dim=-1)
        
        S = torch.matmul(concat_features, self.phi).squeeze(-1)
 
        mask = (adj.to(device) + torch.eye(adj.size(0),device=device)).bool()
        S_masked = torch.where(mask, S, torch.tensor(-9e15, dtype=S.dtype).to(device))
        attention_weights = F.softmax(S_masked, dim=1)
        h = torch.matmul(attention_weights, h)
        return self.activation(h) if self.activation else h
    
    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, out_channels, activation=F.relu, normalize=False, bias=False):
        super(GraphSAGE, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
        self.activation = activation
        self.pool = pool
        
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x, edge_index):
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        row, col = edge_index
        
        out = torch.matmul(x, self.weight)
        if self.bias is not None:
            out = out + self.bias
        out = self.activation(out)
        out = scatter_mean(out[col], row, dim=0, dim_size=out.size(0))
                
        if self.normalize:
            out = F.normalize(out, p=2, dim=-1)
        
        return out

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, feature_dim, hidden_dims, latent_dim):
        super(GATEncoder, self).__init__()
        self.num_nodes = input_dim
        self.feature_dim = feature_dim
        self.hidden_dims = hidden_dims
        self.latent_dim = latent_dim
        
        # GAT layers
        self.layers = nn.ModuleList()
        self.bns = nn.ModuleList()
        
        prev_dim = feature_dim
        for hidden_dim in hidden_dims:
            if block_type == 'GCN':
                self.layers.append(GCN(prev_dim, hidden_dim))
            elif block_type == 'GAT':
                self.layers.append(GAT(prev_dim, hidden_dim))
            elif block_type == 'GraphSAGE'
                self.layers.append(GraphSAGE(prev_dim, hidden_dim))
            self.bns.append(nn.BatchNorm1d(hidden_dim))
            prev_dim = hidden_dim
            
        if block_type == 'GCN':
            self.layers.append(GCN(prev_dim, latent_dim))
        if block_type == 'GAT':
            self.layers.append(GAT(prev_dim, latent_dim))
        if block_type == 'GraphSAGE':
            self.layers.append(GraphSAGE(prev_dim, latent_dim))
        self.bns.append(nn.BatchNorm1d(latent_dim))
        
        self.fc = nn.Linear(latent_dim * input_dim, latent_dim)

    def forward(self, x, adj):
        for layer, bn in zip(self.layers, self.bns):
            x = F.relu(bn(layer(x, adj)))
        
        x = F.dropout(x, training=self.training)
        x_flat = x.view(-1)
        z = self.fc(x_flat)
        
        return z

class Decoder(nn.Module):
    def __init__(self, latent_dim, num_nodes, hidden_dims):
        super(GATDecoder, self).__init__()
        self.latent_dim = latent_dim
        self.num_nodes = num_nodes
        self.hidden_dims = hidden_dims
        
        # Fully connected layer to expand z to node features
        self.fc = nn.Linear(latent_dim, num_nodes * latent_dim)
        self.bn0 = nn.BatchNorm1d(latent_dim)

        # GAT layers
        self.layers = nn.ModuleList()
        self.bns = nn.ModuleList()
        
        prev_dim = latent_dim
        for hidden_dim in hidden_dims:
            if block_type == 'GCN':
                self.layers.append(GCN(prev_dim, hidden_dim))
            elif block_type == 'GAT':
                self.layers.append(GAT(prev_dim, hidden_dim))
            elif block_type == 'GraphSAGE'
                self.layers.append(GraphSAGE(prev_dim, hidden_dim))
            self.bns.append(nn.BatchNorm1d(hidden_dim))
            prev_dim = hidden_dim
        
        if block_type == 'GCN':
            self.layers.append(GCN(prev_dim, num_nodes))
        if block_type == 'GAT':
            self.layers.append(GAT(prev_dim, num_nodes))
        if block_type == 'GraphSAGE':
            self.layers.append(GraphSAGE(prev_dim, num_nodes))
        self.bns.append(nn.BatchNorm1d(num_nodes))

    def forward(self, z):
        x = self.bn0(self.fc(z).view(self.num_nodes, self.latent_dim))
        adj = torch.eye(self.num_nodes, device=z.device)
        
        for layer, bn in zip(self.layers, self.bns):
            x = F.relu(bn(layer(x, adj)))

        # Ensure symmetry of the adjacency matrix
        adj_pred = (x + x.t()) / 2
        
        # Set the diagonal elements to zero (no self-connections)
        adj_pred = adj_pred - torch.diag(torch.diag(adj_pred))
        return adj_pred

class GCRN(nn.Module):
    def __init__(self, input_dim, feature_dim, latent_dim, encoder_hidden_dims, decoder_hidden_dims, num_layers=2):
        super(GCRN, self).__init__()
        set_seed(42)
        self.hidden_dim = latent_dim
        self.num_layers = num_layers

        self.encoder = Encoder(input_dim, feature_dim, encoder_hidden_dims, latent_dim)
        self.gru = nn.GRU(latent_dim, self.hidden_dim, num_layers)
        self.decoder = Decoder(latent_dim, input_dim, decoder_hidden_dims)  

    def forward(self, x, adj):
        num_nodes, feature_dim = x.size()
        
        # Encode the current time point
        z = self.encoder(x, adj)
        z = z.unsqueeze(0).unsqueeze(0) 
        
        # Initialize hidden states for GRU
        h = torch.zeros(self.num_layers, 1, self.hidden_dim, device=x.device)  

        # Update latent representation using the GRU
        z, h = self.gru(z, h)
        z = z.squeeze(0).squeeze(0) 
        
        # Decode the updated latent representation to predict the adjacency matrix for the next time point
        adj_pred = self.decoder(z)
        
        return adj_pred, z
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)