In [826]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch import optim
import pyro
import pyro.distributions as dist

from tensorboardX import SummaryWriter

from tqdm import tqdm
import time


import numpy as np
import random
from numpy.random import *
import matplotlib.pyplot as plt 
#np.random.seed(100)
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

import csv

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

seed = 1234
rng = np.random.RandomState(seed)
torch.manual_seed(seed)

if torch.cuda.is_available():
    device='cuda'
else:
    device='cpu'

In [909]:
x1 = np.random.binomial(1,0.5,size=70000)
x2 = np.random.normal(x1*0.4+0.5,0.2)
effect = 1.5
beta2 = 0.8
eps = 1e-10
b = np.random.normal(0,0.2)


def kaiki(x1=x1,x2=x2,effect=effect,beta2=beta2,b=b):
    return effect*x1+beta2*x2+b

yogo = kaiki(x1,x2)
data = np.array((x1,x2,yogo)).T

In [987]:
data[:60000].shape
train = data[:60000]#train[0]:conditional, train[1]:feature, train[2]:outcome
test  = data[60000:]

In [988]:
len(train)

60000

In [989]:
batch_size = 100
num_iters = 24000
num_epochs = int(batch_size* num_iters/len(train))

In [926]:
train_loader = torch.utils.data.DataLoader(dataset=train, batch_size = batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset = test, batch_size = batch_size, shuffle = False)

In [990]:
num_epochs

40

In [991]:
from pixyz.distributions import Bernoulli, Normal
from pixyz.losses import KullbackLeibler
from pixyz.models import VAE

In [992]:
class Inference(Normal):
    def __init__(self, input_dim,cond_dim, hidden_dim, latent_dim):
        super(Inference, self).__init__(cond_var=["x","y"], var=["z"], name="q")
        
        self.fc1 = nn.Linear(input_dim+cond_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc31 = nn.Linear(hidden_dim, latent_dim)
        self.fc32 = nn.Linear(hidden_dim, latent_dim)
    def forward(self,x,y):
        h = F.relu(self.fc1(torch.cat([x,y],1)))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}

class Generator(Bernoulli):
    def __init__(self, latent_dim,cond_dim, hidden_dim, output_dim):
        super(Generator, self).__init__(cond_var=["z","y"], var=["x"], name="p")
        
        self.fc1 = nn.Linear(latent_dim+cond_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self,z,y):
        h = F.relu(self.fc1(torch.cat([z,y],1)))
        h = F.relu(self.fc2(h))
        return{"probs": F.sigmoid(self.fc3(h))}

class Estimator(Normal):
    def __init__(self, cond_dim, hidden_dim, latent_dim, estimate_dim):
        super(Estimator, self).__init__(cond_var=["z","y"], var=["e"], name="e")
        
        self.fc1 = nn.Linear(latent_dim+cond_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, estimate_dim)
        
    def forward(self,z,y):
        h = F.relu(self.fc1(torch.cat([z,y],1)))
        h = F.relu(self.fc2(h))
        return{"value": self.fc3(h)}

In [993]:
input_dim = 1
hidden_dim = 200
latent_dim = 10
output_dim = 1
cond_dim = 2
estimate_dim =1

In [994]:
p = Generator(latent_dim,cond_dim, hidden_dim, output_dim)
q = Inference(input_dim,cond_dim, hidden_dim, latent_dim)
e = Estimator(cond_dim, hidden_dim, latent_dim, estimate_dim)

p.to(device)
q.to(device)
e.to(device)

Estimator(
  (fc1): Linear(in_features=12, out_features=200, bias=True)
  (fc2): Linear(in_features=200, out_features=200, bias=True)
  (fc3): Linear(in_features=200, out_features=1, bias=True)
)

In [995]:
loc = torch.tensor(0.)
scale = torch.tensor(1.)
prior = Normal(loc=loc, scale=scale, var=["z"], dim= latent_dim, name="p_prior")

In [996]:
kl = KullbackLeibler(q,prior)

