In [1]:
import os, sys

sys.path.insert(0, os.path.abspath('..'))
os.environ['USE_PYGEOS'] = '0'

%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv
load_dotenv()

True

In [2]:
import pandas as pd
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

from google.cloud import storage
from google.oauth2 import service_account
import json

import gc
gc.enable()

import glob

from rasterio.plot import show
import rasterio
from rasterio.merge import merge
from rasterio.transform import from_origin
from rasterio.enums import Resampling

from experiment_configs.configs import unet_config, satmae_large_config, satmae_large_inf_config


from utils.rastervision_pipeline import create_s2_image_source, create_scene_s2, scene_to_inference_ds
from ml.learner import BinarySegmentationPredictor
from models.model_factory import model_factory
from project_config import CLASS_CONFIG
from ml.eval_utils import save_predictions

import subprocess

In [3]:
from google.cloud import storage
from project_config import GCP_PROJECT_NAME, DATASET_JSON_PATH
import json
import matplotlib.pyplot as plt
from utils import gcp_utils

gcp_client = storage.Client(project=GCP_PROJECT_NAME)



In [4]:

ENVBIN = f"{os.environ['HOME']}/.conda/envs/rio-cog-env"
print(ENVBIN)





def create_cogs(source_file, out_file):
    cog_create = f"rio cogeo create {source_file} {out_file}"
    try:
        subprocess.run(cog_create,
                   cwd = f"{ENVBIN}/bin",
                   shell=True)
    except subprocess.CalledProcessError as e:
        print(e)
        


/home/suraj.nair/.conda/envs/rio-cog-env


In [12]:
def save_json(storage_client, bucket_name, out_file_name, data_for_json):
    bucket = storage_client.get_bucket(bucket_name)
    
    json_data = json.dumps(data_for_json, indent = 2)
    blob = bucket.blob(out_file_name)
    blob.upload_from_string(data= json_data ,
                            content_type='application/json')
    
def read_json(storage_client, bucket_name, file_name):
    bucket = storage_client.get_bucket(bucket_name)
    blob = BUCKET.get_blob(file_name)
    # load blob using json
    data = json.loads(blob.download_as_string())
    return data


def delete_files(filepath):
    try:
        cmd = f"rm -f {filepath}/*"
        # Run the gsutil command
        subprocess.run(cmd, shell=True, check=True)
        print(f"Existing file deleted")
    except subprocess.CalledProcessError as e:
        print(f"Error: {e}")

def save_predictions_to_gcp(prediction, 
                     prediction_out_path, 
                    crs_transformer, 
                    prediction_scores_file_path, 
                    prediction_cog_file_path, 
                    gcp_dest,
                    overwrite = False,
                    move_to_gcp = True):
    
#     print(prediction_out_path, prediction_scores_file_path, prediction_cog_file_path, sep = "\n")
    ### SAVE
    
    print(f"Saving to {prediction_out_path}", datetime.now())
    if not os.path.exists(prediction_out_path):
        os.makedirs(prediction_out_path)
    else:
        if overwrite:
            print("Prediction Files Exist! Overwriting..")
            delete_files(prediction_out_path)
            
        else:
            print("Prediction Files Exist! Set overwrite = True to over write!")
            return
        
    save_predictions(prediction, 
         path=prediction_out_path, 
         class_config=CLASS_CONFIG, 
         crs_transformer=crs_transformer, 
         threshold=0.5)

    create_cogs(prediction_scores_file_path, prediction_cog_file_path)

    if move_to_gcp:
        ### Move to GCP
        gcp_move = f"""gsutil cp -r {prediction_cog_file_path} {gcp_dest}"""
        try:
            # Run the gsutil command
            subprocess.run(gcp_move, shell=True, check=True)
            print(f"Files copied successfully to GCP Bucket")
        except subprocess.CalledProcessError as e:
            print(f"Error copying files: {e}")

    
    print("Completed", datetime.now())
    
    
def get_river_index(dataset, river_name, date, prediction_id):
    for i, x in enumerate(dataset):
        if (x['date'] == DATE) & (x['river'] == river_name) & (x['uid'] == prediction_id):
            return i

In [13]:
sheet_id = "1Ov1M_zsb5jYo_dtIjUXco1wvNgasmdoZKtJ877psHL4"
sheet_name = "rivers_to_osm_label"
url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/gviz/tq?tqx=out:csv&sheet={sheet_name}"
        
df_rivers = pd.read_csv(url)
df_rivers['osm_id'] = df_rivers['osm_id'].astype('int')

river_names = df_rivers['River'].unique()


