In [None]:
import sys
import os
sys.path.append(os.path.abspath("../src"))
sys.path.append(os.path.abspath("../"))

In [None]:
import os
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle
from skimage import color
from utils import load_and_normalize_tiff, load_mask
from visualization import plot_image_and_mask_and_prediction
from pathlib import Path

In [None]:
def extract_features(img):
    R = img[0].astype(np.float32)
    G = img[1].astype(np.float32)
    B = img[2].astype(np.float32)
    IR = img[3].astype(np.float32)

    # Raw bands
    features = [R, G, B, IR]

    # spectral bands
    ndvi = (IR - R) / (IR + R + 1e-5)
    ndsi = (G - IR) / (G + IR + 1e-5)
    features.append(ndvi)
    features.append(ndsi)

    # Pixel value
    brightness = (R + G + B) / 3.0
    rgb_norm = np.stack([R, G, B], axis=-1) / 255.0
    hsv = color.rgb2hsv(rgb_norm.clip(0, 1))
    saturation = hsv[:, :, 1]
    
    features.append(brightness)
    features.append(saturation)

    features = np.stack(features, axis=-1)
    return features

def dice_coefficient(mask1, mask2):
    intersection = np.sum(mask1 * mask2)
    return (2.0 * intersection + 1e-9) / (np.sum(mask1) + np.sum(mask2) + 1e-9)

In [None]:
image_files = []
mask_files = []

PROCESSED_DATA = Path("../data/processed2")
for category in ['cloud_free', 'partially_clouded', 'fully_clouded']:
    img_dir = PROCESSED_DATA / "data" / category
    mask_dir = PROCESSED_DATA / "masks" / category
    
    for img_file in img_dir.glob('*.tif'):
        mask_path = mask_dir / img_file.name
        if mask_path.exists():
            image_files.append(img_file)
            mask_files.append(mask_path)

# Shuffle and split
combined = list(zip(image_files, mask_files))
np.random.seed(42)
np.random.shuffle(combined)
image_files, mask_files = zip(*combined)

train_imgs, test_imgs, train_masks, test_masks = train_test_split(image_files, mask_files, test_size=0.2, random_state=42)

print(f"Train: {len(train_imgs)} images, Test: {len(test_imgs)} images")

In [None]:
sgd_model = SGDClassifier(loss='log_loss', max_iter=1, learning_rate='optimal', tol=None, random_state=42)
classes = np.array([0, 1])

In [None]:
def evaluate(model, img_list, mask_list):
    scores = []
    
    for img_path, mask_path in tqdm(zip(img_list, mask_list), total=len(img_list), desc=f"Evaluation"):
        img = load_and_normalize_tiff(img_path)
        mask = load_mask(mask_path)

        features = extract_features(img)
        features = features.reshape(-1, features.shape[-1])
        mask = mask.reshape(-1)

        preds = model.predict(features)

        score = dice_coefficient(preds, mask)
        scores.append(score)

    avg_score = np.mean(scores)
    print(f"Average Dice Score: {avg_score:.4f}")
    return avg_score

In [None]:
def predict_full_image(model, img_path, mask_path):
    img = load_and_normalize_tiff(img_path)
    mask = load_mask(mask_path)

    features = extract_features(img)
    features = features.reshape(-1, features.shape[-1])

    preds = model.predict(features)
    preds = preds.reshape((512, 512)) 

    plot_image_and_mask_and_prediction(img,mask,preds,img_path)

In [None]:
for idx, (img_path, mask_path) in enumerate(tqdm(zip(train_imgs, train_masks), total=len(train_imgs), desc="Training")):
    
    img = load_and_normalize_tiff(img_path)
    mask = load_mask(mask_path)
    features = extract_features(img)

    features = features.reshape(-1, features.shape[-1])
    mask = mask.reshape(-1)

    if idx == 0:
        sgd_model.partial_fit(features, mask, classes=classes)
    else:
        sgd_model.partial_fit(features, mask)

In [None]:
with open('../outputs/temp/sgd_model.pkl', 'wb') as f:
    pickle.dump(sgd_model, f)

In [None]:
model = None
with open('../outputs/temp/sgd_model.pkl', 'rb') as f:
    model = pickle.load(f)

In [None]:
evaluate(model, test_imgs, test_masks)

In [None]:
input_dir = 'test'
data_files = [f for f in os.listdir(input_dir + '/data')]
# plot first 30 images
for f in data_files[0:30]:
    predict_full_image(model, input_dir + '/data/' + f, input_dir + '/masks/' + f)