In [21]:
import yaml
import functools

import torch
from torchsummary import summary

from Prithvi import MaskedAutoencoderViT

In [22]:
yaml_file_path = '../Prithvi-100M/Prithvi_100M_config.yaml'
with open(yaml_file_path, 'r') as f:
    params = yaml.safe_load(f)


In [23]:
mask_ratio = None

# data related
num_frames = 3
img_size = params['img_size']
bands = params['bands']
mean = params['data_mean']
std = params['data_std']

# model related
depth = params['depth']
patch_size = params['patch_size']
embed_dim = params['embed_dim']
num_heads = params['num_heads']
tubelet_size = params['tubelet_size']
decoder_embed_dim = params['decoder_embed_dim']
decoder_num_heads = params['decoder_num_heads']
decoder_depth = params['decoder_depth']

batch_size = params['batch_size']

mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio

In [24]:
model = MaskedAutoencoderViT(
            img_size=img_size,
            patch_size=patch_size,
            num_frames=num_frames,
            tubelet_size=tubelet_size,
            in_chans=len(bands),
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            decoder_embed_dim=decoder_embed_dim,
            decoder_depth=decoder_depth,
            decoder_num_heads=decoder_num_heads,
            mlp_ratio=4.,
            norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
            norm_pix_loss=False)

In [25]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n--> Model has {total_params:,} parameters.\n")


--> Model has 112,639,488 parameters.



In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [27]:
checkpoint = 'Prithvi_100M.pt'

state_dict = torch.load(checkpoint, map_location=device)
# discard fixed pos_embedding weight
del state_dict['pos_embed']
del state_dict['decoder_pos_embed']
model.load_state_dict(state_dict, strict=False)
print(f"Loaded checkpoint from {checkpoint}")

FileNotFoundError: [Errno 2] No such file or directory: 'Prithvi_100M.pt'

In [None]:
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 432.562MB


In [None]:
model

MaskedAutoencoderViT(
  (patch_embed): PatchEmbed(
    (proj): Conv3d(6, 768, kernel_size=(1, 16, 16), stride=(1, 16, 16))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        

In [None]:
batch_size = 1
n_channels = 6
temporal_length = 3
height = 224
width = 224
input_size = (n_channels, temporal_length, height, width)
summary(model, input_size)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 768, 3, 14, 14]       1,180,416
          Identity-2             [-1, 588, 768]               0
        PatchEmbed-3             [-1, 588, 768]               0
         LayerNorm-4             [-1, 148, 768]           1,536
            Linear-5            [-1, 148, 2304]       1,771,776
          Identity-6          [-1, 12, 148, 64]               0
          Identity-7          [-1, 12, 148, 64]               0
            Linear-8             [-1, 148, 768]         590,592
           Dropout-9             [-1, 148, 768]               0
        Attention-10             [-1, 148, 768]               0
         Identity-11             [-1, 148, 768]               0
         Identity-12             [-1, 148, 768]               0
        LayerNorm-13             [-1, 148, 768]           1,536
           Linear-14            [-1, 14

  return F.conv3d(
