In [1]:
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '3' #cannot work

In [2]:
import pandas as pd
import os
import pathlib
import torch
import rasterio
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm

In [3]:
import numpy as np
from scipy.ndimage import label
from PIL import Image
import rioxarray
import xarray as xr
import matplotlib.pyplot as plt
import os
from pathlib import Path
import torch
from tqdm import tqdm

In [4]:
from tifffile import imwrite

In [5]:
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(0)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

In [6]:
# from https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/SAM/Fine_tune_SAM_(segment_anything)_on_a_custom_dataset.ipynb
def get_smaller_bounding_box(ground_truth_map):
  # get bounding box from mask
  z, y_indices, x_indices = np.where(ground_truth_map > 0)
  x_min, x_max = np.min(x_indices), np.max(x_indices)
  y_min, y_max = np.min(y_indices), np.max(y_indices)
  x_range = x_max-x_min
  y_range = y_max-y_min
  x_factor, y_factor = int(x_range/20), int(y_range/20)
  bbox = [x_min+x_factor, 
          y_min+y_factor, 
          x_max-x_factor, 
          y_max-y_factor]

  return bbox

def get_bounding_box(ground_truth_map):
  # get bounding box from mask
  z, y_indices, x_indices = np.where(ground_truth_map > 0)
  x_min, x_max = np.min(x_indices), np.max(x_indices)
  y_min, y_max = np.min(y_indices), np.max(y_indices)
  bbox = [x_min, y_min, x_max, y_max]

  return bbox

def get_bboxes(mask_tensor):
    crop_bboxes = []
    building_bboxes = []
    building_groundtruth = np.where(mask_tensor % 2 == 1, 1, 0)[0,:,:,:]
    field_groundtruth = np.where(mask_tensor==2, 1, 0)[0,:,:,:]

    building_masks, building_num_labels = label(building_groundtruth)
    for i in range(1, building_num_labels + 1):
        building_object_mask = np.where(building_masks == i, 1, 0)
        bbox = get_bounding_box(building_object_mask)
        building_bboxes.append(bbox)
        
    field_masks, field_num_labels = label(field_groundtruth)
    for i in range(1, field_num_labels + 1):
        field_object_mask = np.where(field_masks == i, 1, 0)
        bbox = get_smaller_bounding_box(field_object_mask)
        crop_bboxes.append(bbox)
    return building_bboxes, crop_bboxes

In [7]:
imagedir = '/home/data/kenya/images/'
maskdir = '/home/data/kenya/labels/'

In [8]:
mask_paths = [Path(maskdir).joinpath(i) for i in os.listdir(maskdir) if i.endswith('.tif')]

In [9]:
mask_paths[3]

PosixPath('/home/data/kenya/labels/kenol1_1929.tif')

In [10]:
class SatelliteData(Dataset):
    
    def __init__(self, 
                 imagedir, maskdir):
        self.image_dir = Path(imagedir)
        self.mask_dir = Path(maskdir)
        self.tif_paths = self._get_tif_paths()
        self.mask_paths = self._get_mask_paths()


    def _get_tif_paths(self):
        tif_paths = [self.image_dir.joinpath(i) for i in os.listdir(self.mask_dir) if i.endswith('.tif')]
        return tif_paths

    def _get_mask_paths(self):
        mask_paths = [self.mask_dir.joinpath(i) for i in os.listdir(self.mask_dir) if i.endswith('.tif')]
        return mask_paths
    
    def __len__(self):
        return len(self.tif_paths)

    def __getitem__(self, index):
        def read_tif_as_np_array(path):
            with rasterio.open(path) as src:
                    return src.read()
                
        # Read in merged tif as ground truth
        groundtruth = read_tif_as_np_array(self.mask_paths[index])
        groundtruth = torch.tensor(groundtruth, dtype=torch.uint8)
        image = read_tif_as_np_array(self.tif_paths[index])
        image = torch.tensor(image, dtype=torch.float32)
        chip_name = str(self.tif_paths[index]).split('/')[-1]
        
        return chip_name, groundtruth, image 

In [11]:
image_dataset = SatelliteData(imagedir = imagedir, maskdir = maskdir)

# data loader
image_loader = DataLoader(image_dataset, 
                          batch_size  = 1, 
                          shuffle     = False)

# display images
for batch_idx, inputs in enumerate(image_loader):
    print(inputs[2].shape)
    break

torch.Size([1, 4, 512, 512])


In [17]:
save_dir = pathlib.Path('/home/data/kenya-sam/labels/')

