In [49]:
import torch
import torch.nn as nn
from torch_geometric.nn import HeteroConv, SAGEConv, global_mean_pool
from torch_geometric.data import data, HeteroData

class HeteroGNNEncoder(nn.Module):
    def __init__(self, hidden_channels=64, out_channels=128, num_layers=2, metadata=None):
        """
        Initializes the Heterogeneous GNN Encoder.

        Args:
            hidden_channels (int): Number of hidden units in GNN layers.
            out_channels (int): Dimension of the output latent vector.
            num_layers (int): Number of GNN layers.
            metadata (tuple): Metadata for HeteroConv, typically (node_types, edge_types).
        """
        super(HeteroGNNEncoder, self).__init__()
        
        if metadata is None:
            raise ValueError("Metadata must be provided for HeteroConv.")
        
        node_types, edge_types = metadata
        self.node_types = node_types
        self.edge_types = edge_types
        
        # Define HeteroConv layers
        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            conv_dict = {}
            for edge_type in edge_types:
                src, rel, dst = edge_type
                conv_dict[edge_type] = SAGEConv(-1, hidden_channels, aggr='mean')
            hetero_conv = HeteroConv(conv_dict, aggr='sum')
            self.convs.append(hetero_conv)
        
        # Linear layer to project to latent space
        self.linear = nn.Linear(hidden_channels, out_channels)
        self.activation = nn.ReLU()
        
    def forward(self, data):
        """
        Forward pass of the encoder.

        Args:
            data (data): A data of HeteroData graphs.

        Returns:
            Tensor: Latent representations of shape (num_graphs, out_channels).
        """
        x_dict = data.x_dict  # Dict of node_type -> node_features

        # print("x_dict", x_dict["author"].shape)
        # print()
        # in the beginning author shape was (200,32),
        # because we have 10 author nodes per graph, and we data size 4 and sequence size 5, so 10*4*5=200
        edge_index_dict = data.edge_index_dict  # Dict of edge_type -> edge_index
        
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)  # Perform HeteroConv
            x_dict = {key: self.activation(x) for key, x in x_dict.items()}  # Apply activation

        # after covilution for every node type shape should be (200,64)
        
       
       
        
        # for every node type we take a mean of all of the nodes from one graph
        x_dict = {key: global_mean_pool(x_dict[key], data[key].batch) for key in data.node_types}
        # print("shape after pooling for author")
        # print(x_dict["author"].shape) #(20, 64) so for every graph - 64 vector 
        # print("shape after pooling for paper")
        # print(x_dict["paper"].shape) #(20, 64) so for every graph - 64 vector
        # print("shape after pooling for institution")
        # print(x_dict["institution"].shape) #(20, 64) so for every graph - 64 vector
        # # 
        # print("AAAAAA")
        # print(x_dict.values())

        # here we connect along the 0 dimension, creating extra dimension
        # `stack` creates extra dimension, while `cat` connects along a certain dim
        # after stack we get [3, 20, 64], which is [node_types_num, num_of_graphs, hidden_size]
        # then we are summing up every element along the node_type dimension. So, sum up all the node types.
        # we should get again [20, 64] shape. For every graph, we get 64 size vector
        graph_emb = torch.stack(list(x_dict.values()), dim=0).sum(dim=0)
        # print(f"{graph_emb.shape=}")
        
        # # Global mean pooling
        # graph_emb = global_mean_pool(combined_x, combined_batch)  # Shape: (num_graphs, hidden_channels)
        
        # Project to desired latent dimension
        out = self.linear(graph_emb)  # Shape: (num_graphs, out_channels)
        
        return out

class SequenceGNNEncoder(nn.Module):
    def __init__(self, hidden_channels=64, out_channels=128, num_layers=2, metadata=None):
        """
        Initializes the Sequence GNN Encoder.

        Args:
            hidden_channels (int): Number of hidden units in GNN layers.
            out_channels (int): Dimension of the output latent vector.
            num_layers (int): Number of GNN layers.
            metadata (tuple): Metadata for HeteroConv, typically (node_types, edge_types).
        """
        super(SequenceGNNEncoder, self).__init__()
        self.encoder = HeteroGNNEncoder(hidden_channels, out_channels, num_layers, metadata)
        
    def forward(self, graphs_batch):
        """
        Forward pass for a batch of graph sequences.

        Args:
            graphs_batch (list of list of HeteroData): 
                Outer list has length batch_size.
                Each inner list has length sequence_size, containing HeteroData graphs.

        Returns:
            Tensor: Latent representations of shape (batch_size, sequence_size, out_channels).
        """
        batch_size = len(graphs_batch)
        sequence_size = len(graphs_batch[0])
        
        # Flatten the list of lists into a single list
        all_graphs = [graph for batch in graphs_batch for graph in batch]
        
        # Create a Batch object from the flattened list
        batch = Batch.from_data_list(all_graphs)
        
        # Encode all graphs
        encoded = self.encoder(batch)  # Shape: (batch_size * sequence_size, out_channels)
        
        # Reshape to (batch_size, sequence_size, out_channels)
        encoded = encoded.view(batch_size, sequence_size, -1)
        
        return encoded

