In [1]:
import torch
import os
import pandas as pd
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
import re
from torch import nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision.transforms import v2
from torchvision.io import read_image, ImageReadMode
from torchvision.models import vgg16, VGG16_Weights
#from plyfile import PlyData
from pytorch_tcn import TCN, TemporalConv1d, TemporalConvTranspose1d
#from torch_points3d.core import data_transform as trf
import torch_geometric.transforms as T
from torch_geometric.data import Data
import torch.nn.functional as F
#from torch_pointnet import PointNetfeat as PNet
#from pointnet import get_model as PNet
#from pointnet2 import get_model as PNet2
from pointnet2_ssg import get_model as PNet2

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(device)
print(torch.cuda.is_available())

cuda
True


In [3]:
def read_ply(path):
    pcd = o3d.io.read_point_cloud(path)
    #offset_tensor = torch.tensor([0, 0.325, 0]) # offset between the Point Cloud and RGB
    downsampled_pc = pcd.farthest_point_down_sample(2048)
    
    pos = np.asarray(downsampled_pc.points)
    pos = torch.tensor(pos, dtype=torch.float32)
    #pos += offset_tensor
    
    # Return point cloud data as PyG Data object (no faces for point clouds)
    data = Data(pos=pos)

    return data

def extract_image_number(filename):
    match = re.search(r'frame_(\d+)\.png', filename)
    return int(match.group(1)) if match else -1  # Extract number, fallback -1 for safety

def extract_pc_number(filename):
    match = re.search(r'frame_(\d+)\.ply', filename)
    return int(match.group(1)) if match else -1  # Extract number, fallback -1 for safety

In [4]:
def visualize_rgb(rgb_image):
    # Convert to numpy array and transpose for display (from CxHxW to HxWxC)
    rgb_image_np = rgb_image.permute(1, 2, 0).numpy()
    
    # Plot the RGB image using matplotlib
    plt.imshow(rgb_image_np)
    plt.title("RGB Image")
    plt.axis('off')  # Hide the axes
    plt.show()

# Step 2: Visualize the Point Cloud
def visualize_point_cloud(point_cloud):
    # Convert the point cloud tensor to numpy array
    point_cloud_np = point_cloud.numpy()

    # Create an Open3D point cloud object
    pcd = o3d.geometry.PointCloud()
    
    # Convert the point cloud array to Open3D format
    pcd.points = o3d.utility.Vector3dVector(point_cloud_np)

    # Create a visualizer and add the point cloud
    o3d.visualization.draw_geometries([pcd], window_name="Translated Point Cloud")

