In [1]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter

# Add the src directory to the Python path
sys.path.append(str(Path().resolve().parent / "src"))

import config
from config import FM_MODEL_DIR
from model.train import LandslideDataset

# Import necessary libraries
import torch
from torch import nn
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader, random_split, Subset

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# --- Configuration ---
# Define the path to your downloaded model file
# Make sure to replace 'B13_vitb16_fgmae_ep99.pth' with the actual path if it's different.
model_path = FM_MODEL_DIR / 'reBEN_resnet18-all-v0.2.0'


Using device: cuda


In [2]:
# Cell 2: Load ViT-B/16 Architecture
model = torch.load(model_path, weights_only=False)
model.to(device)


ResNet(
  (conv1): Conv2d(12, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, 

In [3]:
from torchinfo import summary

# Example input size for ViT: (batch_size, channels, height, width)
summary(model, input_size=(1, 12, 120, 120))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 19]                   --
├─Conv2d: 1-1                            [1, 64, 60, 60]           37,632
├─BatchNorm2d: 1-2                       [1, 64, 60, 60]           128
├─ReLU: 1-3                              [1, 64, 60, 60]           --
├─MaxPool2d: 1-4                         [1, 64, 30, 30]           --
├─Sequential: 1-5                        [1, 64, 30, 30]           --
│    └─BasicBlock: 2-1                   [1, 64, 30, 30]           --
│    │    └─Conv2d: 3-1                  [1, 64, 30, 30]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 30, 30]           128
│    │    └─Identity: 3-3                [1, 64, 30, 30]           --
│    │    └─ReLU: 3-4                    [1, 64, 30, 30]           --
│    │    └─Identity: 3-5                [1, 64, 30, 30]           --
│    │    └─Conv2d: 3-6                  [1, 64, 30, 30]           36,864
│

In [4]:
class LandslideDataset(Dataset):
    """Dataset for loading processed landslide detection images - loads all data into memory."""
    
    def __init__(self, image_dir, csv_path, transform=None, device="cpu"):
        """
        Args:
            image_dir: Directory containing processed .npy files
            csv_path: Path to CSV file with image IDs and labels
            transform: Optional transform to apply to images
            device: Device to load data onto ("cpu" or "cuda")
        """
        self.image_dir = Path(image_dir)
        self.transform = transform
        self.device = device
        
        # Load CSV data
        self.df = pd.read_csv(csv_path)
        self.image_ids = self.df['ID'].values
        self.labels = self.df['label'].values.astype(np.float32)
        
        # Load all images into memory at once
        self.images = []
        self.valid_indices = []
        
        for i, img_id in enumerate(tqdm(self.image_ids, desc="Loading images")):
            img_path = self.image_dir / f"{img_id}.npy"
            if img_path.exists():
                # Load image and convert to tensor
                image = np.load(img_path).astype(np.float32)
                image = torch.from_numpy(image).permute(2, 0, 1)  # (C, H, W)
                
                # Move to device if specified
                if device != "cpu":
                    image = image.to(device)
                
                self.images.append(image)
                self.valid_indices.append(i)
            else:
                print(f"Warning: {img_path} not found")
        
        print(f"Loaded {len(self.valid_indices)} valid images out of {len(self.image_ids)}")
        print(f"Total memory usage: {sum(img.element_size() * img.nelement() for img in self.images) / 1024**3:.2f} GB")
    
    def __len__(self):
        return len(self.valid_indices)
    
    def __getitem__(self, idx):
        # Get valid index
        valid_idx = self.valid_indices[idx]
        label = self.labels[valid_idx]
        image = self.images[idx]  # Already loaded in memory
        
        if self.transform:
            image = self.transform(image)
        
        return image, label
    
    def __getitem__(self, idx):
        # Get valid index
        valid_idx = self.valid_indices[idx]
        label = self.labels[valid_idx]
        image = self.images[idx]  # Already loaded in memory

        # Create a 12-channel zero tensor
        img_12ch = torch.zeros((12, image.shape[1], image.shape[2]), dtype=image.dtype)

        # Define mapping from your 6 bands to model input bands
        band_map = {
            0: 3,
            1: 2,
            2: 1,
            3: 7,
            4: 11,
            5: 10
        }

        for src_idx, dst_idx in band_map.items():
            img_12ch[dst_idx] = image[src_idx]

        # Apply optional transform (e.g., resizing)
        if self.transform:
            img_12ch = self.transform(img_12ch)

        return img_12ch, label


In [5]:
# Create dataset with data loaded into memory
print("Loading dataset into memory...")
dataset = LandslideDataset(
    image_dir=config.PROCESSED_TRAIN_IMAGE_DIR,
    csv_path=config.TRAIN_CSV_PATH,
    device=device
)

dataset.transform = v2.Resize(size=120)

data_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    generator=torch.Generator().manual_seed(config.SEED)
)

Loading dataset into memory...


