In [None]:
# Import libraries
import torch as th
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
from tqdm.auto import trange
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Import datasets

sep_data = th.load('../datasets/sep_states1.pt')

ent_data = th.load('../datasets/ent_states1.pt')

sep_data_pair = th.stack((sep_data.real, sep_data.imag), dim=-1)
ent_data_pair = th.stack((ent_data.real, ent_data.imag), dim=-1)

sep_data_pair.shape, ent_data_pair.shape

In [None]:
# Split the data into training and testing sets

sep_train, sep_test = train_test_split(sep_data_pair, test_size=0.3)
ent_train, ent_test = train_test_split(ent_data_pair, test_size=0.3)

sep_train.shape, sep_test.shape, ent_train.shape, ent_test.shape

In [None]:
BATCH_SIZE = 64
ent_train_loader = DataLoader(ent_train, batch_size=BATCH_SIZE, shuffle=True)
ent_test_loader = DataLoader(ent_test, batch_size=BATCH_SIZE, shuffle=True)
sep_train_loader = DataLoader(sep_train, batch_size=BATCH_SIZE, shuffle=True)
sep_test_loader = DataLoader(sep_test, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
## Import train and test dataset, scale them and convert them to data loaders

from torchvision import datasets, transforms

BATCH_SIZE = 64


MNIST_dataset = datasets.MNIST(
    root="./data",
    train=True,
    transform= transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0, 1)]),
    download=True
)


In [None]:
data, labels = MNIST_dataset.data, MNIST_dataset.targets

mask = labels == 6

MNIST_6 = data[mask].float()

In [None]:
train_loader = DataLoader(dataset = MNIST_6,
                          batch_size = BATCH_SIZE,
                          shuffle = True)

In [None]:
images = next(iter(train_loader)).numpy()

fig = plt.figure(figsize=(25, 4))

for idx in np.arange(20):
    ax = fig.add_subplot(2, int(20/2), idx+1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(images[idx]), cmap='gray')
    # print out the correct label for each image
    # .item() gets the value contained in a Tensor

In [None]:
# Define the loss function

def custom_loss(x, x_hat, mean, logvar):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = -0.5 * th.sum(1 + logvar - mean * mean - logvar.exp())
    
    return reproduction_loss + KLD

