# Ultrasound Frame Corner Prediction with Optical Flow and IMU Data


## 1. Import Libraries


In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial.transform import Rotation
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision.models.optical_flow import (
    Raft_Large_Weights,
    Raft_Small_Weights,
    raft_large,
    raft_small,
)
from tqdm.notebook import tqdm

## 2. Load and Preprocess Data


In [None]:
def load_data(file_path):
    return pd.read_hdf(file_path)


file_path = "/home/varun/xia_lab/repos/ABUSFusion/scans/20240826/wrist_data.h5"
df = load_data(file_path)

## 3. Calculate Frame Corners


In [None]:
def calculate_frame_corners(df, width, height, marker_to_probe_bottom):
    corners_list = []

    for index, row in df.iterrows():
        # Extract position and quaternion
        position = np.array([row["ot_pos_x"], row["ot_pos_y"], row["ot_pos_z"]])
        quaternion = np.array([row["ot_qw"], row["ot_qx"], row["ot_qy"], row["ot_qz"]])

        # Convert quaternion to rotation matrix
        rotation_matrix = Rotation.from_quat(quaternion).as_matrix()

        # Define frame corners in probe coordinates (y: axial, x: lateral, z: elevational)
        frame_corners = np.array(
            [
                [0, 0, 0],  # Top-left
                [width, 0, 0],  # Top-right
                [0, height, 0],  # Bottom-left
                [width, height, 0],  # Bottom-right
            ]
        )

        # Transform to align with optical tracker coordinates
        transform_matrix = np.array(
            [
                [0, 0, -1],  # Tracker X -> -Probe Z (elevational)
                [1, 0, 0],  # Tracker Y -> Probe X (lateral)
                [0, -1, 0],  # Tracker Z -> -Probe Y (axial)
            ]
        )

        frame_corners = frame_corners @ transform_matrix.T

        # Add offset for marker to probe bottom
        frame_corners[:, 2] += marker_to_probe_bottom

        # Transform frame corners to world coordinates
        world_corners = (
            np.einsum("ij,kj->ki", rotation_matrix, frame_corners) + position
        )

        corners_list.append(world_corners.flatten())

    corners_array = np.array(corners_list)
    corners_columns = [
        f"corner_{i}_{axis}" for i in range(4) for axis in ["x", "y", "z"]
    ]
    corners_df = pd.DataFrame(corners_array, columns=corners_columns)

    return corners_df

In [None]:
# Probe specifications
probe_specs = {
    "width": 38,  # mm
    "height": 50,  # mm
    "marker_to_probe_bottom": 54,  # mm
}

corners_df = calculate_frame_corners(
    df,
    probe_specs["width"],
    probe_specs["height"],
    probe_specs["marker_to_probe_bottom"],
)
new_df = pd.concat(
    [
        df.drop(
            columns=[
                "ot_pos_x",
                "ot_pos_y",
                "ot_pos_z",
                "ot_qw",
                "ot_qx",
                "ot_qy",
                "ot_qz",
            ]
        ),
        corners_df,
    ],
    axis=1,
)

## 4. Create Dataset for Optical Flow and IMU Data


In [None]:
class OpticalFlowDataset(Dataset):
    def __init__(self, df, sequence_length=2, transform=None, downsample_factor=3):
        self.df = df
        self.sequence_length = sequence_length
        self.transform = transform
        self.downsample_factor = downsample_factor

    def __len__(self):
        return len(self.df) - self.sequence_length + 1

    def __getitem__(self, idx):
        sequence = self.df.iloc[idx : idx + self.sequence_length]

        # Convert frames to tensors
        frame1 = (
            torch.tensor(np.stack(sequence["frame"].iloc[0])).float().permute(2, 0, 1)
        )
        frame2 = (
            torch.tensor(np.stack(sequence["frame"].iloc[1])).float().permute(2, 0, 1)
        )

        # Convert IMU data to numeric, raising an error if any non-numeric values are found
        imu_data = pd.to_numeric(
            sequence.iloc[1][
                [
                    "imu_acc_x",
                    "imu_acc_y",
                    "imu_acc_z",
                    "imu_orientation_x",
                    "imu_orientation_y",
                    "imu_orientation_z",
                ]
            ],
            errors="raise",
        ).values  # Raises an error if non-numeric data is encountered

        imu_data = torch.tensor(imu_data).float()

        # Convert target data to numeric, raising an error if any non-numeric values are found
        target = pd.to_numeric(
            sequence.iloc[1][
                [
                    "corner_0_x",
                    "corner_0_y",
                    "corner_0_z",
                    "corner_1_x",
                    "corner_1_y",
                    "corner_1_z",
                    "corner_2_x",
                    "corner_2_y",
                    "corner_2_z",
                    "corner_3_x",
                    "corner_3_y",
                    "corner_3_z",
                ]
            ],
            errors="raise",
        ).values  # Raises an error if non-numeric data is encountered

        target = torch.tensor(target).float()

        return frame1, frame2, imu_data, target

