In [1]:
import torch
import torch_geometric

from torch.nn import Linear
from torch.nn.functional import elu
from torch_geometric.nn.conv import SplineConv
from torch_geometric.nn.norm import BatchNorm
from torch_geometric.transforms import Cartesian


import torch

from torch_geometric.data import Data
from torch_geometric.nn.pool import max_pool_x, voxel_grid
from typing import List, Optional, Tuple, Union


class MaxPoolingX(torch.nn.Module):

    def __init__(self, voxel_size: List[int], size: int):
        super(MaxPoolingX, self).__init__()
        self.voxel_size = voxel_size
        self.size = size

    def forward(self, x: torch.Tensor, pos: torch.Tensor, batch: Optional[torch.Tensor] = None
                ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.Tensor, torch.Tensor], Data]:
        cluster = voxel_grid(pos, batch=batch, size=self.voxel_size)
        x, _ = max_pool_x(cluster, x, batch, size=self.size)
        return x

    def __repr__(self):
        return f"{self.__class__.__name__}(voxel_size={self.voxel_size}, size={self.size})"

import torch

from torch_geometric.data import Data
from torch_geometric.nn.pool import max_pool, voxel_grid
from typing import Callable, List, Optional, Tuple, Union


class MaxPooling(torch.nn.Module):

    def __init__(self, size: List[int], transform: Callable[[Data, ], Data] = None):
        super(MaxPooling, self).__init__()
        self.voxel_size = list(size)
        self.transform = transform

    def forward(self, x: torch.Tensor, pos: torch.Tensor, batch: Optional[torch.Tensor] = None,
                edge_index: Optional[torch.Tensor] = None, return_data_obj: bool = False
                ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.Tensor, torch.Tensor], Data]:
        assert edge_index is not None, "edge_index must not be None"

        cluster = voxel_grid(pos[:, :2], batch=batch, size=self.voxel_size)
        data = Data(x=x, pos=pos, edge_index=edge_index, batch=batch)
        data = max_pool(cluster, data=data, transform=self.transform)  # transform for new edge attributes
        if return_data_obj:
            return data
        else:
            return data.x, data.pos, getattr(data, "batch"), data.edge_index, data.edge_attr

    def __repr__(self):
        return f"{self.__class__.__name__}(voxel_size={self.voxel_size})"

class GraphRes_lightheavy(torch.nn.Module):

    def __init__(self):
        super(GraphRes_lightheavy, self).__init__()
        dim = 3

        bias = False
        root_weight = False
        pooling_size=(16/346, 12/260)

        # Set dataset specific hyper-parameters.
        kernel_size = 2
        n = [1, 8, 16, 16, 16, 32, 32, 32, 32]
        pooling_outputs = 32
        #kernel_size = 8
        #n = [1, 16, 32, 32, 32, 128, 128, 128]
        #pooling_outputs = 128

        self.conv1 = SplineConv(n[0], n[1], dim=dim, kernel_size=kernel_size, bias=bias, root_weight=root_weight)
        self.norm1 = BatchNorm(in_channels=n[1])
        self.conv2 = SplineConv(n[1], n[2], dim=dim, kernel_size=kernel_size, bias=bias, root_weight=root_weight)
        self.norm2 = BatchNorm(in_channels=n[2])

        self.conv3 = SplineConv(n[2], n[3], dim=dim, kernel_size=kernel_size, bias=bias, root_weight=root_weight)
        self.norm3 = BatchNorm(in_channels=n[3])
        self.conv4 = SplineConv(n[3], n[4], dim=dim, kernel_size=kernel_size, bias=bias, root_weight=root_weight)
        self.norm4 = BatchNorm(in_channels=n[4])

        self.conv5 = SplineConv(n[4], n[5], dim=dim, kernel_size=kernel_size, bias=bias, root_weight=root_weight)
        self.norm5 = BatchNorm(in_channels=n[5])
        self.pool5 = MaxPooling(pooling_size, transform=Cartesian(norm=True, cat=False))

        self.conv6 = SplineConv(n[5], n[6], dim=dim, kernel_size=kernel_size, bias=bias, root_weight=root_weight)
        self.norm6 = BatchNorm(in_channels=n[6])
        self.conv7 = SplineConv(n[6], n[7], dim=dim, kernel_size=kernel_size, bias=bias, root_weight=root_weight)
        self.norm7 = BatchNorm(in_channels=n[7])

        self.pool7 = MaxPoolingX(0.25, size=16)
        self.fc = Linear(pooling_outputs * 16, out_features=2, bias=bias)

    def forward(self, data: torch_geometric.data.Batch) -> torch.Tensor:
        data.x = elu(self.conv1(data.x, data.edge_index, data.edge_attr))
        data.x = self.norm1(data.x)
        data.x = elu(self.conv2(data.x, data.edge_index, data.edge_attr))
        data.x = self.norm2(data.x)

        x_sc = data.x.clone()
        data.x = elu(self.conv3(data.x, data.edge_index, data.edge_attr))
        data.x = self.norm3(data.x)
        data.x = elu(self.conv4(data.x, data.edge_index, data.edge_attr))
        data.x = self.norm4(data.x)
        data.x = data.x + x_sc

        data.x = elu(self.conv5(data.x, data.edge_index, data.edge_attr))
        data.x = self.norm5(data.x)
        data = self.pool5(data.x, pos=data.pos, batch=data.batch, edge_index=data.edge_index, return_data_obj=True)

        x_sc = data.x.clone()
        data.x = elu(self.conv6(data.x, data.edge_index, data.edge_attr))
        data.x = self.norm6(data.x)
        data.x = elu(self.conv7(data.x, data.edge_index, data.edge_attr))
        data.x = self.norm7(data.x)
        data.x = data.x + x_sc

        x = self.pool7(data.x, pos=data.pos[:, :2], batch=data.batch)
        x = x.view(-1, self.fc.in_features)
        return self.fc(x)

