# Meta-Atom Defect Classification: Layered Clustering Analysis

**Hierarchical defect separation pipeline for metasurface quality control**

Layers:
1. **Missing** - Center intensity contrast
2. **Collapsed** - Darkness + area features  
3. **Stitching** - Optical measurement artifacts
4. **Irregular** - Contextual deviation (LOF)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from collections import Counter
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy import stats

# Import our module
from layered_clustering import (
    load_tiles, segment_arrays, run_layered_pipeline, export_dataset, get_summary_stats,
    extract_center_intensity, extract_darkness_area, extract_stitching_features,
    extract_rotation_symmetry, extract_neighbor_deviation, extract_anisotropy,
    extract_curvature_irregularity, extract_radial_deviation, cluster_layer
)

np.random.seed(42)
plt.rcParams.update({
    'font.size': 11,
    'figure.dpi': 120,
    'axes.titlesize': 12,
    'axes.labelsize': 11,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 9,
    'figure.titlesize': 14
})

# Color palette
DEFECT_COLORS = {
    'Good': '#27ae60',
    'Missing': '#e74c3c', 
    'Collapsed': '#8e44ad',
    'Stitching': '#f39c12',
    'Irregular': '#3498db'
}

BASE_DIR = Path('.')
META_ATOMS_DIR = BASE_DIR / 'data' / 'Meta_Atoms'
OUTPUT_DIR = BASE_DIR / 'results'
OUTPUT_DIR.mkdir(exist_ok=True)

## 1. Load Data & Run Pipeline

In [None]:
# Segment if needed
segment_arrays(BASE_DIR, META_ATOMS_DIR)

# Load tiles
tiles = load_tiles(META_ATOMS_DIR)
print(f"Loaded {len(tiles)} meta-atoms")

# Run pipeline
results = run_layered_pipeline(tiles, verbose=True)

# Summary
summary = get_summary_stats(tiles)
print(f"\nFinal Classification:")
for defect, count in sorted(summary.items(), key=lambda x: -x[1]):
    print(f"  {defect}: {count} ({100*count/len(tiles):.1f}%)")

## 2. Classification Overview

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Pie chart
ax = axes[0]
labels = list(summary.keys())
sizes = [summary[l] for l in labels]
colors = [DEFECT_COLORS.get(l, '#95a5a6') for l in labels]
explode = [0.02 if l != 'Good' else 0 for l in labels]
wedges, texts, autotexts = ax.pie(sizes, labels=labels, colors=colors, explode=explode,
                                   autopct='%1.1f%%', startangle=90, pctdistance=0.75)
ax.set_title('Defect Type Distribution')

# Bar chart
ax = axes[1]
sorted_items = sorted(summary.items(), key=lambda x: -x[1])
labels = [x[0] for x in sorted_items]
values = [x[1] for x in sorted_items]
colors = [DEFECT_COLORS.get(l, '#95a5a6') for l in labels]
bars = ax.bar(labels, values, color=colors, edgecolor='black', linewidth=0.5)
ax.set_ylabel('Count')
ax.set_title('Defect Counts')
for bar, val in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5, str(val), 
            ha='center', va='bottom', fontsize=10, fontweight='bold')

# Hierarchical breakdown
ax = axes[2]
layer_names = ['Missing', 'Collapsed', 'Stitching', 'Remaining']
layer_sizes = [summary.get('Missing', 0), summary.get('Collapsed', 0), 
               summary.get('Stitching', 0), summary.get('Good', 0) + summary.get('Irregular', 0)]
cumulative = np.cumsum([0] + layer_sizes[:-1])
layer_colors = ['#e74c3c', '#8e44ad', '#f39c12', '#27ae60']

for i, (name, size, cum, color) in enumerate(zip(layer_names, layer_sizes, cumulative, layer_colors)):
    ax.barh(0, size, left=cum, color=color, label=f'{name}: {size}', edgecolor='white', height=0.6)

ax.set_xlim(0, len(tiles))
ax.set_yticks([])
ax.set_xlabel('Number of Meta-Atoms')
ax.set_title('Sequential Layer Extraction')
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=4)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'classification_overview.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Sample Gallery by Defect Type

In [None]:
defect_types = ['Good', 'Missing', 'Collapsed', 'Stitching', 'Irregular']

fig, axes = plt.subplots(5, 12, figsize=(14, 6))

