In [None]:
# RSSM-based Delta Pose Predictor for Ultrasound Images
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import csv

In [None]:
# Set Data Root
data_root = '.'  # /path/to/data

# Set train/test folder names
train_dirs = {
    "frames_0513_06", "frames_0513_07", "frames_0513_08", "frames_0513_09", "frames_0513_10",
    "frames_0513_11", "frames_0513_12", "frames_0513_13", "frames_0513_14", "frames_0513_15",
    "frames_0513_16", "frames_0513_17", "frames_0513_18", "frames_0513_19", "frames_0513_20",
    "frames_0513_21", "frames_0513_22", "frames_0513_23", "frames_0513_24", "frames_0513_25", "frames_0513_26"
}

test_dirs = {
    "frames_0513_01", "frames_0513_02", "frames_0513_03", "frames_0513_04", "frames_0513_05"
}


In [None]:
# ========= Combine CSVs ============
def combine_pose_csvs_with_foldername(root_folder, output_csv="poses_combined.csv"):
    all_data = []

    for file in sorted(os.listdir(root_folder)):
        if not file.endswith("_final_data.csv"):
            continue

        csv_path = os.path.join(root_folder, file)
        df = pd.read_csv(csv_path)

        # ex: 0513_01_final_data.csv → frames_0513_01
        folder_name = "frames_" + file.replace("_final_data.csv", "")

        # Update Filename Column: → frames_0513_01/frame_0000.png
        df["Filename"] = df["Filename"].apply(lambda x: f"{folder_name}/{x}")
        all_data.append(df)

    if not all_data:
        print("⚠️ No Valid File Found")
        return

    combined_df = pd.concat(all_data, ignore_index=True)
    combined_df.to_csv(output_csv, index=False)
    print(f"✅ Saved Combined CSV to：{output_csv}")
    

In [None]:
combine_pose_csvs_with_foldername(data_root, "poses_combined.csv") 

In [None]:
# === Dataset ===
class SequenceUltrasoundDataset(Dataset):
    def __init__(self, csv_path, root_dirs, sequence_length=5, image_size=(256, 256)):
        self.samples = []
        df = pd.read_csv(csv_path)

        df['folder'] = df['Filename'].apply(lambda x: x.split('/')[0])
        df = df.rename(columns={
            'Filename': 'img_path',
            'X (mm)': 'tx', 'Y (mm)': 'ty', 'Z (mm)': 'tz',
            'Roll (deg)': 'rx', 'Pitch (deg)': 'ry', 'Yaw (deg)': 'rz'
        })

        for dir_ in root_dirs:
            group = df[df['folder'] == dir_].sort_values('img_path')
            frames = group.to_dict('records')
            for i in range(len(frames) - sequence_length):
                seq = frames[i:i + sequence_length + 1]
                self.samples.append(seq)

        self.sequence_length = sequence_length
        self.image_size = image_size
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(image_size),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        seq = self.samples[idx]
        imgs, poses, delta_poses = [], [], []
        for item in seq:
            img_path = item['img_path']
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_path}")
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                raise ValueError(f"Cannot read image: {img_path}")
            img = self.transform(img)
            pose = np.array([
                item['tx'], item['ty'], item['tz'],
                item['rx'], item['ry'], item['rz']
            ], dtype=np.float32)
            imgs.append(img)
            poses.append(pose)

        poses = torch.tensor(np.array(poses), dtype=torch.float32)
        delta_poses = poses[1:] - poses[:-1]

        return (
            torch.stack(imgs[:-1]),       # [T, 1, H, W]
            delta_poses                  # [T, 6]
        )

# === Encoder ===
class Encoder(nn.Module):
    def __init__(self, latent_dim=64):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1), nn.InstanceNorm2d(32), nn.ReLU(),  # 256→128
            nn.Conv2d(32, 64, 4, stride=2, padding=1), nn.InstanceNorm2d(64), nn.ReLU(),  # 128→64
            nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.InstanceNorm2d(128), nn.ReLU(),  # 64→32
            nn.Conv2d(128, 256, 4, stride=2, padding=1), nn.InstanceNorm2d(256), nn.ReLU()  # 32→16
        )
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(256 * 16 * 16, latent_dim)
        self.fc_logvar = nn.Linear(256 * 16 * 16, latent_dim)
        # self.dropout = nn.Dropout(p=0.2)

    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z, mu, logvar

