In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import time
import math
import pandas as pd
import joblib
from torch.utils.data import Dataset, DataLoader
from mpl_toolkits.mplot3d import Axes3D

# Function to check conditions for the ply layer, input shape is (1, 65)
def check_conditions(random_data):
    row_angles = random_data[random_data != -90]
    ply = row_angles[:-1]
    if row_angles[-1] == 1:
        symmetric_ply = np.concatenate((ply, ply[::-1]))
    else:
        ply_trimmed = ply[:-1]
        symmetric_ply = np.concatenate((ply, ply_trimmed[::-1]))
        
    length = len(symmetric_ply)
    if length == 0:
        return 1, 0, 45
    else:
        contiguity_constraint = adjacent_angle_index = 0
        gathering_index = 2

        # Evaluate ply layer gathering
        gathering_index_judge = np.random.randint(2,6)
        
        for j in range(2,7):
            for i in range(length+1-j):
                if len(set(symmetric_ply[i:i+j])) == 1:
                    gathering_index = j
                    if gathering_index > gathering_index_judge:
                        contiguity_constraint = 1

        adjacent_angles = [0]
        for k in range(1, len(ply)):
            prev_value = ply[k - 1]
            value = ply[k]
            adjacent_angle = min(abs(prev_value - value), 180-abs(prev_value - value))
            adjacent_angles.append(adjacent_angle)

        adjacent_angle_index = max(max(adjacent_angles),45)

        return contiguity_constraint, gathering_index, adjacent_angle_index

# Calculate Laminate Parameters, input shape is (1, 65)
def LaminateParameters(random_data):
    row_angles = random_data[random_data != -90]
    ply = row_angles[:-1]
    if row_angles[-1] == 1:
        symmetric_ply = np.concatenate((ply, ply[::-1]))
    else:
        ply_trimmed = ply[:-1]
        symmetric_ply = np.concatenate((ply, ply_trimmed[::-1]))
        
    n = len(symmetric_ply)
    
    if n == 0:
        return [-1, -1, -1, -1, -1, -1, -1, -1, -1]
    else:
        angles_radians = torch.deg2rad(torch.Tensor(symmetric_ply))

        # Calculate z(i)
        layer_heights = (torch.cumsum(torch.ones(n), dim=0)-n/2)

        # Calculate a1,a2,a3,a4
        a1 = round(1/n * torch.sum(torch.cos(2 * angles_radians) * (layer_heights - (layer_heights - 1))).item(), 4)
        a2 = round(1/n * torch.sum(torch.cos(4 * angles_radians) * (layer_heights - (layer_heights - 1))).item(), 4)
        a3 = round(1/n * torch.sum(torch.sin(2 * angles_radians) * (layer_heights - (layer_heights - 1))).item(), 4)
        a4 = round(1/n * torch.sum(torch.sin(4 * angles_radians) * (layer_heights - (layer_heights - 1))).item(), 4)

        # Calculate d1,d2,d3,d4
        d1 = round(4/(n**3) * torch.sum(torch.cos(2 * angles_radians) * (layer_heights**3 - (layer_heights - 1)**3)).item(), 4)
        d2 = round(4/(n**3) * torch.sum(torch.cos(4 * angles_radians) * (layer_heights**3 - (layer_heights - 1)**3)).item(), 4)
        d3 = round(4/(n**3) * torch.sum(torch.sin(2 * angles_radians) * (layer_heights**3 - (layer_heights - 1)**3)).item(), 4)
        d4 = round(4/(n**3) * torch.sum(torch.sin(4 * angles_radians) * (layer_heights**3 - (layer_heights - 1)**3)).item(), 4)

        return [a1, a2, a3, a4, d1, d2, d3, d4, n]

# Encode ply angles into binary representation
angle_mapping = {
    0: '1000000000000', -15: '0100000000000', -30: '0010000000000', -45: '0001000000000', -60: '0000100000000',
    -75: '0000010000000', -90: '0000001000000', 90: '0000000100000', 75: '0000000010000', 60: '0000000001000',
    45: '0000000000100', 30: '0000000000010', 15: '0000000000001',
}

def normalize_data(tensor, min_vals, max_vals):
    """
    Normalize each column of the tensor to the range [0, 1].
    Each column is treated independently.
    """
    # Normalize each column independently
    normalized_tensor = (tensor - min_vals) / (max_vals - min_vals)

    return normalized_tensor

