In [1]:
import os
import torch
import numpy as np
import cv2
from tqdm import tqdm

In [2]:
tensor_folder = "tensors"
draw_folder = "draw_tensors"
os.makedirs(draw_folder, exist_ok=True)

for tensor_file in tqdm(os.listdir(tensor_folder), desc="Visualizing Tensors"):
    if not tensor_file.endswith(".pt"):
        continue  # Skip non-tensor files

    tensor_path = os.path.join(tensor_folder, tensor_file)
    tensor = torch.load(tensor_path)

    if tensor.shape[0] != 4:  # Ensure tensor has 4 channels
        print(f"Skipping {tensor_file}: Unexpected shape {tensor.shape}")
        continue

    tensor_np = tensor.numpy()
    # tensor_np = (tensor_np > 0).astype(np.uint8)  # Ensure binary values
    start_note = tensor_np[0]     # Magenta (Red + Blue)
    music_symbols = tensor_np[1]  # Red
    staff_lines = tensor_np[2]    # Green
    end_note = tensor_np[3]       # Magenta (Red + Blue)

    height, width = start_note.shape
    color_image = np.zeros((height, width, 3), dtype=np.uint8)
    color_image[:, :, 0] = (start_note * 255) + (end_note * 255)  # Blue (start + end)
    color_image[:, :, 1] = staff_lines * 255  # Green (staff lines)
    color_image[:, :, 2] = music_symbols * 255  # Red (music symbols)

    output_path = os.path.join(draw_folder, f"v_{tensor_file.replace(".pt", ".png")}")
    cv2.imwrite(output_path, color_image)

print("Tensor visualization complete.")

Visualizing Tensors: 100%|██████████████████████████████████████████████████████████| 540/540 [00:02<00:00, 239.64it/s]

Tensor visualization complete.



