In [1]:
from torch import nn
import torch


class E_GCL(nn.Module):
    """
    E(n) Equivariant Convolutional Layer
    re
    """

    def __init__(
        self, 
        input_nf, 
        output_nf, 
        hidden_nf, 
        edges_in_d=0, 
        act_fn=nn.SiLU(), 
        residual=True, 
        attention=False, 
        normalize=False, 
        coords_agg='mean', 
        tanh=False
    ):
        super(E_GCL, self).__init__()
        input_edge = input_nf * 2
        self.residual = residual
        self.attention = attention
        self.normalize = normalize
        self.coords_agg = coords_agg
        self.tanh = tanh
        self.epsilon = 1e-8
        edge_coords_nf = 1

        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, 2*hidden_nf),
            act_fn,
            nn.Linear(2*hidden_nf, hidden_nf)
        )

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, 2*hidden_nf),
            act_fn,
            nn.Linear(2*hidden_nf, output_nf)
        )

        layer = nn.Linear(hidden_nf, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        coord_mlp = []
        coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
        coord_mlp.append(act_fn)
        coord_mlp.append(layer)
        if self.tanh:
            coord_mlp.append(nn.Tanh())
        self.coord_mlp = nn.Sequential(*coord_mlp)

        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

    def edge_model(self, source, target, radial, edge_attr):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target, radial], dim=1)
        else:
            out = torch.cat([source, target, radial, edge_attr], dim=1)
        out = self.edge_mlp(out)
        if self.attention:
            att_val = self.att_mlp(out)
            out = out * att_val
        return out

    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0))
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        out = self.node_mlp(agg)
        if self.residual:
            out = x + out
        return out, agg

    def coord_model(self, coord, edge_index, coord_diff, edge_feat):
        row, col = edge_index
        trans = coord_diff * self.coord_mlp(edge_feat)
        if self.coords_agg == 'sum':
            agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
        elif self.coords_agg == 'mean':
            agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
        else:
            raise Exception('Wrong coords_agg parameter' % self.coords_agg)
        coord = coord + agg
        return coord

    def coord2radial(self, edge_index, coord):
        row, col = edge_index
        coord_diff = coord[row] - coord[col]
        radial = torch.sum(coord_diff**2, 1).unsqueeze(1)

        if self.normalize:
            norm = torch.sqrt(radial).detach() + self.epsilon
            coord_diff = coord_diff / norm

        return radial, coord_diff

    def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None):
        row, col = edge_index
        radial, coord_diff = self.coord2radial(edge_index, coord)

        edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)
        coord = self.coord_model(coord, edge_index, coord_diff, edge_feat)
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)

        return h, coord

    
    