for row, defect in enumerate(defect_types):
    defect_tiles = [t for t in tiles if t.defect_type == defect]
    np.random.shuffle(defect_tiles)
    
    for col in range(12):
        ax = axes[row, col]
        if col < len(defect_tiles):
            ax.imshow(defect_tiles[col].image, cmap='gray')
        ax.axis('off')
        
        if col == 0:
            ax.set_ylabel(f'{defect}\n(n={len(defect_tiles)})', fontsize=10, rotation=0, 
                         labelpad=40, va='center', color=DEFECT_COLORS.get(defect, 'black'))

plt.suptitle('Sample Meta-Atoms by Defect Type', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'sample_gallery.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Feature Space Analysis (t-SNE)

In [None]:
# Extract all features for visualization
feature_names = ['center_mean', 'center_std', 'center_edge_contrast', 'mean_intensity', 
                 'dark_ratio', 'contour_area', 'stitching_score', 'max_intensity_jump',
                 'rotation_asymmetry', 'neighbor_deviation', 'anisotropy', 
                 'curvature_irregularity', 'radial_deviation']

# Build feature matrix
X_all = []
for t in tiles:
    row = [t.features.get(f, 0) for f in feature_names]
    X_all.append(row)
X_all = np.array(X_all)
X_all = np.nan_to_num(X_all, nan=0, posinf=0, neginf=0)

# Scale and reduce
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_all)

pca = PCA(n_components=10)
X_pca = pca.fit_transform(X_scaled)

print("Computing t-SNE (this may take a moment)...")
tsne = TSNE(n_components=2, perplexity=30, random_state=42, n_iter=1000)
X_tsne = tsne.fit_transform(X_pca[:, :5])

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# t-SNE colored by defect
ax = axes[0]
for defect in defect_types:
    mask = np.array([t.defect_type == defect for t in tiles])
    ax.scatter(X_tsne[mask, 0], X_tsne[mask, 1], c=DEFECT_COLORS.get(defect, 'gray'),
              label=f'{defect} (n={mask.sum()})', alpha=0.6, s=15, edgecolors='none')
ax.set_xlabel('t-SNE 1')
ax.set_ylabel('t-SNE 2')
ax.set_title('Feature Space Visualization (t-SNE)')
ax.legend(loc='upper right', framealpha=0.9)

# PCA colored by defect
ax = axes[1]
for defect in defect_types:
    mask = np.array([t.defect_type == defect for t in tiles])
    ax.scatter(X_pca[mask, 0], X_pca[mask, 1], c=DEFECT_COLORS.get(defect, 'gray'),
              label=f'{defect}', alpha=0.6, s=15, edgecolors='none')
ax.set_xlabel(f'PC1 ({100*pca.explained_variance_ratio_[0]:.1f}%)')
ax.set_ylabel(f'PC2 ({100*pca.explained_variance_ratio_[1]:.1f}%)')
ax.set_title('Principal Component Analysis')
ax.legend(loc='upper right', framealpha=0.9)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'feature_space_tsne_pca.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Feature Importance by Defect Type

In [None]:
# Create DataFrame
df_features = pd.DataFrame(X_scaled, columns=feature_names)
df_features['defect_type'] = [t.defect_type for t in tiles]

# Compute mean feature values per defect
mean_by_defect = df_features.groupby('defect_type')[feature_names].mean()

# Heatmap
fig, ax = plt.subplots(figsize=(14, 5))
sns.heatmap(mean_by_defect.T, cmap='RdBu_r', center=0, annot=True, fmt='.2f',
            linewidths=0.5, cbar_kws={'label': 'Z-Score'}, ax=ax)
ax.set_xlabel('Defect Type')
ax.set_ylabel('Feature')
ax.set_title('Mean Feature Values by Defect Type (Standardized)')
plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'feature_importance_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Feature Distribution Violin Plots

In [None]:
key_features = ['center_edge_contrast', 'dark_ratio', 'stitching_score', 
                'rotation_asymmetry', 'neighbor_deviation', 'radial_deviation']

fig, axes = plt.subplots(2, 3, figsize=(14, 8))
axes = axes.flatten()

