In [None]:
%load_ext autoreload
%autoreload 2

Mount drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Installing needed libraries

In [None]:
!pip install --upgrade datasets transformers
!pip install --upgrade pytorch-lightning
!pip install latentis

Imports

In [None]:
import os
import shutil
import torch
from torch import nn, optim
from pytorch_lightning import seed_everything
from transformers import (
    AutoModel,
    AutoConfig,
    AutoImageProcessor,
    CLIPVisionConfig,
    CLIPImageProcessor,
    CLIPVisionModel,
)
from datasets import DatasetDict, load_dataset, load_from_disk, DownloadConfig, VerificationMode
from tqdm import tqdm
from torch.utils.data import DataLoader
import functools
from pathlib import Path
import pandas as pd
from typing import List
import itertools
import sys

from dictionaries import DATASET2IMAGE_COLUMN, DATASET2LABEL_COLUMN, DATASET_NAME2HF_NAME, MODEL2CONFIGS, DATASET2NUM_CLASSES
from utils import image_encode, extract_representations
from module import SkipModel, HFwrapper, NoEncoder
from train_NN import train_classifier

In [None]:
# Change the project path if not running in Google Colab from Google Drive / if needed
project_path = '/content/drive/MyDrive/tba-moss' 
sys.path.append(project_path)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Encode data

In [None]:
@torch.no_grad()
def encode_data(loader, skip_encoder):
    embeddings = []
    skip_encoder.eval()

    for batch in tqdm(loader, desc="Encoding Batches with SkipModel"):
        image_input = batch.get("pixel_values", batch.get("images"))
        if image_input is None:
            raise KeyError("Batch missing required key 'pixel_values' or 'images'")
        image_input = image_input.to(device)

        attn_mask = batch.get("attention_mask", None)
        if attn_mask is not None:
            attn_mask = attn_mask.to(device)

        x = skip_encoder(image_input, attention_mask=attn_mask)
        embeddings.extend(x.cpu().tolist())

    return embeddings

