In [None]:
import os
import torch
from torch.utils.data import Dataset
import rasterio
import numpy as np
from rasterio.enums import Resampling
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchvision.models import resnet50
from torch.utils.data import DataLoader
import torchvision.transforms as T
from lightning.pytorch import Trainer
from torchvision import transforms  
from torchgeo.trainers.moco import MoCoTask
from torchgeo.models import ResNet50_Weights
import kornia.augmentation as K
import torch.nn.functional as F
import torchgeo.transforms as T
from lightning.pytorch.loggers import CSVLogger
import glob
import os
import shutil
import random
random.seed(42) 


In [None]:
# Info from metadata.json
# B1	1830 × 1830	60 m
# B2	10980 × 10980	10 m
# B3	10980 × 10980	10 m
# B4	10980 × 10980	10 m
# B5	5490 × 5490	20 m
# B6	5490 × 5490	20 m
# B7	5490 × 5490	20 m
# B8	10980 × 10980	10 m
# B8A	5490 × 5490	20 m
# B9	1830 × 1830	60 m
# B11	5490 × 5490	20 m
# B12	5490 × 5490	20 m

### Sub-sample 3k Data

In [None]:
# root_dir = "/Volumes/WD_Rabina/competition/extracted_data/s2a"
# # List all folders
# scenes = sorted(glob.glob(os.path.join(root_dir, "*/")))
# # print(scenes)
# print(len(scenes))

# no_of_files=3000

# # Randomly select 3000 scenes (without replacement)
# selected_scenes = random.sample(scenes, k=3000)

# print(f"Total selected scenes: {len(selected_scenes)}")
# # print(selected_scenes[:10])  # show first 10 for sanity check

# # Path to new folder where selected scenes will be copied
# destination_root = "data/s2a_3k_sample"
# os.makedirs(destination_root, exist_ok=True)  # create folder if it doesn't exist

# # Copy each selected folder
# count=0
# for scene_path in selected_scenes:
#     # Get folder name only (e.g., "000015")
#     folder_name = os.path.basename(os.path.normpath(scene_path))
    
#     # Destination path
#     dest_path = os.path.join(destination_root, folder_name)
#     count=count+1
#     print(count)
#     # Copy folder and all its contents
#     shutil.copytree(scene_path, dest_path)

# print(f"Copied {len(selected_scenes)} folders to {destination_root}")



### Settings

In [None]:
target_size = 224
target_batch_size=64 #prefer 256 or 128
logger = CSVLogger("logs", name="moco_run")
target_num_workers=0
target_max_epoch=3

use_peft = False  

# Default
# # Similar to SimCLR: https://arxiv.org/abs/2002.05709
# aug1 = aug2 = K.AugmentationSequential(
#     K.RandomResizedCrop(size=(size, size), scale=(0.2, 1)),
#     K.RandomBrightness(brightness=(0.6, 1.4), p=0.8),
#     K.RandomContrast(contrast=(0.6, 1.4), p=0.8),
#     T.RandomGrayscale(weights=weights, p=0.2),
#     K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=0.5),
#     K.RandomHorizontalFlip(),
#     K.RandomVerticalFlip(),  # added
#     data_keys=['input'],
# )

aug = K.AugmentationSequential(
    K.RandomResizedCrop(size=(target_size, target_size), scale=(0.4, 1.0)),
    K.RandomHorizontalFlip(),
    K.RandomVerticalFlip(),
    K.RandomGaussianBlur(kernel_size=(7,7), sigma=(0.1, 1.5), p=0.3),
    K.RandomBrightness(brightness=(0.85, 1.15), p=0.5),
    data_keys=['input'],
)

### Helper Functions

