### Import 

In [1]:
import time
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
# from cords.utils.data.datasets.SL import gen_dataset
from torch.utils.data import Subset
from cords.utils.config_utils import load_config_data
from cords.utils.data.data_utils import WeightedSubset
from ray import tune
import random
import torch.backends.cudnn as cudnn

from tqdm import tqdm
import logging
import os.path as osp
import sys

  from .autonotebook import tqdm as notebook_tqdm
2023-12-20 01:20:52,295	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2023-12-20 01:20:52,361	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
from Light.dataset import Sampler_Loaders, SubDatasets, SubScriptDatasets, SubLoaders
from hyperpyyaml import load_hyperpyyaml
from Light.model import SpeakerLoss
from Define_Model.Optimizer import EarlyStopping

from TrainAndTest.common_func import create_classifier, create_optimizer, create_scheduler, create_model, verification_test, verification_extract, \
    args_parse, args_model, save_model_args

from cords.utils.data.dataloader.SL.adaptive import GLISTERDataLoader, GradMatchDataLoader

#, OLRandomDataLoader, \
    # CRAIGDataLoader, GradMatchDataLoader, RandomDataLoader
from dotmap import DotMap


In [3]:
def __get_logger(results_dir):
    os.makedirs(results_dir, exist_ok=True)
    # setup logger
    plain_formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s",
                                      datefmt="%m/%d %H:%M:%S")
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    s_handler = logging.StreamHandler(stream=sys.stdout)
    s_handler.setFormatter(plain_formatter)
    s_handler.setLevel(logging.INFO)
    logger.addHandler(s_handler)
    f_handler = logging.FileHandler(os.path.join(results_dir, "results.log"))
    f_handler.setFormatter(plain_formatter)
    f_handler.setLevel(logging.DEBUG)
    logger.addHandler(f_handler)
    logger.propagate = False
    return logger

def generate_cumulative_timing(mod_timing):
    tmp = 0
    mod_cum_timing = np.zeros(len(mod_timing))
    for i in range(len(mod_timing)):
         mod_cum_timing[i] = tmp
    return mod_cum_timing / 3600

def save_ckpt(state, ckpt_path):
    torch.save(state, ckpt_path)

