In [1]:
import torch
import numpy as np

from pathlib import Path
import os
import json 

from torch_geometric.data import Data, Dataset, Batch
from torch_geometric.nn.pool import radius_graph, knn_graph
from torch_geometric.nn.norm import BatchNorm
from torch_geometric.transforms import Cartesian
from tqdm.auto import tqdm
import h5py


torch.set_float32_matmul_precision('medium')

def make_graph(ev_arr, gt, beta=0.5e4):
    ts_sample = ev_arr[:, 3] - ev_arr[0, 3]
    ts_sample = torch.tensor(ts_sample*beta).float().reshape(-1, 1)

    coords = torch.tensor(ev_arr[:, :2]).float()
    pos = torch.hstack((ts_sample, coords))

    edge_index = knn_graph(pos, k=32)

    pol = torch.tensor(ev_arr[:, 3]).float().reshape(-1, 1)
    feature = pol
    #feature = torch.hstack((pos, pol))

    graph = Data(x=feature, edge_index=edge_index, pos=pos, y = torch.tensor(gt)[None, :])
    graph = Cartesian()(graph)

    return graph

class RecurrentTactileDataset:

    def __init__(
            self,
            graphs_per_pred = 5,
            root = '../data/extractions/recurrent/train/'
            ):
        self.root = Path(root).resolve()
        self.graphs_per_pred = graphs_per_pred

        if len(os.listdir(self.root / 'raw')) > len(os.listdir(self.root / 'processed')):
            print('processing')
            self.process()

    def process(self):
        raw_files = os.listdir(self.root / 'raw')           
        for idx, file in enumerate(tqdm(raw_files, desc='processing')):
            h5_file = h5py.File(self.root / 'raw' / file)
            events = h5_file['events'][:]
            case = h5_file['case'][:]
            
            knots = np.linspace(events[0, 2], events[-1, 2], num=self.graphs_per_pred+1)
            knot_idx = np.searchsorted(events[:, 2], knots)
            
            for i in range(1, self.graphs_per_pred+1):
                outdir = self.root / 'processed' / f'{idx:04}_{i}.pt'
                torch.save(make_graph(events[knot_idx[i-1]:knot_idx[i]], gt=case), outdir)

    def __getitem__(self, idx):
        out_graphs = []
        for i in range(1, self.graphs_per_pred+1):
            out_graphs.append(torch.load(self.root / 'processed' / f'{idx:04}_{i}.pt' ))
        return out_graphs
    
    def __len__(self):
        return len(os.listdir(self.root / 'raw'))

In [2]:
train_dset = RecurrentTactileDataset(root='../data/extractions/recurrent/train/')
val_dset = RecurrentTactileDataset(root='../data/extractions/recurrent/val/')

In [3]:
def collate_fn(list_of_list_of_graphs):
    batch_size = len(list_of_list_of_graphs)
    graphs_per_pred = len(list_of_list_of_graphs[0])
    make_batch = lambda list_of_graphs: Batch.from_data_list(list_of_graphs)
    return [
        make_batch([list_of_list_of_graphs[i][j] for i in range(batch_size)]) 
        for j in range(graphs_per_pred)
        ]


In [4]:
from models.modules import MaxPooling, MaxPoolingX
from torch_geometric.nn import SplineConv
import torch.nn as nn
from typing import Callable, List, Optional, Tuple, Union
from torch_geometric.nn.pool import max_pool, voxel_grid, max_pool_x, global_max_pool

class MaxPooling2(torch.nn.Module):

    def __init__(self, stride: int, transform: Callable[[Data, ], Data] = None):
        super(MaxPooling2, self).__init__()
        self.voxel_size = [stride+1, stride+1]
        self.scale = stride
        self.transform = transform

    def forward(self, data: Data, return_data_obj: bool = False
                ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.LongTensor, torch.Tensor, torch.Tensor], Data]:

        cluster = voxel_grid(data.pos[:, 1:3], batch=data.batch, size=self.voxel_size)  # cluster on spatial dimension
        # data = Data(x=x, pos=pos, edge_index=edge_index, batch=batch)
        data = max_pool(cluster, data=data, transform=self.transform)  # transform for new edge attributes
        data.pos[:,1:3] = data.pos[:,1:3] // self.scale
        if return_data_obj:
            return data
        else:
            return data.x, data.pos, getattr(data, "batch"), data.edge_index, data.edge_attr

    def __repr__(self):
        return f"{self.__class__.__name__}(voxel_size={self.voxel_size})"

