In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split, KFold
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchsummary import summary
from torchviz import make_dot
from torch.cuda.amp import GradScaler, autocast

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=padding)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.downsample = downsample  # For matching dimensions if needed

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out
    
# Utility Function
def crop_to_match(tensor, target_tensor):
    """
    Crops tensor to match the size of target_tensor along spatial dimensions (D, H, W).
    """
    diff_depth = tensor.size(2) - target_tensor.size(2)
    diff_height = tensor.size(3) - target_tensor.size(3)
    diff_width = tensor.size(4) - target_tensor.size(4)

    # Crop along each dimension
    tensor = tensor[:, :, 
                    diff_depth // 2:tensor.size(2) - diff_depth // 2,
                    diff_height // 2:tensor.size(3) - diff_height // 2,
                    diff_width // 2:tensor.size(4) - diff_width // 2]
    return tensor

# Cascade3DUNet Class
class Cascade3DUNet(nn.Module):
    def __init__(self, in_channels=2, num_classes=1):
        super(Cascade3DUNet, self).__init__()

        # Input projection to match encoder1's input size
        self.input_projection = nn.Conv3d(in_channels, 32, kernel_size=1)

        # Encoder
        self.encoder1 = ResidualBlock(32, 32)
        self.encoder2 = ResidualBlock(32, 64, downsample=nn.Conv3d(32, 64, kernel_size=1))
        self.encoder3 = ResidualBlock(64, 128, downsample=nn.Conv3d(64, 128, kernel_size=1))
        self.encoder4 = ResidualBlock(128, 256, downsample=nn.Conv3d(128, 256, kernel_size=1))
        self.encoder5 = ResidualBlock(256, 512, downsample=nn.Conv3d(256, 512, kernel_size=1))

        # Decoder
        self.decoder4 = nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2)
        self.decoder4_conv = nn.Conv3d(256, 256, kernel_size=3, padding=1)
        self.decoder3 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
        self.decoder3_conv = nn.Conv3d(128, 128, kernel_size=3, padding=1)
        self.decoder2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.decoder2_conv = nn.Conv3d(64, 64, kernel_size=3, padding=1)
        self.decoder1 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
        self.decoder1_conv = nn.Conv3d(32, 32, kernel_size=3, padding=1)

        # Projection layers for skip connections
        self.proj4 = nn.Conv3d(256, 256, kernel_size=1)
        self.proj3 = nn.Conv3d(128, 128, kernel_size=1)
        self.proj2 = nn.Conv3d(64, 64, kernel_size=1)
        self.proj1 = nn.Conv3d(32, 32, kernel_size=1)

        # Final layer
        self.final_conv = nn.Conv3d(32, num_classes, kernel_size=1)

    def forward(self, x):
        # Input projection
        x = self.input_projection(x)  # Convert in_channels=2 to out_channels=32

        # Encoder
        enc1 = self.encoder1(x)  # [B, 32, D, H, W]
        enc2 = self.encoder2(F.max_pool3d(enc1, 2))  # [B, 64, D/2, H/2, W/2]
        enc3 = self.encoder3(F.max_pool3d(enc2, 2))  # [B, 128, D/4, H/4, W/4]
        enc4 = self.encoder4(F.max_pool3d(enc3, 2))  # [B, 256, D/8, H/8, W/8]
        enc5 = self.encoder5(F.max_pool3d(enc4, 2))  # [B, 512, D/16, H/16, W/16]

        # Decoder
        x = self.decoder4(enc5)  # [B, 256, D/8, H/8, W/8]
        x = self.decoder4_conv(x)
        x = crop_to_match(x, enc4) + self.proj4(enc4)

        x = self.decoder3(x)  # [B, 128, D/4, H/4, W/4]
        x = self.decoder3_conv(x)
        x = crop_to_match(x, enc3) + self.proj3(enc3)

        x = self.decoder2(x)  # [B, 64, D/2, H/2, W/2]
        x = self.decoder2_conv(x)
        x = crop_to_match(x, enc2) + self.proj2(enc2)

        x = self.decoder1(x)  # [B, 32, D, H, W]
        x = self.decoder1_conv(x)
        x = crop_to_match(x, enc1) + self.proj1(enc1)

        x = self.final_conv(x)  # Final layer
        return x


In [None]:
# Step 5: Evaluation
# Denormalize Dose

def denormalize_dose(normalized_dose, min_dose, dynamic_max):
    return normalized_dose * (dynamic_max - min_dose) + min_dose