In [None]:
# GET JUST BUILDINGS
for inputs in tqdm(image_loader):
    path, pred, image = inputs
    building_bboxes, crop_bboxes = get_bboxes(pred.numpy())
    if crop_bboxes:
        crop_inputs = processor(image[0,:3,:,:], input_boxes=[crop_bboxes], return_tensors="pt").to(0)
        crop_outputs = model(**crop_inputs)
        crop_masks = processor.image_processor.post_process_masks(crop_outputs.pred_masks.cpu(), crop_inputs["original_sizes"].cpu(), crop_inputs["reshaped_input_sizes"].cpu())

        del crop_inputs, crop_outputs
    if building_bboxes:
        building_inputs = processor(image[0,:3,:,:], input_boxes=[building_bboxes], return_tensors="pt").to(0)
        building_outputs = model(**building_inputs)
        building_masks = processor.image_processor.post_process_masks(building_outputs.pred_masks.cpu(), building_inputs["original_sizes"].cpu(), building_inputs["reshaped_input_sizes"].cpu())
    
        del building_inputs, building_outputs

    crops = pred[0,:,:,:]
    
    if building_bboxes:
        building_mask = torch.any(building_masks[0], 0)[:1,:,:]
        building_binary = torch.where(building_mask, 1, 0)
        crop_mask = torch.where(crops==2, 2, 0)
        sam_mask = torch.where(building_binary==1, 1, crop_mask).numpy()
    else:
        sam_mask = crops.numpy()
        
    imwrite(save_dir / path[0], sam_mask)

In [12]:
# GET CROPS AND BUILDINGS
for inputs in tqdm(image_loader):
    path, pred, image = inputs
    building_bboxes, crop_bboxes = get_bboxes(pred.numpy())
    if crop_bboxes:
        crop_inputs = processor(image[0,:3,:,:], input_boxes=[crop_bboxes], return_tensors="pt").to(0)
        crop_outputs = model(**crop_inputs)
        crop_masks = processor.image_processor.post_process_masks(crop_outputs.pred_masks.cpu(), crop_inputs["original_sizes"].cpu(), crop_inputs["reshaped_input_sizes"].cpu())

        del crop_inputs, crop_outputs
        
    if building_bboxes:
        building_inputs = processor(image[0,:3,:,:], input_boxes=[building_bboxes], return_tensors="pt").to(0)
        building_outputs = model(**building_inputs)
        building_masks = processor.image_processor.post_process_masks(building_outputs.pred_masks.cpu(), building_inputs["original_sizes"].cpu(), building_inputs["reshaped_input_sizes"].cpu())
    
        del building_inputs, building_outputs

    crops = pred[0,:,:,:]
    
    if building_bboxes and crop_bboxes:
        building_mask = torch.any(building_masks[0], 0)[:1,:,:]
        building_binary = torch.where(building_mask, 1, 0)
        crop_mask = torch.any(crop_masks[0], 0)[:1,:,:]
        crop_binary = torch.where(crop_mask, 2, 0)
        sam_mask = torch.where(building_binary==1, 1, crop_binary).numpy()
    elif crop_bboxes:
        crop_mask = torch.any(crop_masks[0], 0)[:1,:,:]
        sam_mask = torch.where(crop_mask, 2, 0).numpy()
    else:
        building_mask = torch.any(building_masks[0], 0)[:1,:,:]
        sam_mask = torch.where(building_mask, 1, 0).numpy()
        
    imwrite(save_dir / path[0], sam_mask)

  0%|                                                                                          | 0/2041 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 15.74 GiB total capacity; 14.50 GiB already allocated; 384.62 MiB free; 14.96 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [22]:
unetdir = Path('/home/workdir/solar_test_output/hardened_prob/')
imagedir = Path('/home/data/kenya-sam/images/')
maskdir = Path('/home/data/kenya-sam/labels/')

In [119]:
rasters = [i for i in os.listdir(save_dir) if i.endswith('.tif')]
for p in tqdm(rasters):
    with rasterio.open(unetdir / ('crisp_id_'+p)) as src:
        profile = src.profile
    with rasterio.open(save_dir / p) as src:
        array = src.read()
    with rasterio.open((save_dir / p), 'w', **profile) as dst:
        dst.write(array.astype(rasterio.uint8))

100%|████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 111.66it/s]


In [29]:
rasters = [i for i in os.listdir(save_dir) if i.endswith('.tif')]
for p in tqdm(rasters):
    with rasterio.open(maskdir / p ) as src:
        profile = src.profile
    with rasterio.open(save_dir / p) as src:
        array = src.read()
    with rasterio.open((save_dir / p), 'w', **profile) as dst:
        dst.write(array.astype(rasterio.uint8))

100%|██████████████████████████████████████████████████████████████████████████████| 2041/2041 [00:13<00:00, 146.86it/s]
