In [None]:
# RSSM-based Delta Pose Predictor for Ultrasound Images
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import csv

In [None]:
# Set Data Root
data_root = '.'  # /path/to/data

# Set train/test folder names
train_dirs = {
    "frames_0513_06", "frames_0513_07", "frames_0513_08", "frames_0513_09", "frames_0513_10",
    "frames_0513_11", "frames_0513_12", "frames_0513_13", "frames_0513_14", "frames_0513_15",
    "frames_0513_16", "frames_0513_17", "frames_0513_18", "frames_0513_19", "frames_0513_20",
    "frames_0513_21", "frames_0513_22", "frames_0513_23", "frames_0513_24", "frames_0513_25", "frames_0513_26"
}

test_dirs = {
    "frames_0513_01", "frames_0513_02", "frames_0513_03", "frames_0513_04", "frames_0513_05"
}


In [None]:
# ========= Combine CSVs ============
def combine_pose_csvs_with_foldername(root_folder, output_csv="poses_combined.csv"):
    all_data = []

    for file in sorted(os.listdir(root_folder)):
        if not file.endswith("_final_data.csv"):
            continue

        csv_path = os.path.join(root_folder, file)
        df = pd.read_csv(csv_path)

        # ex: 0513_01_final_data.csv → frames_0513_01
        folder_name = "frames_" + file.replace("_final_data.csv", "")

        # Update Filename Column: → frames_0513_01/frame_0000.png
        df["Filename"] = df["Filename"].apply(lambda x: f"{folder_name}/{x}")
        all_data.append(df)

    if not all_data:
        print("⚠️ No Valid File Found")
        return

    combined_df = pd.concat(all_data, ignore_index=True)
    combined_df.to_csv(output_csv, index=False)
    print(f"✅ Saved Combined CSV to：{output_csv}")
    

In [None]:
combine_pose_csvs_with_foldername(data_root, "poses_combined.csv") 

