In [3]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import plotly.graph_objects as go
import plotly.express as px
import torch
from models.autoencoder import ConvAutoencoder
import os

# Load the labeled data first
data = np.load('data/tga_afm/data.npz')
X = data['X']  # Shape: (33, 4, 1024)
Y = data['Y']  # Shape: (33, 5, 5)
samples = data['samples']

print(f"Data loaded:")
print(f"  X shape: {X.shape}")
print(f"  Y shape: {Y.shape}")
print(f"  Samples: {len(samples)}")

# Define characteristic and statistic names
characteristics = ['Min Ferret', 'Max Ferret', 'Height', 'Area', 'Volume']
statistics = ['Mean', 'Variance', 'Skewness', 'Kurtosis', 'Median']

Data loaded:
  X shape: (33, 4, 1024)
  Y shape: (33, 5, 5)
  Samples: 112


In [4]:
# Load the attention U-Net model
model = ConvAutoencoder()
checkpoint_path = 'checkpoints/attention_unet/31000.pth'  # Use the latest checkpoint
model.load_state_dict(torch.load('checkpoints/unet/14500.pth', weights_only=True))
model.to('cpu')
model.double()
model.eval()

print(f"Model loaded from: {checkpoint_path}")

# Select which TGA curves to use (W and dW/dT)
X_input = X[:, 1:3, :]  # Shape: (33, 2, 1024)

# Generate encodings using the model
encodings = []
with torch.no_grad():
    for i in range(X_input.shape[0]):
        x_tensor = torch.tensor(X_input[i]).unsqueeze(0)  # Add batch dimension
        encoding = model.encode(x_tensor).squeeze().numpy()
        encodings.append(encoding)

encodings = np.array(encodings)
print(f"\nEncodings generated: {encodings.shape}")

Model loaded from: checkpoints/attention_unet/31000.pth

Encodings generated: (33, 64)


In [5]:
# Apply PCA to reduce to 3 dimensions
pca = PCA(n_components=3)
encodings_3d = pca.fit_transform(encodings)

# Print explained variance
print(f"PCA Analysis:")
print(f"Explained variance ratio:")
for i, var in enumerate(pca.explained_variance_ratio_):
    print(f"  PC{i+1}: {var:.4f} ({var*100:.2f}%)")
print(f"Total explained variance: {pca.explained_variance_ratio_.sum():.4f} ({pca.explained_variance_ratio_.sum()*100:.2f}%)")

PCA Analysis:
Explained variance ratio:
  PC1: 0.4753 (47.53%)
  PC2: 0.2504 (25.04%)
  PC3: 0.0961 (9.61%)
Total explained variance: 0.8218 (82.18%)


In [6]:
# Create directory for saving plots
os.makedirs('figures/pca_by_metric', exist_ok=True)

# Helper function to create a PCA plot colored by a metric
def create_pca_plot(metric_values, char_name, stat_name, samples, encodings_3d, pca):
    fig = go.Figure(data=[go.Scatter3d(
        x=encodings_3d[:, 0],
        y=encodings_3d[:, 1],
        z=encodings_3d[:, 2],
        mode='markers',
        marker=dict(
            size=6,
            color=metric_values,
            colorscale='Viridis',
            showscale=True,
            colorbar=dict(title=f"{char_name}<br>{stat_name}"),
            opacity=0.8,
            line=dict(width=0.5, color='white')
        ),
        text=[f'{samples[i]}<br>{char_name} {stat_name}: {metric_values[i]:.4f}' 
              for i in range(len(encodings_3d))],
        hovertemplate='<b>%{text}</b><br>' +
                      'PC1: %{x:.3f}<br>' +
                      'PC2: %{y:.3f}<br>' +
                      'PC3: %{z:.3f}<br>' +
                      '<extra></extra>'
    )])
    
    fig.update_layout(
        title=f'PCA colored by {char_name} - {stat_name}<br>' +
              f'Explained variance: {pca.explained_variance_ratio_.sum():.2%}',
        scene=dict(
            xaxis_title='First Principal Component',
            yaxis_title='Second Principal Component',
            zaxis_title='Third Principal Component',
        ),
        width=900,
        height=700,
        hovermode='closest'
    )
    
    # Save the plot
    filename = f'figures/pca_by_metric/{char_name.replace(" ", "_").lower()}_{stat_name.lower()}.html'
    fig.write_html(filename)
    
    return fig