In [5]:
class GraphEncoder(nn.Module):

    def __init__(
            self, 
            n_feature=4
            ) -> None:
        super().__init__()

        pseudo = Cartesian(norm=True, cat=False)

        self.conv1 = SplineConv(n_feature, 8, dim=3, kernel_size=2)
        self.norm1 = BatchNorm(8)

        self.conv2 = SplineConv(8, 16, dim=3, kernel_size=2)
        self.norm2 = BatchNorm(16)
        self.pool2 = MaxPooling2(2, transform=pseudo)

        self.conv3 = SplineConv(16, 16, dim=3, kernel_size=2)
        self.norm3 = BatchNorm(16)
        self.pool3 = MaxPooling2(2, transform=pseudo)
        
        self.conv4 = SplineConv(16, 16, dim=3, kernel_size=2)
        self.norm4 = BatchNorm(16)
        self.pool4 = MaxPooling2(2, transform=pseudo)

        self.conv5 = SplineConv(16, 32, dim=3, kernel_size=2)
        self.norm5 = BatchNorm(32)

        self.conv6 = SplineConv(32, 64, dim=3, kernel_size=2)
        self.norm6 = BatchNorm(64)


    def forward(self, in_data):
        
        out = []
        is_list = isinstance(in_data, tuple) or isinstance(in_data, list)
        if is_list:
            for data in in_data:
                data = data.to('cuda')
                
                batch_size = data.num_graphs
                
                data.x = nn.functional.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
                data.x = self.norm1(data.x)

                data.x = nn.functional.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
                data.x = self.norm2(data.x)
                data = self.pool2(data, return_data_obj=True)

                data.x = nn.functional.elu(self.conv3(data.x, data.edge_index, data.edge_attr))
                data.x = self.norm3(data.x)
                data = self.pool3(data, return_data_obj=True)


                data.x = nn.functional.elu(self.conv4(data.x, data.edge_index, data.edge_attr))
                data.x = self.norm4(data.x)
                data = self.pool4(data, return_data_obj=True)

                data.x = nn.functional.elu(self.conv5(data.x, data.edge_index, data.edge_attr))
                data.x = self.norm5(data.x)

                data.x = nn.functional.elu(self.conv6(data.x, data.edge_index, data.edge_attr))
                data.x = self.norm6(data.x)
                x = global_max_pool(data.x, data.batch)

                out.append(x)
        else:
            data = in_data
            data.x = nn.functional.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
            data.x = self.norm1(data.x)

            data.x = nn.functional.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
            data.x = self.norm2(data.x)
            data = self.pool2(data, return_data_obj=True)

            data.x = nn.functional.elu(self.conv3(data.x, data.edge_index, data.edge_attr))
            data.x = self.norm3(data.x)
            data = self.pool3(data, return_data_obj=True)

            data.x = nn.functional.elu(self.conv4(data.x, data.edge_index, data.edge_attr))
            data.x = self.norm4(data.x)
            data = self.pool4(data, return_data_obj=True)

            data.x = nn.functional.elu(self.conv5(data.x, data.edge_index, data.edge_attr))
            data.x = self.norm5(data.x)

            data.x = nn.functional.elu(self.conv6(data.x, data.edge_index, data.edge_attr))
            data.x = self.norm6(data.x)
            x = global_max_pool(data.x, data.batch)

            out = x            
        # print(out)
        return out

In [6]:
batch_size = 4

train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=1, collate_fn=collate_fn)

In [7]:
sample = next(iter(train_loader))
sample

[DataBatch(x=[679, 1], edge_index=[2, 21612], y=[4, 2], pos=[679, 3], edge_attr=[21612, 3], batch=[679], ptr=[5]),
 DataBatch(x=[4876, 1], edge_index=[2, 156032], y=[4, 2], pos=[4876, 3], edge_attr=[156032, 3], batch=[4876], ptr=[5]),
 DataBatch(x=[7595, 1], edge_index=[2, 243040], y=[4, 2], pos=[7595, 3], edge_attr=[243040, 3], batch=[7595], ptr=[5]),
 DataBatch(x=[13801, 1], edge_index=[2, 441632], y=[4, 2], pos=[13801, 3], edge_attr=[441632, 3], batch=[13801], ptr=[5]),
 DataBatch(x=[15009, 1], edge_index=[2, 480288], y=[4, 2], pos=[15009, 3], edge_attr=[480288, 3], batch=[15009], ptr=[5])]

