In [None]:
# RSSM-based Delta Pose Predictor for Ultrasound Images
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import csv

In [None]:
# Set Data Root
data_root = '.'  # /path/to/data

# Set train/test folder names
train_dirs = {
    "frames_0513_06", "frames_0513_07", "frames_0513_08", "frames_0513_09", "frames_0513_10",
    "frames_0513_11", "frames_0513_12", "frames_0513_13", "frames_0513_14", "frames_0513_15",
    "frames_0513_16", "frames_0513_17", "frames_0513_18", "frames_0513_19", "frames_0513_20",
    "frames_0513_21", "frames_0513_22", "frames_0513_23", "frames_0513_24", "frames_0513_25", "frames_0513_26"
}

test_dirs = {
    "frames_0513_01", "frames_0513_02", "frames_0513_03", "frames_0513_04", "frames_0513_05"
}


In [None]:
# ========= Combine CSVs ============
def combine_pose_csvs_with_foldername(root_folder, output_csv="poses_combined.csv"):
    all_data = []

    for file in sorted(os.listdir(root_folder)):
        if not file.endswith("_final_data.csv"):
            continue

        csv_path = os.path.join(root_folder, file)
        df = pd.read_csv(csv_path)

        # ex: 0513_01_final_data.csv → frames_0513_01
        folder_name = "frames_" + file.replace("_final_data.csv", "")

        # Update Filename Column: → frames_0513_01/frame_0000.png
        df["Filename"] = df["Filename"].apply(lambda x: f"{folder_name}/{x}")
        all_data.append(df)

    if not all_data:
        print("⚠️ No Valid File Found")
        return

    combined_df = pd.concat(all_data, ignore_index=True)
    combined_df.to_csv(output_csv, index=False)
    print(f"✅ Saved Combined CSV to：{output_csv}")
    

In [None]:
combine_pose_csvs_with_foldername(data_root, "poses_combined.csv") 

In [None]:
# === Dataset ===
class SequenceUltrasoundDataset(Dataset):
    def __init__(self, csv_path, root_dirs, sequence_length=5, image_size=(256, 256)):
        self.samples = []
        df = pd.read_csv(csv_path)

        df['folder'] = df['Filename'].apply(lambda x: x.split('/')[0])
        df = df.rename(columns={
            'Filename': 'img_path',
            'X (mm)': 'tx', 'Y (mm)': 'ty', 'Z (mm)': 'tz',
            'Roll (deg)': 'rx', 'Pitch (deg)': 'ry', 'Yaw (deg)': 'rz'
        })

        for dir_ in root_dirs:
            group = df[df['folder'] == dir_].sort_values('img_path')
            frames = group.to_dict('records')
            for i in range(len(frames) - sequence_length):
                seq = frames[i:i + sequence_length + 1]
                self.samples.append(seq)

        self.sequence_length = sequence_length
        self.image_size = image_size
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(image_size),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        seq = self.samples[idx]
        imgs, poses, delta_poses = [], [], []
        for item in seq:
            img_path = item['img_path']
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_path}")
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                raise ValueError(f"Cannot read image: {img_path}")
            img = self.transform(img)
            pose = np.array([
                item['tx'], item['ty'], item['tz'],
                item['rx'], item['ry'], item['rz']
            ], dtype=np.float32)
            imgs.append(img)
            poses.append(pose)

        poses = torch.tensor(np.array(poses), dtype=torch.float32)
        delta_poses = poses[1:] - poses[:-1]

        return (
            torch.stack(imgs[:-1]),       # [T, 1, H, W]
            delta_poses                  # [T, 6]
        )

# === Encoder ===
class Encoder(nn.Module):
    def __init__(self, latent_dim=64):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1), nn.ReLU(),  # 256→128
            nn.Conv2d(32, 64, 4, stride=2, padding=1), nn.ReLU(),  # 128→64
            nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.ReLU(),  # 64→32
            nn.Conv2d(128, 256, 4, stride=2, padding=1), nn.ReLU()  # 32→16
        )
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(256 * 16 * 16, latent_dim)
        self.fc_logvar = nn.Linear(256 * 16 * 16, latent_dim)

    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z, mu, logvar

# === RSSM Core ===
class RSSMCore(nn.Module):
    def __init__(self, z_dim, h_dim):
        super().__init__()
        self.rnn = nn.GRUCell(z_dim, h_dim)

    def forward(self, z, h):
        return self.rnn(z, h)

# === Pose Decoder ===
class PoseDecoder(nn.Module):
    def __init__(self, h_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(h_dim, 128), nn.ReLU(),
            nn.Linear(128, 6)
        )

    def forward(self, h):
        return self.fc(h)

# === Frame Decoder ===
class FrameDecoder(nn.Module):
    def __init__(self, h_dim):
        super().__init__()
        self.fc = nn.Linear(h_dim, 128 * 16 * 16)
        self.deconv = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 4, 2, 1)
        )

    def forward(self, h):
        x = self.fc(h).view(-1, 128, 16, 16)
        x = self.deconv(x)
        return x