In [None]:
class SSLDataset(Dataset):
    def __init__(self, scenes, bands, transforms=None):
        """
        Args:
            scenes (list): List of scene folder paths.
            bands (list): List of band names (e.g., ["B1","B2"]).
            patch_size (tuple): Size of random crop (H, W).
            transforms (callable, optional): Optional transform to apply to patches.
        """
        self.scenes = scenes
        self.bands = bands
        # self.patch_size = patch_size
        self.transforms = transforms
        self.target_h= None
        self.target_w = None
        

        # Precompute all timestamp paths to treat each timestamp as a sample
        self.samples = []
        for scene_path in scenes:
            timestamps = sorted([
                d for d in os.listdir(scene_path)
                if os.path.isdir(os.path.join(scene_path, d))
            ])
            for ts in timestamps:
                self.samples.append(os.path.join(scene_path, ts))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ts_path = self.samples[idx]

        with rasterio.open(os.path.join(ts_path, "B2.tif")) as src:
            target_h, target_w = src.height, src.width
            # print(target_h,target_w, "target width and height" )

        band_arrays = []

        for b in self.bands:
            path = os.path.join(ts_path, f"{b}.tif")
            with rasterio.open(path) as src:
                if src.height == target_h and src.width == target_w:
                    arr = src.read(1).astype(np.float32)
                else:
                    arr = src.read(
                        1,
                        out_shape=(target_h, target_w),
                        resampling=Resampling.bilinear
                    ).astype(np.float32)

            band_arrays.append(arr)

        # Insert fake B10
        insert_idx = 10
        b10_pad = np.zeros((target_h, target_w), dtype=np.float32)
        band_arrays.insert(insert_idx, b10_pad)

        img = np.stack(band_arrays, axis=0)

        # img_patch = self._random_crop(img)

        patch_tensor = torch.tensor(img, dtype=torch.float32)

        if self.transforms:
            patch_tensor = self.transforms(patch_tensor)

        return {"image": patch_tensor}

def calculate_stats(dataset, n_samples=500):
    mean = 0
    std = 0
    print(len(dataset))
    n = min(len(dataset), n_samples)

    for i in range(n):
        sample = dataset[i]
        img = sample["image"]   # <-- key fix

        mean += img.mean(dim=(1, 2))
        std += img.std(dim=(1, 2))

    mean /= n
    std /= n

    return mean, std

In [None]:
root_dir = "data/s2a"
# List all folders
scenes = sorted(glob.glob(os.path.join(root_dir, "*/")))
# scenes = ["data/s2a/000015", "data/s2a/000016"]  # list of scene folders
bands = ["B1","B2","B3","B4","B5","B6","B7","B8","B8A","B9","B11","B12"]

In [None]:
# One time run to get mean and std
temp_dataset = SSLDataset(scenes, bands)
mean, std = calculate_stats(temp_dataset)
print(mean)
print(std)

In [None]:
mean = [2358.7412, 2402.7629, 2580.9255, 2614.2227, 3057.6877, 3578.1008,
        3796.8345, 3795.6868, 3947.5913, 4833.6362,    0.0000, 3379.1743,
        2666.4465]
std = [2994.4861, 2847.0354, 2542.9307, 2411.1196, 2399.0249, 2137.6804,
        2036.8357, 2042.7140, 1957.9615, 3559.4121,    0.0000, 1535.7960,
        1393.8278]

# to avoid 0 std
std = [max(s, 1e-5) for s in std]   

# define transform
transform = transforms.Compose([
    transforms.Resize((target_size, target_size)),
    transforms.Normalize(mean=mean, std=std)
])

In [None]:
dataset = SSLDataset(scenes, bands, transforms=transform)
print(len(dataset))
print(dataset[0]['image'].shape)

In [None]:
data_loader = DataLoader(dataset, batch_size=target_batch_size, shuffle=True, num_workers=target_num_workers)
num_batches = len(data_loader)
print("Number of batches:", num_batches)

In [None]:
task = MoCoTask(
    model="resnet50",      
    weights=ResNet50_Weights.SENTINEL2_ALL_MOCO,
    in_channels=13,       
    version=2,             # MoCo v2
    size=target_size,          
    augmentation1=aug,
    augmentation2=aug,
)

# -----------------------------
# PEFT / Full Fine-Tuning Logic
# -----------------------------
if use_peft:
    print("Using PEFT: freezing backbone except last block, training projection head...")
    for name, param in task.backbone.named_parameters():
        if "layer4" in name:      # optionally fine-tune last residual block
            # print("Layer 4 trainable")
            param.requires_grad = True
        else:
            param.requires_grad = False
else:
    print("Full fine-tuning: backbone and projection head trainable...")
    for param in task.backbone.parameters():
        param.requires_grad = True

# Momentum backbone always frozen
for param in task.backbone_momentum.parameters():
    param.requires_grad = False

# Projection head always trainable
for param in task.projection_head.parameters():
    param.requires_grad = True

trainer = Trainer(
    max_epochs=target_max_epoch,
    enable_progress_bar=True, 
    log_every_n_steps=1,
    precision=16,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    logger=logger)
trainer.fit(task, data_loader)

[W205 17:00:02.187859000 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.


In [None]:
print(task.trainer.logged_metrics)

In [None]:
# Save the backbone encoder only
torch.save(task.backbone.state_dict(), "ssl_encoder.pth")

In [None]:
trainer.save_checkpoint("ssl_full_ckpt.ckpt")