for i, feat in enumerate(key_features):
    ax = axes[i]
    data = []
    for defect in defect_types:
        vals = df_features[df_features['defect_type'] == defect][feat].values
        data.append(vals)
    
    parts = ax.violinplot(data, positions=range(len(defect_types)), showmeans=True, showmedians=True)
    
    for j, (pc, defect) in enumerate(zip(parts['bodies'], defect_types)):
        pc.set_facecolor(DEFECT_COLORS.get(defect, 'gray'))
        pc.set_alpha(0.7)
    
    ax.set_xticks(range(len(defect_types)))
    ax.set_xticklabels(defect_types, rotation=45, ha='right')
    ax.set_ylabel('Z-Score')
    ax.set_title(feat.replace('_', ' ').title())
    ax.axhline(0, color='gray', linestyle='--', linewidth=0.5, alpha=0.5)

plt.suptitle('Feature Distributions by Defect Type', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'feature_violin_plots.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Statistical Significance Tests

In [None]:
# ANOVA for each feature
print("ANOVA F-test Results (Defect Type as Factor):")
print("="*60)

anova_results = []
for feat in feature_names:
    groups = [df_features[df_features['defect_type'] == d][feat].values for d in defect_types]
    groups = [g for g in groups if len(g) > 0]
    
    if len(groups) >= 2:
        f_stat, p_val = stats.f_oneway(*groups)
        sig = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else ''
        anova_results.append({'Feature': feat, 'F-statistic': f_stat, 'p-value': p_val, 'Significance': sig})
        print(f"{feat:30s} F={f_stat:8.2f}  p={p_val:.2e} {sig}")

df_anova = pd.DataFrame(anova_results).sort_values('F-statistic', ascending=False)
print("\n* p<0.05, ** p<0.01, *** p<0.001")

In [None]:
# Feature importance bar chart
fig, ax = plt.subplots(figsize=(10, 6))

df_anova_sorted = df_anova.sort_values('F-statistic', ascending=True)
colors = ['#27ae60' if p < 0.001 else '#f39c12' if p < 0.05 else '#95a5a6' 
          for p in df_anova_sorted['p-value']]

ax.barh(range(len(df_anova_sorted)), df_anova_sorted['F-statistic'], color=colors, edgecolor='black', linewidth=0.5)
ax.set_yticks(range(len(df_anova_sorted)))
ax.set_yticklabels([f.replace('_', ' ').title() for f in df_anova_sorted['Feature']])
ax.set_xlabel('ANOVA F-statistic')
ax.set_title('Feature Discriminative Power (Higher = Better Separation)')

# Legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='#27ae60', label='p < 0.001'),
    Patch(facecolor='#f39c12', label='p < 0.05'),
    Patch(facecolor='#95a5a6', label='Not significant')
]
ax.legend(handles=legend_elements, loc='lower right')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'feature_discriminative_power.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Spatial Distribution Heatmaps

In [None]:
# Create spatial grids
arrays = list(set(t.array for t in tiles))

fig, axes = plt.subplots(1, len(arrays), figsize=(5*len(arrays), 5))
if len(arrays) == 1:
    axes = [axes]

defect_to_num = {'Good': 0, 'Irregular': 1, 'Stitching': 2, 'Collapsed': 3, 'Missing': 4}

for ax, arr in zip(axes, arrays):
    arr_tiles = [t for t in tiles if t.array == arr]
    
    max_row = max(t.row for t in arr_tiles)
    max_col = max(t.col for t in arr_tiles)
    
    grid = np.zeros((max_row, max_col))
    for t in arr_tiles:
        if t.row > 0 and t.col > 0:
            grid[t.row-1, t.col-1] = defect_to_num.get(t.defect_type, -1)
    
    cmap = plt.cm.colors.ListedColormap([DEFECT_COLORS[d] for d in ['Good', 'Irregular', 'Stitching', 'Collapsed', 'Missing']])
    im = ax.imshow(grid, cmap=cmap, vmin=0, vmax=4)
    ax.set_title(arr.replace('Crop', ''))
    ax.set_xlabel('Column')
    ax.set_ylabel('Row')

# Colorbar
cbar = fig.colorbar(im, ax=axes, ticks=[0.4, 1.2, 2.0, 2.8, 3.6], shrink=0.8)
cbar.ax.set_yticklabels(['Good', 'Irregular', 'Stitching', 'Collapsed', 'Missing'])

