In [None]:
import sys
from pathlib import Path

# add repo root
root = Path().resolve().parents[0]
sys.path.append(str(root))

In [None]:
from dotenv import load_dotenv
import os

load_dotenv()
api_key = os.getenv("WANDB_API_KEY")

In [None]:
from sklearn.model_selection import train_test_split
from src.data.timeseries_dataset import TimeSeriesDataset
from src.config import SENTINEL_DIR, MASK_DIR

all_sentinel_files = list(SENTINEL_DIR.glob("*_RGBNIRRSWIRQ_Mosaic.tif"))
train_ids, val_ids = train_test_split(all_sentinel_files, test_size=0.1, random_state=0)

print(train_ids[0])

In [None]:
from src.data.transform import ComposeTS, NormalizeBy, RandomCropTS, CenterCropTS

CROP = 64  # use 64 if your tiles are tiny; 128 or 256 only if your tiles are large enough

train_transform = ComposeTS([
    NormalizeBy(10000.0),
    RandomCropTS(CROP),
])

val_transform = ComposeTS([
    NormalizeBy(10000.0),
    CenterCropTS(CROP),          # deterministic for validation
])

train_ds = TimeSeriesDataset(train_ids, sensor="sentinel", slice_mode="first_half", transform=train_transform)
val_ds   = TimeSeriesDataset(val_ids,   sensor="sentinel", slice_mode="first_half", transform=val_transform)


In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=4)


In [None]:
import torch

x, mask = next(iter(train_loader))
print(torch.unique(mask))


In [None]:
import torch
ds_raw = TimeSeriesDataset(train_ids, sensor="sentinel", slice_mode="first_half", transform=None)
has_zero = False
has_nonzero = False

for i in range(len(ds_raw)):
    _, m = ds_raw[i]
    u = torch.unique(m)
    if (u == 0).any():
        has_zero = True
    if (u > 0).any():
        has_nonzero = True

print("any tiles with background (0)?", has_zero)
print("any tiles with change (>0)?", has_nonzero)


In [None]:
import torch
from src.models.external.torchrs_fc_cd import FCEF

device = "cuda" if torch.cuda.is_available() else "cpu"

# probe batch
sample_x, _ = next(iter(train_loader))
_, T, C, H, W = sample_x.shape

model = FCEF(channels=C, t=T, num_classes=2).to(device)
model

In [None]:
import torch

def compute_confusion_binary(y_pred, y_true, positive_class=1):
    """
    y_pred, y_true: (B, H, W) with 0/1 labels
    returns TP, FP, TN, FN as scalars
    """
    y_pred = (y_pred == positive_class)
    y_true = (y_true == positive_class)

    tp = (y_pred & y_true).sum().item()
    fp = (y_pred & ~y_true).sum().item()
    tn = (~y_pred & ~y_true).sum().item()
    fn = (~y_pred & y_true).sum().item()
    return tp, fp, tn, fn

def compute_metrics_from_confusion(tp, fp, tn, fn, eps=1e-8):
    accuracy  = (tp + tn) / (tp + tn + fp + fn + eps)
    precision = tp / (tp + fp + eps)
    recall    = tp / (tp + fn + eps)
    f1        = 2 * precision * recall / (precision + recall + eps)
    iou       = tp / (tp + fp + fn + eps)  # IoU for the positive class
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "iou": iou,
    }

def compute_batch_metrics(logits, mask, positive_class=1):
    """
    logits: (B, 2, H, W)
    mask:   (B, H, W) with {0,1}
    returns metrics dict for the whole batch
    """
    with torch.no_grad():
        pred = torch.argmax(logits, dim=1)  # (B, H, W)
        tp, fp, tn, fn = compute_confusion_binary(pred, mask, positive_class)
        return compute_metrics_from_confusion(tp, fp, tn, fn)


In [None]:
from src.models.external.torchrs_fc_cd import FCEF  # or your local file

model1 = FCEF(channels=C, t=T, num_classes=2).to(device)  # C,T from a batch
criterion = torch.nn.CrossEntropyLoss()

x, mask = next(iter(train_loader))
x = x.to(device)
mask = mask.to(device)

