### Imports and Dependencies

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from skimage.transform import resize
import zarr
import cv2
from tqdm import tqdm
import pandas as pd


### Preprocessing functions for SAR Data

In [None]:
def bilateral_denoise(image):
    return cv2.bilateralFilter(image.astype(np.float32), d=9, sigmaColor=75, sigmaSpace=75)

def apply_clahe(image):
    norm = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    return clahe.apply(norm)

def sobel_edges(image):
    grad_x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3)
    grad_y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3)
    magnitude = cv2.magnitude(grad_x, grad_y)
    return magnitude

def downsample(image, target_shape=(64, 64)):
    return resize(image, target_shape, mode='reflect', preserve_range=True, anti_aliasing=True)


### Dataset for Preprocessed SAR Cubes

In [None]:
class SimCLRSARDataset(Dataset):
    def __init__(self, root_dir, folder_limit=None):
        self.root_dir = root_dir
        self.samples = []
        zarr_folders = sorted(os.listdir(root_dir))[:folder_limit]

        for folder in zarr_folders:
            folder_path = os.path.join(root_dir, folder)
            if not folder_path.endswith(".zarr"):
                continue
            z = zarr.open(folder_path, mode='r')
            bands = z['bands']
            for i in range(bands.shape[0]):
                self.samples.append((folder_path, i))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        zarr_path, sample_idx = self.samples[idx]
        z = zarr.open(zarr_path, mode='r')
        sar_cube = z['bands'][sample_idx]  # [T=4, 2, H, W]

        processed = []
        for t in range(sar_cube.shape[0]):
            vv = sar_cube[t, 0]
            vh = sar_cube[t, 1]

            vv_denoised = bilateral_denoise(vv)
            vh_denoised = bilateral_denoise(vh)
            vv_clahe = apply_clahe(vv_denoised)
            edge_map = sobel_edges(vv_clahe)

            processed.append(np.stack([vv_denoised, vh_denoised, vv_clahe, edge_map], axis=0))

        processed = np.stack(processed)  # [T=4, 4, H, W]
        mean_img = np.mean(processed, axis=0)
        std_img = np.std(processed, axis=0)
        diff_img = processed[-1] - processed[0]

        temporal_stack = np.concatenate([mean_img, std_img, diff_img], axis=0)
        downsampled = np.stack([downsample(temporal_stack[c]) for c in range(temporal_stack.shape[0])])

        return torch.tensor(downsampled, dtype=torch.float32), torch.tensor(downsampled, dtype=torch.float32)


### CNN + Transformer-Based Encoder Model

In [None]:
class CNNTransformerEncoder(nn.Module):
    def __init__(self, input_channels=12, compressed_dim=1024):
        super(CNNTransformerEncoder, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(input_channels, 64, 3, 2, 1),
            nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 512, 3, 2, 1),
            nn.BatchNorm2d(512), nn.ReLU(),
        )
        self.proj = nn.Linear(512, 128)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=128, nhead=8, dim_feedforward=256, batch_first=True),
            num_layers=2
        )
        self.fc = nn.Linear(128, compressed_dim)

    def forward(self, x):
        x = self.cnn(x)
        B, C, H, W = x.shape
        x = x.view(B, C, H * W).permute(0, 2, 1)
        x = self.proj(x)
        x = self.transformer(x)
        return self.fc(x.mean(dim=1))


### NT-Xent Contrastive Loss Function

In [None]:
def nt_xent_loss(z_i, z_j, temperature=0.5):
    batch_size = z_i.size(0)
    z = torch.cat([z_i, z_j], dim=0)
    z = F.normalize(z, dim=1)
    similarity = torch.mm(z, z.T)
    sim_ij = torch.diag(similarity, batch_size)
    sim_ji = torch.diag(similarity, -batch_size)
    positives = torch.cat([sim_ij, sim_ji], dim=0)

    mask = (~torch.eye(2 * batch_size, dtype=bool)).to(z.device)
    negatives = similarity[mask].view(2 * batch_size, -1)

    logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)
    labels = torch.zeros(2 * batch_size, dtype=torch.long).to(z.device)
    logits /= temperature

    return F.cross_entropy(logits, labels)


### SimCLR Training Loop

In [None]:
def train_simclr(model, dataset, epochs=10, batch_size=64, lr=1e-4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")
        for x1, x2 in pbar:
            x1, x2 = x1.to(device), x2.to(device)
            z1 = model(x1)
            z2 = model(x2)
            loss = nt_xent_loss(z1, z2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix({"loss": total_loss / (pbar.n + 1)})

    torch.save(model.state_dict(), "simclr_encoder.pth")
    print("✅ Model saved to simclr_encoder.pth")


### Extract Embeddings to CSV

In [None]:
def extract_embeddings(model, dataset, output_csv="sar_embeddings.csv"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    loader = DataLoader(dataset, batch_size=64, shuffle=False)
    all_embeddings = []

    with torch.no_grad():
        for x, _ in tqdm(loader, desc="Extracting Embeddings"):
            x = x.to(device)
            embeddings = model(x).cpu().numpy()
            all_embeddings.append(embeddings)

    all_embeddings = np.concatenate(all_embeddings, axis=0)
    df = pd.DataFrame(all_embeddings)
    df.to_csv(output_csv, index=False)
    print(f"📄 Embeddings saved to {output_csv}")


### Main Script: Training + Embedding Extraction

In [None]:
if __name__ == "__main__":
    import multiprocessing
    multiprocessing.set_start_method('spawn', force=True)

    base_path = "D:/IVP _ project/data/SSL4EO-S12-v1.1/train/S1GRD"

    print("🚀 Training on 100 folders")
    train_dataset = SimCLRSARDataset(base_path, folder_limit=100)
    model = CNNTransformerEncoder()
    train_simclr(model, train_dataset)

    print("📥 Extracting embeddings from 398 folders")
    full_dataset = SimCLRSARDataset(base_path, folder_limit=398)
    model.load_state_dict(torch.load("simclr_encoder.pth"))
    extract_embeddings(model, full_dataset)