loaded_ply_angle = torch.load('0304ply_angle.pt')

def batch_encode_angles(angles_batch):
    batch_size = angles_batch.shape[0]
    num_angles = angles_batch.shape[1] - 1  # The last column is not included in the encoding
    encoded_result = np.zeros((batch_size, num_angles * 13 + 1), dtype=int)  # 13 bits for each angle

    for batch_idx in range(batch_size):
        for i in range(num_angles):
            angle = angles_batch[batch_idx, i]
            angle_key = angle_mapping.get(angle)

            # Convert to 13-bit binary code
            encoded_bits = [int(bit) for bit in angle_key]
            encoded_result[batch_idx, i * 13:(i + 1) * 13] = encoded_bits
        
        # Add the value of the last column to the last column of the encoding result
        encoded_result[batch_idx, -1] = angles_batch[batch_idx, -1]

    return torch.tensor(encoded_result)

loaded_ply_01encoded = batch_encode_angles(loaded_ply_angle.numpy())
loaded_ply_01encoded = torch.tensor(loaded_ply_01encoded)

loaded_normalized_LP_tensor = torch.load('0304normalized_LP_tensor.pt')
loaded_normalized_Index_tensor = torch.load('0304normalized_Index_tensor.pt')
loaded_Index_tensor = torch.load('0304Index_tensor.pt')
loaded_LP_tensor = torch.load('0304LP_tensor.pt')

index_min_vals = torch.min(loaded_Index_tensor, dim=0)[0]
index_max_vals = torch.max(loaded_Index_tensor, dim=0)[0]

LP_min_vals = torch.min(loaded_LP_tensor, dim=0)[0]
LP_max_vals = torch.max(loaded_LP_tensor, dim=0)[0]

# Dataset and DataLoader
class CustomDataset(Dataset):
    def __init__(self, ply_angle, ply_01encoded, normalized_LP_tensor, normalized_Index_tensor):
        self.ply_angle = ply_angle
        self.ply_01encoded = ply_01encoded
        self.normalized_LP_tensor = normalized_LP_tensor
        self.normalized_Index_tensor = normalized_Index_tensor

    def __len__(self):
        return len(self.ply_01encoded)

    def __getitem__(self, index):
        data_point = {
            "Ply_angle": self.ply_angle[index],
            "Input01": self.ply_01encoded[index],
            "Norm_LP": self.normalized_LP_tensor[index],
            "Norm_index": self.normalized_Index_tensor[index],
        }
        return data_point

# Dataloaders for training, validation, and testing
batch_size = 50
train_loader = DataLoader(
    dataset=CustomDataset(loaded_ply_angle[0:35000], loaded_ply_01encoded[0:35000], loaded_normalized_LP_tensor[0:35000], loaded_normalized_Index_tensor[0:35000]),
    batch_size=batch_size,
    shuffle=True
)

vali_loader = DataLoader(
    dataset=CustomDataset(loaded_ply_angle[35000:45000], loaded_ply_01encoded[35000:45000], loaded_normalized_LP_tensor[35000:45000], loaded_normalized_Index_tensor[35000:45000]),
    batch_size=batch_size,
    shuffle=True
)

test_loader = DataLoader(
    dataset=CustomDataset(loaded_ply_angle[45000:50000], loaded_ply_01encoded[45000:50000], loaded_normalized_LP_tensor[45000:50000], loaded_normalized_Index_tensor[45000:50000]),
    batch_size=batch_size,
    shuffle=True
)
    
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.colors as mcolors

# Function to compute accuracy of the autoencoder by comparing input and output
def accuracy(data1, data2):
    total_elements = torch.numel(data1)
    diff = torch.abs(data1 - data2)
    accuracy = torch.sum(diff == 0) / total_elements
    return accuracy.item()

# Function to decode ply angles from the binary representation
angle_docode_mapping = {
    '1000000000000':0, '0100000000000':-15, '0010000000000':-30, '0001000000000':-45, '0000100000000':-60,
    '0000010000000':-75, '0000001000000':-90, '0000000100000':90, '0000000010000':75, '0000000001000':60,
    '0000000000100':45, '0000000000010':30, '0000000000001':15,
}

