In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import yaml
from app.vjepa.utils import (
    load_checkpoint_cpu,
    init_video_model,
    init_opt,
)

In [2]:
fname = 'one_bit_logs/params-pretrain.yaml'
with open(fname, 'r') as y_file:
        params = yaml.load(y_file, Loader=yaml.FullLoader)

args=params

In [3]:
# -- META
cfgs_meta = args.get('meta')
save_every_freq = cfgs_meta.get('save_every_freq', -1)
skip_batches = cfgs_meta.get('skip_batches', -1)
use_sdpa = cfgs_meta.get('use_sdpa', False)
which_dtype = cfgs_meta.get('dtype')
if which_dtype.lower() == 'bfloat16':
    dtype = torch.bfloat16
    mixed_precision = True
elif which_dtype.lower() == 'float16':
    dtype = torch.float16
    mixed_precision = True
else:
    dtype = torch.float32
    mixed_precision = False

# -- MASK
cfgs_mask = args.get('mask')

# -- MODEL
cfgs_model = args.get('model')
model_name = cfgs_model.get('model_name')
pred_depth = cfgs_model.get('pred_depth')
pred_embed_dim = cfgs_model.get('pred_embed_dim')
uniform_power = cfgs_model.get('uniform_power', True)
use_mask_tokens = cfgs_model.get('use_mask_tokens', True)
zero_init_mask_tokens = cfgs_model.get('zero_init_mask_tokens', True)

# -- DATA
cfgs_data = args.get('data')
dataset_type = cfgs_data.get('dataset_type', 'videodataset')
mask_type = cfgs_data.get('mask_type', 'multiblock3d')
dataset_paths = cfgs_data.get('datasets', [])
datasets_weights = cfgs_data.get('datasets_weights', None)
if datasets_weights is not None:
    assert len(datasets_weights) == len(dataset_paths), 'Must have one sampling weight specified for each dataset'
batch_size = cfgs_data.get('batch_size')
num_clips = cfgs_data.get('num_clips')
num_frames = cfgs_data.get('num_frames')
tubelet_size = cfgs_data.get('tubelet_size')
sampling_rate = cfgs_data.get('sampling_rate')
duration = cfgs_data.get('clip_duration', None)
crop_size = cfgs_data.get('crop_size', 224)
patch_size = cfgs_data.get('patch_size')
pin_mem = cfgs_data.get('pin_mem', False)
num_workers = cfgs_data.get('num_workers', 1)
filter_short_videos = cfgs_data.get('filter_short_videos', False)
decode_one_clip = cfgs_data.get('decode_one_clip', True)
log_resource_util_data = cfgs_data.get('log_resource_utilization', False)

# -- DATA AUGS
cfgs_data_aug = args.get('data_aug')
ar_range = cfgs_data_aug.get('random_resize_aspect_ratio', [3/4, 4/3])
rr_scale = cfgs_data_aug.get('random_resize_scale', [0.3, 1.0])
motion_shift = cfgs_data_aug.get('motion_shift', False)
reprob = cfgs_data_aug.get('reprob', 0.)
use_aa = cfgs_data_aug.get('auto_augment', False)

# -- LOSS
cfgs_loss = args.get('loss')
loss_exp = cfgs_loss.get('loss_exp')
reg_coeff = cfgs_loss.get('reg_coeff')

# -- OPTIMIZATION
cfgs_opt = args.get('optimization')
ipe = cfgs_opt.get('ipe', None)
ipe_scale = cfgs_opt.get('ipe_scale', 1.0)
clip_grad = cfgs_opt.get('clip_grad', None)
wd = float(cfgs_opt.get('weight_decay'))
final_wd = float(cfgs_opt.get('final_weight_decay'))
num_epochs = cfgs_opt.get('epochs')
warmup = cfgs_opt.get('warmup')
start_lr = cfgs_opt.get('start_lr')
lr = cfgs_opt.get('lr')
final_lr = cfgs_opt.get('final_lr')
ema = cfgs_opt.get('ema')
betas = cfgs_opt.get('betas', (0.9, 0.999))
eps = cfgs_opt.get('eps', 1.e-8)

# -- LOGGING
cfgs_logging = args.get('logging')
folder = cfgs_logging.get('folder')
tag = cfgs_logging.get('write_tag')

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
encoder, predictor = init_video_model(
    uniform_power=uniform_power,
    use_mask_tokens=use_mask_tokens,
    num_mask_tokens=len(cfgs_mask),
    zero_init_mask_tokens=zero_init_mask_tokens,
    device=device,
    patch_size=patch_size,
    num_frames=num_frames,
    tubelet_size=tubelet_size,
    model_name=model_name,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_embed_dim=pred_embed_dim,
    use_sdpa=use_sdpa,
)

