In [None]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
# !pip install einops
# !pip install wandb

In [1]:
import random
from collections import defaultdict
from itertools import product
from typing import Callable, Optional
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv,global_mean_pool, ChebConv,global_add_pool
import torch
import numpy as np
from torch import Tensor

def index_to_mask(index: Tensor, size: Optional[int] = None) -> Tensor:
    r"""Converts indices to a mask representation.

    Args:
        idx (Tensor): The indices.
        size (int, optional). The size of the mask. If set to :obj:`None`, a
            minimal sized output mask is returned.

    Example:

        >>> index = torch.tensor([1, 3, 5])
        >>> index_to_mask(index)
        tensor([False,  True, False,  True, False,  True])

        >>> index_to_mask(index, size=7)
        tensor([False,  True, False,  True, False,  True, False])
    """
    index = index.view(-1)
    size = int(index.max()) + 1 if size is None else size
    mask = index.new_zeros(size, dtype=torch.bool)
    mask[index] = True
    return mask

from torch_geometric.data import Data, HeteroData, InMemoryDataset
from torch_geometric.utils import coalesce, remove_self_loops, to_undirected

from torch_geometric.transforms import RandomNodeSplit
from torch_geometric.graphgym.config import cfg

# from torch_geometric.graphgym.utils.ben_utils import get_k_hop_adjacencies


class RingTransferDataset(InMemoryDataset):
    r"""A synthetic dataset that returns a Ring Transfer dataset.

    Args:
        num_graphs (int, optional): The number of graphs. (default: :obj:`1`)
        num_nodes (int, optional): The average number of nodes in a graph.
            (default: :obj:`1000`)
        num_classes (int, optional): The number of node features.
            (default: :obj:`64`)
        task (str, optional): Whether to return node-level or graph-level
            labels (:obj:`"node"`, :obj:`"graph"`, :obj:`"auto"`).
            If set to :obj:`"auto"`, will return graph-level labels if
            :obj:`num_graphs > 1`, and node-level labels other-wise.
            (default: :obj:`"auto"`)
        is_undirected (bool, optional): Whether the graphs to generate are
            undirected. (default: :obj:`True`)
        transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            every access. (default: :obj:`None`)
        **kwargs (optional): Additional attributes and their shapes
            *e.g.* :obj:`global_features=5`.
    """
    def __init__(
        self,
        num_graphs,
        num_nodes,
        num_classes,
        # task: str = "auto",
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        **kwargs,
    ):
        super().__init__('.', transform)

        self.num_graphs = num_graphs
        self.num_nodes = num_nodes
        self._num_classes = num_classes
        self.kwargs = kwargs
        # if cfg.gnn.layers_mp == 1: # the default - otherwise use specified no.

        #     cfg.gnn.layers_mp = num_nodes//2
        split = (self.num_graphs * torch.tensor([0.8, 0.1, 0.1])).long()
        data_list, split = self.load_ring_transfer_dataset(self.num_nodes,
                                                           split=split,
                                                           classes=self._num_classes)

        self.data, self.slices = self.collate(data_list)

        # add train/val split masks
        self.data.train_mask = index_to_mask(torch.tensor(split[0]), size=len(self.data.x))
        self.data.val_mask = index_to_mask(torch.tensor(split[1]), size=len(self.data.x))
        self.data.test_mask = index_to_mask(torch.tensor(split[2]), size=len(self.data.x))


    def load_ring_transfer_dataset(self, nodes=10, split=[5000, 500, 500], classes=5):
        train = self.generate_ring_transfer_graph_dataset(nodes, classes=classes, samples=split[0])
        val = self.generate_ring_transfer_graph_dataset(nodes, classes=classes, samples=split[1])
        test = self.generate_ring_transfer_graph_dataset(nodes, classes=classes, samples=split[2])
        dataset = train + val + test
        return dataset, [list(range(int(split[i]))) for i in range(3)]

    def generate_ring_transfer_graph_dataset(self, nodes, classes=5, samples=10000):
        # Generate the dataset
        dataset = []
        samples_per_class = torch.div(samples, classes, rounding_mode="floor")
        for i in range(samples):
            label = torch.div(i, samples_per_class, rounding_mode="floor")
            target_class = np.zeros(classes)
            target_class[label] = 1.0
            graph = self.generate_ring_transfer_graph(nodes, target_class)
            dataset.append(graph)
        return dataset

    def generate_ring_transfer_graph(self, nodes, target_label):
        opposite_node = nodes // 2

        # Initialise the feature matrix with a constant feature vector
        x = np.ones((nodes, len(target_label)))

        x[0, :] = 0.0
        x[opposite_node, :] = target_label
        x = torch.tensor(x, dtype=torch.float32)

        edge_index = []
        for i in range(nodes-1):
            edge_index.append([i, i + 1])
            edge_index.append([i + 1, i])

        # Add the edges that close the ring
        edge_index.append([0, nodes - 1])
        edge_index.append([nodes - 1, 0])
        edge_index = torch.tensor(edge_index, dtype=torch.long).T

        # Create a mask for the target node of the graph
        mask = torch.zeros(nodes, dtype=torch.bool)
        mask[0] = 1

        # Add the label of the graph as a graph label
        y = torch.tensor([np.argmax(target_label)], dtype=torch.long)

        return Data(x=x, edge_index=edge_index, mask=mask, y=y)



