In [None]:
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
import os
import random
import timm
from timm.utils import accuracy
import numpy as np
import copy
from utils.data_manager import DataManager

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seed = 1
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

args=dict()
args['sample_number'] = 200
args['increment'] = 100

data_manager=DataManager(dataset_name='imagenet1000', shuffle=True, seed=np.random.randint(100), 
    init_cls=args['increment'], increment=args['increment'], args=args)

class_order = data_manager._class_order
# class_mask = tuple([class_order[i:i+increment] for i in range(0, len(class_order), increment)])

train_dataset = data_manager.get_dataset(class_order,source="train", mode="train")
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=8)

In [None]:
meta_model=timm.create_model('vit_base_patch16_224_in21k', pretrained=False, checkpoint_path="[your path]/ViT-B_16.npz")
model_infos=torch.load("[your path]/meta_epoch_10.pth")

filtered_model_infos = {k: v for k, v in model_infos.items()if not k.startswith("fc.")}
load_result = meta_model.load_state_dict(filtered_model_infos, strict=False)            
# print("Missing keys:", load_result.missing_keys)
# print("Unexpected keys:", load_result.unexpected_keys)
# successful_keys = set(filtered_model_infos.keys()) - set(load_result.unexpected_keys)
# print("Successfully loaded keys:", successful_keys)
meta_model.eval()
meta_model.to(device)


class_features = {}
with torch.no_grad():
    for _, images, labels in train_loader:
        images = images.to(device)
        # features = model(images).cpu().numpy()  
        features = meta_model.forward_features(images)[:,0,:].cpu().numpy()
        
        for i, label in enumerate(labels):
            if label.item() not in class_features:
                class_features[label.item()] = []
            class_features[label.item()].append(features[i])

# 4. classes prototype 
class_avg_features = []
for label in sorted(class_features.keys()):
    avg_feature = np.mean(class_features[label], axis=0)
    class_avg_features.append(avg_feature)
class_avg_features = np.array(class_avg_features)  # shape: (1000, 768)

# sorted_labels = sorted(class_features.keys())
# class_avg_features = np.concatenate([np.array(class_features[label]) for label in sorted_labels], axis=0)

# cov_estimator = EmpiricalCovariance(assume_centered=True)
# cov_estimator.fit(class_avg_features)
# cov_matrix1 = cov_estimator.covariance_  # shape: (768, 768)

cov_matrix_meta10 = np.cov(class_avg_features, rowvar=False)

save_path = '[your path]/cov_matrix_backbone10.npy'
np.save(save_path, cov_matrix_meta10)