Loading images: 100%|██████████| 7147/7147 [00:55<00:00, 127.77it/s]

Loaded 7147 valid images out of 7147
Total memory usage: 1.31 GB





In [6]:
def flatten(xss):
    return [x for xs in xss for x in xs]
all_targets = []
class_preds = []
for batch_idx, (images, targets) in enumerate(data_loader):
    all_targets.append(targets.numpy())
    
    output = model(images.to(device))
    preds = [np.argmax(x) for x in output.cpu().data.numpy()]
    class_preds.append(preds)

all_targets = np.concatenate(all_targets)
class_preds = np.array(flatten(class_preds))

In [7]:
Counter(class_preds[all_targets == 1])

Counter({np.int64(1): 440,
         np.int64(17): 178,
         np.int64(3): 141,
         np.int64(12): 140,
         np.int64(6): 134,
         np.int64(8): 72,
         np.int64(5): 42,
         np.int64(15): 36,
         np.int64(10): 30,
         np.int64(18): 19,
         np.int64(11): 9,
         np.int64(7): 9,
         np.int64(13): 3,
         np.int64(9): 1,
         np.int64(14): 1})

In [8]:
Counter(class_preds[all_targets == 0])

Counter({np.int64(1): 1783,
         np.int64(12): 926,
         np.int64(17): 761,
         np.int64(6): 744,
         np.int64(3): 490,
         np.int64(8): 293,
         np.int64(18): 286,
         np.int64(10): 169,
         np.int64(15): 135,
         np.int64(5): 135,
         np.int64(11): 78,
         np.int64(9): 46,
         np.int64(13): 23,
         np.int64(7): 15,
         np.int64(16): 5,
         np.int64(14): 2,
         np.int64(0): 1})

In [9]:
# Remove classification layer:
model.fc = nn.Identity()


In [10]:
embedding_mtx = []
for batch_idx, (images, targets) in enumerate(data_loader):
    output = model(images.to(device))
    embedding_mtx.append(output.cpu().data.numpy())
embedding_mtx = np.concatenate(embedding_mtx)

In [11]:
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=8, random_state=0, n_init="auto").fit_predict(embedding_mtx)

In [12]:
print(Counter(kmeans))
print(Counter(kmeans[all_targets == 1]))

Counter({np.int32(1): 3110, np.int32(5): 954, np.int32(3): 821, np.int32(2): 748, np.int32(6): 666, np.int32(4): 397, np.int32(7): 252, np.int32(0): 199})
Counter({np.int32(1): 715, np.int32(5): 162, np.int32(6): 161, np.int32(3): 127, np.int32(0): 59, np.int32(7): 30, np.int32(2): 1})


In [14]:
dataset.image_ids[(kmeans == 4)] # & (all_targets == 0) 2, 3

array(['ID_81CCTX', 'ID_K7DZF2', 'ID_THLFOS', 'ID_HXP5ML', 'ID_Z6MT2M',
       'ID_KAE13A', 'ID_PUZLZB', 'ID_K75Q1F', 'ID_1Q4AOO', 'ID_R4XKHT',
       'ID_R6I7Z5', 'ID_8XU9PK', 'ID_3OW9N3', 'ID_K3BMDJ', 'ID_29RVBQ',
       'ID_IK14T0', 'ID_PCT6XN', 'ID_OQP2JZ', 'ID_69AWCF', 'ID_9DB1U9',
       'ID_SV2JH6', 'ID_JRVM9X', 'ID_IELG8M', 'ID_JJRTOB', 'ID_OXHKIL',
       'ID_Y53668', 'ID_QZJO8B', 'ID_6481LY', 'ID_Q1KVLC', 'ID_O9KSMX',
       'ID_8GUH5A', 'ID_E66YLO', 'ID_S6UVD7', 'ID_EN6LWL', 'ID_8K8LZN',
       'ID_V62VPV', 'ID_JGIF2L', 'ID_M5AOEF', 'ID_030E20', 'ID_8NQ8MY',
       'ID_7K66L0', 'ID_V1LZYY', 'ID_SDP8AJ', 'ID_A9PB6J', 'ID_E6KKIQ',
       'ID_40KZ27', 'ID_KNHNZ3', 'ID_7VUJ7O', 'ID_CPPZ3T', 'ID_AUSMPY',
       'ID_BOSI3L', 'ID_K45WOE', 'ID_CHF8N2', 'ID_O7LSKU', 'ID_LM08R7',
       'ID_OUG1MA', 'ID_M8RTRC', 'ID_JKJ7X1', 'ID_5NZDM6', 'ID_2758EP',
       'ID_TTPIYG', 'ID_UHHZES', 'ID_I6GD66', 'ID_OYZQCC', 'ID_5B2Q7W',
       'ID_NJ69O3', 'ID_K428U7', 'ID_HMSG74', 'ID_FFMBFC', 'ID_Q