# Versi 1 : 1 Patch untuk 1 Token

In [None]:
# Cell 1
# Import semua library dasar yang dibutuhkan

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Cell 2
# Encoder 3D: bertugas mengekstraksi fitur spasial dan spektral dari patch hyperspectral.

class SpectralSpatialEncoder3D(nn.Module):
    def __init__(self, embedding_dim=256, init_channels=32):
        super().__init__()
        # Lapisan konvolusi pertama: kernel 3x3x20, stride sama (non-overlapping)
        self.conv1 = nn.Conv3d(1, init_channels, kernel_size=(3,3,20), stride=(3,3,20), padding=0)
        self.bn1 = nn.BatchNorm3d(init_channels)
        self.relu = nn.ReLU(inplace=True)

        # Lapisan konvolusi kedua: memperkecil spasial menjadi 1x1
        self.conv2 = nn.Conv3d(init_channels, init_channels*2, kernel_size=(3,3,1), stride=(3,3,1), padding=0)
        self.bn2 = nn.BatchNorm3d(init_channels*2)

        # Lapisan fully connected untuk memproyeksikan ke embedding 1D
        self.fc = nn.Linear(init_channels*2, embedding_dim)

    def forward(self, x):
        # x: tensor input
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        # Pooling agar hasil akhirnya 1 nilai per channel (rata-rata)
        x = F.adaptive_avg_pool3d(x, (1,1,1))
        x = torch.flatten(x, 1)  # dari (B, C,1,1,1) -> (B,C)
        x = self.fc(x)           # ubah ke panjang embedding_dim (misal 256)
        return x


In [None]:
# Cell 3
# Fungsi ini menambahkan noise Gaussian ke vektor embedding hasil encoder
# untuk menghasilkan pasangan positif (positive key)

class LatentAugmentor:
    def __init__(self, sigma=0.1, device='cpu'):
        self.sigma = sigma
        self.device = device

    def __call__(self, features):
        """
        features: tensor berukuran (B, D)
        menghasilkan augmented_features: (B, D)
        """
        noise = torch.randn_like(features, device=features.device) * self.sigma
        return features + noise


In [None]:
# Cell 4
# Transformer Encoder sederhana.
# Menerima input dalam bentuk (batch, jumlah_token, dimensi_embedding)
# Dalam desain sekarang, jumlah_token = 1 (karena 1 patch = 1 token)

class SimpleTransformerEncoder(nn.Module):
    def __init__(self, embed_dim=256, num_heads=8, num_layers=2, mlp_dim=512, dropout=0.1):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=mlp_dim,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(layer, num_layers=num_layers)

    def forward(self, x):
        """
        x: tensor (B, T, D)
        T = jumlah token (di sini 1)
        """
        return self.transformer(x)


In [None]:
# Cell 5
# Projection head memetakan output dari transformer ke dimensi ruang latent
# tempat dilakukan perhitungan kesamaan (cosine similarity)

class ProjectionHead(nn.Module):
    def __init__(self, in_dim=256, proj_dim=128, hidden_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, proj_dim)
        )

    def forward(self, x):
        return self.net(x)  # hasil akhir berukuran (B, proj_dim)


In [None]:
# Cell 6
# Implementasi fungsi loss InfoNCE
# Mengukur kemiripan antar pasangan (query, positive key) dalam satu batch

def info_nce_loss(q, k, temperature=0.1):
    """
    q: queries (B, D)
    k: positive keys (B, D)
    """
    # Normalisasi supaya perbandingan berbasis arah (cosine similarity)
    q = F.normalize(q, dim=1)
    k = F.normalize(k, dim=1)

    # Hitung kesamaan antar semua pasangan dalam batch
    logits = torch.matmul(q, k.t()) / temperature
    labels = torch.arange(logits.size(0), device=logits.device)
    loss = F.cross_entropy(logits, labels)
    return loss


In [None]:
# Cell 7
# Load dataset hasil preprocessing
patch_class0 = np.load("patch_class0.npy")  # kelas non-oil spill
patch_class1 = np.load("patch_class1.npy")  # kelas oil spill

# Cek ukuran masing-masing dataset
print("Class 0 shape:", patch_class0.shape)
print("Class 1 shape:", patch_class1.shape)

In [None]:
# Cell 8
# Gabungkan semua patch menjadi satu array
X_all = np.concatenate([patch_class0, patch_class1], axis=0)

# Buat label
y_all = np.concatenate([
    np.zeros(len(patch_class0)),  # label 0 untuk class 0
    np.ones(len(patch_class1))    # label 1 untuk class 1
])

print("Total samples:", X_all.shape[0])
print("Labels shape:", y_all.shape)


In [None]:
# Cell 9
# Ubah dari numpy ke tensor
X_tensor = torch.tensor(X_all, dtype=torch.float32)
y_tensor = torch.tensor(y_all, dtype=torch.long)

# Ubah bentuk ke format Conv3D (N, C, D, H, W)
X_tensor = X_tensor.unsqueeze(1).permute(0, 1, 4, 2, 3)

print("Tensor shape setelah permute:", X_tensor.shape)


