In [1]:
%matplotlib inline
import os
import torch
import fairseq
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import librosa.display
import numpy as np
import mae_ast.tasks.mae_ast_pretraining



def print_info(message):
    print(f"[*] {message}")

print_info("All libraries imported successfully.")

  from .autonotebook import tqdm as notebook_tqdm


[*] All libraries imported successfully.


In [2]:
config_dir = "config/pretrain"
config_name = "mae_ast - recon"
data_dir = r"D:\MBARI 2KHz\training\input_dir"  # The folder with your train.tsv and valid.tsv
output_dir = r"D:\MBARI 2KHz\training\250701 4en 2de\output_model"  # The folder where checkpoints are saved
checkpoint_file = "checkpoint_last.pt"  # The specific checkpoint you want to load

index_to_visualize = 1327

no_enc = 4
no_dec = 2
im_sample = 'valid'

overrides = {
    'task': {
        'data': data_dir
    },
    'dataset': {
        'valid_subset': im_sample
    },
    'hydra': {
        'run': {
            'dir': output_dir
        }
    },
    'model': {
        'encoder_layers': no_enc,
        'decoder_layers': no_dec
    }
}

config_path = os.path.join(config_dir, f"{config_name}.yaml")
cfg = OmegaConf.load(config_path)
cfg.merge_with(overrides)

print_info("Configuration loaded and ready.")
print_info(f"Will visualize sample #{index_to_visualize} from the validation set.")

[*] Configuration loaded and ready.
[*] Will visualize sample #1327 from the validation set.


In [3]:

print_info("Setting up the Fairseq task and building the model...")
task = fairseq.tasks.setup_task(cfg.task)
model = task.build_model(cfg.model)

checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint_file)
print_info(f"Loading checkpoint from: {checkpoint_path}")

checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])
model.eval()  # Set the model to evaluation mode

if torch.cuda.is_available():
    device = torch.device("cuda")
    model.to(device)
    print_info("Model moved to GPU.")
else:
    device = torch.device("cpu")
    print_info("CUDA not available. Using CPU.")

print_info(f"Loading dataset for split: {cfg.dataset.valid_subset}")
task.load_dataset(cfg.dataset.valid_subset)
dataset = task.dataset(cfg.dataset.valid_subset)
print_info(f"Dataset loaded with {len(dataset)} samples.")

2025-07-01 15:50:34 | INFO | mae_ast.tasks.mae_ast_pretraining | current directory is C:\Users\Ali\OneDrive - Georgia Institute of Technology\25-5 Summer\CS 7643 - Deep Learning\_Project\MAE-AST-Public
2025-07-01 15:50:34 | INFO | mae_ast.tasks.mae_ast_pretraining | MAEPretrainingTask Config {'_name': 'mae_ast_pretraining', 'data': 'D:\\MBARI 2KHz\\training\\input_dir', 'sample_rate': 2000, 'normalize': False, 'enable_padding': False, 'max_keep_size': None, 'max_sample_size': 40000, 'min_sample_size': 5000, 'random_crop': True, 'pad_audio': False, 'feature_type': 'fbank', 'feature_rate': 100, 'feature_dim': 128, 'deltas': False, 'mask_spans': False, 'mask_type': random_mask}


[*] Setting up the Fairseq task and building the model...


