In [None]:
import pdb
import os
import boto3
import rasterio
import rasterio.warp
from rasterio.enums import Resampling
from rasterio.io import MemoryFile
import numpy as np
from io import BytesIO
import tarfile
import gc
from multiprocessing import Pool, cpu_count

# Initialize S3 client
s3_client = boto3.client('s3')

bucket_name = 'agrisense3'
cropscape_folder = 'cropscape_monterey/'
landsat_folder = 'landsat/'
landsat_parsed_folder = 'landsat_parsed/'
converted_folder = 'converted/'
masked_folder = 'landsat_masked/'

def download_from_s3(bucket, key):
    """Download file from S3 and return as bytes."""
    with BytesIO() as file_buffer:
        s3_client.download_fileobj(bucket, key, file_buffer)
        file_buffer.seek(0)
        return file_buffer.read()

def upload_to_s3(file_bytes, bucket, key):
    """Upload bytes to S3."""
    with BytesIO(file_bytes) as file_buffer:
        s3_client.upload_fileobj(file_buffer, bucket, key)
        print(f"Uploaded to {bucket}/{key}")

def resample_raster(src_raster, target_raster_profile, resampling_method=Resampling.nearest):
    """Resample the source raster to match the target raster's profile."""
    with rasterio.open(BytesIO(src_raster)) as src:
        data = np.empty((src.count, target_raster_profile['height'], target_raster_profile['width']), dtype=src.dtypes[0])
        transform = target_raster_profile['transform']
        
        for i in range(1, src.count + 1):
            rasterio.warp.reproject(
                source=rasterio.band(src, i),
                destination=data[i - 1],
                src_transform=src.transform,
                src_crs=src.crs,
                dst_transform=transform,
                dst_crs=target_raster_profile['crs'],
                resampling=resampling_method
            )
        return data[0], transform

def reproject_to_wgs84(src_bytes):
    """Reproject the raster to WGS84."""
    with rasterio.open(BytesIO(src_bytes)) as img:
        transform, width, height = rasterio.warp.calculate_default_transform(
            img.crs, 'EPSG:4326', img.width, img.height, *img.bounds
        )
        profile = img.profile
        profile.update(crs='EPSG:4326', transform=transform, width=width, height=height)
        
        with MemoryFile() as memfile:
            with memfile.open(**profile) as dst:
                for i in range(1, img.count + 1):
                    rasterio.warp.reproject(
                        source=rasterio.band(img, i),
                        destination=rasterio.band(dst, i),
                        src_transform=img.transform,
                        src_crs=img.crs,
                        dst_transform=transform,
                        dst_crs='EPSG:4326',
                        resampling=Resampling.nearest
                    )
            return memfile.read()

def mask_tif(target_bytes, source_bytes):
    """Apply a mask to the target TIFF using the source TIFF."""
    with rasterio.open(BytesIO(target_bytes)) as target:
        target_array = target.read()
        target_profile = target.profile

    source_resampled, _ = resample_raster(source_bytes, target_profile)

    threshold = 221
    mask_array = source_resampled == threshold
    masked_array = np.where(mask_array, target_array, 0)

    with MemoryFile() as memfile:
        with memfile.open(**target_profile) as dst:
            for i in range(target_array.shape[0]):
                dst.write(masked_array[i].astype('float32'), i+1)
        return memfile.read()

def extract_tar_gz(tar_gz_bytes):
    """Extract .tar.gz file and return a list of extracted file bytes."""
    extracted_files = []
    with tarfile.open(fileobj=BytesIO(tar_gz_bytes), mode='r:gz') as tar:
        for member in tar.getmembers():
            if member.isfile() and member.name.endswith('.tif'):
                extracted_files.append((member.name, tar.extractfile(member).read()))
    return extracted_files

