In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from torchvision.utils import make_grid
import torchvision.transforms as T

from models._styleextraction import StyleExtractor

In [None]:
def plot_style_embedding(style_stats, use_tsne=True):
    mu = style_stats.mu.squeeze(-1).squeeze(-1).squeeze(-1)  # [L, D, C]
    sig = style_stats.sig.squeeze(-1).squeeze(-1).squeeze(-1)

    L, D, C = mu.shape
    embeddings, labels = [], []

    for layer in range(L):
        for domain in range(D):
            style_vec = torch.cat([mu[layer, domain], sig[layer, domain]], dim=0)
            embeddings.append(style_vec.numpy())
            labels.append(f"L{layer}-D{domain}")

    X = np.stack(embeddings)
    reducer = TSNE(n_components=2, perplexity=5) if use_tsne else PCA(n_components=2)
    X_embedded = reducer.fit_transform(X)

    plt.figure(figsize=(8, 6))
    for i, label in enumerate(labels):
        plt.scatter(X_embedded[i, 0], X_embedded[i, 1], label=label)
    plt.title("t-SNE Embedding: mu + sig")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
def visualize_feature_maps(before, after, num_channels=4):
    """
    Visualises feature maps before and after MixStyle
    before/after: Tensor [B, C, H, W]
    """
    before = before[0, :num_channels]  # first sample
    after = after[0, :num_channels]

    before_grid = make_grid(before.unsqueeze(1), normalize=True, nrow=num_channels).squeeze(0)
    after_grid = make_grid(after.unsqueeze(1), normalize=True, nrow=num_channels).squeeze(0)

    fig, axs = plt.subplots(2, 1, figsize=(12, 5))
    axs[0].imshow(before_grid.cpu(), cmap='viridis')
    axs[0].set_title("Before MixStyle")
    axs[0].axis('off')

    axs[1].imshow(after_grid.cpu(), cmap='viridis')
    axs[1].set_title("After MixStyle")
    axs[1].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
def plot_mu_heatmap(style_stats, channel=0):
    mu = style_stats.mu.squeeze(-1).squeeze(-1).squeeze(-1)  # [L, D, C]
    heatmap_data = mu[:, :, channel].T.numpy()  # [Domain, Layer]
    sns.heatmap(heatmap_data, annot=True, fmt=".2f", cmap="magma")
    plt.xlabel("Layer")
    plt.ylabel("Domain")
    plt.title(f"mu Heatmap (Channel {channel})")
    plt.show()

In [None]:
def plot_std_heatmap(style_stats, channel=0):
    sig = style_stats.sig.squeeze(-1).squeeze(-1).squeeze(-1)  # [L, D, C]
    heatmap_data = sig[:, :, channel].T.numpy()  # [Domain, Layer]
    sns.heatmap(heatmap_data, annot=True, fmt=".2f", cmap="magma")
    plt.xlabel("Layer")
    plt.ylabel("Domain")
    plt.title(f"std Heatmap (Channel {channel})")
    plt.show()

In [None]:
style = StyleExtractor
style_stats = style.load("style_stats.json")

# t-SNE Embedding
plot_style_embedding(style_stats)

# feature map comparison (f.e. from one training run)
# before = some_feature_map.clone()
# after = mixstyle_layer(some_feature_map)
# visualize_feature_maps(before, after)

# heatmap
plot_mu_heatmap(style_stats, channel=0)
plot_std_heatmap(style_stats, channel=0)
