In [None]:
import os
import json
import torch
import torch.nn.functional as F
import pandas as pd
from torch import nn
from tqdm import tqdm
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
from torch_geometric.nn import GATv2Conv as GATConv, SAGEConv, GCNConv


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NODE_FEATURE_DIM = 778
HIDDEN_DIM = 128
EMBEDDING_DIM = 64
EPOCHS = 10
BATCH_SIZE = 1024 * 10
NEG_SAMPLE_RATIO = 2.0
TRAIN_RATIO = 0.85
DATA_DIR = 'data'
MODEL_DIR = 'models'


class GATEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=[8, 4, 1], dropout=0.2):
        super().__init__()
        self.dropout = dropout
        self.gat1 = GATConv(in_channels, hidden_channels, heads=heads[0], dropout=dropout)
        self.gat2 = GATConv(hidden_channels * heads[0], hidden_channels, heads=heads[1], dropout=dropout)
        self.gat3 = GATConv(hidden_channels * heads[1], out_channels, heads=heads[2], dropout=dropout)

    def forward(self, x, edge_index):
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.gat2(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        return self.gat3(x, edge_index)


class GraphSAGEEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.2):
        super().__init__()
        self.dropout = dropout
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return self.conv3(x, edge_index)


class GCNEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.2):
        super().__init__()
        self.dropout = dropout
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return self.conv3(x, edge_index)


class LinkPredictor(nn.Module):
    def __init__(self, embedding_dim, hidden_dim=128, dropout=0.3):
        super().__init__()
        input_dim = 3 * embedding_dim + 1
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.bn2 = nn.BatchNorm1d(hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, emb, edge_pairs):
        src = emb[edge_pairs[:, 0]]
        dst = emb[edge_pairs[:, 1]]
        h_cat = torch.cat([src, dst], dim=1)
        h_mul = src * dst
        h_dot = (src * dst).sum(dim=1, keepdim=True)
        x = torch.cat([h_cat, h_mul, h_dot], dim=1)
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)
        return torch.sigmoid(self.fc3(x)).squeeze()


model_classes = {
    'GAT': GATEncoder,
    'SAGE': GraphSAGEEncoder,
    'GCN': GCNEncoder,
}


dataset_files = [f for f in os.listdir(DATA_DIR) if f.endswith('.pt')]

all_results = []

for model_name, EncoderClass in model_classes.items():
    for dataset_file in dataset_files:
        dataset_path = os.path.join(DATA_DIR, dataset_file)
        dataset_name = dataset_file.replace('.pt', '')
        pyg_data = torch.load(dataset_path)
        NODE_FEATURE_DIM = pyg_data.x.shape[1]
        edge_index = pyg_data.edge_index
        num_edges = edge_index.size(1)
        perm = torch.randperm(num_edges)
        train_size = int(num_edges * TRAIN_RATIO)
        train_edges = edge_index[:, perm[:train_size]]
        test_edges = edge_index[:, perm[train_size:]]
        train_data = Data(x=pyg_data.x, edge_index=train_edges)

        train_loader = NeighborLoader(
            train_data,
            num_neighbors=[10, 5],
            batch_size=BATCH_SIZE,
            shuffle=True
        )

        encoder = EncoderClass(NODE_FEATURE_DIM, HIDDEN_DIM, EMBEDDING_DIM).to(DEVICE)
        predictor = LinkPredictor(EMBEDDING_DIM).to(DEVICE)

        optimizer = torch.optim.Adam(
            list(encoder.parameters()) + list(predictor.parameters()),
            lr=0.001,
            weight_decay=1e-4
        )

        # Обучение
        for epoch in range(1, EPOCHS + 1):
            encoder.train()
            predictor.train()
            total_loss = 0

            for batch in tqdm(train_loader, desc=f"{model_name}-{dataset_name} Epoch {epoch}", leave=False):
                batch = batch.to(DEVICE)
                optimizer.zero_grad()
                out = encoder(batch.x.float(), batch.edge_index)
                pos_edge_index = batch.edge_index.t()
                num_pos = pos_edge_index.size(0)
                num_neg = int(num_pos * NEG_SAMPLE_RATIO)

                neg_edge_index = negative_sampling(
                    edge_index=pos_edge_index.t(),
                    num_nodes=batch.num_nodes,
                    num_neg_samples=num_neg,
                    method='sparse'
                ).t()

                all_edges = torch.cat([pos_edge_index, neg_edge_index], dim=0)
                labels = torch.cat([
                    torch.ones(num_pos, device=DEVICE),
                    torch.zeros(num_neg, device=DEVICE)
                ])
                preds = predictor(out, all_edges)
                loss = F.binary_cross_entropy(preds, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

        with torch.no_grad():
            test_pos = test_edges.t().to(DEVICE)
            num_test = test_pos.size(0)
            test_neg = negative_sampling(
                edge_index=test_edges,
                num_nodes=pyg_data.num_nodes,
                num_neg_samples=num_test,
                method='sparse'
            ).t().to(DEVICE)

            test_all_edges = torch.cat([test_pos, test_neg], dim=0)
            test_labels = torch.cat([
                torch.ones(num_test, device=DEVICE),
                torch.zeros(num_test, device=DEVICE)
            ])
            unique_nodes = torch.unique(test_all_edges)

            encoder.eval()
            predictor.eval()

            subgraph_loader = NeighborLoader(
                pyg_data,
                input_nodes=unique_nodes,
                num_neighbors=[10, 5],
                batch_size=8192,
                shuffle=False
            )

            embedding_bank = torch.zeros((pyg_data.num_nodes, EMBEDDING_DIM), device=DEVICE)

            for batch in tqdm(subgraph_loader, desc=f"Encoding {model_name}-{dataset_name}", leave=False):
                batch = batch.to(DEVICE)
                out = encoder(batch.x.float(), batch.edge_index)
                embedding_bank[batch.n_id] = out

            preds = predictor(embedding_bank, test_all_edges).detach().cpu()
            labels = test_labels.detach().cpu()

            auc = roc_auc_score(labels.numpy(), preds.numpy())
            acc = accuracy_score(labels.numpy(), preds.numpy() > 0.3)
            f1 = f1_score(labels.numpy(), preds.numpy() > 0.3)

        model_save_path = os.path.join(MODEL_DIR, model_name, dataset_name)
        os.makedirs(model_save_path, exist_ok=True)
        torch.save(encoder.state_dict(), os.path.join(model_save_path, 'encoder.pt'))
        torch.save(predictor.state_dict(), os.path.join(model_save_path, 'predictor.pt'))

        result = {
            'model': model_name,
            'dataset': dataset_name,
            'AUC': auc,
            'Accuracy': acc,
            'F1': f1
        }
        all_results.append(result)

with open('results.json', 'w') as f:
    json.dump(all_results, f, indent=4)