In [1]:
import argparse
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image
from pathlib import Path
from timm.models import create_model

import utils
import modeling_pretrain
from datasets import DataAugmentationForVideoMAE
from torchvision.transforms import ToPILImage
from einops import rearrange
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from decord import VideoReader, cpu
from torchvision import transforms
from transforms import *
from masking_generator import  TubeMaskingGenerator

In [2]:
class DataAugmentationForVideoMAE(object):
    def __init__(self, args):
        self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
        self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
        normalize = GroupNormalize(self.input_mean, self.input_std)
        self.train_augmentation = GroupCenterCrop(args.input_size)
        self.transform = transforms.Compose([                            
            self.train_augmentation,
            Stack(roll=False),
            ToTorchFormatTensor(div=True),
            normalize,
        ])
        if args.mask_type == 'tube':
            self.masked_position_generator = TubeMaskingGenerator(
                args.window_size, args.mask_ratio
            )

    def __call__(self, images):
        process_data , _ = self.transform(images)
        return process_data, self.masked_position_generator()

    def __repr__(self):
        repr = "(DataAugmentationForVideoMAE,\n"
        repr += "  transform = %s,\n" % str(self.transform)
        repr += "  Masked position generator = %s,\n" % str(self.masked_position_generator)
        repr += ")"
        return repr


In [5]:
model = create_model(
    'pretrain_videomae_base_patch16_224',
    pretrained=False,
    drop_path_rate=0.1,
    drop_block_rate=None,
    decoder_depth=4
)