# === RSSM Core ===
class RSSMCore(nn.Module):
    def __init__(self, action_dim, z_dim, h_dim, embed_dim):
        super().__init__()
        self.z_dim = z_dim
        self.h_dim = h_dim

        self.project_action_z = nn.Linear(z_dim + action_dim, h_dim)
        self.gru = nn.GRUCell(h_dim, h_dim)

        self.project_hidden_action = nn.Linear(h_dim + action_dim, h_dim)
        self.prior = nn.Linear(h_dim, z_dim * 2)

        self.project_hidden_obs = nn.Linear(h_dim + embed_dim, h_dim)
        self.posterior = nn.Linear(h_dim, z_dim * 2)

        self.activation = nn.ReLU()

    def forward(self, prev_z, prev_h, actions, embeddings=None, dones=None):
        B, T, _ = actions.size()
        h, z = prev_h, prev_z

        h_seq, z_seq, prior_mean_seq, prior_std_seq = [], [], [], []
        post_mean_seq, post_std_seq = [], []

        for t in range(T):
            a = actions[:, t]
            e = embeddings[:, t] if embeddings is not None else None

            # Reset z if done
            if dones is not None:
                z = z * (1.0 - dones[:, t])

            x = torch.cat([z, a], dim=-1)
            x = self.activation(self.project_action_z(x))
            h = self.gru(x, h)

            # Prior
            ha = torch.cat([h, a], dim=-1)
            ha = self.activation(self.project_hidden_action(ha))
            prior_params = self.prior(ha)
            prior_mean, prior_logstd = torch.chunk(prior_params, 2, dim=-1)
            prior_std = F.softplus(prior_logstd)
            prior_dist = torch.distributions.Normal(prior_mean, prior_std)
            prior_z = prior_dist.rsample()

            # Posterior
            if embeddings is not None:
                he = torch.cat([h, e], dim=-1)
                he = self.activation(self.project_hidden_obs(he))
                post_params = self.posterior(he)
                post_mean, post_logstd = torch.chunk(post_params, 2, dim=-1)
                post_std = F.softplus(post_logstd)
                post_dist = torch.distributions.Normal(post_mean, post_std)
                post_z = post_dist.rsample()
            else:
                post_z = prior_z
                post_mean, post_std = prior_mean, prior_std

            z = post_z

            # Collect for each timestep
            h_seq.append(h.unsqueeze(1))
            z_seq.append(z.unsqueeze(1))
            prior_mean_seq.append(prior_mean.unsqueeze(1))
            prior_std_seq.append(prior_std.unsqueeze(1))
            post_mean_seq.append(post_mean.unsqueeze(1))
            post_std_seq.append(post_std.unsqueeze(1))

        return {
            'h': torch.cat(h_seq, dim=1),
            'z': torch.cat(z_seq, dim=1),
            'prior_mean': torch.cat(prior_mean_seq, dim=1),
            'prior_std': torch.cat(prior_std_seq, dim=1),
            'post_mean': torch.cat(post_mean_seq, dim=1),
            'post_std': torch.cat(post_std_seq, dim=1),
        }

# === Pose Decoder ===
class PoseDecoder(nn.Module):
    def __init__(self, h_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(h_dim, 128), nn.ReLU(),
            nn.Linear(128, 6)
        )

    def forward(self, h):
        return self.fc(h)

# === Frame Decoder ===
class FrameDecoder(nn.Module):
    def __init__(self, h_dim):
        super().__init__()
        self.fc = nn.Linear(h_dim, 128 * 16 * 16)
        self.deconv = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 4, 2, 1)
        )

    def forward(self, h):
        x = self.fc(h).view(-1, 128, 16, 16)
        x = self.deconv(x)
        return x

# === RSSM Model ===
class RSSMGoalDeltaPoseModel(nn.Module):
    def __init__(self, z_dim=64, h_dim=128, action_dim=6, embed_dim=64):
        super().__init__()
        self.encoder = Encoder(z_dim)
        self.rssm = RSSMCore(action_dim, z_dim, h_dim, embed_dim)
        self.pose_decoder = PoseDecoder(h_dim)
        self.frame_decoder = FrameDecoder(h_dim)
        self.z_dim = z_dim
        self.h_dim = h_dim

    def forward(self, obs_seq, delta_pose_seq):  # obs_seq: [B, T, 1, H, W], delta_pose_seq: [B, T, 6]
        B, T, C, H, W = obs_seq.shape
        device = obs_seq.device
        h = torch.zeros(B, self.h_dim, device=device)
        z = torch.zeros(B, self.z_dim, device=device)

        embeddings, mus, logvars = [], [], []
        for t in range(T):
            z_t, mu_t, logvar_t = self.encoder(obs_seq[:, t])
            embeddings.append(z_t.unsqueeze(1))
            mus.append(mu_t.unsqueeze(1))
            logvars.append(logvar_t.unsqueeze(1))
        embeddings = torch.cat(embeddings, dim=1)  # [B, T, z_dim]
        mus = torch.cat(mus, dim=1)
        logvars = torch.cat(logvars, dim=1)

        rssm_out = self.rssm(z, h, delta_pose_seq, embeddings=embeddings)

        pose_preds = self.pose_decoder(rssm_out['h'])  # [B, T, 6]
        img_preds = self.frame_decoder(rssm_out['h'].reshape(-1, self.h_dim)).reshape(B, T, 1, H, W)

        kl_loss = -0.5 * torch.sum(1 + logvars - mus.pow(2) - logvars.exp(), dim=-1).mean()
        return pose_preds, img_preds, kl_loss