In [997]:
model = VAE(q,p, regularizer = kl, optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [998]:
def data_shaper(x,y, input_dim):
    x = x.view(-1, input_dim).float()
    y = torch.eye(2)[y.long()].float()
    return x,y

def train(epoch, input_dim):
    train_loss = 0
    for data in tqdm(train_loader):
        y = data[:,0]
        x = data[:,1]
        x, y = data_shaper(x, y, input_dim)
        x = x.to(device)
        y = y.to(device)
       
        loss = model.train({"x":x, "y":y})
        train_loss += loss
    #pytorchのlossはデフォルトでは平均の値になる。このためバッチ長をかけて
    #全体で割って全体のlossを計算する
    
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch:{} Train loss:{:.4f}'.format(epoch, train_loss))
    return train_loss

In [999]:
def test(epoch, input_dim):
    test_loss = 0
    for data in test_loader:
        y = data[:,0]
        x = data[:,1]
        x, y = data_shaper(x,y, input_dim)
        x = x.to(device)
        y = y.to(device)
        
        loss = model.test({"x":x, "y":y})
        test_loss += loss
    test_loss = test_loss * test_loader.batch_size/ len(test_loader.dataset)
    print('Test loss:{:.4f}'.format(test_loss))
    return test_loss

In [1000]:
def plot_reconstrunction(x,y,output_dim):
    with torch.no_grad():
        z = q.sample({"x":x, "y":y}, return_all=False)
        z.update({"y":y})
        recon_batch = p.sample_mean(z).view(-1, output_dim)
        recon = torch.cat([x.view(-1, output_dim), recon_batch]).cpu()
        return recon

def plot_image_from_latent(z,y, output_dim):
    with torch.no_grad():
        sample = p.sample_mean({"z":z, "y":y}).view(-1, output_dim).cpu()
        return sample

def plot_reconstraction_changing_y(x,y, output_dim):
    y_change = torch.eye(2)[range(2)].to(device)
    batch_dummy = torch.ones(x.size(0))[:, None].to(device)
    recon_all = []
    
    with torch.no_grad():
        for _y in y_change:
            z = q.sample({"x": x, "y": y}, return_all=False)
            z.update({"y": batch_dummy * _y[None,:]})
            recon_batch = p.sample_mean(z).view(-1, output_dim)
            recon_all.append(recon_batch)
    
        recon_changing_y = torch.cat(recon_all)
        recon_changing_y = torch.cat([x.view(-1, output_dim), recon_changing_y]).cpu()
        return recon_changing_y

In [1001]:
writer = SummaryWriter()

plot_number = 1

z_sample = 0.5 * torch.randn(64, latent_dim).to(device)
y_sample = torch.eye(2)[[plot_number]*64].to(device)

_y, _x = iter(x_test_loader).next()
_x, _y = data_shaper(_x, _y, input_dim)
_x = _x.to(device)
_y = _y.to(device)

for epoch in range(1, num_epochs + 1):
    train_loss = train(epoch, input_dim)
    test_loss = test(epoch, input_dim)
    
    recon = plot_reconstrunction(_x[:8], _c[:8], output_dim)
    sample = plot_image_from_latent(z_sample, y_sample, output_dim)
    recon_changing_y = plot_reconstraction_changing_y(_x[:8], _y[:8], output_dim)
    

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    """
    writer.add_image('Image_from_latent', sample, epoch)
    writer.add_image('Image_reconstrunction', recon, epoch)
    writer.add_image('Image_reconstrunction_change_y', recon_changing_y, epoch)
    """
writer.close()

100%|██████████| 600/600 [00:03<00:00, 174.86it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:1 Train loss:0.5270
Test loss:0.5084


100%|██████████| 600/600 [00:03<00:00, 187.66it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:2 Train loss:0.5115
Test loss:0.5093


100%|██████████| 600/600 [00:02<00:00, 211.48it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:3 Train loss:0.5113
Test loss:0.5081


100%|██████████| 600/600 [00:03<00:00, 169.88it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:4 Train loss:0.5114
Test loss:0.5080


100%|██████████| 600/600 [00:03<00:00, 169.82it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:5 Train loss:0.5112
Test loss:0.5094


100%|██████████| 600/600 [00:02<00:00, 205.40it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:6 Train loss:0.5113
Test loss:0.5082


100%|██████████| 600/600 [00:03<00:00, 208.84it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:7 Train loss:0.5112
Test loss:0.5088


100%|██████████| 600/600 [00:03<00:00, 186.62it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:8 Train loss:0.5112
Test loss:0.5078


100%|██████████| 600/600 [00:03<00:00, 181.49it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:9 Train loss:0.5110
Test loss:0.5079


100%|██████████| 600/600 [00:03<00:00, 172.41it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:10 Train loss:0.5111
Test loss:0.5089


100%|██████████| 600/600 [00:03<00:00, 170.08it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:11 Train loss:0.5109
Test loss:0.5080


100%|██████████| 600/600 [00:02<00:00, 206.71it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:12 Train loss:0.5109
Test loss:0.5080


100%|██████████| 600/600 [00:03<00:00, 184.74it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:13 Train loss:0.5109
Test loss:0.5089


100%|██████████| 600/600 [00:03<00:00, 169.71it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:14 Train loss:0.5111
Test loss:0.5077


100%|██████████| 600/600 [00:03<00:00, 169.74it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:15 Train loss:0.5108
Test loss:0.5078


100%|██████████| 600/600 [00:03<00:00, 171.54it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:16 Train loss:0.5109
Test loss:0.5079


100%|██████████| 600/600 [00:03<00:00, 170.12it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:17 Train loss:0.5108
Test loss:0.5086


100%|██████████| 600/600 [00:02<00:00, 206.75it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:18 Train loss:0.5108
Test loss:0.5077


100%|██████████| 600/600 [00:02<00:00, 211.38it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:19 Train loss:0.5108
Test loss:0.5083


100%|██████████| 600/600 [00:03<00:00, 171.61it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:20 Train loss:0.5108
Test loss:0.5086


100%|██████████| 600/600 [00:03<00:00, 172.62it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:21 Train loss:0.5109
Test loss:0.5076


100%|██████████| 600/600 [00:03<00:00, 169.34it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:22 Train loss:0.5107
Test loss:0.5079


100%|██████████| 600/600 [00:03<00:00, 169.79it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:23 Train loss:0.5108
Test loss:0.5086


100%|██████████| 600/600 [00:03<00:00, 186.81it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:24 Train loss:0.5108
Test loss:0.5076


100%|██████████| 600/600 [00:03<00:00, 188.74it/s]


Epoch:25 Train loss:0.5108


  3%|▎         | 17/600 [00:00<00:03, 168.06it/s]

Test loss:0.5083


100%|██████████| 600/600 [00:03<00:00, 165.95it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:26 Train loss:0.5107
Test loss:0.5080


100%|██████████| 600/600 [00:03<00:00, 169.71it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:27 Train loss:0.5107
Test loss:0.5080


100%|██████████| 600/600 [00:03<00:00, 169.19it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:28 Train loss:0.5106
Test loss:0.5082


100%|██████████| 600/600 [00:03<00:00, 169.61it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:29 Train loss:0.5106
Test loss:0.5083


100%|██████████| 600/600 [00:03<00:00, 171.57it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:30 Train loss:0.5109
Test loss:0.5079


100%|██████████| 600/600 [00:03<00:00, 170.35it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:31 Train loss:0.5106
Test loss:0.5081


100%|██████████| 600/600 [00:03<00:00, 173.30it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:32 Train loss:0.5107
Test loss:0.5079


100%|██████████| 600/600 [00:03<00:00, 170.42it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:33 Train loss:0.5107
Test loss:0.5078


100%|██████████| 600/600 [00:03<00:00, 168.26it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:34 Train loss:0.5107
Test loss:0.5079


100%|██████████| 600/600 [00:03<00:00, 174.97it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:35 Train loss:0.5107
Test loss:0.5076


100%|██████████| 600/600 [00:03<00:00, 171.52it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:36 Train loss:0.5106
Test loss:0.5077


100%|██████████| 600/600 [00:03<00:00, 169.84it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:37 Train loss:0.5106
Test loss:0.5080


100%|██████████| 600/600 [00:03<00:00, 166.41it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:38 Train loss:0.5107
Test loss:0.5077


100%|██████████| 600/600 [00:03<00:00, 170.32it/s]
  0%|          | 0/600 [00:00<?, ?it/s]

Epoch:39 Train loss:0.5107
Test loss:0.5077


100%|██████████| 600/600 [00:03<00:00, 180.80it/s]


Epoch:40 Train loss:0.5107
Test loss:0.5082


In [1003]:
pt_size = 10000
z_sample = 0.5 * torch.randn(pt_size, latent_dim).to(device)

# 非治療群を生成
plot_number = 0
y_sample = torch.eye(2)[[plot_number]*pt_size].to(device)
untreated =  plot_image_from_latent(z_sample,y_sample , output_dim)
# 治療群を生成
plot_number = 1
y_sample = torch.eye(2)[[plot_number]*pt_size].to(device)
treated =  plot_image_from_latent(z_sample,y_sample , output_dim)




In [804]:
criterion = nn.MSELoss()
model2 = Estimator(cond_dim, hidden_dim, latent_dim, estimate_dim)
model2.to(device)
optimizer = torch.optim.SGD(model2.parameters(),lr=1e-3)

In [1013]:
for epoch in range(1, num_epochs+1):
    for data in tqdm(train_loader):
        y = data[:,0]
        x = data[:,1]
        o = data[:,2]
        x, y = data_shaper(x, y, input_dim)
        o = o.view(-1, estimate_dim).float()
        x = x.to(device)
        y = y.to(device)
        o = o.to(device)
        
        optimizer.zero_grad()
        
        z = q.sample({"x":x, "y":y}, return_all=False)
        output = e.sample({"z":z,"y":y},return_all=False)
    
        
        loss = criterion(output,o)
        
        loss.backward()
        
        optimizer.step()
    
    total = 0
    total_loss = 0
    for data in test_loader:
        y = data[:,0]
        x = data[:,1]
        o = data[:,2]
        x, y = data_shaper(x, y, input_dim)
        o = o(-1, estimate_dim).float()
        x = x.to(device)
        y = y.to(device)
        o = o.to(device)
        
        z = q.sample({"x":x, "y":y}, return_all=False)
        e = model2(z,y)
        
        loss = criterion(e,o)
        
        total_loss += loss *len(data)
        total += len(data)
    test_loss = total_loss/total.cpu()
    
    print('Epochs:{}, train_Loss:{}, test_Loss:{:.2f}'.format(epoch,loss.item(), test_loss.item()))
    

  0%|          | 0/600 [00:00<?, ?it/s]


AttributeError: 'dict' object has no attribute 'ToTensor'

In [1004]:
data[:0]

array([], shape=(0, 3), dtype=float64)

In [1006]:
data=iter(test_loader).next()

In [1007]:
len(data)

100

In [1029]:
np.array(z).shape

()