In [2]:
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from osgeo import gdal
from pathlib import Path
from numpy.typing import NDArray
from typing import List, Dict, Tuple

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
from sklearn.semi_supervised import SelfTrainingClassifier
from sklearn.calibration import CalibratedClassifierCV

In [1]:
from components.dataloader import AbstractDataLoader
from components.fcm import FCM
from components.ssl import SemiSupervisedModel

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [109]:
# source files
SOURCES = []

# label files
LABELS = []

In [112]:
class FloodDataLoder(AbstractDataLoader):
    classes = [
        255, # water
        0, # no water
        15, # no data
    ]
    
    def __init__(self, sources: List[str], labels: List[str], idx=0, random_state=42):
        self.source = [Path(f) for f in sources]
        self.labels = [Path(f) for f in labels]
        self.idx = idx
        self.rng = np.random.default_rng(random_state)
        self.le = LabelEncoder()

        self.initial_size = None
        self.labeled = None
        self.labels_arr = None
        self.raw_labels = None
        self.unlabeled = None
        self.unlabeled_raw = None
        
    def _load_tif(self, source=True):
        if source:
            ds = gdal.Open(str(self.source[self.idx] / 'VV.tif'), gdal.GA_ReadOnly)
        else:
            ds = gdal.Open(str(self.labels[self.idx] / 'raster_labels.tif'), gdal.GA_ReadOnly)
        rb = ds.GetRasterBand(1)
        img_array = rb.ReadAsArray()
        self.initial_size = img_array.shape
        return img_array.reshape(-1)
    
    def sample(X, y, samples):
        unique_ys = np.unique(y, axis=0)
        result = []
        for unique_y in unique_ys:
            val_indices = np.argwhere(y==unique_y).flatten()
            random_samples = np.random.choice(val_indices, samples, replace=False)
            ret.append(X[random_samples])
        return np.concatenate(result)
        
    def _extract_all(self, percentage: float = 0.1):
        source_img = self._load_tif(source=True)
        label_img = self._load_tif(source=False)
        
        labeled, labels, unlabeled = [], [], []
        labeled_indices = np.array([], dtype=int)
        
        value_counts = dict(zip(*np.unique(label_img, return_counts=True)))
        for klass in self.classes[:-1]:
            n_labeled = int(value_counts[klass] * percentage)
            class_indices = np.argwhere(label_img == klass).flatten()
            indices = self.rng.choice(class_indices, n_labeled, replace=False, shuffle=False)
            
            labeled.append(source_img[indices])
            labels.append(label_img[indices])
            labeled_indices = np.concatenate((labeled_indices, indices), axis=0)
        
        labeled = np.concatenate(labeled, axis=0)
        
        labels = np.concatenate(labels, axis=0)
        labels = self.le.fit_transform(labels).reshape(-1)
        
        unlabeled = pd.DataFrame(np.delete(source_img, labeled_indices))
        unlabeled = unlabeled[unlabeled[0] != 9999].reset_index(drop=True) # костыль
        return pd.DataFrame(labeled), labels, pd.DataFrame(unlabeled), pd.DataFrame(source_img, dtype=int), label_img
        
    def get_labeled(self) -> Tuple[pd.DataFrame, NDArray]:
        if self.labeled is None:
            self.labeled, self.labels_arr, self.unlabeled, self.unlabeled_raw, self.raw_labels = self._extract_all()
        return self.labeled.copy(), self.labels_arr.copy()
    
    def get_raw_labels(self) -> NDArray:
        if self.raw_labels is None:
            self.labeled, self.labels_arr, self.unlabeled, self.unlabeled_raw, self.raw_labels = self._extract_all()
        return self.raw_labels.copy()
    
    def get_unlabeled(self) -> pd.DataFrame:
        if self.unlabeled is None:
            self.labeled, self.labels_arr, self.unlabeled, self.unlabeled_raw, self.raw_labels = self._extract_all()
        return self.unlabeled.copy()
    
    def get_unlabeled_raw(self) -> pd.DataFrame:
        if self.unlabeled_raw is None:
            self.labeled, self.labels_arr, self.unlabeled, self.unlabeled_raw, self.raw_labels = self._extract_all()
        return self.unlabeled_raw.copy()

    def get_unique_labels(self) -> List[str]:
        return self.classes[:-1]