def predict_and_save_results(model, test_loader, save_path, dose_save_path, device="cuda", metadata=None):
    model.eval()
    mae_list = []
    
    with torch.no_grad():
        for i, inputs in enumerate(tqdm(test_loader, desc="Predicting")):
            ct = inputs[0, 0, :, :, :].cpu().numpy()
            structure = inputs[0, 1, :, :, :].cpu().numpy()
            inputs = inputs.to(device)

            # ทำนายผล
            outputs = model(inputs)
            prediction = outputs.squeeze(0).cpu().numpy()

            # Denormalize predictions
            patient_id = f"patient_{i}"
            min_dose = metadata[patient_id]["MinDose"]
            dynamic_max = metadata[patient_id]["DynamicMaxDose"]
            pred_denorm = denormalize_dose(prediction, min_dose, dynamic_max)

            # Save the predicted dose as .npy
            np.save(os.path.join(dose_save_path, f'{patient_id}_predicted_dose.npy'), pred_denorm)

            # Load Ground Truth
            ground_truth = np.load(os.path.join(save_path, f"{patient_id}_ground_truth.npy"))
            # Denormalize Ground Truth (GT) if it is normalized
            ground_truth_denorm = denormalize_dose(ground_truth, min_dose, dynamic_max)
            mae = np.mean(np.abs(pred_denorm - ground_truth_denorm))
            mae_list.append(mae)

            save_path_comparison = os.path.join(save_path, f"{patient_id}_comparison.png")
            visualize_predictions(ct, structure, pred_denorm, ground_truth, save_path_comparison)

    mae_mean = np.mean(mae_list)
    mae_std = np.std(mae_list)

    # Save results to CSV
    results_df = pd.DataFrame({"Patient": [f"patient_{i}" for i in range(len(mae_list))], "MAE": mae_list})
    results_df.to_csv(os.path.join(save_path, "metrics.csv"), index=False)
    print(f"Metrics saved to {os.path.join(save_path, 'metrics.csv')}")
    print(f"Mean MAE: {mae_mean:.4f} Gy, Std Dev: {mae_std:.4f} Gy")

# Visualize Predictions
def visualize_predictions(ct, structure, prediction, ground_truth=None, slice_indices=None, save_path=None):
    if slice_indices is None:
        slice_indices = [16, 32, 48]

    n_cols = 3 if ground_truth is None else 4
    fig, axes = plt.subplots(len(slice_indices), n_cols, figsize=(15, 5 * len(slice_indices)))

    for i, slice_idx in enumerate(slice_indices):
        axes[i, 0].imshow(ct[slice_idx, :, :], cmap='gray')
        axes[i, 0].set_title(f'CT (Slice {slice_idx})')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(structure[slice_idx, :, :], cmap='gray')
        axes[i, 1].set_title(f'Structure (Slice {slice_idx})')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(prediction[slice_idx, :, :], cmap='hot')
        axes[i, 2].set_title(f'Predicted Dose (Slice {slice_idx})')
        axes[i, 2].axis('off')

        if ground_truth is not None:
            axes[i, 3].imshow(ground_truth[slice_idx, :, :], cmap='hot')
            axes[i, 3].set_title(f'Ground Truth Dose (Slice {slice_idx})')
            axes[i, 3].axis('off')

    plt.tight_layout()
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        plt.close()
        print(f"Saved visualization to {save_path}")
    else:
        plt.show()

In [None]:
# ----------------------
# Test Model
# ----------------------
import torch
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

def denormalize_dose(normalized_dose, min_dose, max_dose):
    """
    Denormalize the predicted dose back to original scale.
    
    Args:
        normalized_dose (torch.Tensor): Normalized dose prediction
        min_dose (float): Minimum dose value
        max_dose (float): Maximum dose value
    
    Returns:
        torch.Tensor: Denormalized dose
    """
    return normalized_dose * (max_dose - min_dose) + min_dose

