# %% [markdown]

 # 3D Ultrasound Volume Reconstruction with DCL-Net

In [None]:
# %%

import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
from glob import glob
from natsort import natsorted
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from torch.utils.data import Dataset, DataLoader
from torchvision.models import efficientnet_b1
from scipy.ndimage import zoom

# Import custom modules
from freehand.network import build_model
from freehand.transform import (
    LabelTransform,
    PredictionTransform,
    TransformAccumulation,
)
from freehand.utils import pair_samples, reference_image_points, type_dim
from freehand.loss import PointDistance

ModuleNotFoundError: No module named 'torchvision'

# %% [markdown]

 ## Data Loading and Preprocessing

In [None]:
# %%


class UltrasoundDataset(Dataset):
    def __init__(self, data_dir, scan_name, num_samples=10, sample_range=10):
        self.data_dir = data_dir
        self.scan_name = scan_name
        self.num_samples = num_samples
        self.sample_range = sample_range

        self.frames = np.load(f"{data_dir}{scan_name}/{scan_name}_frames.npy")
        self.tracker_data = np.loadtxt(f"{data_dir}{scan_name}/{scan_name}_pos.txt")

        # Ensure frames are in (C, H, W) format
        if self.frames.ndim == 3:
            self.frames = np.expand_dims(self.frames, axis=0)
        elif self.frames.ndim == 4 and self.frames.shape[-1] in [1, 3]:
            self.frames = np.moveaxis(self.frames, -1, 0)

        # Normalize frames
        self.frames = self.frames.astype(np.float32) / 255.0

    def __len__(self):
        return len(self.tracker_data) - self.sample_range + 1

    def __getitem__(self, idx):
        frames = torch.tensor(
            self.frames[:, idx : idx + self.sample_range], dtype=torch.float32
        )
        tracker_data = torch.tensor(
            self.tracker_data[idx : idx + self.sample_range, 2:], dtype=torch.float32
        )

        return frames, tracker_data

: 

In [None]:
# %%

DATA_DIR = "data/"
SCAN_NAME = "phantom"
dataset = UltrasoundDataset(DATA_DIR, SCAN_NAME)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

: 

# %% [markdown]

 ## Model Definition and Training

In [None]:
# %%

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model parameters
NUM_SAMPLES = 10
NUM_PRED = 9
PRED_TYPE = "parameter"
LABEL_TYPE = "point"

# Training parameters
LEARNING_RATE = 1e-4
NUM_EPOCHS = 100

# Create model
frame_size = dataset.frames.shape[2:]
image_points = reference_image_points(frame_size, 2).to(device)
data_pairs = pair_samples(NUM_SAMPLES, NUM_PRED).to(device)
pred_dim = type_dim(PRED_TYPE, image_points.shape[1], data_pairs.shape[0])

model = build_model(efficientnet_b1, in_frames=NUM_SAMPLES, out_dim=pred_dim).to(device)

# Define loss and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

: 

In [None]:
# %%


def train_model(model, dataloader, criterion, optimizer, num_epochs, device):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (frames, tracker_data) in enumerate(dataloader):
            frames, tracker_data = frames.to(device), tracker_data.to(device)

            optimizer.zero_grad()
            outputs = model(frames)

            # Compute relative transformations
            relative_transformations = tracker_data[:, 1:] - tracker_data[:, :-1]
            relative_transformations = relative_transformations.reshape(outputs.shape)

            loss = criterion(outputs, relative_transformations)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(dataloader)}")

    return model

: 

In [None]:
# %%

trained_model = train_model(model, dataloader, criterion, optimizer, NUM_EPOCHS, device)

: 

# %% [markdown]

 ## 3D Volume Reconstruction and Visualization

In [None]:
# %%


def reconstruct_volume(model, dataset, device, num_frames=100):
    model.eval()
    volume = []
    current_position = torch.zeros(1, 3).to(device)
    current_orientation = torch.zeros(1, 4).to(device)  # Using quaternions

    with torch.no_grad():
        for i in range(
            0,
            min(len(dataset), num_frames - dataset.num_samples + 1),
            dataset.num_samples - 1,
        ):
            frames, _ = dataset[i]
            frames = frames.unsqueeze(0).to(device)
            predicted_params = model(frames)

            # Accumulate transformations
            current_position += predicted_params[0, :3]
            current_orientation += predicted_params[0, 3:]

            # Add the middle frame of the sequence to the volume
            middle_frame = frames[0, :, dataset.num_samples // 2].cpu().numpy()
            volume.append(
                (
                    middle_frame,
                    current_position.cpu().numpy(),
                    current_orientation.cpu().numpy(),
                )
            )

    return volume

: 

In [None]:
# %%


def visualize_3d_volume(volume):
    fig = plt.figure(figsize=(15, 10))
    ax = fig.add_subplot(111, projection="3d")

    for i, (frame, position, orientation) in enumerate(volume):
        # Create a plane representing the ultrasound frame
        X, Y = np.meshgrid(range(frame.shape[1]), range(frame.shape[0]))
        Z = np.full_like(X, i)

        # Apply transformations (simplified - you might want to use proper rotation matrices)
        X_trans = X + position[0, 0]
        Y_trans = Y + position[0, 1]
        Z_trans = Z + position[0, 2]

        # Plot the transformed frame
        ax.plot_surface(
            X_trans, Y_trans, Z_trans, facecolors=plt.cm.gray(frame[0]), shade=False
        )

    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    ax.set_title("Reconstructed 3D Ultrasound Volume")

    plt.show()

: 

In [None]:
# %%

reconstructed_volume = reconstruct_volume(trained_model, dataset, device)
visualize_3d_volume(reconstructed_volume)

: 