In [1]:
import os
import argparse
import json

import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter


from util import rescale, find_max_epoch, print_size
from util import training_loss, calc_diffusion_hyperparams

from distributed_util import init_distributed, apply_gradient_allreduce, reduce_tensor

from WaveNet import WaveNet_vocoder as WaveNet
from CustomDatasetPytorch import CustomAllLoadRawWaveDataset

from pathlib import Path
import glob

def train(window_length,hop_length, num_gpus, rank, group_name, output_directory, tensorboard_directory,
          ckpt_iter, n_iters, iters_per_ckpt, iters_per_logging,
          learning_rate, batch_size_per_gpu):
    """
    Parameters:
    num_gpus, rank, group_name:     parameters for distributed training
    output_directory (str):         save model checkpoints to this path
    tensorboard_directory (str):    save tensorboard events to this path
    ckpt_iter (int or 'max'):       the pretrained checkpoint to be loaded;
                                    automitically selects the maximum iteration if 'max' is selected
    n_iters (int):                  number of iterations to train, default is 1M
    iters_per_ckpt (int):           number of iterations to save checkpoint,
                                    default is 10k, for models with residual_channel=64 this number can be larger
    iters_per_logging (int):        number of iterations to save training log, default is 100
    learning_rate (float):          learning rate
    batch_size_per_gpu (int):       batchsize per gpu, default is 2 so total batchsize is 16 with 8 gpus
    """

    # generate experiment (local) path
    local_path = "ch{}_T{}_betaT{}".format(wavenet_config["res_channels"],
                                           diffusion_config["T"],
                                           diffusion_config["beta_T"])
    # Create tensorboard logger.
    if rank == 0:
        tb = SummaryWriter(os.path.join('exp', local_path, tensorboard_directory))

    # distributed running initialization
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)

    # Get shared output_directory ready
    output_directory = os.path.join('exp', local_path, output_directory)
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory, flush=True)

    # map diffusion hyperparameters to gpu
    for key in diffusion_hyperparams:
        if key is not "T":
            diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()

    # predefine model
    net = WaveNet(**wavenet_config).cuda()
    print_size(net)

    # apply gradient all reduce
    if num_gpus > 1:
        net = apply_gradient_allreduce(net)

    # define optimizer
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    # load checkpoint
    if ckpt_iter == 'max':
        ckpt_iter = find_max_epoch(output_directory)
    if ckpt_iter >= 0:
        try:
            # load checkpoint file
            model_path = os.path.join(output_directory, '{}.pkl'.format(ckpt_iter))
            checkpoint = torch.load(model_path, map_location='cpu')

            # feed model dict and optimizer state
            net.load_state_dict(checkpoint['model_state_dict'])
            if 'optimizer_state_dict' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

            print('Successfully loaded model at iteration {}'.format(ckpt_iter))
        except:
            ckpt_iter = -1
            print('No valid checkpoint model found, start training from initialization.')
    else:
        ckpt_iter = -1
        print('No valid checkpoint model found, start training from initialization.')

 # Get the path to the config file
    experiments_folder = "C:/Users/YLY/Documents/eegAudChallenge/auditory-eeg-challenge-2023-code/task2_regression"
    task_folder = os.path.dirname(experiments_folder)
    config_path = os.path.join(task_folder, 'util', 'config.json')

    # Load the config
    with open("C:/Users/YLY/Documents/eegAudChallenge/auditory-eeg-challenge-2023-code/task2_regression/util/config.json") as fp:
        config = json.load(fp)
    
    # Provide the path of the dataset
    # which is split already to train, val, test

    data_folder = Path(config["dataset_folder"])/config["raw_stimuli_folder"]
    train_files = [path for path in Path(data_folder).resolve().glob("*.npz") if path.stem.split("_")[1] == "1"]
    print(len(train_files))
    train_dataset = CustomAllLoadRawWaveDataset(train_files,int(window_length),int(hop_length),int(8000))

    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=1,shuffle=True)
        
    # training
    n_iter = ckpt_iter + 1
    while n_iter < n_iters + 1:
        for i, data in enumerate(train_loader, 0):
            # get the inputs; data is a list of [inputs, labels]
            eeg, audio = data[0].squeeze(1).cuda(), data[1].type(torch.LongTensor).cuda()
            #print(audio.shape)
            # load audio and mel spectrogram
            # mel_spectrogram = mel_spectrogram.cuda()
            # audio = audio.unsqueeze(1).cuda()

            # back-propagation
            optimizer.zero_grad()
            X = (eeg.float(), audio.float())
            loss = training_loss(net, nn.MSELoss(), X, diffusion_hyperparams)
            # print(loss)
            if num_gpus > 1:
                reduced_loss = reduce_tensor(loss.data, num_gpus).item()
            else:
                reduced_loss = loss.item()
            loss.backward()
            optimizer.step()

            # output to log
            # note, only do this on the first gpu
            if n_iter % iters_per_logging == 0 and rank == 0:
                # save training loss to tensorboard
                print("iteration: {} \treduced loss: {} \tloss: {}".format(n_iter, reduced_loss, loss.item()))
                tb.add_scalar("Log-Train-Loss", torch.log(loss).item(), n_iter)
                tb.add_scalar("Log-Train-Reduced-Loss", np.log(reduced_loss), n_iter)

            # save checkpoint
            if n_iter > 0 and n_iter % iters_per_ckpt == 0 and rank == 0:
                checkpoint_name = '{}.pkl'.format(n_iter)
                torch.save({'model_state_dict': net.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict()},
                           os.path.join(output_directory, checkpoint_name))
                print('model at iteration %s is saved' % n_iter)

            n_iter += 1

    # Close TensorBoard.
    if rank == 0:
        tb.close()




In [2]:
with open('rawWave-train-middle.json') as f:
    data = f.read()
    config = json.loads(data)
    train_config = config["train_config"]  # training parameters
    global dist_config
    dist_config = config["dist_config"]  # to initialize distributed training
    global wavenet_config
    wavenet_config = config["wavenet_config"]  # to define wavenet
    global diffusion_config
    diffusion_config = config["diffusion_config"]  # basic hyperparameters
    global trainset_config
    trainset_config = config["trainset_config"]  # to load trainset
    global diffusion_hyperparams
    diffusion_hyperparams = calc_diffusion_hyperparams(
    **diffusion_config)  # dictionary of all diffusion hyperparameters

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    
train(8000*2,4000, 1, 0, "test", **train_config)

output directory exp\ch128_T200_betaT0.025\logs/checkpoint
WaveNet_vocoder Parameters: 7.465217M
No valid checkpoint model found, start training from initialization.
6
file  0 load:  1870
file  1 load:  911
file  2 load:  911
file  3 load:  834
file  4 load:  834
file  5 load:  1904
data load finished:  6  files in total  7264  samples in total
iteration: 0 	reduced loss: 0.9764248132705688 	loss: 0.9764248132705688
iteration: 100 	reduced loss: 0.3443465530872345 	loss: 0.3443465530872345
iteration: 200 	reduced loss: 0.11258603632450104 	loss: 0.11258603632450104
iteration: 300 	reduced loss: 0.04035026952624321 	loss: 0.04035026952624321
model at iteration 300 is saved
iteration: 400 	reduced loss: 0.01575336791574955 	loss: 0.01575336791574955
iteration: 500 	reduced loss: 0.006740289274603128 	loss: 0.006740289274603128
iteration: 600 	reduced loss: 0.015200371854007244 	loss: 0.015200371854007244
model at iteration 600 is saved
iteration: 700 	reduced loss: 0.004433088470250368 	

KeyboardInterrupt: 