In [2]:

import torch
from torch_geometric.datasets import TUDataset, LRGBDataset
import os.path as osp
import torch_geometric.transforms as T
import wandb
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
import math
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv,global_mean_pool, ChebConv,global_add_pool
#from torch_sparse import SparseTensor
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import torch.nn as nn
import argparse
import random
import torch_geometric
from torch_geometric.utils import dropout_adj,dense_to_sparse
from torch_geometric.utils import to_dense_adj,dense_to_sparse #dropout_adj,
from torch_geometric.transforms import AddLaplacianEigenvectorPE, AddRandomWalkPE

parser = argparse.ArgumentParser()
parser.add_argument('--comp', type=int, default=6, help='Latent Bottleneck')
parser.add_argument('--hidden', type=int, default=64, help='Latent Dimension')
parser.add_argument('--RW', type=int, default=8, help='Latent Dimension')
parser.add_argument('--seed', type=int, default=12, help='Latent Dimension')
parser.add_argument('--batch_size', type=int, default=32, help='Latent Dimension')
parser.add_argument('--k', type=int, default=16, help='number of vectors')
parser.add_argument('--laplace', type=bool, default=False, help='Use laplacian PE')
parser.add_argument('--laplace_RW', type=bool, default=False, help='Use laplacian PE')
parser.add_argument('--FA', type=bool, default=False, help='Use FA Layer')
parser.add_argument('--learnable', type=bool, default=False, help='Use FA Layer')

parser.add_argument('--use_graph', type=bool, default=True, help='Use graph infos')
parser.add_argument('--patches', type=bool, default=True, help='Use graph infos')
parser.add_argument('--pretrain', type=bool, default=False, help='Use graph infos')

parser.add_argument('--use_weights', type=bool, default=False, help='Use graph infos')
parser.add_argument('--gConv', type=str, default='GAT', help='graph conv between latents')
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
parser.add_argument('--lamb', type=float, default=0.8, help='Regularizer')
parser.add_argument('--webdata', type=str, default='xwycycvu', help='which data')
args = parser.parse_args([])

In [3]:
def LapPE(edge_index, pos_enc_dim, num_nodes):
    """
        Graph positional encoding v/ Laplacian eigenvectors
    """

    # Laplacian
    degree = torch_geometric.utils.degree(edge_index[0], num_nodes)
    A = torch_geometric.utils.to_scipy_sparse_matrix(
        edge_index, num_nodes=num_nodes)
    N = sp.diags(np.array(degree.clip(1) ** -0.5, dtype=float))
    L = sp.eye(num_nodes) - N * A * N

    # Eigenvectors with numpy
    EigVal, EigVec = np.linalg.eig(L.toarray())
    idx = EigVal.argsort()  # increasing order
    EigVal, EigVec = EigVal[idx], np.real(EigVec[:, idx])
    PE = torch.from_numpy(EigVec[:, 1:pos_enc_dim+1]).float()
    if PE.size(1) < pos_enc_dim:
        zeros = torch.zeros(num_nodes, pos_enc_dim)
        zeros[:, :PE.size(1)] = PE
        PE = zeros
    return PE


def random_walk(A, n_iter):
    # Geometric diffusion features with Random Walk
    Dinv = A.sum(dim=-1).clamp(min=1).pow(-1).unsqueeze(-1)  # D^-1
    RW = A * Dinv
    M = RW
    M_power = M
    # Iterate
    PE = [torch.diagonal(M)]
    for _ in range(n_iter-1):
        M_power = torch.matmul(M_power, M)
        PE.append(torch.diagonal(M_power))
    PE = torch.stack(PE, dim=-1)
    return PE