In [None]:
# Cell 10
# Split train dan validation

# Tentukan ukuran train dan validation
train_size = int(0.8 * len(X_tensor))
val_size = len(X_tensor) - train_size

# Buat indeks acak untuk memastikan X dan y sejajar
indices = torch.randperm(len(X_tensor))
train_idx = indices[:train_size]
val_idx = indices[train_size:]

# Bagi data berdasarkan indeks yang sama
train_X = X_tensor[train_idx]
train_y = y_tensor[train_idx]
val_X = X_tensor[val_idx]
val_y = y_tensor[val_idx]

# Buat TensorDataset
train_dataset = TensorDataset(train_X, train_y)
val_dataset = TensorDataset(val_X, val_y)

# Buat DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"Train samples: {len(train_dataset)} | Validation samples: {len(val_dataset)}")


In [None]:
# Cell 11
# Inisialisasi perangkat dan model
# Inisialisasi perangkat
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Inisialisasi model
encoder = SpectralSpatialEncoder3D(embedding_dim=256).to(device)
augmentor = LatentAugmentor(sigma=0.08)
transformer = SimpleTransformerEncoder(embed_dim=256).to(device)
proj_head = ProjectionHead(in_dim=256, proj_dim=128).to(device)

# Optimizer
params = list(encoder.parameters()) + list(transformer.parameters()) + list(proj_head.parameters())
optimizer = optim.AdamW(params, lr=1e-4, weight_decay=0.01)



In [None]:
# Cell 12
# Training loop dengan validasi per epoch (dengan checkpoint & resume)

import os
import time
from tqdm import tqdm

# ==== PARAMETER ====
START_EPOCH = 1          # default, akan ditimpa otomatis jika ada checkpoint
NUM_EPOCHS = 20
temperature = 0.1
best_val_loss = float('inf')
checkpoint_path = "checkpoint_sst_transformer.pt"

# ==== Jika checkpoint ada, lanjutkan dari sana ====
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    encoder.load_state_dict(checkpoint["encoder_state"])
    transformer.load_state_dict(checkpoint["transformer_state"])
    proj_head.load_state_dict(checkpoint["proj_head_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    START_EPOCH = checkpoint["epoch"] + 1
    best_val_loss = checkpoint["best_val_loss"]
    print(f"[OK] Checkpoint ditemukan. Melanjutkan dari epoch {START_EPOCH}.")
else:
    print("[MAAF] Tidak ditemukan checkpoint. Memulai training dari awal.")

# ==== Mulai Training ====
for epoch in range(START_EPOCH, NUM_EPOCHS+1):  # Start awal di 1, hingga nanti di 21 - 1
    start_time = time.time()

    # ------------------------
    # MODE TRAIN
    # ------------------------
    encoder.train()
    transformer.train()
    proj_head.train()
    total_train_loss = 0.0

    train_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS} [Train]", leave=True)
    for batch_X, _ in train_bar:
        batch_X = batch_X.to(device)

        # 1. Encode
        features = encoder(batch_X)

        # 2. Latent augmentation
        aug_features = augmentor(features)

        # 3. Transformer + Projection
        z_orig = proj_head(transformer(features.unsqueeze(1)))
        z_aug  = proj_head(transformer(aug_features.unsqueeze(1)))

        # 4. InfoNCE loss
        loss = info_nce_loss(z_orig, z_aug, temperature=temperature)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        train_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_train_loss = total_train_loss / len(train_loader)

    # ------------------------
    # MODE VALIDASI
    # ------------------------
    encoder.eval()
    transformer.eval()
    proj_head.eval()
    total_val_loss = 0.0

    val_bar = tqdm(val_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS} [Val]", leave=True)
    with torch.no_grad():
        for batch_X, _ in val_bar:
            batch_X = batch_X.to(device)
            features = encoder(batch_X)
            aug_features = augmentor(features)
            z_orig = proj_head(transformer(features.unsqueeze(1)))
            z_aug  = proj_head(transformer(aug_features.unsqueeze(1)))
            loss = info_nce_loss(z_orig, z_aug, temperature=temperature)
            total_val_loss += loss.item()
            val_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_val_loss = total_val_loss / len(val_loader)
    elapsed = time.time() - start_time

    print(f"Epoch [{epoch}/{NUM_EPOCHS}] "
          f"Train Loss: {avg_train_loss:.4f} | "
          f"Val Loss: {avg_val_loss:.4f} | "
          f"Time: {elapsed:.2f}s")

    # ------------------------
    # SIMPAN CHECKPOINT (checkpoint memiliki file-nya sendiri)
    # ------------------------
    checkpoint = {
        "epoch": epoch,
        "encoder_state": encoder.state_dict(),
        "transformer_state": transformer.state_dict(),
        "proj_head_state": proj_head.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "best_val_loss": best_val_loss
    }
    torch.save(checkpoint, checkpoint_path)

    # Simpan model terbaik (model terbaik memiliki file-nya sendiri)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(checkpoint, "best_sst_transformer.pt")
        print("OK Model terbaik disimpan.")

print("Alhamdulillah, Training selesai.")

