# Setups

In [None]:
%%capture
!pip install tqdm matplotlib rasterio seaborn

In [1]:
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import confusion_matrix
import pandas as pd
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import rasterio
from rasterio.transform import from_origin
from collections import defaultdict
import csv
from pathlib import Path
import seaborn as sns
import math

In [2]:
# Ensuring reproducibility

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    os.environ["PYTHONHASHSEED"] = str(seed)
    
    # cuDNN reproducibility settings
    torch.backends.cudnn.benchmark = False 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False
set_seed(42)

In [3]:
# We'll need this to send all the Pretraoinig, finetuning, and prediction processes into GPU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
work_dir = "./work_dir"
os.makedirs(work_dir, exist_ok=True)

# Necessary functions

In [5]:
# This is the class that enables us to read the pickled dataset

class PickledDataset(Dataset):
    def __init__(self, imgs, labels=None, chip_indices=None, image_paths=None, metadata=None, usage="train"):
        self.imgs = imgs
        self.labels = labels
        self.chip_indices = chip_indices
        self.image_paths = image_paths
        self.metadata = metadata
        self.usage = usage.lower()  # "train", or "test"

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = torch.from_numpy(self.imgs[idx].transpose(2, 0, 1)).float()

        if self.usage == "train":
            label = torch.from_numpy(self.labels[idx]).long()
            return img, label

        # test/prediction mode
        sample = {"image": img}

        if self.labels is not None:
            sample["label"] = torch.from_numpy(self.labels[idx]).long()

        if self.chip_indices is not None:
            sample["chip_index"] = self.chip_indices[idx]

        if self.image_paths is not None:
            sample["image_path"] = self.image_paths[0]

        if self.metadata is not None:
            sample["metadata"] = self.metadata[0]

        return sample

In [6]:
# Metrics calculation