def RWSE(edge_index, pos_enc_dim, num_nodes):
    """
        Initializing positional encoding with RWSE
    """
    if edge_index.size(-1) == 0:
        PE = torch.zeros(num_nodes, pos_enc_dim)
    else:
        A = torch_geometric.utils.to_dense_adj(
            edge_index, max_num_nodes=num_nodes)[0]
        PE = random_walk(A, pos_enc_dim)
    return PE

In [4]:

def pe(datasetz):
  outs=[]
  ins=[]
  dataset2=[]
  for p in range(len(datasetz)):
    try:
      tempo=datasetz[p]
      posi=RWSE(datasetz[p].edge_index, args.k, datasetz[p]['x'].shape[0])
      tempo['laplace']=torch.cat((datasetz[p].x,posi),dim=1)
      dataset2.append(tempo)
      ins.append(p)
    except:
      print(p)
      outs.append(p)
  return dataset2


dataset = RingTransferDataset(num_graphs=2000,num_nodes=60,num_classes=5)#.shuffle()
dataset1=dataset[:int(2000*0.8)]
validation_set1=dataset[int(2000*0.8):int(2000*0.9)]
test_set1=dataset[int(2000*0.9):]
# validation_set1 = RingTransferDataset(split="val")#.shuffle()
# test_set1 = RingTransferDataset(split="test")#.shuffle()

num_feats=dataset1.num_node_features
num_classes=dataset1.num_classes

if args.laplace==True:
  dataset1 = pe(dataset1)
  validation_set1 = pe(validation_set1)
  test_set1 = pe(test_set1)



In [5]:


from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d


class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim, heads = 3, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

        self.dropout = nn.Dropout(dropout)
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context , mask = None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim = -1)


        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h = h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)




from torch_geometric.nn import global_add_pool

BN = True


class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Identity, self).__init__()

    def forward(self, input):
        return input

    def reset_parameters(self):
        pass


from torch_scatter import scatter
class MLP(nn.Module):
    def __init__(self, nin, nout, nlayer=2, with_final_activation=True, with_norm=True, bias=True):
        super().__init__()
        n_hid = nin
        self.layers = nn.ModuleList([nn.Linear(nin if i == 0 else n_hid,
                                     n_hid if i < nlayer-1 else nout,
                                     # TODO: revise later
                                               bias=True if (i == nlayer-1 and not with_final_activation and bias)
                                               or (not with_norm) else False)  # set bias=False for BN
                                     for i in range(nlayer)])
        self.norms = nn.ModuleList([nn.BatchNorm1d(n_hid if i < nlayer-1 else nout) if with_norm else Identity()
                                    for i in range(nlayer)])
        self.nlayer = nlayer
        self.with_final_activation = with_final_activation
        self.residual = (nin == nout)  # TODO: test whether need this

    def reset_parameters(self):
        for layer, norm in zip(self.layers, self.norms):
            layer.reset_parameters()
            norm.reset_parameters()

    def forward(self, x):
        previous_x = x
        for i, (layer, norm) in enumerate(zip(self.layers, self.norms)):
            x = layer(x)
            if i < self.nlayer-1 or self.with_final_activation:
                x = norm(x)
                x = F.relu(x)

        # if self.residual:
        #     x = x + previous_x
        return x

In [6]:

from torch_geometric.utils import to_undirected
from torch_geometric.data import Data
class HNO(torch.nn.Module):

    def __init__(self,):
        super(HNO, self).__init__()

        if args.laplace_RW==True:
           total_feats=num_feats+args.RW

        else:
           total_feats=num_feats
        self.conv1 = ChebConv(total_feats, int(2*args.hidden),K=8)
        self.Inlin = Linear(total_feats, int(args.hidden))

        self.Inlin2 = Linear(int(args.hidden), int(args.hidden))
        # self.gin1 = GINConv(nn.Sequential(nn.Linear(num_feats, int(args.hidden)),
        #                                     #nn.ReLU(),
        #                                     #nn.Linear(args.hidden, args.hidden),
        #                                     nn.ReLU(),
        #                                     nn.Linear(int(args.hidden), int(args.hidden))))

        #self.conv1 = GCNConv(num_feats, int(2*args.hidden))
        self.conv2 = ChebConv(int(2*args.hidden), int(2*args.hidden),K=8)
        self.conv3 = ChebConv(int(2*args.hidden), args.hidden,K=8)
        self.conv4 = ChebConv(int(args.hidden), args.hidden,K=8)
        # self.conv5 = ChebConv(int(args.hidden), args.hidden,K=8)
        #self.conv4 = SAGEConv(args.hidden, args.hidden)
        self.convGraph =  GCNConv(args.hidden,args.hidden)
        self.convGraph2 =  GCNConv(args.hidden,args.comp)

        self.Fconv = GCNConv(args.hidden*args.comp, num_classes)

        self.conv_sTc= GATConv(args.hidden, args.hidden)#GATConv(args.hidden, args.hidden)
        #self.Inlin = Linear(dataset.num_features, int(args.hidden))

        self.lin = Linear(args.hidden, int(2*args.hidden))
        self.lin2 = Linear(int(2*args.hidden), args.hidden)

        self.mlp= Linear(args.hidden*args.comp, num_classes)
        self.mlp2= Linear(args.hidden, num_classes)

        self.bano1 = torch.nn.BatchNorm1d(num_features= int(2*args.hidden))
        self.bano2 = torch.nn.BatchNorm1d(num_features= int(2*args.hidden))
        self.bano3 = torch.nn.BatchNorm1d(num_features= int(args.hidden))

        self.embLin=Linear(args.hidden, int(args.hidden))

        self.LaToGr= GATConv(int(args.hidden), int(args.hidden))

        #self.cross=CrossAttention(int(args.hidden),context_dim=int(args.hidden)) ## query dim, key_dim

        # if args.gConv=='GIN':
        #   self.conv_cTc = GINConv(nn.Sequential(nn.Linear(int(args.hidden), int(args.hidden)),
        #                                       #nn.ReLU(),
        #                                       #nn.Linear(args.hidden, args.hidden),
        #                                       nn.ReLU(),
        #                                       nn.Linear(int(args.hidden), int(args.hidden))))
        # elif args.gConv=='SAGE':
        #   self.conv_cTc= SAGEConv(args.hidden, args.hidden)

        if args.gConv=='GAT':
          self.conv_cTc= GATConv(args.hidden, args.hidden)

        # self.gin2 = GINConv(nn.Sequential(nn.Linear(args.hidden, args.hidden),
        #                                     nn.ReLU(),
        #                                     nn.Linear(args.hidden, args.hidden),
        #                                     nn.ReLU(),
        #                                     nn.Linear(args.hidden, args.hidden)))
        self.mlpRep = MLP(
            int(args.hidden), num_classes, nlayer=2, with_final_activation=False)


    def connect(self,adTop):
      temp=[]
      for k in range(args.comp):
        T = torch.cat([adTop[:k], adTop[k+1:]])
        temp+=T
      return torch.stack(temp)

    def extend(self,Gnodes,comp,device):
      adTop=torch.arange(0, Gnodes.shape[0]).to(device)
      adBot=torch.arange(Gnodes.shape[0], Gnodes.shape[0]+comp).to(device)

      top=adTop.repeat(comp)
      bot=adBot.repeat(Gnodes.shape[0],1).t().reshape(-1)
      indices = torch.stack((top, bot))
      return indices#top,bot

    def adjacency_matrix(self,matrix):
        # Get the shape of the input matrix
        n= matrix.shape[0]

        # Generate the fully connected adjacency matrix
        row = torch.arange(n).repeat_interleave(n)
        col = torch.arange(n).repeat(n)
        edge_index = torch.stack([row, col], dim=0)
        return edge_index

    def my_patches(self,N,M,device):
      rows=torch.arange(0,N)
      div=N%M ## comp
      ent=N-div
      g2=torch.arange(N,+N+M)
      g2=g2.repeat_interleave(int(ent/M), dim=0)
      if div !=0:
        g2=torch.cat((g2,g2[-1].repeat(div)))
      indices = torch.tensor([rows.tolist(), g2.tolist()], dtype=torch.long)
      return indices


    def patches(self,N, M,device):

        # Calculate the number of connections per group
        connections_per_group = N // M

        # Generate the indices of connected nodes
        row_indices = []
        col_indices = []

        for i in range(M):
            start_index = i * connections_per_group
            end_index = start_index + connections_per_group

            row_indices.extend(range(start_index, end_index))
            col_indices.extend([N + i] * connections_per_group)

        if (N+M)%2 !=0:
                row_indices.append(N-1)
                perm = torch.randperm(len(col_indices))
                idx = perm[0]
                indi = col_indices[idx]

                col_indices.append(indi)


        # Create the sparse tensor
        #random.shuffle(col_indices)            ### When commented, removes random shuffling of columns in the end
        indices = torch.tensor([row_indices, col_indices], dtype=torch.long)

        return indices

    ###  insert the weights at the end of each graph in the batch
    def stack_rows(self,matrix, row_indices, weight):
        n, d = matrix.shape
        p = len(row_indices)
        new_n = n #+ k * p

        # Generate the stacked matrix
        stacked_matrix = torch.zeros(new_n, d, dtype=matrix.dtype, device=matrix.device)
        stacked_matrix[:n] = matrix

        # Stack the new matrix at the specified row indices
        k=weight.shape[0]
        for i, idx in enumerate(row_indices):
            stacked_matrix = torch.cat((stacked_matrix[:idx + i * k], weight, stacked_matrix[idx + i * k -k+ k:]))

        return stacked_matrix


    def transform_indices(self,A, B):

        n = len(A)
        m = max(B)

        shift = m - n+1

        # Step 2: Create mapping dictionary
        mapping = {}
        new_index = 0
        for value in set(B):
            mapping[value] = new_index
            new_index += 1

        # Step 3: Replace elements in B with new indices
        B_transformed = [mapping[value] for value in B]

        # Step 4: Create new tensor A
        A_transformed = list(range(shift, m + 1))

        return A_transformed, B_transformed

    def forward(self, x, edge_index, batch, params,batch_size,device,data,pretrain):

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x=self.bano1(x)
        x = F.dropout(x, training=self.training,p=0.2)

        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x=self.bano2(x)
        x = F.dropout(x, training=self.training,p=0.2)

        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.bano3(x)
        x = F.dropout(x, training=self.training,p=0.2)

        x = self.conv4(x, edge_index)

        classifier=self.mlpRep(x)
        #classifier=self.mlp(CompNodes)
        # classifier=F.leaky_relu(classifier)
        # classifier= F.dropout(classifier, training=self.training,p=0.3)
        # classifier=self.mlp2(classifier)
        cdd=0

        return classifier,cdd, cdd