def decode_angles(encoded_angles):
    num_angles = 64
    decoded_angles = np.zeros((encoded_angles.shape[0], num_angles), dtype=int)

    for i in range(encoded_angles.shape[0]):
        for j in range(num_angles):
            start_idx = j * 13
            end_idx = (j + 1) * 13
            angle_key = ''.join(str(bit) for bit in encoded_angles[i, start_idx:end_idx].astype(int))
            decoded_angles[i, j] = angle_docode_mapping.get(angle_key, -90)
            
    return np.concatenate((decoded_angles, encoded_angles[:,-1].reshape(-1,1)), axis=1)

# Plot two ply layer arrays for comparison
def plot_two_ply(array1, array2, save=False, filename='my_plot.png'):
    angle_array1 = array1[:-1]
    transposed_array1 = angle_array1.reshape(64, 13).T
    sym_column1 = np.full((13, 1), array1[-1])
    ply1 = np.concatenate((transposed_array1, sym_column1), axis=1)
    
    angle_array2 = array2[:-1]
    transposed_array2 = angle_array2.reshape(64, 13).T
    sym_column2 = np.full((13, 1), array2[-1])
    ply2 = np.concatenate((transposed_array2, sym_column2), axis=1)
    
    fig, axs = plt.subplots(1, 2, figsize=(8, 1))  # 1 row, 2 columns

    # Create a custom colormap that goes from white to dark gray
    colors = [(222/255.0, 239/255.0, 251/255.0), (180/255.0, 210/255.0, 217/255.0), (50/255.0, 97/255.0, 115/255.0), (3/255.0, 23/255.0, 64/255.0), (1/255.0, 14/255.0, 38/255.0)]
    cmap_name = 'my_colormap'
    cm = mcolors.LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

    # Plot the first array with the custom colormap and add the grid
    axs[0].imshow(ply1, cmap=cm, aspect='auto', origin='lower')
    axs[0].grid(True)
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    # Plot the second array with the custom colormap and add the grid
    im = axs[1].imshow(ply2, cmap=cm, aspect='auto', origin='lower')
    axs[1].grid(True)
    axs[1].set_xticks([])
    axs[1].set_yticks([])
    
    cbar = fig.colorbar(im, ax=axs[1], fraction=0.046, pad=0.04, ticks=[0, 1])
    cbar.ax.set_yticklabels(['0', '1'])
    
    plt.tight_layout()
    if save:
        plt.savefig(filename, dpi=900)  # Save the image
    plt.show()
    
def plot_two_arrays(array1, array2):
    assert array1.ndim == 1 and array2.ndim == 1, "Both arrays must be 1-D."
    
    x_axis = np.arange(max(len(array1), len(array2)))
    plt.figure(figsize=(10, 1))
    plt.plot(x_axis[:len(array1)], array1, color='red', label='Input', linewidth=0.5, alpha=0.8)
    plt.plot(x_axis[:len(array2)], array2, color='blue', label='Output', linewidth=0.5, alpha=0.8)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.grid(True)
    plt.xlabel('Index')
    plt.ylabel('Value')
    plt.show()

