In [1]:
import torch, pickle, time, os, random
import numpy as np
import os.path as osp
import matplotlib.pyplot as plt
import torch_geometric as tg
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
# accelerate huggingface to GPU
if torch.cuda.is_available():
    from accelerate import Accelerator
    accelerator = Accelerator()
    device = accelerator.device

torch.manual_seed(42)
random.seed(42)

In [2]:
cols_t=['mstar stellar mass [1.0E09 Msun](0)',
 ' v_disk rotation velocity of disk [km/s] (1)',
 ' r_bulge 3D effective radius of bulge [kpc](2)',
 ' mcold cold gas mass in disk [1.0E09 Msun](3)',
 ' mHI cold gas mass [1.0E09 Msun](4)',
 ' mH2 cold gas mass [1.0E09 Msun](5)',
 ' mHII cold gas mass [1.0E09 Msun](6)',
 ' sfrave100myr SFR averaged over 100 Myr [Msun/yr](7)']

cols_t=['mstar stellar mass [1.0E09 Msun](0)',
 ' v_disk rotation velocity of disk [km/s] (1)',
 ' mcold cold gas mass in disk [1.0E09 Msun](3)',
 ' sfrave100myr SFR averaged over 100 Myr [Msun/yr](7)']

all_cols=np.array([0,2,4,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,35]+list(range(37,60)))
targets=[8,11,14,15,16,17,18,23]
targets=[8,11,15,23]

In [3]:
os.listdir(osp.expanduser('~/../../../scratch/gpfs/cj1223/GraphStorage/'))

['vlarge_all_4t_z1.0_standard_quant',
 'vlarge_all_4t_z0.3_quantile_raw',
 'vlarge_4t_quantile_raw_redshift_75_all',
 'vlarge_all_4t_z1.0_quantile_raw',
 'vlarge_all_4t_z0.3_None',
 'vlarge_all_4t_z3.0_quantile_raw',
 'vlarge_all_4t_z2.0_standard_quant',
 'vlarge_all_4t_z0.8_quantile_raw',
 'tvt_idx',
 'vlarge_all_4t_z2.0_None',
 'redshift_scan_0',
 'testid_all_4t_z2.0_None',
 'vlarge_all_4t_z0.0_quantile_stand',
 'vlarge_all_multi_try1',
 'vlarge_4t_quantile_raw_redshift_99_all',
 'vlarge_all_4t_z2.0_quantile_raw',
 'vlarge_all_4t_z0.0_standard_quant',
 'vlarge_all_4t_z0.5_quantile_quant',
 'vlarge_4t_quantile_raw_redshift_50_all',
 'vlarge_all_4t_z2.0_quantile_stand',
 'vlarge_all_4t_z1.0_quantile_quant',
 'transformers',
 'vlarge_all_4t_z0.0_standard_raw',
 'vlarge_all_4t_quantile_raw_final',
 'vlarge_all_4t_z0.5_standard_stand',
 'vlarge_all_4t_z1.8_quantile_raw',
 'vlarge_all_4t_z0.5_standard_quant',
 'vlarge_all_4t_zall_quantile_raw_trainandtest',
 'vlarge_all_4t_z0.0_quantile_ra

In [4]:
case='vlarge_all_multi_try1/vlarge_all_multisimple_z0.0_quantile_raw'
case='vlarge_all_4t_z0.0_quantile_raw'



data=pickle.load(open(osp.expanduser(f'~/../../../scratch/gpfs/cj1223/GraphStorage/{case}/data.pkl'), 'rb'))

In [5]:
# from torch_geometric.data import Data

# cols_t=['mstar stellar mass [1.0E09 Msun](0)',
#  ' mcold cold gas mass in disk [1.0E09 Msun](3)']
# cols_t=['mstar stellar mass [1.0E09 Msun](0)',
#   ' v_disk rotation velocity of disk [km/s] (1)']
# targets=[8,15]
# data=[]
# for d in datat:
#     data.append(Data(x=d.x, edge_index=d.edge_index, edge_attr=d.edge_attr, y=d.y[[0,1]]))

In [6]:
ys=[]
for d in data:
    ys.append(d.y.numpy())
ys=np.vstack(ys)


In [7]:
np.corrcoef(ys[:,0],ys[:,1])

array([[1.        , 0.89590685],
       [0.89590685, 1.        ]])

In [8]:
n_targ=len(data[0].y)
n_feat=len(data[0].x[0])
n_feat, n_targ

(43, 4)

In [9]:
import torch.nn.functional as F
from torch.nn import Linear, LayerNorm, LeakyReLU, Module, ReLU, Sequential, ModuleList
from torch_geometric.nn import SAGEConv, global_mean_pool, norm, global_max_pool, global_add_pool, MetaLayer
from torch_scatter import scatter_mean, scatter_sum, scatter_max, scatter_min
from torch import cat, square,zeros, clone, abs, sigmoid, float32, tanh

class MLP(Module):
    def __init__(self, n_in, n_out, hidden=64, nlayers=2, layer_norm=True):
        super().__init__()
        '''Simple two_layer MLP class with ReLU activiation + layernorm to use later'''
        layers = [Linear(n_in, hidden), ReLU()]
        for i in range(nlayers):
            layers.append(Linear(hidden, hidden))
            layers.append(ReLU()) 
        if layer_norm:
            layers.append(LayerNorm(hidden))
        layers.append(Linear(hidden, n_out))
        self.mlp = Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)