class Evaluator:
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.confusion_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)

    def reset(self):
        self.confusion_matrix[:] = 0

    def add_batch(self, gt, pred):
        gt = gt.flatten()
        pred = pred.flatten()
        mask = (gt >= 0) & (gt < self.num_classes)
        self.confusion_matrix += np.bincount(
            self.num_classes * gt[mask].astype(int) + pred[mask].astype(int),
            minlength=self.num_classes ** 2
        ).reshape(self.num_classes, self.num_classes)

    def overall_accuracy(self):
        return np.diag(self.confusion_matrix).sum() / np.maximum(self.confusion_matrix.sum(), 1)

    def classwise_overall_accuracy(self):
        return np.diag(self.confusion_matrix) / np.maximum(self.confusion_matrix.sum(axis=1), 1)

    def intersection_over_union(self):
        intersection = np.diag(self.confusion_matrix)
        union = self.confusion_matrix.sum(axis=1) + self.confusion_matrix.sum(axis=0) - intersection
        return intersection / np.maximum(union, 1)

    def precision(self):
        tp = np.diag(self.confusion_matrix)
        fp = self.confusion_matrix.sum(axis=0) - tp
        return tp / np.maximum(tp + fp, 1)

    def recall(self):
        tp = np.diag(self.confusion_matrix)
        fn = self.confusion_matrix.sum(axis=1) - tp
        return tp / np.maximum(tp + fn, 1)

    def f1_score(self):
        precision = self.precision()
        recall = self.recall()
        return 2 * precision * recall / np.maximum(precision + recall, 1e-7)

    def plot_confusion_matrix(self, class_labels, unknown_class_idx=None):
        cm = self.confusion_matrix.copy()
        if unknown_class_idx is not None:
            cm[unknown_class_idx, :] = 0
            cm[:, unknown_class_idx] = 0

        row_sums = cm.sum(axis=1, keepdims=True)
        cm_normalized = np.divide(cm, row_sums, where=row_sums != 0) * 100
        cm_normalized = np.nan_to_num(cm_normalized)

        fig, ax = plt.subplots(figsize=(8, 6))
        sns.heatmap(
            cm_normalized,
            annot=True,
            fmt=".1f",
            cmap="Blues",
            xticklabels=class_labels,
            yticklabels=class_labels,
            ax=ax
        )
        ax.set_xlabel("Predicted")
        ax.set_ylabel("True")
        ax.set_title("Confusion Matrix (%)")
        plt.tight_layout()
        plt.show()

    def evaluate_from_predictions(self, pickled_labels_path, pred_dir, class_mapping, unknown_class_idx, out_path):
        def extract_tile_and_year(path):
            base = os.path.basename(path)
            parts = os.path.splitext(base)[0].split('_')
            if len(parts) >= 2:
                return parts[0], parts[1]
            return None, None

        self.reset()

        with open(pickled_labels_path, 'rb') as f:
            label_dict = pickle.load(f)

        label_map = {}
        for idx, label_path in enumerate(label_dict["label_paths"]):
            tile, year = extract_tile_and_year(label_path)
            if tile and year:
                label_map[(tile, year)] = idx

        pred_files = sorted([f for f in os.listdir(pred_dir) if f.endswith('.tif')])

        for pred_file in pred_files:
            tile, year = extract_tile_and_year(pred_file)
            if not tile or not year:
                print(f"Skipping unmatched format: {pred_file}")
                continue

            key = (tile, year)
            if key not in label_map:
                print(f"No label found for prediction: {pred_file}")
                continue

            idx = label_map[key]
            ref_data = label_dict["chips"][idx]
            ref_data = np.squeeze(ref_data)

            pred_path = os.path.join(pred_dir, pred_file)
            with rasterio.open(pred_path) as src:
                pred_data = src.read(1)

            if ref_data.shape != pred_data.shape:
                print(f"Shape mismatch for {pred_file}")
                continue

            self.add_batch(np.expand_dims(ref_data, axis=0), np.expand_dims(pred_data.astype(np.int32), axis=0))

        acc = self.overall_accuracy()
        iou = self.intersection_over_union()
        precision = self.precision()
        recall = self.recall()
        f1 = self.f1_score()
        classwise_acc = self.classwise_overall_accuracy()

        metrics = {
            "Overall Accuracy": acc,
            "Mean IoU": np.nanmean(iou),
            "Mean Precision": np.nanmean(precision),
            "Mean Recall": np.nanmean(recall),
            "Mean F1 Score": np.nanmean(f1)
        }

        with open(out_path, mode="w", newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Metric", "Value"])
            for name, val in metrics.items():
                writer.writerow([name, val])

        classwise_path = Path(out_path).with_name(Path(out_path).stem + "_classwise.csv")
        with open(classwise_path, mode="w", newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Class", "Accuracy", "IoU", "Precision", "Recall", "F1 Score"])
            for i in range(self.num_classes):
                if unknown_class_idx is not None and i == unknown_class_idx:
                    continue
                class_name = class_mapping.get(i, f"class_{i}")
                writer.writerow([
                    class_name,
                    classwise_acc[i],
                    iou[i],
                    precision[i],
                    recall[i],
                    f1[i]
                ])

        label_names = [class_mapping.get(i, f"class_{i}") for i in range(self.num_classes)]
        self.plot_confusion_matrix(label_names, unknown_class_idx)
        return metrics

In [7]:
# Plotting function

def plot_preds(merged_df, tile_indices, pred_dir, label_colors):
    for tile_idx in tile_indices:
        if tile_idx >= len(merged_df):
            print(f"Index {tile_idx} out of range.")
            continue

        row = merged_df.iloc[tile_idx]
        tile_name = row["label_path"]
        img = row["image_chip"].astype(np.float32)
        label = row["label_chip"]

        if img.max() > 1.5:
            img /= img.max()

        if img.shape[2] >= 3:
            rgb = img[:, :, [2, 1, 0]]
        else:
            rgb = np.zeros((*img.shape[:2], 3), dtype=np.float32)

        if img.shape[2] >= 4:
            false = img[:, :, [3, 2, 1]]
        else:
            false = np.zeros_like(rgb)

        pred_path = os.path.join(pred_dir, tile_name)
        if not os.path.exists(pred_path):
            print(f"Prediction file not found for: {tile_name}")
            continue

        with rasterio.open(pred_path) as src:
            pred = src.read(1)

        fig, axes = plt.subplots(1, 4, figsize=(20, 5))

        axes[0].imshow(np.clip(rgb, 0, 1))
        axes[0].set_title("True Color (RGB)")
        axes[0].axis("off")

        axes[1].imshow(np.clip(false, 0, 1))
        axes[1].set_title("False Color (NIR-R-G)")
        axes[1].axis("off")

        axes[2].imshow(label, cmap=label_colors, vmin=0, vmax=2)
        axes[2].set_title("Reference Label")
        axes[2].axis("off")

        axes[3].imshow(pred, cmap=label_colors, vmin=0, vmax=2)
        axes[3].set_title("Prediction")
        axes[3].axis("off")

        fig.suptitle(f"Tile: {tile_name}", fontsize=16)
        plt.tight_layout()
        plt.show()

In [8]:
# Our Unet class!

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.5):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout_rate)
        )

    def forward(self, x):
        return self.conv(x)

