# Synthetic Data

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.functional as F
from torch.distributions import Distribution
import torch.distributions as dist
from torch.utils.data import Dataset, DataLoader

In [3]:
class SynthNN(nn.Module):
    def __init__(self, hidden_size = 6):
        super().__init__()
        self.function = nn.Sequential(
            nn.Linear(in_features=3, out_features=hidden_size),
            nn.ReLU(),
            nn.Linear(in_features=hidden_size, out_features=10),
        )
    def forward(self, Z):
        return self.function(Z)


def simulate_synthetic_data(samples, function):
    envs = np.array([0.2, 2, 3, 5])
    E = np.random.choice([0,1,2,3], size = samples)

    env = envs[E]
    Z1 = env + np.random.normal(0, 1, size = samples)
    Z2 = 2*env + np.random.normal(0, np.sqrt(2), size = samples)
    Y = Z1 + Z2 + np.random.normal(0, 1, size = samples)
    Z3 = Y + np.random.normal(0, 1, size = samples)
    Z = np.stack([Z1, Z2, Z3], axis = 1)

    if function == 'identity':
        X = Z
    elif function == 'linear':
        S = np.random.normal(size = (2,10))
        X = Z@S
    elif function == 'nonlinear':
        synthnn = SynthNN()
        X = synthnn(torch.tensor(Z).float()).detach().numpy()
    
    return X, Y, E, Z, env

In [6]:
model = SynthNN()
x, y, e, z, env = simulate_synthetic_data(100, "nonlinear")

In [14]:
class EnvDataset(Dataset):
    def __init__(self, X, Y, E):
        super().__init__()
        self.X = torch.tensor(X)
        self.Y = torch.tensor(Y).unsqueeze(1)
        E = torch.tensor(E) #From Peter
        # self.E = torch.nn.functional.one_hot(torch.tensor(E))
        self.E = torch.nn.functional.one_hot(E.long())
    
    def __getitem__(self, index):
        return self.X[index], self.Y[index], self.E[index] 
    
    def __len__(self):
        return(len(self.X))

## Generate data

In [15]:
torch.manual_seed(1)
X, Y, envs, Z, E = simulate_synthetic_data(8000, 'nonlinear')
dset = EnvDataset(X[:4000],Y[:4000],envs[:4000])
train_loader = DataLoader(dset, batch_size=128, drop_last=True)
dset = EnvDataset(X[4000:],Y[4000:],envs[4000:])
test_loader = DataLoader(dset, batch_size=128, drop_last=True)