model1.eval()
with torch.no_grad():
    logits = model1(x)
    print("logits NaN:", torch.isnan(logits).any().item())
    loss = criterion(logits, mask)
    print("single-batch loss:", loss.item())


In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
import wandb

lr = 1e-3

# Start a new wandb run to track this script.
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="nina_prosjektoppgave",
    # Set the wandb project where this run will be logged.
    project="FCEarlyFusion",
    # Track hyperparameters and run metadata.
    config={
        "learning_rate": lr,
        "architecture": "FCEF",
        "dataset": "sentinel",
        "epochs": 10,
    },
)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=lr)
scaler = torch.amp.GradScaler("cuda")

for epoch in range(10):
    model.train()
    total_loss = 0.0
    for x, mask in tqdm(train_loader, desc=f"epoch {epoch+1}"):
        x = x.to(device)          # (B, T, C, H, W)
        mask = mask.to(device)    # (B, H, W)

        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            logits = model(x)         # (B, 2, H, W)
            loss = criterion(logits, mask)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, mask in val_loader:
            x = x.to(device)
            mask = mask.to(device)
            with torch.amp.autocast('cuda'):
                logits = model(x)
                loss = criterion(logits, mask)
            val_loss += loss.item()

            pred = torch.argmax(logits, dim=1)
            tp, fp, tn, fn = compute_confusion_binary(pred, mask, positive_class=1)
            sum_tp += tp
            sum_fp += fp
            sum_tn += tn
            sum_fn += fn

    avg_val_loss = val_loss / len(val_loader)
    val_metrics = compute_metrics_from_confusion(sum_tp, sum_fp, sum_tn, sum_fn)

  

    run.log({"avg_train_loss": avg_train_loss,
             "avg_val_loss": avg_val_loss,
             "IoU": val_metrics['iou'],
             "F1": val_metrics['f1'],
             "Precision": val_metrics['precision'],
             "Recall": val_metrics['recall'],
             "Accuracy": val_metrics['accuracy']})

    print(
        f"epoch {epoch+1}: "
        f"train={avg_train_loss:.4f} "
        f"val={avg_val_loss:.4f} "
        f"IoU={val_metrics['iou']:.4f} "
        f"F1={val_metrics['f1']:.4f} "
        f"Prec={val_metrics['precision']:.4f} "
        f"Rec={val_metrics['recall']:.4f} "
        f"Acc={val_metrics['accuracy']:.4f}"
    )

run.finish()

In [None]:
import matplotlib.pyplot as plt

def visualize_batch(model, data_loader, device, num_examples=3):
    model.eval()
    x, mask = next(iter(data_loader))  # one batch
    x = x.to(device)
    mask = mask.to(device)

    with torch.no_grad():
        logits = model(x)
        pred = torch.argmax(logits, dim=1)  # (B, H, W)

    # move to cpu for plotting
    x_cpu = x.cpu()
    mask_cpu = mask.cpu()
    pred_cpu = pred.cpu()

    B, T, C, H, W = x_cpu.shape
    num_examples = min(num_examples, B)

    for i in range(num_examples):
        # very simple "input" visualization:
        # take last time step, first 3 channels and normalize to [0,1]
        img = x_cpu[i, -1, :3]  # (3, H, W)  (if fewer than 3 bands, slice accordingly)
        img_min = img.min()
        img_max = img.max()
        img_vis = (img - img_min) / (img_max - img_min + 1e-8)
        img_vis = img_vis.permute(1, 2, 0)  # (H, W, 3)

        gt = mask_cpu[i]   # (H, W)
        pr = pred_cpu[i]   # (H, W)

        fig, axes = plt.subplots(1, 3, figsize=(10, 4))
        axes[0].imshow(img_vis)
        axes[0].set_title("Input (t_last, RGB-ish)")
        axes[0].axis("off")

        axes[1].imshow(gt, vmin=0, vmax=1)
        axes[1].set_title("Ground truth mask")
        axes[1].axis("off")

        axes[2].imshow(pr, vmin=0, vmax=1)
        axes[2].set_title("Predicted mask")
        axes[2].axis("off")

        plt.tight_layout()
        plt.show()


In [None]:
visualize_batch(model, val_loader, device, num_examples=3)