### Plot model comparison

In [31]:
import matplotlib
matplotlib.use('ps')
from matplotlib import rc
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as mpatches
from matplotlib import colors


rc('text',usetex=True)
rc('text.latex', preamble=r'\usepackage{color}')

def plot_model_predictions(npz_names, fname=None, labels=None, titles=None):
    dlabels = labels
    if not dlabels:
        dlabels = npz_names
    
    n_models = len(npz_names)
    model_colors = ['c','m','y']



    #Layout

    fig, axs = plt.subplots(1, 3, sharey=True, figsize=(10,4), layout='tight')
    ax1, ax2, ax3 = axs
    
    ax1.set_ylabel(f'Predicted Duration')
    ticks = np.arange(0.0, 3.5, 0.5)
    print(ticks)
    ax1.set_yticks(ticks)

    #ax2.get_yaxis().set_visible(False)
    
    #ax3.get_yaxis().set_visible(False)

    
    #ax1 formatting
    for j, ax in enumerate(axs):
        ax.set_xlabel(r'True Duration')
        ax.set_xticks(ticks)

        ax.grid(True)
        if titles: ax.set_title(titles[j])

        recs = []
        for i in range(n_models):
            recs.append(mpatches.Rectangle((0,0),0.5,0.5, fc=model_colors[i]))
        ax.legend(recs, dlabels)
        
        #load and plot the parameter prediction data for each model
        targets, pred, err = np.array([]), np.array([]), np.array([])
        for i, name in enumerate(npz_names[j]):
            result = np.load(name)
    
            model_targets = np.array(result['targets'][:, 0])
            model_pred = np.array(result['predictions'][:, 0])
            model_err = 100.0*(1-model_pred/model_targets)
    
            ax.scatter(model_targets, model_pred, alpha=0.7, s=1.5, c=model_colors[i])
    
            targets = np.append(targets,model_targets)
            pred = np.append(pred, model_pred)
            err = np.append(err, model_err)
        
        #find axis limits 
        mintargets, maxtargets = np.min(targets), np.max(targets)

        ax.set_xlim(0.0, 3.0)
        ax.set_ylim(0.0, 3.0)
        
        #Plot target line
        ideal = np.linspace(0.0,3.0,10)
        ax.plot(ideal, ideal, 'k--', alpha=0.3, linewidth=1.5)
    
        ax.set_aspect('equal')
    #Save
    if fname:
        plt.savefig(fname, dpi=128)
    else:
        plt.show()
    plt.close()


pairs = [("p21c","ctrpx","zreion"), ("p21c","zreion","ctrpx"), ("zreion","ctrpx","p21c")]
titles = [
    r"Trained on \textcolor{green}{21cmFAST} and \textcolor{blue}{Central Pixel}", 
    r"Trained on \textcolor{green}{21cmFAST} and \textcolor{red}{RLS}", 
    r"Trained on \textcolor{blue}{Central Pixel} and \textcolor{red}{RLS}"]

path = "/users/jsolt/FourierNN/trained_models"
labels = ["Adversarial", "Non-Adversarial"]

npz_names = []

for m1, m2, m3 in pairs:
    model_names = [
        f"adversarial_v02_{m1}_{m2}_alpha0.01_lr0.001_ws0.0_s02",
        #f"adversarial_null_v01_{m1}_{m2}_ws0.0_s03"
        f"mixed_{m1}_{m2}_m256_dur_ws0.0_lr0.003_bs64_v02"
    ]

    npz_names.append([f"{path}/{name}/pred_{name}_on_{m3}_test.npz" for name in model_names])


plot_model_predictions(npz_names, fname="adv_vs.jpeg", titles=titles, labels=labels)


[0.  0.5 1.  1.5 2.  2.5 3. ]


In [None]:

labels = ["21cmFAST","Central Pixel","RLS"]

npz_names = []

for m1, m2, _ in pairs:
    name = f"adversarial_v02_{m1}_{m2}_alpha0.01_lr0.003_ws0.0_s03"

    npz_names.append([f"{path}/{name}/pred_{name}_on_{m3}_test.npz" for m3 in ["p21c","ctrpx","zreion"]])


plot_model_predictions(npz_names, titles=titles, labels=labels)


### Testing Adversarial Model
Load packages and initialize parameters:

In [None]:
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from EoR_Dataset import EORImageDataset
from adversarial_model import encoder, discriminator, regressor
from hyperparams import Model_Hyperparameters, Dataset_Hyperparameters

def get_path(model_sim, ws):
    if model_sim == "p21c":
        dp = f"/users/jsolt/data/jsolt/21cmFAST_sims/p21c14/p21c14_ws{ws}_trnspsd_meanremoved_norm.hdf5"
    elif model_sim == "zreion":
        dp = f"/users/jsolt/data/jsolt/zreion_sims/zreion21/zreion21_transposed_ws{ws}.hdf5"
    elif model_sim == "ctrpx":
        dp = f"/users/jsolt/data/jsolt/21cmFAST_centralpix_v04/21cmFAST_centralpix_v04_transposed_ws{ws}.hdf5"
    return dp

batchsize = 4
scale = 256
zindices = [x*17 for x in range(0, 30)]
ws=0.0

param = 'dur'
param_index = 0 if param=="mdpt" else 1

train_sims = ["p21c"] 

data_paths = []
for ts in train_sims:
    data_paths.append(get_path(ts, ws))

hp_data = Dataset_Hyperparameters(
    train_sims, 
    data_paths, 
    zindices, 
    batchsize, 
    subsample_scale=scale, 
    param=param_index,
    n_limit=batchsize*2,
)


Load data:

In [None]:
# training dataset
print("Loading training data...")

test_data = EORImageDataset("test", hp_data)
test_dataloader = DataLoader(train_data, batch_size=batchsize, shuffle=True)
for batch, (X, y) in enumerate(train_dataloader):
    print(len(train_dataloader))
    print(y.shape)
    

Test in context:

In [None]:
from train_Fourier_NN import train_adversarial_NN
from predict_Fourier_NN import predict_adversarial_NN
from plot_model_results import plot_model_predictions


model_name = "test_adversarial"
model_dir = "models/" + model_name
epochs = 2
lr=0.01
alpha=1.0

hp_model = Model_Hyperparameters(model_name, hp_train_data, epochs=epochs, init_lr=lr, alpha=alpha)

#print("Training Loop")
train_adversarial_NN(hp_model)

print("Prediction Loop")
#plot_save_dir = f"{hp_model.MODEL_DIR}/pred_plots"
psdirs = [f"models/plots/{hp_model.MODEL_NAME}", f"{hp_model.MODEL_DIR}/pred_plots"]


for plot_save_dir in psdirs:
    if not os.path.isdir(plot_save_dir): os.mkdir(plot_save_dir)
    
all_sims = ['zreion', 'p21c', 'ctrpx']
for pred_sims in [train_sims, all_sims]:
    pred_set = [Dataset_Hyperparameters([x], [get_path(x, ws)], zindices, batchsize, subsample_scale=scale, param=param_index, n_limit=8) for x in pred_sims]
    pred_files_test = [predict_adversarial_NN(hp_model=hp_model, hp_test=p, mode="test") for p in pred_set]

    print('Plotting...')
    for plot_save_dir in psdirs:
        
        fig_name_test = f"{plot_save_dir}/duration_{hp_model.MODEL_NAME}"
    
        for sim in pred_sims:
            fig_name_test += f"_{sim}"
        
        title = model_name
        labels = pred_sims
        
        plot_model_predictions(pred_files_test, fig_name_test, param_index, labels, title)


### Loss Visualization

In [None]:
import numpy as np
from plot_model_results import plot_loss

path = "models/adversarial_p21c_zreion_ws0.0_alpha0.1_v01/adversarial_p21c_zreion_ws0.0_alpha0.1_v01"
encloss = f"{path}_enc_loss.npz"
regloss = f"{path}_reg_loss.npz"
disloss = f"{path}_dis_loss.npz"

with np.load(encloss) as loss:
    plot_loss(loss, f"{path}_enc_loss.png", title="adversarial_p21c_zreion_ws0.0_alpha0.1_v01 Encoder Loss (linear)", logloss=False)

with np.load(regloss) as loss:
    plot_loss(loss, f"{path}_reg_loss.png", title="adversarial_p21c_zreion_ws0.0_alpha0.1_v01 Regressor Loss (linear)", logloss=False)