In [22]:
BUCKET_NAME = 'sand_mining_inference'
OVERWRITE = True
N_CHANNELS = 10

drop = ['Godavari', 'Godavari (N)',
       'Godavari (S)', 'Sone - South']
river_names = [r.lower() for r in river_names if r not in drop]


In [40]:
#### Configuration

config = satmae_large_inf_config
PREDICTION_ID = config.wandb_id.split("/")[-1]

# Load model
from ml.model_stats import count_number_of_weights
model = model_factory(
    config,
    n_channels= N_CHANNELS ,
)

predictor = BinarySegmentationPredictor(
    config, model, config.encoder_weights_path
)

crop_sz = int(config.tile_size // 5) #20% of the tiles at the edges are discarded

all_params, trainable_params = count_number_of_weights(predictor.model)
print(f"trainable params: {trainable_params/1e6}M || all params: {all_params/1e6}M || trainable%: {100 * trainable_params / all_params:.2f}")


SatMae: Loading encoder weights from /data/sand_mining/checkpoints/finetuned/SatMAE-L_LoRA-bias_LN_160px_mclw-6_B8_E9_SmoothVal-S5-DecOnly-E20.pth
Number of parameters loaded: 299
SatMae: Loading decoder weights from /data/sand_mining/checkpoints/finetuned/SatMAE-L_LoRA-bias_LN_160px_mclw-6_B8_E9_SmoothVal-S5-DecOnly-E20.pth
Temperature scaling set to None
trainable params: 304.273667M || all params: 304.273667M || trainable%: 100.00


In [67]:
DATE = '2022-02-01'

In [68]:
### Load River URI dataset


for river_name in ['chambal']:
    
    master_dataset = "../dataset/river_oos_dataset_v0.2.json"
    river_jsons = json.load(open(master_dataset, 'r'))
    
    gc.collect()
    if river_name not in drop:
        print(river_name, DATE)
        
        BASE_PATH = f"/data/sand_mining/predictions/outputs/{river_name}"
        RIVER_URI = f"https://storage.googleapis.com/sand_mining_inference/{river_name}/{river_name}.geojson"
        RIVER_OUT_NAME = river_name
        filter_func = lambda x: (x['date'] == DATE) & (x['river'] == river_name) & (x['uid'] == PREDICTION_ID)
        
#         river_json = list(filter(filter_func, river_jsons))[0]
        river_index = get_river_index(river_jsons, river_name, DATE, PREDICTION_ID)
        river_json = river_jsons.pop(river_index)
        uri_to_s2 = river_json['uri_to_s2']
        
        if river_json['uri_to_prediction'] != "":
            user_input = input("Prediction URI appears to exist! Would you like to overwrite? (Y/N)")
        else:
            user_input = "Y"
            
        if user_input == "Y":
            prediction_out_path = os.path.join(BASE_PATH, DATE)
            prediction_scores_file_path = os.path.join(prediction_out_path,  "scores.tif")
            prediction_cog_file_path = os.path.join(prediction_out_path, f"{RIVER_OUT_NAME}_prediction_{DATE}.tif")
            gcp_dest = f"gs://sand_mining_inference/{river_name}/{DATE}/{RIVER_OUT_NAME}_prediction_{DATE}.tif"

            prediction_base = f"https://storage.googleapis.com/sand_mining_inference/{RIVER_OUT_NAME}"
            prediction_gcp_path =  f"{prediction_base}/{DATE}/{RIVER_OUT_NAME}_prediction_{DATE}_{PREDICTION_ID}.tif"
            river_json['uri_to_prediction'] = prediction_gcp_path

            river_jsons.append(river_json)
            
            with open("../dataset/river_oos_dataset_v0.2.json", 'w', encoding='utf-8') as f:
                json.dump(river_jsons, f, indent=4, default=str)
        
        #     #Predict
#             r_source = create_scene_s2(config, uri_to_s2, label_uri = None, scene_id = 0, rivers_uri = RIVER_URI)
#             r_inference = scene_to_inference_ds(config, r_source, full_image=False, stride=int(config.tile_size/2))
#             crs_transformer = r_inference.scene.raster_source.crs_transformer
#             prediction = predictor.predict_site(r_inference, crop_sz=crop_sz)

#             save_predictions_to_gcp(
#                     prediction, 
#                     prediction_out_path, 
#                     crs_transformer, 
#                     prediction_scores_file_path, 
#                     prediction_cog_file_path, 
#                     gcp_dest,
#                     overwrite = OVERWRITE,
#                     move_to_gcp = True)

        else:
            print("Completed!")








chambal 2022-02-01
