In [1]:
#libraries
import torch 
from torch_geometric.nn import GCNConv
from torch_geometric.nn import VGAE
from torch.nn import BatchNorm1d, LeakyReLU, Softplus
from torch_geometric.utils import negative_sampling
from torch_geometric.loader.neighbor_loader import NeighborLoader
import os

In [2]:
# Set the folder path and the number of files to select
trainfolder_path = 'E:\\dat\\'

# Get the list of all files in the folder
all_files = os.listdir(trainfolder_path)

# Filter out any non-files (like directories)
train_files = [f for f in all_files if os.path.isfile(os.path.join(trainfolder_path, f))]

In [3]:
# Set the folder path and the number of files to select
testfolder_path = 'E:\\test\\'
num_files_to_select = 50  # Define the number of files you want to select randomly

# Get the list of all files in the folder/
all_files = os.listdir(testfolder_path)

# Filter out any non-files (like directories)
test_files = [f for f in all_files if os.path.isfile(os.path.join(testfolder_path, f))]

In [4]:
train_dat = [torch.load(f'E:\\dat\\{file}') for file in train_files]
test_dat = [torch.load(f'E:\\test\\{file}') for file in test_files]

In [5]:
for data in train_dat:
    neigh_samp = NeighborLoader(data, num_neighbors=[-1,-1,-1], subgraph_type='bidirectional')
    

In [None]:
class ContactMapEncoder(torch.nn.Module):
    def __init__(self, in_channel, out_channels):
        super(ContactMapEncoder, self).__init__()
        
        # Define GCNConv layers for mean and log standard deviation
        self.conv_mu = GCNConv(in_channel, out_channels)
        self.conv_logstd = GCNConv(in_channel, out_channels)

    def forward(self, x, edge_index):
        # Compute latent mean (mu)
        z_mu = self.conv_mu(x, edge_index)

        # Compute latent log standard deviation (logstd) with Softplus
        z_logstd = self.conv_logstd(x, edge_index)

        return z_mu, z_logstd

In [22]:
# parameters
num_features = 1024
out_channels = 200

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
# model
model = VGAE(ContactMapEncoder(num_features, out_channels))

# move to GPU 
model = model.to(device)

# inizialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

print(model)

VGAE(
  (encoder): ContactMapEncoder(
    (conv_mu): GCNConv(1024, 200)
    (conv_logstd): GCNConv(1024, 200)
  )
  (decoder): InnerProductDecoder()
)


In [11]:
# Manually print out the model summary
def print_model_summary(model):
    print("Model Summary:")
    print("\nModel Parameters:")
    total_params = 0
    for name, param in model.named_parameters():
        print(f"{name}: {param.numel()} parameters")
        total_params += param.numel()
    print(f"\nTotal Parameters: {total_params}")

print_model_summary(model)

Model Summary:

Model Parameters:
encoder.conv1.bias: 512 parameters
encoder.conv1.lin.weight: 524288 parameters
encoder.conv_mu.bias: 256 parameters
encoder.conv_mu.lin.weight: 131072 parameters
encoder.conv_logstd.bias: 256 parameters
encoder.conv_logstd.lin.weight: 131072 parameters
encoder.bn1.weight: 512 parameters
encoder.bn1.bias: 512 parameters

Total Parameters: 788480


In [14]:
def train(data, beta=1):  #take graph data obj as input
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x, data.edge_index)
    loss = model.recon_loss(z, data.edge_index)
    loss = loss + beta * (1 / data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    return float(loss)

def test():
    model.eval()
    with torch.no_grad():
        aucs = []
        for data in test_dat:
            data.to(device)
            neg = negative_sampling(data.edge_index, num_nodes= data.num_nodes, force_undirected=True)
            z = model.encode(data.x, data.edge_index)
            auc, ap = model.test(z, data.edge_index, neg)
            aucs.append(auc)
        print(torch.mean(torch.tensor(aucs)))

In [15]:
for epoch in neigh_samp:
    mini_loss = train(data)
    print(f'train loss {mini_loss}') 
    print('')
    test()
    print('')
    inp = input('next epoch ?')
    if inp == 'n':
        break
    else:
        torch.save(model.state_dict(), 'newVGAE_GCN.pt')

train loss 14.852560997009277

tensor(0.8348, dtype=torch.float64)



In [None]:
torch.save(model.state_dict(), '/home/hpc_users/2019s17273@stu.cmb.ac.lk/sachintha/structure/VGAE_GCN.pt')