# Define the Encoder module
class Encoder(nn.Module):
    def __init__(self, seq_length, z_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=13, stride=13, padding=0)
        self.bn1 = nn.BatchNorm1d(16)  
        
        self.conv2_1 = nn.Conv1d(in_channels=16, out_channels=8, kernel_size=2, stride=1, padding=0)
        self.bn2_1 = nn.BatchNorm1d(8)
        
        self.conv2_3 = nn.Conv1d(in_channels=16, out_channels=8, kernel_size=4, stride=1, padding=2)
        self.bn2_3 = nn.BatchNorm1d(8)
        
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(1025, 256)
        self.bn3 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(1025, 256)
        self.bn4 = nn.BatchNorm1d(256)
        self.fc_mean = nn.Linear(512, z_dim)
        self.fc_logvar = nn.Linear(512, z_dim)
        
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x_angle = x[:, :832].unsqueeze(1)
        x_sym = x[:, 832].unsqueeze(1)
        
        xconv1 = self.dropout(F.relu(self.bn1(self.conv1(x_angle))))        #torch.Size([50, 16, 64])
        xconv2_1 = self.dropout(F.relu(self.bn2_1(self.conv2_1(xconv1))))  #torch.Size([50, 8, 63])
        xconv2_3 = self.dropout(F.relu(self.bn2_3(self.conv2_3(xconv1))))  #torch.Size([50, 8, 65]) 
        xconv2 = torch.cat((xconv2_1, xconv2_3), dim=2)                    #torch.Size([50, 8, 128]) 
        
        xflattenconv1 = self.flatten(xconv1)                               #torch.Size([50, 16*64]) 
        xflattenconv2 = self.flatten(xconv2)                               #torch.Size([50, 8*128]) 
        x1 = torch.cat((xflattenconv1, x_sym), dim=1)                      #torch.Size([50, 1025])
        x2 = torch.cat((xflattenconv2, x_sym), dim=1)                      #torch.Size([50, 1025])
        xfc1 =  F.relu(self.bn3(self.fc1(x1)))                             #torch.Size([50, 256])
        xfc2 =  F.relu(self.bn4(self.fc2(x2)))                             #torch.Size([50, 256])
        xfc = torch.cat((xfc1, xfc2), dim=1)                               #torch.Size([50, 512])
        mean = self.fc_mean(xfc)                                           #torch.Size([50, z_dim])
        mean[:, :9] = torch.sigmoid(mean[:, :9])
        logvar = self.fc_logvar(xfc)                                       #torch.Size([50, z_dim])
        return mean, logvar

# Define the Decoder module
class Decoder(nn.Module):
    def __init__(self, seq_length, z_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(z_dim, 1024)
        self.bn1 = nn.BatchNorm1d(1024)
        
        self.deconv1 = nn.ConvTranspose1d(in_channels=16, out_channels=4, kernel_size=13, stride=13, padding=0)
        self.bn_deconv1 = nn.BatchNorm1d(4)
        self.flatten = nn.Flatten()
        self.fc2 = nn.Linear(52*64, 176)
        self.bn2 = nn.BatchNorm1d(176)
        
        self.fc3 = nn.Linear(1200, seq_length)
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=2)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        xfc1 = self.dropout(F.relu(self.bn1(self.fc1(x))))                             #torch.Size([50, 1024])
        xfc1_reshape = xfc1.view(xfc1.size(0), 16, 64)                                 #torch.Size([50, 16, 64])
        
        xdeconv1 = self.dropout(F.relu(self.bn_deconv1(self.deconv1(xfc1_reshape))))   #torch.Size([50, 4, 64*13])
        xflatten_deconv1 = self.flatten(xdeconv1)                                      #torch.Size([50, 4*64*13])
        xfc2 = self.dropout(F.relu(self.bn2(self.fc2(xflatten_deconv1))))              #torch.Size([50, 176])
        
        x_cat = torch.cat((xfc1, xfc2), dim=1)                                         #torch.Size([50, 1200])
        x = self.fc3(x_cat)
        x1 = x[:,:832].view(x.size(0), 64, 13)
        x1 = self.softmax(x1)
        x1 = x1.view(x1.size(0), -1)
        x2 = self.sigmoid(x[:,832].unsqueeze(1))
        x = torch.cat((x1,x2), dim=1)
        return x

# Define the VAE module
class VAE(nn.Module):
    def __init__(self, seq_length, z_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(seq_length, z_dim)
        self.decoder = Decoder(seq_length, z_dim)

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar[:, 9:])
        eps = torch.randn_like(std)
        mean[:, 9:] += eps * std
        return mean

    def forward(self, x):
        mean, logvar = self.encoder(x)
        z = self.reparameterize(mean, logvar)
        x_recon = self.decoder(z)
        return x_recon, mean, logvar

seq_length = 833
z_dim = 128
vae_model = VAE(seq_length, z_dim)

# Define hybrid loss function (MAE + MRE)
def hybrid_loss(y_true, y_pred, alpha=0.9, epsilon=0.01):
    mae = F.l1_loss(y_true, y_pred)
    mre = torch.mean(torch.abs((y_true - y_pred) / (torch.abs(y_true) + epsilon)))
    return alpha * mae + (1 - alpha) * mre

