In [1]:
import torch
from torch_geometric.nn import SplineConv, voxel_grid, max_pool, max_pool_x
import torch.nn.functional as F
import torch_geometric.transforms as T
from imports.TactileDataset import TactileDataset
class nvs_no_skip(torch.nn.Module):
    def __init__(self):
        super(nvs_no_skip, self).__init__()
        self.conv1 = SplineConv(1, 64, dim=3, kernel_size=3)
        self.bn1 = torch.nn.BatchNorm1d(64)
        self.conv2 = SplineConv(64, 128, dim=3, kernel_size=3)
        self.bn2 = torch.nn.BatchNorm1d(128)
        self.conv3 = SplineConv(128, 256, dim=3, kernel_size=3)
        self.bn3 = torch.nn.BatchNorm1d(256)
        self.conv4 = SplineConv(256, 512, dim=3, kernel_size=3)
        self.bn4 = torch.nn.BatchNorm1d(512)
        self.fc1 = torch.nn.Linear(64 * 512, 1024)
        self.fc2 = torch.nn.Linear(1024, 2)

    def forward(self, data):
        data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
        data.x = self.bn1(data.x)
        cluster = voxel_grid(data.pos, batch=data.batch, size=0.05)
        data = max_pool(cluster, data, transform=T.Cartesian(cat=False))

        data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
        data.x = self.bn2(data.x)
        cluster = voxel_grid(data.pos, batch=data.batch, size=0.1)
        data = max_pool(cluster, data, transform=T.Cartesian(cat=False))
        
        data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr))
        data.x = self.bn3(data.x)
        cluster = voxel_grid(data.pos,batch= data.batch, size=0.15)
        data = max_pool(cluster, data, transform=T.Cartesian(cat=False))

        data.x = F.elu(self.conv4(data.x, data.edge_index, data.edge_attr))
        data.x = self.bn4(data.x)
        cluster = voxel_grid(data.pos, batch=data.batch, size=0.25)
        x,_ = max_pool_x(cluster, data.x, batch=data.batch, size=64)

        x = x.view(-1, self.fc1.weight.size(1))
        x = F.elu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        
        return x


In [2]:
import json
import torch
import torch_geometric as pyg
from tqdm.auto import tqdm, trange  
from pathlib import Path
from numpy import pi
from pandas import DataFrame

