In [1]:
from collections import defaultdict
import torch
import tqdm
import wandb
import os



import utils
import lkis as LKIS
from datasets import Collator

from torch_pca import PCA
import os

from torchvision.datasets.imagenet import ImageFolder
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader

from models.wrapper_models import ResNetForKoopmanEstimation
from transformers import ViTForImageClassification, AutoImageProcessor , ViTImageProcessor

from omegaconf import DictConfig
from hydra import initialize, compose
import matplotlib.pyplot as plt
import numpy as np


In [2]:
# Set Matplotlib style
plt.style.use("seaborn-v0_8-deep")
plt.rcParams.update(
    {
        "font.size": 12,
        "axes.titlesize": 14,
        "axes.labelsize": 13,
        "legend.fontsize": 11,
        "figure.dpi": 300,
    }
)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [3]:
STUDENT_CHECKPOINT_PATH = ""
TEACHER_CHECKPOINT_PATH = "/data/users/cboned/checkpoints/Vit_CIFAR100_first_train.pt"
DATASET_PATH = "/data/users/cboned/data/Generic/cifar"

In [4]:
train_dataset = CIFAR100(root=DATASET_PATH, download=False, train=True)
validation_dataset = CIFAR100(root=DATASET_PATH, download=False, train=False)


In [5]:
teacher_model = ViTForImageClassification.from_pretrained(TEACHER_CHECKPOINT_PATH)


processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
collator = Collator(processor)
teacher_model.to(device)
train_dloader = DataLoader(train_dataset, 
                           shuffle=True,
                           batch_size=64,
                           pin_memory=False,
                           num_workers=0,
                           collate_fn=collator.classification_collate_fn
                           )


test_dloader = DataLoader(validation_dataset,
                          batch_size=1,
                          pin_memory=False,
                          num_workers=0,
                          shuffle=True,
                          collate_fn=collator.classification_collate_fn

)


## Computing the UEV decomposition

In [6]:
import os
import seaborn as sns
import imageio
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
import numpy as np
import os

def plot_trajectory_with_vector_field(traj_batch, out_folder, out_name:str="merged_trajectories.png",  image_indices=None, merged=False):
    os.makedirs(out_folder, exist_ok=True)

    n_images = traj_batch.shape[0]
    steps = traj_batch.shape[1]

    if merged:
        plt.figure(figsize=(8, 6))
        for img_idx in range(n_images):
            traj = traj_batch[img_idx]
            x, y = traj[:, 0], traj[:, 1]
            dx = x[1:] - x[:-1]
            dy = y[1:] - y[:-1]

            plt.plot(x, y, label=f"Img {img_idx}", alpha=0.7)
            plt.quiver(x[:-1], y[:-1], dx, dy, angles='xy', scale_units='xy', scale=0.5, alpha=0.1)

        plt.title("Merged Koopman Trajectories + Phase Arrows")
        plt.xlabel("PCA Dim 1")
        plt.ylabel("PCA Dim 2")
        plt.axis("equal")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(out_folder, out_name))
        plt.close()

    else:
        for img_idx in (image_indices if image_indices is not None else range(n_images)):
            traj = traj_batch[img_idx]
            x, y = traj[:, 0], traj[:, 1]
            dx = x[1:] - x[:-1]
            dy = y[1:] - y[:-1]

            plt.figure(figsize=(6, 5))
            plt.plot(x, y, marker="o", label="Trajectory")
            plt.quiver(x[:-1], y[:-1], dx, dy, angles='xy', scale_units='xy', scale=1, color="red", label="Flow")

            for step_idx in range(len(x)):
                plt.text(x[step_idx], y[step_idx], str(step_idx), fontsize=8)

            plt.title(f"Trajectory + Phase Flow (Image {img_idx})")
            plt.xlabel("PCA Dim 1")
            plt.ylabel("PCA Dim 2")
            plt.axis("equal")
            plt.grid(True)
            plt.tight_layout()
            plt.savefig(os.path.join(out_folder, f"trajectory_image_{img_idx}.png"))
            plt.close()


In [7]:
Vt_dictionary = defaultdict()
cls_vt_dictionary = defaultdict(dict)
features_to_keep = defaultdict(list)
features_to_keep_per_class = defaultdict(list)

fet_to_keep_general = []

