In [None]:
import torch
import torchvision.models as models
import timm

# Utility: freeze params
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
    return model

# 1. ResNet-34
resnet34 = freeze_model(models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1))
print("ResNet-34 loaded:", sum(p.numel() for p in resnet34.parameters())/1e6, "M params")

# 2. InceptionV3
inception = freeze_model(models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1))
print("InceptionV3 loaded:", sum(p.numel() for p in inception.parameters())/1e6, "M params")

# 3. SqueezeNet 1.1
squeezenet = freeze_model(models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1))
print("SqueezeNet loaded:", sum(p.numel() for p in squeezenet.parameters())/1e6, "M params")

# 4. EfficientNetV2-S (via timm)
efficientnetv2s = freeze_model(timm.create_model("tf_efficientnetv2_s_in21k", pretrained=True))
print("EfficientNetV2-S loaded:", sum(p.numel() for p in efficientnetv2s.parameters())/1e6, "M params")

# 5. MobileNetV3-Small
mobilenetv3s = freeze_model(models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1))
print("MobileNetV3-Small loaded:", sum(p.numel() for p in mobilenetv3s.parameters())/1e6, "M params")

print("\n✅ All models loaded and frozen successfully!")


In [None]:
import torch
from torchvision import transforms
from torch import nn
from PIL import Image

# Assuming models are already loaded and frozen:
# resnet34, inception, squeezenet, efficientnetv2s, mobilenetv3s

# Preprocessing transforms
preprocess_224 = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

