# Inference

In [None]:
import sys
sys.path.append('../cnn-land-cover/scripts/')
import land_cover_models as lcm
import torch
import pytorch_lightning as pl
import glob
import logging
import os
import warnings

from tqdm.notebook import tqdm

import utils
from osgeo import gdal
from osgeo import ogr
from osgeo import osr

In [None]:
log = logging.getLogger('pytorch_lightning')
log.propagate = False
log.setLevel(logging.ERROR)

os.environ["GTIFF_SRS_SOURCE"] = 'EPSG' # This is still annoying

import warnings
warnings.filterwarnings('ignore', '.*softmax has been deprecated*')

from osgeo import gdal
gdal.SetConfigOption('GTIFF_SRS_SOURCE', 'EPSG') # So is this

In [None]:
# Load models for inference
ROOT = './logs' # update this DA

models_v4 = {# new only v2
    'main': os.path.join(
        ROOT,
        'DA_main_v1/version_0/checkpoints/epoch=59-step=16380.ckpt'
    ),
    'c': os.path.join(
        ROOT,
        'DA_C/version_0/checkpoints/epoch=59-step=7980.ckpt'
    ),
    'd': os.path.join(
        ROOT,
        'DA_D/version_0/checkpoints/epoch=59-step=10200.ckpt'
    ),
    'e': os.path.join(
        ROOT,
        'DA_E/version_1/checkpoints/epoch=59-step=9480.ckpt'
    ),
}


tiled_vrt = {
    'src_path': '../ADP/AP/combined_mosaic.vrt',
    'tiles_path': ',,/ADP/AP',
}

for model in models_v4.keys():
    print(model, os.path.exists(models_v4[model]))

In [None]:
# File Paths
tiled_vrt = {
    'src_path': '../ADP/AP/combined_mosaic.vrt',  # Make sure this is correct
    'tiles_path': '../ADP/AP/'  # Ensure this path matches what was tested successfully
}

out_dir_root = '../predictions/'
out_suffix = '_v4'

mask_lookup = { 
    'c': 1,
    'd': 2,
    'e': 3
}
mask_dir = None  # Initialize as None, to be updated after main model processing

for mclass in ['main', 'c', 'd', 'e']:
    ckpt_path = models_v4[mclass]
    print(f"Loading model for {mclass} from {ckpt_path}")
    model = lcm.LandCoverUNet.load_from_checkpoint(ckpt_path)
    
    # Define output path for each model
    outpath = f"{out_dir_root}_{mclass}{out_suffix}"
    if os.path.exists(outpath):
        print(f'Overwriting tiles in {outpath}')
    else:
        os.makedirs(outpath)
    
    # Get list of tiles for inference
    tiles = glob.glob(os.path.join(tiled_vrt['tiles_path'], '*.tif'))
    print(f"Using src_path: {tiled_vrt['src_path']}")
    print(f"Using tiles_path: {tiled_vrt['tiles_path']}")
    print(f"Number of tiles found for {mclass}: {len(tiles)}")

    if len(tiles) == 0:
        print(f"No tiles found in path {tiled_vrt['tiles_path']} for {mclass}. Check the path.")
        continue
    
    writer = utils.VrtTileClassifier(
        model,
        tiled_vrt['src_path'],
        tiles,
        outpath,
        add_overlap=False
    )
    writer.set_batch_size(15)
    
    if mclass == 'main':
        writer.write_tiles(tqdm, mask=False)
        mask_dir = outpath  # Set the mask_dir after main model
    else:
        writer.set_tile_mask_dir(
            mask_dir,
            classes_to_keep=[mask_lookup[mclass]],
            suffix=''
        )
        writer.write_tiles(tqdm, mask=True)
