In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
import pathlib

data_dir = pathlib.Path().resolve().parent/"data"

In [None]:
from torch_geometric.data import Data

def create_group_data(group_features, label, n_members):
    x = torch.tensor(group_features, dtype=torch.float)
    a = [for n in range(n_members)]
    edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]], dtype=torch.long)  # Полносвязный граф
    y = torch.tensor([label], dtype=torch.float)
    return Data(x=x, edge_index=edge_index, y=y)

In [4]:
class GroupSuccessPredictor(nn.Module):
    def __init__(self, num_features, hidden_dim, output_dim):
        super(GroupSuccessPredictor, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return torch.sigmoid(x)