In [1]:
import os
import numpy as np 
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import SGD, Adam
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm

In [2]:
min_matrix = pd.read_csv('./data/min_matrix.csv', names=['x', 'y', 'z'])
max_matrix = pd.read_csv('./data/max_matrix.csv', names=['x', 'y', 'z'])
cable_matrix = pd.read_csv('./data/cable_matrix.csv', names=['t1', 't2', 't3', 't4'])

In [3]:
class ShapeToTension(Dataset):
    def __init__(self, shape_files, root_dir, min_matrix, max_matrix, cable_matrix):
        super(ShapeToTension, self).__init__()
        self.shape_files = shape_files
        self.root_dir = root_dir
        
        self.min_matrix = min_matrix
        self.x_min = self.min_matrix['x'].min()
        self.y_min = self.min_matrix['y'].min()
        self.z_min = self.min_matrix['z'].min()
        
        self.max_matrix = max_matrix
        self.x_max = self.max_matrix['x'].max()
        self.y_max = self.max_matrix['y'].max()
        self.z_max = self.max_matrix['z'].max()
        
        self.cable_matrix = cable_matrix
        # device to put the tensors on to speed up 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        
    def __len__(self):
        return len(self.cable_matrix)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        shape_file = self.shape_files[idx]
        shape_tensor = self.shape_matrix_to_tensor(shape_file)
        
        t1, t2, t3, t4 = self.cable_matrix.iloc[idx, :]
        tensions = torch.tensor([t1, t2, t3, t4], dtype=torch.float32)
        sample = {'shape_tensor': shape_tensor, 'tensions': tensions}
        
        return sample
    
    
    def shape_matrix_to_tensor(self, shape_file):
        # read shape matrix and shift the center by subtracting the minimum of each coordinate
        matrix_dir = os.path.join(self.root_dir, shape_file)
        shape_matrix = pd.read_csv(matrix_dir, header=None).T
        shape_matrix.columns = ['x', 'y', 'z']
        shape_matrix['x'] = (shape_matrix['x'] - self.x_min) / (self.x_max - self.x_min)
        shape_matrix['y'] = (shape_matrix['y'] - self.y_min) / (self.y_max - self.y_min)
        shape_matrix['z'] = (shape_matrix['z'] - self.z_min) / (self.z_max - self.z_min)
        shape_matrix = round(shape_matrix * 300 * 0.99)

        # tensor of the shape 
        shape_tensor = torch.zeros((1, 300,300,300), dtype=torch.float32)

        # loop of the coordinates of the points of the spline shape and convert them to one
        for index, row in shape_matrix.iterrows():
            width, depth, height = int(row['x']), int(row['y']), int(row['z'])
            shape_tensor[:, depth, height, width] = 1
            
        return shape_tensor

In [4]:
shape_files_dir = './data/shape/'
shape_files = os.listdir(shape_files_dir)
shape_files.sort()

In [5]:
ShapeToTension_ds = ShapeToTension(shape_files=shape_files, 
                                   root_dir=shape_files_dir, 
                                   min_matrix=min_matrix,
                                   max_matrix=max_matrix,
                                   cable_matrix=cable_matrix)

In [6]:
ShapeToTension_dl = DataLoader(ShapeToTension_ds, batch_size=1,shuffle=True)

In [7]:
data_itr = iter(ShapeToTension_dl)
sample = next(data_itr)
print(sample['shape_tensor'].shape)

torch.Size([1, 1, 300, 300, 300])