def vae_loss(x_recon, x, mean, logvar, norm_prop, norm_index):
    beta0, beta1, beta2 = 0.3, 0.2, 5000
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
    kld_loss = -0.5 * torch.sum(1 + logvar[:,9:] - mean[:,9:].pow(2) - logvar[:,9:].exp())
    reg_loss1 = (F.mse_loss(mean[:,:6], norm_prop[:,:6]) + F.mse_loss(mean[:,6], norm_prop[:,6]))*5
    reg_loss2 = F.mse_loss(mean[:,7:9], norm_index)
    return beta0 * recon_loss, beta1 * kld_loss, beta2 * (reg_loss1 + reg_loss2)

# Training setup
optimizer = optim.Adam(vae_model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, threshold=0.1, verbose=True, min_lr=1e-12)

# Define variables to track training and validation losses
train_losses = []
recon_losses = []
kld_losses = []
reg_losses = []
noise_losses = []
vali_losses = []
vali_recon_losses = []
vali_kld_losses = []
vali_reg_losses = []
vali_noise_losses = []
num_epochs = 500
validate_every = 5  # Validate every 5 epochs

# Start training timer
start_time = time.time()
print("Start Training! Time:", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start_time)))

for epoch in range(num_epochs):
    vae_model.train()
    loss_train = loss_recon = loss_kld = loss_reg = loss_noise = 0.0
    tp, acc = 0, 0 # For tracking training progress and accuracy

    for batch in train_loader:
        raw_data_batch = batch["Input01"].to(torch.float32)
        norm_LP_batch = batch["Norm_LP"].to(torch.float32)
        norm_index_batch = batch["Norm_index"].to(torch.float32)
        optimizer.zero_grad()
        
        recon_batch, mean, logvar = vae_model(raw_data_batch)
        recon_loss, kld_loss, reg_loss = vae_loss(recon_batch, raw_data_batch, mean, logvar, norm_LP_batch, norm_index_batch)
        
        hidden_with_noise = mean.clone()
        hidden_with_noise[:, 9:] = torch.randn_like(hidden_with_noise[:, 9:])
        recon_with_noise = vae_model.decoder(hidden_with_noise)
        recon_with_noise_angle = np.eye(13)[np.argmax(recon_with_noise[:,:832].detach().cpu().numpy().reshape(-1,64,13), axis = 2)].reshape(-1,64*13)
        decoded_recon_with_noise = decode_angles(np.concatenate((recon_with_noise_angle,recon_with_noise[:,-1].detach().cpu().numpy().round().reshape(-1,1)), axis=1))    
        noise_loss = 0
        
        for j in range(len(decoded_recon_with_noise)):
            row_angles = decoded_recon_with_noise[j].reshape(1,-1)
            if row_angles.size == 0:
                row_angles = np.resize(row_angles, (1, 1))
            LP = LaminateParameters(row_angles)
            recon_LP = torch.Tensor(LP[0:3] + LP[4:7] + [LP[8]])
            normalized_recon_LP = normalize_data(recon_LP, LP_min_vals, LP_max_vals)
            noise_loss += (F.mse_loss(normalized_recon_LP[:6], hidden_with_noise[j, :6])+F.mse_loss(normalized_recon_LP[-1], hidden_with_noise[j, 6]))*300
            
            contiguity_constraint, gathering_index, adjacent_angle_index = check_conditions(row_angles)
            recon_index = torch.Tensor([gathering_index, adjacent_angle_index])
            normalized_recon_index = normalize_data(recon_index, index_min_vals, index_max_vals)
            noise_loss += F.mse_loss(hidden_with_noise[j, 7:9], normalized_recon_index)*30
                            
        train_loss = recon_loss + kld_loss + reg_loss + noise_loss
        
        train_loss.backward()
        torch.nn.utils.clip_grad_norm_(vae_model.parameters(), max_norm=20.0)
        optimizer.step()
        loss_train += train_loss.item()
        loss_recon += recon_loss.item()
        loss_kld += kld_loss.item()
        loss_reg += reg_loss.item()
        loss_noise += noise_loss.item()
    
    if (epoch + 1) % validate_every == 0:
        vae_model.eval()
        with torch.no_grad():
            loss_vali = loss_recon_vali = loss_kld_vali = loss_reg_vali = loss_noise_vali = 0.0

            for batch in vali_loader:
                raw_data_vali_batch = batch["Input01"].to(torch.float32)
                norm_LP_batch_vali = batch["Norm_LP"].to(torch.float32)
                norm_index_batch_vali = batch["Norm_index"].to(torch.float32)
                ply_angle_vali_batch = batch["Ply_angle"].to(torch.float32)
                
                recon_vali_batch, mean_vali, logvar_vali = vae_model(raw_data_vali_batch)
                recon_loss_vali, kld_loss_vali, reg_loss_vali = vae_loss(recon_vali_batch, raw_data_vali_batch, mean_vali, logvar_vali, norm_LP_batch_vali, norm_index_batch_vali)
                recon_vali_batch_angle = np.eye(13)[np.argmax(recon_vali_batch[:,:832].detach().cpu().numpy().reshape(-1,64,13), axis = 2)].reshape(-1,64*13)
                decoded_recon_vali_batch = decode_angles(np.concatenate((recon_vali_batch_angle,recon_vali_batch[:,-1].detach().cpu().numpy().reshape(-1,1)), axis=1))
                
                acc += accuracy(ply_angle_vali_batch, decoded_recon_vali_batch)
                
                hidden_with_noise = mean_vali.clone()
                hidden_with_noise[:, 9:] = torch.randn_like(hidden_with_noise[:, 9:])
                recon_with_noise = vae_model.decoder(hidden_with_noise)
                recon_with_noise_angle = np.eye(13)[np.argmax(recon_with_noise[:,:832].detach().cpu().numpy().reshape(-1,64,13), axis = 2)].reshape(-1,64*13)
                decoded_recon_with_noise = decode_angles(np.concatenate((recon_with_noise_angle,recon_with_noise[:,-1].detach().cpu().numpy().round().reshape(-1,1)), axis=1))    
                noise_loss_vali = 0
                
                for j in range(len(decoded_recon_with_noise)):
                    row_angles = decoded_recon_with_noise[j].reshape(1,-1)
                    if row_angles.size == 0:
                        row_angles = np.resize(row_angles, (1, 1))
                    LP = LaminateParameters(row_angles)
                    recon_LP = torch.Tensor(LP[0:3] + LP[4:7] + [LP[8]])
                    normalized_recon_LP = normalize_data(recon_LP, LP_min_vals, LP_max_vals)
                    noise_loss_vali += (F.mse_loss(normalized_recon_LP[:6], hidden_with_noise[j, :6])+F.mse_loss(normalized_recon_LP[-1], hidden_with_noise[j, 6]))*300
                    
                    contiguity_constraint, gathering_index, adjacent_angle_index = check_conditions(row_angles)
                    recon_index = torch.Tensor([gathering_index, adjacent_angle_index])
                    normalized_recon_index = normalize_data(recon_index, index_min_vals, index_max_vals)
                    noise_loss_vali += F.mse_loss(hidden_with_noise[j, 7:9], normalized_recon_index)*30
                    
                loss_vali += recon_loss_vali + kld_loss_vali + reg_loss_vali + noise_loss_vali
                loss_recon_vali += recon_loss_vali
                loss_kld_vali += kld_loss_vali
                loss_reg_vali += reg_loss_vali
                loss_noise_vali += noise_loss_vali

            vali_losses.append(loss_vali / len(vali_loader.dataset))
            vali_recon_losses.append(loss_recon_vali / len(vali_loader.dataset))
            vali_kld_losses.append(loss_kld_vali / len(vali_loader.dataset))
            vali_reg_losses.append(loss_reg_vali / len(vali_loader.dataset))
            vali_noise_losses.append(loss_noise_vali / len(vali_loader.dataset))
            
            acc /= (len(vali_loader.dataset)/batch_size)

            print(f"Epoch [{epoch + 1}/{num_epochs}] Validation Loss: {loss_vali / len(vali_loader.dataset):.4f}, "
                  f"Recon Loss: {loss_recon_vali / len(vali_loader.dataset):.4f}, "
                  f"KLD Loss: {loss_kld_vali / len(vali_loader.dataset):.4f}, "
                  f"Reg Loss: {loss_reg_vali / len(vali_loader.dataset):.4f}, "
                  f"Noise Loss: {loss_noise_vali / len(vali_loader.dataset):.4f}")

            # Adjust learning rate based on validation loss
            scheduler.step(loss_vali / len(vali_loader.dataset))
    
    train_losses.append(loss_train / len(train_loader.dataset))
    recon_losses.append(loss_recon / len(train_loader.dataset))
    kld_losses.append(loss_kld / len(train_loader.dataset))
    reg_losses.append(loss_reg / len(train_loader.dataset))
    noise_losses.append(loss_noise / len(train_loader.dataset))
    print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {loss_train / len(train_loader.dataset):.4f} "
          f"- Recon Loss: {loss_recon / len(train_loader.dataset):.4f} - KLD Loss: {loss_kld / len(train_loader.dataset):.4f} - Reg Loss: {loss_reg / len(train_loader.dataset):.4f}- Noise Loss: {loss_noise / len(train_loader.dataset):.4f}")