In [8]:
import pytorch_lightning as pl

class GNNLSTM(pl.LightningModule):
    
    def __init__(self):
        super().__init__()
        self.gnn = GraphEncoder(n_feature=1)
        self.lstm = torch.nn.LSTM(64, 64)
        self.pred = torch.nn.Linear(64, 2)
        
    def training_step(self, batch, batch_idx):
        y = batch[0].y
        out = self.gnn(batch)
        out = torch.stack(out)
        out, (hn, cn) = self.lstm(out)
        out = self.pred(cn)
        loss = nn.functional.l1_loss(out, y)
        self.log_dict({'train_loss': loss, 'train_loss_degrees': torch.rad2deg(loss)}, batch_size=batch_size, prog_bar=True)
        return loss
        
    def validation_step(self, batch, batch_idx):
        y = batch[0].y
        out = self.gnn(batch)
        out = torch.stack(out)
        out, (hn, cn) = self.lstm(out)
        out = self.pred(cn)
        loss = nn.functional.l1_loss(out, y)
        self.log_dict({'val_loss': loss, 'val_loss_degrees': torch.rad2deg(loss)}, batch_size=batch_size, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def forward(self, in_data):
        out = self.gnn(in_data)
        out = torch.stack(out)
        out, (hn, cn) = self.lstm(out)
        out = self.pred(cn[0])
        return out

In [9]:
model = GNNLSTM().cuda()
model

GNNLSTM(
  (gnn): GraphEncoder(
    (conv1): SplineConv(1, 8, dim=3)
    (norm1): BatchNorm(8)
    (conv2): SplineConv(8, 16, dim=3)
    (norm2): BatchNorm(16)
    (pool2): MaxPooling2(voxel_size=[3, 3])
    (conv3): SplineConv(16, 16, dim=3)
    (norm3): BatchNorm(16)
    (pool3): MaxPooling2(voxel_size=[3, 3])
    (conv4): SplineConv(16, 16, dim=3)
    (norm4): BatchNorm(16)
    (pool4): MaxPooling2(voxel_size=[3, 3])
    (conv5): SplineConv(16, 32, dim=3)
    (norm5): BatchNorm(32)
    (conv6): SplineConv(32, 64, dim=3)
    (norm6): BatchNorm(64)
  )
  (lstm): LSTM(64, 64)
  (pred): Linear(in_features=64, out_features=2, bias=True)
)

In [10]:
trainer = pl.Trainer(accelerator="gpu", max_epochs=300, logger=pl.loggers.CSVLogger('../results/pl'))
trainer.fit(model, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type         | Params
--------------------------------------
0 | gnn  | GraphEncoder | 29.3 K
1 | lstm | LSTM         | 33.3 K
2 | pred | Linear       | 130   
--------------------------------------
62.7 K    Trainable params
0         Non-trainable params
62.7 K    Total params
0.251     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  data.pos[:,1:3] = data.pos[:,1:3] // self.scale
  loss = nn.functional.l1_loss(out, y)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  loss = nn.functional.l1_loss(out, y)
  loss = nn.functional.l1_loss(out, y)


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [11]:
model

GNNLSTM(
  (gnn): GraphEncoder(
    (conv1): SplineConv(1, 8, dim=3)
    (norm1): BatchNorm(8)
    (conv2): SplineConv(8, 16, dim=3)
    (norm2): BatchNorm(16)
    (pool2): MaxPooling2(voxel_size=[3, 3])
    (conv3): SplineConv(16, 16, dim=3)
    (norm3): BatchNorm(16)
    (pool3): MaxPooling2(voxel_size=[3, 3])
    (conv4): SplineConv(16, 16, dim=3)
    (norm4): BatchNorm(16)
    (pool4): MaxPooling2(voxel_size=[3, 3])
    (conv5): SplineConv(16, 32, dim=3)
    (norm5): BatchNorm(32)
    (conv6): SplineConv(32, 64, dim=3)
    (norm6): BatchNorm(64)
  )
  (lstm): LSTM(64, 64)
  (pred): Linear(in_features=64, out_features=2, bias=True)
)