with np.load(disloss) as loss:
    plot_loss(loss, f"{path}_dis_loss.png", title="adversarial_p21c_zreion_ws0.0_alpha0.1_v01 Discriminator Loss (linear)", logloss=False)


In [None]:
import numpy as np
from plot_model_results import plot_loss

path = "models/adversarial_zreion_ctrpx_ws0.0_alpha0.1_beta0.1_v01/adversarial_zreion_ctrpx_ws0.0_alpha0.1_beta0.1_v01"
encloss = f"{path}_enc_loss.npz"
regloss = f"{path}_reg_loss.npz"
disloss = f"{path}_dis_loss.npz"

loss={}
with np.load(encloss) as data:
    loss["train"]=np.clip(data["train"], 0.09, 0.11)
    loss["val"]=np.clip(data["val"],  0.09, 0.11)
    plot_loss(loss, f"{path}_enc_loss.png", title="adversarial_p21c_zreion_ws0.0_alpha0.1_v01 Encoder Loss (linear)", logloss=False)


In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import numpy as np
import os
import h5py
import pandas as pd
%matplotlib inline

modules = ["enc", "reg", "dis"]
sims = ["zreion", "ctrpx"]
allsims = ["p21c", "zreion", "ctrpx"]
start = 10
for alpha in [0.1]:
    for ws in [0.0]:
        name = f"adversarial_v02_{sims[0]}_{sims[1]}_alpha{alpha}_lr0.003_ws{ws}"
        lossdict = {}
        for module in modules:
            lossdict[module] = {}
            path = f"models/{name}/{name}_{module}_loss.npz"
            with np.load(path) as data:
                lossdict[module]['train']=pd.DataFrame(data['train'][start:])
                lossdict[module]['val']=pd.DataFrame(data['val'][start:])
        
        fname = f"models/{name}/{name}_all_loss.png"
        title = f"{name} Loss"



        n_axes =len(lossdict)
    
        fig, axs = plt.subplots(n_axes, 1, sharex=True, tight_layout=True, figsize=(6, 12))
        
        fig.suptitle(title)
        for r, (label, loss) in enumerate(lossdict.items()):
            axs[r].grid(True)
            
        
            axs[r].set_ylabel('MSE Loss (tanh + exp running avg')
            
            epochs = len(loss["train"])
            axs[r].plot(np.arange(start+1, epochs+start+1), np.tanh(loss["val"]).ewm(com=5.0).mean(), label='Validation loss', linewidth=0.7)
            axs[r].plot(np.arange(start+1, epochs+start+1), np.tanh(loss["train"]).ewm(com=5.0).mean(), label='Training loss', linewidth=0.7)
            axs[r].set_title(label)
            axs[r].legend()
            axs[r].axvline(1000, color='red', ls=":", label="label")
        axs[-1].set_xlabel('Epochs')
        plt.show()

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import numpy as np
import os
import h5py
import pandas as pd
%matplotlib inline

def plot_loss_grid(lossdict, fname, title="", start=10, steps=[], steplabels=[]):
    fig, axs = plt.subplots(len(lossdict), 1, sharex=True, tight_layout=True, figsize=(6, 12))
    fig.suptitle(title)
    
    for r, (label, loss) in enumerate(lossdict.items()):
        axs[r].grid(True)
        axs[r].set_ylabel('MSE Loss (tanh + exp running avg')
    
        train_loss = pd.DataFrame(np.tanh(loss['train'][start:])).ewm(com=5.0).mean()
        val_loss = pd.DataFrame(np.tanh(loss['val'][start:])).ewm(com=5.0).mean()
        
        epochs = np.arange(start+1, len(train_loss)+start+1)

        axs[r].plot(epochs, val_loss, label='Validation', linewidth=0.7)
        axs[r].plot(epochs, train_loss, label='Training', linewidth=1.0)
        
        for i in range(len(steps)):
            axs[r].axvline(steps[i], color='red', ls=":", alpha=0.5)
            axs[r].text(
                steps[i]-0.02*(axs[r].get_xlim()[1]), 
                axs[r].get_ylim()[1] - 0.03*(axs[r].get_ylim()[1]-axs[r].get_ylim()[0]), 
                steplabels[i],
                horizontalalignment='right',
                verticalalignment='top',
                color='#000000',
                backgroundcolor='#eeeeeec0',)
        
        axs[r].set_title(label)
        axs[r].legend(loc='upper left')
    axs[-1].set_xlabel('Epochs')
        
    plt.show()