def load_ckpt(ckpt_path, model, optimizer):
    checkpoint = torch.load(ckpt_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    loss = checkpoint['loss']
    metrics = checkpoint['metrics']
    return start_epoch, model, optimizer, loss, metrics

In [4]:
torch.cuda.set_device(2)

In [5]:
data_dir = '/home/yangwenhao/project/SpeakerVerification-pytorch'
lstm_dir = '/home/yangwenhao/project/lstm_speaker_verification/data'

In [6]:
seed = 1234

np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.benchmark = True

In [18]:
# train_config = data_dir + '/Data/checkpoint/ECAPA_brain/Mean_batch96_SASP2_em192_official_2sesmix8/arcsoft_adam_cyclic/vox1/wave_fb80_band05_aug5/123456/model.2023.12.17.yaml'

train_config = 'model.2023.12.17.yaml'

with open(train_config, 'r') as f:
    config_args = load_hyperpyyaml(f)

### Dataset Loader

In [8]:
train_dir, valid_dir, train_extract_dir = SubScriptDatasets(config_args)
# train_loader, train_sampler, valid_loader, valid_sampler, train_extract_loader, train_extract_sampler = Sampler_Loaders(
#         train_dir, valid_dir, train_extract_dir, config_args)
# train_dir.base_utts = train_dir.base_utts[:153600]
train_loader, valid_loader, train_extract_loader = SubLoaders(train_dir, valid_dir, train_extract_dir, config_args)

==> Generating 1 lengths with Average: 32000.


In [19]:
device = 'cuda:2' #Device Argument

model = config_args['embedding_model']

model.loss = SpeakerLoss(config_args)
model.loss.reduction = 'none'

model = model.to(device)

In [10]:
model_para = [{'params': model.parameters()}]
if config_args['loss_type'] in ['center', 'variance', 'mulcenter', 'gaussian', 'coscenter', 'ring']:
    assert config_args['lr_ratio'] > 0
    model_para.append({'params': model.loss.xe_criterion.parameters(
    ), 'lr': config_args['lr'] * config_args['lr_ratio']})

if 'second_wd' in config_args and config_args['second_wd'] > 0:
    # if config_args['loss_type in ['asoft', 'amsoft']:
    classifier_params = list(map(id, model.classifier.parameters()))
    rest_params = filter(lambda p: id(
        p) not in classifier_params, model.parameters())

    init_lr = config_args['lr'] * \
        config_args['lr_ratio'] if config_args['lr_ratio'] > 0 else config_args['lr']
    init_wd = config_args['second_wd'] if config_args['second_wd'] > 0 else config_args['weight_decay']
    print('Set the lr and weight_decay of classifier to %f and %f' %
          (init_lr, init_wd))
    model_para = [{'params': rest_params},
                  {'params': model.classifier.parameters(), 'lr': init_lr, 'weight_decay': init_wd}]

Set the lr and weight_decay of classifier to 0.001000 and 0.000200


In [11]:
fraction = config_args['coreset_percent']

opt_kwargs = {'lr': config_args['lr'], 'lr_decay': config_args['lr_decay'],
                  'weight_decay': config_args['weight_decay'],
                  'dampening': config_args['dampening'],
                  'momentum': config_args['momentum'],
                  'nesterov': config_args['nesterov']}

optimizer = create_optimizer(
        model_para, config_args['optimizer'], **opt_kwargs)
scheduler = create_scheduler(optimizer, config_args, train_dir)
early_stopping_scheduler = EarlyStopping(patience=config_args['early_patience'],
                                         min_delta=config_args['early_delta'])

In [16]:
#Results logging directory
result_dname = 'results_gradmatch_{:.2f}_batch{}'.format(fraction, config_args['batch_size'])

# if 'num_pipes' in config_args:
#     result_dname += '_aug{}'.format(config_args['num_pipes'])
    
results_dir = osp.abspath(osp.expanduser(result_dname))
logger = __get_logger(results_dir)
logger.info("hello")

[12/20 01:21:58] __main__ INFO: hello
[12/20 01:21:58] __main__ INFO: hello


In [13]:
# import copy
# copy.deepcopy(dss_args.model.loss)

# # model

In [20]:
selection_strategy = 'GLISTER'
# dss_args = dict(model=model,
#                 loss=model.loss,
#                 eta=0.01,
#                 num_classes=1211,
#                 num_epochs=30,
#                 device='cuda',
#                 fraction=0.25,
#                 select_every=6,
#                 kappa=0,
#                 linear_layer=False,
#                 selection_type='SL',
#                 greedy='Stochastic')

dss_args=dict(type="GradMatch",
                            fraction=fraction,
                            select_every=6,
                            lam=0.5,
                            selection_type='PerBatch',
                            v1=True,
                            valid=False,
                            kappa=0,
                            eps=1e-100,
                            linear_layer=False,
                            model=model,
                            loss=model.loss,
                            eta = 0.001,
                            num_classes = 1211,
                            device = 'cuda'
                            )

dss_args = DotMap(dss_args)
# dataloader = GLISTERDataLoader(train_loader, valid_loader, dss_args, logger, 
#                                   batch_size=config_args['batch_size'], 
#                                   shuffle=True,
#                                   pin_memory=False)
dataloader = GradMatchDataLoader(train_loader, valid_loader, dss_args, logger, 
                                  batch_size=config_args['batch_size'], 
                                  shuffle=True,
                                  pin_memory=False)


In [None]:
#Training Arguments
num_epochs = 30

#Arguments for results logging
print_every = 1
# print_args = ["val_loss", "val_acc", "tst_loss", "tst_acc", "time"]
print_args = ["val_loss", "val_acc", "time"]

#Argumets for checkpointing
save_every = 3
is_save = True

#Evaluation Metrics
trn_losses = list()
val_losses = list()
tst_losses = list()
subtrn_losses = list()
timing = list()
trn_acc = list()
val_acc = list()  
tst_acc = list()  
subtrn_acc = list()

### Training

In [24]:
if 'augment_pipeline' in config_args:
    num_pipes = config_args['num_pipes'] if 'num_pipes' in config_args else 1
    augment_pipeline = []
    for _, augment in enumerate(config_args['augment_pipeline']):
        augment_pipeline.append(augment.cuda())

In [25]:
"""
################################################# Training Loop #################################################
"""
for epoch in range(num_epochs):
    subtrn_loss = 0
    subtrn_correct = 0
    subtrn_total = 0
    model.train()
    start_time = time.time()
    for _, (inputs, targets, weights) in tqdm(enumerate(dataloader), ncols=50):
        
        inputs = inputs.to(device)
        label = targets.to(device)
        # targets = targets.to(device, non_blocking=True)
        weights = weights.to(device)  
        
        if 'augment_pipeline' in config_args:
            with torch.no_grad():
                wavs_aug_tot = []
                labels_aug_tot = []
                weights_aug_tot = []
                
                wavs_aug_tot.append(inputs.cuda()) # data_shape [batch, 1,1,time]
                labels_aug_tot.append(label.cuda())
                weights_aug_tot.append(weights.cuda())
                
                wavs = inputs.squeeze().cuda()
                wav_label = label.squeeze().cuda()
                wav_weights = weights.squeeze().cuda()
                
                augs_idx = np.random.choice(len(augment_pipeline), size=num_pipes, replace=False)
                augs_idx = set(augs_idx)
                augs = [augment_pipeline[i] for i in augs_idx]
                sample_idxs = [np.arange(len(wavs))] * len(augs_idx)

                for data_idx, augment in zip(sample_idxs, augs):
                    # Apply augment
                    wavs_aug = augment(wavs[data_idx], torch.tensor([1.0]*len(wavs)).cuda())
                    # Managing speed change
                    if wavs_aug.shape[1] > wavs.shape[1]:
                        wavs_aug = wavs_aug[:, 0 : wavs.shape[1]]
                    else:
                        zero_sig = torch.zeros_like(wavs)
                        zero_sig[:, 0 : wavs_aug.shape[1]] = wavs_aug
                        wavs_aug = zero_sig

                    if 'concat_augment' in config_args and config_args['concat_augment']:
                        wavs_aug_tot.append(wavs_aug.unsqueeze(1).unsqueeze(1))
                        labels_aug_tot.append(wav_label[data_idx])
                        weights_aug_tot.append(wav_weights[data_idx].cuda())
                    else:
                        wavs = wavs_aug
                        wavs_aug_tot[0] = wavs_aug.unsqueeze(1).unsqueeze(1)
                        labels_aug_tot[0] = wav_label[data_idx]
                
                inputs = torch.cat(wavs_aug_tot, dim=0)
                label = torch.cat(labels_aug_tot)
                weights = torch.cat(weights_aug_tot, dim=0)
                
        # print(inputs)
        optimizer.zero_grad()

        outputs, feats = model(inputs)
        losses  = model.loss((outputs, feats), targets)
        loss = torch.dot(losses, weights/(weights.sum()))
        loss.backward()
        
        subtrn_loss += loss.item()
        optimizer.step()
        _, predicted = outputs.max(1)
        subtrn_total += targets.size(0)
        subtrn_correct += predicted.eq(targets).sum().item()
        
        scheduler.step()
        
    epoch_time = time.time() - start_time
    timing.append(epoch_time)


    """
    ################################################# Evaluation Loop #################################################
    """

    if (epoch + 1) % print_every == 0:
        trn_loss, trn_correct, trn_total = 0, 0, 0
        val_loss, val_correct, val_total = 0, 0, 0
        tst_correct, tst_total, tst_loss = 0, 0, 0
        
        model.eval()

        if ("trn_loss" in print_args) or ("trn_acc" in print_args):
            with torch.no_grad():
                for _, (inputs, targets) in enumerate(valid_loader):
                    inputs, targets = inputs.to(device), \
                                      targets.to(device, non_blocking=True)
                    
                    outputs, feats = model(inputs)
                    loss  = model.loss((outputs, feats), targets)
                    
                    trn_loss += loss.item()
                    if "trn_acc" in print_args:
                        _, predicted = outputs.max(1)
                        trn_total += targets.size(0)
                        trn_correct += predicted.eq(targets).sum().item()
                        
                trn_losses.append(trn_loss)

            if "trn_acc" in print_args:
                trn_acc.append(trn_correct / trn_total)

        if ("val_loss" in print_args) or ("val_acc" in print_args):
            with torch.no_grad():
                for _, (inputs, targets) in enumerate(valid_loader):
                    inputs, targets = inputs.to(device), \
                                      targets.to(device, non_blocking=True)
                    # outputs = model(inputs)
                    # loss = criterion(outputs, targets)
                    outputs, feats = model(inputs)
                    loss  = model.loss((outputs, feats), targets)
                    
                    val_loss += loss.mean().item()
                    if "val_acc" in print_args:
                        _, predicted = outputs.max(1)
                        val_total += targets.size(0)
                        val_correct += predicted.eq(targets).sum().item()
                val_losses.append(val_loss)

            if "val_acc" in print_args:
                val_acc.append(val_correct / val_total)

        if ("tst_loss" in print_args) or ("tst_acc" in print_args):
            with torch.no_grad():
                for _, (inputs, targets) in enumerate(valid_loader):
                    inputs, targets = inputs.to(device), \
                                      targets.to(device, non_blocking=True)
                    
                    # outputs = model(inputs)
                    # loss = criterion(outputs, targets)
                    
                    outputs, feats = model(inputs)
                    loss  = model.loss((outputs, feats), targets)
                    
                    tst_loss += loss.mean().item()
                    if "tst_acc" in print_args:
                        _, predicted = outputs.max(1)
                        tst_total += targets.size(0)
                        tst_correct += predicted.eq(targets).sum().item()
                tst_losses.append(tst_loss)

            if "tst_acc" in print_args:
                tst_acc.append(tst_correct / tst_total)

        if "subtrn_acc" in print_args:
            subtrn_acc.append(subtrn_correct / subtrn_total)

        if "subtrn_losses" in print_args:
            subtrn_losses.append(subtrn_loss)

        print_str = "Epoch: " + str(epoch + 1)

        """
        ################################################# Results Printing #################################################
        """

        for arg in print_args:

            if arg == "val_loss":
                print_str += " , " + "Valid Loss: {:.8f}".format(val_losses[-1])

            if arg == "val_acc":
                print_str += " , " + "Valid Accuracy: {:.4f}".format(val_acc[-1])

            if arg == "tst_loss":
                print_str += " , " + "Test Loss: {:.8f}".format(tst_losses[-1])

            if arg == "tst_acc":
                print_str += " , " + "Test Accuracy: {:.4f}".format(tst_acc[-1])

            if arg == "trn_loss":
                print_str += " , " + "Train Loss: {:.8f}".format(trn_losses[-1])

            if arg == "trn_acc":
                print_str += " , " + "Train Accuracy: {:.4f}".format(trn_acc[-1])

            if arg == "subtrn_loss":
                print_str += " , " + "Subset Loss: {:.8f}".format(subtrn_losses[-1])

            if arg == "subtrn_acc":
                print_str += " , " + "Subset Accuracy: {:.4f}".format(subtrn_acc[-1])

            if arg == "time":
                print_str += " , " + "Timing: {:.2f}".format(timing[-1])

        logger.info(print_str)

    """
    ################################################# Checkpoint Saving #################################################
    """

    if ((epoch + 1) % save_every == 0) and is_save:

        metric_dict = {}

        for arg in print_args:
            if arg == "val_loss":
                metric_dict['val_loss'] = val_losses
            if arg == "val_acc":
                metric_dict['val_acc'] = val_acc
            if arg == "tst_loss":
                metric_dict['tst_loss'] = tst_losses
            if arg == "tst_acc":
                metric_dict['tst_acc'] = tst_acc
            if arg == "trn_loss":
                metric_dict['trn_loss'] = trn_losses
            if arg == "trn_acc":
                metric_dict['trn_acc'] = trn_acc
            if arg == "subtrn_loss":
                metric_dict['subtrn_loss'] = subtrn_losses
            if arg == "subtrn_acc":
                metric_dict['subtrn_acc'] = subtrn_acc
            if arg == "time":
                metric_dict['time'] = timing

        ckpt_state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': model.loss,
            'metrics': metric_dict
        }

        # save checkpoint
        save_ckpt(ckpt_state, results_dir + '/model.pt')
        logger.info("Model checkpoint saved at epoch: {0:d}".format(epoch + 1))

