In [None]:
from torch_geometric.datasets import  TUDataset
from torch_geometric.transforms import ToDense
from torch_geometric.loader import DenseDataLoader
from  torch_geometric.nn import DenseSAGEConv
import torch
from math import ceil
from torch_geometric.nn import dense_diff_pool, GraphNorm
import torch.nn.functional as F

In [None]:
max_num_nodes =50
class Reduce(object):
    def __call__(self, data):
        return data.num_nodes <= max_num_nodes


In [None]:
dataset = TUDataset(
    root="data/TUDataset",
    name="PROTEINS",
    transform=ToDense(max_num_nodes),
    pre_filter=Reduce(),
    force_reload=True,
)


In [None]:
# dataset = TUDataset(
#     root="data/TUDataset",
#     name="PROTEINS",
#     force_reload=True
# )
# len(dataset)


In [None]:
# max_num_nodes = 0
# for data in dataset:
#     if data.num_nodes > max_num_nodes:
#         max_num_nodes = data.num_nodes
# max_num_nodes

In [None]:
# dataset = TUDataset(
#     root="data/TUDataset",
#     name="PROTEINS",
#     transform=ToDense(max_num_nodes),
#     force_reload=True,
# )


In [None]:
len(dataset)
# it did not decresed beacuse
#  i took max num nodes in dataset for computing ToDense
#  so need prefilter


In [None]:
dataset = dataset.shuffle()


In [None]:

test_dataset = dataset[:int(0.1*len(dataset))]
len(test_dataset)


In [None]:
val_dataset = dataset[len(test_dataset) : int(0.2 *len(dataset))]
len(val_dataset)

In [None]:
train_dataset = dataset[len(val_dataset)+len(test_dataset):]
len(train_dataset)

In [None]:
test_loader = DenseDataLoader(test_dataset, batch_size=32)
val_loader = DenseDataLoader(val_dataset, batch_size=32)
train_loader = DenseDataLoader(train_dataset, batch_size=32)


In [None]:
class GNNMODULE(torch.nn.Module):
    def __init__(
        self,
        _in,
        _hidden,
        _num_layer,
        _out,
    ):
        super().__init__()
        self.gnn = torch.nn.ModuleList()
        self.graph_norm = torch.nn.ModuleList()
        if _num_layer == 1:
            self.gnn.append(DenseSAGEConv(_in, _out))
            self.graph_norm.append(GraphNorm(_out))
        else:
            self.gnn.append(DenseSAGEConv(_in, _hidden))
            self.graph_norm.append(GraphNorm(_hidden))
            for _ in range(_num_layer - 2):
                self.gnn.append(DenseSAGEConv(_hidden, _hidden))
                self.graph_norm.append(GraphNorm(_hidden))
            self.gnn.append(DenseSAGEConv(_hidden, _out))
            self.graph_norm.append(GraphNorm(_out))
        self.num_layer = _num_layer

    def forward(self, x, adj):
        for i in range(self.num_layer):
            x = F.relu(self.graph_norm[i](self.gnn[i](x, adj)))

        return x

In [None]:
class DiFFPooLMODEL(torch.nn.Module):
    def __init__(self, _in, _hidden, _out):
        super().__init__()

        self.gnn_in = GNNMODULE(_in, _hidden, 1, _hidden)

        self.gnn_emb_1 = GNNMODULE(_hidden, _hidden, 2, _hidden)
        num_nodes = ceil(0.25 * max_num_nodes)
        self.gnn_pool_1 = GNNMODULE(_hidden, _hidden, 2, num_nodes)

        self.gnn_mid = GNNMODULE(_hidden, _hidden, 1, _hidden)

        self.gnn_emb_2 = GNNMODULE(_hidden, _hidden, 2, _hidden)
        num_nodes = ceil(0.25 * num_nodes)
        self.gnn_pool_2 = GNNMODULE(_hidden, _hidden, 2, num_nodes)

        self.gnn_out = GNNMODULE(_hidden, _hidden, 1, _hidden)

        self.lin_1 = torch.nn.Linear(_hidden, _hidden)
        self.lin_2 = torch.nn.Linear(_hidden, _out)
        self.bns = torch.nn.BatchNorm1d(_hidden)

    def forward(self, x, adj, mask):
        x = self.gnn_in(x, adj)

        s = self.gnn_pool_1(x, adj)
        x = self.gnn_emb_1(x, adj)
        x, adj, llp_1, le_1 = dense_diff_pool(x, adj, s, mask)

        x = self.gnn_mid(x, adj)

        s = self.gnn_pool_2(x, adj)
        x = self.gnn_emb_2(x, adj)
        x, adj, llp_2, le_2 = dense_diff_pool(x, adj, s)

        x = self.gnn_out(x, adj)
        x = x.mean(dim=1)

        # x = F.relu(x)
        # x = F.dropout(x, p=0.3)
        x = self.bns(self.lin_1(x))
        x = F.relu(x)
        # x = F.dropout(x, p=0.3)
        x = self.lin_2(x)
        x = F.softmax(x, dim=-1)
        # print(x.shape)
        return x, llp_1 + le_1 + llp_2 + le_2


In [None]:
model = DiFFPooLMODEL(dataset.num_features, 64, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


In [None]:
# Define training function
def train(epoch):
    model.train()
    loss_all = 0

    for data in train_loader:
       
        optimizer.zero_grad()
        output, ll = model(data.x, data.adj, data.mask)
        loss = F.cross_entropy(output, data.y.view(-1))
        total_loss = loss + 0.1 * (ll)  # Incorporate auxiliary loss
        total_loss.backward()
        loss_all += data.y.size(0) * total_loss.item()
        optimizer.step()
    return loss_all / len(train_dataset)


# Define testing function
@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0

    for data in loader:
        
        pred = model(data.x, data.adj, data.mask)[0].max(dim=1)[1]
        correct += pred.eq(data.y.view(-1)).sum().item()
    return correct / len(loader.dataset)


# Training loop
best_val_acc = test_acc = 0
for epoch in range(1, 151):
    train_loss = train(epoch)
    val_acc = test(val_loader)
    test_acc = test(test_loader)
    if val_acc > best_val_acc:
        
        best_val_acc = val_acc
        # Save the best model
        torch.save(model.state_dict(), "best_model.pth")
    print(
        f"Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, "
        f"Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}"
    )
