# Get features from downloaded data

To iterate fast, we use the already downloaded images to get features. We should really download new images in a more appropriate format (e.g. median pixel values over three years), but let's use these for now. We want to get the rest of the pipeline stress-tested before improving this area

In [1]:
cd ../

/cephyr/users/markpett/Alvis/ImputeAwareATE


## Save Landsat geotiffs as RGB pngs

Read DHS data from CSV

In [3]:
import os
import pandas as pd
import configparser

# Read config file
config = configparser.ConfigParser()
config.read('config.ini')

DATA_DIR = config['PATHS']['DATA_DIR']

slash_df = pd.read_csv(os.path.join(DATA_DIR, 'dhs_data_with_slash.csv')) # Some of the cluster_ids contains a slash.
slash_df

Unnamed: 0,cluster_id,lon,lat,rural,region_id,country,survey,month,year,iwi
0,AO.Bengo.71.135,13.640789,-8.589805,False,AO.Bengo,Angola,Angola 2015-16 Standard DHS,11,2015,62.334459
1,AO.Bengo.71.158,14.122619,-7.718385,True,AO.Bengo,Angola,Angola 2015-16 Standard DHS,2,2016,8.226589
2,AO.Bengo.71.169,13.654425,-8.592545,False,AO.Bengo,Angola,Angola 2015-16 Standard DHS,10,2015,62.760211
3,AO.Bengo.71.203,13.517859,-8.652260,True,AO.Bengo,Angola,Angola 2015-16 Standard DHS,1,2016,68.211697
4,AO.Bengo.71.208,13.721998,-7.852511,True,AO.Bengo,Angola,Angola 2015-16 Standard DHS,11,2015,14.825944
...,...,...,...,...,...,...,...,...,...,...
69944,ZW.Midlands.72.37,30.008579,-20.911177,True,ZW.Midlands,Zimbabwe,Zimbabwe 2015 Standard DHS,9,2015,27.791567
69945,ZW.Midlands.72.52,29.860028,-20.402214,True,ZW.Midlands,Zimbabwe,Zimbabwe 2015 Standard DHS,10,2015,36.929878
69946,ZW.Midlands.72.69,30.172833,-20.724753,True,ZW.Midlands,Zimbabwe,Zimbabwe 2015 Standard DHS,10,2015,24.406326
69947,ZW.Midlands.72.91,29.820084,-19.453466,False,ZW.Midlands,Zimbabwe,Zimbabwe 2015 Standard DHS,7,2015,59.887344


Code for reading geotiff files

In [4]:
import os
import io
import zipfile
import numpy as np
import rasterio as rio
import multiprocessing as mp
from retry import retry
from rasterio import MemoryFile
from typing import Tuple, Iterable
from time import time
from einops import rearrange

MS_BANDS = ['BLUE', 'GREEN', 'RED', 'NIR', 'SWIR1', 'SWIR2']

def parse_landsat_file_name(file_path: str) -> Tuple[str, str, str]:
    """
    Parses the filename of an image file to extract the band, date, and collection information.
    
    Args:
        file_path (str): The path to the image file.
        
    Returns:
        Tuple[str, str, str]: A tuple containing the band, date, and collection information extracted from the filename.
    """
    
    # Only consider the last part of the file name
    # Otherwise, the indexing might be messed up by the file name
    file_path = file_path.split('.')[-2]
    
    # The collection is denoted by four characters staring with 'L'
    collection_i = file_path.find('L')
    collection = file_path[collection_i:collection_i+4]
    
    # find the date
    date = file_path.split('_')[-2]
    
    # find the band
    band = file_path.split('_')[-1]
    
    return band, date, collection

def get_img_series_from_landsat_zip(zip_file_path: str) -> np.array:
    
    sample = {}
    img_frames = []

    with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
        
        # Read byte data for each band tiff-file into dict
        # Group band data by date-collection, i.e. by image frame
        for file_name in zip_file.namelist():
            band, date, collection = parse_landsat_file_name(file_name)
            
            image_frame_id = f'{date}.{collection}'
            image_frame_data_dict = sample.setdefault(image_frame_id, {})
            image_frame_data_dict[band] = zip_file.read(file_name)
        
        # Iterate over all frames in sample
        for image_frame_id, image_frame_data_dict in sample.items():
            # Get image frame as np.array with shape (width, heigh, n_channels)
            img_frame = get_landsat_frame(image_frame_data_dict)
            img_frames.append(img_frame)
            
            # Free memory (might be unnecessary)
            del(image_frame_data_dict)
    
    # Stack frames to an array and rearrange to shape (n_frames, n_channels, width, height)
    img_series = np.stack(img_frames)
    #img_series = rearrange(img_series, 't h w c -> t c h w')
    return img_series

def get_landsat_frame(band_to_data_dict: dict) -> np.array:
    """
    Combines the data from the different bands of a Landsat image into a single tiff-file.

    Args:
        band_to_data_dict (dict): A dictionary mapping band names to the corresponding tiff-file data.

    Returns:
        bytes: The image data as bytes in tiff-format.
    """
    
    img_values = []
    
    for band in ['RED', 'GREEN', 'BLUE']:
        band_data = band_to_data_dict[band]
        with MemoryFile(band_data) as band_memfile:
            with band_memfile.open() as src:
                band_values = src.read(1)
                img_values.append(band_values)
    img = np.stack(img_values, axis=2)
    
    return img

Code for writing PNG file

In [5]:
from PIL import Image

def get_last_img(zip_file_path: str) -> np.array:
    
    # Get last Landsat frame
    x = get_img_series_from_landsat_zip(zip_file_path)
    img = x[-1]
    
    # Rescale
    img = np.clip(img * 0.0000275 - 0.2, 0.0, 0.3) / 0.3
    
    return img