class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Bottleneck, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

class TripleConv(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.5):
        super(TripleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout_rate)
        )

    def forward(self, x):
        return self.conv(x)


class unet_ra(nn.Module):
    def __init__(self, n_classes, in_channels,
                 filter_config=(64, 128, 256, 512, 1024, 1048), dropout_rate=0.5):
        super(unet_ra, self).__init__()

        # Encoder
        self.conv_block_enc_1 = DoubleConv(in_channels, filter_config[0], dropout_rate)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv_block_enc_2 = DoubleConv(filter_config[0], filter_config[1], dropout_rate)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv_block_enc_3 = DoubleConv(filter_config[1], filter_config[2], dropout_rate)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv_block_enc_4 = TripleConv(filter_config[2], filter_config[3], dropout_rate)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv_block_enc_5 = TripleConv(filter_config[3], filter_config[4], dropout_rate)
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Center
        self.center_1 = Bottleneck(filter_config[4], filter_config[5])

        # Decoder
        self.up1 = nn.ConvTranspose2d(filter_config[5], filter_config[4], kernel_size=2, stride=2)
        self.conv_block_dec_1 = TripleConv(filter_config[4]*2, filter_config[4], dropout_rate)

        self.up2 = nn.ConvTranspose2d(filter_config[4], filter_config[3], kernel_size=2, stride=2)
        self.conv_block_dec_2 = TripleConv(filter_config[3]*2, filter_config[3], dropout_rate)

        self.up3 = nn.ConvTranspose2d(filter_config[3], filter_config[2], kernel_size=2, stride=2)
        self.conv_block_dec_3 = DoubleConv(filter_config[2]*2, filter_config[2], dropout_rate)

        self.up4 = nn.ConvTranspose2d(filter_config[2], filter_config[1], kernel_size=2, stride=2)
        self.conv_block_dec_4 = DoubleConv(filter_config[1]*2, filter_config[1], dropout_rate)

        self.up5 = nn.ConvTranspose2d(filter_config[1], filter_config[0], kernel_size=2, stride=2)
        self.conv_block_dec_5 = DoubleConv(filter_config[0]*2, filter_config[0], dropout_rate)

        self.out = nn.Conv2d(filter_config[0], n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.conv_block_enc_1(x)
        pool1 = self.pool1(enc1)

        enc2 = self.conv_block_enc_2(pool1)
        pool2 = self.pool2(enc2)

        enc3 = self.conv_block_enc_3(pool2)
        pool3 = self.pool3(enc3)

        enc4 = self.conv_block_enc_4(pool3)
        pool4 = self.pool4(enc4)

        enc5 = self.conv_block_enc_5(pool4)
        pool5 = self.pool5(enc5)

        # Center
        center = self.center_1(pool5)

        # Decoder with skip connections
        up1 = self.up1(center)
        dec1 = self.conv_block_dec_1(torch.cat([up1, enc5], dim=1))  #[batch_size, channels, height, width] dim=1 is channel dimention

        up2 = self.up2(dec1)
        dec2 = self.conv_block_dec_2(torch.cat([up2, enc4], dim=1))

        up3 = self.up3(dec2)
        dec3 = self.conv_block_dec_3(torch.cat([up3, enc3], dim=1))

        up4 = self.up4(dec3)
        dec4 = self.conv_block_dec_4(torch.cat([up4, enc2], dim=1))

        up5 = self.up5(dec4)
        dec5 = self.conv_block_dec_5(torch.cat([up5, enc1], dim=1))

        out = self.out(dec5)
        return out


# Pretraining Process

### Step 1: Load train/vaildation dataset pickles

In [10]:
pwd

'/home/airg/rabedi/UCSB_workshop'

In [12]:
# Reading train and validation dataset pickles

with open("/home/airg/rabedi/UCSB/pickles/pretrain/train_data.pkl", "rb") as f:
    train_data = pickle.load(f)
    train_imgs = train_data["images"]
    train_labels = train_data["labels"]

with open("/home/airg/rabedi/UCSB/pickles/pretrain/validate_data.pkl", "rb") as f:
    val_data = pickle.load(f)
    val_imgs = val_data["images"]
    val_labels = val_data["labels"]

In [13]:
# Create training and validation datasets

train_dataset = PickledDataset(train_imgs, train_labels, usage="train")
val_dataset = PickledDataset(val_imgs, val_labels, usage="train")

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=5, shuffle=False, num_workers=0)

