In [1]:
import torch

from models_v2 import AVSal_IB4LconAEM_2SP
from Modules import (
    SpherConvLSTMCell,
    SpherConvLSTM_EncoderCell,
    SpherConvLSTM_DecoderCell,
    SphereMaxPool2D
)
from spherenet.sphere_cnn import SphereConv2D
from torch.nn import MaxPool2d, Upsample, Sequential, Linear, BatchNorm1d, LeakyReLU, Dropout, Sigmoid, BatchNorm2d

# Allowlist all classes used in the model
torch.serialization.add_safe_globals([
    AVSal_IB4LconAEM_2SP,
    SpherConvLSTMCell,
    SpherConvLSTM_EncoderCell,
    SpherConvLSTM_DecoderCell,
    SphereConv2D,
    MaxPool2d,
    Upsample,
    Sequential,
    Linear,
    BatchNorm1d,
    LeakyReLU,
    Dropout,
    Sigmoid,
    BatchNorm2d,
    SphereMaxPool2D
])

In [2]:
%pip install -q git+https://github.com/facebookresearch/ImageBind.git || true
%pip install -q timm ftfy einops librosa soundfile torchaudio || true
%pip install -q ftfy transformers sentencepiece || true
print("Install finished (failures are OK if packages already exist).")

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Install finished (failures are OK if packages already exist).


In [3]:
import os
from pathlib import Path
from PIL import Image
import numpy as np
import torch
import torchaudio
import torchvision.transforms as T
import librosa

# Simple AEM: convert waveform -> log-mel -> small CNN -> embedding vector (1024-D to match model)
class SimpleAEM(torch.nn.Module):
    def __init__(self, emb_dim=1024, n_mels=64):
        super().__init__()
        self.n_mels = n_mels
        self.emb_dim = emb_dim
        # small conv stack
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(32, 64, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool2d((1,1)),
        )
        self.fc = torch.nn.Linear(64, emb_dim)
    def forward(self, waveform, sr=16000):
        # waveform: (samples,) or (1, samples)
        if waveform.dim()==1:
            wav = waveform.unsqueeze(0)
        else:
            wav = waveform
        wav = wav.float()
        # compute log-mel with librosa for robustness
        wav_np = wav.cpu().numpy()[0]
        mel = librosa.feature.melspectrogram(y=wav_np, sr=sr, n_mels=self.n_mels, fmin=20, fmax=sr//2)
        log_mel = librosa.power_to_db(mel, ref=np.max)
        # normalize
        log_mel = (log_mel - log_mel.mean()) / (log_mel.std() + 1e-9)
        x = torch.from_numpy(log_mel).unsqueeze(0).unsqueeze(0)  # (1,1,n_mels,frames)
        x = x.to(wav.device if hasattr(wav,'device') else torch.device('cpu'))
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        out = self.fc(x)
        return out.squeeze(0)

def get_image_embedder(device):
    from imagebind.models import ImageBindModel
    from imagebind import data, models
    import torchvision.transforms as T

    model = ImageBindModel(device=device)
    model.eval()

    transform = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
    ])

    def embed_images(imgs):
        xs = torch.stack([transform(img) for img in imgs]).to(device)
        with torch.no_grad():
            out = model.forward({"vision": xs})["vision"]
        return out.cpu()

    return embed_images, model, "ImageBind"


In [4]:
# User parameters: update these paths to your frames folder and wav file
frames_folder = r'./TestVid'  # replace with your frames folder path
wav_path = r'./TestVid/5020.wav'   # replace with your .wav file path
num_frames = 60
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load frames (sorted) and keep first num_frames
def load_first_n_frames(folder, n):
    p = Path(folder)
    files = sorted([x for x in p.iterdir() if x.suffix.lower() in ['.jpg','.jpeg','.png']])[:n]
    imgs = [Image.open(str(f)).convert('RGB') for f in files]
    return imgs, files

imgs, img_files = load_first_n_frames(frames_folder, num_frames)
print(f'Loaded {len(imgs)} frames from', frames_folder)

# Load waveform (mono) at 16kHz
def load_wav(path, target_sr=16000):
    wav, sr = librosa.load(path, sr=target_sr, mono=True)
    wav_t = torch.from_numpy(wav)
    return wav_t, target_sr

if not Path(wav_path).exists():
    print('Warning: wav file not found at', wav_path)
    wav = None
    sr = 16000
else:
    wav, sr = load_wav(wav_path)
    print('Loaded wav:', wav_path, 'sr=', sr, 'length (s)=', len(wav)/sr)

# Instantiate AEM (1024-D) and embedder
aem_processor = SimpleAEM(emb_dim=1024, n_mels=64).to(device)
embed_images, image_model, image_backend = get_image_embedder(device)
print(f'Using {image_backend} for image embeddings')

# Convert PIL images to tensor frames (B, T, C, H, W) format
tf_frames = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor(), T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
frame_tensors = torch.stack([tf_frames(img) for img in imgs])  # (T, C, H, W)
input_frames = frame_tensors.unsqueeze(0).to(device)  # (1, T, C, H, W) - batch of 1
print('Input frames tensor shape:', input_frames.shape)