# End training timer
end_time = time.time()
print("Training completed!")
print(f"Training duration: {end_time - start_time:.2f} seconds")

plt.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
plt.plot(range(1, num_epochs + 1), recon_losses, label='Train Recon Loss')
plt.plot(range(1, num_epochs + 1), kld_losses, label='Train KLD Loss')
plt.plot(range(1, num_epochs + 1), reg_losses, label='Train Reg Loss')
plt.plot(range(1, num_epochs + 1), noise_losses, label='Train Noise Loss')
plt.plot(range(validate_every, num_epochs + 1, validate_every), vali_losses, label='Vali Loss')
plt.plot(range(validate_every, num_epochs + 1, validate_every), vali_recon_losses, label='Vali Recon Loss')
plt.plot(range(validate_every, num_epochs + 1, validate_every), vali_kld_losses, label='Vali KLD Loss')
plt.plot(range(validate_every, num_epochs + 1, validate_every), vali_reg_losses, label='Vali Reg Loss')
plt.plot(range(validate_every, num_epochs + 1, validate_every), vali_noise_losses, label='Vali Noise Loss')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Losses vs. Epoch')
plt.legend()
plt.show()

torch.save(vae_model.state_dict(), '0304vae_model.pth')

state_dict = torch.load('0304vae_model.pth')
vae_model.load_state_dict(state_dict)

