In [1]:
import os
from os.path import join
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from astropy.io import fits
import pyxis.torch as pxt
import scipy.optimize as so

from networks import *
from train import *
import config

data_dir = '/xdisk/timeifler/wxs0703/kl_nn/test_data/test_database'
samp_dir = '/xdisk/timeifler/wxs0703/kl_nn/samples/samples_massive.csv'
fig_dir = '/xdisk/timeifler/wxs0703/kl_nn/figures/'
model_dir = '/xdisk/timeifler/wxs0703/kl_nn/model/'
results_dir = '/xdisk/timeifler/wxs0703/kl_nn/results/'

In [2]:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12356"
torch.cuda.set_device(0)
init_process_group(backend='nccl', rank=0, world_size=1)

In [3]:
model_file = join(model_dir, 'Deconv21')
model = load_model(DeconvNN, path=model_file,strict=True, assign=True)

  model.load_state_dict(torch.load(path), strict=strict, assign=assign)


In [4]:
# Get data loader
test_args = list(config.test.values())
test_ds = pxt.TorchDataset(data_dir)
test_dl = DataLoader(test_ds,
                     batch_size=100,
                     pin_memory=True,
                     shuffle=False,
                     sampler=DistributedSampler(test_ds))

In [None]:
def predict(nfeatures, test_data, model, criterion=nn.MSELoss(), gpu_id=0):

    model.eval()
    losses=[]
    for i, batch in enumerate(test_data):
        img = batch['img'].float().to(gpu_id)
        spec = batch['spec'].float().to(gpu_id)
        fid = batch['fid_pars'].float().view(-1, nfeatures).to(gpu_id)
        outputs = model(fid)
        loss = criterion(outputs, img)
        losses.append(loss.item())
        if i == 0:
            ids = batch['id'].numpy()
            preds = outputs.view(-1, 48, 48).detach().cpu().numpy()
            targets = img.view(-1, 48, 48).cpu().numpy()
        else:
            ids = np.concatenate((ids, batch['id'].numpy()))
            preds = np.vstack((preds, outputs.view(-1, 48, 48).detach().cpu().numpy()))
            targets = np.vstack((targets, img.view(-1, 48, 48).cpu().numpy()))

    epoch_loss = sum(losses) / len(losses)
    epoch_loss = np.sqrt(epoch_loss) # comment out if not using MSE

    return ids, preds, targets, epoch_loss