In [12]:
import warnings
warnings.simplefilter('ignore')
import itertools
import numpy as np
import matplotlib.pyplot as plt 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import itertools
import os

### import from our files
from mliv.dgps import get_data, get_tau_fn, fn_dict
from mliv.neuralnet.utilities import log_metrics, plot_results, hyperparam_grid,\
                                     hyperparam_mult_grid, eval_performance
from mliv.neuralnet.mnist_dgps import AbstractMNISTxz
from mliv.neuralnet import AGMM,KernelLayerMMDGMM
from mliv.neuralnet.rbflayer import gaussian, inverse_multiquadric

In [38]:
class CNN_Z_agmm(nn.Module):
    def __init__(self):
        super(CNN_Z_agmm, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        self.fc3 = nn.Linear(10, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        x = self.fc3(x)
        output = x.squeeze()  # F.log_softmax(x, dim=1)
        return output


class CNN_Z_kernel(nn.Module):
    def __init__(self, g_features=100):
        super(CNN_Z_kernel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, g_features)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = x  # F.log_softmax(x, dim=1)
        return output


class CNN_X(nn.Module):
    def __init__(self):
        super(CNN_X, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = x  # F.log_softmax(x, dim=1)
        return output.squeeze()


In [39]:
def fc_z_kernel(n_z, n_hidden, g_features, dropout_p):
    FC_Z_kernel = nn.Sequential(
        nn.Dropout(p=dropout_p),
        nn.Linear(n_z, n_hidden),
        nn.LeakyReLU(),
        nn.Dropout(p=dropout_p),
        nn.Linear(n_hidden, g_features),
        nn.ReLU(),
    )
    return FC_Z_kernel


def fc_z_agmm(n_z, n_hidden, dropout_p):
    FC_Z_agmm = nn.Sequential(
        nn.Dropout(p=dropout_p),
        nn.Linear(n_z, n_hidden),
        nn.LeakyReLU(),
        nn.Dropout(p=dropout_p),
        nn.Linear(n_hidden, 1),
    )
    return FC_Z_agmm


def fc_x(n_t, n_hidden, dropout_p):
    FC_X = nn.Sequential(
        nn.Dropout(p=dropout_p),
        nn.Linear(n_t, n_hidden),
        nn.LeakyReLU(),
        nn.Dropout(p=dropout_p),
        nn.Linear(n_hidden, 1),
    )
    return FC_X

In [40]:
def generate_data(
    X_IMAGE=False,
    Z_IMAGE=False,
    tau_fn="abs",
    n_samples=10000,
    n_instruments=2,
    iv_strength=0.5,
    device=None,
):
    mnist_dgp = AbstractMNISTxz(X_IMAGE, Z_IMAGE, tau_fn)
    n_test = n_samples // 10
    n_t = 1

    T, Z, Y, G, _ = mnist_dgp.generate_data(
        n_samples, tau_fn=tau_fn, n_instruments=n_instruments, iv_strength=iv_strength
    )

    T_test, Z_test, Y_test, G_test, _ = mnist_dgp.generate_data(
        n_test, tau_fn=tau_fn, n_instruments=n_instruments, iv_strength=iv_strength,
    )

    Z_train, Z_val, T_train, T_val, Y_train, Y_val, G_train, G_val = train_test_split(
        Z, T, Y, G, test_size=0.1, shuffle=True
    )
    Z_train, T_train, Y_train, G_train = map(
        lambda x: torch.Tensor(x), (Z_train, T_train, Y_train, G_train)
    )
    Z_val, T_val, Y_val, G_val = map(
        lambda x: torch.Tensor(x).to(device), (Z_val, T_val, Y_val, G_val)
    )
    Z_test, T_test, Y_test, G_test = map(
        lambda x: torch.Tensor(x).to(device), (Z_test, T_test, Y_test, G_test)
    )

    data_array = []
    data_array.append((Z_train, T_train, Y_train, G_train))
    data_array.append((Z_val, T_val, Y_val, G_val))
    data_array.append((Z_test, T_test, Y_test, G_test))
    return data_array

In [41]:
def train_agmm(
    Z_train,
    T_train,
    Y_train,
    G_train,
    Z_val,
    T_val,
    Y_val,
    G_val,
    T_test,
    G_test,
    X_IMAGE=False,
    Z_IMAGE=False,
    n_t=1,
    n_instruments=2,
    n_hidden=200,
    dropout_p=0.1,
    learner_lr=1e-4,
    adversary_lr=1e-4,
    learner_l2=1e-4,
    adversary_l2=1e-4,
    adversary_norm_reg=1e-4,
    n_epochs=100,
    batch_size=100,
    train_learner_every=1,
    train_adversary_every=1,
):
    if X_IMAGE:
        learner = CNN_X()
    else:
        learner = fc_x(n_t, n_hidden, dropout_p)
    if Z_IMAGE:
        adversary = CNN_Z_agmm()
    else:
        adversary = fc_z_agmm(n_instruments, n_hidden, dropout_p)

    def logger(learner, adversary, epoch, writer):
        if not X_IMAGE:
            writer.add_histogram("learner", learner[-1].weight, epoch)
        if not Z_IMAGE:
            writer.add_histogram("adversary", adversary[-1].weight, epoch)
        log_metrics(
            Z_val,
            T_val,
            Y_val,
            Z_val,
            T_val,
            Y_val,
            T_test,
            learner,
            adversary,
            epoch,
            writer,
            true_of_T=G_val,
        )

    np.random.seed(12356)
    print("---Hyperparameters---")
    print("Learner Learning Rate:", learner_lr)
    print("Adversary learning rate:", adversary_lr)
    print("Learner_l2:", learner_l2)
    print("Adversary_l2:", adversary_l2)
    print("Number of epochs:", n_epochs)
    print("Batch Size:", batch_size)
    agmm = AGMM(learner, adversary).fit(
        Z_train,
        T_train,
        Y_train,
        learner_lr=learner_lr,
        adversary_lr=adversary_lr,
        learner_l2=learner_l2,
        adversary_l2=adversary_l2,
        n_epochs=n_epochs,
        bs=batch_size,
        logger=logger,
        model_dir="agmm_model",
        device=device,
        train_learner_every=train_learner_every,
        train_adversary_every=train_adversary_every,
    )

    return agmm


#### Train KernelLayerGMM
def train_kernellayergmm(
    Z_train,
    T_train,
    Y_train,
    G_train,
    Z_val,
    T_val,
    Y_val,
    G_val,
    T_test,
    G_test,
    g_features=100,
    kernel_fn=gaussian,
    centers=None,
    sigmas=None,
    X_IMAGE=False,
    Z_IMAGE=False,
    n_t=1,
    n_instruments=2,
    n_hidden=200,
    dropout_p=0.1,
    learner_lr=1e-4,
    adversary_lr=1e-4,
    learner_l2=1e-4,
    adversary_l2=1e-4,
    adversary_norm_reg=1e-4,
    n_epochs=100,
    batch_size=100,
    train_learner_every=1,
    train_adversary_every=1,
):
    if X_IMAGE:
        learner = CNN_X()
    else:
        learner = fc_x(n_t, n_hidden, dropout_p)
    if Z_IMAGE:
        adversary = CNN_Z()
    else:
        adversary = fc_z_kernel(n_instruments, n_hidden, g_features, dropout_p)

    def logger(learner, adversary, epoch, writer):
        if not X_IMAGE:
            writer.add_histogram("learner", learner[-1].weight, epoch)
        # if not Z_IMAGE:
        #  writer.add_histogram('adversary', adversary[-1].weight, epoch)
        writer.add_histogram("adversary", adversary.beta.weight, epoch)
        log_metrics(
            Z_val,
            T_val,
            Y_val,
            Z_val,
            T_val,
            Y_val,
            T_test,
            learner,
            adversary,
            epoch,
            writer,
            true_of_T=G_val,
        )

    np.random.seed(12356)
    print("---Hyperparameters---")
    print("Learner Learning Rate:", learner_lr)
    print("Adversary learning rate:", adversary_lr)
    print("Learner_l2:", learner_l2)
    print("Adversary_l2:", adversary_l2)
    print("Number of epochs:", n_epochs)
    print("Batch Size:", batch_size)
    print("G features", g_features)
    print("Number of centers", n_centers)
    print("Kernel function", kernel_fn.__name__)
    klayermmdgmm = KernelLayerMMDGMM(
        learner,
        adversary,
        g_features,
        n_centers,
        kernel_fn,
        centers=centers,
        sigmas=sigmas,
    )
    klayermmdgmm.fit(
        Z_train,
        T_train,
        Y_train,
        learner_l2=learner_l2,
        adversary_l2=adversary_l2,
        adversary_norm_reg=adversary_norm_reg,
        learner_lr=learner_lr,
        adversary_lr=adversary_lr,
        n_epochs=n_epochs,
        bs=bs,
        logger=logger,
        model_dir="klayer_model",
        device=device,
        train_learner_every=train_learner_every,
        train_adversary_every=train_adversary_every,
    )

    return klayermmdgmm

In [42]:
device = torch.cuda.current_device() if torch.cuda.is_available() else None

In [43]:
X_IMAGE = True
Z_IMAGE = True
tau_fn = "abs"
n_samples = 1000
n_instruments = 2
iv_strength = 0.5
data = generate_data(
    X_IMAGE=X_IMAGE,
    Z_IMAGE=Z_IMAGE,
    tau_fn=tau_fn,
    n_samples=n_samples,
    n_instruments=n_instruments,
    iv_strength=iv_strength,
    device=device,
)
(Z_train, T_train, Y_train, G_train) = data[0]
(Z_val, T_val, Y_val, G_val) = data[1]
(Z_test, T_test, Y_test, G_test) = data[2]

In [44]:
from torchvision import models
from torchsummary import summary


In [45]:
summary(CNN_X().to(device), (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 26, 26]             320
            Conv2d-2           [-1, 64, 24, 24]          18,496
         Dropout2d-3           [-1, 64, 12, 12]               0
            Linear-4                  [-1, 128]       1,179,776
         Dropout2d-5                  [-1, 128]               0
            Linear-6                    [-1, 1]             129
Total params: 1,198,721
Trainable params: 1,198,721
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.52
Params size (MB): 4.57
Estimated Total Size (MB): 5.09
----------------------------------------------------------------


In [48]:
%%time
# parameters for networks
dropout_p = 0.1
n_t = 1
n_hidden = 200

# local hyperparam
learner_lr = 1e-4
adversary_lr = 1e-4
learner_l2 = 1e-4
adversary_l2 = 1e-4
adversary_norm_reg = 1e-4
n_epochs = 10
bs = 100
agmm = train_agmm(Z_train, T_train, Y_train, G_train, Z_val, T_val, Y_val, G_val, T_test, G_test,
                  X_IMAGE=True, Z_IMAGE=True, n_t=n_t, n_instruments=n_instruments,
                  n_hidden=n_hidden, dropout_p=dropout_p, learner_lr=learner_lr, adversary_lr=adversary_lr,
                  learner_l2=learner_l2, adversary_l2=adversary_l2, adversary_norm_reg=adversary_norm_reg,
                  n_epochs=n_epochs, batch_size=bs)


#plot_results(agmm, T_test, true_of_T_test=G_test)
#eval_performance(agmm,T_test, true_of_T_test=G_test)

---Hyperparameters---
Learner Learning Rate: 0.0001
Adversary learning rate: 0.0001
Learner_l2: 0.0001
Adversary_l2: 0.0001
Number of epochs: 10
Batch Size: 100
CPU times: user 866 ms, sys: 97.6 ms, total: 964 ms
Wall time: 8.33 s