# Design function to generate random designs based on fixed and random components
def Design(fixed_components, random_components, num_total, random_G, random_A):
    with torch.no_grad():  
        fixed_components_repeated = fixed_components.repeat(num_total, 1)
        if fixed_components[-1] == -1: fixed_components_repeated[:, -1] = random_A.view(-1)
        if fixed_components[-2] == -1: fixed_components_repeated[:, -2] = random_G.view(-1)
        hidden_with_noise = torch.cat((fixed_components_repeated, random_components), dim=1)
        recon_with_noise = vae_model.decoder(hidden_with_noise)
        recon_with_noise_angle = np.eye(13)[np.argmax(recon_with_noise[:,:832].detach().cpu().numpy().reshape(-1,64,13), axis = 2)].reshape(-1,64*13)
        decoded_recon_with_noise = decode_angles(np.concatenate((recon_with_noise_angle,recon_with_noise[:,-1].detach().cpu().numpy().round().reshape(-1,1)), axis=1))    
        
    return decoded_recon_with_noise

def denormalize_data(normalized_tensor, min_vals, max_vals):
    denormalized_tensor = normalized_tensor * (max_vals - min_vals) + min_vals
    return denormalized_tensor

# Example of how to call VAE to achieve design for 100-ply laminate

# Normalize target laminate parameters
target_LP = torch.Tensor([0, 0, 0, 0, 0, 0, 100])
norm_target_LP = normalize_data(target_LP, LP_min_vals, LP_max_vals)
fixed_components = torch.cat((norm_target_LP,torch.Tensor([-1,-1])))

# Generate random components for design
num_total = 100000
random_components = torch.randn(num_total, z_dim-9)
random_A = torch.randint(0, 4, size=(num_total, 1)).float() / 3
random_G = torch.randint(0, 4, size=(num_total, 1)).float() / 3

# Start timing
start_time = time.time()
start_time_readable = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start_time))
print("Start Solving! Start Time:", start_time_readable)

design = Design(fixed_components, random_components, num_total, random_G, random_A)
end_time = time.time()  # End timing
print("Design solving time:", end_time - start_time)