In [None]:
import numpy as np
import os, sys

project_root = os.path.abspath("")  # alternative
if project_root[-12:] == 'LyoSavin2023':
    base_dir = project_root
else:
    base_dir = os.path.dirname(project_root)
sys.path.append(os.path.join(base_dir, 'core'))
sys.path.append(os.path.join(base_dir, 'core/utils'))

import torch
from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt
import zarr

from utils import remove_all_ticks_and_labels

# training a model during the oscillation

In [None]:

def train_model(model, 
                model_name, 
                model_number, 
                num_steps, 
                forward_schedule,
                num_hidden, 
                num_dims,
                num_epochs,
                batch_size,
                lr,
                device,
                dataset,
                pretrained_model):

    # beta-related parameters
    from prior_utils import forward_process
    from utils import noise_estimation_loss
    
    coefs = forward_process(num_steps, device, forward_schedule)
    betas, alphas, alphas_prod, alphas_prod_p, alphas_bar_sqrt, one_minus_alphas_prod_log, one_minus_alphas_prod_sqrt = coefs
    alphas_bar_sqrt = alphas_bar_sqrt.to(device)
    one_minus_alphas_prod_sqrt = one_minus_alphas_prod_sqrt.to(device)
    
    # training set
    dataset = dataset.to(device)
    
    print('model_name:', model_name)
    print('model_number:', model_number)
    print('num_steps:', num_steps)
    print('forward_schedule:', forward_schedule)
    print('num_hidden:', num_hidden)
    print('num_epochs:', num_epochs)
    print('dataset shape:', dataset.shape)
    
    # define model
    if pretrained_model['use_pretrained_model_weights']:
        if pretrained_model['use_checkpoint_weights']==False:
            from utils import load_model_weights
            pretrained_model_name = pretrained_model['model_name']
            pretrained_model_num = pretrained_model['model_num']
            print(f'taking weights from pretrained model {pretrained_model_name}_{pretrained_model_num}!')
            model = load_model_weights(model, pretrained_model_name, pretrained_model_num, device)
        elif pretrained_model['use_checkpoint_weights']==True:
            from utils import load_model_weights_from_chkpt
            model, num_steps, ambient_dims = load_model_weights_from_chkpt(pretrained_model['model_name'], pretrained_model['model_num'], epoch_number=pretrained_model['checkpoint_epoch'], device=device)
            print('model weights loaded from checkpoint!', flush=True)
            
    model.to(device)

    # training parameteres
    optimizer = optim.Adam(model.parameters(), lr=lr)
    if pretrained_model['use_pretrained_model_weights'] and pretrained_model['use_checkpoint_weights']==True:
        from utils import load_optimizer_state_dict
        optimizer = load_optimizer_state_dict(optimizer, pretrained_model['model_name'], pretrained_model['model_num'], epoch_number=pretrained_model['checkpoint_epoch'], device=device)
        print('optimizer state dict loaded from checkpoint!', flush=True)

    run_dir = os.path.join(base_dir, 'demos/runs', f'{model_name}_{model_number}')
    tb = SummaryWriter(run_dir)
    start_time = time.time()
    
    # start training
    model.train()
    for t in tqdm(range(int(num_epochs)), total=int(num_epochs), desc='Training model', unit='epochs', miniters=int(num_epochs)/1000, maxinterval=float("inf")):
        permutation = torch.randperm(dataset.size()[0], device=device)
    
        for i in range(0, dataset.size()[0], batch_size):
            # retrieve current batch
            indices = permutation[i:i+batch_size]
            batch_x = dataset[indices]
            
            # compute the loss
            loss = noise_estimation_loss(model, batch_x, num_steps, alphas_bar_sqrt, one_minus_alphas_prod_sqrt, device, norm='l2', has_class_label=False)
            # zero the gradients
            optimizer.zero_grad()
            # backward pass: compute the gradient of the loss wrt the parameters
            loss.backward()
            # call the step function to update the parameters
            optimizer.step()
        
        if t <= int(2e5):
            if t % int(1e4) == 0:
                save_checkpoint(t, model.state_dict(), optimizer.state_dict(), loss.item(), model_name, model_number)
        else:
            if t % int(1e5) == 0:
                save_checkpoint(t, model.state_dict(), optimizer.state_dict(), loss.item(), model_name, model_number)
        
        # write to tensorboard
        tb.add_scalar('Loss', loss.item(), t)
    tb.flush()

    end_time = time.time()
    duration = end_time - start_time
    duration_mins = duration / 60
    print(f'training took {duration:.0f} seconds, which is {duration_mins:.2f} minutes.')
    return model