@torch.no_grad()
def run_encoding((†Joint senior authorship)
    dataset_name: str,
    encoder_name: str,
    translator_name: str,
    seed: int,
    samples_to_extract: int,
    batch_size: int,
    skips: list = [[], [(10, 11)]],
    mode: int = 1,
):

    seed_everything(seed)
    split2encoding = {}

    if encoder_name not in MODEL2CONFIGS:
        raise ValueError(f"Model configuration not found for {encoder_name}. Please add it to MODEL2CONFIGS.")

    model_config = MODEL2CONFIGS[encoder_name]

    print(f"Dataset: {dataset_name}, Encoder: {encoder_name}, Translator: {translator_name}, Skips: {skips}")

    DATASET_DIR = (
        Path(project_path) / "embeddings" / dataset_name / encoder_name.split("/")[1]
    )


    DATASET_DIR.mkdir(parents=True, exist_ok=True)
    data: DatasetDict = DatasetDict(
        train=load_dataset(DATASET_NAME2HF_NAME[dataset_name], split="train"),
        test=load_dataset(DATASET_NAME2HF_NAME[dataset_name], split="test"),
    )

    print(f"Loading HF AutoModel: {encoder_name}")
    config = AutoConfig.from_pretrained(encoder_name, output_hidden_states=True, return_dict=True)
    processor = AutoImageProcessor.from_pretrained(encoder_name)
    encoder = AutoModel.from_pretrained(encoder_name, config=config)
    collate_fn = functools.partial(
        image_encode,
        processor=processor,
        image_name=DATASET2IMAGE_COLUMN[dataset_name],
        label_name=DATASET2LABEL_COLUMN[dataset_name],
    )

    encoder.eval().to(device)

    train_loader = DataLoader(
        data["train"],
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=8,
        collate_fn=collate_fn,
    )

    test_loader = DataLoader(
        data["test"],
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=8,
        collate_fn=collate_fn,
    )

    all_layer_embeddings = extract_representations(
        encoder=encoder,
        max_samples=samples_to_extract,
        loader=train_loader,
        model_config=model_config,
        model_is_open_clip=encoder_name.startswith("open_clip:"),
        seed=seed,
    )
    print(f"Captured embeddings for layers: {list(all_layer_embeddings.keys())}")

    for skip in tqdm(skips, desc="Encoding Different Skips"):
        print(f"\nProcessing skip: {skip}")

        split2encoding = {}

        skip_encoder = SkipModel(
            encoder=encoder,
            skips=skip,
            mode=mode,
            precomputed_embeddings=all_layer_embeddings,
            translator_factory_name=translator_name,
            **model_config,
        )
        skip_encoder = skip_encoder.to(device).eval()

        split2encoding["train"] = encode_data(loader=train_loader, skip_encoder=skip_encoder)
        split2encoding["test"] = encode_data(loader=test_loader, skip_encoder=skip_encoder)

        print("Saving results to disk...")
        for split, encoding in split2encoding.items():
            if not encoding:
                print(f"Warning: No embeddings generated for split '{split}', skip '{skip}'. Skipping saving.")
                continue
            column_name = str(skip)
            if column_name not in data[split].column_names:
                if len(encoding) != len(data[split]):
                    print(
                        f"Error: Encoding length ({len(encoding)}) does not match dataset length ({len(data[split])}) for split '{split}', skip '{skip}'."
                    )
                    continue
                data[split] = data[split].add_column(column_name, encoding)
            else:
                final_column_name = f"{column_name}_new"
                print(f"Column '{column_name}' already exists. Saving them with a new name: {final_column_name}")
                data[split] = data[split].add_column(final_column_name, encoding)

        del skip_encoder
        torch.cuda.empty_cache()

        if DATASET_DIR.exists():
            temp_dir = DATASET_DIR.parent / f"{DATASET_DIR.name}_temp"
            try:
                if temp_dir.exists():
                    shutil.rmtree(temp_dir)
                data.save_to_disk(str(temp_dir))
                shutil.rmtree(DATASET_DIR)
                shutil.move(str(temp_dir), DATASET_DIR)
                print(f"Saved intermediate results for skip {skip} to {DATASET_DIR}")
            except Exception as e:
                print(f"Error saving intermediate results: {e}")
                if temp_dir.exists():
                    shutil.rmtree(temp_dir)
        else:
            DATASET_DIR.mkdir(parents=True, exist_ok=True)
            data.save_to_disk(str(DATASET_DIR))
            print(f"Saved initial results for skip {skip} to {DATASET_DIR}")

In [None]:
run_encoding(
    dataset_name = "cifar10",
    encoder_name="facebook/dinov2-small",
    translator_name="linear",
    seed=0,
    samples_to_extract=500,
    batch_size = 256,
    skips = [[], [(10, 11)]], # [] is the original model

)

## Classification