### Step 2:  Configuration

In [None]:
# Configuration

in_channels = 4 #train_imgs[0].shape[2]
num_classes = 3 #len(set(label.max() for label in train_labels))
num_epochs = 20
class_mapping = {0: 'non_field', 1: 'field', 2: 'Boundary'}
unknown_class_idx = None

### Step 3: Model initialization and hyperparameter settings

In [None]:
# Model initialization

model = unet_ra(n_classes=num_classes, in_channels=in_channels).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

### Step 4: Training loop

In [None]:
train_logs = []
val_logs = []

best_miou = 0.0
best_model_state = None

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    print(f"\nEpoch {epoch + 1}/{num_epochs} - Training...")

    train_bar = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
    for i, (images, labels) in train_bar:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_bar.set_description(f"Batch {i + 1}")
        train_bar.set_postfix(loss=loss.item())

    train_loss /= len(train_loader)

    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []
    print("Validating...")

    val_bar = tqdm(val_loader, total=len(val_loader), leave=False)
    with torch.no_grad():
        for images, labels in val_bar:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            preds = outputs.argmax(dim=1)
            all_preds.append(preds)
            all_labels.append(labels)

            val_bar.set_postfix(loss=loss.item())

    val_loss /= len(val_loader)
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    evaluator = Evaluator(num_classes)
    evaluator.add_batch(all_labels.cpu().numpy(), all_preds.cpu().numpy())

    acc = evaluator.overall_accuracy()
    miou = np.nanmean(evaluator.intersection_over_union())
    precision = np.nanmean(evaluator.precision())
    recall = np.nanmean(evaluator.recall())
    f1 = np.nanmean(evaluator.f1_score())

    print(f"Epoch {epoch + 1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}")
    print(f"  mIoU:       {miou:.4f}")
    print(f"  Accuracy:   {acc:.4f}")
    print(f"  Precision:  {precision:.4f}")
    print(f"  Recall:     {recall:.4f}")
    print(f"  F1 Score:   {f1:.4f}")

    train_logs.append({
        "epoch": epoch + 1,
        "train_loss": train_loss
    })

    val_logs.append({
        "epoch": epoch + 1,
        "val_loss": val_loss,
        "miou": miou,
        "accuracy": acc,
        "mean_precision": precision,
        "mean_recall": recall,
        "mean_f1_score": f1
    })

    if miou > best_miou:
        best_miou = miou
        best_model_state = model.state_dict()
        print(f"  Best model updated with validation mIoU {miou:.4f}")

### Step 4: Saving logs/model parameters

In [None]:
# Save logs to CSV

