# %% [markdown]

 # DCL-Net Training for Sensorless Freehand 3D Ultrasound Reconstruction

 This notebook implements the training process for DCL-Net as described in the paper. We'll go through the following steps:
 1. Import required libraries
 2. Read and preprocess data
 3. Define the DCL-Net model
 4. Implement the loss function
 5. Set up the training loop
 6. Train the model
 7. Evaluate the model
 8. Reconstruct the 3D volume

# %% [markdown]

 ## 1. Import required libraries

In [None]:
# %%

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import cv2
import SimpleITK as sitk
from tqdm import tqdm
import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'cv2'

# %% [markdown]

 # DCL-Net Training for Sensorless Freehand 3D Ultrasound Reconstruction

 This notebook implements the training process for DCL-Net as described in the paper. We'll go through the following steps:
 1. Import required libraries
 2. Read and preprocess data
 3. Define the DCL-Net model
 4. Implement the loss function
 5. Set up the training loop
 6. Train the model
 7. Evaluate the model
 8. Reconstruct the 3D volume

# %% [markdown]

 ## 1. Import required libraries

In [None]:
# %%

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import cv2
import SimpleITK as sitk
from tqdm import tqdm
import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'cv2'

# %% [markdown]

 ## 2. Read and preprocess data

In [None]:
# %%

class USDataset(Dataset):
    def __init__(self, video_path, imu_csv_path, tracker_csv_path, num_frames=5):
        self.num_frames = num_frames

        # Read video
        self.cap = cv2.VideoCapture(video_path)
        self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Read IMU data
        self.imu_data = pd.read_csv(imu_csv_path)

        # Read tracker data
        self.tracker_data = pd.read_csv(tracker_csv_path)

        # Synchronize data (assuming timestamps are available)
        self.sync_data()

    def sync_data(self):
        # Implement synchronization logic here
        # This should align video frames with IMU and tracker data
        pass

    def __len__(self):
        return self.total_frames - self.num_frames + 1

    def __getitem__(self, idx):
        frames = []
        for i in range(self.num_frames):
            self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx + i)
            ret, frame = self.cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                frame = cv2.resize(frame, (224, 224))
                frames.append(frame)

        frames = np.stack(frames, axis=0)

        # Get corresponding tracker data
        tracker_data = self.tracker_data.iloc[idx : idx + self.num_frames]

        # Calculate mean transformation parameters
        mean_params = self.calculate_mean_params(tracker_data)

        return (
            torch.from_numpy(frames).float().unsqueeze(0),
            torch.from_numpy(mean_params).float(),
        )

    def calculate_mean_params(self, tracker_data):
        # Implement logic to calculate mean transformation parameters
        # This should return a 6-element array (tx, ty, tz, rx, ry, rz)
        pass

: 

# %% [markdown]

 ## 3. Define the DCL-Net model

In [None]:
# %%