In [8]:
class Conv3D(nn.Module):
    def __init__(self):
        super(Conv3D, self).__init__()
        self.conv1_set = nn.Sequential(nn.Conv3d(1, 4, kernel_size=4, stride=2),
                                       nn.BatchNorm3d(4),
                                       nn.ReLU(),
                                       nn.MaxPool3d(2))
        self.conv2_set = nn.Sequential(nn.Conv3d(4, 8, kernel_size=4, stride=2),
                                       nn.BatchNorm3d(8),
                                       nn.ReLU(),
                                       nn.MaxPool3d(2))
        self.conv3_set = nn.Sequential(nn.Conv3d(8, 16, kernel_size=4, stride=2),
                                       nn.BatchNorm3d(16),
                                       nn.ReLU(),
                                       nn.MaxPool3d(2))
        self.fc1 = nn.Sequential(nn.Linear(1024, 512),
                                 nn.ReLU())
        self.fc2 = nn.Sequential(nn.Linear(512, 128),
                                 nn.ReLU())
        self.fc3 = nn.Linear(128, 4)
    def forward(self, x):
        x = self.conv1_set(x)
        x = self.conv2_set(x)
        x = self.conv3_set(x)
        x = torch.flatten(x,start_dim=1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

In [9]:
data_itr = iter(ShapeToTension_dl)
sample = data_itr.next()
model = Conv3D()
out = model(sample['shape_tensor'])
out.shape

torch.Size([1, 4])

In [10]:
# train the model
def train_model(train_dl, model):
    model.train()
    # define the optimization
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    # enumerate epochs
    for epoch in range(20):
        loss_arr = []
        # enumerate mini batches
        for i, sample_batch in tqdm(enumerate(train_dl)):
            # clear the gradients
            optimizer.zero_grad()
            # compute the model output
            yhat = model(sample_batch['shape_tensor'].to(device))
            # calculate loss
            loss = criterion(yhat, sample_batch['tensions'].to(device))
            loss_arr.append(loss.item())
            # credit assignment
            loss.backward()
            # update model weights
            optimizer.step()
        if epoch % 10 == 0 or epoch == 0:
            print('Epoch: {}, mean loss: {}'.format(epoch, np.mean(loss_arr)))

In [11]:
len(ShapeToTension_ds)

250000

In [12]:
#lengths = [int(len(ShapeToTension_ds)*0.8)]
ShapeToTension_ds_sub_A, ShapeToTension_ds_sub_B = random_split(ShapeToTension_ds, lengths=[100000, 150000])
ShapeToTension_ds_train, ShapeToTension_ds_test = random_split(ShapeToTension_ds_sub_A, lengths=[80000, 20000])

ShapeToTension_dl_train = DataLoader(ShapeToTension_ds_train, batch_size=16,shuffle=True)
ShapeToTension_dl_test = DataLoader(ShapeToTension_ds_test, batch_size=1,shuffle=True)

In [13]:
model = Conv3D().cuda()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
# train the model
train_model(ShapeToTension_dl_train, model)

5000it [2:04:14,  1.49s/it]
0it [00:00, ?it/s]

Epoch: 0, mean loss: 3.8326760459303855


5000it [2:00:55,  1.45s/it]
5000it [2:01:12,  1.45s/it]
5000it [2:00:20,  1.44s/it]
5000it [2:00:15,  1.44s/it]
5000it [2:00:23,  1.44s/it]
5000it [2:00:32,  1.45s/it]
5000it [2:01:02,  1.45s/it]
5000it [2:00:58,  1.45s/it]
5000it [1:59:30,  1.43s/it]
5000it [2:00:19,  1.44s/it]
0it [00:00, ?it/s]

Epoch: 10, mean loss: 1.819358343565464


5000it [2:02:48,  1.47s/it]
5000it [2:03:45,  1.49s/it]
5000it [1:56:55,  1.40s/it]
5000it [1:56:11,  1.39s/it]
5000it [1:56:19,  1.40s/it]
5000it [1:56:27,  1.40s/it]
5000it [1:56:04,  1.39s/it]
5000it [1:56:30,  1.40s/it]
5000it [1:56:26,  1.40s/it]


In [14]:
# Specify a path
PATH = "model300_wBN_DrBerkeCode_80KS_0-1.pt"

# Save
torch.save(model, PATH)

# Testing #

In [15]:
# test the model
def test_model(test_dl, model):
    model.eval()
    abs_error = torch.zeros([len(test_dl),4]).to('cuda')
    for i, sample_batch in tqdm(enumerate(test_dl)):
        # compute the model output
        yhat = model(sample_batch['shape_tensor'].to('cuda'))
        # retrieve numpy array
        yhat = yhat.detach()
        targets = sample_batch['tensions'].to('cuda')
        abs_error[i,:] = torch.absolute(targets - yhat)
    return abs_error

In [17]:
# test the model 0-1
# test the model IK
abs_error_accumulated = test_model(ShapeToTension_dl_test, model)

20000it [32:26, 10.27it/s]


In [18]:
abs_error_accumulated.mean(0)

tensor([1.0254, 1.0657, 1.0362, 1.0562], device='cuda:0')

In [19]:
abs_error_accumulated.std(0)

tensor([0.7901, 0.8275, 0.7977, 0.8048], device='cuda:0')