2025-07-01 15:50:34 | INFO | mae_ast.models.mae_ast | MAEModel Config: {'_name': 'mae_ast', 'ast_kernel_size_chan': 16, 'ast_kernel_size_time': 16, 'ast_kernel_stride_chan': 16, 'ast_kernel_stride_time': 16, 'encoder_layers': 4, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': gelu, 'layer_norm_first': False, 'feature_grad_mult': 0.1, 'use_post_enc_proj': False, 'decoder_embed_dim': 768, 'decoder_layers': 2, 'decoder_layerdrop': 0.0, 'dropout': 0.1, 'attention_dropout': 0.1, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.05, 'dropout_input': 0.1, 'random_mask_prob': 0.75, 'mask_length': 10, 'mask_selection': static, 'mask_other': 0.0, 'no_mask_overlap': False, 'mask_min_space': 0, 'conv_pos': 128, 'conv_pos_groups': 16, 'checkpoint_activations': False, 'max_token_length': 48000, 'enc_sine_pos': True, 'enc_conv_pos': False, 'dec_sine_pos': True, 'dec_conv_pos': False}


[*] Loading checkpoint from: D:\MBARI 2KHz\training\250701 4en 2de\output_model\checkpoints\checkpoint_last.pt


2025-07-01 15:50:36 | INFO | mae_ast.data.mae_ast_dataset | max_keep=None, min_keep=5000, loaded 10374, skipped 0 short and 0 long, longest-loaded=20000, shortest-loaded=10000
2025-07-01 15:50:36 | INFO | mae_ast.data.mae_ast_dataset | pad_audio=False, random_crop=True, normalize=False, max_sample_size=40000


[*] Model moved to GPU.
[*] Loading dataset for split: valid
[*] Dataset loaded with 10374 samples.


In [4]:

print_info(f"Extracting preprocessed sample at index {index_to_visualize}...")
sample = dataset[index_to_visualize]
spectrogram_tensor = sample['source']

print_info("Creating tensors for model input...")

padding_mask = torch.zeros(1, spectrogram_tensor.shape[0], dtype=torch.bool)

batch_tensor = spectrogram_tensor.unsqueeze(0).to(device)
padding_mask = padding_mask.to(device)

print("Inputs are ready.")
print(f"Shape of input tensor: {batch_tensor.shape}")
print(f"Shape of padding mask: {padding_mask.shape}")

[*] Extracting preprocessed sample at index 1327...
[*] Creating tensors for model input...
Inputs are ready.
Shape of input tensor: torch.Size([1, 1001, 128])
Shape of padding mask: torch.Size([1, 1001])


In [5]:
print_info("Running model forward pass...")

with torch.no_grad():
    model_output = model.forward(source=batch_tensor, padding_mask=padding_mask)

print_info("--- Inspection Results ---")
print(f"The model returned an object of type: {type(model_output)}")

if isinstance(model_output, (list, tuple)):
    print(f"It contains {len(model_output)} items.")
    for i, item in enumerate(model_output):
        if hasattr(item, 'shape'):
            print(f"  - Item #{i} has shape: {item.shape}")
elif hasattr(model_output, 'shape'):
     print(f"The output has shape: {model_output.shape}")

[*] Running model forward pass...
[*] --- Inspection Results ---
The model returned an object of type: <class 'dict'>


In [6]:
print_info("Keys in the output dictionary:")
print(list(model_output.keys()))

print_info("\nShapes of tensors in the output dictionary:")
for key, value in model_output.items():
    if isinstance(value, torch.Tensor) and hasattr(value, 'shape'):
        print(f"  - Key '{key}' has a tensor with shape: {value.shape}")

[*] Keys in the output dictionary:
['logit_m_list_recon', 'logit_m_list_class', 'target_m_list', 'padding_mask', 'mask_indices']
[*] 
Shapes of tensors in the output dictionary:
  - Key 'logit_m_list_recon' has a tensor with shape: torch.Size([1, 372, 256])
  - Key 'logit_m_list_class' has a tensor with shape: torch.Size([1, 372, 256])
  - Key 'target_m_list' has a tensor with shape: torch.Size([1, 372, 256])
  - Key 'padding_mask' has a tensor with shape: torch.Size([1, 496])
  - Key 'mask_indices' has a tensor with shape: torch.Size([1, 496])


In [7]:
# Final Cell: Correctly Oriented Image Reconstruction

# --- Imports for this cell ---
from PIL import Image
import numpy as np

# --- 1. Get a Sample and Run Inference ---
print_info(f"Extracting preprocessed sample at index {index_to_visualize}...")
sample = dataset[index_to_visualize]
spectrogram_tensor = sample['source']

# Prepare tensors for the model
batch_tensor = spectrogram_tensor.unsqueeze(0).to(device)
padding_mask = torch.zeros(1, spectrogram_tensor.shape[0], dtype=torch.bool).to(device)

print_info("Running model forward pass...")
with torch.no_grad():
    model_output = model.forward(source=batch_tensor, padding_mask=padding_mask, mask=True)

# --- 2. Extract Tensors from the Model Output ---
print_info("Extracting tensors from the output dictionary...")
reconstructed_patches = model_output['logit_m_list_recon'].squeeze(0)
mask_indices = model_output['mask_indices'].squeeze(0)

# --- 3. Reassemble the Full Images ---
print_info("Reassembling full spectrograms...")

# Get original patches and parameters
all_patches = model.unfold(batch_tensor.unsqueeze(1)).squeeze(0).transpose(0, 1)
p_c, p_t = cfg.model.ast_kernel_size_chan, cfg.model.ast_kernel_size_time
h, w = spectrogram_tensor.shape[0], spectrogram_tensor.shape[1]
n_h, n_w = h // p_c, w // p_t
folder = torch.nn.Fold(output_size=(h, w), kernel_size=(p_c, p_t), stride=(p_c, p_t))

# Reassemble the masked image
masked_input_patches = all_patches.clone()
masked_input_patches[mask_indices] = torch.min(all_patches)
masked_image_data = masked_input_patches.transpose(0, 1).reshape(1, p_c * p_t, n_h * n_w)
masked_image_tensor = folder(masked_image_data).squeeze(0)

# Reassemble the globally normalized reconstruction
recons_global_patches = all_patches.clone()
recons_global_patches[mask_indices] = reconstructed_patches
recons_global_data = recons_global_patches.transpose(0, 1).reshape(1, p_c * p_t, n_h * n_w)
recons_global_tensor = folder(recons_global_data).squeeze(0)

# Reassemble the per-patch normalized reconstruction
'''recons_per_patch_patches = all_patches.clone()
normalized_recons_patches = torch.zeros_like(reconstructed_patches)
for i in range(reconstructed_patches.shape[0]):
    patch = reconstructed_patches[i]
    min_p, max_p = torch.min(patch), torch.max(patch)
    if max_p > min_p:
        normalized_recons_patches[i] = (patch - min_p) / (max_p - min_p)
    else:
        normalized_recons_patches[i] = torch.ones_like(patch) * 0.5
recons_per_patch_patches[mask_indices] = normalized_recons_patches
recons_per_patch_data = recons_per_patch_patches.transpose(0, 1).reshape(1, p_c * p_t, n_h * n_w)
recons_per_patch_tensor = folder(recons_per_patch_data).squeeze(0)'''

# --- 4. Corrected Helper Function to Save Images ---
def save_tensor_as_image(data_tensor, filename):
    """Normalizes a tensor and saves it as a correctly oriented grayscale PNG."""
    # Convert to NumPy and TRANSPOSE to make Time the X-axis (wide image)
    data_array = data_tensor.cpu().numpy().transpose()

    # Normalize the data to the 0-255 range for image saving
    min_val, max_val = np.min(data_array), np.max(data_array)
    if max_val > min_val:
        normalized_data = (data_array - min_val) / (max_val - min_val)
    else:
        normalized_data = np.zeros_like(data_array)

    image_data = (normalized_data * 255).astype(np.uint8)
    img = Image.fromarray(image_data, 'L')
    img.save(filename)
    print_info(f"Successfully saved image to {filename}")

# --- 5. Save All Comparison Images ---
print_info("Saving all comparison images with correct orientation...")

save_tensor_as_image(spectrogram_tensor[:, :cfg.task.feature_dim], f"{index_to_visualize}-{im_sample}-{no_enc}enc{no_dec}dec-original.png")
#save_tensor_as_image(masked_image_tensor.squeeze(0)[:, :cfg.task.feature_dim], f"{index_to_visualize}-masked.png")
save_tensor_as_image(recons_global_tensor.squeeze(0)[:, :cfg.task.feature_dim], f"{index_to_visualize}-{im_sample}-{no_enc}enc{no_dec}dec-recon.png")
#save_tensor_as_image(recons_per_patch_tensor.squeeze(0)[:, :cfg.task.feature_dim], "comparison_04_recon_per_patch_norm.png")

[*] Extracting preprocessed sample at index 1327...
[*] Running model forward pass...
[*] Extracting tensors from the output dictionary...
[*] Reassembling full spectrograms...
[*] Saving all comparison images with correct orientation...
[*] Successfully saved image to 1327-valid-4enc2dec-original.png
[*] Successfully saved image to 1327-valid-4enc2dec-recon.png