# Example Usage
if __name__ == "__main__":
    # Example metadata: define node types and edge types
    node_types = ['author', 'paper', 'institution']
    edge_types = [
        ('author', 'writes', 'paper'),
        ('paper', 'cites', 'paper'),
        ('paper', 'affiliated_with', 'institution'),
        ('institution', 'hosts', 'author')
    ]
    metadata = (node_types, edge_types)
    
    # Initialize the encoder
    encoder = SequenceGNNEncoder(hidden_channels=64, out_channels=128, num_layers=2, metadata=metadata)
    
    # Create dummy data
    # For simplicity, create random features and random edges
    def create_dummy_heterodata():
        data = HeteroData()
        if torch.randint(0, 10, (1,)) > 5:
            # print("1")
            # Example node features
            data['author'].x = torch.randn(10, 32)       # 10 authors with 32-dim features
            data['paper'].x = torch.randn(20, 64)        # 20 papers with 64-dim features
            data['institution'].x = torch.randn(5, 16)   # 5 institutions with 16-dim features
        else:
            # print("2")
             # Example node features
            data['author'].x = torch.randn(9, 32)       # 9 authors with 32-dim features
            data['paper'].x = torch.randn(19, 64)        # 19 papers with 64-dim features
            data['institution'].x = torch.randn(4, 16)   # 4 institutions with 16-dim features
        
             
        
        # Example edges
        data['author', 'writes', 'paper'].edge_index = torch.tensor([
            [0, 1, 2, 3],
            [0, 1, 2, 3]
        ], dtype=torch.long)
        
        data['paper', 'cites', 'paper'].edge_index = torch.tensor([
            [0, 1, 2],
            [1, 2, 3]
        ], dtype=torch.long)
        
        data['paper', 'affiliated_with', 'institution'].edge_index = torch.tensor([
            [0, 1, 2],
            [0, 1, 2]
        ], dtype=torch.long)
        
        data['institution', 'hosts', 'author'].edge_index = torch.tensor([
            [0, 1],
            [0, 1]
        ], dtype=torch.long)
        
        return data
    
    batch_size = 4
    sequence_size = 5
    
    # Create a batch of sequences
    graphs_batch = []
    for _ in range(batch_size):
        sequence = [create_dummy_heterodata() for _ in range(sequence_size)]
        graphs_batch.append(sequence)
    
    # Move to device if needed
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder = encoder.to(device)
    
    # Optionally, move data to device
    # This requires iterating and moving each HeteroData to the device
    for i in range(batch_size):
        for j in range(sequence_size):
            for key in graphs_batch[i][j].x_dict.keys():
                graphs_batch[i][j].x_dict[key] = graphs_batch[i][j].x_dict[key].to(device)
            for key in graphs_batch[i][j].edge_index_dict.keys():
                graphs_batch[i][j].edge_index_dict[key] = graphs_batch[i][j].edge_index_dict[key].to(device)
    
    # Encode the batch
    with torch.no_grad():
        latent_representations = encoder(graphs_batch)  # Shape: (batch_size, sequence_size, 128)
    
    print(latent_representations.shape)  # Should print torch.Size([4, 5, 128])

1
2
2
1
2
1
2
2
1
1
2
1
1
2
1
2
2
2
1
1
graph_emb.shape=torch.Size([20, 64])
torch.Size([4, 5, 128])


In [25]:
def debug_print(arg):
    print(f"{arg=}")

In [13]:
kuk = 34
debug_print("kuk")

arg='kuk'


In [43]:
ar = torch.tensor((
    [
    [1,2,3],[4,5,6], [1,2,3],[4,5,6]
    ],
    [
    [11,12,13],[41,53,60],[31,42,53],[46,5,16]
    ]
))

In [44]:
ar.shape

