In [None]:
import torch
import torchvision.models as models
import timm

# Utility: freeze params
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
    return model

# 1. ResNet-34
resnet34 = freeze_model(models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1))
print("ResNet-34 loaded:", sum(p.numel() for p in resnet34.parameters())/1e6, "M params")

# 2. InceptionV3
inception = freeze_model(models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1))
print("InceptionV3 loaded:", sum(p.numel() for p in inception.parameters())/1e6, "M params")

# 3. SqueezeNet 1.1
squeezenet = freeze_model(models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1))
print("SqueezeNet loaded:", sum(p.numel() for p in squeezenet.parameters())/1e6, "M params")

# 4. EfficientNetV2-S (via timm)
efficientnetv2s = freeze_model(timm.create_model("tf_efficientnetv2_s_in21k", pretrained=True))
print("EfficientNetV2-S loaded:", sum(p.numel() for p in efficientnetv2s.parameters())/1e6, "M params")

# 5. MobileNetV3-Small
mobilenetv3s = freeze_model(models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1))
print("MobileNetV3-Small loaded:", sum(p.numel() for p in mobilenetv3s.parameters())/1e6, "M params")

print("\n✅ All models loaded and frozen successfully!")


In [None]:
import torch
from torchvision import transforms
from torch import nn
from PIL import Image

# Assuming models are already loaded and frozen:
# resnet34, inception, squeezenet, efficientnetv2s, mobilenetv3s

# Preprocessing transforms
preprocess_224 = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

preprocess_299 = transforms.Compose([
    transforms.Resize(320),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def get_embeddings(image_path, device='cpu'):
    """
    Run an image through all five frozen models and return embeddings.
    Returns a dictionary of tensors keyed by model name.
    """
    # Load image
    img = Image.open(image_path).convert('RGB')
    
    embeddings = {}
    
    # Set all models to eval mode and freeze parameters
    resnet34.eval()
    inception.eval()
    squeezenet.eval()
    efficientnetv2s.eval()
    mobilenetv3s.eval()
    
    with torch.no_grad():
        # ResNet-34 - Remove final FC layer
        x = preprocess_224(img).unsqueeze(0).to(device)
        resnet34_feat = torch.flatten(torch.nn.Sequential(*list(resnet34.children())[:-1])(x), 1)
        embeddings['resnet34'] = resnet34_feat
        
        # InceptionV3 - Use forward pass and extract features before final FC
        x = preprocess_299(img).unsqueeze(0).to(device)
        
        # Method 1: Replace the fc layer temporarily
        original_fc = inception.fc
        inception.fc = nn.Identity()
        inception_feat = inception(x)
        inception.fc = original_fc  # Restore original
        embeddings['inceptionv3'] = inception_feat
       
        
        # SqueezeNet 1.1 - Use features only
        x = preprocess_224(img).unsqueeze(0).to(device)
        squeezenet_feat = torch.flatten(squeezenet.features(x), 1)
        embeddings['squeezenet'] = squeezenet_feat
        
        # EfficientNetV2-S - Use forward_features method
        x = preprocess_224(img).unsqueeze(0).to(device)
        effnet_feat = torch.flatten(efficientnetv2s.forward_features(x), 1)
        embeddings['efficientnetv2s'] = effnet_feat
        
        # MobileNetV3-Small - Use features only
        x = preprocess_224(img).unsqueeze(0).to(device)
        mobilenet_feat = torch.flatten(mobilenetv3s.features(x), 1)
        embeddings['mobilenetv3_small'] = mobilenet_feat
    
    return embeddings

# Test the function
filename = "../Results/Dev/rock.00069/aligned_spectrogram_21.png"
embeds = get_embeddings(filename, device='cpu')
print(embeds['resnet34'].shape)
print(embeds['squeezenet'].shape)