In [None]:
import os
import numpy as np
import torch
import torch.nn.functional as F
import cv2  # Optional: for visualization
from torch.utils.data import DataLoader, TensorDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
checkpoint_path = 'xxx.pth'  # Example checkpoint file from your checkpoints directory
# Assuming your model architecture is defined (using build_unet from segmentation_models_pytorch for example)
# and that you have already built the model as 'model'

import segmentation_models_pytorch as smp

# Build the model (for PyTorch; adjust input parameters if needed)
model = smp.Unet(
    encoder_name="resnet34",   # backbone architecture
    encoder_weights="imagenet",  # pretrained on ImageNet for transfer learning
    in_channels=1,             # input channels (assume grayscale spectrogram)
    classes=1,                 # binary segmentation output
    activation="sigmoid"       # to get probability output
)
model.to(device)

# Load checkpoint if available
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print("Loaded best model from checkpoint.")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from data_preprocessing.read_csv_file import read_csv_file
file_path = '/Users/remiliascarlet/Desktop/MDP/transfer_learning/burst_data/csv/done/240420202002-Peach Mountain.csv'

df = pd.read_csv(file_path, on_bad_lines='skip')
# if df is not None:
#     print("DataFrame head:\n", df)

dates = df['Date']
times = df['Time']
frequency = df.columns[2:].astype(int)
data = df.iloc[:, 2:]
data_correct_shape = data.T[::-1]

# print("Dates:\n", dates.head())
# print("Times:\n", times.head())
# print("Frequency (Hz):\n", frequency)
# print("Data:\n", data.head())

In [None]:
from data_preprocessing.data_slicing import SpectrogramSlicer

tile_size = 256
stride = 256
slicer = SpectrogramSlicer(target_size=(256,256), overlap_ratio=0.25, random_offset=False)
tiles, positions = slicer.slice_entire_spectrogram(data_correct_shape)

print(positions)

In [None]:
model.eval()  # Set model to evaluation mode

# Convert tiles to tensor; initial tile shape is (tile_size, tile_size), expand dims to (1, tile_size, tile_size)
tiles_tensor = torch.tensor(tiles).unsqueeze(1)  # Now shape is (N, 1, tile_size, tile_size)
tiles_tensor = tiles_tensor.to(device)

# Run predictions
with torch.no_grad():
    preds = model(tiles_tensor)
    # Apply threshold of 0.5 to generate binary masks
    binary_preds = (preds > 0.5).float()

# Move predictions to CPU and convert to numpy
binary_preds_np = binary_preds.cpu().numpy()
print("Predicted masks shape:", binary_preds_np.shape)

In [None]:
from prediction.prediction_utils import reconstruct_mask
from data_preprocessing.data_label import apply_morphological_operations, apply_rolling_median_filter
reconstructed_mask = reconstruct_mask(binary_preds_np, positions, data_correct_shape.shape, tile_size)
print("Reconstructed mask shape:", reconstructed_mask.shape)

morph_mask = apply_morphological_operations(reconstructed_mask, erosion_radius=20, dilation_radius=1, operation_sequence=['erode', 'dilate'])
final_mask = apply_rolling_median_filter(morph_mask, window_size=5)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.imshow(data_correct_shape, cmap='gray')
plt.title("Original Spectrogram")
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(reconstructed_mask, cmap='gray')
plt.title("Predicted Burst Mask")
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(final_mask, cmap='gray')
plt.title("Final Burst Mask")
plt.axis('off')

plt.tight_layout()
plt.show()