torch.Size([2, 4, 3])

In [45]:
ar.sum(dim=0)

tensor([[12, 14, 16],
        [45, 58, 66],
        [32, 44, 56],
        [50, 10, 22]])

In [None]:
def train_loop():
    epoch_0 = 0
    for epoch in range(epoch_0, self.autoencoder_num_epochs):
                epoch_loss = []
                print("Experiment: autoencoder {0}: training Epoch = ".format(self.autoencoder), epoch+1, 'out of', self.autoencoder_num_epochs, 'epochs')
    
                # Loop through all the train data using the data loader
                for ii, (dem, ob, ac, l, t, scores, rewards, idx) in enumerate(self.train_loader):
                    # print("Batch {}".format(ii),end='')
                    dem = dem.to(device)  # 5 dimensional vector (Gender, Ventilation status, Re-admission status, Age, Weight)
                    ob = ob.to(device)    # 33 dimensional vector (time varying measures)
                    ac = ac.to(device) # actions
                    l = l.to(device)
                    t = t.to(device)
                    scores = scores.to(device)
                    idx = idx.to(device)
                    loss_pred = 0
    
                    # Cut tensors down to the batch's largest sequence length... Trying to speed things up a bit...
                    max_length = int(l.max().item())
    
                    # The following losses are for DDM and will not be modified by any other approach
                    train_loss, dec_loss, inv_loss = 0, 0, 0
                    model_loss, recon_loss, forward_loss = 0, 0, 0                    
                        
                    # Set training mode (nn.Module.train()). It does not actually trains the model, but just sets the model to training mode.
                    self.gen.train()
                    self.pred.train()
    
                    ob = ob[:,:max_length,:]
                    dem = dem[:,:max_length,:]
                    ac = ac[:,:max_length,:]
                    scores = scores[:,:max_length,:]
                    
                    # Special case for CDE
                    # Getting loss_pred and mse_loss
                    if self.autoencoder == 'CDE':
                        loss_pred, mse_loss, _ = self.container.loop(ob, dem, ac, scores, l, max_length, self.context_input, corr_coeff_param = self.corr_coeff_param, device = device, coefs = self.train_coefs, idx = idx)
                    else:
                        loss_pred, mse_loss, _ = self.container.loop(ob, dem, ac, scores, l, max_length, self.context_input, corr_coeff_param = self.corr_coeff_param, device=device, autoencoder = self.autoencoder)   
    
                    self.optimizer.zero_grad()
                    
                    if self.autoencoder != 'DDM':
                        loss_pred.backward()
                        self.optimizer.step()
                        epoch_loss.append(loss_pred.detach().cpu().numpy())                
                    else:
                        # For DDM
                        train_loss, dec_loss, inv_loss, model_loss, recon_loss, forward_loss, corr_loss, loss_pred = loss_pred
                        train_loss = forward_loss + self.inv_loss_coef*inv_loss + self.dec_loss_coef*dec_loss - self.corr_coeff_param*corr_loss.sum()
                        train_loss.backward()
                        # Clipping gradients to prevent exploding gradients
                        torch.nn.utils.clip_grad_norm(self.all_params, self.max_grad_norm)
                        self.optimizer.step()
                        epoch_loss.append(loss_pred.detach().cpu().numpy())
                                            
                self.autoencoding_losses.append(epoch_loss)
                if (epoch+1)%self.saving_period == 0: # Run validation and also save checkpoint
                    
                    #Computing validation loss
                    epoch_validation_loss = []
                    with torch.no_grad():
                        for jj, (dem, ob, ac, l, t, scores, rewards, idx) in enumerate(self.val_loader):
    
                            dem = dem.to(device)
                            ob = ob.to(device)
                            ac = ac.to(device)
                            l = l.to(device)
                            t = t.to(device)
                            idx = idx.to(device)
                            scores = scores.to(device)
                            loss_val = 0
    
                            # Cut tensors down to the batch's largest sequence length... Trying to speed things up a bit...
                            max_length = int(l.max().item())                        
                            
                            ob = ob[:,:max_length,:]
                            dem = dem[:,:max_length,:]
                            ac = ac[:,:max_length,:] 
                            scores = scores[:,:max_length,:] 
                            
                            self.gen.eval()
                            self.pred.eval()    
                            
                            if self.autoencoder == 'CDE':
                                loss_val, mse_loss, _ = self.container.loop(ob, dem, ac, scores, l, max_length, corr_coeff_param = 0, device = device, coefs = self.val_coefs, idx = idx)
                            else:
                                loss_val, mse_loss, _ = self.container.loop(ob, dem, ac, scores, l, max_length, self.context_input, corr_coeff_param = 0, device=device, autoencoder = self.autoencoder)                                                 
                            
                            if self.autoencoder in ['DST', 'ODERNN', 'CDE']:
                                epoch_validation_loss.append(mse_loss)
                            elif self.autoencoder == "DDM":
                                epoch_validation_loss.append(loss_val[-1].detach().cpu().numpy())
                            else:
                                epoch_validation_loss.append(loss_val.detach().cpu().numpy())
                        
                            
                    self.autoencoding_losses_validation.append(epoch_validation_loss)
    
                    save_dict = {'epoch': epoch,
                            'gen_state_dict': self.gen.state_dict(),
                            'pred_state_dict': self.pred.state_dict(),
                            'optimizer_state_dict': self.optimizer.state_dict(),
                            'loss': self.autoencoding_losses,
                            'validation_loss': self.autoencoding_losses_validation
                            }
                    
                    if self.autoencoder == 'DDM':
                        save_dict['dyn_state_dict'] = self.dyn.state_dict()
                        
                    try:
                        torch.save(save_dict, self.checkpoint_file)
                        # torch.save(save_dict, self.checkpoint_file[:-3] + str(epoch) +'_.pt')
                        np.save(self.data_folder + '/{}_losses.npy'.format(self.autoencoder.lower()), np.array(self.autoencoding_losses))
                    except Exception as e:
                        print(e)
    
                    
                    try:
                        np.save(self.data_folder + '/{}_validation_losses.npy'.format(self.autoencoder.lower()), np.array(self.autoencoding_losses_validation))
                    except Exception as e:
                        print(e)
                        
                #Final epoch checkpoint
                try:
                    save_dict = {
                                'epoch': self.autoencoder_num_epochs-1,
                                'gen_state_dict': self.gen.state_dict(),
                                'pred_state_dict': self.pred.state_dict(),
                                'optimizer_state_dict': self.optimizer.state_dict(),
                                'loss': self.autoencoding_losses,
                                'validation_loss': self.autoencoding_losses_validation,
                                }
                    if self.autoencoder == 'DDM':
                        save_dict['dyn_state_dict'] = self.dyn.state_dict()
                        torch.save(self.dyn.state_dict(), self.dyn_file)
                    torch.save(self.gen.state_dict(), self.gen_file)
                    torch.save(self.pred.state_dict(), self.pred_file)
                    torch.save(save_dict, self.checkpoint_file)
                    np.save(self.data_folder + '/{}_losses.npy'.format(self.autoencoder.lower()), np.array(self.autoencoding_losses))
                except Exception as e:
                        print(e)
               

