In [None]:
import render_optimized as render #################à
import rooms.dataset
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

In [None]:
torch.set_default_dtype(torch.float32)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
D_0 = rooms.dataset.dataLoader("classroomBase")

In [None]:
L_0 = render.get_listener(source_xyz=D_0.speaker_xyz, listener_xyz = D_0.xyzs[0], surfaces=D_0.all_surfaces, 
                                                speed_of_sound=D_0.speed_of_sound,
                                                parallel_surface_pairs=D_0.parallel_surface_pairs,
                                                max_order=D_0.max_order, max_axial_order=D_0.max_axial_order)

In [None]:
R_0 = render.Renderer(n_surfaces=len(D_0.all_surfaces))

In [None]:
# Use many GPUs if available
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    R_0 = nn.DataParallel(R_0).module

In [None]:
pt_file = torch.load('./models/classroomBase/weights.pt', map_location=device)
R_0.energy_vector = nn.Parameter(pt_file['model_state_dict']['energy_vector'])
R_0.source_response = nn.Parameter(pt_file['model_state_dict']['source_response'])
R_0.directivity_sphere = nn.Parameter(pt_file['model_state_dict']['directivity_sphere'])
R_0.decay = nn.Parameter(pt_file['model_state_dict']['decay'])
R_0.RIR_residual = nn.Parameter(pt_file['model_state_dict']['RIR_residual'])
R_0.spline_values = nn.Parameter(pt_file['model_state_dict']['spline_values'])

R_0.bp_ord_cut_freqs.to(device) 

In [None]:
import fibonacci_utilities as fib
azimuths, elevations = fib.fibonacci_azimuths_and_elevations(2)

In [None]:
r = R_0.render_RIR_directional(L_0, azimuths, elevations)
r

In [None]:
import evaluate

music_0 = evaluate.render_music(r[0]['t_response'].detach().unsqueeze(0), D_0.music_dls[0:1,...], device=device)[0][0]

In [None]:
music_1 = evaluate.render_music(r[1]['t_response'].detach().unsqueeze(0), D_0.music_dls[0:1,...], device=device)[0][0]


In [None]:
plt.plot(music_0)
plt.title('Music 0')
plt.show()

In [1]:
%pip install soundfile

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [None]:
import numpy as np
import soundfile as sf

sf.write('music0.wav', music_0.astype(np.float32), 48000)

In [None]:
plt.plot(music_1)
plt.title('Music 1')
plt.show()

In [None]:
sf.write('music1.wav', music_1.astype(np.float32), 48000)