print("Helper function created for generating PCA plots")

Helper function created for generating PCA plots


## Min Ferret Metrics

In [7]:
# Min Ferret - Median
metric_values = Y[:, 0, 4]
fig = create_pca_plot(metric_values, 'Min Ferret', 'Median', samples, encodings_3d, pca)
print(f"Min Ferret - Median range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Min Ferret - Median range: 14341.5703 to 963923.7621


In [8]:
# Max Ferret - Median
metric_values = Y[:, 1, 4]
fig = create_pca_plot(metric_values, 'Max Ferret', 'Median', samples, encodings_3d, pca)
print(f"Max Ferret - Median range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Max Ferret - Median range: 2208080272.9270 to 46799001900077.3750


In [9]:
# Height - Median
metric_values = Y[:, 2, 4]
fig = create_pca_plot(metric_values, 'Height', 'Median', samples, encodings_3d, pca)
print(f"Height - Median range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Height - Median range: 3.3863 to 69.9317


In [10]:
# Area - Median
metric_values = Y[:, 3, 4]
fig = create_pca_plot(metric_values, 'Area', 'Median', samples, encodings_3d, pca)
print(f"Area - Median range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Area - Median range: 13.4309 to 5316.1562


In [11]:
# Volume - Median
metric_values = Y[:, 4, 4]
fig = create_pca_plot(metric_values, 'Volume', 'Median', samples, encodings_3d, pca)
print(f"Volume - Median range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Volume - Median range: 404.3788 to 2767.6314


In [12]:
# Volume - Kurtosis
metric_values = Y[:, 4, 3]
fig = create_pca_plot(metric_values, 'Volume', 'Kurtosis', samples, encodings_3d, pca)
print(f"Volume - Kurtosis range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Volume - Kurtosis range: 1000.0000 to 1600.0000


In [13]:
# Volume - Skewness
metric_values = Y[:, 4, 2]
fig = create_pca_plot(metric_values, 'Volume', 'Skewness', samples, encodings_3d, pca)
print(f"Volume - Skewness range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Volume - Skewness range: 0.2694 to 1.6525


In [14]:
# Volume - Variance
metric_values = Y[:, 4, 1]
fig = create_pca_plot(metric_values, 'Volume', 'Variance', samples, encodings_3d, pca)
print(f"Volume - Variance range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Volume - Variance range: 58.3095 to 76.1577


In [15]:
# Volume - Mean
metric_values = Y[:, 4, 0]
fig = create_pca_plot(metric_values, 'Volume', 'Mean', samples, encodings_3d, pca)
print(f"Volume - Mean range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Volume - Mean range: 10.0000 to 30.0000


## Volume Metrics

In [16]:
# Area - Kurtosis
metric_values = Y[:, 3, 3]
fig = create_pca_plot(metric_values, 'Area', 'Kurtosis', samples, encodings_3d, pca)
print(f"Area - Kurtosis range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Area - Kurtosis range: 16.2978 to 1639.5946


In [17]:
# Area - Skewness
metric_values = Y[:, 3, 2]
fig = create_pca_plot(metric_values, 'Area', 'Skewness', samples, encodings_3d, pca)
print(f"Area - Skewness range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Area - Skewness range: 8.9346 to 1896.9503


In [18]:
# Area - Variance
metric_values = Y[:, 3, 1]
fig = create_pca_plot(metric_values, 'Area', 'Variance', samples, encodings_3d, pca)
print(f"Area - Variance range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Area - Variance range: 3.1448 to 100.0509


In [19]:
# Area - Mean
metric_values = Y[:, 3, 0]
fig = create_pca_plot(metric_values, 'Area', 'Mean', samples, encodings_3d, pca)
print(f"Area - Mean range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Area - Mean range: 3.0599 to 68.8878


## Area Metrics

In [20]:
# Height - Kurtosis
metric_values = Y[:, 2, 3]
fig = create_pca_plot(metric_values, 'Height', 'Kurtosis', samples, encodings_3d, pca)
print(f"Height - Kurtosis range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Height - Kurtosis range: 3.2310 to 34.7921


In [21]:
# Height - Skewness
metric_values = Y[:, 2, 2]
fig = create_pca_plot(metric_values, 'Height', 'Skewness', samples, encodings_3d, pca)
print(f"Height - Skewness range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Height - Skewness range: 2.6386 to 34.6976


In [22]:
# Height - Variance
metric_values = Y[:, 2, 1]
fig = create_pca_plot(metric_values, 'Height', 'Variance', samples, encodings_3d, pca)
print(f"Height - Variance range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Height - Variance range: 1.5112 to 6.6892


In [23]:
# Height - Mean
metric_values = Y[:, 2, 0]
fig = create_pca_plot(metric_values, 'Height', 'Mean', samples, encodings_3d, pca)
print(f"Height - Mean range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Height - Mean range: 1.3610 to 6.7209


## Height Metrics

In [24]:
# Max Ferret - Kurtosis
metric_values = Y[:, 1, 3]
fig = create_pca_plot(metric_values, 'Max Ferret', 'Kurtosis', samples, encodings_3d, pca)
print(f"Max Ferret - Kurtosis range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Max Ferret - Kurtosis range: 39721402.0983 to 14026548166.0273


In [25]:
# Max Ferret - Skewness
metric_values = Y[:, 1, 2]
fig = create_pca_plot(metric_values, 'Max Ferret', 'Skewness', samples, encodings_3d, pca)
print(f"Max Ferret - Skewness range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Max Ferret - Skewness range: 1.7540 to 130.2435


In [26]:
# Max Ferret - Variance
metric_values = Y[:, 1, 1]
fig = create_pca_plot(metric_values, 'Max Ferret', 'Variance', samples, encodings_3d, pca)
print(f"Max Ferret - Variance range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Max Ferret - Variance range: 3946.4162 to 65351.8536


In [27]:
# Max Ferret - Mean
metric_values = Y[:, 1, 0]
fig = create_pca_plot(metric_values, 'Max Ferret', 'Mean', samples, encodings_3d, pca)
print(f"Max Ferret - Mean range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Max Ferret - Mean range: 1455.0725 to 23008.4065


## Max Ferret Metrics

In [28]:
# Min Ferret - Kurtosis
metric_values = Y[:, 0, 3]
fig = create_pca_plot(metric_values, 'Min Ferret', 'Kurtosis', samples, encodings_3d, pca)
print(f"Min Ferret - Kurtosis range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Min Ferret - Kurtosis range: 5806.0687 to 70959.3679


In [29]:
# Min Ferret - Skewness
metric_values = Y[:, 0, 2]
fig = create_pca_plot(metric_values, 'Min Ferret', 'Skewness', samples, encodings_3d, pca)
print(f"Min Ferret - Skewness range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Min Ferret - Skewness range: 1.6867 to 11.0062


In [30]:
# Min Ferret - Variance
metric_values = Y[:, 0, 1]
fig = create_pca_plot(metric_values, 'Min Ferret', 'Variance', samples, encodings_3d, pca)
print(f"Min Ferret - Variance range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Min Ferret - Variance range: 135.6448 to 394.8019


In [31]:
# Min Ferret - Mean
metric_values = Y[:, 0, 0]
fig = create_pca_plot(metric_values, 'Min Ferret', 'Mean', samples, encodings_3d, pca)
print(f"Min Ferret - Mean range: {metric_values.min():.4f} to {metric_values.max():.4f}")
fig.show()

Min Ferret - Mean range: 78.9337 to 228.4697
