In [None]:
# RSSM-based Delta Pose Predictor for Ultrasound Images
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import random

In [None]:
# Set Data Root
data_root = '.'  # /path/to/data

# Set train/test folder names
train_dirs = {
    "0513_06", "0513_07", "0513_08", "0513_09", "0513_10",
    "0513_11", "0513_12", "0513_13", "0513_14", "0513_15",
    "0513_16", "0513_17", "0513_18", "0513_19", "0513_20",
    "0513_21", "0513_22", "0513_23", "0513_24", "0513_25", "0513_26"
}

test_dirs = {
    "0513_01", "0513_02", "0513_03", "0513_04", "0513_05"
}

In [None]:
# ========= Combine CSVs ============
def combine_pose_csvs_with_foldername(root_folder, output_csv="poses_combined.csv"):
    all_data = []

    for file in sorted(os.listdir(root_folder)):
        if not file.endswith("_final_data.csv"):
            continue

        csv_path = os.path.join(root_folder, file)
        df = pd.read_csv(csv_path)

        # ex: 0513_01_final_data.csv → frames_0513_01
        folder_name = "frames_" + file.replace("_final_data.csv", "")

        # Update Filename Column: → frames_0513_01/frame_0000.png
        df["Filename"] = df["Filename"].apply(lambda x: f"{folder_name}/{x}")
        all_data.append(df)

    if not all_data:
        print("⚠️ No Valid File Found")
        return

    combined_df = pd.concat(all_data, ignore_index=True)
    combined_df.to_csv(output_csv, index=False)
    print(f"✅ Saved Combined CSV to：{output_csv}")

In [None]:
combine_pose_csvs_with_foldername(data_root, "poses_combined.csv") 

In [None]:
# =========================== Dataset ===========================
class UltrasoundGoalPoseDataset(Dataset):
    def __init__(self, csv_path, image_size=(256, 256), transform=None, include_dirs=None):
        self.data = pd.read_csv(csv_path)
        self.transform = transform
        self.image_size = image_size

        self.sequences = {}
        for _, row in self.data.iterrows():
            folder = row['Filename'].split("/")[0]
            if include_dirs is not None and folder not in include_dirs:
                continue
            if folder not in self.sequences:
                self.sequences[folder] = []
            self.sequences[folder].append(row)

        for key in self.sequences:
            self.sequences[key] = sorted(self.sequences[key], key=lambda x: x['Filename'])

        self.pairs = []
        for seq in self.sequences.values():
            if len(seq) < 16:
                continue
            goal_views = seq[-15:]
            for i in range(len(seq) - 15):
                curr = seq[i]
                goal = random.choice(goal_views)
                self.pairs.append((curr, goal))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        curr, goal = self.pairs[idx]

        def load_img(row):
            img = Image.open(row['Filename']).convert('L').resize(self.image_size)
            return self.transform(img) if self.transform else transforms.ToTensor()(img)

        curr_img = load_img(curr)
        goal_img = load_img(goal)

        delta_pose = torch.tensor([
            goal['X (mm)'] - curr['X (mm)'],
            goal['Y (mm)'] - curr['Y (mm)'],
            goal['Z (mm)'] - curr['Z (mm)'],
            goal['Roll (deg)'] - curr['Roll (deg)'],
            goal['Pitch (deg)'] - curr['Pitch (deg)'],
            goal['Yaw (deg)'] - curr['Yaw (deg)']
        ], dtype=torch.float32)

        return torch.cat([curr_img, goal_img], dim=0), delta_pose

# =========================== RSSM Components ===========================
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(2, 32, 4, 2, 1),  # curr + goal
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.ReLU(),
            nn.Flatten(),
        )
        self.linear = nn.Linear(128 * 32 * 32, 256)

    def forward(self, x):
        return self.linear(self.conv(x))

class RSSM(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.GRUCell(256, 256)
        self.fc = nn.Linear(256, 6)  # output delta pose

    def forward(self, z, h):
        h = self.rnn(z, h)
        return self.fc(h), h

class RSSMModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.rssm = RSSM()

    def forward(self, x, h):
        z = self.encoder(x)
        out, h = self.rssm(z, h)
        return out, h

# =========================== Training ===========================
def train_model(csv_path, train_dirs, test_dirs, image_size=(256, 256), epochs=30, batch_size=32, lr=1e-4):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor()
    ])

    train_dataset = UltrasoundGoalPoseDataset(csv_path, include_dirs=train_dirs, image_size=image_size, transform=transform)
    test_dataset = UltrasoundGoalPoseDataset(csv_path, include_dirs=test_dirs, image_size=image_size, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"✅ Using device: {device}")
    model = RSSMModel().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.SmoothL1Loss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        h = torch.zeros(batch_size, 256).to(device)

        for imgs, deltas in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            imgs, deltas = imgs.to(device), deltas.to(device)

            preds, h = model(imgs, h.detach())
            loss = loss_fn(preds, deltas)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * imgs.size(0)

        avg_loss = total_loss / len(train_dataset)
        print(f"Epoch {epoch+1} Train Loss: {avg_loss:.4f}")

        # Evaluate
        model.eval()
        test_loss = 0.0
        h_test = torch.zeros(batch_size, 256).to(device)

        with torch.no_grad():
            for imgs, deltas in test_loader:
                imgs, deltas = imgs.to(device), deltas.to(device)
                preds, h_test = model(imgs, h_test.detach())
                loss = loss_fn(preds, deltas)
                test_loss += loss.item() * imgs.size(0)

        avg_test_loss = test_loss / len(test_dataset)
        print(f"Epoch {epoch+1} Test Loss: {avg_test_loss:.4f}")

        torch.save(model.state_dict(), f"rssm_model_epoch_{epoch+1}.pth")

    print("✅ Training Completed!")


In [None]:
# =========================== Run ===========================
if __name__ == "__main__":
    train_model(
        csv_path="poses_combined.csv",
        train_dirs=train_dirs,
        test_dirs=test_dirs,
        epochs=20,
        batch_size=64,
        lr=1e-4,
        image_size=(256, 256)
    )