In [None]:
import rooms.dataset
import render_optimized as render
import torch
import torch.nn as nn
import metrics
import train
import os
import numpy as np
import evaluate
import trace1

import matplotlib.pyplot as plt

In [None]:
torch.set_default_dtype(torch.float32)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
dataset_name = "espoo_S2_amb"

D = rooms.dataset.dataLoader(dataset_name)

In [None]:
#training parameters

n_fibonacci = 128 #128 
late_stage_model= "UniformResidual" #"UniformResidual"
toa_perturb = True #True
model_transmission = False #False

skip_train = False #False
continue_train = False #False

n_epochs = 200 #1000 
batch_size = 7 #4 #4 nel test ##############If dataset has different microphone types better if batch_size divides N_train and the number of microphones of different type!!!!!!
lr = 1e-2 #1e-2
pink_noise_supervision = True #True
pink_start_epoch = 250 #500
fs = 48000 #48000 

load_dir= 'precomputed/' + dataset_name
save_dir= '~/espoo_s2_amb_200epochs'

skip_inference = True #False
skip_music = True #False
skip_eval = True #False
skip_binaural = True #False

valid = False #False #Evaluate on valid instead of test

In [None]:
R = render.Renderer(n_surfaces=len(D.all_surfaces), n_fibonacci=n_fibonacci,
                        late_stage_model=late_stage_model,
                        toa_perturb = toa_perturb, model_transmission=model_transmission).to(device)

In [None]:
# Use multiple GPUs if available
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    R = nn.DataParallel(R).module

In [None]:
# Directional case

loss_fcn = metrics.training_loss_directional
#loss_fcn = metrics.training_loss_directional_with_decay
#loss_fcn = metrics.training_loss_directional_rates

for listener_position in D.RIRs:
    for response in listener_position:
        response['t_response'] = torch.Tensor(response['t_response'][:R.RIR_length])
        response['t_response'].to(device) 

gt_audio = D.RIRs
rendering_method = render.Renderer.render_RIR_directional

solo per training le prossime

In [None]:
def initialize(indices, source_xyz, listener_xyzs, surfaces, load_dir,
               ######################################################
                rendering_methods,
                mic_orientations,
                mic_0_gains,
                mic_180_loss,
                cardioid_exponents):
    Ls = []

    for idx in indices:
        L= render.get_listener(source_xyz=source_xyz, listener_xyz = listener_xyzs[idx], surfaces = surfaces, 
                               load_dir = load_dir, load_num = idx, speed_of_sound = D.speed_of_sound, 
                               max_order = D.max_order, parallel_surface_pairs = D.parallel_surface_pairs, 
                               max_axial_order = D.max_axial_order, 
                               ####################################################
                               rendering_method = rendering_methods[idx], mic_orientation = mic_orientations[idx], mic_0_gains = mic_0_gains[idx], mic_180_loss = mic_180_loss[idx], cardioid_exponents = cardioid_exponents[idx])
        Ls.append(L)
    return Ls

In [None]:
"""
Training
"""
if not skip_train:
    print("Training")

    #Initialize Listeners
    Ls = initialize(indices=D.train_indices,
                    listener_xyzs=D.xyzs,
                    source_xyz=D.speaker_xyz,
                    surfaces=D.all_surfaces,
                    load_dir=load_dir,
                    #######################################
                    rendering_methods = D.rendering_methods, #############should define the rendering method for every listener!!!!!!!!!!!!!!!!!!!!!!!!
                    mic_orientations = D.mic_orientations,#############and all this other stuff!!!!!!!!!!!
                    mic_0_gains= D.mic_0_gains, ############################it's needed for the microphone responses!!!!!!!!!!!!!!!!!!!!!!
                    mic_180_loss = D.mic_180_loss,
                    cardioid_exponents = D.cardioid_exponents)
    if continue_train:
        R.load_state_dict(torch.load(os.path.join(save_dir,"weights.pt"))['model_state_dict'])

    losses = train.train_loop(R=R, Ls=Ls, train_gt_audio=gt_audio[D.train_indices], D=D,
                        n_epochs = n_epochs, batch_size = batch_size, lr = lr, loss_fcn = loss_fcn,
                        save_dir=save_dir,
                        pink_noise_supervision = pink_noise_supervision,
                        pink_start_epoch=pink_start_epoch,
                        continue_train = continue_train, fs=fs)

else:
    R.load_state_dict(torch.load(os.path.join(save_dir,"weights.pt"))['model_state_dict'])
    R.train = False
    R.toa_perturb = False