In [1]:
import os
import datetime

os.chdir("../")
print(os.getcwd())

C:\Users\Ali\OneDrive - Georgia Institute of Technology\25-5 Summer\CS 7643 - Deep Learning\_Project\mbari-mae


In [17]:
import os
import torch
import mae_ast.tasks.mae_ast_pretraining
import fairseq
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import random
import logging

logging.getLogger().setLevel(logging.WARNING)



def print_info(message):
    print(f"[*] {message}")

def save_with_colormap(data_tensor, filename, norm_min, norm_max, cmap='viridis'):

    data_array = data_tensor.cpu().numpy().transpose()
    if norm_max > norm_min:
        normalized_data = (data_array - norm_min) / (norm_max - norm_min)
    else:
        normalized_data = np.zeros_like(data_array)
    normalized_data = np.clip(normalized_data, 0, 1)
    plt.imsave(filename, normalized_data[:992], cmap=cmap, vmin=0, vmax=1, origin='lower')
    print_info(f"Successfully saved image to {filename}")

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    print_info(f"Random seed set to {seed} for reproducibility.")

BASE_CONFIG_PATH = r"config/pretrain/mae_ast - recon.yaml"
CHECKPOINT_DIR = r"C:\Users\Ali\OneDrive - Georgia Institute of Technology\25-5 Summer\CS 7643 - Deep Learning\_Project\mbari-mae\notebook\downstream\load_model"
DATA_DIRECTORY = r"D:\MBARI 2KHz\training\input_dir"
OUTPUT_IMAGE_FOLDER = r"notebook/reconstructions"

DATASET_SPLIT = 'train'
SAMPLE_INDEX = 29664
RANDOM_SEED = torch.randint(0, 100000, (1,))[0]

MODELS_TO_TEST = [
    {'name': '2en1de', 'file': '2en1de.pt', 'encoder_layers': 2, 'decoder_layers': 1},
    {'name': '4en1de', 'file': '4en1de.pt', 'encoder_layers': 4, 'decoder_layers': 1},
    {'name': '6en1de', 'file': '6en1de.pt', 'encoder_layers': 6, 'decoder_layers': 1},
]

MASKING_RATIO = 0.9
CONTRAST_FACTOR = 15.0
BRIGHTNESS_FACTOR = 5.0
SAVE_IMAGES = True


print_info("--- Starting Reconstruction Comparison Process ---")

base_cfg = OmegaConf.load(BASE_CONFIG_PATH)

temp_cfg = base_cfg.copy()
temp_cfg.merge_with({'task': {'data': DATA_DIRECTORY}, 'dataset': {'valid_subset': DATASET_SPLIT}})
task = fairseq.tasks.setup_task(temp_cfg.task)
task.load_dataset(temp_cfg.dataset.valid_subset)
dataset = task.dataset(temp_cfg.dataset.valid_subset)

sample = dataset[SAMPLE_INDEX]
original_spectrogram_tensor = sample['source']

