## Imports

In [1]:
import copy
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional

from model_merging.data.dataset import HFImageClassification
from model_merging.model.image_classifier import ImageClassifier
import open_clip
import wandb

import hydra
import omegaconf
import pytorch_lightning as pl
import torch
from hydra import compose, initialize
from hydra.utils import instantiate
from lightning.pytorch import Callback
from omegaconf import DictConfig, ListConfig, OmegaConf
from torch.nn.utils import parameters_to_vector, vector_to_parameters

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import enforce_tags, seed_index_everything
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO

# Force the execution of __init__.py if this file is executed directly.
import model_merging  # noqa
from model_merging.model.encoder import ClassificationHead, ImageEncoder
from model_merging.model.heads import (
    get_classification_head,
)
from model_merging.utils.io_utils import (
    boilerplate,
    load_model_from_hf,
)
from model_merging.utils.plots import plot_interactive_radar_chart
from model_merging.utils.utils import (
    build_callbacks,
    get_finetuning_accuracies,
    compute_avg_accuracy,
    print_memory,
)
import json
import os

  from .autonotebook import tqdm as notebook_tqdm


  import pkg_resources
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")


In [2]:
import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="layer_analysis")
cfg = compose(config_name="multitask", overrides=["benchmark=N2"])

'hydra/launcher/basic' is validated against ConfigStore schema with the same name.
This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/automatic_schema_matching for migration instructions.
  coro.send(None)


## Boilerplate

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
pylogger = logging.getLogger(__name__)

In [5]:
seed_index_everything(cfg)

logger, template_core = boilerplate(cfg)

num_tasks = len(cfg.benchmark.datasets)

# Temporarily disable struct mode to allow dynamic update
omegaconf.OmegaConf.set_struct(cfg, False)
cfg.num_tasks = num_tasks  # Now we can safely update it
omegaconf.OmegaConf.set_struct(cfg, True)  # Re-enable struct mode

# upperbound accuracies, used for logging the normalized accuracy
finetuned_accuracies: Dict[str, float] = get_finetuning_accuracies(
    cfg.misc.finetuned_accuracy_path
)[cfg.nn.encoder.model_name]

