In [None]:
import os
import sys
import cv2
import torch
import numpy as np
import random
from utils.dataset import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from torch.utils.data import DataLoader
from models.autoencoder import *
import matplotlib.gridspec as gridspec
from new_model.Translation import *
from torch.utils.data import DataLoader, TensorDataset
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from operator import itemgetter


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt = torch.load('logs_tran/AE_4D_2023_07_15__01_00_19/ckpt_3.195090_10000.pt', map_location=torch.device(device))
model = Translation(ckpt['args']).to(device)
model.load_state_dict(ckpt['state_dict'])

In [None]:
def predictx(model, input_data):
    gen_y = []
    gen_x = []
    with torch.no_grad():
        z = model.xtoyencode(input_data)
        yprime = model.xtoydecode(z,input_data.size(1))
        gen_y.append(yprime.detach().cpu())
        
        z2 = model.ytoxencode(yprime)
        xprime = model.ytoxdecode(z2,yprime.size(1))
        gen_x.append(xprime.detach().cpu())
        
    gen_y = torch.cat(gen_y, dim=0)
    gen_x = torch.cat(gen_x, dim=0)
    return gen_y, gen_x

In [None]:
def predicty(model, input_data):
    gen_y = []
    gen_x = []
    with torch.no_grad():
        z = model.ytoxencode(input_data)
        xprime = model.ytoxdecode(z,input_data.size(1))
        gen_x.append(xprime.detach().cpu())
        
        z2 = model.xtoyencode(xprime)
        yprime = model.xtoydecode(z2,xprime.size(1))
        gen_y.append(yprime.detach().cpu())
        
    gen_y = torch.cat(gen_y, dim=0)
    gen_x = torch.cat(gen_x, dim=0)
    return gen_y, gen_x

In [None]:
def predictae(model, input_data):
    gen_pcs = []
    with torch.no_grad():   
        z = model.encode(input_data)
        x = model.decode(z,input_data.size(1))
        gen_pcs.append(x.detach().cpu())
    gen_pcs = torch.cat(gen_pcs, dim=0)
    return gen_pcs

In [None]:
def generate(model, input_data, batch_idx, i):
    colors=(238, 75, 43)
    yprime, xprime = predicty(model, input_data)
    #yprime, xprime = predictx(model, input_data)
    gs = gridspec.GridSpec(2, 4, width_ratios=[1, 1, 1, 0.05])  # GridSpec with a ratio for the color bar
    fig = plt.figure()
    ax0 = plt.subplot(gs[0], projection='3d')
    ax1 = plt.subplot(gs[1], projection='3d')
    ax2 = plt.subplot(gs[2], projection='3d')
    axcbar = plt.subplot(gs[3], projection='3d')

    x = input_data[i,:,0]
    y = input_data[i,:,1] 
    z = input_data[i,:,2] 
    c = input_data[i,:,3]

    img_sim = ax0.scatter(x, y, z, c=c, cmap=plt.cool(), s = 1, vmin = -1.5, vmax = 0.5)
    ax0.set_title('Original')

    x = xprime[i,:,0]
    y = xprime[i,:,1]
    z = xprime[i,:,2]
    c = xprime[i,:,3]

    img_exp = ax1.scatter(x, y, z, c=c, cmap=plt.cool(), s = 1, vmin = -1.5, vmax = 0.5)
    ax1.set_title('Translation')
    
    x = yprime[i,:,0]
    y = yprime[i,:,1] 
    z = yprime[i,:,2] 
    c = yprime[i,:,3]

    img_sim = ax2.scatter(x, y, z, c=c, cmap=plt.cool(), s = 1, vmin = -1.5, vmax = 0.5)
    ax2.set_title('Reconstruction')
    
    for ax in [ax0, ax1, ax2]:
        ax.set_xlabel("Z-axis")
        ax.set_ylabel("X-axis")
        ax.set_zlabel("Y-axis")
        ax.set_xlim([-2.5, 2.5])
        ax.set_ylim([-2.5, 2.5])
        ax.set_zlim([-2.5, 2.5])
        fig.colorbar(img_sim, ax=ax, label='Fourth dimension')
    
    plt.tight_layout()
    #plt.savefig(f'plots/plot_{idx}.png')
    plt.show()


In [None]:
def generateae(model, input_data,i):
    colors=(238, 75, 43)
    points = predictae(model, input_data)
    num_points=points.shape[0]
        
    fig, axs = plt.subplots(1, 2, subplot_kw={'projection':'3d'})
    ax = axs[0]
    x = input_data[i,:,0]
    y = input_data[i,:,1]
    z = input_data[i,:,2]
    c = input_data[i,:,3] 
    ax.scatter(x,y, z, c=c, s=1, cmap=plt.cool())
    ax.set_xlabel("Z-axis")
    ax.set_ylabel("X-axis")
    ax.set_zlabel("Y-axis")
    
    # ax.set_title('Original')

    ax = axs[1]
    x = points[i,:,0]
    y = points[i,:,1] 
    z = points[i,:,2] 
    c = points[i,:,3] 
    ax.scatter(x,y, z, c=c, s=1, cmap=plt.cool())

    ax.set_xlabel("Z-axis")
    ax.set_ylabel("X-axis")
    ax.set_zlabel("Y-axis")
    ax.set_title('Generated')
    
    plt.tight_layout()
    plt.savefig(f'plots/plot_{idx}.png')
    plt.show()