# === Train Model ===
def train_model(csv_path, train_dirs, test_dirs, image_size=(256, 256), batch_size=32, epochs=20, lr=1e-4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_dataset = SequenceUltrasoundDataset(csv_path, train_dirs, sequence_length=10, image_size=image_size)
    test_dataset = SequenceUltrasoundDataset(csv_path, test_dirs, sequence_length=10, image_size=image_size)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    model = RSSMGoalDeltaPoseModel().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        # ==== Training ====
        model.train()
        total_train_loss = 0.0
        for obs_seq, delta_poses_seq in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
            obs_seq = obs_seq.to(device)
            delta_poses_seq = delta_poses_seq.to(device)

            optimizer.zero_grad()
            pose_preds, img_preds, kl_loss = model(obs_seq, delta_poses_seq)

            loss_pose = F.smooth_l1_loss(pose_preds, delta_poses_seq)
            loss_recon = F.mse_loss(img_preds, obs_seq)
            loss = 2.0 * loss_pose + 0.5 * loss_recon + 0.1 * kl_loss
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # ==== Testing ====
        model.eval()
        total_test_loss = 0.0
        all_test_pred = []
        all_test_true = []

        with torch.no_grad():
            for obs_seq, delta_poses_seq in tqdm(test_loader, desc=f"Epoch {epoch+1} Testing"):
                obs_seq = obs_seq.to(device)
                delta_poses_seq = delta_poses_seq.to(device)

                pose_preds, img_preds, kl_loss = model(obs_seq, delta_poses_seq)

                loss_pose = F.smooth_l1_loss(pose_preds, delta_poses_seq)
                loss_recon = F.mse_loss(img_preds, obs_seq)
                loss = 5.0 * loss_pose + 0.5 * loss_recon + 0.1 * kl_loss
                total_test_loss += loss.item()

                all_test_pred.append(pose_preds.cpu().numpy())
                all_test_true.append(delta_poses_seq.cpu().numpy())

        avg_test_loss = total_test_loss / len(test_loader)

        # ==== Save test predictions and ground truth as CSV ====
        all_test_pred_arr = np.concatenate(all_test_pred, axis=0)  # [num_batches*B, T, 6]
        all_test_true_arr = np.concatenate(all_test_true, axis=0)

        all_test_pred_flat = all_test_pred_arr.reshape(-1, 6)
        all_test_true_flat = all_test_true_arr.reshape(-1, 6)

        df_test_results = pd.DataFrame({
            'pred_tx': all_test_pred_flat[:, 0],
            'pred_ty': all_test_pred_flat[:, 1],
            'pred_tz': all_test_pred_flat[:, 2],
            'pred_rx': all_test_pred_flat[:, 3],
            'pred_ry': all_test_pred_flat[:, 4],
            'pred_rz': all_test_pred_flat[:, 5],
            'true_tx': all_test_true_flat[:, 0],
            'true_ty': all_test_true_flat[:, 1],
            'true_tz': all_test_true_flat[:, 2],
            'true_rx': all_test_true_flat[:, 3],
            'true_ry': all_test_true_flat[:, 4],
            'true_rz': all_test_true_flat[:, 5],
        })
        # === Accuracy Calculation for each of the 6 components ===
        epsilon = 1e-6  

        # flatten prediction and ground truth: [N, 6]
        pred = all_test_pred_flat
        true = all_test_true_flat

        # Bound: 50% ~ 150% of Ground Truth 
        within_bounds = (pred >= 0.5 * (true + epsilon)) & (pred <= 1.5 * (true + epsilon))

        true_zero = np.abs(true) < epsilon
        pred_zero = np.abs(pred) < epsilon
        zero_match = true_zero & pred_zero

        correct_mask = within_bounds | zero_match  # shape [N, 6]

        component_names = ['tx', 'ty', 'tz', 'rx', 'ry', 'rz']
        component_accuracies = {
            name: correct_mask[:, i].mean()
            for i, name in enumerate(component_names)
        }

        csv_filename = f'new_test_pred_true_epoch_{epoch+1}.csv'
        df_test_results.to_csv(csv_filename, index=False)

        print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")
        print(f"Per-component Accuracy (within 50%~150% of GT):")
        for name in component_names:
            print(f"  {name}: {component_accuracies[name]:.4f}")
        print(f"Test predictions and true values saved to {csv_filename}")

In [None]:
# === Run ===
if __name__ == "__main__":
    train_model(
        csv_path="poses_combined.csv",
        train_dirs=train_dirs,
        test_dirs=test_dirs,
        epochs=20,
        batch_size=64,
        lr=1e-4,
        image_size=(256, 256)
    )

In [None]:
# === Predict ===
def predict_action(current_img, goal_img, model_path="rssm_model.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = RSSMGoalDeltaPoseModel().to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    current_img = transform(current_img).unsqueeze(0).to(device)  # (1, 1, H, W)
    goal_img = transform(goal_img).unsqueeze(0).to(device)

    with torch.no_grad():
        delta_pose, _ = model(current_img, goal_img)
    
    return delta_pose.squeeze(0).cpu().numpy()