In [1]:
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from osgeo import gdal
from pathlib import Path
from numpy.typing import NDArray
from typing import List, Dict, Tuple

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
from sklearn.semi_supervised import SelfTrainingClassifier
from sklearn.calibration import CalibratedClassifierCV

In [2]:
from components.dataloader import AbstractDataLoader
from components.fcm import FCM
from components.ssl import SemiSupervisedModel

In [3]:
import warnings
warnings.filterwarnings("ignore")

In [4]:
class CropDataLoader(AbstractDataLoader):
    bands = ['r', 'g', 'b']
    #classes = [0, 1, 2, 3, 4]
    initial_size = (256, 256)
    
    def __init__(self, idx=0):
        self.all_ds = None
        self.idx = idx
        self.le = LabelEncoder()
        self.labeled = None
        self.labels = None
        self.unlabeled = None
    
    def set_all_ds(self):
        main_dir = '/Users/dimignatiev/Documents/HSE/Deeplom/SSL/nasa_marine_debris'

        sources_sub_dir = f'{main_dir}/nasa_marine_debris_source'
        sources = sorted([f'{sources_sub_dir}/{f}' for f in os.listdir(sources_sub_dir) if 'source' in f])

        labels_sub_dir = f'{main_dir}/nasa_marine_debris_labels'
        labels = sorted([f'{labels_sub_dir}/{f}' for f in os.listdir(labels_sub_dir) if 'labels' in f])

        n_pics = 100
        res_dfs = []
        for id_, (l, s) in enumerate(zip(labels[:n_pics], sources[:n_pics])):
            ds = gdal.Open(f'{s}/image_geotiff.tif', gdal.GA_ReadOnly)
            source_bands = []
            for i in range(1, 4):
                rb = ds.GetRasterBand(1)
                arr = rb.ReadAsArray().reshape(-1, 1)
                source_bands.append(arr)

            label_arr = np.full((256, 256), -1)
            bbs = np.load(f'{l}/pixel_bounds.npy')
            for bb in bbs:
                label_arr[bb[1]:bb[3],bb[0]:bb[2]] = bb[4]
            label_arr = label_arr.reshape(-1, 1)
            source_bands.append(label_arr)
            res_df = pd.DataFrame(np.concatenate(source_bands, axis=1), columns=['r', 'g', 'b', 'label'])
            res_df['id'] = id_
            res_dfs.append(res_df)
        self.all_ds = pd.concat(res_dfs)
        
    def set_sub_dss(self):
        pic_df = self.all_ds[self.all_ds['id'] == self.idx].copy()
        pic_df['is_labeled'] = False
        self.pic_classes = pic_df['label'].unique()
        for cl in self.pic_classes:
            cl_df = pic_df[pic_df['label'] == cl].sample(frac=0.1, random_state=42)
            pic_df.loc[cl_df.index, 'is_labeled'] = True
        self.labeled = pic_df[pic_df['is_labeled'] == True][self.bands]
        self.labels = self.le.fit_transform(pic_df[pic_df['is_labeled'] == True]['label'])
        self.unlabeled = pic_df[pic_df['is_labeled'] == False][self.bands]

    def get_labeled(self) -> Tuple[pd.DataFrame, NDArray]:
        return self.labeled.copy(), self.labels.copy()

    def get_unlabeled(self) -> pd.DataFrame:
        return self.unlabeled.copy()
    
    def get_unlabeled_raw(self):
        return self.all_ds[self.all_ds['id'] == self.idx][self.bands].copy()
    
    def get_raw_labels(self):
        return self.all_ds[self.all_ds['id'] == self.idx]['label'].copy()
    
    def get_w_idx(self):
        return self.all_ds[self.all_ds['id'] == self.idx].copy()

    def get_unique_labels(self) -> List[str]:
        return self.pic_classes

NameError: name 'AbstractDataLoader' is not defined

In [None]:
def metrics(y_true, y_pred):    
    cr = classification_report(y_true, y_pred, output_dict=True)
    return cr
    

def train_ssl(dl):
    clf = SGDClassifier(max_iter=10_000, random_state=42, class_weight='balanced', n_jobs=-1)
    clustering = FCM()
    ka_ = RBFSampler(gamma='scale', random_state=42)
    ssl = SelfSupervisedModel(clf, clustering, dl, ka_)
    ssl.increase_labeled(0.7, 0.7, min_perc=0.0)
    ssl.fit()
    return ssl


def train_st(dl):
    l, ll = dl.get_labeled()
    u = dl.get_unlabeled()
    ul = np.array([-1]*u.shape[0])
    
    x = np.concatenate((l, u))
    y = np.concatenate((ll, ul))
    
    clf = SGDClassifier(max_iter=10_000, random_state=42, class_weight='balanced', n_jobs=-1)
    calib_model = CalibratedClassifierCV(clf, method='isotonic', cv=3, n_jobs=-1)
    st = SelfTrainingClassifier(calib_model)
    st.fit(x, y)
    return st


def train_default(dl):
    clf = SGDClassifier(max_iter=10_000, random_state=42, class_weight='balanced', n_jobs=-1)
    clf.fit(*dl.get_labeled())
    return clf


def predict(dl, model):
    to_pred = dl.get_unlabeled_raw()
    preds = model.predict(to_pred)
    preds = dl.le.inverse_transform(preds)
    return preds


def run_experiment(idx=0, show=False):    
    dl = CropDataLoader(idx)
    dl.set_all_ds()
    dl.set_sub_dss()

    t1 = time.time()
    st = train_st(dl)
    t2 = time.time()
    st_time = t2-t1
    print(f'ST Time: {st_time:.2f}')
    
    t1 = time.time()
    ssl = train_ssl(dl)
    t2 = time.time()
    ssl_time = t2-t1
    print(f'SSL Time: {ssl_time:.2f}')
    
    t1 = time.time()
    clf = train_default(dl)
    t2 = time.time()
    clf_time = t2-t1
    print(f'Default Time: {clf_time:.2f}')
    
    sub_df = dl.get_w_idx()
    
    sub_df['ssl_preds'] = predict(dl, ssl)
    sub_df['clf_preds'] = predict(dl, clf)
    sub_df['st_preds'] = predict(dl, st)
    
    y_true = sub_df['label'].to_numpy()
    ssl_preds = sub_df['ssl_preds'].to_numpy()
    clf_preds = sub_df['clf_preds'].to_numpy()
    st_preds = sub_df['st_preds'].to_numpy()

    ssl_ms = metrics(y_true, ssl_preds)
    clf_ms = metrics(y_true, clf_preds)
    st_ms = metrics(y_true, st_preds)

    if show:
        true_im = y_true.reshape(dl.initial_size)
        ssl_im = ssl_preds.reshape(dl.initial_size)
        clf_im = clf_preds.reshape(dl.initial_size)
        st_im = st_preds.reshape(dl.initial_size)

        fig = plt.figure(figsize=(12, 12))
        for i, im in enumerate([true_im, st_im, ssl_im, clf_im], 1): 
            fig.add_subplot(1, 4, i)
            plt.imshow(im)
        plt.show()
        fig.clear()

    st_ms['time'] = st_time
    ssl_ms['time'] = ssl_time
    clf_ms['time'] = clf_time
    
    return {'st': st_ms, 'ssl': ssl_ms, 'default': clf_ms}