In [None]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import torch
from models.attention_unet import AttentionUNet
from models.autoencoder import ConvAutoencoder, AutoencoderDataset
from torch.utils.data import DataLoader

In [None]:
USE_ATTENTION_UNET = True
COMPRESSED_DIM = 64
CHECKPOINT_PATH = './checkpoints/attention_unet/31500.pth' if USE_ATTENTION_UNET else './checkpoints/unet/14500.pth'

In [None]:
data = np.load('./data/tga_afm/data.npz')
X = data['X'] 
Y = data['Y'] 
samples = data['samples']

X_input = X[:, 1:3, :]

In [None]:
# Load the model and generate encodings
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

if USE_ATTENTION_UNET:
    model = AttentionUNet(ch_in=2, ch_out=2, compressed_dim=COMPRESSED_DIM)
    checkpoint = torch.load(CHECKPOINT_PATH, weights_only=False, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
else:
    model = ConvAutoencoder()
    model.load_state_dict(torch.load(CHECKPOINT_PATH, weights_only=True, map_location=device))

model.to(device)
model.double()
model.eval()

In [None]:
encodings = []
with torch.no_grad():
    for i in range(X_input.shape[0]):
        x = torch.tensor(X_input[i]).unsqueeze(0).to(device)
        encoding = model.encode(x).cpu().numpy()
        encodings.append(encoding[0])

encodings = np.array(encodings)
encoding.shape

In [None]:
scaler = StandardScaler()
encodings_scaled = scaler.fit_transform(encodings)

pca = PCA(n_components=3)
encodings_pca = pca.fit_transform(encodings_scaled)

print(f"PCA explained variance ratio: {pca.explained_variance_ratio_}")
print(f"Total variance explained: {pca.explained_variance_ratio_.sum():.2%}")

In [None]:
# Define labels to visualize
# Y dimensions: (Sample, Characteristic, Statistic)
# Characteristics: 0=Min Ferret, 1=Max Ferret, 2=Height, 3=Area, 4=Volume
# Statistics: 0=Mean, 1=Variance, 2=Skewness, 3=Kurtosis, 4=Median

label_configs = [
    {'data': Y[:, 2, 0], 'name': 'Height Mean', 'colorscale': 'Viridis'},
    {'data': Y[:, 0, 2], 'name': 'Min Ferret Skewness', 'colorscale': 'Plasma'},
    {'data': Y[:, 3, 0], 'name': 'Area Mean', 'colorscale': 'RdBu'},
    {'data': Y[:, 4, 0], 'name': 'Volume Mean', 'colorscale': 'Magma'},
]

In [None]:
model_name = "Attention U-Net" if USE_ATTENTION_UNET else "Autoencoder"

fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=[config['name'] for config in label_configs],
    specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}],
           [{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
    vertical_spacing=0.1,
    horizontal_spacing=0.05
)

for idx, config in enumerate(label_configs):
    row = idx // 2 + 1
    col = idx % 2 + 1
    
    labels = config['data']
    
    hover_text = [
        f"Sample: {samples[i]}<br>" +
        f"{config['name']}: {labels[i]:.4f}<br>" +
        f"PC1: {encodings_pca[i, 0]:.4f}<br>" +
        f"PC2: {encodings_pca[i, 1]:.4f}<br>" +
        f"PC3: {encodings_pca[i, 2]:.4f}"
        for i in range(len(samples))
    ]
    
    trace = go.Scatter3d(
        x=encodings_pca[:, 0],
        y=encodings_pca[:, 1],
        z=encodings_pca[:, 2],
        mode='markers',
        marker=dict(
            size=8,
            color=labels,
            colorscale=config['colorscale'],
            showscale=True,
            colorbar=dict(
                title=config['name'],
                len=0.4,
                x=1.0 if col == 2 else 0.45,
                y=0.75 if row == 1 else 0.25,
                thickness=15,
                tickfont=dict(size=10)
            ),
            line=dict(color='black', width=0.5),
            opacity=0.8
        ),
        text=hover_text,
        hovertemplate='%{text}<extra></extra>',
        name=config['name']
    )
    
    fig.add_trace(trace, row=row, col=col)
    
    fig.update_scenes(
        {
            f'xaxis{idx+1 if idx > 0 else ""}': dict(title=f'PC1 ({pca.explained_variance_ratio_[0]:.1%})'),
            f'yaxis{idx+1 if idx > 0 else ""}': dict(title=f'PC2 ({pca.explained_variance_ratio_[1]:.1%})'),
            f'zaxis{idx+1 if idx > 0 else ""}': dict(title=f'PC3 ({pca.explained_variance_ratio_[2]:.1%})'),
        }
    )
    
fig.update_layout(
    title=dict(
        text=f'<b>Encoding Space Visualization - {model_name}</b><br>' +
             f'<sub>3D PCA of Learned Encodings (Total variance: {pca.explained_variance_ratio_.sum():.1%})</sub>',
        x=0.5,
        xanchor='center',
        font=dict(size=16)
    ),
    height=900,
    width=1400,
    showlegend=False,
    margin=dict(l=0, r=0, t=100, b=0)
)

fig.write_html('figures/encoding_pca_visualization_interactive.html')
fig.show()

print("\nInteractive visualization saved to: figures/encoding_pca_visualization_interactive.html")
print(f"Total variance explained by 3 PCs: {pca.explained_variance_ratio_.sum():.2%}")