In [1]:
# ###################################################
# Change current directory to the root of the project
import os
from pathlib import Path
current_dir = Path(os.getcwd())
os.chdir(current_dir.parents[1])
# ###################################################

In [2]:
import torch
from omegaconf.omegaconf import OmegaConf
from hyper_cl.models import get_model


### ResNet18LatentReplay

In [3]:
config = {"n_classes": 100,
          "model": "ResNet18LatentReplay",
          "multi_head": True,
          "latent_depth": 5,
          "seed": 0
          }

config = OmegaConf.create(config)

model = get_model(config)

x = torch.randn(200, 3, 32, 32)
feats = model.extract_feat(x)
print(feats.shape)
mem_consumption = feats.numel() * feats.element_size()
mem_consumption_MB = mem_consumption / (1024 * 1024)

# Memory consumption
print(f"Memory consumption: {mem_consumption_MB:.2f} MB")
# Parameter shape
print(f"Parameter shape: {feats.shape}")

torch.Size([200, 160, 4, 4])
Memory consumption: 1.95 MB
Parameter shape: torch.Size([200, 160, 4, 4])


### Hyper-Resnet-SH

In [5]:
config = {"n_classes": 100,
          "model": "HyperResNet18SH",
          # "model": "HyperResNet18SPv1SH",
          # "model": "HyperResNet18SPv2SH",
          # "model": "HyperResNet18SPv3SH",
          # "model": "HyperResNet18SPv4SH",
          "model_params": {
              "embd_dim": 32,
              "hidden_size_1": 50,
              "hidden_size_2": 32,
              "head_emb_dim": 32
          },
          "bnch_params": {
            "n_experiences": 20,
            "return_task_id": True,
            },
          "seed": 0
          }

config = OmegaConf.create(config)

model = get_model(config)

mem_consumption_MB = (model.weight_generator.n_params * 4) / (1024 * 1024)
print(f"Memory consumption: {mem_consumption_MB:.2f} MB")
print(f"Number of HN parameters: {model.weight_generator.n_params}")

Memory consumption: 4.86 MB
Number of HN parameters: 1272877