def unsorted_segment_sum(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result.scatter_add_(0, segment_ids, data)
    return result


def unsorted_segment_mean(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    count = data.new_full(result_shape, 0)
    result.scatter_add_(0, segment_ids, data)
    count.scatter_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)

In [2]:
import torch.nn.functional as F
from torch_geometric.nn import voxel_grid, max_pool, max_pool_x, global_mean_pool

class egnn1(torch.nn.Module):
    def __init__(self):
        super(egnn1, self).__init__()
        
        self.conv1 = E_GCL(2, 32, 64, 2, residual=False)
        self.bn1 = torch.nn.BatchNorm1d(32)
        
        self.conv2 = E_GCL(32, 64, 128, 2, residual=False)
        self.bn2 = torch.nn.BatchNorm1d(64)
        
        self.conv3 = E_GCL(64, 128, 150, 2, residual=False)
        self.bn3 = torch.nn.BatchNorm1d(128)
        
        self.fc1 = torch.nn.Linear(128 + 2, 256)
        self.fc2 = torch.nn.Linear(256, 128)
        self.fc3 = torch.nn.Linear(128, 2)
        
    def forward(self, data):
        data.x, data.pos = self.conv1(data.x, data.edge_index, data.pos, data.edge_attr)
        data.x, data.pos = F.elu(data.x), F.elu(data.pos)
        data.x = self.bn1(data.x)
        
        data.x, data.pos = self.conv2(data.x, data.edge_index, data.pos, data.edge_attr)
        data.x, data.pos = F.elu(data.x), F.elu(data.pos)
        data.x = self.bn2(data.x)
        
        data.x, data.pos = self.conv3(data.x, data.edge_index, data.pos, data.edge_attr)
        data.x, data.pos = F.elu(data.x), F.elu(data.pos)
        data.x = self.bn3(data.x)

        x = global_mean_pool(data.x, data.batch)
        pos = global_mean_pool(data.pos, data.batch)
        x = torch.hstack([pos, x])
        
        x = x.view(-1, self.fc1.weight.size(1))
        x = F.elu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        
        x = F.elu(self.fc2(x))
        x = F.dropout(x, training=self.training)
        x = self.fc3(x)
        
        return x

In [3]:
model = egnn1()

In [4]:
model

egnn1(
  (conv1): E_GCL(
    (edge_mlp): Sequential(
      (0): Linear(in_features=7, out_features=64, bias=True)
      (1): SiLU()
      (2): Linear(in_features=64, out_features=128, bias=True)
      (3): SiLU()
      (4): Linear(in_features=128, out_features=64, bias=True)
    )
    (node_mlp): Sequential(
      (0): Linear(in_features=66, out_features=64, bias=True)
      (1): SiLU()
      (2): Linear(in_features=64, out_features=128, bias=True)
      (3): SiLU()
      (4): Linear(in_features=128, out_features=32, bias=True)
    )
    (coord_mlp): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): SiLU()
      (2): Linear(in_features=64, out_features=1, bias=False)
    )
  )
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): E_GCL(
    (edge_mlp): Sequential(
      (0): Linear(in_features=67, out_features=128, bias=True)
      (1): SiLU()
      (2): Linear(in_features=128, out_features=256, bias=True)

In [5]:
from imports.ExtractContactCases import ExtractContactCases

ex=ExtractContactCases('/home/hussain/tactile/data/contact_extraction1', bag_file_name='../dataset_ENVTACT_new2.bag', _keep_raw=True)
#ex.extract()

In [8]:
from imports.TrainModel import TrainModel
from torch_geometric.transforms import Distance, Cartesian, Center, Compose
from torch_geometric import seed_everything
seed_everything(0)
#!rm ../data/contact_extraction1/{train,test,val}/processed/*

tm = TrainModel(
    '/home/hussain/tactile/data/contact_extraction1/', 
    model.to('cuda'), 
    n_epochs=150, 
    transform=Compose([Center(), Cartesian(cat=False)]), 
    features='pol_time', 
    lr = 0.001,
    weight_decay=0.005,
    augment=False, patience=10, batch=4)

In [9]:
tm.train()

training:   0%|          | 0/150 [00:00<?, ?epoch/s]

  0%|          | 0/97 [00:00<?, ?batch/s]

logging


  0%|          | 0/97 [00:00<?, ?batch/s]

logging


  0%|          | 0/97 [00:00<?, ?batch/s]

logging


  0%|          | 0/97 [00:00<?, ?batch/s]

logging


  0%|          | 0/97 [00:00<?, ?batch/s]

logging


  0%|          | 0/97 [00:00<?, ?batch/s]

logging


  0%|          | 0/97 [00:00<?, ?batch/s]

logging


  0%|          | 0/97 [00:00<?, ?batch/s]

logging


  0%|          | 0/97 [00:00<?, ?batch/s]

logging


  0%|          | 0/97 [00:00<?, ?batch/s]

logging


  0%|          | 0/97 [00:00<?, ?batch/s]

logging


  0%|          | 0/97 [00:00<?, ?batch/s]

logging


  0%|          | 0/97 [00:00<?, ?batch/s]

KeyboardInterrupt: 

In [10]:
for i in model.parameters()

<generator object Module.parameters at 0x7fb06d69a820>

In [None]:
tm.test()