def evaluate_model(model, test_dataset, metadata_dict, device='cuda'):
    """
    Evaluate the model on test dataset.
    
    Args:
        model (torch.nn.Module): Trained model
        test_dataset (Dataset): Test dataset
        metadata_dict (dict): Dictionary containing dose metadata
        device (str): Computing device
    
    Returns:
        dict: Evaluation metrics
    """
    model.eval()
    model = model.to(device)
    
    mae_list = []
    mse_list = []
    
    results = {
        'patient_ids': [],
        'mae_values': [],
        'mse_values': []
    }
    
    with torch.no_grad():
        for idx in range(len(test_dataset)):
            # Get input data
            input_tensor = test_dataset[idx]
            
            # Prepare input
            input_tensor = input_tensor.unsqueeze(0).to(device)
            
            # Get filename for metadata lookup
            filename = test_dataset.ct_paths[idx].split('/')[-1]
            
            # Get dose metadata
            min_dose = metadata_dict[filename]['MinDose']
            max_dose = metadata_dict[filename]['DynamicMaxDose']
            
            # Predict
            prediction = model(input_tensor)
            
            # Denormalize prediction and ground truth
            pred_dose = prediction.squeeze(0).cpu()
            true_dose = torch.tensor(np.load(test_dataset.dose_paths[idx]))
            
            # Denormalize both prediction and ground truth
            pred_dose_denorm = denormalize_dose(pred_dose, min_dose, max_dose)
            true_dose_denorm = denormalize_dose(torch.tensor(true_dose), min_dose, max_dose)
            
            # Calculate metrics
            mae = torch.mean(torch.abs(pred_dose_denorm - true_dose_denorm)).item()
            mse = torch.mean((pred_dose_denorm - true_dose_denorm)**2).item()
            
            mae_list.append(mae)
            mse_list.append(mse)
            
            # Store results
            results['patient_ids'].append(filename)
            results['mae_values'].append(mae)
            results['mse_values'].append(mse)
            
            # Optional: Visualization of dose distribution
            plt.figure(figsize=(12, 4))
            plt.subplot(131)
            plt.title('Predicted Dose')
            plt.imshow(pred_dose_denorm.numpy()[pred_dose_denorm.shape[0]//2], cmap='hot')
            plt.colorbar()
            
            plt.subplot(132)
            plt.title('Ground Truth Dose')
            plt.imshow(true_dose_denorm.numpy()[true_dose_denorm.shape[0]//2], cmap='hot')
            plt.colorbar()
            
            plt.subplot(133)
            plt.title('Difference')
            plt.imshow(np.abs(pred_dose_denorm.numpy()[pred_dose_denorm.shape[0]//2] - 
                               true_dose_denorm.numpy()[true_dose_denorm.shape[0]//2]), cmap='cool')
            plt.colorbar()
            
            plt.tight_layout()
            plt.savefig(f'dose_comparison_{filename}.png')
            plt.close()
    
    # Compute overall metrics
    overall_metrics = {
        'mean_mae': np.mean(mae_list),
        'std_mae': np.std(mae_list),
        'mean_mse': np.mean(mse_list),
        'std_mse': np.std(mse_list)
    }
    
    # Save results to CSV
    results_df = pd.DataFrame(results)
    results_df.to_csv('test_results.csv', index=False)
    
    # Save overall metrics
    with open('overall_metrics.txt', 'w') as f:
        for key, value in overall_metrics.items():
            f.write(f"{key}: {value}\n")
    
    return overall_metrics

def main():
    # Load metadata
    metadata_df = pd.read_csv('metadata.csv')
    metadata_dict = {row['FileName']: {
        'MinDose': row['MinDose'], 
        'DynamicMaxDose': row['DynamicMaxDose']
    } for _, row in metadata_df.iterrows()}
    
    # Prepare test dataset
    test_dataset = DosePredictionDataset(
        ct_paths=[...],  # List of CT file paths
        structure_paths=[...],  # List of structure file paths
        dose_paths=[...],  # List of dose file paths
        csv_path='metadata.csv',
        mode='test',
        normalize=True
    )
    
    # Load trained model
    model = Cascade3DUNet(...)  # Your model architecture
    model.load_state_dict(torch.load('best_model.pth'))
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Evaluate model
    metrics = evaluate_model(
        model, 
        test_dataset, 
        metadata_dict, 
        device
    )
    
    # Print metrics
    print("Overall Test Metrics:")
    for key, value in metrics.items():
        print(f"{key}: {value}")

if __name__ == "__main__":
    main()

# ----------------------
# Convert Model Output to DICOM
# ----------------------
def save_dicom(pred_rd, dicom_template_path, output_dicom_path):
    dicom_template = pydicom.dcmread(dicom_template_path)
    pred_rd = pred_rd.cpu().numpy().squeeze()  # Convert to numpy array
    dicom_template.PixelData = pred_rd.astype(np.float32).tobytes()
    dicom_template.DoseGridScaling = np.max(pred_rd) / np.max(dicom_template.pixel_array)
    dicom_template.RescaleSlope = 1.0
    dicom_template.RescaleIntercept = 0.0
    dicom_template.SOPInstanceUID = f"1.2.826.0.1.3680043.10.511.{uuid.uuid4().int}"
    dicom_template.StudyInstanceUID = f"1.2.826.0.1.3680043.10.511.{uuid.uuid4().int}"
    dicom_template.SeriesInstanceUID = f"1.2.826.0.1.3680043.10.511.{uuid.uuid4().int}"
    dicom_template.save_as(output_dicom_path)

# Example Usage
dicom_template_path = "original_dose.dcm"
output_dicom_path = "predicted_dose.dcm"
save_dicom(pred_rd, dicom_template_path, output_dicom_path)

# ----------------------
# Visualize Dose Map
# ----------------------
dicom_new = pydicom.dcmread(output_dicom_path)
dose_array = dicom_new.pixel_array
plt.imshow(dose_array[dose_array.shape[0] // 2], cmap="jet")
plt.colorbar(label="Dose")
plt.title("Predicted Dose Distribution")
plt.show()