In [None]:
def classification(
    dataset_name: str,
    model_name: str,
    layers_to_approximate: List,
    seed: int,
    batch_size: int,
):

    seed_everything(seed)

    model_name_slug = model_name.split("/")[-1]

    EMBEDDINGS_DIR = str(
        Path(project_path) / "embeddings" / dataset_name / model_name_slug
    )

    print(f"Loading embeddings from: {EMBEDDINGS_DIR}")

    if not os.path.exists(EMBEDDINGS_DIR):
        raise FileNotFoundError(f"Embeddings not found: {EMBEDDINGS_DIR}.")
    embeddings = DatasetDict.load_from_disk(EMBEDDINGS_DIR)
    embeddings.set_format("torch")

    if model_name not in MODEL2CONFIGS:
        raise ValueError(f"Model configuration not found for '{model_name}' in MODEL2CONFIGS.")

    print(f'Approximating {layers_to_approximate}')
    embedding_col_name = str(layers_to_approximate)

    if (embedding_col_name not in embeddings["train"].column_names) or (
        embedding_col_name not in embeddings["test"].column_names
    ):
        raise KeyError(f"Skip '{embedding_col_name}' not found in loaded embeddings.")

    label_col_name = DATASET2LABEL_COLUMN[dataset_name]

    hf_train_embeddings = (
        embeddings["train"]
        .select_columns([embedding_col_name, label_col_name])
        .rename_column(embedding_col_name, "images")
        .rename_column(label_col_name, "labels")
    )

    hf_test_embeddings = (
        embeddings["test"]
        .select_columns([embedding_col_name, label_col_name])
        .rename_column(embedding_col_name, "images")
        .rename_column(label_col_name, "labels")
    )

    batch_size = batch_size
    num_workers = 2
    num_classes = DATASET2NUM_CLASSES[dataset_name]

    train_dataloader = DataLoader(
        hf_train_embeddings, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True
    )
    test_dataloader = DataLoader(
        hf_test_embeddings, shuffle=False, batch_size=batch_size, num_workers=num_workers, pin_memory=True
    )

    sample_embedding = embeddings["train"][0][embedding_col_name]
    hidden_size = sample_embedding.shape[-1]

    classifier = nn.Linear(hidden_size, num_classes)
    no_encoder = NoEncoder(embeddings=None)
    skip_model = HFwrapper(encoder=no_encoder, classifier=classifier)
    skip_model.to(device)
    skip_model.freeze_encoder()

    lr = 0.01
    num_epochs = 5
    optimizer = optim.Adam(skip_model.parameters(), lr=lr)

    print("Starting classifier training...")
    _, _, _, eval_accuracies, _ = train_classifier(
        model=skip_model,
        train_data_loader=train_dataloader,
        test_data_loader=test_dataloader,
        optimizer=optimizer,
        criterion=nn.CrossEntropyLoss(),
        label_column_name="labels",
        num_epochs=num_epochs,
    )
    accuracy = eval_accuracies[-1]
    print(f"Training finished. Final accuracy: {accuracy:.4f}")

    columns = [
        "seed",
        "dataset",
        "model",
        "approx_layer",
        "accuracy",
        "delta_acc",
    ]

    results_path = Path(project_path) / "results.csv"

    if os.path.exists(results_path):
        try:
            results_df = pd.read_csv(results_path)
        except pd.errors.EmptyDataError:
            print(f"Results file {results_path} is empty. Initializing DataFrame.")
            results_df = pd.DataFrame(columns=columns)
        except Exception as e:
            print(f"Error reading results file {results_path}: {e}. Initializing DataFrame.")
            results_df = pd.DataFrame(columns=columns)
    else:
        results_path.parent.mkdir(parents=True, exist_ok=True)
        results_df = pd.DataFrame(columns=columns)

    results_list = []
    results = {}
    original_accuracy = 0.0
    baseline_skip_repr = str([])

    if str(layers_to_approximate) == baseline_skip_repr:
        original_accuracy = accuracy
    else:
        filtered_df = results_df[
            (results_df["approx_layer"] == "[]")
            & (results_df["dataset"] == dataset_name)
            & (results_df["model"] == model_name)
            & (results_df["seed"] == seed)        ]
        original_accuracy = filtered_df["accuracy"].iloc[0] if not filtered_df.empty else 0.0

    delta_acc = (
        original_accuracy - accuracy if original_accuracy is not None and original_accuracy != 0.0 else 0.0
    )

    results = {
        "seed": seed,
        "dataset": dataset_name,
        "model": model_name,
        "approx_layer": layers_to_approximate,
        "accuracy": accuracy,
        "delta_acc": delta_acc,
        }

    results_list.append(results)

    new_results_df = pd.DataFrame(results_list)
    results_df = pd.concat([results_df, new_results_df])
    results_df.to_csv(results_path, index=False)

In [None]:
approximations = [[], [(10, 11)]]  # important: these layers as to be the same as in the encoding part
seeds = [0, 1, 2]

for approximation_config, seed_value in itertools.product(approximations, seeds):
    classification(
        dataset_name="cifar10",
        model_name="facebook/dinov2-small",
        layers_to_approximate=approximation_config,
        seed=seed_value,
        batch_size=256,
    )

results = pd.read_csv(Path(project_path) / 'results.csv')
results.drop(columns=['seed']).groupby(["model", "dataset", "approx_layer"]).agg(["mean", "std"]).round(3)