In [1]:
import rooms.dataset
import render
import torch
import torch.nn as nn
import metrics
import train
import os

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
torch.set_default_dtype(torch.float32)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
dataset_name = "prova"

D = rooms.dataset.dataLoader(dataset_name)

In [4]:
#training parameters

n_fibonacci = 12 #128
late_stage_model= "UniformResidual" #"UniformResidual"
toa_perturb = True #True
model_transmission = False #False

skip_train = False #False
continue_train = False #False

n_epochs = 3 #1000
batch_size = 3 #4
lr = 1e-2 #1e-2
pink_noise_supervision = False #True
pink_start_epoch = 500 #500
fs = 4000 #48000

load_dir= None
save_dir= '~/prova_training'

In [5]:
R = render.Renderer(n_surfaces=len(D.all_surfaces), n_fibonacci=n_fibonacci,
                        late_stage_model=late_stage_model,
                        toa_perturb = toa_perturb, model_transmission=model_transmission).to(device)

In [6]:
# Utilizza più GPU se disponibili
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    R = nn.DataParallel(R).module

In [7]:
#directional case
loss_fcn = metrics.training_loss_for_learned_bp

for listener_position in D.RIRs:
    for response in listener_position:
        response['t_response'] = torch.Tensor(response['t_response'][:R.RIR_length]) 

gt_audio = D.RIRs
rendering_method = render.Renderer.render_RIR_learned_beampattern

In [8]:
def initialize(indices, source_xyz, listener_xyzs, surfaces, load_dir):
    Ls = []

    for idx in indices:
        L= render.get_listener(source_xyz=source_xyz, listener_xyz = listener_xyzs[idx], surfaces = surfaces, 
                               load_dir = load_dir, load_num = idx, speed_of_sound = D.speed_of_sound, 
                               max_order = D.max_order, parallel_surface_pairs = D.parallel_surface_pairs, 
                               max_axial_order = D.max_axial_order)
        Ls.append(L)
    return Ls

In [9]:
"""
Training
"""
if not skip_train:
    print("Training")

    #Initialize Listeners
    Ls = initialize(indices=D.train_indices,
                    listener_xyzs=D.xyzs,
                    source_xyz=D.speaker_xyz,
                    surfaces=D.all_surfaces,
                    load_dir=load_dir)
            
    if continue_train:
        R.load_state_dict(torch.load(os.path.join(save_dir,"weights.pt"))['model_state_dict'])

    losses = train.train_loop(R=R, Ls=Ls, train_gt_audio=gt_audio[D.train_indices], D=D,
                        n_epochs = n_epochs, batch_size = batch_size, lr = lr, loss_fcn = loss_fcn,
                        save_dir=save_dir,
                        pink_noise_supervision = pink_noise_supervision,
                        pink_start_epoch=pink_start_epoch,
                        continue_train = continue_train, fs=fs)

else:
    R.load_state_dict(torch.load(os.path.join(save_dir,"weights.pt"))['model_state_dict'])
    R.train = False
    R.toa_perturb = False

Training
Considered Paths:	6
Total Considered Paths, after Axial:	6
Valid Paths:	7
Considered Paths:	6
Total Considered Paths, after Axial:	6
Valid Paths:	7
Considered Paths:	6
Total Considered Paths, after Axial:	6
Valid Paths:	7
Considered Paths:	6
Total Considered Paths, after Axial:	6
Valid Paths:	7
Considered Paths:	6
Total Considered Paths, after Axial:	6
Valid Paths:	7
Considered Paths:	6
Total Considered Paths, after Axial:	6
Valid Paths:	7
Considered Paths:	6
Total Considered Paths, after Axial:	6
Valid Paths:	7
Considered Paths:	6
Total Considered Paths, after Axial:	6
Valid Paths:	7
Considered Paths:	6
Total Considered Paths, after Axial:	6
Valid Paths:	7
Loss:	<function training_loss_for_learned_bp at 0x7cdd383c3ce0>
Late Network Style	UniformResidual
energy_vector
source_response
directivity_sphere
decay
RIR_residual
spline_values
bp_ord_cut_freqs
0
caso direzionale


Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at ../aten/src/ATen/native/SpectralOps.cpp:873.)
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]


31.876358032226562
caso direzionale
35.86842727661133
caso direzionale
30.756473541259766
caso direzionale
26.8070068359375
caso direzionale
26.84660530090332
caso direzionale
30.373106002807617
caso direzionale
27.263751983642578
caso direzionale
31.742412567138672
caso direzionale
27.188127517700195
1
caso direzionale
24.68709373474121
caso direzionale
26.290081024169922
caso direzionale
30.610441207885742
caso direzionale
27.635536193847656
caso direzionale
27.286518096923828
caso direzionale
24.410873413085938
caso direzionale
25.81817626953125
caso direzionale
25.474609375
caso direzionale
29.56269645690918
2
caso direzionale
26.676376342773438
caso direzionale
29.150314331054688
caso direzionale
25.149391174316406
caso direzionale
23.633041381835938
caso direzionale
25.358217239379883
caso direzionale
24.979503631591797
caso direzionale
26.45018196105957
caso direzionale
28.499353408813477
caso direzionale
23.717931747436523
