Connected to abusfusion (Python 3.12.5)


# 6DoF Pose Estimation with Ultrasound Frames and IMU Data Using Mamba SSM


## 1. Import Libraries


In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from mambassm import MambaSSM  # Import Mamba SSM

ModuleNotFoundError: No module named 'mambassm'

## 2. Load and Preprocess Data


In [None]:
def load_data(file_path):
    return pd.read_hdf(file_path)


file_path = "/path/to/your/data.h5"
df = load_data(file_path)

# Ensure data is numeric
df = df.apply(pd.to_numeric, errors="coerce").dropna()

: 

## 3. Create Dataset for 6DoF Pose Estimation


In [None]:
class PoseEstimationDataset(Dataset):
    def __init__(self, df, transform=None, downsample_factor=3):
        self.df = df
        self.transform = transform
        self.downsample_factor = downsample_factor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Load and process ultrasound frame
        frame = torch.tensor(np.stack(row["frame"])).float().permute(2, 0, 1)
        if self.downsample_factor > 1:
            frame = F.resize(
                frame,
                [
                    frame.shape[1] // self.downsample_factor,
                    frame.shape[2] // self.downsample_factor,
                ],
            )

        # IMU data
        imu_data = torch.tensor(
            [
                row["imu_acc_x"],
                row["imu_acc_y"],
                row["imu_acc_z"],
                row["imu_orientation_x"],
                row["imu_orientation_y"],
                row["imu_orientation_z"],
            ]
        ).float()

        # 6DoF pose as target
        target = torch.tensor(
            [
                row["ot_pos_x"],
                row["ot_pos_y"],
                row["ot_pos_z"],
                row["ot_qw"],
                row["ot_qx"],
                row["ot_qy"],
                row["ot_qz"],
            ]
        ).float()

        return frame, imu_data, target

: 

In [None]:
# Create the dataset
pose_dataset = PoseEstimationDataset(df)

# Split the dataset
train_indices, val_indices = train_test_split(
    range(len(pose_dataset)), test_size=0.2, random_state=42
)

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(pose_dataset, batch_size=8, sampler=train_sampler)
val_loader = DataLoader(pose_dataset, batch_size=1, sampler=val_sampler)

: 

## 4. Define Model for 6DoF Pose Estimation


In [None]:
class PoseEstimationModel(nn.Module):
    def __init__(self, input_channels=3, downsample_factor=3):
        super(PoseEstimationModel, self).__init__()

        # CNN for ultrasound frames
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        # Mamba SSM for IMU data
        self.ssm = MambaSSM(
            input_dim=6, hidden_dim=64, num_layers=2, bidirectional=False
        )

        # Fully connected layers for 6DoF pose
        conv_output_height = 1000 // (
            4 * downsample_factor
        )  # Adjusted based on downsampling
        conv_output_width = 657 // (
            4 * downsample_factor
        )  # Adjusted based on downsampling

        self.fc = nn.Sequential(
            nn.Linear(64 * conv_output_height * conv_output_width + 64, 256),
            nn.ReLU(),
            nn.Linear(
                256, 7
            ),  # 7 values: 3 for translation, 4 for rotation (quaternion)
        )

    def forward(self, frame, imu_data):
        # CNN for ultrasound frames
        x = self.pool(torch.relu(self.conv1(frame)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten the output

        # Mamba SSM for IMU data
        imu_data = imu_data.unsqueeze(1)  # Add sequence length dimension (assumed 1)
        ssm_out, _ = self.ssm(imu_data)
        ssm_out = ssm_out[:, -1, :]  # Take the last output of the sequence

        # Concatenate CNN and SSM outputs
        combined = torch.cat((x, ssm_out), dim=1)

        # Fully connected layers for 6DoF pose
        output = self.fc(combined)

        return output

: 

## 5. Train the Model


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PoseEstimationModel().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.1, patience=5, verbose=True
)

: 

In [None]:
def train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    num_epochs=50,
    patience=10,
):
    best_val_loss = float("inf")
    best_model_weights = None
    epochs_no_improve = 0

    for epoch in tqdm(range(num_epochs), desc="Epochs", position=0):
        model.train()
        train_loss = 0.0
        epoch_progress = tqdm(
            train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", position=1, leave=False
        )

        for frames, imu_data, targets in epoch_progress:
            frames, imu_data, targets = (
                frames.to(device),
                imu_data.to(device),
                targets.to(device),
            )
            optimizer.zero_grad()

            # Forward pass
            outputs = model(frames, imu_data)
            loss = criterion(outputs, targets)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_loss_avg = train_loss / (epoch_progress.n + 1)
            epoch_progress.set_postfix({"Avg Loss": f"{train_loss_avg:.4f}"})

        # Validation step
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for frames, imu_data, targets in val_loader:
                frames, imu_data, targets = (
                    frames.to(device),
                    imu_data.to(device),
                    targets.to(device),
                )
                outputs = model(frames, imu_data)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

        val_loss /= len(val_loader)
        scheduler.step(val_loss)

        tqdm.write(
            f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss_avg:.4f}, Val Loss: {val_loss:.4f}"
        )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_weights = model.state_dict().copy()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                tqdm.write("Early stopping triggered")
                break

    model.load_state_dict(best_model_weights)
    return model

: 

In [None]:
# Train the model
trained_model = train_model(
    model, train_loader, val_loader, criterion, optimizer, scheduler
)

# Save the best model
torch.save(trained_model.state_dict(), "best_model.pth")

: 

## 6. Evaluate the Model


In [None]:
def evaluate_model(model, data_loader):
    model.eval()
    total_loss = 0.0
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for frames, imu_data, targets in data_loader:
            frames, imu_data, targets = (
                frames.to(device),
                imu_data.to(device),
                targets.to(device),
            )
            outputs = model(frames, imu_data)
            loss = criterion(outputs, targets)
            total_loss += loss.item()

            all_predictions.append(outputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    avg_loss = total_loss / len(data_loader)
    all_predictions = np.concatenate(all_predictions)
    all_targets = np.concatenate(all_targets)

    return avg_loss, all_predictions, all_targets

: 

In [None]:
# Evaluate the model
val_loss, val_predictions, val_targets = evaluate_model(trained_model, val_loader)
print(f"Validation Loss: {val_loss:.4f}")

: 

## 7. Visualize Predictions


In [None]:
def visualize_predictions(predictions, targets):
    fig = plt.figure(figsize=(15, 10))
    ax = fig.add_subplot(111, projection="3d")

    # Plot ground truth positions
    ax.scatter(
        targets[:, 0], targets[:, 1], targets[:, 2], label="Ground Truth", marker="o"
    )

    # Plot predicted positions
    ax.scatter(
        predictions[:, 0],
        predictions[:, 1],
        predictions[:, 2],
        label="Predicted",
        marker="x",
    )

    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    ax.set_title("Predicted vs Ground Truth Positions")
    ax.legend()
    plt.show()


visualize_predictions(val_predictions, val_targets)

: 