In [1]:
#!/usr/bin/env python
# ============================================================
#  SITS-BERT  ·  California-Labeled  ·  6-band data prep
# ------------------------------------------------------------
#  Reads  : /kaggle/input/california-labeled/{Train,Validate,Test}.csv
#  Writes : /kaggle/working/data/6_features_{Train,Validate,Test}.csv
# ------------------------------------------------------------
#  Each output row = 24 × 7 floats  (B1-6 + DOY)  +  label
# ============================================================

import csv, subprocess, sys, shutil, random
from pathlib import Path
import numpy as np

In [2]:
# ───────────────────────────────────
# 0 · CONFIG
# ───────────────────────────────────
SEED       = 42
SRC_DIR    = Path("/kaggle/input/california-labeled")
DST_DIR    = Path("/kaggle/working/data")
SITS_REPO  = Path("/kaggle/working/SITS-BERT")


In [3]:
!git clone -q https://github.com/linlei1214/SITS-BERT.git

In [4]:
!ls SITS-BERT/checkpoints/pretrain/

checkpoint.bert.pth


In [5]:
!mkdir checkpoints_finetune

In [6]:
!python /kaggle/working/SITS-BERT/code/finetuning.py \
  --file_path /kaggle/input/california-labeled/ \
  --pretrain_path SITS-BERT/checkpoints/pretrain/ \
  --finetune_path /kaggle/working/checkpoints_finetune/ \
  --num_features 10 \
  --max_length 64 \
  --num_classes 13 \
  --epochs 100 \
  --batch_size 128 \
  --hidden_size 256 \
  --layers 3 \
  --attn_heads 8 \
  --learning_rate 2e-4 \
  --dropout 0.1

2025-06-17 18:47:17.997347: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750186038.199726      79 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750186038.254220      79 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Loading Data sets...
training samples: 1300, validation samples: 1300, testing samples: 318588
Creating Dataloader...
Initialing SITS-BERT...
Loading pre-trained model parameters...
Creating Downstream Task Trainer...
Fine-tuning SITS-BERT...
EP0, train_OA=23.38, train_Kappa=0.170, validate_OA=43.46, validate_Kappa=0.388
EP:0 Model Saved on: /kaggle/working/checkpoints_finetune/checkpoint.tar
EP1, train_OA=53.31, train_Kappa=0.494, v

In [7]:
# %matplotlib inline

# import sys, inspect, csv, numpy as np, torch
# from pathlib import Path
# from torch.utils.data import Dataset, DataLoader
# from sklearn.metrics import classification_report, accuracy_score, cohen_kappa_score
# import matplotlib.pyplot as plt

# # ───────────────────────────── Configuration ──────────────────────────────
# ROOT       = Path("/kaggle/working")
# CSV_TEST   = Path("/kaggle/input/california-labeled/Test.csv")
# CKPT_FILE  = ROOT / "checkpoints_finetune" / "checkpoint.tar"
# SITS_REPO  = ROOT / "SITS-BERT" / "code"
# BATCH_SIZE = 128
# WINDOW_LEN = 64
# STRIDE     = 1   # same as during training
# DEVICE     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # ─────────── Add repo to path & import ────────────────────────────────────
# sys.path.append(str(SITS_REPO))
# from model.classification_model import SBERTClassification
# from model.bert                 import SBERT

# # ─────────── Build SBERT helper ──────────────────────────────────────────
# _sig = inspect.signature(SBERT.__init__)
# def make_sbert(num_features: int) -> torch.nn.Module:
#     base = {
#         "num_features": num_features,
#         "hidden":       256,
#         "n_layers":     3,
#         "attn_heads":   8,
#         "dropout":      0.1,
#     }
#     kwargs = {k: v for k, v in base.items() if k in _sig.parameters}
#     if (missing := [p.name for p in _sig.parameters.values()
#                     if p.default is inspect._empty
#                        and p.name not in kwargs
#                        and p.name != "self"]):
#         raise RuntimeError(f"❌ Missing SBERT init args: {missing}")
#     return SBERT(**kwargs)