modules = ["enc", "reg", "dis"]
sims = ["zreion", "ctrpx"]
allsims = ["p21c", "zreion", "ctrpx"]
steplabels = ["lr=0.003", "lr=0.001", "lr=0.003"]


for alpha in [0.01]:
    for ws in [0.0]:
        names = [
            f"adversarial_v02_{sims[0]}_{sims[1]}_alpha{alpha}_lr0.003_ws{ws}",
            f"adversarial_v02_{sims[0]}_{sims[1]}_alpha{alpha}_lr0.001_ws{ws}_s02",
            f"adversarial_v02_{sims[0]}_{sims[1]}_alpha{alpha}_lr0.003_ws{ws}_s03",
        ]
        lossdict = {}
        for module in modules:
            lossdict[module] = {}
            npz_names = [f"models/{name}/{name}_{module}_loss.npz" for name in names]
            lossdict[module]["train"]=np.array([])
            lossdict[module]["val"]=np.array([])
            steps = []
            for npz_name in npz_names:
                with np.load(npz_name) as data:
                    lossdict[module]["train"] = np.concatenate([lossdict[module]["train"], data["train"]])
                    lossdict[module]["val"] = np.concatenate([lossdict[module]["val"], data["val"]])
                    steps.append(len(lossdict[module]["train"]))

        title = f"{name} Loss"
        print(steps)
        plot_loss_grid(lossdict, "", title, steps=steps, steplabels=steplabels)


### Autoencoder

In [1]:
import torch
from torch.utils.data import DataLoader
from EoR_Dataset import EORImageDataset
from autoencoder import autoencoder
from hyperparams import Dataset_Hyperparameters
import numpy as np

###
# Load Data
###
ws=0.0
sims = ["ctrpx",] 
data_paths = [f"/users/jsolt/data/jsolt/21cmFAST_centralpix_v05/21cmFAST_centralpix_v05_transposed_ws{ws}.hdf5"]

hp = Dataset_Hyperparameters(
                                    sims, 
                                    data_paths, 
                                    zindices=np.linspace(0, 511, 30, dtype=int), 
                                    batchsize=1, 
                                    subsample_scale=256, 
                                    param=1,
                                    n_limit=10,
)

data = EORImageDataset("test", hp)

print(data[0][0].shape)

Sim 0: 10 samples
Total number of samples: 10
Loading cube 0 of 10 from sim 0 (pointer = 0)...
torch.Size([30, 256, 256])


In [2]:
###
# Load Model(s)
###
versions = [24.0, 25.0, 26.0, 27.0]
names = {v: f"autoencoder_v{v:0>4}_ctrpx_ws{ws}" for v in versions}
paths = {v: f"models/{name}/{name}.pth" for v, name in names.items()}
models = {}

for v, path in paths.items():
    models[v] = autoencoder()
    if torch.cuda.is_available(): models[v].cuda()
    models[v].load_state_dict(torch.load(path, map_location=torch.device('cpu')))

print(paths)

{24.0: 'models/autoencoder_v24.0_ctrpx_ws0.0/autoencoder_v24.0_ctrpx_ws0.0.pth', 25.0: 'models/autoencoder_v25.0_ctrpx_ws0.0/autoencoder_v25.0_ctrpx_ws0.0.pth', 26.0: 'models/autoencoder_v26.0_ctrpx_ws0.0/autoencoder_v26.0_ctrpx_ws0.0.pth', 27.0: 'models/autoencoder_v27.0_ctrpx_ws0.0/autoencoder_v27.0_ctrpx_ws0.0.pth'}


In [3]:
%matplotlib inline
import numpy as np
from plot_model_results import plot_image_rows

zind = np.linspace(0, 29, 30, dtype=int)

for lci in [0,1,2,8]:
    rows = []
    with torch.no_grad():
        X, _, _, _ = data[lci]
        inpt = X[None, :, :, :]
        rows.append(inpt[0,zind])

        for v in versions:
            model=models[v]
            model.eval()

            outpt = model(inpt)
            rows.append(outpt[0,zind])

    
    title = f"lightcone {lci}"
    fname = f"model_compare_lightcone_{lci}.png"
    rowlabels = ["input","control","shufflez","zoomin","shufflez + zoomin"]
    plot_image_rows(rows, title=title, fname=fname, collabels=zind, rowlabels=rowlabels)