def process_tar_gz_file(obj):
    key = obj['Key']
    tar_gz_bytes = download_from_s3(bucket_name, key)
    extracted_files = extract_tar_gz(tar_gz_bytes)
    existing_files_response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=converted_folder)
    existing_files = [os.path.basename(item['Key']) for item in existing_files_response.get('Contents', [])]

    for file_name, file_bytes in extracted_files:
        if file_name in existing_files:
            print(f"{key}, {file_name} previously processed")
            continue
        reprojected_bytes = reproject_to_wgs84(file_bytes)
        upload_to_s3(reprojected_bytes, bucket_name, f"{converted_folder}{file_name}")
        del file_bytes, reprojected_bytes  # Free memory
        gc.collect()  # Trigger garbage collection

    # Move the parsed tar.gz file to the landsat_parsed_folder
    parsed_key = landsat_parsed_folder + key[len(landsat_folder):]
    s3_client.copy_object(Bucket=bucket_name, CopySource={'Bucket': bucket_name, 'Key': key}, Key=parsed_key)
    s3_client.delete_object(Bucket=bucket_name, Key=key)
    del tar_gz_bytes, extracted_files  # Free memory
    gc.collect()  # Trigger garbage collection
    print(f"{bucket_name} {key} Completed")

def process_mask_file(obj):
    key = obj['Key']
    proc_cnt = 0
    masked_key = f"{masked_folder}{os.path.splitext(key.split('/')[-1])[0]}_masked.tiff"
    
    # Check if the masked file already exists
    masked_response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=masked_folder)
    existing_masked_files = [item['Key'] for item in masked_response.get('Contents', [])]
    if masked_key in existing_masked_files:
        print(f"Masked file {masked_key} already exists. Skipping.")
        return
    proc_cnt += 1
    if proc_cnt > 100
        print("Proc Cnt 
        return
    year = key.split('/')[-1][26:30]
    target_bytes = download_from_s3(bucket_name, key)
    source_key = f"{converted_folder}cropscape-strawberries-06053-{year}.tif"

    # Check if source file exists
    try:
        source_bytes = download_from_s3(bucket_name, source_key)
    except:
        print(f"Source file not found for {key}: {source_key}")
        return
    
    masked_bytes = mask_tif(target_bytes, source_bytes)
    upload_to_s3(masked_bytes, bucket_name, masked_key)
    del target_bytes, source_bytes, masked_bytes  # Free memory
    gc.collect()  # Trigger garbage collection
    print(f"{proc_cnt} {bucket_name} {key} Completed")

def process_files():
    # Download and reproject Landsat files
    response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=landsat_folder)
    if 'Contents' in response:
        with Pool(2) as pool:
            pool.map(process_tar_gz_file, [obj for obj in response['Contents'] if obj['Key'].endswith('.tar.gz')])

    # Masking
    response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=converted_folder)
    if 'Contents' in response:
        with Pool(2) as pool:
            pool.map(process_mask_file, [obj for obj in response['Contents'] if obj['Key'].endswith('.tif')])

    print("Processing complete.")

if __name__ == "__main__":
    process_files()


Masked file landsat_masked/LC08_L2SP_043035_20140122_20200912_02_T1_SR_B6_masked.tiff already exists. Skipping.
Masked file landsat_masked/LC08_L2SP_043035_20130527_20200913_02_T1_ETA_masked.tiff already exists. Skipping.
Masked file landsat_masked/LC08_L2SP_043035_20140122_20200912_02_T1_SR_B7_masked.tiff already exists. Skipping.
Masked file landsat_masked/LC08_L2SP_043035_20130527_20200913_02_T1_ETF_masked.tiff already exists. Skipping.
Masked file landsat_masked/LC08_L2SP_043035_20140122_20200912_02_T1_SR_NDMI_masked.tiff already exists. Skipping.
Masked file landsat_masked/LC08_L2SP_043035_20140122_20200912_02_T1_SR_QA_AEROSOL_masked.tiff already exists. Skipping.
Masked file landsat_masked/LC08_L2SP_043035_20140122_20200912_02_T1_ST_ATRAN_masked.tiff already exists. Skipping.
Masked file landsat_masked/LC08_L2SP_043035_20140122_20200912_02_T1_ST_B10_masked.tiff already exists. Skipping.
Masked file landsat_masked/LC08_L2SP_043035_20140122_20200912_02_T1_ST_CDIST_masked.tiff alrea