In [None]:
import sys
from pathlib import Path
import pickle

import numpy as np
import pandas as pd
import geopandas as gpd
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy import stats

sys.path.append('..')

from src.engineer.geowiki import GeoWikiDataInstance
from src.engineer.nigeria import NigeriaDataInstance
from src.exporters.sentinel.cloudfree import BANDS

%reload_ext autoreload
%autoreload 2

## Get data

In [None]:
data_path = Path('../data_split_test/nigeria')

In [None]:
subsets = ['training', 'validation', 'testing']
rows = []
for subset in subsets:
    pickle_files = [file for file in (data_path / subset).glob('*.pkl')]
    for file in pickle_files:
        identifier = file.name.split('_')[0]
        date = '_'.join(pickle_files[0].name.split('_')[1:]).split('.')[0]
        with file.open("rb") as f:
            target_datainstance = pickle.load(f)
        assert isinstance(target_datainstance, NigeriaDataInstance), 'Pickle file is not an instance of geowiki data'
        label = target_datainstance.is_crop

        rows.append((
            identifier,
            date,
            target_datainstance.instance_lat,
            target_datainstance.instance_lon,
            label,
            file.name,
            subset
            ))

In [None]:
df = pd.DataFrame(rows, columns=['identifier', 'date', 'lat', 'lon', 'label', 'filename', 'set'])
gdf = gpd.GeoDataFrame(df, geometry=gpd.points_from_xy(x=df.lon, y=df.lat), crs='epsg:4326')
gdf

## Label distribution

In [None]:
label_dist = gdf.groupby(['set'])['label'].agg(['count', 'sum'])
label_dist['ratio'] = label_dist['sum'] / label_dist['count']
label_dist.loc['total'] = [len(gdf), gdf['label'].sum(), gdf['label'].sum()/len(gdf)]
label_dist.rename(columns={'sum': 'cropland_count'}, inplace=True)
label_dist

## Visualize spatial distribution and save

In [None]:
gdf.plot(column='set', legend=True)

In [None]:
gdf.to_file(data_path / 'nigeria_stratified_labelled_v1_splits.shp')