In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn.glob import global_add_pool

In [None]:
class GCN(torch.nn.Module):
    def arguments_read(self, *args, **kwargs):
        if args:
            if len(args)==1:  
                data = args[0]
                x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
                if hasattr(data, 'batch'): batch = data.batch
                else: batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device)
            
            elif len(args)==3:  
                x, edge_index, edge_attr = args[0], args[1], args[2]
                batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device)

            elif len(args)==4:  
                x, edge_index, edge_attr, batch = args[0], args[1], args[2], args[3]

            elif len(args)==2:  
                x, edge_index = args[0], args[1]
                edge_attr = kwargs.get('edge_attr')
                batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device)
            
            else:
                raise ValueError(f"forward's args should take 1, 3 or 4 arguments but got {len(args)}")
        else:
            data = kwargs.get('data')
            if not data:  
                x = kwargs.get('x')
                edge_index = kwargs.get('edge_index')
                edge_attr = kwargs.get('edge_attr')
                batch = kwargs.get('batch')
                if not batch: batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device)
            else:  
                x = data.x
                edge_index = data.edge_index
                edge_attr = data.edge_attr
                if hasattr(data, 'batch'): batch = data.batch
                else: batch = torch.zeros(x.shape[0], dtype=torch.int64, device=x.device)
        
        assert x is not None, "forward's args is empty and required node features x is not in kwargs"
        assert edge_index is not None, "forward's args is empty and required edge_index is not in kwargs"
                
        return x, edge_index, edge_attr, batch

    
    def __init__(self, input_dim=1, output_dim=1, gnn_latent_dim=[128,128,128,128,128,128], gnn_dropout=0, fc_latent_dim=[], fc_dropout=0):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.gnn_latent_dim = gnn_latent_dim
        self.gnn_dropout = gnn_dropout
        self.num_gnn_layers = len(self.gnn_latent_dim)
        
        self.fc_latent_dim = fc_latent_dim
        self.fc_dropout = fc_dropout
        self.num_fc_layers = len(self.fc_latent_dim) + 1
        
        
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(input_dim, self.gnn_latent_dim[0]))
        for i in range(1, self.num_gnn_layers): self.convs.append(GCNConv(self.gnn_latent_dim[i-1], self.gnn_latent_dim[i]))
        
        self.fcs = nn.ModuleList()
        if self.num_fc_layers > 1:
            self.fcs.append(nn.Linear(self.gnn_latent_dim[-1], self.fc_latent_dim[0]))
            for i in range(1, self.num_fc_layers-1): self.fcs.append(nn.Linear(self.fc_latent_dim[i-1], self.fc_latent_dim[i]))
            self.fcs.append(nn.Linear(self.fc_latent_dim[-1], self.output_dim))
        else: self.fcs.append(nn.Linear(self.gnn_latent_dim[-1], self.output_dim))

    
    def forward(self, *args, **kwargs):
        x, edge_index, edge_attr, batch = self.arguments_read(*args, **kwargs)

        for i in range(self.num_gnn_layers):
            x = self.convs[i](x, edge_index, edge_weight=edge_attr)
            x = F.relu(x)
            x = F.dropout(x, p=self.gnn_dropout, training=self.training)

        for i in range(self.num_fc_layers-1):
            x = self.fcs[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.fc_dropout, training=self.training)
        x = self.fcs[-1](x)

        if 'node_regression' in kwargs and kwargs['node_regression']==True: return x
        else: return global_add_pool(x, batch)