In [157]:
# Import libraries
import torch as th
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torchvision.datasets import MNIST

In [158]:
# Import datasets

sep_data = th.load('../datasets/sep_states.pt')

ent_data = th.load('../datasets/ent_states.pt')

sep_data.shape, ent_data.shape

(torch.Size([10000, 4]), torch.Size([10000, 4]))

In [159]:
sep_data.dtype, ent_data.dtype

(torch.complex64, torch.complex64)

In [160]:
sep_data_pair = th.stack((sep_data.real, sep_data.imag), dim=-1)
ent_data_pair = th.stack((ent_data.real, ent_data.imag), dim=-1)


sep_data_pair.dtype, ent_data_pair.dtype

(torch.float32, torch.float32)

In [161]:
sep_data_pair.shape, ent_data_pair.shape

(torch.Size([10000, 4, 2]), torch.Size([10000, 4, 2]))

In [162]:
# Create data loaders

BATCH_SIZE = 64
ent_loader = DataLoader(sep_data_pair.reshape(-1, 1, 4, 2), batch_size=BATCH_SIZE, shuffle=True)
sep_loader = DataLoader(ent_data_pair.reshape(-1, 1, 4, 2), batch_size=BATCH_SIZE, shuffle=True)

In [163]:
# Define the loss function

def custom_loss(x, x_hat, mean, logvar):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = -0.5 * th.sum(1 + logvar - mean * mean - logvar.exp())
    
    return reproduction_loss + KLD

In [164]:
# create a transofrm to apply to each datapoint
transform = transforms.Compose([transforms.ToTensor()])

# download the MNIST datasets
path = '~/datasets'
train_dataset = MNIST(path, transform=transform, download=True)
test_dataset  = MNIST(path, transform=transform, download=True)

# create train and test dataloaders
batch_size = 100
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

device = th.device("cuda" if th.cuda.is_available() else "cpu")

In [165]:
print(train_dataset.data.shape)
for i, (images, labels) in enumerate(train_loader):
    print(images.shape)
    print(labels.shape)
    break

torch.Size([60000, 28, 28])
torch.Size([100, 1, 28, 28])
torch.Size([100])


In [166]:
class VAE(nn.Module):

    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.input_dim = input_dim
        # encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, latent_dim),
            nn.LeakyReLU(0.2)
            )
        
        # latent mean and variance 
        self.mean_layer = nn.Linear(latent_dim, 2)
        self.logvar_layer = nn.Linear(latent_dim, 2)
        
        # decoder
        self.decoder = nn.Sequential(
            nn.Linear(2, latent_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
            )
     
    def encode(self, x):
        x = self.encoder(x)
        mean, logvar = self.mean_layer(x), self.logvar_layer(x)
        return mean, logvar

    def reparameterization(self, mean, var):
        epsilon = th.randn_like(var).to(device)      
        z = mean + var*epsilon
        return z

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        print("Init", x.shape)
        x = x.view(-1, self.input_dim)
        print("View", x.shape)
        mean, logvar = self.encode(x)
        print("Mean", mean.shape)
        z = self.reparameterization(mean, logvar)
        print("Z", z.shape)
        x_hat = self.decode(z)
        print("X_hat", x_hat.shape)
        x_hat = x_hat.view(-1, 4, 2)
        return x_hat, mean, logvar

In [171]:
model = VAE(input_dim=28 * 28, hidden_dim=32, latent_dim=12)

for i, (x, label) in enumerate(train_loader):
    res = model(x)
    print(res[0].shape)
    break

Init torch.Size([100, 1, 28, 28])
View torch.Size([100, 784])
Mean torch.Size([100, 2])
Z torch.Size([100, 2])
X_hat torch.Size([100, 784])
torch.Size([9800, 4, 2])


In [None]:
class ComplexVAE(nn.Module):
    def __init__(self, input_channels, hidden_dim, kernel_size):
        super(ComplexVAE, self).__init__()
        
        
        
        # self.encoder = nn.Sequential(nn.Conv2d(input_channels, hidden_dim[0], kernel_size[0]),
        #                              nn.LeakyReLU(0.2),
        #                              nn.Conv2d(hidden_dim[0], hidden_dim[1], kernel_size[1]),
        #                              nn.LeakyReLU(0.2),
        #                              nn.Linear(hidden_dim[1], hidden_dim[2]),
        #                              nn.LeakyReLU(0.2))
        
        self.conv1 = nn.Conv2d(input_channels, hidden_dim[0], kernel_size[0])
        self.leakyrelu = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(hidden_dim[0], hidden_dim[1], kernel_size[1])
        self.fc = nn.Linear(hidden_dim[1], hidden_dim[2])
        
        # self.decoder = nn.Sequential(nn.Linear(hidden_dim[2], hidden_dim[1]),
        #                              nn.LeakyReLU(0.2),
        #                              nn.ConvTranspose2d(hidden_dim[1], hidden_dim[0], kernel_size[1]),
        #                              nn.LeakyReLU(0.2),
        #                              nn.ConvTranspose2d(hidden_dim[0], input_channels, kernel_size[0]),
        #                              nn.Sigmoid())
        
    def encode(self, x):
        print("Init", x.shape)
        x = self.conv1(x)
        x = self.leakyrelu(x)
        print("Conv1", x.shape)
        x = self.conv2(x)
        x = self.leakyrelu(x)
        print("Conv2", x.shape)
        x = x.view(-1, 4, 2)
        print("View", x.shape)
        x = self.fc(x)
        x = self.leakyrelu(x)
        print("FC", x.shape)
        return x
        
        
    def forward(self, x):
        x = self.encode(x)
        # print("Encoder", x.shape)
        # x = self.decoder(x)
        # print("Decoder", x.shape)
        return x

device = th.device('cuda' if th.cuda.is_available() else 'cpu')
model = ComplexVAE(input_channels=1, hidden_dim=[32, 12, 4], kernel_size=[1, 1]).to(device)

for i, x in enumerate(sep_loader):
    res = model(x)
    print(res.shape)
    break

Init torch.Size([64, 1, 4, 2])
Conv1 torch.Size([64, 32, 4, 2])
Conv2 torch.Size([64, 12, 4, 2])
View torch.Size([768, 4, 2])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (3072x2 and 12x4)