INFO:root:MultiMaskWrapper(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed3D(
      (proj): BitConv3d(3, 1024, kernel_size=(2, 16, 16), stride=(2, 16, 16))
    )
    (blocks): ModuleList(
      (0-23): 24 x Block(
        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): BitLinear(in_features=1024, out_features=3072, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): BitLinear(in_features=1024, out_features=1024, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (mlp): MLP(
          (fc1): BitLinear(in_features=1024, out_features=4096, bias=True)
          (act): GELU(approximate='none')
          (fc2): BitLinear(in_features=4096, out_features=1024, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm): LayerNorm((1024,), eps=1e-06, elemen

In [6]:
load_path = 'one_bit_logs/jepa-latest.pth.tar'
(
encoder,
predictor,
target_encoder,
opt,
scaler,
epoch,
) = load_checkpoint_cpu(r_path=load_path,encoder=encoder,predictor=predictor, target_encoder=None, opt = None, scaler= None)

INFO:root:loaded pretrained encoder from epoch 302 with msg: <All keys matched successfully>
INFO:root:loaded pretrained predictor from epoch 302 with msg: <All keys matched successfully>
INFO:root:Encountered exception when loading checkpoint 'NoneType' object has no attribute 'load_state_dict'


In [7]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np

# 假设您的encoder已经定义好并加载了权重
conv3d_weight = encoder.backbone.patch_embed.proj.weight.data.cpu().numpy()

def save_kernels_as_images(conv3d_weight, grid_size=20):
    out_channels, in_channels, T, H, W = conv3d_weight.shape
    
    # 创建输出文件夹
    output_dir = "conv3d_kernels"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    num_kernels_to_plot = min(out_channels * T, grid_size * grid_size)
    
    # 图像大小
    img_height = H
    img_width = W

    # 创建一个大的画布
    grid_img = np.zeros((grid_size * img_height, grid_size * img_width, 3))

    kernel_idx = 0
    for out_ch in range(out_channels):
        for t in range(T):
            if kernel_idx >= num_kernels_to_plot:
                break

            kernel = conv3d_weight[out_ch]
            kernel_slice = kernel[:, t, :, :]  # 取出时间维度t的切片

            # 归一化
            kernel_slice = (kernel_slice - kernel_slice.min()) / (kernel_slice.max() - kernel_slice.min())
            kernel_slice = np.transpose(kernel_slice, (1, 2, 0))  # 将通道放到最后

            row = kernel_idx // grid_size
            col = kernel_idx % grid_size

            grid_img[row * img_height:(row + 1) * img_height, col * img_width:(col + 1) * img_width, :] = kernel_slice

            kernel_idx += 1

    plt.figure(figsize=(grid_size, grid_size))
    plt.imshow(grid_img)
    plt.axis('off')
    plt.savefig(os.path.join(output_dir, 'kernels_grid.png'))
    plt.close()

save_kernels_as_images(conv3d_weight)


In [8]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np

from src.models.utils.bit_conv import weight_quant

# 假设您的encoder已经定义好并加载了权重
conv3d_weight = weight_quant(encoder.backbone.patch_embed.proj.weight.data.cpu()).numpy()


def save_kernels_as_images(conv3d_weight, grid_size=20):
    out_channels, in_channels, T, H, W = conv3d_weight.shape
    
    # 创建输出文件夹
    output_dir = "conv3d_kernels"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    num_kernels_to_plot = min(out_channels * T, grid_size * grid_size)
    
    # 图像大小
    img_height = H
    img_width = W

    # 创建一个大的画布
    grid_img = np.zeros((grid_size * img_height, grid_size * img_width, 3))

    kernel_idx = 0
    for out_ch in range(out_channels):
        for t in range(T):
            if kernel_idx >= num_kernels_to_plot:
                break

            kernel = conv3d_weight[out_ch]
            kernel_slice = kernel[:, t, :, :]  # 取出时间维度t的切片

            # 归一化
            kernel_slice = (kernel_slice - kernel_slice.min()) / (kernel_slice.max() - kernel_slice.min())
            kernel_slice = np.transpose(kernel_slice, (1, 2, 0))  # 将通道放到最后

            row = kernel_idx // grid_size
            col = kernel_idx % grid_size

            grid_img[row * img_height:(row + 1) * img_height, col * img_width:(col + 1) * img_width, :] = kernel_slice

            kernel_idx += 1

    plt.figure(figsize=(grid_size, grid_size))
    plt.imshow(grid_img)
    plt.axis('off')
    plt.savefig(os.path.join(output_dir, 'kernels_grid_one_bit.png'))
    plt.close()

save_kernels_as_images(conv3d_weight)