In [9]:
class RPC_Dataset(Dataset):
    def __init__(self, image_dir, pointcloud_dir, forces, image_transform=None, pointcloud_transform=None):
        self.image_dir = image_dir
        self.pointcloud_dir = pointcloud_dir
        self.labels = pd.read_csv(forces)

        rgb_transform = v2.Compose([
            v2.Resize((224, 224)),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True)
        ])

        self.image_transform = rgb_transform
        self.image_augment = image_transform
        self.pointcloud_transform = pointcloud_transform

        vgg = vgg16(weights=VGG16_Weights.DEFAULT)
        vgg.eval()
        vgg.to(device)
        self.rgb_feature_extractor = nn.Sequential(
            *list(vgg.features.children()),   # Convolutional layers
            nn.AdaptiveAvgPool2d((7, 7)),     # Ensure 7x7 output size
            nn.Flatten(),                     # Flatten to (batch_size, 25088)
            *list(vgg.classifier.children())[:-1]  # Use up to the second to last FC layer (4096 output)
        )
        
        pointnet = PNet2(40, normal_channel=False)
        #checkpoint = torch.load("pnet2_weights.pth")
        checkpoint = torch.load("pnet2_weights_ssg.pth")
        pointnet.load_state_dict(checkpoint['model_state_dict'], strict=False)
        pointnet.eval()
        pointnet.to(device)
        self.pc_feature_extractor = pointnet
        '''
        self.pc_feature_extractor = nn.Sequential(
            *list(pointnet.children())[:-1],
            nn.Linear(40, 512)
        )
        '''
        self.seq_length = 15

        # Get a list of image and point cloud file names (assuming they match)
        self.image_filenames = sorted(os.listdir(image_dir), key=extract_image_number)
        self.pointcloud_filenames = sorted(os.listdir(pointcloud_dir), key=extract_pc_number)

    def __len__(self):
        #print(len(self.image_filenames))
        return int(len(self.image_filenames) / self.seq_length)

    def __getitem__(self, idx):
        # Load the image
        if idx >= len(self.labels):
            raise IndexError(f"Index {idx} is out of bounds for labels of size {len(self.labels)}")
        
        cat_tensors = []
        #print(self.pc_feature_extractor)
        for i in range(self.seq_length):
            img_path = os.path.join(self.image_dir, self.image_filenames[idx + i])
            image = read_image(img_path, ImageReadMode.RGB)
            # Apply base transform
            image = self.image_transform(image)
            pointcloud_path = os.path.join(self.pointcloud_dir, self.pointcloud_filenames[idx + i])
            pointcloud = read_ply(pointcloud_path)
            # Apply augmentation
            if self.pointcloud_transform or self.image_augment:
                if self.image_augment:
                    image = self.image_augment(image)
                if self.pointcloud_transform:
                    pointcloud = self.pointcloud_transform(pointcloud)
            pointcloud_tensor = pointcloud.pos.clone().detach()
            image, pointcloud_tensor = image.to(device), pointcloud_tensor.to(device)
            pointcloud_tensor = pointcloud_tensor.transpose(0, 1)
            image = image.unsqueeze(0)
            pointcloud_tensor = pointcloud_tensor.unsqueeze(0)
            #print(pointcloud_tensor.size())
            rgb_features = self.rgb_feature_extractor(image)
            pc_features = self.pc_feature_extractor(pointcloud_tensor)[0]
            #print(i)
            rgb_features = rgb_features.squeeze(0)
            pc_features = pc_features.squeeze(0)
            cat = torch.cat((rgb_features, pc_features), dim=0)
            # Stack the frames into separate sequences
            cat_tensors.append(cat)
            
        #images = torch.stack(images, dim=0)  # Shape: (seq_length, C, H, W)
        #pointclouds = torch.stack(pointclouds, dim=0)  # Shape: (seq_length, N, 3)
        cat_tensors = torch.stack(cat_tensors, dim=0)
        label = self.labels.iloc[idx + self.seq_length - 2, 1]
        contact = self.labels.iloc[idx + self.seq_length - 2, 2]
        
        return cat_tensors, torch.tensor(label, dtype=torch.float32), torch.tensor(contact, dtype=torch.bool)

