In [1]:
import clip
from PIL import Image
import torch
from torchvision import transforms
import wav2clip
from typing import Dict
import torch.nn as nn
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import torch
from pprint import pprint

# 1) Loading a tensor or generic Python object
pretrained_weights = torch.load("Wav2CLIP.pt")

# image_transform_obj = {k: v for k, v in obj.items() if "image_transform" in k}
# print(obj['audio_transform.sequential.3.weight'])
# pprint([el for el in image_transform_obj.keys()])

# Load image

In [3]:
image_path = "images\cat.jpg"
image = Image.open(image_path).convert("RGB")
image = image.resize((512, 512))
# Preprocess: convert to tensor normalized in [-1, 1]
image_tensor = (
    transforms.ToTensor()(image).unsqueeze(0).to(device)
)  # shape [1,3,H,W]
image_tensor = 2.0 * image_tensor - 1.0  # scale from [0,1] to [-1,1]

  image_path = "images\cat.jpg"


# Get CLIP embeddings

In [4]:
def get_preprocessing_for_clip():
    return transforms.Compose([
        # 1) Resize shorter edge to target_size, keep aspect ratio:
        transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
        
        # 2) Center‐crop to exactly (target_size, target_size):
        transforms.CenterCrop(224),
        
        # 3) Normalize per‐channel:
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
    ])

In [5]:
# Load CLIP model
clip_model_vit, _ = clip.load("ViT-B/32", device=device)
clip_model_res, _ = clip.load("RN50", device=device)

In [6]:
clip_preprocessing = get_preprocessing_for_clip()
preprocessed_image_tensor = clip_preprocessing(image_tensor)
preprocessed_image_tensor.shape

torch.Size([1, 3, 224, 224])

In [7]:
image_embedding_vit = clip_model_vit.encode_image(preprocessed_image_tensor)
image_embedding_vit.shape

torch.Size([1, 512])

In [8]:
image_embedding_res = clip_model_res.encode_image(preprocessed_image_tensor)
image_embedding_res.shape

torch.Size([1, 1024])

# Wav2CLIP

In [9]:
audio_path = "audio/1-100032-A-0.wav"
import librosa

wav2clip_model = wav2clip.get_model()
wav, sr = librosa.load(audio_path, sr=48000, mono=True)
audio_tensor = wav
audio_emb_batch = wav2clip.embed_audio(audio_tensor, wav2clip_model)
audio_emb_batch.shape

(1, 512)

# Compare

In [10]:
print(torch.nn.CosineSimilarity(dim=-1)(image_embedding_vit, torch.from_numpy(audio_emb_batch).to(device)))
# print(torch.nn.CosineSimilarity(dim=-1)(image_embedding_res, torch.from_numpy(audio_emb_batch).to(device)))

tensor([0.0768], device='cuda:0', grad_fn=<SumBackward1>)


In [11]:
class MLPLayers(nn.Module):
    def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
        super(MLPLayers, self).__init__()
        self.nonlin = nonlin
        self.dropout = dropout

        sequence = []
        for u0, u1 in zip(units[:-1], units[1:]):
            sequence.append(nn.Linear(u0, u1))
            sequence.append(self.nonlin)
            sequence.append(nn.Dropout(self.dropout))
        sequence = sequence[:-2]

        self.sequential = nn.Sequential(*sequence)

    def forward(self, X):
        X = self.sequential(X)
        return X

In [12]:
def filter_and_strip_state_dict(
    state_dict: Dict[str, torch.Tensor],
    prefix: str
) -> Dict[str, torch.Tensor]:
    """
    Keep only items where key starts with prefix, then strip the prefix.
    Example:
      'image_transform.sequential.0.weight' -> 'sequential.0.weight'
    """
    filtered = {}
    for key, value in state_dict.items():
        if key.startswith(prefix + '.'):
            new_key = key[len(prefix) + 1:]
            filtered[new_key] = value
    return filtered

def load_mlp_weights(
    mlp: nn.Module,
    state_dict: Dict[str, torch.Tensor],
    prefix: str
) -> None:
    """
    Load weights from a flat state_dict into the MLP.

    Args:
      mlp: Instance of MLPLayers
      state_dict: Dict mapping prefixed keys to tensors
      prefix: Prefix to filter keys (e.g., 'image_transform')
    """
    # Extract relevant weights
    mlp_weights = filter_and_strip_state_dict(state_dict, prefix)

    # Load into model's state dict
    model_dict = mlp.state_dict()
    model_dict.update(mlp_weights)
    mlp.load_state_dict(model_dict)

In [13]:
# Instantiate MLP and load weights
image_transform = MLPLayers(units=[512, 512, 512]).to(device)
load_mlp_weights(image_transform, pretrained_weights, prefix='image_transform')
image_transform.eval()
# # Verify loaded state
# for name, param in image_transform.named_parameters():
#     print(f"{name}: {param.shape}")

MLPLayers(
  (nonlin): ReLU()
  (sequential): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
  )
)

In [14]:
# Instantiate MLP and load weights
audio_transform = MLPLayers(units=[512, 512, 512]).to(device)
load_mlp_weights(audio_transform, pretrained_weights, prefix='audio_transform')
audio_transform.eval()

MLPLayers(
  (nonlin): ReLU()
  (sequential): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
  )
)

In [15]:
image_embedding = image_embedding_vit.float()
audio_embedding = torch.from_numpy(audio_emb_batch).to(device)

In [16]:
transformed_image = image_transform(image_embedding)
transformed_audio = audio_transform(audio_embedding)

In [17]:
def compute_loss(image, audio):
    logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
    

print(torch.nn.CosineSimilarity(dim=-1)(image_embedding, audio_embedding))
print(torch.nn.CosineSimilarity(dim=-1)(transformed_image, audio_embedding))
print(torch.nn.CosineSimilarity(dim=-1)(transformed_image, transformed_audio))

tensor([0.0768], device='cuda:0', grad_fn=<SumBackward1>)
tensor([-0.0967], device='cuda:0', grad_fn=<SumBackward1>)
tensor([-0.0080], device='cuda:0', grad_fn=<SumBackward1>)


In [18]:
class CLIPLoss1D(nn.Module):
    def __init__(self):
        super(CLIPLoss1D, self).__init__()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.loss_image = nn.CrossEntropyLoss()
        self.loss_text = nn.CrossEntropyLoss()

    def forward(self, image_features, text_features):
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logit_scale * text_features @ image_features.t()

        batch_size = image_features.shape[0]
        ground_truth = torch.arange(batch_size, dtype=torch.long, device=device)
        return (
            self.loss_image(logits_per_image, ground_truth)
            + self.loss_text(logits_per_text, ground_truth)
        ) / 2

In [19]:
loss = CLIPLoss1D()
loss(image_embedding, audio_embedding)

tensor(0., device='cuda:0', grad_fn=<DivBackward0>)

In [20]:
image_embedding.shape

torch.Size([1, 512])

In [21]:
loss(torch.ones(size=[4, 512]), torch.ones(size=[4, 512]))

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward)