In [1]:
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
import numpy as np
import os
import datetime

In [31]:
device = torch.device("cuda:0" if torch.cuda.device_count() != 0 else "cpu")
print(device)

n_qubits = 50
n_outcomes = 4
batchSize = 32
filename = 'train.txt'
num_epochs = 50
log_interval = 1000

cuda:0


In [32]:
data_array = np.load('numpy_POVM_data.npz')['a']


In [33]:
#np.savez_compressed('numpy_POVM_data', a = data_array)
print(data_array.shape)

(1000000, 200)


In [34]:
# print(data_array[0][1:])
# data_array[data_array == 0] = -1
# print(data_array[0][1:])

In [35]:
torch.nn.functional.binary_cross_entropy.__code__.co_varnames

('input', 'target', 'weight', 'size_average', 'reduce', 'new_size')

In [85]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(n_qubits * n_outcomes, (n_qubits*n_outcomes) * 2)
        self.fc21 = nn.Linear((n_qubits* n_outcomes) * 2, (n_qubits* n_outcomes) // 4)
        self.fc22 = nn.Linear((n_qubits* n_outcomes) * 2, (n_qubits* n_outcomes) // 4)
        self.fc3 = nn.Linear((n_qubits* n_outcomes) // 4, (n_qubits* n_outcomes) * 2)
        self.fc4 = nn.Linear((n_qubits* n_outcomes) * 2, n_qubits * n_outcomes) 
        #self.act = nn.Softmax( 1)

    def encode(self, x):
        h1 = (self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        h3 = torch.tanh(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    
    #CE = nn.CrossEntropyLoss()
    #CE_loss = CE(recon_x, x)
    #print(recon_x.shape, x.shape)
    BCE = F.binary_cross_entropy(recon_x, x, reduce=True)
    #MSE = F.mse_loss(recon_x, x.view(-1, 200), size_average=False)
    
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD, BCE, KLD


def train(epoch, train_losses):
    model.train()
    train_loss = 0
    for batch_idx, (data) in enumerate(train_loader,):
        data = data[0].to(device)
        #print(data.shape)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        #print(recon_batch.shape, data.shape, mu.shape, logvar.shape)
        loss, MSE, KLD = loss_function(recon_batch, data, mu, logvar)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tMSE: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data), MSE / len(data)))
    avg_batch_loss = train_loss / len(train_loader.dataset)
    train_losses.append(avg_batch_loss)

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, avg_batch_loss))


def test(epoch, test_losses):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data) in enumerate(test_loader,0):
            data = data[0].to(device)
            #print data.shape
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar)[0].item()
            if i == 0:
                n = min(data.size(0), 8)
                #comparison = torch.cat([data[:n],
                #                      recon_batch.view(batchSize, 1, imageSize, imageSize)[:n]])
                #save_image(comparison.cpu(),
                #         'results/reconstruction_' + str(epoch) + '.png', nrow=n)
                
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('====> Test set loss: {:.4f}'.format(test_loss))

    
print(model)


VAE(
  (fc1): Linear(in_features=200, out_features=400, bias=True)
  (fc21): Linear(in_features=400, out_features=50, bias=True)
  (fc22): Linear(in_features=400, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=200, bias=True)
)


In [86]:
#normed_array, scale = normalize(image_array, norm_scale = norm_scale )

tensor_array_train = torch.stack([torch.Tensor(i) for i in data_array[:100000]])
tensor_data_train = torch.utils.data.TensorDataset(tensor_array_train)
tensor_array_test = torch.stack([torch.Tensor(i) for i in data_array[100000:200000]])
tensor_data_test = torch.utils.data.TensorDataset(tensor_array_test)

