In [1]:
import sys
sys.path.insert(1, '../')
print(sys.path)

['/home/cboned/Projects/Koopman-Learning/notebooks', '../', '/home/cboned/miniconda3/envs/graphocr/lib/python310.zip', '/home/cboned/miniconda3/envs/graphocr/lib/python3.10', '/home/cboned/miniconda3/envs/graphocr/lib/python3.10/lib-dynload', '', '/home/cboned/miniconda3/envs/graphocr/lib/python3.10/site-packages']


In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"


import torch
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


from collections import defaultdict
from datasets import Collator

import os

from torchvision.datasets import CIFAR10, CIFAR100
from torch.utils.data import DataLoader

from transformers import ViTForImageClassification,  CLIPModel, ViTModel, ViTImageProcessor, AutoModel, AutoModelForImageClassification, Dinov2ForImageClassification, AutoImageProcessor
from models.ode_transformer_gpt  import ViTNeuralODE
import numpy as np
import scienceplots

from torchdiffeq import odeint
import matplotlib.pyplot as plt

from sklearn.decomposition import PCA
from matplotlib.lines import Line2D
import seaborn as sns
from sklearn.manifold import TSNE

from PIL import Image

cuda


In [3]:
# Set Matplotlib style
plt.style.use("science")
plt.rcParams['text.usetex'] = False

plt.rcParams.update(
    {
        "font.size": 12,
        "axes.titlesize": 14,
        "axes.labelsize": 13,
        "legend.fontsize": 11,
        "figure.dpi": 300,
    }
)

In [4]:
DATASET_PATH = "/data/users/cboned/cifar"
train_dataset = CIFAR10(root=DATASET_PATH, download=False, train=True) ## Modify to Cifar100
validation_dataset = CIFAR10(root=DATASET_PATH, download=False, train=False)

In [5]:
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-with-registers-large')
collator = Collator(processor)

train_dloader = DataLoader(train_dataset, 
                           shuffle=False,
                           batch_size=1,
                           pin_memory=False,
                           num_workers=1,
                           collate_fn=collator.classification_collate_fn
                           )


test_dloader = DataLoader(validation_dataset,
                          batch_size=1,
                          pin_memory=False,
                          num_workers=1,
                          drop_last=True,
                          shuffle=True,
                          collate_fn=collator.classification_collate_fn

)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [6]:
EDO_CHECKPOINT_PATH = "/data/users/cboned/checkpoints/EDO_DISTILLATION_VITDINO_ON_CIFAR10_ATTENTION_DISTILL_FULL_PATH_MSE_PE_REGISTERS.pt"

edo_model = ViTNeuralODE(
    img_size=224,
    patch_size=16,
    in_chans=3,
    mlp_ratio=4,
    num_classes=10,
    embed_dim=768,
    num_heads=12,
    emulate_depth=12.0,
    time_interval=1.0,   # match 12 "layers" by integrating over [0,12]
    num_eval_steps=24,
    solver="euler",
    register_tokens=10,
    pos_embed_register_tokens=True
)




try:
    edo_model.load_state_dict(torch.load(EDO_CHECKPOINT_PATH, weights_only=True)["state_dict"])
except Exception as e:
    print(e)
    print("Matching individual Weights")
    weight_to_update = torch.load(EDO_CHECKPOINT_PATH, weights_only=True)
    for w in weight_to_update.keys():
        if edo_model.state_dict().get(w) is not None:
            edo_model.state_dict()[w].data.copy_(weight_to_update[w])

                
edo_model = edo_model.to(device)
edo_model = edo_model.eval()

In [7]:
data = next(iter(test_dloader))

inputs = data["pixel_values"].to(device)
original_image = data["raw_images"][0]
original_image_resized = original_image.resize((224, 224), resample=Image.BICUBIC)
with torch.no_grad():
    outputs_edo = edo_model(inputs["pixel_values"], output_attentions=True, output_hidden_states=True, output_attention_trajectory=True)


# Computation of the lyapunov exponents

In [8]:
import torch
import math
from torchdiffeq import odeint

# jvp for a single tangent vector v (same shape as x)
def jvp_single(func, t, x, v):
    x_req = x.detach().requires_grad_(True)
    _, jvp = torch.autograd.functional.jvp(lambda xx: func(t, xx), x_req, v, create_graph=False)
    return jvp

