# ==============================================================
# Compute Embeddings of the EuroSAT-LS Dataset Using Pretrained MOE-MAE Encoder Weights
# ==============================================================

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from models.moe_mae import MOEMAE, build_model
from datasets.eurosat import EuroSATDatasetLS
from transformation.transformer import ToFloat, ZScoreNormalize
from utils.data_config import BigEarthNetInfo
from embed.compute_embed import compute_geomoemae_embeddings
from utils.data_utils import load_model

In [2]:
data_txt_train = "/mnt/storage/data/eurosat-l/eurosat-train.txt"
data_txt_val = "/mnt/storage/data/eurosat-l/eurosat-val.txt"
data_txt_test = "/mnt/storage/data/eurosat-l/eurosat-test.txt"
data_path = "/mnt/storage/data/eurosat-l/eurosat-l"
save_path = "/mnt/storage/data/eurosat-l"

In [3]:
device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )
print(f"Using device: {device}")

Using device: cuda


In [None]:
model_size = "S"
img_size = 40
patch_size = 4
in_channels = 7
checkpoint_path = "./weights/moe_mae_bigearthnet_ls/pretrained_S_best.pth"
encoder = build_model(
        size=model_size,
        img_size=img_size,
        patch_size=patch_size,
        in_chans=in_channels,
    )
model = MOEMAE(encoder).to(device)
model = load_model(model,checkpoint_path,device)
encoder = model.encoder
encoder.eval()

  checkpoint = torch.load(checkpoint_path, map_location=device)


mLiT(
  (patch_proj): Conv2d(7, 144, kernel_size=(4, 4), stride=(4, 4))
  (week_proj): Linear(in_features=2, out_features=144, bias=True)
  (hour_proj): Linear(in_features=2, out_features=144, bias=True)
  (lat_proj): Linear(in_features=2, out_features=144, bias=True)
  (lon_proj): Linear(in_features=2, out_features=144, bias=True)
  (layers): ModuleList(
    (0): MoETransformerEncoderLayer(
      (norm1): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (attn): GroupedQueryAttention(
        (q_proj): Linear(in_features=144, out_features=144, bias=True)
        (k_proj): Linear(in_features=144, out_features=72, bias=True)
        (v_proj): Linear(in_features=144, out_features=72, bias=True)
        (out_proj): Linear(in_features=144, out_features=144, bias=True)
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (proj_dropout): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((144,), eps=1e-05, elementwise_affine=True)
      (moe): MoELayer(
       

In [5]:
total_params = sum(p.numel() for p in encoder.parameters())
trainable_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)

print(f"Total mmLiT Encoder parameters: {total_params:,}")
print(f"Trainable mmLiT Encoder parameters: {trainable_params:,}")

Total mmLiT Encoder parameters: 2,366,798
Trainable mmLiT Encoder parameters: 2,366,798


In [6]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total mmLiT parameters: {total_params:,}")
print(f"Trainable mmLiT parameters: {trainable_params:,}")

Total mmLiT parameters: 2,537,562
Trainable mmLiT parameters: 2,537,562


In [7]:
bigearth_transforms = transforms.Compose(
        [
            transforms.Resize((40, 40)),
            ToFloat(),
            ZScoreNormalize(
                BigEarthNetInfo.STATISTICS["mean"],
                BigEarthNetInfo.STATISTICS["std"],
            ),
        ]
    )
train_dataset = EuroSATDatasetLS(
        root_dir = data_path,
        split_file = data_txt_train,
        transform=bigearth_transforms,
        return_one_hot=True,
        strict=False,
    )
train_dataloader = DataLoader(
        train_dataset,
        batch_size=64,
        persistent_workers=False,
        prefetch_factor=4,
        num_workers=4,
        shuffle=False,
        pin_memory=True,
        # sampler=train_sampler,
    )
val_dataset = EuroSATDatasetLS(
        root_dir = data_path,
        split_file = data_txt_val,
        transform=bigearth_transforms,
        return_one_hot=True,
        strict=False,
    )

val_dataloader = DataLoader(
        val_dataset,
        batch_size=64,
        persistent_workers=False,
        prefetch_factor=4,
        num_workers=4,
        shuffle=False,
        pin_memory=True,
    )

test_dataset = EuroSATDatasetLS(
        root_dir = data_path,
        split_file = data_txt_test,
        transform=bigearth_transforms,
        return_one_hot=True,
        strict=False,
    )

test_dataloader = DataLoader(
        test_dataset,
        batch_size=64,
        persistent_workers=False,
        prefetch_factor=4,
        num_workers=4,
        shuffle=False,
        pin_memory=True,
    )

In [None]:
npz_path_train = f"{save_path}/x_y_train_geomoemae_{model_size}_embed_pos_500epochs.npz"
npz_path_val = f"{save_path}/x_y_val_geomoemae_{model_size}_embed_pos_500epochs.npz"
npz_path_test = f"{save_path}/x_y_test_geomoemae_{model_size}_embed_pos_500epochs.npz"

In [None]:
x_train, y_train = compute_geomoemae_embeddings(
    encoder,
    train_dataloader,
    device,
 )

Computing embeddings:   0%|          | 0/254 [00:00<?, ?it/s]

Computing embeddings: 100%|██████████| 254/254 [01:17<00:00,  3.28it/s]


In [10]:
print ("X train shape: ", x_train.shape)
print ("Y train shape: ", y_train.shape)

X train shape:  (16200, 105, 144)
Y train shape:  (16200, 10)


In [11]:
x_train = x_train.reshape(x_train.shape[0], -1)

In [12]:
np.savez(
        npz_path_train,
        x_train=x_train,
        y_train=y_train.astype(np.int16),
    )

In [13]:
del x_train
del y_train

In [None]:
x_val, y_val = compute_geomoemae_embeddings(
    encoder,
    val_dataloader,
    device)

Computing embeddings:   0%|          | 0/85 [00:00<?, ?it/s]

Computing embeddings: 100%|██████████| 85/85 [00:23<00:00,  3.59it/s]


In [15]:
x_val = x_val.reshape(x_val.shape[0], -1)

In [16]:
np.savez(
        npz_path_val,
        x_val=x_val,
        y_val=y_val.astype(np.int16),
    )

In [17]:
del x_val
del y_val

In [None]:
x_test, y_test = compute_geomoemae_embeddings(
    encoder,
    test_dataloader,
    device,
 )

Computing embeddings:   0%|          | 0/85 [00:00<?, ?it/s]

Computing embeddings: 100%|██████████| 85/85 [00:24<00:00,  3.50it/s]


In [19]:
x_test = x_test.reshape(x_test.shape[0], -1)

In [20]:
np.savez(
        npz_path_test,
        x_test=x_test,
        y_test=y_test.astype(np.int16),
    )

In [None]:
del x_test
del y_test