# # ─────────── Sliding-window Dataset ────────────────────────────────────
# class WindowedDataset(Dataset):
#     def __init__(self, path: Path, win_len: int = WINDOW_LEN, stride: int = STRIDE):
#         rows = list(csv.reader(path.open()))
#         self.x, self.doy, self.mask, self.y = [], [], [], []

#         for r in rows:
#             if not r: 
#                 continue
#             arr   = np.asarray(r[:-1], dtype=np.float32)
#             label = int(r[-1])
#             n_feat = 11       # 10 bands + DOY
#             spec   = n_feat-1
#             T      = arr.size // n_feat

#             # compute start indices
#             if T < win_len:
#                 starts = [0]
#             else:
#                 starts = list(range(0, T - win_len + 1, stride))
#                 if (T - win_len) % stride != 0:
#                     starts.append(T - win_len)

#             for s in starts:
#                 block = arr[s*n_feat : s*n_feat + win_len*n_feat]
#                 real = min(win_len, block.size//n_feat)
#                 if block.size < win_len*n_feat:
#                     pad = np.zeros(win_len*n_feat - block.size, dtype=np.float32)
#                     block = np.concatenate([block, pad], axis=0)

#                 seq = block.reshape(win_len, n_feat)
#                 self.x.append(seq[:, :spec] / 10000.0)
#                 self.doy.append(seq[:, spec].astype(np.int64))
#                 m = np.zeros(win_len, bool)
#                 m[:real] = True
#                 self.mask.append(m)
#                 self.y.append(label)

#         # tensorify
#         self.x    = torch.tensor(np.stack(self.x),    dtype=torch.float32)
#         self.doy  = torch.tensor(np.stack(self.doy),  dtype=torch.long)
#         self.mask = torch.tensor(np.stack(self.mask), dtype=torch.bool)
#         self.y    = torch.tensor(self.y,              dtype=torch.long)
#         print(f"✅ Loaded {len(self.y):,} windows "
#               f"(each {win_len} steps of 10 bands+DOY)")

#     def __len__(self):
#         return len(self.y)

#     def __getitem__(self, i):
#         return self.x[i], self.doy[i], self.mask[i], self.y[i]

# # ─────────── Build dataset & loader ─────────────────────────────────────
# ds     = WindowedDataset(CSV_TEST, win_len=WINDOW_LEN, stride=STRIDE)
# loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False)

# # ─────────── Build model & load checkpoint ───────────────────────────────
# sbert = make_sbert(ds.x.size(-1))
# model = SBERTClassification(sbert, int(ds.y.max()+1)).to(DEVICE)
# ckpt  = torch.load(CKPT_FILE, map_location=DEVICE)
# model.load_state_dict(ckpt["model_state_dict"], strict=True)
# model.eval()

# # ─────────── Inference at window-level ───────────────────────────────────
# y_true, y_pred = [], []
# with torch.no_grad():
#     pe_len = model.sbert.embedding.position.pe.size(0)
#     for xb, doy, mask, yb in loader:
#         xb   = xb.to(DEVICE)
#         doy  = torch.clamp(doy.to(DEVICE) - 1, 0, pe_len - 1)
#         mask = mask.to(DEVICE)
#         logits = model(xb, doy, mask)  # [B, C]
#         preds  = logits.argmax(dim=1).cpu().tolist()
#         y_pred.extend(preds)
#         y_true.extend(yb.tolist())

# # ─────────── Metrics ──────────────────────────────────────────────────────
# oa    = accuracy_score(y_true, y_pred) * 100
# kappa = cohen_kappa_score(y_true, y_pred)
# report = classification_report(y_true, y_pred, digits=4, zero_division=0)
# print(f"\nTest OA   = {oa:.2f}%")
# print(f"Test κ    = {kappa:.3f}")
# print(report)

# # ─────────── Confusion matrix ─────────────────────────────────────────────
# cm = confusion_matrix(y_true, y_pred)
# plt.figure(figsize=(10, 8))
# try:
#     import seaborn as sns
#     sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False)
# except ImportError:
#     plt.imshow(cm, interpolation="nearest", aspect="auto")
# plt.title("Window-level SITS-BERT Confusion Matrix")
# plt.xlabel("Predicted"); plt.ylabel("True")
# plt.tight_layout()
# plt.show()