model = GraphRes_lightheavy().cuda()

In [2]:
torch.load('/home/hussain/Downloads/model.pt')

GraphRes_lightheavy(
  (conv1): SplineConv(1, 8, dim=3)
  (norm1): BatchNorm(8)
  (conv2): SplineConv(8, 16, dim=3)
  (norm2): BatchNorm(16)
  (conv3): SplineConv(16, 16, dim=3)
  (norm3): BatchNorm(16)
  (conv4): SplineConv(16, 16, dim=3)
  (norm4): BatchNorm(16)
  (conv5): SplineConv(16, 32, dim=3)
  (norm5): BatchNorm(32)
  (pool5): MaxPooling(voxel_size=[0.046242774566473986, 0.046153846153846156])
  (conv6): SplineConv(32, 32, dim=3)
  (norm6): BatchNorm(32)
  (conv7): SplineConv(32, 32, dim=3)
  (norm7): BatchNorm(32)
  (pool7): MaxPoolingX(voxel_size=0.25, size=16)
  (fc): Linear(in_features=512, out_features=2, bias=False)
)

In [5]:
import json
with open('data/extractions/contact_extraction2/extraction_params.json', 'w') as f:
    params = json.dump({}, f)

In [9]:
from src.imports.TrainModel import TrainModel
!rm -rf data/extractions/contact_extraction2/{test,train,val}/processed/*.pt

tm = TrainModel('data/extractions/contact_extraction2/', model.cuda(), batch=4, features='pol')

Processing...


{'kNN': 32, 'node_features': 'pol'}


Done!
Processing...


{'kNN': 32, 'node_features': 'pol'}


Done!
Processing...


{'kNN': 32, 'node_features': 'pol'}


Done!


In [10]:
tm.validate()

(2.6035049446375975, 6.5740485191345215)

In [7]:
for i in tm.val_loader:
    print(i)

DataBatch(x=[2851, 4], edge_index=[2, 44938], y=[1, 2], pos=[2851, 3], edge_attr=[44938, 3], batch=[2851], ptr=[2])
DataBatch(x=[3227, 4], edge_index=[2, 51127], y=[1, 2], pos=[3227, 3], edge_attr=[51127, 3], batch=[3227], ptr=[2])
DataBatch(x=[6523, 4], edge_index=[2, 104087], y=[1, 2], pos=[6523, 3], edge_attr=[104087, 3], batch=[6523], ptr=[2])
DataBatch(x=[9330, 4], edge_index=[2, 149280], y=[1, 2], pos=[9330, 3], edge_attr=[149280, 3], batch=[9330], ptr=[2])
DataBatch(x=[6930, 4], edge_index=[2, 110725], y=[1, 2], pos=[6930, 3], edge_attr=[110725, 3], batch=[6930], ptr=[2])
DataBatch(x=[7467, 4], edge_index=[2, 119373], y=[1, 2], pos=[7467, 3], edge_attr=[119373, 3], batch=[7467], ptr=[2])
DataBatch(x=[3565, 4], edge_index=[2, 56503], y=[1, 2], pos=[3565, 3], edge_attr=[56503, 3], batch=[3565], ptr=[2])
DataBatch(x=[5554, 4], edge_index=[2, 88549], y=[1, 2], pos=[5554, 3], edge_attr=[88549, 3], batch=[5554], ptr=[2])
DataBatch(x=[7403, 4], edge_index=[2, 118427], y=[1, 2], pos=[74