In [76]:
%matplotlib inline
import matplotlib.pyplot as plt
from plot_model_results import plot_image_grid, plot_image_rows


lci=2
for lci in range(9):
    X, *_ = data[lci]
    inpt = X[None, :, :, :].detach()
    outpt = model(inpt).detach()
    
    metric = np.zeros_like(inpt.numpy())
    for i in range(30):
        metric[0,i] = np.corrcoef(inpt[0,i], outpt[0,i])[256:,:256]
        #metric[0,i] = np.where(np.isnan(metric[0,i]), 0.0, metric[0,i])
    
    
    
    zsample = np.linspace(0, 29, 6, dtype=int)
    rows = [inpt[0,zsample], outpt[0,zsample], metric[0,zsample]]
    rowlabels = ['input', 'output', 'corrcoef']
    plot_image_rows(rows, rowlabels=rowlabels, collabels=zsample, title=f'ctrpx lightcone {lci}', fname=f'ctrpx_lightcone{lci}_corrcoef.png', vmin=-1, vmax=1)
    
    
    


### VAE

In [1]:
import torch
from torch.utils.data import DataLoader
from EoR_Dataset import EORImageDataset
from variational_autoencoder import vae
from hyperparams import DataHyperparameters, ModelHyperparameters
import numpy as np

###
# Load Data
###
ws=0.0
sims = ["ctrpx",] 
data_paths = [f"/users/jsolt/data/jsolt/21cmFAST_centralpix_v05/21cmFAST_centralpix_v05_transposed_ws{ws}.hdf5"]

hp_train_data = DataHyperparameters(
    sims=sims,
    data_paths=data_paths,
    zindices=np.linspace(0, 511, 30, dtype=int).tolist(),
    boxlength=256,
    param_index=1,
    ztransform=["zoomin"],
    lenlimit=8,
)


data = EORImageDataset("test", hp_train_data)

print(data[0][0].shape)
print(len(data))


Sim 0: 8 samples
Total number of samples: 8
Loading cube 0 of 8 from sim 0 (pointer = 0)...
torch.Size([30, 256, 256])
8


In [2]:
###
# Load Model(s)
###
hd1 = 16
hd2 = 32
ld = 64
models = {}
names = {}

ks_dict = {3.0:5, 4.0:7}
pad_dict = {3.0:2, 4.0:3}

for i in [3.0, 4.0]:
    name = f"single_channel_vae_v{i:0>4}_ctrpx_ws0.0"
    names[i] = name

    hp_model = ModelHyperparameters(
        model_name=name,
        device="cuda" if torch.cuda.is_available() else "cpu",
        training_data_hp=hp_train_data, 
        batchsize=64,
        epochs=1000, 
        initial_lr=1e-3,
        lr_milestones=[],
        lr_gamma=0.1,
        parent_model=None,
        input_dim=1, 
        hidden_dim_1=hd1,
        hidden_dim_2=hd2,
        latent_dim=ld,
        kernel_size=ks_dict[i],
        stride=2,
        padding=pad_dict[i]
    )
    path = f"{hp_model.model_dir}/{hp_model.model_name}.pth" #f"trained_models/{name}/{name}.pth" 
    
    models[i] = vae(hp_model)
    models[i].load_state_dict(torch.load(path, map_location=torch.device('cpu')))

print(models.keys())

dict_keys([3.0, 4.0])


In [3]:
%matplotlib inline
import importlib
import numpy as np
import plot_model_results
importlib.reload(plot_model_results)

zind = np.linspace(0, 29, 8, dtype=int)

for lci in [2]:
    rows = {}
    with torch.no_grad():
        X, *_, = data[lci]
        inpt = X[None, zind, :, :]
        rows['input'] = inpt[0]

        for i in [3.0, 4.0]:
            model = models[i]
            model.eval()

            outpt = torch.zeros_like(inpt)
            for j in range(len(zind)):
                slice = inpt[0,j]
                outpt[0,j] = model.decode(model.encode(slice[None,None,:,:])[0])
            rows[f"(v{i:0>4})"] = outpt[0]

            
    title = f"Single-Channel VAE: Lightcone {lci}"
    fname = f"single_channel_vae_ctrpx_ws0.0_lci{lci}_result.png"
    plot_model_results.plot_image_rows(rows, collabels=zind, title=title, fname=fname)