In [7]:


import numpy as np
if args.laplace==True:
  eigs=args.k
else:
  eigs=0


import matplotlib.pyplot as plt
lplot=[]
vplot=[]
torch.manual_seed(args.seed)

model = HNO().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

criterion = torch.nn.CrossEntropyLoss()

from torch_geometric.loader import DataLoader
trainloader = DataLoader(dataset1, batch_size=args.batch_size, shuffle=True,drop_last=True)
valoader = DataLoader(validation_set1, batch_size=16, shuffle=False,drop_last=True)
testloader = DataLoader(test_set1, batch_size=16, shuffle=False,drop_last=True)

In [8]:

temp=0
for epoch in range(150):
  model.train()
  correct = 0

  totalLoss=0
  totalAcc=0
  for i, data in enumerate(trainloader):

    data=data.to(device)

    optimizer.zero_grad()

    weights=torch.rand(args.comp,args.hidden,requires_grad=True).to(device)

    if args.laplace==True:
      feats=data.laplace
    else:
      feats=data.x.float()

    classify,atts, fuser=model(feats,data.edge_index,data.batch,weights,data.batch.unique().shape[0],device,data,pretrain=args.pretrain)

    loss = criterion(classify[data.mask], data.y)

    loss.backward()

    optimizer.step()

    pred = classify.argmax(dim=1)  # Use the class with highest probability.
    correct = pred[data.mask] == data.y # Check against ground-truth labels.
    train_acc = int(correct.sum()) / int(data.mask.sum())  # Derive ratio of correct predictions.

    totalAcc+=train_acc

  totalAcc/=(i+1)

  if epoch %10==0:
      optimizer.param_groups[0]["lr"]=optimizer.param_groups[0]["lr"]*0.95

  val_correct=0
  val_precision=0
  totalVaLoss=0
  totalValAcc=0
  for j, valdata in enumerate(valoader):
    model.eval()
    valdata=valdata.to(device)

    if args.laplace==True:
      valfeats=valdata.laplace
    else:
      valfeats=valdata.x.float()

    val_classify,val_atts,val_fuser=model(valfeats,valdata.edge_index,valdata.batch,weights,valdata.batch.unique().shape[0],device,valdata,pretrain=args.pretrain)

    val_loss= criterion(val_classify[valdata.mask], valdata.y)

    totalVaLoss+=val_loss

    val_pred = val_classify.argmax(dim=1)
    #print(val_pred)
    val_correct = val_pred[valdata.mask] == valdata.y
    val_acc = int(val_correct.sum()) / int(valdata.mask.sum())

    totalValAcc+=val_acc

  totalValAcc/=(j+1)

  if val_acc>=temp:
    temp=val_acc
    when=epoch
    torch.save({
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_weights':weights,
          }, './'+args.webdata+'_'+str(epoch)+'.pth')

  #lplot.append(totalLoss)
  #vplot.append(totalVaLoss)

  print(f'Epoch: {epoch:03d}, Loss: {loss.item():.4f},Train Acc: {totalAcc:.4f}, Val_Loss: {val_loss.item():.4f},Val Acc: {totalValAcc:.4f}')