class TrainModel():

    def __init__(
        self, 
        extraction_case_dir, 
        model,
        n_epochs = 150,
        optimizer = 'adam',
        lr = 0.01,
        loss_func = torch.nn.L1Loss(),
        transform = None,
        features = 'all',
        weight_decay=0,
        patience=10,
        batch = 1,
        augment=False
        ):

        self.extraction_case_dir = Path(extraction_case_dir)
        self.transform = transform

        self.train_data = TactileDataset(self.extraction_case_dir / 'train', transform=transform, features=features, augment=augment)
        self.val_data = TactileDataset(self.extraction_case_dir / 'val', features=features)
        self.test_data = TactileDataset(self.extraction_case_dir / 'test', features=features)

        self.train_loader = pyg.loader.DataLoader(self.train_data, shuffle=True, batch_size=batch)
        self.val_loader = pyg.loader.DataLoader(self.val_data)
        self.test_loader = pyg.loader.DataLoader(self.test_data)

        self.model = model
        self.n_epochs = n_epochs


        if optimizer == 'adam':
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        else:
            raise NotImplementedError('use tm.optimizer = torch.optim.<optimizer>')
        
        #self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', min_lr=1e-5, patience=patience)

        self.loss_func = loss_func

        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        

    def train(self):
        self.train_losses = []
        self.val_losses = []
        self.lr = []

        name = str(type(self.model)).split('.')[-1][:-2]
        path = Path('results') / name

        for epoch in trange(self.n_epochs, desc='training', unit='epoch'):
            
            if (epoch == 5):
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = 0.001

            if epoch == 60:
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = 0.0001

            if epoch == 100:
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = 0.00001
            
            
            #bunny(epoch)
            epoch_loss = 0
            lr = self.optimizer.param_groups[0]['lr']
            self.lr.append(lr)
            val_loss = torch.inf
            with tqdm(self.train_loader, unit="batch") as tepoch:
                for i, data in enumerate(tepoch):
                    tepoch.set_description(f"Epoch {epoch}")
                    with torch.autograd.detect_anomaly():
                        data = data.to(self.device)
                        self.optimizer.zero_grad()
                        end_point = self.model(data)
                        loss = self.loss_func(end_point, data.y)
                        loss.backward()
                        self.optimizer.step()
                        lr = self.optimizer.param_groups[0]['lr']

                        epoch_loss += loss.detach().item()
                    
                        tepoch.set_postfix({
                            'train_loss': epoch_loss / (i + 1), 
                            'train_loss_degrees': epoch_loss / (i + 1) * 180/pi, 
                            'val_loss': self.val_losses[epoch - 1] if epoch > 0 else 'na',
                            'val_loss_degrees': self.val_losses[epoch - 1] * 180/pi if epoch > 0 else 'na',
                            'lr': lr
                            })

                #self.scheduler.step(val_loss)
                epoch_loss /= len(self.train_data)
                val_loss = self.validate()
                tepoch.set_postfix({'train_loss': epoch_loss, 'val_loss': val_loss})
                self.train_losses.append(epoch_loss)
                self.val_losses.append(val_loss)
            if (epoch + 1) % 1 == 0:
                self.log(current_epoch=epoch)
        torch.save(self.model, path / 'model.pt')

    def validate(self):
        loss = 0
        for i, data in enumerate(self.val_loader):      
            data = data.to(self.device)
            end_point = self.model(data)

            loss += self.loss_func(end_point, data.y).detach().item()
        loss /= len(self.val_data)
        return loss
    
    def test(self):
        loss = 0
        for i, data in enumerate(self.test_loader):      
            data = data.to(self.device)
            end_point = self.model(data)

            loss += self.loss_func(end_point, data.y).detach().item()
        loss /= len(self.train_data)
        return loss

    def augment(self, batch):
        pass

    def log(self, current_epoch):
        #find model name
        print('logging')
        name = str(type(self.model)).split('.')[-1][:-2]
        path = Path('results') / name
        if not path.exists():
            path.mkdir(parents=True)

        with open(path / 'training_params.json', 'w') as f:
            params = {
                'model': name,
                'extraction_used': str(self.extraction_case_dir),
                'n_epochs': self.n_epochs,
                'final_val_loss_degrees': self.val_losses[-1] * 180 / pi,
            }
            json.dump(params, f, indent=4)

        train_log = { 
            'epoch': [i for i in range(1, current_epoch+2)],
            'train_loss': self.train_losses,
            'val_loss': self.val_losses,
            'lr': self.lr
        }
        DataFrame(train_log).to_csv(path / 'train_log.csv', index=False)

In [3]:
from torch_geometric.loader import DataLoader
from torch_geometric import seed_everything
from torch_geometric.transforms import Distance, Cartesian
#from imports.TrainModel import TrainModel
seed_everything(0)

model = nvs_no_skip().cuda()
#!rm ../data/contact_extraction10/{train,test,val}/processed/*

tm = TrainModel('/home/hussain/tactile/data/contact_extraction10/', model, n_epochs=150, transform=Cartesian(cat=False), features='pol', augment=False, lr=0.01)

In [4]:
tm.train()

training:   0%|          | 0/150 [00:00<?, ?epoch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  with torch.autograd.detect_anomaly():


logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

logging


  0%|          | 0/386 [00:00<?, ?batch/s]

KeyboardInterrupt: 

In [None]:
min(tm.val_losses)

In [None]:
tm.test()

In [None]:
import time 
from tqdm.auto import tqdm
t = []
def validate(self):
    loss = 0
    for i, data in enumerate(tqdm(self.val_loader)):
        now = time.time()
        data = data.to('cuda')
        end_point = self.model(data)
        end = time.time() - now
        t.append(end)
        print(end)
        loss += self.loss_func(end_point[0], data.y).detach().item()
    loss /= len(self.val_data)
    return loss

validate(tm)

In [None]:
import numpy as np
np.mean(t)

In [None]:
tm.model.to('cuda')

In [None]:
from imports.ExtractContactCases import ExtractContactCases

ex = ExtractContactCases('/home/hussain/tactile/data/contact_extraction10/', '/home/hussain/tactile/dataset_ENVTACT_new2.bag')
ex.extract()