def write_png(img: np.array, write_path: str):
    
    # Get as PIL image
    img = Image.fromarray((img * 255).astype(np.uint8))
    img.save(write_path)
    
    
SOURCE_DIR = '/mimer/NOBACKUP/groups/globalpoverty1/markus/temporal-vit/dhs_images'
OUT_DIR = '/mimer/NOBACKUP/groups/globalpoverty1/markus/impute_aware_ate'

def save_row_as_png(row):
    
    geotiff_path = os.path.join(SOURCE_DIR, row['cluster_id'], 'landsat.zip')
    cluster_id = row['cluster_id'].replace('/', '.') # Remove '/' from cluster IDs
    png_path = os.path.join(OUT_DIR, 'dhs_images', cluster_id + '.png')
    
    if not os.path.isfile(png_path):
    
        img = get_last_img(geotiff_path)

        write_png(img, png_path)

Write all geotiffs as RGB pngs in the new location

In [7]:
start = time()

# Get samples as list, since multiprocessing doesn't work with dataframes
clusters = [row for _, row in slash_df.iterrows()]

n_workers = 40
pool = mp.Pool(n_workers)
pool.map(save_row_as_png, clusters)
pool.close()
pool.join()

time_len = time() - start

In [9]:
slash_df['cluster_id'] = slash_df['cluster_id'].apply(lambda x: x.replace('/', '.')) # Remove '/' from cluster IDs

In [10]:
slash_df.to_csv('/mimer/NOBACKUP/groups/globalpoverty1/markus/impute_aware_ate/dhs_data.csv', index=False)

## Turn pngs to CLIP encodings

Get CLIP embeddings from PNG images,

In [6]:
df = pd.read_csv(os.path.join(DATA_DIR, 'dhs_data.csv'))
df

Unnamed: 0,cluster_id,lon,lat,rural,region_id,country,survey,month,year,iwi
0,AO.Bengo.71.135,13.640789,-8.589805,False,AO.Bengo,Angola,Angola 2015-16 Standard DHS,11,2015,62.334459
1,AO.Bengo.71.158,14.122619,-7.718385,True,AO.Bengo,Angola,Angola 2015-16 Standard DHS,2,2016,8.226589
2,AO.Bengo.71.169,13.654425,-8.592545,False,AO.Bengo,Angola,Angola 2015-16 Standard DHS,10,2015,62.760211
3,AO.Bengo.71.203,13.517859,-8.652260,True,AO.Bengo,Angola,Angola 2015-16 Standard DHS,1,2016,68.211697
4,AO.Bengo.71.208,13.721998,-7.852511,True,AO.Bengo,Angola,Angola 2015-16 Standard DHS,11,2015,14.825944
...,...,...,...,...,...,...,...,...,...,...
69944,ZW.Midlands.72.37,30.008579,-20.911177,True,ZW.Midlands,Zimbabwe,Zimbabwe 2015 Standard DHS,9,2015,27.791567
69945,ZW.Midlands.72.52,29.860028,-20.402214,True,ZW.Midlands,Zimbabwe,Zimbabwe 2015 Standard DHS,10,2015,36.929878
69946,ZW.Midlands.72.69,30.172833,-20.724753,True,ZW.Midlands,Zimbabwe,Zimbabwe 2015 Standard DHS,10,2015,24.406326
69947,ZW.Midlands.72.91,29.820084,-19.453466,False,ZW.Midlands,Zimbabwe,Zimbabwe 2015 Standard DHS,7,2015,59.887344


In [7]:
import torch
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm

OUTPUT_EMBEDDINGS_FILE = os.path.join(DATA_DIR, "clip_embeddings.npy")
MODEL_NAME = "flax-community/clip-rsicd"
BATCH_SIZE = 32  # Adjust this based on your GPU memory capacity

# Load the pretrained model and processor from huggingface
model = CLIPModel.from_pretrained(MODEL_NAME).to("cuda")
processor = CLIPProcessor.from_pretrained(MODEL_NAME)

embeddings = []

# Helper function to process a batch of images and return embeddings
def process_batch(image_batch):
    inputs = processor(images=image_batch, return_tensors="pt", padding=True).to("cuda")
    with torch.no_grad():
        batch_features = model.get_image_features(**inputs).cpu().numpy()  # Move back to CPU for numpy
    return batch_features

# Process images
image_batch = []
for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing Images"):
    
    # Read image
    png_path = os.path.join(DATA_DIR, 'dhs_images', row['cluster_id'] + '.png')
    image = Image.open(png_path)
    image_batch.append(image)

    # Process the batch if it reaches the batch size
    if len(image_batch) == BATCH_SIZE:
        batch_embeddings = process_batch(image_batch)
        embeddings.append(batch_embeddings)
        image_batch = []  # Reset batch list

# Process any remaining images that didn’t complete a full batch
if image_batch:
    batch_embeddings = process_batch(image_batch)
    embeddings.append(batch_embeddings)

# Save the embeddings array as a .npy file
embeddings_array = np.vstack(embeddings)
np.save(OUTPUT_EMBEDDINGS_FILE, embeddings_array)

print(f"Embeddings saved to {OUTPUT_EMBEDDINGS_FILE}.")


  from .autonotebook import tqdm as notebook_tqdm
Processing Images: 100%|██████████| 69949/69949 [19:15<00:00, 60.54it/s]


Embeddings saved to /mimer/NOBACKUP/groups/globalpoverty1/markus/impute_aware_ate/clip_embeddings.npy.
