In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import gc

In [None]:
X_dict = dict()
for root, dirs, files in os.walk("/content/drive/MyDrive/Colab_Notebooks/CV_handwriting/lamono_tesnors"):
    for filename in files:
      sparse_tensor = torch.load("/content/drive/MyDrive/Colab_Notebooks/CV_handwriting/lamono_tesnors/"+filename)
      X_dict[filename[:-3]] = sparse_tensor

In [None]:
X_dict

In [None]:
A_hat = torch.load("/content/drive/MyDrive/Colab_Notebooks/CV_handwriting/A_hat.pt")

In [None]:
class GCNLayer(nn.Module):
    """
        GCN layer

        Args:
            input_dim (int): Dimension of the input
            output_dim (int): Dimension of the output (a softmax distribution)
            A (torch.Tensor): 2D adjacency matrix
    """

    def __init__(self, A: torch.Tensor):
        super(GCNLayer, self).__init__()

        A = A.coalesce()

        #(D^-1/2 * A_hat * D^-1/2)
        #each element of A_hat aij should be multiplied on 1/(di*dj)^(1/2)
        #where di - number of graph edges of i node
        #dj - number of graph edges of j node
        #Since we have removed the edge pixels , each vertex will have 8 neighbours.
        #So each element of A_hat matrix should be multiplied on 1/(8*8)^(1/2) = 1/8

        A_fin = torch.sparse_coo_tensor(
            A.indices(),
            A.values() * 1/8,
            A.size()
        )
        self.A_fin = A_fin

        self.W = nn.Parameter(torch.rand(input_dim, output_dim))

    def forward(self, X: torch.Tensor):

        # (D^-1/2 * A_hat * D^-1/2) * X
        support_1 = torch.sparse.mm(self.A_fin,X.reshape(-1,1))

        # (D^-1/2 * A_hat * D^-1/2) * X * W
        support_2 = torch.sparse.mm(support_1, self.W)
        print(support_2.size())
        # ReLU(D^-1/2 * A_hat * D^-1/2 * X * W)
        H = F.relu(support_2)

        return H

In [None]:
# Create the GCN Layer
gcn_layer = GCNLayer(A_hat)

# Example input feature matrix
X = X_dict['001'].to_dense().to(torch.float)

output = gcn_layer(X)

print(output)

In [None]:
output.count_nonzero()