train_loader = torch.utils.data.DataLoader(tensor_data_train, batch_size=batchSize, num_workers=8)
test_loader = torch.utils.data.DataLoader(tensor_data_test, batch_size = batchSize, num_workers = 8)
print(train_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7f03840741d0>


In [87]:
epochs = []
train_losses = []
test_losses = []

mydir = os.path.join('/Users/Mike_Laptop/Documents/Postgraduate/Perimeter/RESEARCH/datapovmtfim/runs/', 
                     datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "/")
mydir = None
if mydir != None:
    try:
        os.makedirs(mydir)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

    with open(mydir + 'NetworkInfo.txt', 'w') as f:
        f.write('model:' + str(model) + '\n')
        #print >> f, 'model:', model
        #print >> f, 'normalization scale:', norm_scale
        #print >> f, 'latent dimensionality:', zdim
        f.write('loss function: MSE + KL div' + '\n')
        f.write('batch size:' +  str(batchSize) + '\n')
        f.write('epochs:' +  str(num_epochs) + '\n')
        f.write('number of qubits: ' + str(n_qubits) + '\n')
        f.write('number of measurement outcomes: ' +  str(n_outcomes) + '\n')
        #print >> f, 'beta:', beta



num_epochs = 300

for epoch in range(1, + num_epochs +1):
    epochs.append(epoch)
   # if epoch % 2 == 1:
    #    plot_reconstructions(model, save_dir=mydir, conv=False, simple=False, n=6, epoch = epoch)
    #    plot_avg_both(data_avg[0],data_avg[1], n_events, save_dir = mydir,  epoch = epoch, norm_scale=norm_scale)
       # plot_difference(data_avg, epoch = epoch, save_dir=mydir, norm_scale = norm_scale, n_events = n_events)
   #     samples(model, epoch = epoch, save_dir = mydir)
    #if epoch % 4 == 1:
        #plot_all_metrics(genx_FWHMs, realx_FWHMs, genx_means, realx_means,epochs = epochs, n_epochs=num_epochs, save_dir = mydir)
        #plot_sum_difference(sum_differences, epochs=epochs)
        #xsection_multi(data_avg, epoch, savedir=mydir, imageSize=imageSize)
    train(epoch, train_losses)
    test(epoch, test_losses)
    #if epoch % 4 == 1:
    #    plot_losses(train_losses, test_losses, epochs = epochs, n_epochs=num_epochs, save_dir = mydir)
    #if epoch == num_epochs:
    #    samples(model, rows=1,columns=1, save_dir=mydir,  epoch = epoch, beta = beta)
        
#     if epoch > 1:
#         stats_r, stats_f = calc_means_stds(imageSize=imageSize, n_times = 10)
#         plot_means_stds(stats_r, stats_f, save_dir = mydir, imageSize=imageSize, trim=10)
        
    

====> Epoch: 1 Average loss: 0.0408
====> Test set loss: 0.0138
====> Epoch: 2 Average loss: 0.0138
====> Test set loss: 0.0138
====> Epoch: 3 Average loss: 0.0139


Process Process-341:
Process Process-344:
Process Process-343:
Process Process-339:
Process Process-342:
Process Process-340:
Process Process-337:
Process Process-338:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/albergo/anaconda2/envs/pytorch04/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/albergo/anaconda2/envs/pytorch04/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/albergo/anaconda2/envs/pytorch04/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/albergo/anaconda2/envs/pytorch04/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/albergo/anaconda2/envs/pytorch04/lib/python3.6/multiproce

KeyboardInterrupt: 

  File "/home/albergo/anaconda2/envs/pytorch04/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/albergo/anaconda2/envs/pytorch04/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/albergo/anaconda2/envs/pytorch04/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/albergo/anaconda2/envs/pytorch04/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/albergo/anaconda2/envs/pytorch04/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/albergo/anaconda2/envs/pytorch04/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/albergo/anaconda2/envs/pytorch04/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  Fil

In [103]:
sample = torch.randn(200, 50).to(device)
sample = model.decode(sample).cpu()
sample[10]

tensor([ 0.1855,  0.0356,  0.1849,  0.5955,  0.1443,  0.0539,  0.1887,
         0.6169,  0.1902,  0.0601,  0.1624,  0.6173,  0.1536,  0.0605,
         0.1956,  0.6004,  0.1859,  0.0507,  0.1793,  0.6032,  0.1408,
         0.0584,  0.1876,  0.6216,  0.2292,  0.0583,  0.1773,  0.5443,
         0.1560,  0.0652,  0.1613,  0.6370,  0.2383,  0.0563,  0.1805,
         0.5656,  0.1099,  0.0610,  0.1694,  0.6869,  0.2109,  0.0605,
         0.1762,  0.5945,  0.1286,  0.0669,  0.1696,  0.6793,  0.2284,
         0.0659,  0.1769,  0.5568,  0.1330,  0.0657,  0.1753,  0.6499,
         0.2129,  0.0639,  0.1923,  0.5540,  0.1188,  0.0654,  0.1795,
         0.6630,  0.2214,  0.0701,  0.1618,  0.5553,  0.1105,  0.0597,
         0.1901,  0.6318,  0.2360,  0.0680,  0.1920,  0.5237,  0.1091,
         0.0774,  0.1708,  0.6415,  0.2512,  0.0682,  0.1663,  0.5372,
         0.1179,  0.0680,  0.1717,  0.6743,  0.2404,  0.0610,  0.1768,
         0.5450,  0.1285,  0.0690,  0.1922,  0.6470,  0.2275,  0.0685,
      