rgb_transform = v2.Compose([
    v2.Resize((224, 224)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

# Define transform for the images (if needed)
rgb_augment = v2.Compose([
    v2.Resize((224, 224)),
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.AutoAugment()
])

pc_augment = T.Compose([
    T.RandomJitter(0.01),
    T.RandomRotate(30),
    T.RandomScale((0.8, 1.2))
])

# Initialize the dataset
train_image_dir = 'data/train_images'          # Path to folder containing images
train_pointcloud_dir = 'data/train_pcs'  # Path to folder containing .ply point clouds
#train_forces = 'data/train_labels.csv'
train_forces = 'data/updated_train.csv'

test_image_dir = 'data/test_images'          # Path to folder containing images
test_pointcloud_dir = 'data/test_pcs'  # Path to folder containing .ply point clouds
#test_forces = 'data/test_labels.csv'
test_forces = 'data/updated_test.csv'

train_dataset = RPC_Dataset(train_image_dir, train_pointcloud_dir, train_forces)
test_dataset = RPC_Dataset(test_image_dir, test_pointcloud_dir, test_forces)
#augmented_test_dataset = RPC_Dataset(test_image_dir, test_pointcloud_dir, test_forces, image_transform=rgb_augment, pointcloud_transform=pc_augment)
#augmented_test_dataset = ConcatDataset([test_dataset, augmented_test_dataset])

'''
data = dataset.__getitem__(5)
image, pointcloud = data[0], data[1]
visualize_rgb(image)
visualize_point_cloud(pointcloud)
'''
# Use DataLoader to batch and do not shuffle the data to keep it sequential
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

'''
for batch in train_dataloader:
    cats, labels = batch  # Unpack all three: images, pointclouds, and labels
    print("Cats batch size:", cats.size())
    #print("Pointcloud batch size:", pointclouds.size())
    print("Labels:", labels)
    break
'''

  checkpoint = torch.load("pnet2_weights_ssg.pth")
  checkpoint = torch.load("pnet2_weights_ssg.pth")


'\nfor batch in train_dataloader:\n    cats, labels = batch  # Unpack all three: images, pointclouds, and labels\n    print("Cats batch size:", cats.size())\n    #print("Pointcloud batch size:", pointclouds.size())\n    print("Labels:", labels)\n    break\n'

In [6]:
class RPC_TCN(nn.Module):
    def __init__(self, input_size=4608, output_size=1, num_channels=[64, 128, 256], kernel_size=3, dropout=0.2):
        super(RPC_TCN, self).__init__()
        # Spatial Block
        
        #self.full = nn.Linear(1024, 512)

        # Temporal Block
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 1280)
        self.fc3 = nn.Linear(1280, 1280)
        self.fc4 = nn.Linear(1280, output_size)
        self.relu = nn.ReLU()
        self.tconv1 = TemporalConv1d(256, 64, kernel_size)
        self.tconv2 = TemporalConv1d(64, 128, kernel_size)
        self.tconv3 = TemporalConv1d(128, 256, kernel_size)
        #self.tconvt1 = TemporalConvTranspose1d(256, 64, kernel_size, stride=2)
        #self.tconvt2 = TemporalConvTranspose1d(64, 128, kernel_size, stride=2)
        #self.tconvt3 = TemporalConvTranspose1d(128, 256, kernel_size, stride=2)
        self.bnorm1 = nn.BatchNorm1d(64)
        self.bnorm2 = nn.BatchNorm1d(128)
        self.bnorm3 = nn.BatchNorm1d(256)
        
        self.tcn = TCN(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
        self.fc = nn.Linear(num_channels[-1], output_size)

    def forward(self, cat_input):
        
        # TB forward prop
        #cat = self.fc1(cat)
        #print(cat.size())
        #cat = F.relu(self.bnorm1(self.tconv1(cat)))
        #cat = F.relu(self.bnorm2(self.tconv2(cat)))
        #cat = F.relu(self.bnorm3(self.tconv3(cat)))
        #cat = F.relu(self.bnorm1(self.tconvt1(cat)))
        #cat = F.relu(self.bnorm2(self.tconvt2(cat)))
        #cat = F.relu(self.bnorm3(self.tconvt3(cat)))
        #force = F.relu(self.fc4(self.fc3(self.fc2(cat))))
        cat_input = cat_input.transpose(1, 2)
        #print(cat.size())
        #print(cat)
        force = self.tcn(cat_input)  # Apply TCN layer
        #print(force.size())
        force = self.fc(force[:, :, -1])  # Take the last time step and pass through fully connected layer
        #print(force.size())
        return force

model = RPC_TCN().to(device)
print(model)

RPC_TCN(
  (fc1): Linear(in_features=4608, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=1280, bias=True)
  (fc3): Linear(in_features=1280, out_features=1280, bias=True)
  (fc4): Linear(in_features=1280, out_features=1, bias=True)
  (relu): ReLU()
  (tconv1): TemporalConv1d(
    256, 64, kernel_size=(3,), stride=(1,)
    (padder): ConstantPad1d(padding=(2, 0), value=0.0)
  )
  (tconv2): TemporalConv1d(
    64, 128, kernel_size=(3,), stride=(1,)
    (padder): ConstantPad1d(padding=(2, 0), value=0.0)
  )
  (tconv3): TemporalConv1d(
    128, 256, kernel_size=(3,), stride=(1,)
    (padder): ConstantPad1d(padding=(2, 0), value=0.0)
  )
  (bnorm1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnorm2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnorm3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (tcn): TCN(
    (network): ModuleList(
      (0): Te

In [17]:
# Mean Squared Error
loss_fn = nn.MSELoss()
#optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

def contact_loss(pred, f, contact, func='MSE', weight=10):
    loss_fn = nn.MSELoss() if func == 'MSE' else nn.L1Loss()
    loss = loss_fn(pred, f)

    weighted_loss = torch.where(contact == 0, weight*loss, loss)

    return torch.mean(weighted_loss)

def train(dataloader, model, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (cat, f, contact) in enumerate(dataloader):
        cat, f = cat.float(), f.float()
        cat, f, contact = cat.to(device), f.to(device), contact.to(device)
        #f = f * 0.001
        #print(f)

        # Compute prediction error
        pred = model(cat)
        pred = pred.squeeze(0)
        #print(pred)
        #loss = loss_fn(pred, f)
        loss = contact_loss(pred, f, contact)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 1 == 0:
            loss, current = loss.item(), (batch + 1) * len(cat)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    total_absolute_error, test_loss = 0, 0

    with torch.no_grad():
        for cat, f, contact in dataloader:
            cat, f = cat.float(), f.float()
            cat, f, contact = cat.to(device), f.to(device), contact.to(device)
            
            # Forward pass: Compute predictions
            pred = model(cat)
            pred = pred.squeeze(0)  # Ensure correct shape if necessary
            
            # Accumulate Mean Squared Error loss (MSE) as used in training
            test_loss += contact_loss(pred, f, contact).item()
            
            # Calculate Mean Absolute Error (MAE) for this batch
            #total_absolute_error += torch.sum(torch.abs(pred - f)).item()
            total_absolute_error += torch.sum(contact_loss(pred, f, contact, func='MAE')).item()

    # Calculate average losses
    avg_mse_loss = test_loss / num_batches
    mae = total_absolute_error / size

    print(f"Test Results: \n MAE: {mae:.5f}, Avg MSE Loss: {avg_mse_loss:.5f} \n")

In [18]:
epochs = 20
for t in range(epochs):
    torch.cuda.empty_cache()
    print(f"Epoch {t+1}\n-------------------------------")
    #train(train_dataloader, model, loss_fn, optimizer)
    #test(test_dataloader, model, loss_fn)
    train(train_dataloader, model, optimizer)
    test(test_dataloader, model)
    if (t+1) % 2 == 0:
        scheduler.step()
print("Done!")

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")


Epoch 1
-------------------------------
loss: 0.307230  [    1/   33]
loss: 0.017039  [    2/   33]
loss: 0.005657  [    3/   33]
loss: 0.233500  [    4/   33]
loss: 0.494319  [    5/   33]
loss: 0.041492  [    6/   33]
loss: 0.009757  [    7/   33]
loss: 0.024964  [    8/   33]
loss: 0.062175  [    9/   33]
loss: 0.125688  [   10/   33]
loss: 0.022470  [   11/   33]
loss: 0.022873  [   12/   33]
loss: 0.000522  [   13/   33]
loss: 0.097092  [   14/   33]
loss: 0.001549  [   15/   33]
loss: 0.000032  [   16/   33]
loss: 0.037148  [   17/   33]
loss: 0.091034  [   18/   33]
loss: 0.034727  [   19/   33]
loss: 0.166097  [   20/   33]
loss: 0.019003  [   21/   33]
loss: 0.012995  [   22/   33]
loss: 0.415239  [   23/   33]
loss: 0.067076  [   24/   33]
loss: 0.007780  [   25/   33]
loss: 0.001712  [   26/   33]
loss: 0.062856  [   27/   33]
loss: 0.249283  [   28/   33]
loss: 0.082959  [   29/   33]
loss: 0.006768  [   30/   33]
loss: 0.023765  [   31/   33]
loss: 0.048874  [   32/   33]


KeyboardInterrupt: 