combined_logs = []
for train, val in zip(train_logs, val_logs):
    log = {"epoch": train["epoch"], "train_loss": train["train_loss"]}
    log.update(val) 
    combined_logs.append(log)

log_df = pd.DataFrame(combined_logs)
log_df.to_csv(os.path.join(work_dir, "pretrain_loss.csv"), index=False)

print("Training and validation logs saved to pretrain_loss.csv")


# Save best model (by mIoU)

if best_model_state is not None:
    torch.save(best_model_state, os.path.join(work_dir, "pretrain_best_model.pth"))
    print("Best model saved to best_model.pth")

# Save final model (last epoch)

torch.save(model.state_dict(), os.path.join(work_dir, "pretrain_last_model.pth"))
print("Last model saved to last_model.pth")

### Step 5: Train and validation loss curve

In [None]:
log_file = os.path.join(work_dir, "pretrain_loss.csv")

df = pd.read_csv(log_file)

plt.figure(figsize=(8, 5))
plt.plot(df['epoch'], df['train_loss'], label='Train Loss', marker='o')
plt.plot(df['epoch'], df['val_loss'], label='Val Loss', marker='o')
plt.title("Train vs Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

metrics = df.columns.difference(['epoch', 'train_loss', 'val_loss'])
num_metrics = len(metrics)
cols = 2
rows = math.ceil(num_metrics / cols)

fig, axes = plt.subplots(rows, cols, figsize=(12, 4 * rows))
axes = axes.flatten()

for idx, metric in enumerate(metrics):
    axes[idx].plot(df['epoch'], df[metric], marker='o', label=metric)
    axes[idx].set_title(metric)
    axes[idx].set_xlabel("Epoch")
    axes[idx].set_ylabel(metric)
    axes[idx].grid(True)
    axes[idx].legend()

for j in range(idx + 1, len(axes)):
    fig.delaxes(axes[j])

plt.tight_layout()
plt.show()

# Prediction

### Step 1: Load the test data

In [None]:
with open("./pickles/predict/test_data_img.pkl", "rb") as f:
    test_data = pickle.load(f)

test_imgs = test_data["chips"]
test_labels = test_data.get("labels")
chip_indices = test_data["chip_indices"]
image_paths = test_data["image_paths"]
metadata = test_data["metadata"]
tile_ids = test_data["tile_ids"]
years = test_data["years"]

In [None]:
test_dataset = PickledDataset(
    imgs=test_imgs,
    labels=test_labels,                     
    chip_indices=chip_indices,              
    image_paths=image_paths,                
    metadata=metadata,
    usage="test"
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

### Step 2: Load the best model

In [None]:
# We assume that you didn't restart the kernel otherwise you'll need to initialize the model again

model.load_state_dict(torch.load(os.path.join(work_dir, "best_model.pth"), map_location="cuda" if torch.cuda.is_available() else "cpu"))

model.to(device)
model.eval()

### Step 3: Make predictions

In [None]:
save_dir = os.path.join(work_dir, "pred_2cls")
os.makedirs(save_dir, exist_ok=True)

# Group chips by tile
grouped_data = defaultdict(lambda: {
    "chips": [],
    "chip_indices": [],
    "meta": None,
    "image_path": None
})

for chip, idx, meta, path, tile in zip(test_imgs, chip_indices, metadata, image_paths, tile_ids):
    grouped_data[tile]["chips"].append(chip)
    grouped_data[tile]["chip_indices"].append(idx)
    grouped_data[tile]["meta"] = meta
    grouped_data[tile]["image_path"] = path

# === Predict and save each tile ===
for tile_id, data in tqdm(grouped_data.items(), desc="Predicting tiles"):
    meta = data["meta"].copy()
    meta.update({
        "count": 1,
        "dtype": rasterio.int8
    })
    meta.pop("nodata", None)  # remove if invalid

    tile_height = meta["height"]
    tile_width = meta["width"]
    full_pred = np.zeros((tile_height, tile_width), dtype=np.int8)

    for chip, (chip_row, chip_col) in zip(data["chips"], data["chip_indices"]):
        img = torch.from_numpy(chip.transpose(2, 0, 1)).unsqueeze(0).float().to(device)

        with torch.no_grad():
            output = model(img)

            pred = output.argmax(dim=1).squeeze().cpu().numpy().astype(np.int8)

            # Convert 3-class prediction → binary:
            # 0 = non-field, 1 = field and boundary (original classes 1 and 2)
            pred = np.where(pred == 2, 1, pred)

        chip_size = pred.shape[0]
        full_pred[
            chip_row:chip_row + chip_size,
            chip_col:chip_col + chip_size
        ] = pred

    tile_path = data["image_path"]
    tile_name = os.path.splitext(os.path.basename(tile_path))[0]
    pred_path = os.path.join(save_dir, f"{tile_name}.tif")

    with rasterio.open(pred_path, "w", **meta) as dst:
        dst.write(full_pred, 1)

### Step 4: Visualization

In [None]:
# Load pickled image and label data

with open("./pickles/predict/test_data_img.pkl", "rb") as f:
    img_data = pickle.load(f)

with open("./pickles/predict/test_data_lbl_2cls.pkl", "rb") as f:
    lbl_data = pickle.load(f)

img_df = pd.DataFrame({
    "chip_index": img_data["chip_indices"],
    "tile_id": img_data["tile_ids"],
    "year": img_data["years"],
    "image_chip": img_data["chips"]
})

lbl_df = pd.DataFrame({
    "chip_index": lbl_data["chip_indices"],
    "tile_id": lbl_data["tile_ids"],
    "year": lbl_data["years"],
    "label_chip": lbl_data["chips"],
    "label_path": lbl_data["label_paths"]
})

merged_df = pd.merge(
    img_df,
    lbl_df,
    on=["chip_index", "tile_id", "year"]
)

In [None]:
# Calling plot function

plot_preds(
    merged_df=merged_df,
    tile_indices=[0, 2],
    pred_dir="./work_dir/pred_2cls",
    label_colors=ListedColormap(["#FFFF00", "#00FF00", "#000000"])
)

### Step 5: Accuracy Assessment

In [None]:
# Initialize the evaluator

evaluator = Evaluator(num_classes=2)

# Configuration
pickled_labels_path = "./pickles/predict/test_data_lbl_2cls.pkl"
pred_dir = "./work_dir/pred_2cls"
unknown_class_idx = None
class_mapping = {0: 'non_field', 1: 'field'} 
out_path = "./work_dir/eval_metrics_train_only.csv"

# Run evaluation
evaluator.evaluate_from_predictions(
    pickled_labels_path=pickled_labels_path,
    pred_dir=pred_dir,
    class_mapping=class_mapping,
    unknown_class_idx=unknown_class_idx,
    out_path=out_path
)


# Finetuning Process

### Step 1: Loading finetune dataset pickles

In [None]:
with open("./pickles/finetune/train_data.pkl", "rb") as f:
    train_data = pickle.load(f)
    train_imgs = train_data["images"]
    train_labels = train_data["labels"]

with open("./pickles/finetune/validate_data.pkl", "rb") as f:
    val_data = pickle.load(f)
    val_imgs = val_data["images"]
    val_labels = val_data["labels"]

In [None]:
# Create training and validation datasets

train_dataset = PickledDataset(train_imgs, train_labels, usage="train")
val_dataset = PickledDataset(val_imgs, val_labels, usage="train")

# Create dataloaders

train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=5, shuffle=False, num_workers=0)

### Step 2: Configuration

In [None]:
# Load pretrained weights

checkpoint_path = "./work_dir/pretrain_best_model.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
print("Loaded pretrained model weights.")

# Freeze layers by index

freeze_indices = list(range(7))  # You can adjust based on which layers you want to freeze!
modules = list(model.children())
for i in freeze_indices:
    for param in modules[i].parameters():
        param.requires_grad = False
    print(f"Froze layer {i}: {modules[i].__class__.__name__}")

# Set up optimizer for trainable parameters

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=0.001,
    weight_decay=1e-4
)