class Sage(Module):
    def __init__(self, hidden_channels=128, in_channels=43, out_channels=4, encode=True, conv_layers=3, conv_activation='relu', 
                    decode_layers=2, decode_activation='leakyrelu', layernorm=True, variance=True, agg='sum', rho=6):
        super(Sage, self).__init__()
        '''Model built upon the GraphSAGE convolutional layer. This is a node only model (no global, no edge).
        Model takes a data object from a dataloader in the forward call and takes out the rest itself. 
        hidden_channels, n_in, n_out must be specified
        Most other things can be customized at wish, e.g. activation functions for which ReLU and LeakyReLU can be used'''
        self.encode=encode
        if self.encode:
            self.node_enc = MLP(in_channels, hidden_channels, layer_norm=True)
        self.decode_activation=decode_activation
        self.conv_activation=conv_activation
        self.layernorm=layernorm
        self.in_channels=in_channels
        self.out_channels=out_channels
        self.hidden_channels=hidden_channels
        self.variance=variance
        self.agg=agg
        self.rho=rho
        ########################
        # Convolutional Layers #
        ######################## 

        self.convs=ModuleList()
        if self.encode:
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        else:
            self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(int(conv_layers-1)):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))

        ##################
        # Decode Layers #
        ##################       

        self.decoders = ModuleList()
        self.norms = ModuleList()
        for _ in range(out_channels):
            self.decoder=ModuleList()
            self.norm=ModuleList()
            for i in range(decode_layers):
                if i==decode_layers-1: ## if final layer, make layer with only one output
                    self.norm.append(LayerNorm(normalized_shape=hidden_channels))
                    self.decoder.append(Linear(hidden_channels, 1))
                else:
                    self.norm.append(LayerNorm(normalized_shape=hidden_channels))
                    self.decoder.append(Linear(hidden_channels, hidden_channels))
            self.decoders.append(self.decoder)
            self.norms.append(self.norm)

        ###################
        # Variance Layers #
        ###################

        if variance:
            self.sigs = ModuleList()
            self.sig_norms = ModuleList()
            for _ in range(out_channels):
                self.sig=ModuleList()
                self.sig_norm=ModuleList()
                for i in range(decode_layers):
                    if i==decode_layers-1:
                        self.sig_norm.append(LayerNorm(normalized_shape=hidden_channels))
                        self.sig.append(Linear(hidden_channels, 1))
                    else:
                        self.sig_norm.append(LayerNorm(normalized_shape=hidden_channels))
                        self.sig.append(Linear(hidden_channels, hidden_channels))
                self.sigs.append(self.sig)
                self.sig_norms.append(self.sig_norm)

        ######################
        # Co-Variance Layers #
        ######################

        if self.rho!=0:
            self.rhos = ModuleList()
            self.rho_norms = ModuleList()
            for _ in range(self.rho):
                self.rho_l=ModuleList()
                self.rho_norm=ModuleList()
                for i in range(decode_layers):
                    if i==decode_layers-1:
                        self.rho_norm.append(LayerNorm(normalized_shape=hidden_channels))
                        self.rho_l.append(Linear(hidden_channels, 1))
                    else:
                        self.rho_norm.append(LayerNorm(normalized_shape=hidden_channels))
                        self.rho_l.append(Linear(hidden_channels, hidden_channels))
                self.rhos.append(self.rho_l)
                self.rho_norms.append(self.rho_norm)
        
        #####################
        # Activation Layers #
        #####################
        
        self.conv_act=self.conv_act_f()
        self.decode_act=self.decode_act_f() ## could apply later

    def conv_act_f(self):
        if self.conv_activation =='relu':
            print('RelU conv activation')
            act = ReLU()
            return act
        if self.conv_activation =='leakyrelu':
            print('LeakyRelU conv activation')
            act=LeakyReLU()
            return act
        if not self.conv_activation:
            raise ValueError("Please specify a conv activation function")

    def decode_act_f(self):
        if self.decode_activation =='relu':
            print('RelU decode activation')
            act = ReLU()
            return act
        if self.decode_activation =='leakyrelu':
            print('LeakyRelU decode activation')
            act=LeakyReLU()
            return act
        if not self.decode_activation:
            print("Please specify a decode activation function")
            return None

    def forward(self, graph):

        #get the data
        x, edge_index, batch = graph.x, graph.edge_index, graph.batch
        if self.encode:
            x = self.node_enc(x)
        
        #convolutions 
        for conv in self.convs:
            x = conv(x, edge_index)
            x=self.conv_act(x)
        if self.agg=='sum':
            x = global_add_pool(x, batch)
        if self.agg=='max':
            x = global_max_pool(x, batch)
        
        #decoder
        
        x_out=[]
        for norm, decode in zip(self.norms, self.decoders):
            x1=clone(x)
            for n, d in zip(norm, decode):
                x1=d(n(x1))
                x1=self.decode_act(x1)
            x_out.append(x1)
        x_out=cat(x_out, dim=1)
        
        # variance
        if self.variance:
            sig=[]
            for norm, decode in zip(self.sig_norms, self.sigs):
                x1=clone(x)
                for n, d in zip(norm, decode):
                    x1=d(n(x1))
                    x1=self.decode_act(x1)
                sig.append(x1)
            sig=abs(cat(sig, dim=1))

        if self.rho!=0:
            rho=[]
            for norm, decode in zip(self.rho_norms, self.rhos):
                x1=clone(x)
                for n, d in zip(norm, decode):
                    x1=d(n(x1))
                    x1=self.decode_act(x1)
                rho.append(x1)
            rho=abs(cat(rho, dim=1)) ### not sure this works with only 1d
        
        if self.variance:
            if self.rho!=0:
                return x_out, sig, tanh(rho)
            else:
                return x_out, sig
        else:
            return x_out

