# dummy

In [None]:
class GCRN_gcn(nn.Module):
    def __init__(self, input_dim, feature_dim, hidden_dim, num_layers=2):
        super(GCRN_gcn, self).__init__()
        self.encoder = GCNEncoder(input_dim, feature_dim)
        self.decoder = GCNDecoder(hidden_dim, input_dim)  # Assuming input_dim is the number of nodes
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # Create a vanilla GRU layer with the specified number of layers
        self.gru = nn.GRU(feature_dim, hidden_dim, num_layers)

    def forward(self, x, adj):
        num_nodes, feature_dim = x.size()
        
        # Encode the current time point
        z, mu, logvar = self.encoder(x, adj)
        z = z.unsqueeze(0).unsqueeze(0)  # (1, 1, feature_dim)
        
        # Initialize hidden states for GRU
        h = torch.zeros(self.num_layers, 1, self.hidden_dim, device=x.device)  # (num_layers, batch_size, hidden_dim)
        
        # Update latent representation using the GRU
        z, h = self.gru(z, h)
        z = z.squeeze(0).squeeze(0) # (hidden_dim,)
        
        # Decode the updated latent representation to predict the adjacency matrix for the next time point
        adj_pred = self.decoder(z)
        
        return adj_pred, mu, logvar
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

def loss_function(recon_adj, adj, mu, logvar):
    recon_loss = F.mse_loss(recon_adj, adj, reduction='sum')
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kld_loss


def train_model(model, train_features, train_adj, num_epochs=100, lr=0.001, save_path='gcrn_model.pth'):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    training_loss = []

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        
        for i in range(train_features.size(0)):
            optimizer.zero_grad()
            
            # Extract current and next time point data
            x_t = train_features[i, 0]  # Current time point node features
            adj_t = train_adj[i, 0]     # Current time point adjacency matrix
            adj_t_next = train_adj[i, 1]  # Next time point adjacency matrix (ground truth)
            
            # Forward pass
            recon_adj, mu, logvar = model(x_t, adj_t)
            # Compute loss
            loss = loss_function(recon_adj, adj_t_next, mu, logvar)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()

        epoch_loss /= train_features.size(0)
        training_loss.append(epoch_loss)
        
        print(f'Epoch {epoch + 1}, Loss: {epoch_loss}')

    # Plot the training loss
    plt.plot(training_loss)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.show()

    # Save the trained model
    torch.save(model.state_dict(), save_path)
    print(f'Model saved to {save_path}')