In [1]:
import numpy as np
import muram as muram
import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm
# Test the dataset and the dataloader
from dataset import MURAMVelocityDataset
from torch.utils.data import DataLoader
import json, os
import torch
from deepvel_unet2 import UDeepVel

In [2]:
path = 'data'
intensity_files = sorted(glob(f'{path}/I_out*'))
iterations = [int(file[-6:]) for file in intensity_files]
img_size = 64
stride=5
base_dir = 'checkpoints_full_32_dropout_lr'
config_file = f'{base_dir}/config.json'
output_dir = f'{base_dir}'
model_file = f'{base_dir}/deepvel_best_4992.pth'

In [3]:
Intensities = []
velocities = []
times = []
for iter in tqdm(iterations[:10], desc="Loading data"):
    # load the intensity data
    Intensities.append(muram.MuramIntensity(path, iter))
    # load the velocity data
    tau_slice = muram.MuramTauSlice(path, iter, 1)
    # extract the velocity data
    vx, vy = tau_slice.vy, tau_slice.vz
    # insert the velocity data into the list
    velocities.append((vx, vy))
    times.append(tau_slice.time)

Loading data: 100%|██████████| 10/10 [00:00<00:00, 12.73it/s]


In [4]:
Intensities_subsample = Intensities
velocities_subsample = velocities
times_subsample = times

In [5]:
# Load configuration
with open(config_file, 'r') as f:
    config = json.load(f)

# Create output directory
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

In [6]:
# Initialize the dataset
dataset = MURAMVelocityDataset(Intensities_subsample, velocities_subsample, times_subsample,
                                img_size=config["dataset"]["patch_size"],
                                batch_size=1,  # Process one sample at a time for testing
                                aug=None,  # No augmentation for testing
                                ) #seed= config["dataset"]["seed"])

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=None)

In [None]:
# Initialize model
model = UDeepVel(
    input_channels=config["model"]["input_channels"],
    output_channels=config["model"]["output_channels"],
    n_latent_chanels=config["model"]["n_latent_chanels"],
    chanels_multiples=config["model"]["chanels_multiples"],
    is_attn_encoder=config["model"]["is_attn_encoder"],
    is_attn_decoder=config["model"]["is_attn_decoder"],
    n_blocks=config["model"]["n_blocks"],
    time_emb_dim=config["model"]["time_emb_dim"],
    padding_mode=config["model"]["padding_mode"],
    bilinear=config["model"]["bilinear"]
)
model.load_state_dict(torch.load(model_file)["model"])
model = model.to(config["training"]["used_gpu"])
model.eval()

In [None]:
# Loss function
criterion = torch.nn.MSELoss()

num_samples = 10  # Number of samples to test

# Test loop
total_loss = 0.0

fig, axs = plt.subplots(num_samples, 1, figsize=(12, 18), dpi=300)
with torch.no_grad():
    for i, (X_batch, Y_batch, T_batch) in enumerate(tqdm(dataloader, total=num_samples, desc="Testing")):
        if i >= num_samples:
            break

        X_batch = X_batch.to(config["training"]["used_gpu"])
        Y_batch = Y_batch.to(config["training"]["used_gpu"])
        T_batch = T_batch.to(config["training"]["used_gpu"])

        # Forward pass
        outputs = model(X_batch, T_batch)
        loss = criterion(outputs, Y_batch)
        total_loss += loss.item()

        # Convert to numpy for visualization
        pred = outputs.cpu().numpy()[0]  # Shape: (2, H, W)
        true = Y_batch.cpu().numpy()[0]  # Shape: (2, H, W)
        vx_pred, vy_pred = pred[0], pred[1]
        vx_true, vy_true = true[0], true[1]

        X_batch = X_batch.cpu().numpy()
        Y_batch = Y_batch.cpu().numpy()
        T_batch = T_batch.cpu().numpy()

        concatenated = np.hstack([X_batch[0, j, ...] for j in range(3)])
            
        # Create a grid for the velocity field (from the middle image).
        # Use the already defined stride.
        rows_idx = np.arange(0, Y_batch.shape[-2], stride)
        cols_idx = np.arange(0, Y_batch.shape[-1], stride)
        Xgrid, Ygrid = np.meshgrid(cols_idx, rows_idx)
        # Extract the velocity components; note the swaped order in quiver.
        vx_sample_field = vx_true[::stride, ::stride]
        vy_sample_field = vy_true[::stride, ::stride]
        vx_pred_sample_field = vx_pred[::stride, ::stride]
        vy_pred_sample_field = vy_pred[::stride, ::stride]

        # Since the middle image is the second one (assuming equal widths),
        # the intensity of the middle image is located in the center third.
        offset = concatenated.shape[1] // 3
        Xgrid_offset_true = Xgrid + 2*offset
        Xgrid_offset_pred = Xgrid + 0*offset

        ax = axs[i]
        ax.imshow(concatenated, cmap='afmhot', origin='lower')
        mag = np.sqrt(vx_sample_field**2 + vy_sample_field**2)
        mag_pred = np.sqrt(vx_pred_sample_field**2 + vy_pred_sample_field**2)
        ax.quiver(Xgrid_offset_true, Ygrid, vy_sample_field, vx_sample_field, mag, color='blue')
        ax.quiver(Xgrid_offset_pred, Ygrid, vy_pred_sample_field, vx_pred_sample_field, mag_pred, color='red')
        ax.set_title(f"Batch {i+1} - Times: {T_batch[0,0]} {T_batch[0,1]} {T_batch[0,2]}")
        ax.axis("off")

        plt.tight_layout()
        # Save visualization
    plt.savefig(os.path.join(output_dir, f"test_samples.png"))
    plt.close()

avg_loss = total_loss / num_samples
print(f"Average MSE Loss on {num_samples} test samples: {avg_loss:.4f}")