In [None]:
dataset.node_time_dict = {
    "user": torch.tensor([1, 3, 5, 7]),  # Timestamps for user nodes
    "item": torch.tensor([2, 4, 6, 8]),  # Timestamps for item nodes
}

In [None]:
# @timeit
def time_select_up_to_t(dataset, t):
    """Select nodes and edges up to and including time step t.
    @param t: time index (inclusive)
    @return: HeteroData, with sliced nodes and edges based on time.
    """
    dt = HeteroData()
    d = dataset

    # Get node timestamps (if they exist)
    node_time_dict = getattr(d, "node_time_dict", {})  # Use empty dict if node_time_dict doesn't exist

    # Copy node information up to and including time step t
    for ntype, value in d.x_dict.items():
        if ntype in node_time_dict:  # If node type has time information
            mask = (node_time_dict[ntype] <= t).squeeze(-1)  # Filter nodes based on time
            dt[ntype].x = value[mask]  # Copy node features for filtered nodes
            if "num_nodes" in d[ntype].keys():
                dt[ntype].num_nodes = mask.sum().item()  # Update number of nodes
        else:  # If node type has no time information, copy all nodes
            dt[ntype].x = value
            if "num_nodes" in d[ntype].keys():
                dt[ntype].num_nodes = d[ntype].num_nodes

    # Get edge timestamps and edge indices
    dea = d.edge_time_dict
    dei = d.edge_index_dict

    # Filter edges up to and including time step t
    for etype in dea:
        mask = (dea[etype] <= t).squeeze(-1)  # Include all edges with time <= t
        dt[etype].edge_index = dei[etype][:, mask]
        dt[etype].edge_time = dea[etype][mask]

    return dt