In [None]:
def save_predictions_to_csv(preds_list, targets_list, epoch):
    filename = f"predictions_epoch_{epoch+1}.csv"
    with open(filename, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([
            "Pred_X", "Pred_Y", "Pred_Z", "Pred_Roll", "Pred_Pitch", "Pred_Yaw",
            "True_X", "True_Y", "True_Z", "True_Roll", "True_Pitch", "True_Yaw"
        ])
        for pred, target in zip(preds_list, targets_list):
            row = pred.tolist() + target.tolist()
            writer.writerow(row)
    print(f"✅ Predictions saved to {filename}")
    

In [None]:
# =========================== Dataset ===========================
class UltrasoundGoalPoseDataset(Dataset):
    def __init__(self, csv_path, image_size=(256, 256), transform=None, include_dirs=None):
        self.data = pd.read_csv(csv_path)
        self.transform = transform
        self.image_size = image_size

        self.sequences = {}
        for _, row in self.data.iterrows():
            folder = row['Filename'].split("/")[0]
            if include_dirs is not None and folder not in include_dirs:
                continue
            if folder not in self.sequences:
                self.sequences[folder] = []
            self.sequences[folder].append(row)

        for key in self.sequences:
            self.sequences[key] = sorted(self.sequences[key], key=lambda x: x['Filename'])

        self.pairs = []
        for seq in self.sequences.values():
            if len(seq) < 16:
                continue
            goal_views = seq[-15:]
            mean_pose = np.mean([[g['X (mm)'], g['Y (mm)'], g['Z (mm)'],
                                  g['Roll (deg)'], g['Pitch (deg)'], g['Yaw (deg)']] for g in goal_views], axis=0)

            for i in range(len(seq) - 15):
                curr = seq[i]
                self.pairs.append((curr, mean_pose))

        delta_poses = [
            mean_pose - np.array([
                row['X (mm)'], row['Y (mm)'], row['Z (mm)'],
                row['Roll (deg)'], row['Pitch (deg)'], row['Yaw (deg)']
            ])
            for row, mean_pose in self.pairs
        ]
        self.mean = torch.tensor(np.mean(delta_poses, axis=0), dtype=torch.float32)
        self.std = torch.tensor(np.std(delta_poses, axis=0), dtype=torch.float32)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        curr, goal_pose = self.pairs[idx]

        def load_img(row):
            img = Image.open(row['Filename']).convert('L').resize(self.image_size)
            return self.transform(img) if self.transform else transforms.ToTensor()(img)

        curr_img = load_img(curr)
        goal_img = curr_img.clone()

        delta_pose = torch.tensor([
            goal_pose[0] - curr['X (mm)'],
            goal_pose[1] - curr['Y (mm)'],
            goal_pose[2] - curr['Z (mm)'],
            goal_pose[3] - curr['Roll (deg)'],
            goal_pose[4] - curr['Pitch (deg)'],
            goal_pose[5] - curr['Yaw (deg)'],
        ], dtype=torch.float32)

        delta_pose = (delta_pose - self.mean) / self.std

        return torch.cat([curr_img, goal_img], dim=0), delta_pose

# =========================== RSSM Components ===========================
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(2, 32, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.ReLU(),
            nn.Flatten(),
        )
        self.linear = nn.Linear(128 * 32 * 32, 256)

    def forward(self, x):
        return self.linear(self.conv(x))

class RSSM(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.GRUCell(256, 256)
        self.fc = nn.Linear(256, 6)

    def forward(self, z, h):
        h = self.rnn(z, h)
        return self.fc(h), h

class RSSMModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.rssm = RSSM()

    def forward(self, x, h):
        z = self.encoder(x)
        out, h = self.rssm(z, h)
        return out, h

# =========================== Accuracy ===========================
def compute_pose_accuracy(preds, targets):
    abs_error = torch.abs(preds - targets)
    mean_error = abs_error.mean(dim=0)
    total_mean = mean_error.mean().item()
    return 1.0 / (1.0 + total_mean)

# =========================== Training ===========================
def train_model(csv_path, train_dirs, test_dirs, image_size=(256, 256), epochs=30, batch_size=32, lr=1e-4):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor()
    ])

    train_dataset = UltrasoundGoalPoseDataset(csv_path, include_dirs=train_dirs, image_size=image_size, transform=transform)
    test_dataset = UltrasoundGoalPoseDataset(csv_path, include_dirs=test_dirs, image_size=image_size, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"✅ Using device: {device}")
    model = RSSMModel().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.SmoothL1Loss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        total_acc = 0.0

        for imgs, deltas in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            imgs, deltas = imgs.to(device), deltas.to(device)
            h = torch.zeros(imgs.size(0), 256).to(device)
            preds, h = model(imgs, h.detach())
            loss = loss_fn(preds, deltas)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * imgs.size(0)
            acc = compute_pose_accuracy(preds.detach(), deltas)
            total_acc += acc * imgs.size(0)

        avg_loss = total_loss / len(train_dataset)
        avg_acc = total_acc / len(train_dataset)
        print(f"Epoch {epoch+1} Train Loss: {avg_loss:.4f} | Train Accuracy: {avg_acc:.4f}")

        # Evaluation
        model.eval()
        test_loss = 0.0
        test_acc = 0.0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for imgs, deltas in test_loader:
                imgs, deltas = imgs.to(device), deltas.to(device)
                h_test = torch.zeros(imgs.size(0), 256).to(device)
                preds, h_test = model(imgs, h_test.detach())
                loss = loss_fn(preds, deltas)
                test_loss += loss.item() * imgs.size(0)

                acc = compute_pose_accuracy(preds, deltas)
                test_acc += acc * imgs.size(0)

                all_preds.extend(preds.cpu())
                all_targets.extend(deltas.cpu())

        avg_test_loss = test_loss / len(test_dataset)
        avg_test_acc = test_acc / len(test_dataset)
        print(f"Epoch {epoch+1} Test Loss: {avg_test_loss:.4f} | Test Accuracy: {avg_test_acc:.4f}")

        # Save Predictions
        save_predictions_to_csv(all_preds, all_targets, epoch)

        torch.save(model.state_dict(), f"rssm_model_epoch_{epoch+1}.pth")

    print("✅ Training Completed!")


In [None]:
# =========================== Run ===========================
if __name__ == "__main__":
    train_model(
        csv_path="poses_combined.csv",
        train_dirs=train_dirs,
        test_dirs=test_dirs,
        epochs=40,
        batch_size=64,
        lr=1e-4,
        image_size=(256, 256)
    )