In [8]:
for batch_idx, data in tqdm.tqdm(
    enumerate(train_dloader),
    desc="Extracting the PCA Vt for computing the trajectories",
    total=len(train_dloader),
):
    
    inputs: dict = data["pixel_values"].to(device)
    labels: list = data["labels"].to(device)
    with torch.no_grad():
        features = teacher_model(**inputs, output_hidden_states=True)["hidden_states"]
        for idx, feat in enumerate(features):
            #feat2 = feat.reshape(-1, 768)
            features_to_keep[idx].append(feat[:, 0])
            fet_to_keep_general.append(feat[:, 0])
    
    if (batch_idx * feat.shape[0]) >= 2000:
        break

Extracting the PCA Vt for computing the trajectories:   4%|▍         | 32/782 [00:08<03:26,  3.63it/s]


In [9]:
for i in range(len(features_to_keep)):
    embeddings_to_compute_pca = torch.cat(features_to_keep[idx], dim=0)
    U_i, S_i, Vt_i = utils.perform_pca_lowrank(embeddings_to_compute_pca, n_eigenvectors=2, center=False)

    Vt_dictionary[i] = Vt_i

In [10]:
teacher_model.eval()
merged_traj = []
for batch_idx, data in tqdm.tqdm(
    enumerate(test_dloader),
    desc="Training Procedure",
    total=len(test_dloader),
):
    
    inputs: dict = data["pixel_values"].to(device)
    images = data["raw_images"]
    with torch.no_grad():
        features = teacher_model(**inputs, output_hidden_states=True)["hidden_states"]
    features_cls_trajectory = torch.cat(
        [feat.unsqueeze(1) for feat in features], dim=1
    )[:, :, 0]
    b, seq, d = features_cls_trajectory.shape
    trajectories_observation = []
    trajectories_from_vision = []
    for i in range(seq):
        trajectory_from_vision_encoder = utils.project_onto_subspace(
            A=features_cls_trajectory[:, i], Vt=Vt_dictionary[i], k=2
    )
        trajectories_from_vision.append(trajectory_from_vision_encoder)
    
    traj_teacher = torch.cat(trajectories_from_vision, dim=0).cpu().numpy()
    
    merged_traj.append(traj_teacher[None, :])

    if batch_idx >= 1000:
        break


merged_traj = np.concatenate(merged_traj, axis=0)

Training Procedure:  10%|█         | 1000/10000 [00:06<00:56, 158.99it/s]


In [11]:

def plot_vector_field_from_koopman(traj_batch, out_path="vector_field.png", density=1.2):
    traj = traj_batch.reshape(-1, 2)
    x, y = traj[:, 0], traj[:, 1]

    # Build coarse grid
    x_grid = np.linspace(x.min()-1, x.max()+1, 30)
    y_grid = np.linspace(y.min()-1, y.max()+1, 30)
    X, Y = np.meshgrid(x_grid, y_grid)

    # Estimate velocity vectors at grid points via local average
    U, V = np.zeros_like(X), np.zeros_like(Y)
    for i in range(len(x) - 1):
        xi, yi = x[i], y[i]
        dxi, dyi = x[i+1] - x[i], y[i+1] - y[i]

        dist = (X - xi)**2 + (Y - yi)**2
        weight = np.exp(-dist / 0.5)
        U += weight * dxi
        V += weight * dyi

    # Compute magnitude
    magnitude = np.sqrt(U**2 + V**2)

    # Find possible attractors (lowest 5% magnitude)
    print(np.percentile(magnitude, 99))
    attractor_mask = magnitude > np.percentile(magnitude, 99)
    attractor_x = X[attractor_mask]
    attractor_y = Y[attractor_mask]

    # Plotting
    plt.figure(figsize=(10, 8))

    # 1. Magnitude heatmap
    plt.contourf(X, Y, magnitude, levels=40, cmap="Blues", alpha=0.6)

    # 2. Streamlines
    plt.streamplot(X, Y, U, V, color=magnitude, linewidth=1.5, cmap="viridis", density=density)

    # 4. Attractor candidates
    plt.scatter(attractor_x, attractor_y, color="red", s=40, label="Possible Attractors")

    # Labels and final formatting
    plt.xlabel("PCA Dim 1")
    plt.ylabel("PCA Dim 2")
    plt.colorbar(label="Vector Magnitude")
    plt.title("Koopman-Inferred Vector Field and Attractors")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()



In [12]:
plot_out_folder = "./koopman_phase_plots_vit"
out_name = f"merge_trajectories_cls.png"
#plot_trajectory_with_vector_field(traj_batch=merged_traj, out_folder=plot_out_folder, merged=False)
plot_trajectory_with_vector_field(traj_batch=merged_traj, out_folder=plot_out_folder, out_name=out_name, merged=True)

In [13]:
plot_vector_field_from_koopman(merged_traj, out_path=f"koopman_phase_plots_vit/vector_field_cls.png")


16.147421866351873
