In [None]:
import torch
from torch import nn
import matplotlib.pyplot as plt
from pipelines import pipeline
import numpy as np
from utils import *
from tqdm import trange
from loaders import DatasetLoader
import os, random, string
from omegaconf import OmegaConf

In [None]:
# Hardware
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [None]:

torch.manual_seed(0)
np.random.seed(0)

In [None]:
def main(n_iters, cfg, dataset_fname, file_dir, ckpt = None):
    ###################### Load Datasets #######################
    loader = DatasetLoader(dataset_fname)
    loader.split_train_val(ratio = cfg.training.train_split)
    ############################################################
    
    # Initialize models
    (model, 
     fine_model, 
     encode, 
     encode_viewdirs,
     optimizer, 
     scheduler,
     synthesizer
    ) = init_models(cfg, device, ckpt)

    # Loss function: NMSE loss
    loss_fn = lambda pred, true: torch.sum(torch.abs(pred-true)**2)/torch.sum(torch.abs(true)**2) 
    snr = lambda loss: -10. * np.log10(loss)
    
    # Training iterations
    train_snrs = []
    val_coarse_snrs = []
    val_fine_snrs = []
    iternum = []
    for i in trange(n_iters):
        logging.debug(f"Iteration: {i}")
        # Run the training pipeline
        try: 
            sta_id = np.random.choice(loader.trainset, cfg.training.batch_size)
            (
                total_loss_coarse, 
                total_loss_fine,
                _, 
                _, 
                _, 
                _, 
            ) = pipeline(cfg,
                        sta_id,
                        loader,
                        model, 
                        fine_model, 
                        encode, 
                        encode_viewdirs, 
                        optimizer,
                        loss_fn,
                        synthesizer,
                        device,
                        mode = 'Train')

            train_snrs.append(snr(total_loss_fine))
    #         Save a checkpoint at given rate
            if i % cfg.training.save_rate == 0 or i == cfg.training.n_iters-1:
                save_path = os.path.join(file_dir, "ckpt.pt") if cfg.training.overwrite else os.path.join(file_dir, f'ckpt_iter_{i}.pt')
                save_ckpt(model, fine_model, optimizer, save_path)

            # Evaluate at given display rate.
            if i % cfg.training.display_rate == 0:
                with torch.no_grad():
                    sta_id = loader.valset
                    (
                        total_loss_coarse, 
                        total_loss_fine,
                        cfr_pred_coarse, 
                        cfr_pred_fine, 
                        total_weights_coarse, 
                        total_weights_fine
                    ) = pipeline(cfg,
                                sta_id,
                                loader,
                                model, 
                                fine_model, 
                                encode, 
                                encode_viewdirs, 
                                optimizer,
                                loss_fn,
                                synthesizer,
                                device,
                                mode = 'Eval')
                    
                    val_coarse_snrs.append(snr(total_loss_coarse))
                    val_fine_snrs.append(snr(total_loss_fine))
                    iternum.append(i)
                    if scheduler is not None:
                        scheduler.step(total_loss_fine+total_loss_coarse)
                    #-------------- Plot results-------------------------------
                    fig, ax = plt.subplots(1, 3, figsize=(12,4))
                    target_cfr = loader.get_cfr_batch(1, sta_id).flatten()

                    ax[0].plot(np.real(target_cfr), np.imag(target_cfr), "ro", label = "GT")
                    ax[0].plot(np.real(cfr_pred_coarse.cpu()), np.imag(cfr_pred_coarse.cpu()), "bo", label = "Coarse")
                    ax[0].set_title(f"SNR: {snr(total_loss_coarse):.2f} dB")
                    ax[0].legend()
                    
                    ax[1].plot(np.real(target_cfr), np.imag(target_cfr), "ro", label = "GT")
                    ax[1].plot(np.real(cfr_pred_fine.cpu()), np.imag(cfr_pred_fine.cpu()), "bo", label = "Fine")
                    ax[1].set_title(f"SNR: {snr(total_loss_fine):.2f} dB")
                    ax[1].legend()
                    
                    ax[2].plot(range(0, i + 1), train_snrs, 'r', label = "Train Fine")
                    ax[2].plot(iternum, val_coarse_snrs, 'y', label = 'Val Coarse')
                    ax[2].plot(iternum, val_fine_snrs, 'b', label = 'Val Fine')
                    ax[2].legend()
                    plt.show()
                    del cfr_pred_coarse, cfr_pred_fine, total_weights_coarse, total_weights_fine, total_loss_coarse, total_loss_fine

        except Exception: 
            save_path = os.path.join(file_dir, "ckpt.pt") if cfg.training.overwrite else os.path.join(file_dir, f'ckpt_iter_{i}.pt')
            save_ckpt(model, fine_model, optimizer, save_path)  
            raise Exception

In [None]:
######################## Configurations ###################
cfg = OmegaConf.load('./config/default.yaml')
env = "conference"

if env == "office":
    cfg.sampling.n_samples = 256+128
    cfg.sampling.n_samples_hierarchical = 128
    cfg.sampling.far = 24
###########################################################
datadir = "./data/"
dataset_fname = os.path.join(datadir, f"dataset_{env}_ch1_rt_image_fc.pkl")
if not os.path.exists(cfg.training.save_dir):
    os.mkdir(cfg.training.save_dir)
file_dir = os.path.join(cfg.training.save_dir, ''.join(random.choices(string.ascii_uppercase + string.digits, k=5)))
os.mkdir(file_dir)
print("Checkpoint will be saved at: ", file_dir)

# Speficify the checkpoint file name to load if retraining
# ckpt_fname = r"ckpt\conference_ckpt_1R.pt"
ckpt_fname = None
main(cfg, dataset_fname = dataset_fname, file_dir = file_dir, ckpt=ckpt_fname)