patch_size = 16
if original_spectrogram_tensor.shape[0] % patch_size != 0:
    padding_needed = patch_size - (original_spectrogram_tensor.shape[0] % patch_size)
    original_spectrogram_tensor = F.pad(original_spectrogram_tensor, (0, 0, 0, padding_needed), "constant", 0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_tensor = original_spectrogram_tensor.unsqueeze(0).to(device)
padding_mask = torch.zeros(1, original_spectrogram_tensor.shape[0], dtype=torch.bool).to(device)

if SAVE_IMAGES:
    os.makedirs(OUTPUT_IMAGE_FOLDER, exist_ok=True)
    norm_min = torch.min(original_spectrogram_tensor).item()
    norm_max = torch.max(original_spectrogram_tensor).item()
    save_with_colormap(
        original_spectrogram_tensor[:992],
        os.path.join(OUTPUT_IMAGE_FOLDER, f"{SAMPLE_INDEX}-original.png"),
        norm_min, norm_max
    )

print_info("\n--- Starting Model Comparison Loop ---")
is_first_run = True # Flag to save the masked image only once

for model_info in MODELS_TO_TEST:
    model_name = model_info['name']
    print_info(f"\n--- Processing model: {model_name} ---")

    set_random_seed(RANDOM_SEED)

    cfg = base_cfg.copy()
    cfg.merge_with({
        'task': {'data': DATA_DIRECTORY},
        'model': {
            'random_mask_prob': MASKING_RATIO,
            'encoder_layers': model_info['encoder_layers'],
            'decoder_layers': model_info['decoder_layers'],
        }
    })
    model = task.build_model(cfg.model)
    checkpoint_path = os.path.join(CHECKPOINT_DIR, model_info['file'])
    print_info(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    model.to(device)
    model.eval()

    print_info("Running model forward pass to generate mask and reconstruct...")
    with torch.no_grad():
        model_output = model.forward(
            source=batch_tensor,
            padding_mask=padding_mask,
            mask=True
        )

    mask_indices = model_output['mask_indices'].squeeze(0)
    all_patches = model.unfold(batch_tensor.unsqueeze(1)).squeeze(0).transpose(0, 1)

    if is_first_run and SAVE_IMAGES:
        print_info("Saving masked image (only once)...")
        p_c = cfg.model.ast_kernel_size_chan
        p_t = cfg.model.ast_kernel_size_time
        h, w = original_spectrogram_tensor.shape
        folder = torch.nn.Fold(output_size=(h, w), kernel_size=(p_c, p_t), stride=(p_c, p_t))

        masked_patches = all_patches.clone()
        masked_patches[mask_indices] = torch.min(all_patches)
        masked_data = masked_patches.transpose(0, 1).unsqueeze(0)
        masked_tensor = folder(masked_data).squeeze()
        save_with_colormap(
            masked_tensor[:992],
            os.path.join(OUTPUT_IMAGE_FOLDER, f"{SAMPLE_INDEX}-{MASKING_RATIO}-{model_name}-masked.png"),
            norm_min, norm_max
        )
        is_first_run = False

    reconstructed_patches = model_output['logit_m_list_recon'].squeeze(0)
    mean_p = torch.mean(reconstructed_patches)
    adjusted_recons_patches = mean_p + CONTRAST_FACTOR * (reconstructed_patches - mean_p)
    adjusted_recons_patches = adjusted_recons_patches - BRIGHTNESS_FACTOR
    recon_patches = all_patches.clone()
    recon_patches[mask_indices] = adjusted_recons_patches
    recon_data = recon_patches.transpose(0, 1).unsqueeze(0)
    if 'folder' not in locals():
        p_c = cfg.model.ast_kernel_size_chan
        p_t = cfg.model.ast_kernel_size_time
        h, w = original_spectrogram_tensor.shape
        folder = torch.nn.Fold(output_size=(h, w), kernel_size=(p_c, p_t), stride=(p_c, p_t))
    reconstructed_tensor = folder(recon_data).squeeze()

    if SAVE_IMAGES:
        output_filename = os.path.join(OUTPUT_IMAGE_FOLDER, f"{SAMPLE_INDEX}-{MASKING_RATIO}-{model_name}-recon.png")
        save_with_colormap(reconstructed_tensor[:992], output_filename, norm_min, norm_max)

print_info("\n--- All models processed. Comparison complete. ---")

[*] --- Starting Reconstruction Comparison Process ---
[*] Successfully saved image to notebook/reconstructions\29664-original.png
[*] 
--- Starting Model Comparison Loop ---
[*] 
--- Processing model: 2en1de ---
[*] Random seed set to 80298 for reproducibility.
[*] Loading checkpoint: C:\Users\Ali\OneDrive - Georgia Institute of Technology\25-5 Summer\CS 7643 - Deep Learning\_Project\mbari-mae\notebook\downstream\load_model\2en1de.pt
[*] Running model forward pass to generate mask and reconstruct...
[*] Saving masked image (only once)...
[*] Successfully saved image to notebook/reconstructions\29664-0.9-2en1de-masked.png
[*] Successfully saved image to notebook/reconstructions\29664-0.9-2en1de-recon.png
[*] 
--- Processing model: 4en1de ---
[*] Random seed set to 80298 for reproducibility.
[*] Loading checkpoint: C:\Users\Ali\OneDrive - Georgia Institute of Technology\25-5 Summer\CS 7643 - Deep Learning\_Project\mbari-mae\notebook\downstream\load_model\4en1de.pt
[*] Running model forw