In [8]:
device = torch.device('cuda:0')
cudnn.benchmark = True
patch_size = model.encoder.patch_embed.patch_size
print("Patch size = %s" % str(patch_size))
window_size = (16 // 2, 224 // patch_size[0], 224 // patch_size[1])
patch_size = patch_size

Patch size = (16, 16)


In [9]:
model.to(device)

FileNotFoundError: [Errno 2] No such file or directory: './checkpoint.pth'

In [19]:
checkpoint = torch.load('../../checkpoint.pth', map_location='cpu')
model.load_state_dict(checkpoint['model'])

RuntimeError: Error(s) in loading state_dict for PretrainVisionTransformer:
	Unexpected key(s) in state_dict: "encoder.blocks.12.norm1.weight", "encoder.blocks.12.norm1.bias", "encoder.blocks.12.attn.q_bias", "encoder.blocks.12.attn.v_bias", "encoder.blocks.12.attn.qkv.weight", "encoder.blocks.12.attn.proj.weight", "encoder.blocks.12.attn.proj.bias", "encoder.blocks.12.norm2.weight", "encoder.blocks.12.norm2.bias", "encoder.blocks.12.mlp.fc1.weight", "encoder.blocks.12.mlp.fc1.bias", "encoder.blocks.12.mlp.fc2.weight", "encoder.blocks.12.mlp.fc2.bias", "encoder.blocks.13.norm1.weight", "encoder.blocks.13.norm1.bias", "encoder.blocks.13.attn.q_bias", "encoder.blocks.13.attn.v_bias", "encoder.blocks.13.attn.qkv.weight", "encoder.blocks.13.attn.proj.weight", "encoder.blocks.13.attn.proj.bias", "encoder.blocks.13.norm2.weight", "encoder.blocks.13.norm2.bias", "encoder.blocks.13.mlp.fc1.weight", "encoder.blocks.13.mlp.fc1.bias", "encoder.blocks.13.mlp.fc2.weight", "encoder.blocks.13.mlp.fc2.bias", "encoder.blocks.14.norm1.weight", "encoder.blocks.14.norm1.bias", "encoder.blocks.14.attn.q_bias", "encoder.blocks.14.attn.v_bias", "encoder.blocks.14.attn.qkv.weight", "encoder.blocks.14.attn.proj.weight", "encoder.blocks.14.attn.proj.bias", "encoder.blocks.14.norm2.weight", "encoder.blocks.14.norm2.bias", "encoder.blocks.14.mlp.fc1.weight", "encoder.blocks.14.mlp.fc1.bias", "encoder.blocks.14.mlp.fc2.weight", "encoder.blocks.14.mlp.fc2.bias", "encoder.blocks.15.norm1.weight", "encoder.blocks.15.norm1.bias", "encoder.blocks.15.attn.q_bias", "encoder.blocks.15.attn.v_bias", "encoder.blocks.15.attn.qkv.weight", "encoder.blocks.15.attn.proj.weight", "encoder.blocks.15.attn.proj.bias", "encoder.blocks.15.norm2.weight", "encoder.blocks.15.norm2.bias", "encoder.blocks.15.mlp.fc1.weight", "encoder.blocks.15.mlp.fc1.bias", "encoder.blocks.15.mlp.fc2.weight", "encoder.blocks.15.mlp.fc2.bias", "encoder.blocks.16.norm1.weight", "encoder.blocks.16.norm1.bias", "encoder.blocks.16.attn.q_bias", "encoder.blocks.16.attn.v_bias", "encoder.blocks.16.attn.qkv.weight", "encoder.blocks.16.attn.proj.weight", "encoder.blocks.16.attn.proj.bias", "encoder.blocks.16.norm2.weight", "encoder.blocks.16.norm2.bias", "encoder.blocks.16.mlp.fc1.weight", "encoder.blocks.16.mlp.fc1.bias", "encoder.blocks.16.mlp.fc2.weight", "encoder.blocks.16.mlp.fc2.bias", "encoder.blocks.17.norm1.weight", "encoder.blocks.17.norm1.bias", "encoder.blocks.17.attn.q_bias", "encoder.blocks.17.attn.v_bias", "encoder.blocks.17.attn.qkv.weight", "encoder.blocks.17.attn.proj.weight", "encoder.blocks.17.attn.proj.bias", "encoder.blocks.17.norm2.weight", "encoder.blocks.17.norm2.bias", "encoder.blocks.17.mlp.fc1.weight", "encoder.blocks.17.mlp.fc1.bias", "encoder.blocks.17.mlp.fc2.weight", "encoder.blocks.17.mlp.fc2.bias", "encoder.blocks.18.norm1.weight", "encoder.blocks.18.norm1.bias", "encoder.blocks.18.attn.q_bias", "encoder.blocks.18.attn.v_bias", "encoder.blocks.18.attn.qkv.weight", "encoder.blocks.18.attn.proj.weight", "encoder.blocks.18.attn.proj.bias", "encoder.blocks.18.norm2.weight", "encoder.blocks.18.norm2.bias", "encoder.blocks.18.mlp.fc1.weight", "encoder.blocks.18.mlp.fc1.bias", "encoder.blocks.18.mlp.fc2.weight", "encoder.blocks.18.mlp.fc2.bias", "encoder.blocks.19.norm1.weight", "encoder.blocks.19.norm1.bias", "encoder.blocks.19.attn.q_bias", "encoder.blocks.19.attn.v_bias", "encoder.blocks.19.attn.qkv.weight", "encoder.blocks.19.attn.proj.weight", "encoder.blocks.19.attn.proj.bias", "encoder.blocks.19.norm2.weight", "encoder.blocks.19.norm2.bias", "encoder.blocks.19.mlp.fc1.weight", "encoder.blocks.19.mlp.fc1.bias", "encoder.blocks.19.mlp.fc2.weight", "encoder.blocks.19.mlp.fc2.bias", "encoder.blocks.20.norm1.weight", "encoder.blocks.20.norm1.bias", "encoder.blocks.20.attn.q_bias", "encoder.blocks.20.attn.v_bias", "encoder.blocks.20.attn.qkv.weight", "encoder.blocks.20.attn.proj.weight", "encoder.blocks.20.attn.proj.bias", "encoder.blocks.20.norm2.weight", "encoder.blocks.20.norm2.bias", "encoder.blocks.20.mlp.fc1.weight", "encoder.blocks.20.mlp.fc1.bias", "encoder.blocks.20.mlp.fc2.weight", "encoder.blocks.20.mlp.fc2.bias", "encoder.blocks.21.norm1.weight", "encoder.blocks.21.norm1.bias", "encoder.blocks.21.attn.q_bias", "encoder.blocks.21.attn.v_bias", "encoder.blocks.21.attn.qkv.weight", "encoder.blocks.21.attn.proj.weight", "encoder.blocks.21.attn.proj.bias", "encoder.blocks.21.norm2.weight", "encoder.blocks.21.norm2.bias", "encoder.blocks.21.mlp.fc1.weight", "encoder.blocks.21.mlp.fc1.bias", "encoder.blocks.21.mlp.fc2.weight", "encoder.blocks.21.mlp.fc2.bias", "encoder.blocks.22.norm1.weight", "encoder.blocks.22.norm1.bias", "encoder.blocks.22.attn.q_bias", "encoder.blocks.22.attn.v_bias", "encoder.blocks.22.attn.qkv.weight", "encoder.blocks.22.attn.proj.weight", "encoder.blocks.22.attn.proj.bias", "encoder.blocks.22.norm2.weight", "encoder.blocks.22.norm2.bias", "encoder.blocks.22.mlp.fc1.weight", "encoder.blocks.22.mlp.fc1.bias", "encoder.blocks.22.mlp.fc2.weight", "encoder.blocks.22.mlp.fc2.bias", "encoder.blocks.23.norm1.weight", "encoder.blocks.23.norm1.bias", "encoder.blocks.23.attn.q_bias", "encoder.blocks.23.attn.v_bias", "encoder.blocks.23.attn.qkv.weight", "encoder.blocks.23.attn.proj.weight", "encoder.blocks.23.attn.proj.bias", "encoder.blocks.23.norm2.weight", "encoder.blocks.23.norm2.bias", "encoder.blocks.23.mlp.fc1.weight", "encoder.blocks.23.mlp.fc1.bias", "encoder.blocks.23.mlp.fc2.weight", "encoder.blocks.23.mlp.fc2.bias", "decoder.blocks.4.norm1.weight", "decoder.blocks.4.norm1.bias", "decoder.blocks.4.attn.q_bias", "decoder.blocks.4.attn.v_bias", "decoder.blocks.4.attn.qkv.weight", "decoder.blocks.4.attn.proj.weight", "decoder.blocks.4.attn.proj.bias", "decoder.blocks.4.norm2.weight", "decoder.blocks.4.norm2.bias", "decoder.blocks.4.mlp.fc1.weight", "decoder.blocks.4.mlp.fc1.bias", "decoder.blocks.4.mlp.fc2.weight", "decoder.blocks.4.mlp.fc2.bias", "decoder.blocks.5.norm1.weight", "decoder.blocks.5.norm1.bias", "decoder.blocks.5.attn.q_bias", "decoder.blocks.5.attn.v_bias", "decoder.blocks.5.attn.qkv.weight", "decoder.blocks.5.attn.proj.weight", "decoder.blocks.5.attn.proj.bias", "decoder.blocks.5.norm2.weight", "decoder.blocks.5.norm2.bias", "decoder.blocks.5.mlp.fc1.weight", "decoder.blocks.5.mlp.fc1.bias", "decoder.blocks.5.mlp.fc2.weight", "decoder.blocks.5.mlp.fc2.bias", "decoder.blocks.6.norm1.weight", "decoder.blocks.6.norm1.bias", "decoder.blocks.6.attn.q_bias", "decoder.blocks.6.attn.v_bias", "decoder.blocks.6.attn.qkv.weight", "decoder.blocks.6.attn.proj.weight", "decoder.blocks.6.attn.proj.bias", "decoder.blocks.6.norm2.weight", "decoder.blocks.6.norm2.bias", "decoder.blocks.6.mlp.fc1.weight", "decoder.blocks.6.mlp.fc1.bias", "decoder.blocks.6.mlp.fc2.weight", "decoder.blocks.6.mlp.fc2.bias", "decoder.blocks.7.norm1.weight", "decoder.blocks.7.norm1.bias", "decoder.blocks.7.attn.q_bias", "decoder.blocks.7.attn.v_bias", "decoder.blocks.7.attn.qkv.weight", "decoder.blocks.7.attn.proj.weight", "decoder.blocks.7.attn.proj.bias", "decoder.blocks.7.norm2.weight", "decoder.blocks.7.norm2.bias", "decoder.blocks.7.mlp.fc1.weight", "decoder.blocks.7.mlp.fc1.bias", "decoder.blocks.7.mlp.fc2.weight", "decoder.blocks.7.mlp.fc2.bias", "decoder.blocks.8.norm1.weight", "decoder.blocks.8.norm1.bias", "decoder.blocks.8.attn.q_bias", "decoder.blocks.8.attn.v_bias", "decoder.blocks.8.attn.qkv.weight", "decoder.blocks.8.attn.proj.weight", "decoder.blocks.8.attn.proj.bias", "decoder.blocks.8.norm2.weight", "decoder.blocks.8.norm2.bias", "decoder.blocks.8.mlp.fc1.weight", "decoder.blocks.8.mlp.fc1.bias", "decoder.blocks.8.mlp.fc2.weight", "decoder.blocks.8.mlp.fc2.bias", "decoder.blocks.9.norm1.weight", "decoder.blocks.9.norm1.bias", "decoder.blocks.9.attn.q_bias", "decoder.blocks.9.attn.v_bias", "decoder.blocks.9.attn.qkv.weight", "decoder.blocks.9.attn.proj.weight", "decoder.blocks.9.attn.proj.bias", "decoder.blocks.9.norm2.weight", "decoder.blocks.9.norm2.bias", "decoder.blocks.9.mlp.fc1.weight", "decoder.blocks.9.mlp.fc1.bias", "decoder.blocks.9.mlp.fc2.weight", "decoder.blocks.9.mlp.fc2.bias", "decoder.blocks.10.norm1.weight", "decoder.blocks.10.norm1.bias", "decoder.blocks.10.attn.q_bias", "decoder.blocks.10.attn.v_bias", "decoder.blocks.10.attn.qkv.weight", "decoder.blocks.10.attn.proj.weight", "decoder.blocks.10.attn.proj.bias", "decoder.blocks.10.norm2.weight", "decoder.blocks.10.norm2.bias", "decoder.blocks.10.mlp.fc1.weight", "decoder.blocks.10.mlp.fc1.bias", "decoder.blocks.10.mlp.fc2.weight", "decoder.blocks.10.mlp.fc2.bias", "decoder.blocks.11.norm1.weight", "decoder.blocks.11.norm1.bias", "decoder.blocks.11.attn.q_bias", "decoder.blocks.11.attn.v_bias", "decoder.blocks.11.attn.qkv.weight", "decoder.blocks.11.attn.proj.weight", "decoder.blocks.11.attn.proj.bias", "decoder.blocks.11.norm2.weight", "decoder.blocks.11.norm2.bias", "decoder.blocks.11.mlp.fc1.weight", "decoder.blocks.11.mlp.fc1.bias", "decoder.blocks.11.mlp.fc2.weight", "decoder.blocks.11.mlp.fc2.bias". 
	size mismatch for mask_token: copying a param with shape torch.Size([1, 1, 512]) from checkpoint, the shape in current model is torch.Size([1, 1, 384]).
	size mismatch for encoder.patch_embed.proj.weight: copying a param with shape torch.Size([1024, 3, 2, 16, 16]) from checkpoint, the shape in current model is torch.Size([768, 3, 2, 16, 16]).
	size mismatch for encoder.patch_embed.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.0.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.0.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.0.attn.q_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.0.attn.v_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.0.attn.qkv.weight: copying a param with shape torch.Size([3072, 1024]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for encoder.blocks.0.attn.proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for encoder.blocks.0.attn.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.0.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.0.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.0.mlp.fc1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for encoder.blocks.0.mlp.fc1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for encoder.blocks.0.mlp.fc2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for encoder.blocks.0.mlp.fc2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.1.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.1.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.1.attn.q_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.1.attn.v_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.1.attn.qkv.weight: copying a param with shape torch.Size([3072, 1024]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for encoder.blocks.1.attn.proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for encoder.blocks.1.attn.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.1.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.1.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.1.mlp.fc1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for encoder.blocks.1.mlp.fc1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for encoder.blocks.1.mlp.fc2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for encoder.blocks.1.mlp.fc2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.2.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.2.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.2.attn.q_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.2.attn.v_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.2.attn.qkv.weight: copying a param with shape torch.Size([3072, 1024]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for encoder.blocks.2.attn.proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for encoder.blocks.2.attn.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.2.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.2.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.2.mlp.fc1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for encoder.blocks.2.mlp.fc1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for encoder.blocks.2.mlp.fc2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for encoder.blocks.2.mlp.fc2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.3.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.3.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.3.attn.q_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.3.attn.v_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.3.attn.qkv.weight: copying a param with shape torch.Size([3072, 1024]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for encoder.blocks.3.attn.proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for encoder.blocks.3.attn.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.3.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.3.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.3.mlp.fc1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for encoder.blocks.3.mlp.fc1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for encoder.blocks.3.mlp.fc2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for encoder.blocks.3.mlp.fc2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.4.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.4.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.4.attn.q_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.4.attn.v_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.4.attn.qkv.weight: copying a param with shape torch.Size([3072, 1024]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for encoder.blocks.4.attn.proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for encoder.blocks.4.attn.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.4.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.4.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.4.mlp.fc1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for encoder.blocks.4.mlp.fc1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for encoder.blocks.4.mlp.fc2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for encoder.blocks.4.mlp.fc2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.5.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.5.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.5.attn.q_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.5.attn.v_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.5.attn.qkv.weight: copying a param with shape torch.Size([3072, 1024]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for encoder.blocks.5.attn.proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for encoder.blocks.5.attn.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.5.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.5.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.5.mlp.fc1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for encoder.blocks.5.mlp.fc1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for encoder.blocks.5.mlp.fc2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for encoder.blocks.5.mlp.fc2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.6.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.6.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.6.attn.q_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.6.attn.v_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.6.attn.qkv.weight: copying a param with shape torch.Size([3072, 1024]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for encoder.blocks.6.attn.proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for encoder.blocks.6.attn.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.6.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.6.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.6.mlp.fc1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for encoder.blocks.6.mlp.fc1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for encoder.blocks.6.mlp.fc2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for encoder.blocks.6.mlp.fc2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.7.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.7.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.7.attn.q_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.7.attn.v_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.7.attn.qkv.weight: copying a param with shape torch.Size([3072, 1024]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for encoder.blocks.7.attn.proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for encoder.blocks.7.attn.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.7.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.7.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.7.mlp.fc1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for encoder.blocks.7.mlp.fc1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for encoder.blocks.7.mlp.fc2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for encoder.blocks.7.mlp.fc2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.8.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.8.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.8.attn.q_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.8.attn.v_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.8.attn.qkv.weight: copying a param with shape torch.Size([3072, 1024]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for encoder.blocks.8.attn.proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for encoder.blocks.8.attn.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.8.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.8.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.8.mlp.fc1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for encoder.blocks.8.mlp.fc1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for encoder.blocks.8.mlp.fc2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for encoder.blocks.8.mlp.fc2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.9.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.9.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.9.attn.q_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.9.attn.v_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.9.attn.qkv.weight: copying a param with shape torch.Size([3072, 1024]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for encoder.blocks.9.attn.proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for encoder.blocks.9.attn.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.9.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.9.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.9.mlp.fc1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for encoder.blocks.9.mlp.fc1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for encoder.blocks.9.mlp.fc2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for encoder.blocks.9.mlp.fc2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.10.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.10.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.10.attn.q_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.10.attn.v_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.10.attn.qkv.weight: copying a param with shape torch.Size([3072, 1024]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for encoder.blocks.10.attn.proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for encoder.blocks.10.attn.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.10.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.10.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.10.mlp.fc1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for encoder.blocks.10.mlp.fc1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for encoder.blocks.10.mlp.fc2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for encoder.blocks.10.mlp.fc2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.11.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.11.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.11.attn.q_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.11.attn.v_bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.11.attn.qkv.weight: copying a param with shape torch.Size([3072, 1024]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for encoder.blocks.11.attn.proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for encoder.blocks.11.attn.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.11.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.11.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.blocks.11.mlp.fc1.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for encoder.blocks.11.mlp.fc1.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for encoder.blocks.11.mlp.fc2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for encoder.blocks.11.mlp.fc2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.norm.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for encoder.norm.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for decoder.blocks.0.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.0.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.0.attn.q_bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.0.attn.v_bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.0.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for decoder.blocks.0.attn.proj.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for decoder.blocks.0.attn.proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.0.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.0.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.0.mlp.fc1.weight: copying a param with shape torch.Size([2048, 512]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for decoder.blocks.0.mlp.fc1.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for decoder.blocks.0.mlp.fc2.weight: copying a param with shape torch.Size([512, 2048]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for decoder.blocks.0.mlp.fc2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.1.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.1.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.1.attn.q_bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.1.attn.v_bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.1.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for decoder.blocks.1.attn.proj.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for decoder.blocks.1.attn.proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.1.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.1.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.1.mlp.fc1.weight: copying a param with shape torch.Size([2048, 512]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for decoder.blocks.1.mlp.fc1.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for decoder.blocks.1.mlp.fc2.weight: copying a param with shape torch.Size([512, 2048]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for decoder.blocks.1.mlp.fc2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.2.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.2.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.2.attn.q_bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.2.attn.v_bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.2.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for decoder.blocks.2.attn.proj.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for decoder.blocks.2.attn.proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.2.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.2.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.2.mlp.fc1.weight: copying a param with shape torch.Size([2048, 512]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for decoder.blocks.2.mlp.fc1.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for decoder.blocks.2.mlp.fc2.weight: copying a param with shape torch.Size([512, 2048]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for decoder.blocks.2.mlp.fc2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.3.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.3.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.3.attn.q_bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.3.attn.v_bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.3.attn.qkv.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for decoder.blocks.3.attn.proj.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for decoder.blocks.3.attn.proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.3.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.3.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.blocks.3.mlp.fc1.weight: copying a param with shape torch.Size([2048, 512]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for decoder.blocks.3.mlp.fc1.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for decoder.blocks.3.mlp.fc2.weight: copying a param with shape torch.Size([512, 2048]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for decoder.blocks.3.mlp.fc2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for decoder.head.weight: copying a param with shape torch.Size([1536, 512]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for encoder_to_decoder.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([384, 768]).

In [None]:


    device = torch.device('cuda:0')
    cudnn.benchmark = True
    patch_size = model.encoder.patch_embed.patch_size
    print("Patch size = %s" % str(patch_size))
    window_size = (16 // 2, 224 // patch_size[0], 224 // patch_size[1])
    patch_size = patch_size

    model.to(device)
    checkpoint = torch.load(args.model_path, map_location='cpu')
    model.load_state_dict(checkpoint['module'])
    model.eval()

    if args.save_path:
        Path(args.save_path).mkdir(parents=True, exist_ok=True)

    with open(args.img_path, 'rb') as f:
        vr = VideoReader(f, ctx=cpu(0))
    duration = len(vr)
    new_length  = 1 
    new_step = 1
    skip_length = new_length * new_step
    # frame_id_list = [1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61]

    
    tmp = np.arange(0,32, 2) + 60
    frame_id_list = tmp.tolist()
    # average_duration = (duration - skip_length + 1) // args.num_frames
    # if average_duration > 0:
    #     frame_id_list = np.multiply(list(range(args.num_frames)),
    #                             average_duration)
    #     frame_id_list = frame_id_list + np.random.randint(average_duration,
    #                                             size=args.num_frames)

    video_data = vr.get_batch(frame_id_list).asnumpy()
    print(video_data.shape)
    img = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]

    transforms = DataAugmentationForVideoMAE(args)
    img, bool_masked_pos = transforms((img, None)) # T*C,H,W
    # print(img.shape)
    img = img.view((args.num_frames , 3) + img.size()[-2:]).transpose(0,1) # T*C,H,W -> T,C,H,W -> C,T,H,W
    # img = img.view(( -1 , args.num_frames) + img.size()[-2:]) 
    bool_masked_pos = torch.from_numpy(bool_masked_pos)

    with torch.no_grad():
        # img = img[None, :]
        # bool_masked_pos = bool_masked_pos[None, :]
        img = img.unsqueeze(0)
        print(img.shape)
        bool_masked_pos = bool_masked_pos.unsqueeze(0)
        
        img = img.to(device, non_blocking=True)
        bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
        outputs = model(img, bool_masked_pos)

        #save original video
        mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
        std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
        ori_img = img * std + mean  # in [0, 1]
        imgs = [ToPILImage()(ori_img[0,:,vid,:,:].cpu()) for vid, _ in enumerate(frame_id_list)  ]
        for id, im in enumerate(imgs):
            im.save(f"{args.save_path}/ori_img{id}.jpg")

        img_squeeze = rearrange(ori_img, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=2, p1=patch_size[0], p2=patch_size[0])
        img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
        img_patch = rearrange(img_norm, 'b n p c -> b n (p c)')
        img_patch[bool_masked_pos] = outputs

        #make mask
        mask = torch.ones_like(img_patch)
        mask[bool_masked_pos] = 0
        mask = rearrange(mask, 'b n (p c) -> b n p c', c=3)
        mask = rearrange(mask, 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2) ', p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14)

        #save reconstruction video
        rec_img = rearrange(img_patch, 'b n (p c) -> b n p c', c=3)
        # Notice: To visualize the reconstruction video, we add the predict and the original mean and var of each patch.
        rec_img = rec_img * (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) + img_squeeze.mean(dim=-2, keepdim=True)
        rec_img = rearrange(rec_img, 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)', p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14)
        imgs = [ ToPILImage()(rec_img[0, :, vid, :, :].cpu().clamp(0,0.996)) for vid, _ in enumerate(frame_id_list)  ]

        for id, im in enumerate(imgs):
            im.save(f"{args.save_path}/rec_img{id}.jpg")

        #save masked video 
        img_mask = rec_img * mask
        imgs = [ToPILImage()(img_mask[0, :, vid, :, :].cpu()) for vid, _ in enumerate(frame_id_list)]
        for id, im in enumerate(imgs):
            im.save(f"{args.save_path}/mask_img{id}.jpg")