In [None]:
# install required libraries
!pip install -q rasterio==1.2.10
!pip install -q geopandas==0.10.2
!pip install -q git+https://github.com/tensorflow/examples.git
!pip install -q -U tfds-nightly
!pip install -q focal-loss
!pip install -q tensorflow-addons==0.8.3
#!pip install -q matplotlib==3.5 # UNCOMMENT if running on LOCAL
!pip install -q scikit-learn==1.0.1
!pip install -q scikit-image==0.18.3
!pip install -q boto3

In [39]:
# import required libraries
import os, glob, functools, fnmatch
from zipfile import ZipFile
from itertools import product

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.image as mpimg
mpl.rcParams['axes.grid'] = False
mpl.rcParams['figure.figsize'] = (12,12)
import matplotlib.image as mpimg
from matplotlib.colors import ListedColormap

import pandas as pd
from PIL import Image
import geopandas as gpd
from IPython.display import clear_output
from time import sleep

import skimage.io as skio # lighter dependency than tensorflow for working with our tensors/arrays
from sklearn.metrics import confusion_matrix, f1_score

import boto3 

import io
from configparser import ConfigParser

In [13]:
%matplotlib inline

In [3]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [10]:
# authenticate with AWS credentials
config = ConfigParser()
configFilePath = 'access_keys.csv'
with open(configFilePath) as f:
    config.read_file(f)
AWS_ACCESS_KEY_ID = config.get('credentials', 'AWS_ACCESS_KEY_ID')
AWS_SECRET_ACCESS_KEY = config.get('credentials', 'AWS_SECRET_ACCESS_KEY')


s3 = boto3.resource('s3',
                    aws_access_key_id=AWS_ACCESS_KEY_ID,
                    aws_secret_access_key=AWS_SECRET_ACCESS_KEY)

In [11]:
def image_from_s3(bucket, key):
    bucket = s3.Bucket(bucket)
    image = bucket.Object(key)
    img_data = image.get().get('Body').read()
    return Image.open(io.BytesIO(img_data)) 

In [14]:
def iterate_bucket_items(bucket):
    """
    Generator that iterates over all objects in a given s3 bucket
    See http://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.list_objects_v2 
    for return data format
    :param bucket: name of s3 bucket
    :return: dict of metadata for an object
    """


    client = client = boto3.client('s3',aws_access_key_id=AWS_ACCESS_KEY_ID,aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
    paginator = client.get_paginator('list_objects_v2')
    page_iterator = paginator.paginate(Bucket = bucket)

    for page in page_iterator:
        if page['KeyCount'] > 0:
            for item in page['Contents']:
                yield item


In [15]:
items_s3 = []
urls_s3 = []
items_urls_s3 = []

for i in iterate_bucket_items(bucket='veda-ai-supraglacial-meltponds'):
    bucket='veda-ai-supraglacial-meltponds'
    ik = i["Key"]
    items_s3.append(ik)
    url = 's3://'+str(bucket)+'/'+str(ik)
    urls_s3.append(url)
    item_url = (ik, url)
    items_urls_s3.append(item_url)

In [28]:
ps_imgs_dir = './planetscope/'

In [None]:
ps_imgs = glob.glob(f'{ps_imgs_dir}*.png')

In [37]:
label_dict = {0: 'background',
              1: 'snow',
              2: 'dark_ice',
              3: 'melt_ponds',
              4: 'open_water',
              5: 'ridge_shadpows'}
color_dict = {0: 'black',
              1: 'white',
              2: 'grey',
              3: 'red',
              4: 'blue',
              5: 'darkgrey'}

In [46]:
num_labels = len(color_dict)
incorrect_color_data = np.array([2, 3, 4, 2, 3, 4]) * np.ones((num_labels, num_labels))

cmapm = ListedColormap(color_dict.values())
imin = min(label_dict)
imax = max(label_dict)

In [None]:
plot_dir = './plots_model_v1/'
if (not os.path.isdir(plot_dir)):
  os.mkdir(plot_dir)
for i_url in items_urls_s3:
    substring = "predictions_test_focal_loss_planetscope_96"
    substring1 = ".png" 
    if substring in str(i_url) and substring1 in str(i_url):
        i = i_url[0]
        url = i_url[1]
        image_in_mem = image_from_s3("veda-ai-supraglacial-meltponds", str(i))
        filename_split = os.path.splitext(i) 
        filename_zero, fileext = filename_split 
        basename = os.path.basename(filename_zero) 
        img_ps = np.array(Image.open(f'{ps_imgs_dir}/{basename}.png'))
        img_pred = np.array(image_in_mem)
        if 3 in np.unique(img_pred):
          print(basename)

          fig = plt.figure(figsize=(12, 12))

          ax_0 = plt.subplot(121)
          cax = ax_0.imshow(img_ps)

          ax_1 = plt.subplot(122)
          cm = ListedColormap(color_dict.values())
          cax = ax_1.imshow(img_pred, cmap=cmapm, interpolation='none', vmin=imin, vmax=imax)
          cbar = fig.colorbar(cax, ticks=np.linspace(imin, imax, 2 * num_labels + 1)[1::2], shrink=0.45, ax=ax_1)
          cbar.ax.set_yticklabels(label_dict.values())
          #plt.show() 
          plt.savefig(f'{plot_dir}{basename}_vals.png')