### Step 3: Training (finetuning) loop

In [None]:
train_logs = []
val_logs = []
best_miou = 0.0
best_model_state = None
num_epochs = 40

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    print(f"\nEpoch {epoch + 1}/{num_epochs} - Training")

    train_bar = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
    for i, (images, labels) in train_bar:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_bar.set_description(f"Batch {i + 1}")
        train_bar.set_postfix(loss=loss.item())

    train_loss /= len(train_loader)

    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []
    print("Validating")

    val_bar = tqdm(val_loader, total=len(val_loader), leave=False)
    with torch.no_grad():
        for images, labels in val_bar:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            preds = outputs.argmax(dim=1)
            all_preds.append(preds)
            all_labels.append(labels)

            val_bar.set_postfix(loss=loss.item())

    val_loss /= len(val_loader)
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    evaluator = Evaluator(num_classes)
    evaluator.add_batch(all_labels.cpu().numpy(), all_preds.cpu().numpy())

    miou = np.nanmean(evaluator.intersection_over_union())
    acc = evaluator.overall_accuracy()

    print(f"Epoch {epoch + 1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}")
    print(f"  mIoU:       {miou:.4f}")
    print(f"  Accuracy:   {acc:.4f}")

    train_logs.append({"epoch": epoch + 1, "train_loss": train_loss})
    val_logs.append({"epoch": epoch + 1, "val_loss": val_loss, "miou": miou, "accuracy": acc})

    if miou > best_miou:
        best_miou = miou
        best_model_state = model.state_dict()
        print(f"Best model updated with validation mIoU {miou:.4f}")