class DCLNet(nn.Module):
    def __init__(self, num_frames=5, cardinality=32):
        super(DCLNet, self).__init__()
        self.conv1 = nn.Conv3d(
            1,
            64,
            kernel_size=(3, 7, 7),
            stride=(1, 2, 2),
            padding=(1, 3, 3),
            bias=False,
        )
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(
            kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)
        )

        # Add ResNeXt blocks here

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(2048, 6)  # Output 6 parameters (tx, ty, tz, rx, ry, rz)

        # Add attention module
        self.attention = nn.Sequential(
            nn.BatchNorm3d(2048),
            nn.Conv3d(2048, 1024, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv3d(1024, 1, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm3d(1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # Add ResNeXt block operations here

        attention = self.attention(x)
        x = x * attention

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

: 

# %% [markdown]

 ## 4. Implement the loss function

In [None]:
# %%

class MSECorrelationLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(MSECorrelationLoss, self).__init__()
        self.alpha = alpha
        self.mse = nn.MSELoss()

    def forward(self, pred, target):
        mse_loss = self.mse(pred, target)

        # Calculate correlation loss
        pred_mean = pred.mean(dim=0, keepdim=True)
        target_mean = target.mean(dim=0, keepdim=True)
        pred_std = pred.std(dim=0, unbiased=False)
        target_std = target.std(dim=0, unbiased=False)

        correlation = ((pred - pred_mean) * (target - target_mean)).mean(dim=0) / (
            pred_std * target_std
        )
        correlation_loss = 1 - correlation.mean()

        return self.alpha * mse_loss + (1 - self.alpha) * correlation_loss

: 

# %% [markdown]

 ## 5. Set up the training loop

In [None]:
# %%

def train(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        for frames, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            frames, labels = frames.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(frames)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        # Validation
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for frames, labels in val_loader:
                frames, labels = frames.to(device), labels.to(device)
                outputs = model(frames)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        val_loss /= len(val_loader)

        print(
            f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}"
        )

    return model

: 

# %% [markdown]

 ## 6. Train the model

In [None]:
# %%

# Set up dataset and data loaders
dataset = USDataset(
    "path/to/video.mp4", "path/to/imu_data.csv", "path/to/tracker_data.csv"
)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# Initialize model, loss, and optimizer
model = DCLNet()
criterion = MSECorrelationLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Train the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 100
trained_model = train(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, device
)

# Save the trained model
torch.save(trained_model.state_dict(), "dcl_net_model.pth")

: 

# %% [markdown]

 ## 7. Evaluate the model

In [None]:
# %%

def evaluate(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for frames, labels in tqdm(test_loader, desc="Evaluating"):
            frames, labels = frames.to(device), labels.to(device)
            outputs = model(frames)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            all_predictions.append(outputs.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    test_loss /= len(test_loader)
    all_predictions = np.concatenate(all_predictions)
    all_labels = np.concatenate(all_labels)

    # Calculate metrics
    mean_error = np.mean(np.abs(all_predictions - all_labels), axis=0)

    print(f"Test Loss: {test_loss:.4f}")
    print(f"Mean Absolute Error: {mean_error}")

    return all_predictions, all_labels


# Set up test dataset and loader
test_dataset = USDataset(
    "path/to/test_video.mp4",
    "path/to/test_imu_data.csv",
    "path/to/test_tracker_data.csv",
)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# Load the trained model
model = DCLNet()
model.load_state_dict(torch.load("dcl_net_model.pth"))
model.to(device)

# Evaluate the model
criterion = MSECorrelationLoss()
predictions, labels = evaluate(model, test_loader, criterion, device)

: 

# %% [markdown]

 ## 8. Reconstruct the 3D volume

In [None]:
# %%

def reconstruct_volume(
    model, video_path, imu_csv_path, tracker_csv_path, output_path, device
):
    # Set up dataset for the entire video
    dataset = USDataset(video_path, imu_csv_path, tracker_csv_path)
    data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

    model.eval()
    transformations = []

    with torch.no_grad():
        for frames, _ in tqdm(data_loader, desc="Predicting transformations"):
            frames = frames.to(device)
            output = model(frames)
            transformations.append(output.cpu().numpy()[0])

    transformations = np.array(transformations)

    # Reconstruct 3D volume using predicted transformations
    volume = reconstruct_3d_volume(dataset, transformations)

    # Save the reconstructed volume
    sitk_volume = sitk.GetImageFromArray(volume)
    sitk.WriteImage(sitk_volume, output_path)

    return volume


def reconstruct_3d_volume(dataset, transformations):
    # Initialize an empty volume
    volume = np.zeros((256, 256, len(transformations)), dtype=np.uint8)

    # Set up initial transformation matrix
    current_transform = np.eye(4)

    for i, transformation in enumerate(transformations):
        # Update current transformation
        delta_transform = transformation_vector_to_matrix(transformation)
        current_transform = np.dot(current_transform, delta_transform)

        # Get the current frame
        frame = dataset[i][0].numpy()[0, 0]

        # Apply transformation to the frame and insert it into the volume
        transformed_frame = apply_transformation(frame, current_transform)
        volume[:, :, i] = transformed_frame

    return volume


def transformation_vector_to_matrix(vector):
    # Convert 6-element vector (tx, ty, tz, rx, ry, rz) to 4x4 transformation matrix
    tx, ty, tz, rx, ry, rz = vector

    # Create rotation matrix
    Rx = np.array(
        [[1, 0, 0], [0, np.cos(rx), -np.sin(rx)], [0, np.sin(rx), np.cos(rx)]]
    )
    Ry = np.array(
        [[np.cos(ry), 0, np.sin(ry)], [0, 1, 0], [-np.sin(ry), 0, np.cos(ry)]]
    )
    Rz = np.array(
        [[np.cos(rz), -np.sin(rz), 0], [np.sin(rz), np.cos(rz), 0], [0, 0, 1]]
    )
    R = np.dot(Rz, np.dot(Ry, Rx))

    # Create translation vector
    T = np.array([tx, ty, tz])

    # Combine rotation and translation into 4x4 matrix
    M = np.eye(4)
    M[:3, :3] = R
    M[:3, 3] = T

    return M


def apply_transformation(frame, transform):
    # Apply 3D transformation to a 2D frame
    height, width = frame.shape
    y, x = np.meshgrid(np.arange(height), np.arange(width), indexing="ij")
    homogeneous_coords = np.stack(
        [
            x.flatten(),
            y.flatten(),
            np.zeros_like(x.flatten()),
            np.ones_like(x.flatten()),
        ],
        axis=-1,
    )

    transformed_coords = np.dot(transform, homogeneous_coords.T).T
    transformed_coords = transformed_coords[:, :2] / transformed_coords[:, 3:]

    transformed_frame = cv2.remap(
        frame,
        transformed_coords[:, 0].reshape(height, width).astype(np.float32),
        transformed_coords[:, 1].reshape(height, width).astype(np.float32),
        cv2.INTER_LINEAR,
    )

    return transformed_frame


# Reconstruct 3D volume
reconstructed_volume = reconstruct_volume(
    model,
    "path/to/test_video.mp4",
    "path/to/test_imu_data.csv",
    "path/to/test_tracker_data.csv",
    "reconstructed_volume.nii.gz",
    device,
)

: 

# %% [markdown]

 ## 9. Visualize results

In [None]:
# %%

def visualize_results(predictions, labels, reconstructed_volume):
    # Plot predicted vs. ground truth transformation parameters
    fig, axs = plt.subplots(2, 3, figsize=(15, 10))
    param_names = ["tx", "ty", "tz", "rx", "ry", "rz"]

    for i, ax in enumerate(axs.flat):
        ax.scatter(labels[:, i], predictions[:, i], alpha=0.5)
        ax.plot(
            [labels[:, i].min(), labels[:, i].max()],
            [labels[:, i].min(), labels[:, i].max()],
            "r--",
            lw=2,
        )
        ax.set_xlabel(f"Ground Truth {param_names[i]}")
        ax.set_ylabel(f"Predicted {param_names[i]}")
        ax.set_title(f"{param_names[i]} Prediction")

    plt.tight_layout()
    plt.show()

    # Visualize reconstructed volume
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    axs[0].imshow(
        reconstructed_volume[:, :, reconstructed_volume.shape[2] // 2], cmap="gray"
    )
    axs[0].set_title("Axial View")
    axs[0].axis("off")

    axs[1].imshow(
        reconstructed_volume[:, reconstructed_volume.shape[1] // 2, :], cmap="gray"
    )
    axs[1].set_title("Coronal View")
    axs[1].axis("off")

    axs[2].imshow(
        reconstructed_volume[reconstructed_volume.shape[0] // 2, :, :], cmap="gray"
    )
    axs[2].set_title("Sagittal View")
    axs[2].axis("off")

    plt.tight_layout()
    plt.show()


# Visualize results
visualize_results(predictions, labels, reconstructed_volume)

: 

# %% [markdown]

 ## 10. Conclusion and Future Work

# %% [markdown]

 In this notebook, we have implemented the DCL-Net for sensorless freehand 3D ultrasound reconstruction. We have covered the following steps:

 1. Data preprocessing and loading
 2. Model architecture implementation
 3. Custom loss function (MSE + Correlation Loss)
 4. Training and evaluation
 5. 3D volume reconstruction
 6. Results visualization

 The DCL-Net shows promising results in estimating the transformation parameters between consecutive ultrasound frames. However, there are several areas for potential improvement and future work:

 1. Fine-tuning hyperparameters: Experiment with different learning rates, batch sizes, and model architectures to optimize performance.
 2. Data augmentation: Implement more advanced data augmentation techniques to improve model generalization.
 3. Attention mechanism: Further refine the attention module to focus on the most informative regions of the ultrasound frames.
 4. Multi-task learning: Explore the possibility of jointly predicting transformation parameters and segmenting anatomical structures.
 5. Real-time reconstruction: Optimize the model and reconstruction pipeline for real-time 3D volume generation during freehand ultrasound scanning.
 6. Clinical validation: Conduct extensive clinical validation studies to assess the accuracy and reliability of the reconstructed 3D volumes in various anatomical regions and scanning scenarios.

# %% [markdown]

 ## 11. Save and Export Results

In [None]:
# %%

# Save predictions and labels
np.save("predictions.npy", predictions)
np.save("labels.npy", labels)

# Export visualizations
plt.savefig("transformation_predictions.png")
plt.savefig("reconstructed_volume_views.png")

print("Results saved and exported successfully.")

: 