# === RSSM Model ===
class RSSMGoalDeltaPoseModel(nn.Module):
    def __init__(self, z_dim=64, h_dim=128):
        super().__init__()
        self.encoder = Encoder(z_dim)
        self.rssm = RSSMCore(z_dim, h_dim)
        self.pose_decoder = PoseDecoder(h_dim)
        self.frame_decoder = FrameDecoder(h_dim)
        self.z_dim = z_dim
        self.h_dim = h_dim

    def forward(self, obs_seq):  # [B, T, 1, H, W]
        B, T, C, H, W = obs_seq.shape
        h = torch.zeros(B, self.h_dim).to(obs_seq.device)
        pose_preds, img_preds = [], []
        total_kl = 0.0

        for t in range(T):
            z, mu, logvar = self.encoder(obs_seq[:, t])
            h = self.rssm(z, h)

            delta_pose = self.pose_decoder(h)
            recon_img = self.frame_decoder(h)

            pose_preds.append(delta_pose)
            img_preds.append(recon_img)

            kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
            total_kl += kl

        pose_preds = torch.stack(pose_preds, dim=1)
        img_preds = torch.stack(img_preds, dim=1)
        return pose_preds, img_preds, total_kl

# === Train Model ===
def train_model(csv_path, train_dirs, test_dirs, image_size=(256, 256), batch_size=32, epochs=20, lr=1e-4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_dataset = SequenceUltrasoundDataset(csv_path, train_dirs, sequence_length=10, image_size=image_size)
    test_dataset = SequenceUltrasoundDataset(csv_path, test_dirs, sequence_length=10, image_size=image_size)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    model = RSSMGoalDeltaPoseModel().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        # ==== Training ====
        model.train()
        total_train_loss = 0.0
        for obs_seq, delta_poses_seq in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
            obs_seq = obs_seq.to(device)
            delta_poses_seq = delta_poses_seq.to(device)

            optimizer.zero_grad()
            pose_preds, img_preds, kl_loss = model(obs_seq)

            loss_pose = F.smooth_l1_loss(pose_preds, delta_poses_seq)
            loss_recon = F.mse_loss(img_preds, obs_seq)
            loss = 2.0 * loss_pose + 0.5 * loss_recon + 0.1 * kl_loss
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # ==== Testing ====
        model.eval()
        total_test_loss = 0.0
        all_test_pred = []
        all_test_true = []

        with torch.no_grad():
            for obs_seq, delta_poses_seq in tqdm(test_loader, desc=f"Epoch {epoch+1} Testing"):
                obs_seq = obs_seq.to(device)
                delta_poses_seq = delta_poses_seq.to(device)

                pose_preds, img_preds, kl_loss = model(obs_seq)

                loss_pose = F.smooth_l1_loss(pose_preds, delta_poses_seq)
                loss_recon = F.mse_loss(img_preds, obs_seq)
                loss = 5.0 * loss_pose + 0.5 * loss_recon + 0.1 * kl_loss
                total_test_loss += loss.item()

                all_test_pred.append(pose_preds.cpu().numpy())
                all_test_true.append(delta_poses_seq.cpu().numpy())

        avg_test_loss = total_test_loss / len(test_loader)

        # ==== Save test predictions and ground truth as CSV ====
        all_test_pred_arr = np.concatenate(all_test_pred, axis=0)  # [num_batches*B, T, 6]
        all_test_true_arr = np.concatenate(all_test_true, axis=0)

        all_test_pred_flat = all_test_pred_arr.reshape(-1, 6)
        all_test_true_flat = all_test_true_arr.reshape(-1, 6)

        df_test_results = pd.DataFrame({
            'pred_tx': all_test_pred_flat[:, 0],
            'pred_ty': all_test_pred_flat[:, 1],
            'pred_tz': all_test_pred_flat[:, 2],
            'pred_rx': all_test_pred_flat[:, 3],
            'pred_ry': all_test_pred_flat[:, 4],
            'pred_rz': all_test_pred_flat[:, 5],
            'true_tx': all_test_true_flat[:, 0],
            'true_ty': all_test_true_flat[:, 1],
            'true_tz': all_test_true_flat[:, 2],
            'true_rx': all_test_true_flat[:, 3],
            'true_ry': all_test_true_flat[:, 4],
            'true_rz': all_test_true_flat[:, 5],
        })

        csv_filename = f'test_pred_true_epoch_{epoch+1}.csv'
        df_test_results.to_csv(csv_filename, index=False)

        print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")
        print(f"Test predictions and true values saved to {csv_filename}")

In [None]:
# === Run ===
if __name__ == "__main__":
    train_model(
        csv_path="poses_combined.csv",
        train_dirs=train_dirs,
        test_dirs=test_dirs,
        epochs=20,
        batch_size=64,
        lr=1e-4,
        image_size=(256, 256)
    )

In [None]:
# === Predict ===
def predict_action(current_img, goal_img, model_path="rssm_model.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = RSSMGoalDeltaPoseModel().to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    current_img = transform(current_img).unsqueeze(0).to(device)  # (1, 1, H, W)
    goal_img = transform(goal_img).unsqueeze(0).to(device)

    with torch.no_grad():
        delta_pose, _ = model(current_img, goal_img)
    
    return delta_pose.squeeze(0).cpu().numpy()