In [1]:
!pip install torch_geometric
!pip install torch-scatter
!pip install gudhi
!pip install torchdiffeq
!pip install scikit-optimize

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: Operation cancelled by user[0m[31m
[0mTraceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/pip/_internal/cli/base_command.py", line 179, in exc_logging_wrapper
    status = run_func(*args)
             ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/pip/_internal/cli/req_command.py", line 67, in wrapper
    return func(self, options, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-package

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.transforms import NormalizeFeatures
from skopt import gp_minimize
from skopt.space import Real, Integer, Categorical
import numpy as np

from torch_geometric.datasets import Planetoid, WebKB, WikipediaNetwork

##############################################
# Функция загрузки датасета по имени
##############################################
def load_dataset(dataset_name):
    transform = NormalizeFeatures()
    if dataset_name in ["Cora", "Citeseer", "Pubmed"]:
        dataset = Planetoid(root=f'data/Planetoid/{dataset_name}', name=dataset_name, transform=transform)
    elif dataset_name in ["Texas", "Wisconsin", "Cornell"]:
        dataset = WebKB(root='data/WebKB', name=dataset_name, transform=transform)
    elif dataset_name in ["Film", "Squirrel", "Chameleon"]:
        dataset = WikipediaNetwork(root='data/WikipediaNetwork', name=dataset_name.lower(), transform=transform)
    else:
        raise ValueError(f"Unknown dataset {dataset_name}")

    data = dataset[0]
    if hasattr(data, "train_mask") and data.train_mask.ndim > 1:
        data.train_mask = data.train_mask[:, 0]
        data.val_mask   = data.val_mask[:, 0]
        data.test_mask  = data.test_mask[:, 0]
    return dataset, data

##############################################
# Функции для вычисления базисов и связей на ребрах
##############################################
def compute_node_bases(node_feat, edge_index, d):
    N = node_feat.size(0)
    bases = []
    row, col = edge_index
    for i in range(N):
        mask = (row == i)
        neighbor_idx = col[mask]
        if neighbor_idx.numel() < d:
            basis = torch.eye(d, device=node_feat.device)
        else:
            diffs = node_feat[neighbor_idx] - node_feat[i].unsqueeze(0)
            try:
                _, _, Vh = torch.linalg.svd(diffs, full_matrices=False)
                basis = Vh
            except Exception as e:
                basis = torch.eye(d, device=node_feat.device)
        bases.append(basis.unsqueeze(0))
    bases = torch.cat(bases, dim=0)
    return bases

def compute_edge_conn(bases, edge_index):
    row, col = edge_index
    O_i = bases[row]
    O_j = bases[col]
    prod = torch.matmul(O_i.transpose(-2, -1), O_j)
    edge_conn = []
    for i in range(prod.size(0)):
        try:
            U, _, Vh = torch.linalg.svd(prod[i])
            O_ij = torch.matmul(U, Vh)
        except Exception as e:
            O_ij = torch.eye(prod.size(-1), device=prod.device)
        edge_conn.append(O_ij.unsqueeze(0))
    edge_conn = torch.cat(edge_conn, dim=0)
    return edge_conn

##############################################
# Модель обучения
##############################################
class ConnSheafLayer(nn.Module):
    def __init__(self, d, f_dim):
        super().__init__()
        self.d = d
        self.f_dim = f_dim
        self.ln = nn.LayerNorm([d, f_dim])

    def forward(self, x, edge_index, edge_conn):
        num_nodes = x.size(0)
        row, col = edge_index
        messages = torch.matmul(edge_conn, x[col])
        agg = torch.zeros_like(x)
        agg = agg.index_add(0, row, messages)
        deg = torch.bincount(row, minlength=num_nodes).float().view(-1, 1, 1).clamp(min=1.0)
        agg = agg / deg
        out = self.ln(F.relu(agg))
        return out

class ConnSheafNet(nn.Module):
    def __init__(self, in_dim, d, f_dim, out_dim, depth=6):
        super().__init__()
        self.encoder = nn.Linear(in_dim, d)
        self.d = d
        self.f_dim = f_dim
        self.layers = nn.ModuleList([ConnSheafLayer(d, f_dim) for _ in range(depth)])
        self.decoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(d * f_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, out_dim)
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        node_feat = self.encoder(x)
        # Вычисляем базисы для каждой вершины
        with torch.no_grad():
            bases = compute_node_bases(node_feat, edge_index, self.d)
            edge_conn = compute_edge_conn(bases, edge_index)
        x_sheaf = node_feat.unsqueeze(2).repeat(1, 1, self.f_dim)
        for layer in self.layers:
            x_sheaf = layer(x_sheaf, edge_index, edge_conn) + x_sheaf
        out = self.decoder(x_sheaf)
        return F.log_softmax(out, dim=1)

##############################################
# Функции обучения и тестирования
##############################################
def train(model, data, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def test(model, data):
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = pred[mask].eq(data.y[mask]).sum().item()
        acc = correct / mask.sum().item()
        accs.append(acc)
    return accs

##############################################
# Целевая функция для байесовской оптимизации
##############################################
def objective(params, dataset, data, device):
    # Гиперпараметры: lr, weight_decay, d, f_dim, depth
    lr, weight_decay, d, f_dim, depth = params
    d = int(d)
    f_dim = int(f_dim)
    depth = int(depth)

    model = ConnSheafNet(
        in_dim=dataset.num_node_features,
        d=d,
        f_dim=f_dim,
        out_dim=dataset.num_classes,
        depth=depth
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    epochs = 100
    for epoch in range(1, epochs + 1):
        _ = train(model, data, optimizer)

    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)
    correct = pred[data.val_mask].eq(data.y[data.val_mask]).sum().item()
    val_acc = correct / data.val_mask.sum().item()
    print(f"Params: lr={lr:.1e}, wd={weight_decay:.1e}, d={d}, f_dim={f_dim}, depth={depth} => Val Acc: {val_acc:.4f}")
    return -val_acc

##############################################
#перебор датасетов и подбор гиперпараметров
##############################################
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Список датасетов для проверки
    dataset_names = ["Texas", "Wisconsin", "Film", "Squirrel", "Chameleon",
                     "Cornell", "Citeseer", "Pubmed", "Cora"]

    # Пространство гиперпараметров:
    # lr, weight_decay, d, f_dim, depth
    space = [
        Real(1e-5, 1e-1, prior='log-uniform', name='lr'),
        Real(1e-5, 1e-3, prior='log-uniform', name='weight_decay'),
        Integer(8, 64, name='d'),
        Integer(2, 8, name='f_dim'),
        Integer(4, 10, name='depth')
    ]

    for ds_name in dataset_names:
        print(f"\n{'='*40}\nDataset: {ds_name}\n{'='*40}")
        try:
            dataset, data = load_dataset(ds_name)
        except Exception as e:
            print(f"Ошибка при загрузке датасета {ds_name}: {e}")
            continue
        data = data.to(device)
        func = lambda params: objective(params, dataset, data, device)
        res = gp_minimize(func, dimensions=space, n_calls=30, random_state=42, verbose=True)
        best_params = res.x
        best_val_acc = -res.fun
        print(f"Лучшие гиперпараметры для {ds_name}:")
        print(f"lr: {best_params[0]:.1e}, wd: {best_params[1]:.1e}, d: {best_params[2]}, "
              f"f_dim: {best_params[3]}, depth: {best_params[4]}")
        print(f"Валидационная точность: {best_val_acc:.4f}")