In [None]:
class VAE_fc(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(VAE_fc, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size[0])
        self.fc2 = nn.Linear(hidden_size[0], hidden_size[1])
        
        self.leaky_relu = nn.LeakyReLU(0.2)
        
        self.fc3 = nn.Linear(hidden_size[1], hidden_size[0])
        self.fc4 = nn.Linear(hidden_size[0], input_size)
    
    def encode(self, x):
        
        # print("flatten", x.shape)
        x = self.fc1(x)
        x = self.leaky_relu(x)
        # print("fc1", x.shape)
        
        x = self.fc2(x)
        x = self.leaky_relu(x)
        # print("fc2", x.shape)
        return x
    
    def decode(self, x):
        x = self.fc3(x)
        x = self.leaky_relu(x)
        
        x = self.fc4(x)
        x = self.leaky_relu(x)
        return nn.Sigmoid()(x)
    
    def forward(self, x):
        # print("input", x.shape)
        x = x.flatten(start_dim=1)
        # print("flatten", x.shape)
        encoded = self.encode(x)
        
        decoded = self.decode(encoded)
        return encoded, decoded

In [None]:
def get_batch_accuracy(logit, target):
    corrects = (th.max(logit, 1)[1].view(target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects / target.size(0)
    return accuracy.item()


def get_test_stats(model, criterion, test_loader, device):
    test_acc, test_loss = 0.0, 0.0
    for _, data in enumerate(test_loader):
        data = data.to(device)
        _, decoded = model(data)
        test_loss += criterion(data, decoded).item()
        test_acc += get_batch_accuracy(data, decoded)
        return test_loss, test_acc

def train_model(model, train_loader, epochs, optimizer, criterion, device):
    _trained_model = model.to(device)
    _trained_model.train()
    _train_loss = []
    for _ in range(epochs):
        _total_loss = 0.0
        for _, _data in enumerate(train_loader):
            _data = _data.to(device)
            _, _decoded = _trained_model(_data)
            _loss = criterion(_data.flatten(start_dim=1), _decoded)
            optimizer.zero_grad()
            _loss.backward()
            optimizer.step()
            _total_loss += _loss.item() * data.size(0)
            
        _train_loss.append(_total_loss / len(train_loader.dataset))
        
        # print('Epoch: {}, Loss: {:.4f}'.format(epoch, epoch_loss))
    return _train_loss, _trained_model

In [None]:
model = VAE_fc(input_size=4 * 2, hidden_size=[32, 16])
optimizer = th.optim.Adam(model.parameters(), lr=0.01)

train_loss, model = train_model(model=model,
                    train_loader=ent_train_loader,
                    epochs=100,
                    optimizer=optimizer,
                    criterion=nn.MSELoss(),
                    device='cpu')


In [None]:
plt.plot(np.arange(len(train_loss)), train_loss)
plt.xlabel('Epochs')
plt.ylabel('Loss')

In [None]:
_, result = model(ent_train)

ent_data_reconsructed = result.view(-1, 4, 2)

ent_data_reconsructed

In [None]:
model = VAE_fc(input_size=28 * 28, hidden_size=[14 * 28, 14 * 14])


model = train_model(model=model,
                    train_loader=train_loader,
                    epochs=5,
                    optimizer=optimizer,
                    criterion=nn.MSELoss(),
                    device='cpu')

In [None]:
_, reconstructed_data = model(MNIST_6)

reconstructed_data = reconstructed_data.view(-1, 28, 28).detach().numpy()
reconstructed_data.shape

In [None]:
next(iter(reconstructed_data))[0]

In [None]:
images = next(iter(reconstructed_data))

fig = plt.figure(figsize=(25, 4))

for idx in np.arange(20):
    ax = fig.add_subplot(2, int(20/2), idx+1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(images), cmap='gray')
    # print out the correct label for each image
    # .item() gets the value contained in a Tensor

In [None]:
sep_data_pair.shape

In [None]:
sep_dataset = TensorDataset(sep_data_pair[:, :, 0], sep_data_pair[:, :, 1])
dataloader = DataLoader(sep_dataset, batch_size=64, shuffle=True)

In [None]:
# Define the Variational Autoencoder (VAE) class
class QuantumVAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(QuantumVAE, self).__init__()

        # Encoder
        # self.encoder = nn.Sequential(
        #     nn.Linear(input_dim, 128),
        #     nn.ReLU(),
        #     nn.Linear(128, 64),
        #     nn.ReLU(),
        #     nn.Linear(64, latent_dim * 2)  # Multiply by 2 for mean and log-variance
        # )
        
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, latent_dim * 2)
        
        

        # # Decoder
        # self.decoder = nn.Sequential(
        #     nn.Linear(latent_dim, 64),
        #     nn.ReLU(),
        #     nn.Linear(64, 128),
        #     nn.ReLU(),
        #     nn.Linear(128, input_dim),
        #     nn.Tanh()  # Assuming quantum states are represented by complex vectors
        # )
        
        self.fc4 = nn.Linear(latent_dim, 64)
        self.fc5 = nn.Linear(64, 128)
        self.fc6 = nn.Linear(128, input_dim)
        
    def encoder(self, x):
        # print("Init", x.shape)
        x = self.fc1(x)
        # print("fc1", x.shape)
        x = nn.ReLU()(x)
        x = self.fc2(x)
        # print("fc2", x.shape)
        x = nn.ReLU()(x)
        x = self.fc3(x)
        # print("fc3", x.shape)
        return x

    def decoder(self, x):
        x = self.fc4(x)
        x = nn.ReLU()(x)
        x = self.fc5(x)
        x = nn.ReLU()(x)
        x = self.fc6(x)
        x = nn.Tanh()(x)
        return x

    def reparameterize(self, mu, logvar):
        std = th.exp(0.5 * logvar)
        eps = th.randn_like(std)
        return mu + eps * std

    def forward(self, x_real, x_imag):
        # Concatenate real and imaginary parts
        # print("real", x_real.shape)
        # print("imag", x_imag.shape)
        x = th.cat([x_real, x_imag], dim=1)
        # print("concat", x.shape)

        # Encode
        encoded = self.encoder(x)

        # Split into mean and log-variance
        mu, logvar = th.chunk(encoded, 2, dim=1)

        # Reparameterize
        z = self.reparameterize(mu, logvar)

        # Decode
        decoded = self.decoder(z)

        # Split the decoded output into real and imaginary parts
        decoded_real, decoded_imag = th.chunk(decoded, 2, dim=1)

        return decoded_real, decoded_imag, mu, logvar

# Example usage
input_dim = 81 * 2  # Assuming complex vectors of length 81
latent_dim = 16  # Adjust as needed

# Instantiate the QuantumVAE model
quantum_vae = QuantumVAE(input_dim, latent_dim)


for x_real, x_imag in dataloader:
    reconstructions_real, reconstructions_imag, mu, logvar = quantum_vae(x_real, x_imag)
    print(reconstructions_real.shape, reconstructions_imag.shape, mu.shape, logvar.shape)
    break


In [32]:
# Define loss function and optimizer
quantum_vae = QuantumVAE(input_dim, latent_dim)
criterion = nn.MSELoss()
optimizer = optim.Adam(quantum_vae.parameters(), lr=0.001)

# Training loop (you need to provide your dataset and DataLoader)
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = 0.0
    for data_real, data_imag in dataloader:  # Assuming you have separate real and imaginary parts in your dataset
        # Forward pass
        reconstructions_real, reconstructions_imag, mu, logvar = quantum_vae(data_real, data_imag)

        # Concatenate real and imaginary parts for the loss calculation
        inputs_real = data_real
        inputs_imag = data_imag
        reconstructions = th.cat([reconstructions_real, reconstructions_imag], dim=1)

        # Compute the loss
        loss = criterion(reconstructions, th.cat([inputs_real, inputs_imag], dim=1)) + 0.5 * th.sum(logvar.exp() - logvar - 1 + mu.pow(2))
        train_loss += loss.item() * data_real.size(0)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Print the loss at the end of each epoch
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')


Epoch [1/10], Loss: 4025.6628
Epoch [2/10], Loss: 0.7752
Epoch [3/10], Loss: 0.3411
Epoch [4/10], Loss: 0.1947
Epoch [5/10], Loss: 0.1365
Epoch [6/10], Loss: 0.1090
Epoch [7/10], Loss: 0.0916
Epoch [8/10], Loss: 0.0872
Epoch [9/10], Loss: 0.0967
Epoch [10/10], Loss: 0.1282


In [41]:
decoded_real, decoded_imag, _, _ = quantum_vae(sep_data_pair[:, :, 0], sep_data_pair[:, :, 1])

decoded = th.stack([decoded_real, decoded_imag], dim=1)

torch.Size([30000, 2, 81])