# plt.plot(torch.stack(lplot).detach().cpu().numpy(),label="training")
# plt.plot(torch.stack(vplot).detach().cpu().numpy(),label="validation")
# plt.legend(loc="upper left")

#from torchmetrics import AUROC
#auroc = AUROC(task="binary")


Epoch: 000, Loss: 1.7126,Train Acc: 0.1956, Val_Loss: 1.8426,Val Acc: 0.2083
Epoch: 001, Loss: 1.6762,Train Acc: 0.1787, Val_Loss: 1.9571,Val Acc: 0.2083
Epoch: 002, Loss: 1.6164,Train Acc: 0.1856, Val_Loss: 1.5511,Val Acc: 0.1667
Epoch: 003, Loss: 1.6402,Train Acc: 0.1881, Val_Loss: 1.6271,Val Acc: 0.2083
Epoch: 004, Loss: 1.6058,Train Acc: 0.1944, Val_Loss: 1.6080,Val Acc: 0.2083
Epoch: 005, Loss: 1.6287,Train Acc: 0.1919, Val_Loss: 1.5677,Val Acc: 0.2083
Epoch: 006, Loss: 1.5556,Train Acc: 0.2137, Val_Loss: 1.7786,Val Acc: 0.2083
Epoch: 007, Loss: 1.5973,Train Acc: 0.2144, Val_Loss: 1.6354,Val Acc: 0.2083
Epoch: 008, Loss: 1.6168,Train Acc: 0.2194, Val_Loss: 1.6891,Val Acc: 0.2083
Epoch: 009, Loss: 1.6616,Train Acc: 0.1963, Val_Loss: 1.7497,Val Acc: 0.2083
Epoch: 010, Loss: 1.5861,Train Acc: 0.1994, Val_Loss: 1.6466,Val Acc: 0.2083
Epoch: 011, Loss: 1.6497,Train Acc: 0.2056, Val_Loss: 1.6641,Val Acc: 0.2083
Epoch: 012, Loss: 1.6337,Train Acc: 0.2013, Val_Loss: 1.5685,Val Acc: 0.2083

In [10]:
classify.shape

torch.Size([32, 5])

In [11]:
data.y.shape

torch.Size([32])

In [12]:
classify[data.mask].shape

IndexError: The shape of the mask [1920] at index 0 does not match the shape of the indexed tensor [32, 5] at index 0

In [None]:
device="cuda"
checkpoint = torch.load('./'+args.webdata+'_'+str(when)+'.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

weights2=checkpoint['best_weights']
weights2=weights2.to(device)

In [None]:
test_correct=0

for j, testdata in enumerate(testloader):
  model.eval()
  testdata=testdata.to(device)

  if args.laplace==True:
    testfeats=testdata.laplace
  else:
    testfeats=testdata.x.float()

  test_classify,val_atts,val_fuser=model(testfeats,testdata.edge_index,testdata.batch,weights2,testdata.batch.unique().shape[0],device,testdata,pretrain=args.pretrain)

  test_pred = test_classify.argmax(dim=1)
  #print(val_pred)
  test_correct = test_pred[testdata.mask] == testdata.y
  test_acc = int(test_correct.sum()) / int(testdata.mask.sum())



In [None]:
test_acc

1.0