In [6]:
import numpy as np
import plot_model_results
importlib.reload(plot_model_results)

names = [
    "single_channel_vae_v01.0_ctrpx_ws0.0_16_32_64_lr0.0001",
    "single_channel_vae_v03.0_ctrpx_ws0.0",
    "single_channel_vae_v04.0_ctrpx_ws0.0",
]

keys = ['train','val']
for key in keys:
    loss = {}
    for name in names:
        loss_path = f"trained_models/{name}/{name}_loss.npz"
        with np.load(loss_path) as f:
            loss[name] = f[key][:500]
    fname = f"single_channel_vae_kernel_size_loss_comparison.png"
    title = f"Single-Channel VAE: Kernel Size Loss Comparison"
    plot_model_results.plot_loss_comparison(loss, fname, title=title, transform=np.log10, ylabel="Loss (log)")

Loss plot saved.
Loss plot saved.


### Hyperparameter encoding

In [4]:
import json
import jsonpickle
import numpy as np

class DataHyperparameters():
    def __init__(self, /, **kwargs):
        #Defaults
        self.tvt_dict = {"train":0.8, "val":0.10, "test":0.10}
        self.lenlimit = -1
        
        #kwargs
        self.__dict__.update(kwargs)

        #Other attributes for convenience
        self.n_datasets = len(kwargs.get("data_paths", []))
        self.n_channels = len(kwargs.get("zindices", []))
    
    def __repr__(self):
        keys = sorted(self.__dict__)
        items = ("{}={!r}".format(k, self.__dict__[k]) for k in keys)
        return "{}({})".format(type(self).__name__, ", ".join(items))
    
    def __eq__(self, other):
        return self.__dict__ == other.__dict__




class ModelHyperparameters():
    def __init__(self, /, **kwargs):
        #kwargs
        self.__dict__.update(kwargs)

    def __repr__(self):
        keys = sorted(self.__dict__)
        items = ("{}={!r}".format(k, self.__dict__[k]) for k in keys)
        return "{}({})".format(type(self).__name__, ", ".join(items))
    
    def __eq__(self, other):
        return self.__dict__ == other.__dict__



hp_train_data = DataHyperparameters(
    sims=["sim1", "sim2"],
    data_paths=["path1", "path2"],
    zindices=np.linspace(0, 511, 30, dtype=int).tolist(),
    batchsize=64,
    boxlength=256,
    param_index=1,
    ztransform=["shufflez","zoomin"],
    lenlimit=200,
)

hp_model = ModelHyperparameters(
    model_name="model_name", 
    hp_data=hp_train_data, 
    epochs=2000, 
    init_lr=0.001,
    parent_model=None,
)


def save_hyperparameters(hp):
    with open(f"test_hyperparameters.json", 'w') as f:
        pickle = jsonpickle.encode(hp, indent=4)
        f.write(pickle)

def load_hyperparameters(path):
    with open(path, 'r') as f:
        jsonstr = f.read()
    return jsonpickle.decode(jsonstr)

save_hyperparameters(hp_model)

x = load_hyperparameters(f"test_hyperparameters.json")
print(x)

ModelHyperparameters(epochs=2000, hp_data=DataHyperparameters(batchsize=64, boxlength=256, data_paths=['path1', 'path2'], lenlimit=200, n_channels=30, n_datasets=2, param_index=1, sims=['sim1', 'sim2'], tvt_dict={'train': 0.8, 'val': 0.1, 'test': 0.1}, zindices=[0, 17, 35, 52, 70, 88, 105, 123, 140, 158, 176, 193, 211, 229, 246, 264, 281, 299, 317, 334, 352, 370, 387, 405, 422, 440, 458, 475, 493, 511], ztransform=['shufflez', 'zoomin']), init_lr=0.001, model_name='model_name', parent_model=None)


### Misc

In [8]:
import torch

a = torch.Tensor([[8,2,8],[9,1.5,9.5]])