In [138]:
def metrics(y_true, y_pred):    
    cr = classification_report(y_true, y_pred, output_dict=True)
    return cr

def train_ssl(dl):
    clf = SGDClassifier(max_iter=10_000, random_state=42, class_weight='balanced', n_jobs=-1)
    clustering = FCM()
    ssl = SemiSupervisedModel(clf, clustering, dl)
    ssl.increase_labeled(0.9, 0.9, min_perc=0.0)
    ssl.fit()
    return ssl

def train_st(dl):
    l, ll = dl.get_labeled()
    u = dl.get_unlabeled()
    ul = np.array([-1]*u.shape[0])
    
    x = np.concatenate((l, u))
    y = np.concatenate((ll, ul))
    
    clf = SGDClassifier(max_iter=10_000, random_state=42, class_weight='balanced', n_jobs=-1)
    calib_model = CalibratedClassifierCV(clf, method='isotonic', cv=3, n_jobs=-1)
    st = SelfTrainingClassifier(calib_model, max_iter=3)
    st.fit(x, y)
    return st


def train_default(dl):
    clf = SGDClassifier(max_iter=10_000, random_state=42, class_weight='balanced', n_jobs=-1)
    clf.fit(*dl.get_labeled())
    return clf


def predict(dl, model):
    to_pred = dl.get_unlabeled_raw()
    #print('to_pred1', to_pred.shape)
    to_pred = to_pred[to_pred[0] != 9999]
    #print('to_pred2', to_pred.shape)
    preds = model.predict(to_pred)
    preds_df = pd.DataFrame(dl.le.inverse_transform(preds))
    
    x_raw = dl.get_unlabeled_raw()
    indexes = x_raw[x_raw[0] != 9999].index
    x_raw.iloc[indexes] = preds_df
    y_pred = np.where(x_raw == 9999, 15, x_raw)
    return y_pred


def run_experiment(idx=0, show=False):
    dl = FloodDataLoder(SOURCES, LABELS, idx)

    t1 = time.time()
    st = train_st(dl)
    t2 = time.time()
    st_time = t2-t1
    print(f'ST Time: {st_time:.2f}')
    
    t1 = time.time()
    ssl = train_ssl(dl)
    t2 = time.time()
    ssl_time = t2-t1
    print(f'SSL Time: {ssl_time:.2f}')
    
    t1 = time.time()
    clf = train_default(dl)
    t2 = time.time()
    clf_time = t2-t1
    print(f'Default Time: {clf_time:.2f}')
    
    ssl_preds = predict(dl, ssl)
    clf_preds = predict(dl, clf)
    st_preds = predict(dl, st)
    
    y_true = dl.get_raw_labels()
    try:
        ssl_ms = metrics(y_true, ssl_preds)
    except Exception as e:
        print(f'SSL Exception: {e}')
        ssl_ms = None
        
    try:
        clf_ms = metrics(y_true, clf_preds)
    except Exception as e:
        print(f'CLF Exception: {e}')
        clf_ms = None
        
    try:
        st_ms = metrics(y_true, st_preds)
    except Exception as e:
        print(f'CLF Exception: {e}')
        st_ms = None

    if show:
        true_im = y_true.reshape(dl.initial_size)
        ssl_im = ssl_preds.reshape(dl.initial_size)
        clf_im = clf_preds.reshape(dl.initial_size)
        st_im = st_preds.reshape(dl.initial_size)
        
        fig = plt.figure(figsize=(12, 12))
        for i, im in enumerate([true_im, st_im, ssl_im, clf_im], 1): 
            fig.add_subplot(1, 4, i)
            plt.imshow(im)
        plt.show()
        fig.clear()
    
    st_ms['time'] = st_time
    ssl_ms['time'] = ssl_time
    clf_ms['time'] = clf_time
    return {'st': st_ms, 'ssl': ssl_ms, 'default': clf_ms}