In [1]:
import rooms.dataset
import render
import torch
import torch.nn as nn
import metrics
import train
import os

import matplotlib.pyplot as plt

In [2]:
torch.set_default_dtype(torch.float32)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
dataset_name = "prova"

D = rooms.dataset.dataLoader(dataset_name)

In [7]:
#training parameters

n_fibonacci = 128 #128 
late_stage_model= "UniformResidual" #"UniformResidual"
toa_perturb = True #True
model_transmission = False #False

skip_train = False #False
continue_train = False #False

n_epochs = 2 #1000 
batch_size = 4 #4 #4 nel test
lr = 1e-2 #1e-2
pink_noise_supervision = True #True
pink_start_epoch = 500 #500
fs = 48000 #48000 

load_dir= None
save_dir= '~/prova_training_2epochs'

In [8]:
R = render.Renderer(n_surfaces=len(D.all_surfaces), n_fibonacci=n_fibonacci,
                        late_stage_model=late_stage_model,
                        toa_perturb = toa_perturb, model_transmission=model_transmission).to(device)

In [9]:
# Use multiple GPUs if available
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    R = nn.DataParallel(R).module

solo per training le prossime

In [10]:
# Directional case
loss_fcn = metrics.training_loss_directional

for listener_position in D.RIRs:
    for response in listener_position:
        response['t_response'] = torch.Tensor(response['t_response'][:R.RIR_length])
        response['t_response'].to(device) 

gt_audio = D.RIRs
rendering_method = render.Renderer.render_RIR_directional

In [11]:
def initialize(indices, source_xyz, listener_xyzs, surfaces, load_dir):
    Ls = []

    for idx in indices:
        L= render.get_listener(source_xyz=source_xyz, listener_xyz = listener_xyzs[idx], surfaces = surfaces, 
                               load_dir = load_dir, load_num = idx, speed_of_sound = D.speed_of_sound, 
                               max_order = D.max_order, parallel_surface_pairs = D.parallel_surface_pairs, 
                               max_axial_order = D.max_axial_order)
        Ls.append(L)
    return Ls

In [None]:
"""
Training
"""
if not skip_train:
    print("Training")

    #Initialize Listeners
    Ls = initialize(indices=D.train_indices,
                    listener_xyzs=D.xyzs,
                    source_xyz=D.speaker_xyz,
                    surfaces=D.all_surfaces,
                    load_dir=load_dir)
            
    if continue_train:
        R.load_state_dict(torch.load(os.path.join(save_dir,"weights.pt"))['model_state_dict'])

    losses = train.train_loop(R=R, Ls=Ls, train_gt_audio=gt_audio[D.train_indices], D=D,
                        n_epochs = n_epochs, batch_size = batch_size, lr = lr, loss_fcn = loss_fcn,
                        save_dir=save_dir,
                        pink_noise_supervision = pink_noise_supervision,
                        pink_start_epoch=pink_start_epoch,
                        continue_train = continue_train, fs=fs)

else:
    R.load_state_dict(torch.load(os.path.join(save_dir,"weights.pt"))['model_state_dict'])
    R.train = False
    R.toa_perturb = False

prova di ascolto

In [7]:
import numpy as np
import evaluate

In [8]:
pt_file = torch.load(save_dir + '/weights.pt', map_location=device)
R.energy_vector = nn.Parameter(pt_file['model_state_dict']['energy_vector'])
R.source_response = nn.Parameter(pt_file['model_state_dict']['source_response'])
R.directivity_sphere = nn.Parameter(pt_file['model_state_dict']['directivity_sphere'])
R.decay = nn.Parameter(pt_file['model_state_dict']['decay'])
R.RIR_residual = nn.Parameter(pt_file['model_state_dict']['RIR_residual'])
R.spline_values = nn.Parameter(pt_file['model_state_dict']['spline_values'])

R.bp_ord_cut_freqs = nn.Parameter(pt_file['model_state_dict']['bp_ord_cut_freqs'])

In [None]:
listener_1 = render.get_listener(source_xyz= np.array([5,3,1.5]), listener_xyz = np.array([5,7,1.5]), surfaces = D.all_surfaces, 
                               load_dir = load_dir, load_num = None, speed_of_sound = D.speed_of_sound, 
                               max_order = D.max_order, parallel_surface_pairs = D.parallel_surface_pairs, 
                               max_axial_order = D.max_axial_order)

listener_2 = render.get_listener(source_xyz= np.array([9.9,9.9,2.9]), listener_xyz = np.array([5,3,1.5]), surfaces = D.all_surfaces, 
                               load_dir = load_dir, load_num = None, speed_of_sound = D.speed_of_sound, 
                               max_order = D.max_order, parallel_surface_pairs = D.parallel_surface_pairs, 
                               max_axial_order = D.max_axial_order)

In [10]:
RIR_1 = R.render_RIR(listener_1)

In [11]:
RIR_2 = R.render_RIR(listener_2)

In [None]:
# RIR_1 plot
plt.plot(RIR_1.detach().cpu())
plt.title("RIR_1")
plt.xlabel("Sample")
plt.ylabel("Value")
plt.grid(True)
plt.show()

In [None]:
# RIR_2" plot
plt.plot(RIR_2.detach().cpu())
plt.title("Plot")
plt.xlabel("Indice")
plt.ylabel("Valore")
plt.grid(True)
plt.show()

In [12]:
predicted_music_1 = evaluate.render_music(np.array([RIR_1.detach().cpu()]), np.array([D.music_dls[0]]), device = device)
predicted_music_2 = evaluate.render_music(np.array([RIR_2.detach().cpu()]), np.array([D.music_dls[0]]), device = device)

In [13]:
import sounddevice as sd

duration = predicted_music_1.shape[2]/fs  # Duration in seconds
t = np.linspace(0, duration, int(fs * duration), endpoint=False)


sd.play(predicted_music_1[0][0], samplerate=48000)
sd.wait()  


In [14]:
duration = predicted_music_2.shape[2]/fs 
t = np.linspace(0, duration, int(fs * duration), endpoint=False)


sd.play(predicted_music_2[0][0], samplerate=48000)
sd.wait() 

In [16]:
sd.play(D.music_dls[9][0], samplerate=48000)
sd.wait() 