"""
################################################# Results Summary #################################################
"""

logger.info("{0:s} Selection Run---------------------------------".format(selection_strategy))
logger.info("Final SubsetTrn: {0:f}".format(subtrn_loss))
if "val_loss" in print_args:
    if "val_acc" in print_args:
        logger.info("Valid Loss: %.2f , Validation Accuracy: %.2f", val_loss, val_acc[-1])
    else:
        logger.info("Valid Loss: %.2f", val_loss)

if "tst_loss" in print_args:
    if "tst_acc" in print_args:
        logger.info("Test Loss: %.2f, Test Accuracy: %.2f", tst_loss, tst_acc[-1])
    else:
        logger.info("Test Data Loss: %f", tst_loss)
        
logger.info('---------------------------------------------------------------------')
logger.info(selection_strategy)
logger.info('---------------------------------------------------------------------')

NameError: name 'num_epochs' is not defined

In [None]:
"""
################################################# Final Results Logging #################################################
"""

if "val_acc" in print_args:
    val_str = "Valid Accuracy, " + " , ".join([str(val) for val in val_acc])
    logger.info(val_str)

if "tst_acc" in print_args:
    tst_str = "Test Accuracy, " + " , ".join([str(tst) for tst in tst_acc])
    logger.info(tst_str)

if "time" in print_args:
    time_str = "Time, " + " , ".join([str(t) for t in timing])
    logger.info(timing)

timing_array = np.array(timing)
cum_timing = list(generate_cumulative_timing(timing_array))
logger.info("Total time taken by %s = %.4f ", selection_strategy, cum_timing[-1])