In [1]:
import os
from matplotlib import pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
from typing import Callable

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import torchmetrics

from data_utils import *

In [2]:
# raw
X_tr, y_tr, X_te, y_te = getEmbXy(-1, 0, True)

# convert
X_tr = X_tr.type(torch.float)
X_te = X_te.type(torch.float)

# tables
df_tr, df_te = createTables(X_tr, y_tr, X_te, y_te)

## Generate a set of Objective Functions

From easy $\mathcal{F}_{0}$ to hard $\mathcal{F}_{3}$

### $\mathcal{F}_{0}$: Based on NN embedding

$f_{01}(x) = \mu(x)$

$f_{02}(x) = q_{0.9}(x)$

$f_{03}(x) = \sigma(x)$

$f_{04}(x) = \mu(x)+\sigma(x)$

### $\mathcal{F}_{1}$: Based on Convolution embedding

$f_{11}(x) = a \cdot x$

$f_{12}(x) = \mbox{pool}(a \cdot x)$

$f_{13}(x) = \mbox{pool}(a \cdot x) \cdot b$

$f_{14}(x) = \mbox{pool}(\mbox{pool}(a \cdot x) \cdot b)$

### $\mathcal{F}_{2}$: Based on Convolution *and* Label

$f_{21}(x) = \mbox{norm}(a \cdot x) + norm(y)$

$f_{22}(x) = \mbox{norm}(\mbox{pool}(a \cdot x)) + norm(y)$

$f_{23}(x) = \mbox{norm}(\mbox{pool}(a \cdot x) \cdot b) + norm(y)$

$f_{24}(x) = \mbox{norm}(\mbox{pool}(\mbox{pool}(a \cdot x) \cdot b)) + norm(y)$

In [10]:
# save tables
df_tr.to_csv('./tables/df_tr.csv', index=False)
df_te.to_csv('./tables/df_te.csv', index=False)

In [11]:
df_tr = pd.read_csv('./tables/df_tr.csv')

df_tr.head()

Unnamed: 0,mu_tr,q935_tr,sig_tr,mu+sig_tr,conv1_tr,conv2_tr,conv3_tr,conv4_tr,conv1+label_tr,conv2+label_tr,conv3+label_tr,conv4+label_tr
0,35.108418,-253.0,35.108418,70.216835,8076.0,4655.0,828.0,522.0,0.454871,1.01004,0.436427,0.423666
1,39.66199,-252.0,39.66199,79.32398,5371.0,2860.0,0.0,0.0,-2.273893,-2.242747,-3.508076,-3.415902
2,24.799746,-177.0,24.799746,49.59949,4743.0,2612.0,525.0,271.0,-1.121196,-1.068617,-0.719861,-0.936549
3,21.855867,-220.0,21.855867,43.711735,7179.0,3014.0,798.0,292.0,-1.260587,-1.766038,-1.028233,-1.890037
4,29.609694,-251.10498,29.609694,59.219387,7586.0,3264.0,1240.0,736.0,1.658493,1.21486,2.922497,2.672724