In [None]:
# Create the dataset
optical_flow_dataset = OpticalFlowDataset(new_df)

# Split the dataset
train_indices, val_indices = train_test_split(
    range(len(optical_flow_dataset)), test_size=0.2, random_state=42
)

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(optical_flow_dataset, batch_size=8, sampler=train_sampler)
val_loader = DataLoader(optical_flow_dataset, batch_size=1, sampler=val_sampler)

## 5. Define RAFT Model for Optical Flow


In [None]:
# Load pretrained RAFT model from torchvision
weights = Raft_Small_Weights.DEFAULT
raft_model = raft_small(weights=weights).eval().cuda()

# RAFT normalization and utility function
transforms_raft = weights.transforms()

In [None]:
class InputPadder:
    """Pads images such that dimensions are divisible by 8."""

    def __init__(self, dims, mode="constant"):
        self.ht, self.wd = dims[-2:]
        pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
        pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
        self._pad = [
            pad_wd // 2,
            pad_wd - pad_wd // 2,
            pad_ht // 2,
            pad_ht - pad_ht // 2,
        ]
        self.mode = mode

    def pad(self, *inputs):
        return [nn.functional.pad(x, self._pad, mode=self.mode) for x in inputs]

    def unpad(self, *inputs):
        ht, wd = self.ht, self.wd
        return [
            x[..., self._pad[2] : ht + self._pad[2], self._pad[0] : wd + self._pad[0]]
            for x in inputs
        ]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def compute_optical_flow(frame1, frame2, model, device, downsample_factor=3):
    # Get the original size
    original_size = frame1.shape[-2:]  # (height, width)

    # Compute the downsampled size
    downsampled_height = original_size[0] // downsample_factor
    downsampled_width = original_size[1] // downsample_factor

    # Find the nearest multiples of 8
    adjusted_height = (downsampled_height + 7) // 8 * 8
    adjusted_width = (downsampled_width + 7) // 8 * 8

    # Resize the frames to the adjusted size
    frame1 = F.resize(frame1, size=[adjusted_height, adjusted_width], antialias=False)
    frame2 = F.resize(frame2, size=[adjusted_height, adjusted_width], antialias=False)

    # Apply the RAFT-specific transformation to both images together
    frame1, frame2 = transforms_raft(frame1, frame2)

    # Move the transformed images to the GPU
    frame1, frame2 = frame1.to(device), frame2.to(device)

    # Pass the two images as separate arguments to the RAFT model
    with torch.no_grad():
        flow_list = model(frame1, frame2)

    # Return the last predicted flow from the model
    return flow_list[-1]

## 6. Define Frame Corner Prediction Model