print(a.shape)
b = torch.repeat_interleave(a, 30, dim=0)
print(b.shape)
print(b[:5])

torch.mean(a, )

torch.Size([2, 3])
torch.Size([60, 3])
tensor([[8., 2., 8.],
        [8., 2., 8.],
        [8., 2., 8.],
        [8., 2., 8.],
        [8., 2., 8.]])


tensor([1.])

In [76]:
import torch
from torch import nn
torch.autograd.set_detect_anomaly(True)

def corrcoef_loss(input, target, reduction='mean'):    
    # Covariance
    X = torch.cat((input, target), dim=-2)
    X -= torch.mean(X, -1, keepdim=True)
    X_T = torch.transpose(X, -2, -1)
    c = torch.matmul(X, X_T) / (X.shape[-1] - 1)

    # Correlation Coefficient
    d = torch.diagonal(c, dim1=-1, dim2=-2)
    dd = torch.where(d == 0, 1, d)

    stddev = torch.sqrt(dd)
    c /= stddev[:,:,:,None]
    c /= stddev[:,:,None,:]

    #1 - Cross-Correlation
    ccd = 1-torch.diagonal(c, offset=c.shape[-1]//2, dim1=-1, dim2=-2)

    if reduction == 'mean':
        return ccd.mean()
    elif reduction == 'sum':
        return ccd.sum()
    return ccd



batch_size = 4
channels = 3
imsize = 16
shape = (batch_size, channels, imsize, imsize)
params = batch_size*channels*imsize**2
print(params)

a = torch.rand(shape)
#a = torch.arange(params).reshape(shape) / params
b = torch.rand(shape)
#b = torch.where(a <0.8, 0.0, 1.0)
a.requires_grad = True
b.requires_grad = True


u = nn.functional.mse_loss(a, b, reduction='sum')
v = nn.functional.mse_loss(a, b, reduction='mean')

w = nn.functional.binary_cross_entropy(a, b, reduction='sum')
x = nn.functional.binary_cross_entropy(a, b, reduction='mean')

y = corrcoef_loss(a, b, reduction='sum')
z = corrcoef_loss(a, b, reduction='mean')

print("\nMSE:")
print(f"{u.item():.3f}")
print(f"{v.item():.3f}")

print("\nBCE:")
print(f"{w.item():.3f}")
print(f"{x.item():.3f}")

print("\nCC:")
print(f"{y.item():.3f}")
print(f"{z.item():.3f}")

print("\nSummed BCE / Summed CC:")
print(f"{(w / y).item():.3f}")

print("\nSummed MSE / Summed CC:")
print(f"{(u / y).item():.3f}")


3072

MSE:
508.662
0.166

BCE:
3068.094
0.999

CC:
192.242
1.001

Summed BCE / Summed CC:
15.960

Summed MSE / Summed CC:
2.646


In [78]:
a = torch.rand(shape)
#a = torch.arange(params).reshape(shape) / params
b = torch.where(a < 0.5, a*0.9, a)
a.requires_grad = True
b.requires_grad = True


u = nn.functional.mse_loss(a, b, reduction='sum')
v = nn.functional.mse_loss(a, b, reduction='mean')

w = nn.functional.binary_cross_entropy(a, b, reduction='sum')
x = nn.functional.binary_cross_entropy(a, b, reduction='mean')

y = corrcoef_loss(a, b, reduction='sum')
z = corrcoef_loss(a, b, reduction='mean')

print("\nMSE:")
print(f"{u.item():.3f}")
print(f"{v.item():.3f}")

print("\nBCE:")
print(f"{w.item():.3f}")
print(f"{x.item():.3f}")

print("\nCC:")
print(f"{y.item():.3f}")
print(f"{z.item():.3f}")

print("\nSummed BCE / Summed CC:")
print(f"{(w / y).item():.3f}")

print("\nSummed MSE / Summed CC:")
print(f"{(u / y).item():.3f}")



MSE:
1.297
0.000

BCE:
1499.751
0.488

CC:
0.235
0.001

Summed BCE / Summed CC:
6388.668

Summed MSE / Summed CC:
5.525


In [4]:


def func(*args):
    print(args)
    print(*args)

func(1, 2, 3)


(1, 2, 3)
1 2 3


In [1]:
import torch
from torch import corrcoef