In [1]:
import numpy as np
import h5py
import torch
from haar_noising_script import apply_haar_scrambling

filename_channel_1 = "../../../data/QG1_normalized_16x16_100k"
filename_channel_2 = "../../../data/QG2_normalized_16x16_100k"
filename_channel_3 = "../../../data/QG3_normalized_16x16_100k"

data_X_channel_1 = np.array(h5py.File(filename_channel_1, "r")['X'])
data_X_channel_2 = np.array(h5py.File(filename_channel_2, "r")['X'])
data_X_channel_3 = np.array(h5py.File(filename_channel_3, "r")['X'])

data_X = np.stack([data_X_channel_1, data_X_channel_2, data_X_channel_3], axis=-1)

encoded_data_channel_1 = torch.load("../../../data/Q1_16x16_1k_encoded.pt")
encoded_data_channel_2 = torch.load("../../../data/Q2_16x16_1k_encoded.pt")
encoded_data_channel_3 = torch.load("../../../data/Q3_16x16_1k_encoded.pt")

encoded_data = np.stack([encoded_data_channel_1, encoded_data_channel_2, encoded_data_channel_3], axis=-1)
print(encoded_data.shape)

num_samples = 100
scrambled_states = apply_haar_scrambling(np.array(encoded_data), num_samples, seed=42)
scrambled_states = torch.tensor(scrambled_states, dtype=torch.float32)
print(scrambled_states.shape)

  encoded_data_channel_1 = torch.load("../../../data/Q1_16x16_1k_encoded.pt")
  encoded_data_channel_2 = torch.load("../../../data/Q2_16x16_1k_encoded.pt")
  encoded_data_channel_3 = torch.load("../../../data/Q3_16x16_1k_encoded.pt")


(1000, 8, 8, 4, 3)
torch.Size([100, 8, 8, 4])


  scrambled_states = torch.tensor(scrambled_states, dtype=torch.float32)


In [7]:
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import scipy.linalg
import pennylane as qml

# Assuming `encoded_data` and `scrambled_states` are already loaded and processed
train_encoded_data, val_encoded_data, train_scrambled_states, val_scrambled_states = train_test_split(
    encoded_data[:num_samples], scrambled_states, test_size=0.2, random_state=42, shuffle=True
)

n_qubits = 4
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev, interface='torch')
def quantum_circuit(inputs, weights):
    qml.templates.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.templates.StronglyEntanglingLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

class QuantumLayer(nn.Module):
    def __init__(self, n_qubits, n_layers):
        super(QuantumLayer, self).__init__()
        weight_shapes = {"weights": (n_layers, n_qubits, 3)}
        self.qlayer = qml.qnn.TorchLayer(quantum_circuit, weight_shapes)

    def forward(self, x):
        return self.qlayer(x)

class QuantumDiffusionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_qubits, n_layers):
        super(QuantumDiffusionModel, self).__init__()
        
        self.fc1 = nn.Linear(input_dim, n_qubits)  # Input dimension needs to be correct
        self.quantum_layer = QuantumLayer(n_qubits, n_layers)
        self.fc2 = nn.Linear(n_qubits, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        print(f"Input shape to model: {x.shape}")
        x = torch.relu(self.fc1(x))  # First linear layer
        print(f"Shape after fc1: {x.shape}")
        x = self.quantum_layer(x)  # Quantum layer
        print(f"Shape after quantum layer: {x.shape}")
        x = torch.relu(self.fc2(x))  # Second linear layer
        print(f"Shape after fc2: {x.shape}")
        x = self.dropout(x)  # Dropout
        x = self.fc3(x)  # Final layer
        print(f"Shape after fc3 (output): {x.shape}")
        return x

def decode(encoded_data):
    num_samples, encoded_height, encoded_width, num_channels, _ = encoded_data.shape
    decoded_data = np.zeros((num_samples, 16, 16, 3))  # Adjusted for 3 channels

    for sample in range(num_samples):
        for i in range(encoded_height):
            for j in range(encoded_width):
                for c in range(num_channels):
                    # Decode each channel
                    if c == 0:
                        decoded_data[sample, 2*i, 2*j] = encoded_data[sample, i, j, c, :]
                    elif c == 1:
                        decoded_data[sample, 2*i, 2*j+1] = encoded_data[sample, i, j, c, :]
                    elif c == 2:
                        decoded_data[sample, 2*i+1, 2*j] = encoded_data[sample, i, j, c, :]
                    elif c == 3:
                        decoded_data[sample, 2*i+1, 2*j+1] = encoded_data[sample, i, j, c, :]

    return decoded_data

def flip(decoded_data):
    return 1 - decoded_data

def calculate_statistics(data):
    data = data.reshape(data.shape[0], -1)
    mean = np.mean(data, axis=0)
    covariance = np.cov(data, rowvar=False)
    return mean, covariance

def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
    diff = mu1 - mu2
    covmean, _ = scipy.linalg.sqrtm(sigma1 @ sigma2, disp=False)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = diff @ diff + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
    return fid

n_layers = 6
input_dim = 8 * 8 * 4 * 3  # Updated input dimension based on data
hidden_dim = 128
output_dim = input_dim  

model = QuantumDiffusionModel(input_dim, hidden_dim, output_dim, n_qubits, n_layers)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

num_epochs = 50
loss_values = []
val_loss_values = []
fid_scores = []

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    # Reshape input correctly to fit into the model
    inputs = train_scrambled_states.view(len(train_scrambled_states), -1)
    print(f"Train scrambled states shape: {inputs.shape}")
    outputs = model(inputs)
    
    loss = criterion(outputs, train_encoded_data.view(len(train_encoded_data), -1))
    loss.backward()
    optimizer.step()
    loss_values.append(loss.item())
    
    model.eval()
    with torch.no_grad():
        val_inputs = val_scrambled_states.view(len(val_scrambled_states), -1)
        print(f"Val scrambled states shape: {val_inputs.shape}")
        val_outputs = model(val_inputs)
        val_loss = criterion(val_outputs, val_encoded_data.view(len(val_encoded_data), -1))
        val_loss_values.append(val_loss.item())
        
        denoised_states = model(val_scrambled_states.view(len(val_scrambled_states), -1))
        denoised_states = denoised_states.view(len(val_scrambled_states), 8, 8, 4, 3).detach().numpy()  # Adjusted for 3 channels
        decoded_data = decode(denoised_states)
        decoded_data = flip(decoded_data)

        mu1, sigma1 = calculate_statistics(data_X[:len(decoded_data)])
        mu2, sigma2 = calculate_statistics(decoded_data)
        fid = calculate_fid(mu1, sigma1, mu2, sigma2)
        fid_scores.append(fid)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}, FID: {fid:.4f}')

plt.figure(figsize=(18, 5))

plt.subplot(1, 2, 1)
plt.plot(loss_values, label='Training Loss', color='blue')
plt.plot(val_loss_values, label='Validation Loss', color='orange')
plt.xlabel('Epoch', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.title('Training and Validation Loss Over Epochs', fontsize=16)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend(fontsize=12)

plt.subplot(1, 2, 2)
plt.plot(fid_scores, label='FID Score', color='green')
plt.xlabel('Epoch', fontsize=14)
plt.ylabel('FID', fontsize=14)
plt.title('FID Score Over Epochs', fontsize=16)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend(fontsize=12)

plt.tight_layout()
plt.show()


Train scrambled states shape: torch.Size([80, 256])
Input shape to model: torch.Size([80, 256])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (80x256 and 768x4)

In [None]:
def generate_new_images(model, num_images, input_dim=8*8*4*3):
    model.eval()  
    with torch.no_grad():
        
        for i in range(num_images):
            random_noise = torch.randn(num_images, input_dim)
            
            generated_data = model(random_noise)
            generated_data = flip(generated_data.view(num_images, 8, 8, 4, 3).detach().numpy())  # Adjusted for 3 channels
            
            decoded_images = decode(generated_data)
            fig, axes = plt.subplots(1, 5, figsize=(10, 2))

            for qubit in range(4):
                im = axes[qubit].imshow(generated_data[i, :, :, qubit, 0], cmap='viridis')  # Showing first channel
                axes[qubit].set_title(f"Encoded Qubit {qubit+1} (Channel 1)")
                fig.colorbar(im, ax=axes[qubit])

            im = axes[4].imshow(decoded_images[i, :, :, 0], cmap='viridis')  # Showing decoded first channel
            axes[4].set_title("Decoded (Channel 1)")
            fig.colorbar(im, ax=axes[4])
            
            plt.tight_layout()
            plt.show()

    return decoded_images

num_samples_to_generate = 5
new_images = generate_new_images(model, num_samples_to_generate)