# CODE to DAVEnet Inferences

In [1]:
import utils.util as u
import DAVEnet_models
import torch.nn as nn
import torch
import os
dir_DAVEnet_models_remote = "/home/asantos/DAVEnet_Dump/exp_4/models"

def download_DAVEnet_weights(epoch,remote_dir, local_dir):
    i_m_remote_path = os.path.join(remote_dir,f"image_model.{epoch}.pth")
    a_m_remote_path = os.path.join(remote_dir,f"audio_model.{epoch}.pth")
    
    i_m_local_path = os.path.join(local_dir,f"image_model.{epoch}.pth")
    a_m_local_path = os.path.join(local_dir,f"audio_model.{epoch}.pth")
    u.download_remote_file(i_m_remote_path, i_m_local_path)
    u.download_remote_file(a_m_remote_path, a_m_local_path)

def check_if_DAVEnet_in_local(epoch,local_dir):
    image_model_path = os.path.join(local_dir, f"image_model.{epoch}.pth")
    audio_model_path = os.path.join(local_dir, f"audio_model.{epoch}.pth")
    return os.path.exists(image_model_path) and os.path.exists(audio_model_path)

def load_DAVEnet(epoch,local_dir,remote_dir="/home/asantos/DAVEnet_Dump/exp_4/models",device="cpu"):
    print("--DAVEnet--")
    if check_if_DAVEnet_in_local(epoch,local_dir):
        print("Weights in local -> Load models")
    else:
        print("Weights not in local -> Proceed to download")
        download_DAVEnet_weights(epoch, remote_dir,local_dir)
        print("Load models")

    audio_model = nn.DataParallel(DAVEnet_models.Davenet().to(device))
    image_model = nn.DataParallel(DAVEnet_models.VGG16().to(device))

    audio_model.load_state_dict(torch.load(os.path.join(local_dir, f"audio_model.{epoch}.pth"), map_location=device))
    image_model.load_state_dict(torch.load(os.path.join(local_dir, f"image_model.{epoch}.pth"), map_location=device))
    print("Loaded parameters from epoch ",epoch)
    return audio_model,image_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# audio_model, image_model = load_DAVEnet(56,"garbage",device=device)


In [9]:
from utils.util import MatchmapVideoGeneratorDAVEnet

# mvgDAVEnet = MatchmapVideoGeneratorDAVEnet(
#         audio_model= audio_model,
#         image_model= image_model,
#         nFrames= nFrames,
#         device = device,
#         img = img,
#         spec = spec
# )


def inference_maker_DAVEnet(epoch, split, sample_idx, local_dir_saving,model_name=""):
    inference_name = f'DAVEnet{model_name}-epoch{epoch}_{split}_{sample_idx}.mp4'
    inference_local_path = os.path.join(local_dir_saving, "inferences", str(sample_idx), inference_name)

    if os.path.exists(inference_local_path):
        print("Inference already in local")
        return inference_local_path
    print("Proceed to make the inference")

    # Ensure the directory for inference_local_path exists
    inference_dir = os.path.dirname(inference_local_path)
    if not os.path.exists(inference_dir):
        os.makedirs(inference_dir)

    #Load models
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    audio_model, image_model = load_DAVEnet(
        epoch=epoch,
        local_dir= os.path.join(local_dir_saving, "DAVEnet_weights"),
        device=device
    )
    
    #Check if the audio and the image of the sample index is in local
    json_file = f'garbage/{split}.json'
    if os.path.exists(json_file):
        gs = u.GetSampleFromJson(json_file, local_dir_saving,padvalue=0)
        img_local_path, aud_local_path = gs.download_sample(sample_idx)
    else:
        raise FileNotFoundError("json file not found for retrieving the samples")

    img = gs.load_image(img_local_path)
    spec = gs.load_audio_to_spec(aud_local_path)
    nFrames =  "pepe"
    mgv = MatchmapVideoGeneratorDAVEnet(audio_model, image_model, nFrames, device, img, spec)
    mgv.create_video_with_audio(inference_local_path, aud_local_path)
    
    return inference_local_path


inference_maker_DAVEnet(epoch=136,split="valn",sample_idx=1,local_dir_saving="dir_exp_PlacesAudio",model_name="No")



Proceed to make the inference
--DAVEnet--
Weights in local -> Load models




Loaded parameters from epoch  136
Sample already downloaded


  mel_basis = librosa.filters.mel(sr, n_fft, n_mels=num_mel_bins, fmin=fmin)


Video created at: dir_exp_PlacesAudio/inferences/1/DAVEnetNo-epoch136_valn_1.mp4
Success! Video file created: 147996 bytes


'dir_exp_PlacesAudio/inferences/1/DAVEnetNo-epoch136_valn_1.mp4'