# Compute image embeddings (for the first num_frames)
with torch.no_grad():
    image_embs = embed_images(imgs)  # Tensor (T, D)
    print('Image embeddings shape:', image_embs.shape)
    image_embs = image_embs.unsqueeze(0).to(device)  # (1, T, D)

# Compute audio embedding for the whole wav using AEM
if wav is not None:
    aem_processor.eval()
    wav = wav.to(device)
    with torch.no_grad():
        audio_emb_single = aem_processor(wav, sr)  # (1024,)
    print('Audio embedding shape:', audio_emb_single.shape)
    # Expand to batch: (1, 1024)
    audio_emb = audio_emb_single.unsqueeze(0).to(device)
else:
    audio_emb = torch.zeros(1, 1024).to(device)
    print('No audio file; using zero audio embedding')

# Load pretrained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AVSal_IB4LconAEM_2SP(input_dim=3, hidden_dim=18, output_dim=1, option='train')
model = torch.load('AViSal360_train_fold4.pth', map_location=device)
model.to(device)
model.eval()

print('Model loaded. Running inference...')
print(f'Model forward signature: forward(img, aem, emb)')
print(f'  img shape: {input_frames.shape}')
print(f'  aem shape: {audio_emb.shape}')
print(f'  emb shape: {image_embs.shape}')

out = None
try:
    with torch.no_grad():
        out = model(input_frames, audio_emb, image_embs)
    print('Model forward pass successful!')
except Exception as e:
    print(f'Model forward pass failed with error: {e}')
    import traceback
    traceback.print_exc()

print('Model output:', type(out), getattr(out, 'shape', None))

# Visualize and save saliency maps
if out is not None and isinstance(out, torch.Tensor):
    import matplotlib.pyplot as plt
    from matplotlib import cm
    
    out_np = out.numpy()
    print('Output shape:', out_np.shape)
    
    # Create output directory for saliency maps
    output_dir = Path('saliency_maps')
    output_dir.mkdir(exist_ok=True)
    
    # Determine if output is per-frame or single map
    # Typically: (1, 1, H, W) or (B, C, H, W)
    if out_np.ndim == 4:
        if out_np.shape[0] == 1 and out_np.shape[1] == 1:
            # Single saliency map (B=1, C=1, H, W)
            print('Single saliency map output')
            saliency = out_np[0, 0]
            
            fig, axes = plt.subplots(1, 2, figsize=(12, 5))
            axes[0].imshow(imgs[0])
            axes[0].set_title('First Frame')
            axes[0].axis('off')
            
            im = axes[1].imshow(saliency, cmap='hot')
            axes[1].set_title('Saliency Map')
            axes[1].axis('off')
            plt.colorbar(im, ax=axes[1])
            
            plt.tight_layout()
            plt.savefig(output_dir / 'saliency_map_combined.png', dpi=100, bbox_inches='tight')
            print(f'Saved combined view to {output_dir / "saliency_map_combined.png"}')
            plt.close()
        else:
            # Per-frame saliency maps (B, C, H, W) where B=num_frames
            print(f'Per-frame saliency maps: {out_np.shape}')
            num_saliency_frames = out_np.shape[0]
            
            for frame_idx in range(num_saliency_frames):
                saliency = out_np[frame_idx, 0] if out_np.shape[1] > 1 else out_np[frame_idx]
                
                fig, axes = plt.subplots(1, 2, figsize=(12, 5))
                
                # Original frame
                if frame_idx < len(imgs):
                    axes[0].imshow(imgs[frame_idx])
                axes[0].set_title(f'Frame {frame_idx}')
                axes[0].axis('off')
                
                # Saliency map with colorbar
                im = axes[1].imshow(saliency, cmap='hot')
                axes[1].set_title(f'Saliency Map {frame_idx}')
                axes[1].axis('off')
                plt.colorbar(im, ax=axes[1])
                
                plt.tight_layout()
                plt.savefig(output_dir / f'saliency_map_{frame_idx:04d}.png', dpi=100, bbox_inches='tight')
                plt.close()
                
                if (frame_idx + 1) % 10 == 0:
                    print(f'Saved saliency map {frame_idx + 1}/{num_saliency_frames}')
            
            print(f'Saved all saliency maps to {output_dir}/')
    
    # Also save raw output tensor
    np.save(output_dir / 'test_output.npy', out_np)
    print(f'Saved output tensor to {output_dir / "test_output.npy"}')
else:
    print('No tensor output to save')

Loaded 60 frames from ./TestVid
Loaded wav: ./TestVid/5020.wav sr= 16000 length (s)= 60.0


  from .autonotebook import tqdm as notebook_tqdm


ImportError: cannot import name 'ImageBindModel' from 'imagebind.models' (c:\Users\mahd\AppData\Local\Programs\Python\Python311\Lib\site-packages\imagebind\models\__init__.py)