preprocess_299 = transforms.Compose([
    transforms.Resize(320),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def get_embeddings(image_path, device='cpu'):
    """
    Run an image through all five frozen models and return embeddings.
    Returns a dictionary of tensors keyed by model name.
    """
    # Load image
    img = Image.open(image_path).convert('RGB')
    
    embeddings = {}
    
    # Set all models to eval mode and freeze parameters
    resnet34.eval()
    inception.eval()
    squeezenet.eval()
    efficientnetv2s.eval()
    mobilenetv3s.eval()
    
    with torch.no_grad():
        # ResNet-34 - Remove final FC layer
        x = preprocess_224(img).unsqueeze(0).to(device)
        resnet34_feat = torch.flatten(torch.nn.Sequential(*list(resnet34.children())[:-1])(x), 1)
        embeddings['resnet34'] = resnet34_feat
        
        # InceptionV3 - Use forward pass and extract features before final FC
        x = preprocess_299(img).unsqueeze(0).to(device)
        
        ## Replace the fc layer temporarily
        original_fc = inception.fc
        inception.fc = nn.Identity()
        inception_feat = inception(x)
        inception.fc = original_fc  # Restore original
        embeddings['inceptionv3'] = inception_feat
    
        
        # SqueezeNet 1.1 - Use features + global average pooling
        x = preprocess_224(img).unsqueeze(0).to(device)
        squeezenet_raw = squeezenet.features(x)  # [1, 512, H, W]
        squeezenet_pooled = torch.nn.functional.adaptive_avg_pool2d(squeezenet_raw, (1, 1))  # [1, 512, 1, 1]
        squeezenet_feat = torch.flatten(squeezenet_pooled, 1)  # [1, 512]
        embeddings['squeezenet'] = squeezenet_feat
        
        # EfficientNetV2-S - Use forward_features + global average pooling
        x = preprocess_224(img).unsqueeze(0).to(device)
        effnet_raw = efficientnetv2s.forward_features(x)  # [1, channels, H, W]
        effnet_pooled = torch.nn.functional.adaptive_avg_pool2d(effnet_raw, (1, 1))  # [1, channels, 1, 1]
        effnet_feat = torch.flatten(effnet_pooled, 1)  # [1, channels]
        embeddings['efficientnetv2s'] = effnet_feat
        
        # MobileNetV3-Small - Use features + global average pooling
        x = preprocess_224(img).unsqueeze(0).to(device)
        mobilenet_raw = mobilenetv3s.features(x)  # [1, channels, H, W]
        mobilenet_pooled = torch.nn.functional.adaptive_avg_pool2d(mobilenet_raw, (1, 1))  # [1, channels, 1, 1]
        mobilenet_feat = torch.flatten(mobilenet_pooled, 1)  # [1, channels]
        embeddings['mobilenetv3_small'] = mobilenet_feat
    
    return embeddings

# Test the function
filename = "../Results/Dev/rock.00069/aligned_spectrogram_21.png"
embeddings = get_embeddings(filename, device='cpu')

print(f"{'Model Name':<20} | {'Embedding Shape'}")
print("-" * 30)
for model_name, embedding in embeddings.items():
    print(f"{model_name:<20} | {embedding.shape}")

In [None]:
import glob
import os

def get_aligned_spectrograms(prefix):
    """
    Given a prefix like '../Results/Dev/rock.00069/aligned_spectrogram',
    return all matching files of the form aligned_spectrogram_<n>.png,
    sorted numerically by <n>.
    """
    pattern = f"{prefix}_*.png"
    files = glob.glob(pattern)
    files.sort(key=lambda x: int(os.path.splitext(x)[0].split('_')[-1]))
    return files


# Example usage:
prefix = "../Results/Dev/rock.00069/aligned_spectrogram"
filenames = get_aligned_spectrograms(prefix)

print(filenames)


In [None]:
import torch
from torchvision import transforms
from torch import nn
from PIL import Image

# Assuming models are already loaded and frozen:
# resnet34, inception, squeezenet, efficientnetv2s, mobilenetv3s

# Preprocessing transforms
preprocess_224 = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

preprocess_299 = transforms.Compose([
    transforms.Resize(320),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


def get_embeddings_batch(filenames, device="cpu"):
    """
    Run a batch of images through all five frozen models and return embeddings.
    Returns a dictionary of tensors keyed by model name.
    Each tensor has shape [batch_size, feature_dim].
    """
    # Load and preprocess images
    imgs_224 = [preprocess_224(Image.open(f).convert("RGB")) for f in filenames]
    imgs_299 = [preprocess_299(Image.open(f).convert("RGB")) for f in filenames]

    batch_224 = torch.stack(imgs_224).to(device)  # [B, 3, 224, 224]
    batch_299 = torch.stack(imgs_299).to(device)  # [B, 3, 299, 299]

    embeddings = {}

    # Set models to eval
    resnet34.eval()
    inception.eval()
    squeezenet.eval()
    efficientnetv2s.eval()
    mobilenetv3s.eval()

    with torch.no_grad():
        # ResNet-34 - Remove final FC layer
        resnet34_feat = torch.flatten(
            torch.nn.Sequential(*list(resnet34.children())[:-1])(batch_224), 1
        )
        embeddings["resnet34"] = resnet34_feat

        # InceptionV3 - Replace FC layer with Identity
        original_fc = inception.fc
        inception.fc = nn.Identity()
        inception_feat = inception(batch_299)
        inception.fc = original_fc
        embeddings["inceptionv3"] = inception_feat

        # SqueezeNet - features + GAP
        sq_raw = squeezenet.features(batch_224)  # [B, 512, H, W]
        sq_pooled = torch.nn.functional.adaptive_avg_pool2d(sq_raw, (1, 1))
        embeddings["squeezenet"] = torch.flatten(sq_pooled, 1)

        # EfficientNetV2-S - forward_features + GAP
        eff_raw = efficientnetv2s.forward_features(batch_224)
        eff_pooled = torch.nn.functional.adaptive_avg_pool2d(eff_raw, (1, 1))
        embeddings["efficientnetv2s"] = torch.flatten(eff_pooled, 1)

        # MobileNetV3-Small - features + GAP
        mob_raw = mobilenetv3s.features(batch_224)
        mob_pooled = torch.nn.functional.adaptive_avg_pool2d(mob_raw, (1, 1))
        embeddings["mobilenetv3_small"] = torch.flatten(mob_pooled, 1)

    return embeddings



embeddings = get_embeddings_batch(filenames, device="cpu")

print(f"{'Model Name':<20} | {'Embedding Shape'}")
print("-" * 40)
for model_name, embedding in embeddings.items():
    print(f"{model_name:<20} | {embedding.shape}")