## Loss implementation

I'll first try the 2d version formulated in the below fashion since $\rho$ fits nicely with a sigmoid

$$
f(x, y)=\frac{1}{2 \pi \sigma_{X} \sigma_{Y} \sqrt{1-\rho^{2}}} \exp \left(-\frac{1}{2\left(1-\rho^{2}\right)}\left[\left(\frac{x-\mu_{X}}{\sigma_{X}}\right)^{2}-2 \rho\left(\frac{x-\mu_{X}}{\sigma_{X}}\right)\left(\frac{y-\mu_{Y}}{\sigma_{Y}}\right)+\left(\frac{y-\mu_{Y}}{\sigma_{Y}}\right)^{2}\right]\right)
$$
Where $\rho$ is the correlation between $X$ and $Y$ and where $\sigma_{X}>0$ and $\sigma_{Y}>0 .$ In this case,
$$
\boldsymbol{\mu}=\left(\begin{array}{l}
\mu_{X} \\
\mu_{Y}
\end{array}\right), \quad \boldsymbol{\Sigma}=\left(\begin{array}{cc}
\sigma_{X}^{2} & \rho \sigma_{X} \sigma_{Y} \\
\rho \sigma_{X} \sigma_{Y} & \sigma_{Y}^{2}
\end{array}\right)
$$

Defining $$z_x = \left(\frac{x-\mu_{X}}{\sigma_{X}}\right)$$, we can write the NLL as:

$$
ln(\sigma_{x})+ln(\sigma_{y})+1/2*ln(1-\rho^2)+\frac{1}{2(1-\rho^2)}(z_x^2+z_y^2-2*\rho*(z_x*z_y))
$$


In [10]:
def l_func_2d(pred, ys, sig1, sig2, rho):
#     global z1, z2, sigloss, rholoss, factor
    z1=(pred[:,0]-ys[:,0])/sig1
    z2=(pred[:,1]-ys[:,1])/sig2
    sigloss=torch.sum(torch.log(sig1)+torch.log(sig2))
    rholoss=torch.sum(torch.log(1-rho**2)/2)
    factor=1/(2*(1-rho**2))
    err_loss = torch.sum(factor*(z1**2+z2**2-2*rho*z1*z2))
    
    return err_loss+sigloss+rholoss, err_loss, sigloss, rholoss

In [16]:
def l_func_2d(pred, ys, sig, rho):
#     global delta, bsize, A2, sig_inv, detloss, err_loss
    
    delta=pred-ys
    bsize=delta.shape[0]
    N=4
    #this is messy but it works 
    
    #compute the covariance matrix
    vals = torch.vstack([sig[:,0]**2, rho[:,0]*sig[:,0]*sig[:,1], rho[:,1]*sig[:,0]*sig[:,2], rho[:,2]*sig[:,0]*sig[:,3],\
                 sig[:,1]**2,rho[:,3]*sig[:,1]*sig[:,2], rho[:,4]*sig[:,1]*sig[:,3], \
                sig[:,2]**2,rho[:,5]*sig[:,2]*sig[:,3],\
                 sig[:,3]**2])

    A = torch.zeros(N, N,bsize, device='cuda:0')
    A[0] = vals[:4]
    A[1] = torch.vstack([vals[1], vals[4:7]])
    A[2] = torch.vstack([vals[2], vals[5], vals[7:9]])
    A[3] = torch.vstack([vals[3], vals[6], vals[8], vals[9]])
    
    A2=A.permute(2,0,1)
    
    det=torch.det(A2)
    detloss=torch.sum(torch.log(det))/2
    
    sig_inv = torch.inverse(A2)
    
    err=delta*torch.bmm(sig_inv, delta.unsqueeze(2))[:,:,0]
    
    err_loss = torch.sum(err)/2
    
    return err_loss+detloss, err_loss, detloss

In [17]:
criterion = torch.nn.MSELoss()
n_epochs=200
n_trials=1
batch_size=128
split=0.8
test_data=data[int(len(data)*split):]
train_data=data[:int(len(data)*split)]
l1_lambda = 0
l2_lambda = 0

In [18]:
# Initialize our train function

def train():
    model.train()
    global out, sigs, rho
    for data in train_loader:  
        out, sigs, rho = model(data)  
        loss, err_l, det_l = l_func_2d(out, data.y.view(-1,n_targ), sigs, rho) 
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())


        loss = loss + l1_lambda * l1_norm + l2_lambda * l2_norm
#             loss.backward()
        accelerator.backward(loss)
        optimizer.step() 
        optimizer.zero_grad() 
#     print(loss, l1_norm*l1_lambda, l2_norm*l2_lambda)
 # test function

def test(loader): ##### transform back missing
    model.eval()
    outs = []
    ys = []
    varss = []
    rhos = []
    with torch.no_grad(): ##this solves it!!!
        for dat in loader: 
            out, sigs , rho  = model(dat) 
            ys.append(dat.y.view(-1,n_targ))
            outs.append(out)
            varss.append(sigs)
            rhos.append(rho)
    outss=torch.vstack(outs)
    yss=torch.vstack(ys)
    varss=torch.vstack(varss)
    rhos=torch.vstack(rhos)
    return torch.std(outss - yss, axis=0), outss, yss, varss, rhos

In [19]:
trains, tests, scatter = [], [], []
yss, preds=[],[]
model = Sage()
train_loader=DataLoader(train_data, batch_size=batch_size, shuffle=1, num_workers=1)
test_loader=DataLoader(test_data, batch_size=batch_size, shuffle=0,num_workers=1)    
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
_, _, test_loader = accelerator.prepare(model, optimizer, test_loader)
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
print('GPU ', next(model.parameters()).is_cuda)


RelU conv activation
LeakyRelU decode activation
GPU  True


In [20]:
#this uses about 1 GB of memory on the GPU
tr_acc, te_acc = [], []
start=time.time()
for epoch in range(n_epochs):

    train()

    if (epoch+1)%2==0:
        train_acc, _ , _, _ , _ = test(train_loader)
        test_acc, _ , _ , _ , _ = test(test_loader)
        tr_acc.append(train_acc.cpu().numpy())
        te_acc.append(test_acc.cpu().numpy())
        print(f'Epoch: {epoch+1:03d}, Train scatter: {np.round(train_acc.cpu().numpy(), 4)} \n \
          Test scatter:  {np.round(test_acc.cpu().numpy(), 4)}')
stop=time.time()
spent=stop-start
print(f"{spent:.2f} seconds spent training, {spent/n_epochs:.3f} seconds per epoch. Processed {len(data)*split*n_epochs/spent:.0f} trees per second")


Epoch: 002, Train scatter: [0.935  0.1753 0.5431 0.9949] 
           Test scatter:  [0.9295 0.1734 0.5395 0.9948]
Epoch: 004, Train scatter: [0.935  0.3712 0.5431 0.9949] 
           Test scatter:  [0.9295 0.3662 0.5395 0.9948]
Epoch: 006, Train scatter: [0.9349 0.2094 0.5431 0.9949] 
           Test scatter:  [0.9295 0.2069 0.5395 0.9948]
Epoch: 008, Train scatter: [0.9349 0.1872 0.5431 0.9949] 
           Test scatter:  [0.9295 0.185  0.5395 0.9948]
Epoch: 010, Train scatter: [0.9349 0.2074 0.5431 0.9949] 
           Test scatter:  [0.9295 0.2042 0.5395 0.9947]
Epoch: 012, Train scatter: [0.9349 0.1764 0.5431 0.9949] 
           Test scatter:  [0.9294 0.1744 0.5394 0.9948]
Epoch: 014, Train scatter: [0.9348 0.172  0.5431 0.9948] 
           Test scatter:  [0.9294 0.1699 0.5394 0.9947]
Epoch: 016, Train scatter: [0.9349 0.1609 0.5429 0.9949] 
           Test scatter:  [0.9294 0.1592 0.5393 0.9948]
Epoch: 018, Train scatter: [0.9344 0.1374 0.5422 0.9949] 
           Test scatter:  [0.9

Epoch: 146, Train scatter: [0.1066 0.0365 0.2043 0.3931] 
           Test scatter:  [0.108  0.0361 0.2076 0.3871]
Epoch: 148, Train scatter: [0.1057 0.0363 0.2044 0.3924] 
           Test scatter:  [0.1065 0.0359 0.2081 0.3873]
Epoch: 150, Train scatter: [0.1102 0.0369 0.2076 0.3984] 
           Test scatter:  [0.1107 0.0365 0.2111 0.392 ]
Epoch: 152, Train scatter: [0.1081 0.0365 0.2034 0.3925] 
           Test scatter:  [0.1089 0.036  0.2071 0.3864]
Epoch: 154, Train scatter: [0.1068 0.036  0.2022 0.3884] 
           Test scatter:  [0.1068 0.0357 0.2062 0.3831]
Epoch: 156, Train scatter: [0.1049 0.0358 0.2022 0.3884] 
           Test scatter:  [0.1051 0.0355 0.2062 0.3828]
Epoch: 158, Train scatter: [0.1089 0.036  0.203  0.389 ] 
           Test scatter:  [0.1094 0.0357 0.2068 0.3841]
Epoch: 160, Train scatter: [0.1068 0.036  0.2013 0.3893] 
           Test scatter:  [0.1076 0.0356 0.2053 0.3838]
Epoch: 162, Train scatter: [0.1041 0.0355 0.2007 0.3859] 
           Test scatter:  [0.1

In [None]:
n_epochs=120
te_acc, tr_acc = np.array(te_acc), np.array(tr_acc)
fig, ax =plt.subplots(2,2,figsize=(15,7))
ax=ax.flatten()
best = [0.0781, 0.034,0.1834, 0.3608]
for i in range(4):
    ax[i].plot(np.arange(0,n_epochs,2), tr_acc[:,i], 'r-', label=f'training')
    ax[i].plot(np.arange(0,n_epochs,2), te_acc[:,i], 'g-',label=f'validation')
    ax[i].hlines(best[i], 0, 120, color='k', linestyle='dashed', label='best in single training')
    ax[i].set(xlabel='Epoch', ylabel=r'$\sigma$ stellar mass [dex]', \
        title=f'{cols_t[i][:6]} Minimum {np.min(np.array(te_acc)[:,i]):.3f}, median last 10 {np.median(np.array(te_acc)[-5:,i]):.3f}')
    ax[i].legend()

# ax[1].plot(np.arange(0,n_epochs,2), tr_acc[:,1], 'r-', label='training')
# ax[1].plot(np.arange(0,n_epochs,2), te_acc[:,1], 'g-',label='validation')
# ax[1].set(xlabel='Epoch', ylabel=r'$\sigma$ v_disk [dex]', \
#     title=f'Minimum {np.min(np.array(te_acc)[:,1]):.4f}, median last 10 {np.median(np.array(te_acc)[-5:,1]):.4f}')
# ax[1].legend()

fig.suptitle('4D Gaussian Loss with correlation')
fig.tight_layout()

In [None]:
np.mean(np.array(te_acc)[-5:,0])

In [None]:
np.array(te_acc)[-5:]

In [None]:
trainstd, outtrain, ytrain, var, rho = test(train_loader)

In [None]:
fig , ax = plt.subplots(2, figsize=(15,8))
ax = ax.flatten()
l=0
for k in range(n_targ):
#     ax[k].hist(ress.cpu().numpy()[:,k], bins=50, histtype='step', label='res')
    ax[k].hist(outtrain.cpu().numpy()[:,k], bins=50, \
            range=list(np.percentile(outtrain.cpu().numpy()[:,k], [l,100-l])), histtype='step', label='pred')
    ax[k].hist(ytrain.cpu().numpy()[:,k], bins=50, \
            range=list(np.percentile(outtrain.cpu().numpy()[:,k], [l,100-l])), histtype='step', label='true')

    ax[k].legend()
#     print(np.std(ress.cpu().numpy()[:,k]), np.mean(ress.cpu().numpy()[:,k]))

In [None]:
plt.hist(rho.cpu().numpy().flatten(), bins=100);

In [None]:
i=0
plt.hist((outtrain.cpu().numpy()[:,i]-ytrain.cpu().numpy()[:,i])/var[:,i].cpu().numpy(), bins=100);

In [None]:
plt.hist(var.cpu().numpy()[:,0], bins=100);

In [None]:
import matplotlib as mpl
i=0
z0=(outtrain.cpu().numpy()[:,i]-ytrain.cpu().numpy()[:,i])/var[:,i].cpu().numpy()
i=1
z1=(outtrain.cpu().numpy()[:,i]-ytrain.cpu().numpy()[:,i])/var[:,i].cpu().numpy()
vals, x, y, ax=plt.hist2d(z0,z1,bins=25, norm=mpl.colors.LogNorm(), cmap=mpl.cm.magma)
X, Y = np.meshgrid((x[1:]+x[:-1])/2, (y[1:]+y[:-1])/2)
plt.contour(X,Y, np.log(vals.T+1), levels=10, colors='black')

In [None]:
np.corrcoef(z0,z1)

In [None]:
import matplotlib as mpl
vals, x, y, ax=plt.hist2d(outtrain.cpu().numpy()[:,0],outtrain.cpu().numpy()[:,1],bins=25, norm=mpl.colors.LogNorm(), cmap=mpl.cm.magma)
X, Y = np.meshgrid((x[1:]+x[:-1])/2, (y[1:]+y[:-1])/2)
plt.contour(X,Y, np.log(vals.T+1), levels=10, colors='black')

In [None]:
np.corrcoef(outtrain.cpu().numpy()[:,0],outtrain.cpu().numpy()[:,1])

In [None]:
(x[1:]+x[:-1])/2

In [None]:
fig , ax = plt.subplots(2, figsize=(8,10))
ax = ax.flatten()
l=0.5
for k in range(n_targ):
    ytr=ytrain.cpu().numpy()[:,k]
    predtr=outtrain.cpu().numpy()[:,k]
    ax[k].plot(ytr, predtr, 'ro', label='true-pred')
    ax[k].plot([min(ytr),max(ytr)],[min(ytr),max(ytr)], 'k--', label='Perfect correspondance')
    ax[k].set(title=[cols_t[k], np.round(np.std(ytr-predtr),3), np.round(np.mean(ytr-predtr),3)])
    ax[k].legend()
#     print(np.std(ress.cpu().numpy()[:,k]), np.mean(ress.cpu().numpy()[:,k]))

In [None]:
teststd, outtest, ytest, vtest, rhotest = test(test_loader)

In [None]:
fig , ax = plt.subplots(2, figsize=(10,10))
ax = ax.flatten()
l=0.0
for k in range(n_targ):
    ax[k].hist(outtest.cpu().numpy()[:,k], bins=50, \
            range=list(np.percentile(outtest.cpu().numpy()[:,k], [l,100-l])), histtype='step', label='pred')
    ax[k].hist(ytest.cpu().numpy()[:,k], bins=50, \
            range=list(np.percentile(outtest.cpu().numpy()[:,k], [l,100-l])), histtype='step', label='true')
    ax[k].set(title=cols_t[k])
    ax[k].legend()
#     print(np.std(ress.cpu().numpy()[:,k]), np.mean(ress.cpu().numpy()[:,k]))
fig.tight_layout()

In [None]:
# fig , ax = plt.subplots(2, figsize=(7,10))
# ax = ax.flatten()
# for k in range(n_targ):
#     yte=ytest.cpu().numpy()[:,k]
#     predte=outtest.cpu().numpy()[:,k]
#     ax[k].plot(yte, predte, 'ro', label='true-pred')
#     ax[k].plot([min(yte),max(yte)],[min(yte),max(yte)], 'k--', label='Perfect correspondance')
#     ax[k].set(title=[cols_t[k], np.round(np.std(yte-predte),3), np.round(np.mean(yte-predte),3)])
#     ax[k].legend()
# #     print(np.std(ress.cpu().numpy()[:,k]), np.mean(ress.cpu().numpy()[:,k]))

In [None]:
ys=torch.rand(128,4)
pred=torch.rand(128,4)
sig = torch.abs(torch.rand(128,4))
rho = torch.tanh(torch.rand(128,6)-torch.ones(128,6)/2)
# ((pred-ys)/sig)**2

In [None]:
d=ys-pred
d.shape[0]

In [None]:
vals = torch.vstack([sig[:,0]**2, rho[:,0]*sig[:,0]*sig[:,1], rho[:,1]*sig[:,0]*sig[:,2], rho[:,2]*sig[:,0]*sig[:,3],\
                 sig[:,1]**2,rho[:,3]*sig[:,1]*sig[:,2], rho[:,4]*sig[:,1]*sig[:,3], \
                sig[:,2]**2,rho[:,5]*sig[:,2]*sig[:,3],\
                 sig[:,3]**2])

A = torch.zeros(N, N,128)

A[0] = vals[:4]
A[1] = torch.vstack([vals[1], vals[4:7]])
A[2] = torch.vstack([vals[2], vals[5], vals[7:9]])
A[3] = torch.vstack([vals[3], vals[6], vals[8], vals[9]])

A2=A.permute(2,0,1)
det=torch.det(A2)

sig_inv = torch.inverse(A2)

d*torch.bmm(sig_inv, d.unsqueeze(2))[:,:,0]