In [None]:
# Unpair Translation
# Load the data
sim_data = np.load('data/Mg22_Unpair/Mg22_simulated_deleted.npy')
exp_data = np.load('data/Mg22_Unpair/Mg22_simulated_undeleted.npy')
# Convert to PyTorch Tensors
sim_data = torch.from_numpy(sim_data).float()
exp_data = torch.from_numpy(exp_data).float()

# Create TensorDatasets
sim_dset = TensorDataset(sim_data)
exp_dset = TensorDataset(exp_data)

train_loader = DataLoader(
    exp_dset,
    batch_size=4,
    num_workers=0,
)


In [None]:
for batch_idx, batch in enumerate(train_loader):
    print("Batch:", batch_idx)
    test_data = batch[0].to(device)

    for event_idx in range(test_data.shape[0]):  # Loop over each event in the batch
        print("Event:", event_idx)
        generate(model, test_data, batch_idx, event_idx)

## Video

In [None]:
models = []
path = "logs_tran/AE_4D_2023_07_07__20_14_06"

# Get a list of all the .pt files in the directory
files = [f for f in os.listdir(path) if f.endswith('.pt')]

# Get last modification time for each file and iteration number from the file name
files_info = [(f, os.path.getmtime(os.path.join(path, f)), int(os.path.splitext(f)[0].split('_')[2])) for f in files]

# Filter files where iteration number is divisible by 100
files_filtered = [f_info for f_info in files_info if f_info[2] % 500 == 0]

# Sort the files based on the modification time (from oldest to newest)
files_sorted = sorted(files_filtered, key=itemgetter(1))
print(files_sorted)

for file_name, mtime, iter_num in files_sorted:
    model_path = os.path.join(path, file_name)
    ckpt = torch.load(model_path, map_location=torch.device(device))
    model = Translation(ckpt['args']).to(device)
    model.load_state_dict(ckpt['state_dict'])
    models.append(model)


In [None]:
def generate_plot(model_output):
    # Unpack the model_output
    x = model_output[0, :, 0]
    y = model_output[0, :, 1]
    z = model_output[0, :, 2]
    c = model_output[0, :, 3]

    # Create the plot
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(z, x, y, c=c, cmap=plt.cm.cool, s=1, vmin=-1.5, vmax=0.5)
    ax.set_xlabel('Z')
    ax.set_ylabel('X')
    ax.set_zlabel('Y')
    ax.set_title('Output')
    ax.set_xlim([-2.5, 2.5])
    ax.set_ylim([-2.5, 2.5])
    ax.set_zlim([-2.5, 2.5])

    # Return the plot as a matplotlib figure object
    return fig

In [None]:
# Step 2: Process the Event
# Create an iterator from the DataLoader
data_iter = iter(train_loader)

# Get the first batch
event_data = next(data_iter)[0]
generate_plot(event_data)

In [None]:
# Step 3: Apply Models and Generate Plots
plot_imagesx = []
plot_imagesy = []
for model in models:
    # Apply the model to the event data
    model_outputy, model_outputx = predict(model, event_data)
    
    # Generate the plot using your existing plot generation function
    ploty = generate_plot(model_outputy)
    plotx = generate_plot(model_outputx)
    
    # Convert matplotlib figures to numpy arrays
    canvas = FigureCanvas(plotx)
    canvas.draw()
    plot_image_x = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))
    canvas = FigureCanvas(ploty)
    canvas.draw()
    plot_image_y = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))

    # Append the plot image to the list
    plot_imagesy.append(plot_image_y)
    plot_imagesx.append(plot_image_x)


In [None]:
# Step 4: Create Video
videox_filename = 'output_videox2.mp4'
videoy_filename = 'output_videoy2.mp4'
frame_rate = 1  # Adjust as needed
video_size = (plot_imagesx[0].shape[1], plot_imagesx[0].shape[0])  # Specify the desired video frame size

# Initialize the video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
videox_writer = cv2.VideoWriter(videox_filename, fourcc, frame_rate, video_size)
videoy_writer = cv2.VideoWriter(videoy_filename, fourcc, frame_rate, video_size)

# Write each plot image as a frame in the video
for plot_image in plot_imagesx:
    videox_writer.write(plot_image)

for plot_image in plot_imagesy:
    videoy_writer.write(plot_image)

# Release the video writer
videox_writer.release()
videoy_writer.release()

print("Video created successfully!")

## AutoEncoder Test

In [None]:
sim_data = np.load('data/Mg22_Unpair/Mg22_simulated_256.npy')
exp_data = np.load('data/Mg22_Unpair/Mg22_experimental_256.npy')

# Convert to PyTorch Tensors
sim_data = torch.from_numpy(sim_data).float()
exp_data = torch.from_numpy(exp_data).float()

# Create TensorDatasets
sim_dset = TensorDataset(sim_data)
exp_dset = TensorDataset(exp_data)


test_loader = DataLoader(
    exp_dset,
    batch_size=8,
    num_workers=0,
)


for i, batch in enumerate(test_loader):
    print(i, ":")
    test_data = batch[0].to(device)
    
    generateae(model, test_data, i)