plt.suptitle('Spatial Distribution of Defects', fontsize=14, fontweight='bold')
plt.savefig(OUTPUT_DIR / 'spatial_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

## 9. Confusion-Style Matrix (Layer Transitions)

In [None]:
# Per-layer cluster assignments
layer_names = ['layer_missing', 'layer_collapsed', 'layer_stitching']

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

for ax, layer in zip(axes, layer_names):
    # Build transition matrix
    cluster_to_defect = {}
    for t in tiles:
        if layer in t.layer_assignments:
            cluster = t.layer_assignments[layer]
            if cluster not in cluster_to_defect:
                cluster_to_defect[cluster] = Counter()
            cluster_to_defect[cluster][t.defect_type] += 1
    
    if not cluster_to_defect:
        continue
    
    clusters = sorted(cluster_to_defect.keys())
    matrix = np.zeros((len(clusters), len(defect_types)))
    
    for i, c in enumerate(clusters):
        for j, d in enumerate(defect_types):
            matrix[i, j] = cluster_to_defect[c].get(d, 0)
    
    # Normalize by row
    matrix_norm = matrix / (matrix.sum(axis=1, keepdims=True) + 1e-10)
    
    sns.heatmap(matrix_norm, annot=True, fmt='.2f', cmap='Blues', ax=ax,
                xticklabels=defect_types, yticklabels=[f'C{c}' for c in clusters],
                cbar_kws={'label': 'Proportion'})
    ax.set_xlabel('Final Defect Type')
    ax.set_ylabel('Cluster')
    ax.set_title(layer.replace('layer_', '').title())

plt.suptitle('Cluster Composition by Layer', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'cluster_composition.png', dpi=150, bbox_inches='tight')
plt.show()

## 10. Feature Correlation Matrix

In [None]:
fig, ax = plt.subplots(figsize=(12, 10))

corr = df_features[feature_names].corr()
mask = np.triu(np.ones_like(corr, dtype=bool))

sns.heatmap(corr, mask=mask, cmap='RdBu_r', center=0, annot=True, fmt='.2f',
            linewidths=0.5, ax=ax, square=True, cbar_kws={'label': 'Correlation'})

ax.set_title('Feature Correlation Matrix')
plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'feature_correlation_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

## 11. PCA Variance Explained

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Scree plot
ax = axes[0]
var_exp = pca.explained_variance_ratio_
cum_var = np.cumsum(var_exp)
ax.bar(range(1, len(var_exp)+1), var_exp, color='steelblue', alpha=0.7, label='Individual')
ax.plot(range(1, len(var_exp)+1), cum_var, 'ro-', label='Cumulative')
ax.axhline(0.9, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Principal Component')
ax.set_ylabel('Variance Explained')
ax.set_title('PCA Scree Plot')
ax.legend()

# Loadings
ax = axes[1]
loadings = pd.DataFrame(pca.components_[:3].T, columns=['PC1', 'PC2', 'PC3'], index=feature_names)
loadings_abs = loadings.abs().mean(axis=1).sort_values(ascending=True)
ax.barh(range(len(loadings_abs)), loadings_abs.values, color='steelblue')
ax.set_yticks(range(len(loadings_abs)))
ax.set_yticklabels([f.replace('_', ' ').title() for f in loadings_abs.index])
ax.set_xlabel('Mean Absolute Loading (PC1-3)')
ax.set_title('Feature Importance in PCA')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'pca_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

## 12. Export Dataset

In [None]:
# Export
df_export = export_dataset(tiles, OUTPUT_DIR / 'meta_atoms_classified.csv')

print(f"Dataset exported: {OUTPUT_DIR / 'meta_atoms_classified.csv'}")
print(f"Shape: {df_export.shape}")
print(f"\nColumns: {list(df_export.columns)}")
df_export.head()

In [None]:
# Summary statistics
print("\n" + "="*70)
print("DATASET SUMMARY")
print("="*70)
print(df_export.describe().T[['mean', 'std', 'min', 'max']])

---
## Summary

**Pipeline extracted:**

In [None]:
for defect, count in sorted(summary.items(), key=lambda x: -x[1]):
    pct = 100 * count / len(tiles)
    bar = '#' * int(pct / 2)
    print(f"{defect:12s} {count:4d} ({pct:5.1f}%) {bar}")

print(f"\nTotal: {len(tiles)} meta-atoms")
print(f"\nOutput files saved to: {OUTPUT_DIR.absolute()}")