In [1]:
import torch 
from torch_geometric.nn import GCNConv
from torch_geometric.nn import VGAE, global_mean_pool
import os

In [2]:
class ContactMapEncoder(torch.nn.Module):
    def __init__(self, in_channel, out_channels):
        super(ContactMapEncoder, self).__init__()
        
        # Define GCNConv layers for mean and log standard deviation
        self.conv_mu = GCNConv(in_channel, out_channels)
        self.conv_logstd = GCNConv(in_channel, out_channels)

    def forward(self, x, edge_index):
        # Compute latent mean (mu)
        z_mu = self.conv_mu(x, edge_index)

        # Compute latent log standard deviation (logstd) with Softplus
        z_logstd = self.conv_logstd(x, edge_index)

        return z_mu, z_logstd

In [3]:
# parameters
num_features = 1024
out_channels = 200

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
# model
model = VGAE(ContactMapEncoder(num_features, out_channels))

# move to GPU 
model = model.to(device)
model.load_state_dict(torch.load('D:\\year 4\\semester 1\\BT\\BT 4033\\structure and seq\\newVGAE_GCN.pt', weights_only=True))# inizialize the optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# print(model)

<All keys matched successfully>

In [4]:
def genPoolEmbedd(data):
    model.eval()
    with torch.no_grad():
        graph_embedd = model.encode(data.x, data.edge_index)
        pooled = global_mean_pool(graph_embedd, torch.zeros(graph_embedd.shape[0], dtype=torch.long)).flatten()
    return pooled

In [5]:
# Set the folder path and the number of files to select
folder_path = 'D:\\year 4\\semester 1\\BT\\BT 4033\\structure and seq\\aug\\'

# Get the list of all files in the folder
all_files = os.listdir(folder_path)

# Filter out any non-files (like directories)
all_files = [f for f in all_files if os.path.isfile(os.path.join(folder_path, f))]

In [None]:
saving_path = 'D:\\year 4\\semester 1\\BT\\BT 4033\\structure and seq\\pool\\'

In [7]:
for map in all_files:
    data = torch.load(folder_path+map)
    pooled = genPoolEmbedd(data)
    torch.save(pooled, f'{saving_path}{map}')