In [None]:
class FrameCornerPredictionModel(nn.Module):
    def __init__(self, input_channels=2):
        super(FrameCornerPredictionModel, self).__init__()

        # Optical flow encoder
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        # IMU encoder
        self.fc_imu = nn.Sequential(
            nn.Linear(6, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU()
        )

        # Placeholder for dynamically sized fc layer
        self.fc_corners = None

    def forward(self, flow, imu_data):
        # Move everything to the same device as `flow`
        device = flow.device
        imu_data = imu_data.to(device)

        # Pass through conv layers
        x = self.pool(torch.relu(self.conv1(flow)))
        x = self.pool(torch.relu(self.conv2(x)))

        # Dynamically calculate the size after conv layers
        batch_size = x.size(0)
        conv_output_dim = x.size(1) * x.size(2) * x.size(3)

        # If fc_corners is not initialized, initialize it now with the correct input size
        if self.fc_corners is None:
            self.fc_corners = nn.Sequential(
                nn.Linear(conv_output_dim + 128, 512),
                nn.ReLU(),
                nn.Linear(512, 12),  # 12 values for 4 corners (x, y, z) * 4
            ).to(device)

        x = x.view(batch_size, -1)  # Flatten the output

        imu_features = self.fc_imu(imu_data)

        combined_features = torch.cat((x, imu_features), dim=1)

        corners = self.fc_corners(combined_features)

        return corners

## 7. Train the Model


In [None]:
model = FrameCornerPredictionModel().to(device)
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
scheduler = ReduceLROnPlateau(
    optimizer, mode="min", factor=0.1, patience=5, verbose=True
)

In [None]:
def train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    num_epochs=100000,
    patience=10,
    downsample_factor=3,
):
    best_val_loss = float("inf")
    best_model_weights = None
    epochs_no_improve = 0

    for epoch in tqdm(range(num_epochs), desc="Epochs", position=0):
        model.train()
        train_loss = 0.0
        epoch_progress = tqdm(
            train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", position=1, leave=False
        )

        for frame1, frame2, imu_data, targets in epoch_progress:
            frame1, frame2, imu_data, targets = (
                frame1.to(device),
                frame2.to(device),
                imu_data.to(device),
                targets.to(device),
            )
            optimizer.zero_grad()

            # Compute optical flow using RAFT with the downsample factor
            flow = compute_optical_flow(
                frame1, frame2, raft_model, device, downsample_factor
            )

            # Forward pass
            outputs = model(flow, imu_data)
            loss = criterion(outputs, targets)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_loss_avg = train_loss / (epoch_progress.n + 1)
            epoch_progress.set_postfix({"Avg Loss": f"{train_loss_avg:.4f}"})

        # Validation step
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for frame1, frame2, imu_data, targets in val_loader:
                frame1, frame2, imu_data, targets = (
                    frame1.to(device),
                    frame2.to(device),
                    imu_data.to(device),
                    targets.to(device),
                )
                flow = compute_optical_flow(
                    frame1, frame2, raft_model, device, downsample_factor
                )
                outputs = model(flow, imu_data)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

        val_loss /= len(val_loader)
        scheduler.step(val_loss)

        tqdm.write(
            f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss_avg:.4f}, Val Loss: {val_loss:.4f}"
        )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_weights = model.state_dict().copy()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                tqdm.write("Early stopping triggered")
                break

    model.load_state_dict(best_model_weights)
    return model

In [14]:
# Train the model
trained_model = train_model(
    model, train_loader, val_loader, criterion, optimizer, scheduler
)

# Save the best model
torch.save(trained_model.state_dict(), "best_model.pth")

## 8. Evaluate Model


In [None]:
def evaluate_model(model, data_loader, downsample_factor=3):
    model.eval()
    total_loss = 0.0
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for frame1, frame2, imu_data, targets in data_loader:
            frame1, frame2, imu_data, targets = (
                frame1.to(device),
                frame2.to(device),
                imu_data.to(device),
                targets.to(device),
            )
            flow = compute_optical_flow(
                frame1, frame2, raft_model, device, downsample_factor
            )
            outputs = model(flow, imu_data)
            loss = criterion(outputs, targets)
            total_loss += loss.item()

            all_predictions.append(outputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    avg_loss = total_loss / len(data_loader)
    all_predictions = np.concatenate(all_predictions)
    all_targets = np.concatenate(all_targets)

    return avg_loss, all_predictions, all_targets


# Evaluate the model
val_loss, val_predictions, val_targets = evaluate_model(trained_model, val_loader)
print(f"Validation Loss: {val_loss:.4f}")

## 9. Visualize Predictions


In [None]:
def visualize_predictions(predictions, targets):
    fig = plt.figure(figsize=(15, 10))
    ax = fig.add_subplot(111, projection="3d")

    # Plot ground truth corners
    for i in range(4):
        ax.scatter(
            targets[:, i * 3],
            targets[:, i * 3 + 1],
            targets[:, i * 3 + 2],
            label=f"Ground Truth Corner {i+1}",
            marker="o",
        )

    # Plot predicted corners
    for i in range(4):
        ax.scatter(
            predictions[:, i * 3],
            predictions[:, i * 3 + 1],
            predictions[:, i * 3 + 2],
            label=f"Predicted Corner {i+1}",
            marker="x",
        )

    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    ax.set_title("Predicted vs Ground Truth Corners")
    ax.legend()
    plt.show()


visualize_predictions(val_predictions, val_targets)