### Step 4: Saving logs/model parameters

In [None]:
combined_logs = []
for train, val in zip(train_logs, val_logs):
    log = {"epoch": train["epoch"], "train_loss": train["train_loss"], "val_loss": val["val_loss"], "miou": val["miou"], "accuracy": val["accuracy"]}
    combined_logs.append(log)

log_df = pd.DataFrame(combined_logs)
log_df.to_csv(os.path.join(work_dir, "finetune_loss.csv"), index=False)

print("Fine-tuning logs saved to finetune_loss.csv")


# Save best model (by mIoU)

if best_model_state is not None:
    torch.save(best_model_state, os.path.join(work_dir, "finetuned_best_model.pth"))
    print("Best model saved to finetuned_best_model.pth")

# Save final model (last epoch)

torch.save(model.state_dict(), os.path.join(work_dir, "finetuned_last_model.pth"))
print("Last model saved to finetuned_last_model.pth")


### Step 5: Train and validation loss curve

In [None]:
log_file = os.path.join(work_dir, "finetune_loss.csv")

df = pd.read_csv(log_file)

plt.figure(figsize=(8, 5))
plt.plot(df['epoch'], df['train_loss'], label='Train Loss', marker='o')
plt.plot(df['epoch'], df['val_loss'], label='Val Loss', marker='o')
plt.title("Train vs Val Loss (Finetuning)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

metrics = df.columns.difference(['epoch', 'train_loss', 'val_loss'])
num_metrics = len(metrics)
cols = 2
rows = math.ceil(num_metrics / cols)

fig, axes = plt.subplots(rows, cols, figsize=(12, 4 * rows))
axes = axes.flatten()

for idx, metric in enumerate(metrics):
    axes[idx].plot(df['epoch'], df[metric], marker='o', label=metric)
    axes[idx].set_title(metric)
    axes[idx].set_xlabel("Epoch")
    axes[idx].set_ylabel(metric)
    axes[idx].grid(True)
    axes[idx].legend()

for j in range(idx + 1, len(axes)):
    fig.delaxes(axes[j])

plt.tight_layout()
plt.show()

# Prediction after finetuning

### Step 1: Load the test data

Same as previous!

### Step 2: Load the best model

In [None]:
# We assume that you didn't restart the kernel otherwise you'll need to initialize the model again

model.load_state_dict(torch.load(os.path.join(work_dir, "finetuned_best_model.pth"), map_location="cuda" if torch.cuda.is_available() else "cpu"))

model.to(device)
model.eval()

### Step 3: Make prediction

In [None]:
save_dir = os.path.join(work_dir, "pred_with_finetune_2cls")
os.makedirs(save_dir, exist_ok=True)

# Group chips by tile
grouped_data = defaultdict(lambda: {
    "chips": [],
    "chip_indices": [],
    "meta": None,
    "image_path": None
})

for chip, idx, meta, path, tile in zip(test_imgs, chip_indices, metadata, image_paths, tile_ids):
    grouped_data[tile]["chips"].append(chip)
    grouped_data[tile]["chip_indices"].append(idx)
    grouped_data[tile]["meta"] = meta
    grouped_data[tile]["image_path"] = path

# Predict and save each tile
for tile_id, data in tqdm(grouped_data.items(), desc="Predicting tiles"):
    meta = data["meta"].copy()
    meta.update({
        "count": 1,
        "dtype": rasterio.int8
    })
    meta.pop("nodata", None)

    tile_height = meta["height"]
    tile_width = meta["width"]
    full_pred = np.zeros((tile_height, tile_width), dtype=np.int8)

    for chip, (chip_row, chip_col) in zip(data["chips"], data["chip_indices"]):
        img = torch.from_numpy(chip.transpose(2, 0, 1)).unsqueeze(0).float().to(device)

        with torch.no_grad():
            output = model(img)
            pred = output.argmax(dim=1).squeeze().cpu().numpy().astype(np.int8)

            # Convert 3-class prediction → binary:
            # 0 = non-field, 1 = field and boundary (original classes 1 and 2)
            pred = np.where(pred == 2, 1, pred)

        chip_size = pred.shape[0]
        full_pred[
            chip_row:chip_row + chip_size,
            chip_col:chip_col + chip_size
        ] = pred

    tile_path = data["image_path"]
    tile_name = os.path.splitext(os.path.basename(tile_path))[0]
    pred_path = os.path.join(save_dir, f"{tile_name}.tif")

    with rasterio.open(pred_path, "w", **meta) as dst:
        dst.write(full_pred, 1)

### Step 4: Visualization

In [None]:
# Load pickled image and label data (the same as before)

with open("./pickles/predict/test_data_img.pkl", "rb") as f:
    img_data = pickle.load(f)

with open("./pickles/predict/test_data_lbl_2cls.pkl", "rb") as f:
    lbl_data = pickle.load(f)

img_df = pd.DataFrame({
    "chip_index": img_data["chip_indices"],
    "tile_id": img_data["tile_ids"],
    "year": img_data["years"],
    "image_chip": img_data["chips"]
})

lbl_df = pd.DataFrame({
    "chip_index": lbl_data["chip_indices"],
    "tile_id": lbl_data["tile_ids"],
    "year": lbl_data["years"],
    "label_chip": lbl_data["chips"],
    "label_path": lbl_data["label_paths"]
})

merged_df = pd.merge(
    img_df,
    lbl_df,
    on=["chip_index", "tile_id", "year"]
)

In [None]:
# Calling plot function

plot_preds(
    merged_df=merged_df,
    tile_indices=[0, 2],
    pred_dir="./work_dir/pred_with_finetune_2cls",
    label_colors=ListedColormap(["#FFFF00", "#00FF00", "#000000"])
)

### Step 5: Acuracy assessment

In [None]:
# Initialize the evaluator

evaluator = Evaluator(num_classes=2)

# Configuration
pickled_labels_path = "./pickles/predict/test_data_lbl_2cls.pkl"
pred_dir = "./work_dir/pred_with_finetune_2cls"
unknown_class_idx = None
class_mapping = {0: 'non_field', 1: 'field'} 
out_path = "./work_dir/eval_metrics_finetune.csv"

# Run evaluation
evaluator.evaluate_from_predictions(
    pickled_labels_path=pickled_labels_path,
    pred_dir=pred_dir,
    class_mapping=class_mapping,
    unknown_class_idx=unknown_class_idx,
    out_path=out_path
)