In [None]:

from model import *

model = train_model(epochs=10, batch_size=8, learning_rate=0.001, num_workers=8, data_dir='./processed_data')

In [None]:
model.save("Unet_model_case2")
model.eval()

In [7]:
from unet import *
from dataset_loader import *

def resize_spectrogram(spectrogram):
    # Convert input list or numpy array to tensor if needed
    if not torch.is_tensor(spectrogram):
        spectrogram = torch.tensor(spectrogram, dtype=torch.float32)
    
    # Ensure it's 4D: (Batch, Channel, Height, Width)
    # The spectrogram from SpectogramEncoder.encode is (Freq, Time).
    # We want 1 batch, 1 channel.
    if spectrogram.ndim == 2:
        spectrogram = spectrogram.unsqueeze(0).unsqueeze(0)
    elif spectrogram.ndim == 3:
        spectrogram = spectrogram.unsqueeze(0)
        
    # Resize to (256, 256)
    # Note: F.interpolate expects float input
    resized_spectrogram = F.interpolate(spectrogram, size=(256, 256), mode='bilinear', align_corners=False)
    return resized_spectrogram


model = UNetSpectrogramTranslator()
model.load("Unet_model_cvss_c.pth")
model.eval()
data_dict = dataset_loader.load_data(start_idx=0, num_samples=100, split="test", lang="en")
sample = data_dict['en'][8]
play_audio(sample)

Model loaded from Unet_model_cvss_c.pth
Buffering data into memory...
Buffered 100 samples (0.03 MB). Converting to Dataset...


In [8]:
print(f"Sample ID: {sample['id']}")

# Encode
print("Encoding audio...")
encoder = SpectogramEncoder()
raw_audio = np.array(sample['audio']['array'])
spectrogram = encoder.encode(raw_audio, sample['audio']['sampling_rate'])
print(f"Original shape: {spectrogram.shape}")
# Debug: Print stats of INPUT spectrogram
print(f"Input Spectrogram Stats: min={spectrogram.min():.4f}, max={spectrogram.max():.4f}, mean={spectrogram.mean():.4f}")

# Capture original dimensions
original_height, original_width = spectrogram.shape

print("Resizing spectrogram...")
# Ensure normalization happens here if it's not in resize_spectrogram (it IS NOT in the helper function in case2.ipynb cell 4)
# Wait, in the notebook cell 4 function 'resize_spectrogram', it just resizes. 
# But in UNET.py, the dataset loader DOES normalization: (spec + 80)/80.
# WE MUST DO NORMALIZATION HERE MANUALLY BEFORE INFERENCE!
input_tensor = resize_spectrogram(spectrogram)

# NORMALIZE INPUT
print("Normalizing Input... ((x+80)/80)")
input_tensor = (input_tensor + 80.0) / 80.0

# Move input to Same Device as Model
device = next(model.parameters()).device
input_tensor = input_tensor.to(device)

print("Running inference...")
with torch.no_grad():
    output_tensor = model(input_tensor)

# Resize Back to Original Dimensions
print("Restoring original size...")
output_tensor = F.interpolate(
    output_tensor, 
    size=(original_height, original_width), 
    mode='bilinear', 
    align_corners=False
)

# Decode
print("Decoding output...")
output_spec = output_tensor.squeeze().cpu().numpy()

print(f"Raw Model Output Stats: min={output_spec.min():.4f}, max={output_spec.max():.4f}, mean={output_spec.mean():.4f}")

# SAFETY: Replace NaNs/Infs
if not np.isfinite(output_spec).all():
    print("Warning: Model output contains NaNs or Infs. replacing...")
    output_spec = np.nan_to_num(output_spec, nan=0.5, posinf=1.0, neginf=0.0)

# Denormalize (Reverse the (x+80)/80 done in training)
print("Denormalizing Output... ((x*80)-80)")
output_spec = (output_spec * 80.0) - 80.0

print(f"Final Spectrogram Stats (dB): min={output_spec.min():.4f}, max={output_spec.max():.4f}, mean={output_spec.mean():.4f}")

reconstructed_audio = encoder.decode(output_spec, sample['audio']['sampling_rate'])
rec_audio = Audio(data=reconstructed_audio, rate=sample['audio']['sampling_rate'])
rec_audio

Sample ID: 1846
Encoding audio...
Original shape: (1025, 346)
Input Spectrogram Stats: min=-80.0000, max=0.0000, mean=-58.2143
Resizing spectrogram...
Normalizing Input... ((x+80)/80)
Running inference...
Restoring original size...
Decoding output...
Raw Model Output Stats: min=-0.0072, max=0.6833, mean=0.2916
Denormalizing Output... ((x*80)-80)
Final Spectrogram Stats (dB): min=-80.5746, max=-25.3340, mean=-56.6729