class AugmentedODE(torch.nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func
    def forward(self, t, y):
        x, v = y
        dxdt = self.func(t, x)
        dvdt = jvp_single(self.func, t, x, v)
        return (dxdt, dvdt)

def largest_cls_lyapunov_odeint(func, x0, t_span, renorm_every=50, device=None, verbose=True):
    """
    Largest Lyapunov exponent for perturbations *only on the CLS token*.
    - func: callable(t, x) -> dx/dt (expects x shape [B, 1+N, D])
    - x0: initial tokens (batch dim allowed; we use first sample)
    - t_span: torch.linspace time grid
    """
    if device is None:
        device = x0.device
    x = x0.detach().to(device)
    if x.shape[0] > 1:
        x = x[0:1].clone()          # work with the first batch element

    # Initialize tangent vector: zero everywhere except CLS token at index 0
    v = torch.zeros_like(x, device=device)
    cls_pert = torch.randn_like(x[0, 0:1, :], device=device)  # shape [1, D]
    cls_pert = cls_pert / (cls_pert.view(-1).norm() + 1e-12)
    v[0, 0:1, :] = cls_pert     # only CLS token gets initial perturbation

    aug = AugmentedODE(func).to(device)
    total_log = 0.0
    total_time = 0.0

    # integrate in segments of length renorm_every (in grid steps)
    for i in range(0, len(t_span) - 1, renorm_every):
        t_segment = t_span[i : i + renorm_every + 1].to(device)
        x_sol, v_sol = odeint(aug, (x, v), t_segment, method="rk4")
        x = x_sol[-1].detach()
        v = v_sol[-1]    # not detached yet

        # compute norm only of CLS components
        v_cls = v[0, 0, :]                 # shape [D]
        norm_cls = v_cls.view(-1).norm().item()
        if norm_cls < 1e-12:
            # numerical collapse: re-randomize (rare)
            v_cls = torch.randn_like(v_cls)
            norm_cls = v_cls.view(-1).norm().item()
            v[0, 0, :] = v_cls

        total_log += math.log(norm_cls)
        delta_t = t_segment[-1].item() - t_segment[0].item()
        total_time += delta_t

        # renormalize only CLS components to unit norm
        v[0, 0, :] = v[0, 0, :] / (norm_cls + 1e-12)
        v = v.detach()

        if verbose:
            est = total_log / max(total_time, 1e-12)
            print(f"[t={t_segment[-1].item():.4f}] CLS largest λ ≈ {est:.6f}")

    if total_time == 0:
        raise ValueError("No renormalizations performed: increase len(t_span) or decrease renorm_every")
    return total_log / total_time


In [9]:
tokens = edo_model.patch_embed(inputs["pixel_values"])   # shape [1, 1+N, D]
print(tokens.shape)

torch.Size([1, 207, 768])


In [10]:
func = edo_model.odefunc

In [21]:
# Define time span
t_span = torch.linspace(0, 1.0, 24).to(device)

lyap = largest_cls_lyapunov_odeint(func, tokens, t_span, renorm_every=1, verbose=False)
print("Largest Lyapunov exponent ≈", lyap/edo_model.emulate_depth)

Largest Lyapunov exponent ≈ 0.061828176344969794


## Compute the per class lyapunov exponent

In [36]:
from transformers import ViTForImageClassification
teacher_model_checkpoint = "../checkpoints/Vit_CIFAR10_DINO.pt" #
model = ViTForImageClassification.from_pretrained(teacher_model_checkpoint, attn_implementation='eager').to(device)
model.eval()

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [39]:
from sklearn.metrics import confusion_matrix
from collections import defaultdict
import numpy as np

import tqdm
feats = []
acc_edo = []
all_labels= []
all_preds_edo = []
t_span = torch.linspace(0, 1.0, 24).to(device)

# store preds and gts
labels_prediction_lv = defaultdict(list)
func = edo_model.odefunc

for batch_idx, data in tqdm.tqdm(
    enumerate(test_dloader),
    desc="Evaluating EDO",
    total=len(test_dloader),
):
    inputs: dict = data["pixel_values"].to(device)
    labels: torch.Tensor = data["labels"].to(device)

    with torch.no_grad():
        # --- student/edo ---
        output_edo = edo_model(**inputs, labels=labels, output_hidden_states=True)
        outputs_teacher = model(**inputs)
        tokens = edo_model.patch_embed(inputs["pixel_values"])
        lyap = largest_cls_lyapunov_odeint(func, tokens, t_span, renorm_every=1, verbose=False)

        lab = labels.cpu().item()
        
        states = output_edo["states"]
        logits_edo = output_edo["logits"]
        pred_edo = logits_edo.argmax(-1)
        pred_teacher = outputs_teacher.logits.argmax(-1)
        
        
        acc_edo.append((pred_edo == labels).float().mean().cpu().item())

        labels_prediction_lv[lab].append([pred_teacher.item(), pred_edo.item(), lab, lyap/edo_model.emulate_depth])

        # store preds/labels for confusion matrix
        all_labels.extend(labels.cpu().numpy().tolist())
        all_preds_edo.extend(pred_edo.cpu().numpy().tolist())

print(f"Accuracy EDO: {np.mean(acc_edo):.4f}")

Evaluating EDO: 100%|██████████| 10000/10000 [39:13<00:00,  4.25it/s] 

Accuracy EDO: 0.8721





In [40]:
import pickle as pkl
with open("./EDO_DISTILLATION_VITDINO_ON_CIFAR10_ATTENTION_DISTILL_FULL_PATH_MSE_PE_REGISTERS_LV.pkl", "wb") as file:
    pkl.dump(labels_prediction_lv, file)

In [42]:
sorted(labels_prediction_lv[0], key=lambda x: x[-1])

[[0, 0, 0, 0.042180686197459384],
 [0, 0, 0, 0.04771771527378136],
 [0, 9, 0, 0.0483287667017902],
 [1, 1, 0, 0.04940662838718673],
 [0, 1, 0, 0.04945369511149075],
 [0, 0, 0, 0.05064372385793847],
 [0, 0, 0, 0.051833827244021635],
 [0, 0, 0, 0.052799664822544],
 [0, 0, 0, 0.05294809488832786],
 [0, 0, 0, 0.05312856920948816],
 [8, 0, 0, 0.053545356503685894],
 [0, 0, 0, 0.05374164024256445],
 [0, 0, 0, 0.053782121878063964],
 [0, 7, 0, 0.053863622080684785],
 [0, 0, 0, 0.0540569900155659],
 [0, 0, 0, 0.054210988026674324],
 [0, 7, 0, 0.0543043411545392],
 [0, 0, 0, 0.0547518500376186],
 [0, 9, 0, 0.055044055784478484],
 [0, 8, 0, 0.055136905057493434],
 [0, 0, 0, 0.05514234599051659],
 [0, 0, 0, 0.05533187781100776],
 [0, 0, 0, 0.055381730620906074],
 [0, 0, 0, 0.0555796827934541],
 [0, 5, 0, 0.055728759216726785],
 [0, 0, 0, 0.055762725028624734],
 [0, 0, 0, 0.05581814145076738],
 [0, 0, 0, 0.05604490422999266],
 [3, 3, 0, 0.05609897571249398],
 [0, 0, 0, 0.056413981812920205],
 [0, 

In [44]:
with open("./EDO_DISTILLATION_VITDINO_ON_CIFAR10_ATTENTION_DISTILL_FULL_PATH_MSE_PE_REGISTERS_LV.pkl", "rb") as file:
    ob = pkl.load(file)

In [45]:
sorted(ob[0], key=lambda x: x[-1])

[[0, 0, 0, 0.042180686197459384],
 [0, 0, 0, 0.04771771527378136],
 [0, 9, 0, 0.0483287667017902],
 [1, 1, 0, 0.04940662838718673],
 [0, 1, 0, 0.04945369511149075],
 [0, 0, 0, 0.05064372385793847],
 [0, 0, 0, 0.051833827244021635],
 [0, 0, 0, 0.052799664822544],
 [0, 0, 0, 0.05294809488832786],
 [0, 0, 0, 0.05312856920948816],
 [8, 0, 0, 0.053545356503685894],
 [0, 0, 0, 0.05374164024256445],
 [0, 0, 0, 0.053782121878063964],
 [0, 7, 0, 0.053863622080684785],
 [0, 0, 0, 0.0540569900155659],
 [0, 0, 0, 0.054210988026674324],
 [0, 7, 0, 0.0543043411545392],
 [0, 0, 0, 0.0547518500376186],
 [0, 9, 0, 0.055044055784478484],
 [0, 8, 0, 0.055136905057493434],
 [0, 0, 0, 0.05514234599051659],
 [0, 0, 0, 0.05533187781100776],
 [0, 0, 0, 0.055381730620906074],
 [0, 0, 0, 0.0555796827934541],
 [0, 5, 0, 0.055728759216726785],
 [0, 0, 0, 0.055762725028624734],
 [0, 0, 0, 0.05581814145076738],
 [0, 0, 0, 0.05604490422999266],
 [3, 3, 0, 0.05609897571249398],
 [0, 0, 0, 0.056413981812920205],
 [0, 