[34m[1mwandb[0m: Currently logged in as: [33mcrisostomi[0m ([33mgladia[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Load models

In [6]:
# only has vision encoder, no text transformer
base_model: ImageEncoder = load_model_from_hf(
    model_name=cfg.nn.encoder.model_name
)

finetuned_models = {
    dataset.name: load_model_from_hf(
        model_name=cfg.nn.encoder.model_name, dataset_name=dataset.name
    )
    for dataset in cfg.benchmark.datasets
}

pylogger.info(f"Number of tasks: {cfg.num_tasks}")
pylogger.info(f"Finetuned models: {list(finetuned_models.keys())}")

finetuned_state_dicts = {
    dataset.name: finetuned_models[dataset.name].state_dict()
    for dataset in cfg.benchmark.datasets
}


## Training representation task vectors

### Get base representations

In [7]:
from tqdm import tqdm


base_representations: Dict[str, torch.Tensor] = {}

for dataset_cfg in cfg.benchmark.datasets:

    dataset = instantiate(
        dataset_cfg, preprocess_fn=base_model.val_preprocess
    )


    all_representations: List[torch.Tensor] = []
    base_model.eval().cuda()

    with torch.no_grad():
        for batch in tqdm(dataset.train_loader):
            images, labels = batch
            images = images.cuda()
            labels = labels.cuda()

            representations = base_model(images)
            all_representations.append(representations.cpu())

    all_representations_tensor = torch.cat(all_representations, dim=0)

    base_representations[dataset_cfg.name] = all_representations_tensor

100%|██████████| 148/148 [00:14<00:00, 10.41it/s]
100%|██████████| 64/64 [00:08<00:00,  7.17it/s]


### Get finetuned representations

In [8]:
from tqdm import tqdm

ft_representations: Dict[str, torch.Tensor] = {}

for dataset_cfg in cfg.benchmark.datasets:

    dataset = instantiate(
        dataset_cfg, preprocess_fn=base_model.val_preprocess
    )

    all_representations: List[torch.Tensor] = []

    ft_model = finetuned_models[dataset_cfg.name].cuda()

    with torch.no_grad():
        for batch in tqdm(dataset.train_loader):
            images, labels = batch
            images = images.cuda()
            labels = labels.cuda()

            representations = ft_model(images)
            all_representations.append(representations.cpu())

    all_representations_tensor = torch.cat(all_representations, dim=0)

    ft_representations[dataset_cfg.name] = all_representations_tensor

100%|██████████| 148/148 [00:13<00:00, 10.85it/s]
100%|██████████| 64/64 [00:09<00:00,  7.06it/s]


### Fit the transformation

In [9]:
task_transformations: Dict[str, torch.nn.Module] = {}
orthogonal_task_transformations: Dict[str, torch.nn.Module] = {}

for dataset_name, base_repr in base_representations.items():
    ft_repr = ft_representations[dataset_name]

    if base_repr.shape != ft_repr.shape:
        raise ValueError(
            f"Base and finetuned representations must have matching shapes, got {base_repr.shape} and {ft_repr.shape} for {dataset_name}."
        )

    base_double = base_repr.double()
    ft_double = ft_repr.double()

    solution = torch.linalg.lstsq(base_double, ft_double).solution.to(base_repr.dtype)

    linear = torch.nn.Linear(base_repr.shape[1], ft_repr.shape[1])
    with torch.no_grad():
        linear.weight.copy_(solution.T)
        linear.bias.zero_()

    task_transformations[dataset_name] = linear

    cross_covariance = base_double.T @ ft_double
    U, _, Vh = torch.linalg.svd(cross_covariance, full_matrices=False)
    orthogonal_matrix = (U @ Vh).to(base_repr.dtype)

    orthogonal_linear = torch.nn.Linear(
        base_repr.shape[1], ft_repr.shape[1], bias=False
    )
    with torch.no_grad():
        orthogonal_linear.weight.copy_(orthogonal_matrix.T)

    orthogonal_task_transformations[dataset_name] = orthogonal_linear


In [10]:
ft_representations['Cars'].shape

torch.Size([8144, 512])

In [11]:
class TransformedImageClassifier(ImageClassifier):
    def __init__(
        self,
        encoder: ImageEncoder,
        classifier: ClassificationHead,
        transform: torch.nn.Module,
        **kwargs,
    ):
        super().__init__(encoder, classifier, **kwargs)
        self.transform = transform

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embeddings = self.encoder(x)

        embeddings = self.transform(embeddings)

        logits = self.classification_head(embeddings)

        return logits

### Evaluate

In [12]:
def evaluate_transformations(
    transformations: Dict[str, torch.nn.Module], label: str
) -> Dict[str, Any]:
    results = {}

    for dataset_cfg in cfg.benchmark.datasets:

        dataset = instantiate(
            dataset_cfg, preprocess_fn=base_model.val_preprocess
        )

        classification_head = get_classification_head(
            cfg.nn.encoder.model_name,
            dataset_cfg.name,
            ckpt_path=cfg.misc.ckpt_path,
            openclip_cachedir=cfg.misc.openclip_cachedir,
            device=cfg.device,
        )

        model = TransformedImageClassifier(
            encoder=base_model,
            classifier=classification_head,
            transform=transformations[dataset_cfg.name],
            x_key=cfg.conventions.x_key,
            y_key=cfg.conventions.y_key,
        )

        model.set_metrics(len(dataset.classnames))
        model.set_task(dataset_cfg.name)
        model.set_finetuning_accuracy(
            finetuned_accuracies[
                dataset_cfg.name + "Val" if cfg.eval_on_train else dataset_cfg.name
            ]
        )

        callbacks: List[Callback] = build_callbacks(
            cfg.train.callbacks, template_core
        )

        trainer = pl.Trainer(
            default_root_dir=cfg.core.storage_dir,
            plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
            logger=logger,
            callbacks=callbacks,
            limit_test_batches=(
                cfg.number_of_train_batches if cfg.eval_on_train else None
            ),
            **cfg.train.trainer,
        )
        pylogger.info(f"Evaluating {label} transform on the {dataset_cfg.name} test set!")
        test_results = trainer.test(model=model, dataloaders=dataset.test_loader)

        results[dataset_cfg.name] = test_results

    avg = compute_avg_accuracy(results)
    results["avg"] = [
        avg
    ]  # as a list for consistency due to lightning logging stuff this way

    prefixed_avg = {f"{label}_{metric}": value for metric, value in avg.items()}
    logger.experiment.log(prefixed_avg)

    pylogger.info({label: results})

    return results

linear_results = evaluate_transformations(task_transformations, "linear")
orthogonal_results = evaluate_transformations(
    orthogonal_task_transformations, "orthogonal"
)


Loading classification head from ./models//ViT-B-32/head_RESISC45.pt


  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


INFO: You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Testing DataLoader 0: 100%|██████████| 50/50 [00:04<00:00, 11.36it/s]


INFO: GPU available: True (cuda), used: True


Loading classification head from ./models//ViT-B-32/head_Cars.pt


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


INFO: You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Testing DataLoader 0: 100%|██████████| 63/63 [00:08<00:00,  7.49it/s]


INFO: GPU available: True (cuda), used: True


Loading classification head from ./models//ViT-B-32/head_RESISC45.pt


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


INFO: You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Testing DataLoader 0: 100%|██████████| 50/50 [00:04<00:00, 11.53it/s]


Loading classification head from ./models//ViT-B-32/head_Cars.pt


INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


INFO: You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Testing DataLoader 0: 100%|██████████| 63/63 [00:08<00:00,  7.35it/s]
