In [None]:
import shutil
import os

import argparse
import yaml
import torch

from audioldm_train.utilities.data.dataset import MusicDataset

from torch.utils.data import DataLoader
from pytorch_lightning import seed_everything
from audioldm_train.utilities.tools import get_restore_step
from audioldm_train.utilities.model_util import instantiate_from_config
from audioldm_train.utilities.tools import build_dataset_json_from_list

from sklearn.metrics import top_k_accuracy_score

In [None]:
class Dummy(object):
    pass


args = Dummy()
args.config_yaml = "audioldm_train/config/2023_08_23_reproduce_audioldm/audioldm_original_medium_with_clip_clap_music.yaml"
args.reload_from_ckpt = "data/checkpoints/audioldm-m-full_new.ckpt"

assert torch.cuda.is_available(), "CUDA is not available"

config_yaml = args.config_yaml
exp_name = os.path.basename(config_yaml.split(".")[0])
exp_group_name = os.path.basename(os.path.dirname(config_yaml))

config_yaml_path = os.path.join(config_yaml)
config_yaml = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader)

if args.reload_from_ckpt != None:
    config_yaml["reload_from_ckpt"] = args.reload_from_ckpt
dataset_json = None
configs = config_yaml

In [None]:
if "seed" in configs.keys():
    seed_everything(configs["seed"])
else:
    print("SEED EVERYTHING TO 0")
    seed_everything(0)

if "precision" in configs.keys():
    torch.set_float32_matmul_precision(configs["precision"])

log_path = configs["log_directory"]

if "dataloader_add_ons" in configs["data"].keys():
    dataloader_add_ons = configs["data"]["dataloader_add_ons"]
else:
    dataloader_add_ons = []

val_dataset = MusicDataset(
    configs, split="train", add_ons=dataloader_add_ons, dataset_json=dataset_json
)

val_loader = DataLoader(
    val_dataset,
    batch_size=100,
)

try:
    config_reload_from_ckpt = configs["reload_from_ckpt"]
except:
    config_reload_from_ckpt = None

resume_from_checkpoint = config_reload_from_ckpt
print("Reload ckpt specified in the config file %s" % resume_from_checkpoint)

latent_diffusion = instantiate_from_config(configs["model"])
latent_diffusion.set_log_dir(log_path, exp_group_name, exp_name)

guidance_scale = configs["model"]["params"]["evaluation_params"][
    "unconditional_guidance_scale"
]
ddim_sampling_steps = configs["model"]["params"]["evaluation_params"][
    "ddim_sampling_steps"
]
n_candidates_per_samples = configs["model"]["params"]["evaluation_params"][
    "n_candidates_per_samples"
]

checkpoint = torch.load(resume_from_checkpoint)

try:
    latent_diffusion.load_state_dict(checkpoint["state_dict"])
except Exception as e:
    print(e)
    latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False)

latent_diffusion.eval()
latent_diffusion = latent_diffusion.cuda()

In [None]:
data_iter = iter(val_loader)

In [None]:
data = next(data_iter)

In [None]:
clip_clap = latent_diffusion.cond_stage_models[0]

In [None]:
data

In [None]:
results = clip_clap.three_modal_contrastive_loss(dict(
    image=data["image"].to(memory_format=torch.contiguous_format, device="cuda"),
    audio=data["waveform"].to(memory_format=torch.contiguous_format, device="cuda").float(),
    text=list(data["text"])
)
)

In [None]:
from sklearn.metrics import top_k_accuracy_score
labels = results["labels"]
i2a_probs = results["i2a_probs"]
i2t_probs = results["i2t_probs"]
a2t_probs = results["a2t_probs"]

i2a_top3 = top_k_accuracy_score(labels, i2a_probs, k=1)
i2t_top3 = top_k_accuracy_score(labels, i2t_probs, k=1)
a2t_top3 = top_k_accuracy_score(labels, a2t_probs, k=1)

In [None]:
(i2a_top3, i2t_top3, a2t_top3)

In [None]:
import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device="cpu")
model.eval()
model = model.to("cuda")

In [None]:
image = data["image"]
text = data["text"]
text = clip.tokenize(text).to("cuda")
image = image.to("cuda")

In [None]:
with torch.no_grad():
    model
    image_features = model.encode_image(image.to("cuda"))
    text_features = model.encode_text(text.to("cuda"))

    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)

In [None]:
c_i2t_top3 = top_k_accuracy_score(labels, probs, k=1)

In [None]:
c_i2t_top3

In [None]:
import numpy as np
image = torch.stack([preprocess(Image.open(val_dataset.data[i]["images"][0])) for i in range(100)])
image = image.to("cuda")

In [None]:
image.shape

In [None]:
with torch.no_grad():
    model
    image_features = model.encode_image(image.to("cuda"))
    text_features = model.encode_text(text.to("cuda"))

    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)

In [None]:
c_i2t_top3 = top_k_accuracy_score(labels, probs, k=1)
c_i2t_top3

In [None]:
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = torch.stack(
    [preprocess(Image.open(name)).to(device) for name in ["./man.png", "./woman.png"]]
)
text = clip.tokenize(["a diagram", "a dog", "a cat", "a man", "a women"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]
c_i2t_top3 = top_k_accuracy_score(labels, probs, k=3)
c_i2t_top3