### Estimating the covariance for dcGAN generated data

In this notebook we show how to estimate the population covariances $(\Psi, \Phi, \Omega)$ with a Monte Carlo algorithm for a given data generator $\mathcal{G}$, and teacher-student feature maps $(\varphi_{t}, \varphi_{s})$. Concretely, in this notebook we will look at the following setting:

- **Generator:** Our generator will be the dcGAN from [Radford et al.](https://arxiv.org/abs/1511.06434) trained to map i.i.d. Gaussian noise $z\sim\mathcal{N}(0,\rm{I}_{100})\mapsto x\in\mathbb{R}^{D}$ into CIFAR10-looking images. For more details, check notebook `synthetic_data_pipeline.ipynb`.
- **Teacher features:** The teacher feature map $\varphi_{t}:x\in\mathbb{R}^{D}\mapsto u\in\mathbb{R}^{p}$ will be a fully-connected neural network trained to classify odd (+1) vs even (-1) real CIFAR10 images. The feature map is obtained by selecting all but the last layer, which define the teacher weights $\theta_{0}\in\mathbb{R}^{p}$.
- **Student features:** The student feature map $\varphi_{s}:x\in\mathbb{R}^{D}\mapsto u\in\mathbb{R}^{d}$ will be a fully-connected neural network trained on 30k fake CIFAR10-like images sampled from the generator above, with lablels assigned by the teacher also described above.

## Loading the generator

In [59]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from dcgan import Generator

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

For more info on the generator, see `synthetic_data_pipeline.ipynb`

In [60]:
# Load the generator

latent_dim = 100 # generator latent dimension

generator = Generator(ngpu=1)
generator.load_state_dict(torch.load("./data/weights/dcgan_cifar10_weights.pth", map_location=device))

print(generator)

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (13): Tanh()
  )
)


## Loading the teacher feature map

In [7]:
import teachers
import teacherutils

For more info on the teacher, see `synthetic_data_pipeline.ipynb`

In [33]:
dx, dy, c = 32, 32, 3 # teacher input dimension
D = dx*dy*c
p = D
# Load teacher vector
kwargs = {"input_dim": [1, dx, dy]}
teacher_mlp = teacherutils.get_model("mlp", "erf", D, 1, **kwargs)
teacher_mlp.load_state_dict(torch.load("./data/weights/mlp_erf_cifar10.pt", 
                                       map_location=device))

print(teacher_mlp)

MLP(
  (preprocess1): Linear(in_features=3072, out_features=3072, bias=False)
  (preprocess2): Linear(in_features=3072, out_features=3072, bias=False)
  (preprocess3): Linear(in_features=3072, out_features=3072, bias=False)
  (bnz): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (fc): Linear(in_features=3072, out_features=1, bias=False)
  (v): Linear(in_features=1, out_features=1, bias=False)
)


In [34]:
# Extracting the feature map. Note inputs are assumed to be flattened
teacher_map = lambda x: teacher_mlp.preprocess(x).detach().numpy()

## Loading student feature maps

In [32]:
# Load student
d = D
student = torch.nn.Sequential(
    torch.nn.Linear(D, p, bias=False),
    torch.nn.ReLU(),
    torch.nn.Linear(p,p, bias=False),
    torch.nn.ReLU(),
    torch.nn.Linear(p,1, bias=False))

# Load weights.
student.load_state_dict(torch.load('./data/weights/weights_mlp_student_epochs=200.pth', map_location=device))

# Extract feature map
student_map = lambda x: student[:-1](x).detach().numpy()

## Monte-Carlo estimation

In [57]:
from tqdm import tqdm

In [58]:
data = {'mean_u': np.zeros(p), 
        'mean_v': np.zeros(d), 
        'psi': np.zeros((p,p)), 
        'omega': np.zeros((d,d)), 
        'phi': np.zeros((p,d))
        }

M2_omega = np.zeros((d, d))  # running estimate of residuals
M2_phi = np.zeros((p, d))  # running estimate of residuals
M2_psi = np.zeros((p, p))  # running estimate of residuals

# Keeping last values
data_last = {}
for name in data.keys():
    data_last[name] = np.zeros(data[name].shape)
        
mc_mean_v_old = 0
mc_mean_u_old = 0

In [None]:
step = -1
mc_steps = int(1e6) # Maximum number of steps
batch_size = 1000 # Number of samples at every batch
checkpoint = 100 # Save partial results every checkpoint loops
with torch.no_grad():
    while step < mc_steps:
        for _ in tqdm(range(checkpoint)):
            
            step += 1
            
            # Generate CIFAR10-like images
            Z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
            X = generator(Z).reshape(batch_size, -1)
            
            # Compute student features
            V = student_map(X)
            
            # Compute teacher faeatures
            U = teacher_map(X)
            
            # Save old means
            mc_mean_v_old = data["mean_v"]
            mc_mean_u_old = data["mean_u"]

            # Update means
            dmean_u = np.mean(U, axis=0) - data["mean_u"]
            data["mean_u"] += dmean_u / (step + 1)

            dmean_v = np.mean(V, axis=0) - data["mean_v"]
            data["mean_v"] += dmean_v / (step + 1)

            # Update residuals
            M2_omega += (V - mc_mean_v_old).T @ (V - data["mean_v"]) / batch_size
            M2_psi += (U - mc_mean_u_old).T @ (U - data["mean_u"]) / batch_size
            M2_phi += (U - mc_mean_u_old).T @ (V - data["mean_v"]) / batch_size

        data["omega"] = M2_omega / (step + 1)
        data["phi"] = M2_phi / (step + 1)
        data["psi"] = M2_psi / (step + 1)

        # Build status message
        status = "{}".format(step * batch_size)
            
        for name in data.keys():
            diff = np.sqrt(np.mean((data[name] - data_last[name]) ** 2))
            status += ", {}".format(diff)

            # Update last
            data_last[name] = data[name]
        
        print(status)
        
        for name in data.keys():
            fname = "./data/matrices/covariances/{}_t=mlp_s=mlp_epoch=200_n={}.npy".format(name, step * batch_size)
            np.save(fname, data[name])