In [9]:
from fastapi import FastAPI, UploadFile, HTTPException
from fastapi.responses import FileResponse
import os
import subprocess
import uuid
from pathlib import Path
import aiofiles
import os
import glob
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import functional as TF
import segmentation_models_pytorch as smp
from torch.optim import lr_scheduler
from sklearn.model_selection import train_test_split
from fastapi.responses import PlainTextResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, Request, Response
import uvicorn
from pyngrok import ngrok



In [None]:
app = FastAPI()


app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

In [10]:

BASE_DIR = "ecw_processing"
INPUT_DIR = f"{BASE_DIR}/input"
OUTPUT_DIR = f"{BASE_DIR}/output"
GRID_DIR = f"{OUTPUT_DIR}/grids"
MASK_DIR = f"{OUTPUT_DIR}/masks"

os.mkdir(BASE_DIR)
os.mkdir(INPUT_DIR)
os.mkdir(OUTPUT_DIR)
os.mkdir(GRID_DIR)
os.mkdir(MASK_DIR)

FileExistsError: [Errno 17] File exists: 'ecw_processing'

In [11]:
class RoadDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, scale_factor=0.5):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.tif")))
        self.mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*.tif")))
        self.scale_factor = scale_factor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)

        mask = (mask > 0).astype(np.uint8)

        new_h = int(image.shape[0] * self.scale_factor)
        new_w = int(image.shape[1] * self.scale_factor)

        new_h += (32 - new_h % 32) % 32
        new_w += (32 - new_w % 32) % 32

        image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
        mask = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        mask = np.expand_dims(mask, axis=0)

        return {
            "image": TF.to_tensor(image),
            "mask": torch.tensor(mask, dtype=torch.int64),
        }

In [12]:
class RoadSegmentationModel(pl.LightningModule):
    def __init__(self, encoder_name="resnet34", in_channels=3, classes=1):
        super().__init__()
        self.model = smp.create_model("FPN", encoder_name=encoder_name, in_channels=in_channels, classes=classes)
        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

    def forward(self, x):
        return self.model(x)

    def shared_step(self, batch):
        images, masks = batch["image"], batch["mask"]
        logits = self(images)
        loss = self.loss_fn(logits, masks)
        return loss, logits, masks

    def training_step(self, batch, batch_idx):
        loss, _, _ = self.shared_step(batch)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, logits, masks = self.shared_step(batch)
        prob_masks = logits.sigmoid()
        pred_masks = (prob_masks > 0.5).float()

        pred_masks = pred_masks.to(torch.int64)

        tp, fp, fn, tn = smp.metrics.get_stats(pred_masks, masks, mode="binary")
        iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_iou", iou, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=2e-4)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=25, eta_min=1e-5)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
            },
        }

In [None]:
road_model = RoadSegmentationModel()
road_model.load_state_dict(torch.load("roads_trained_FPN.pth", weights_only=True))

In [None]:
def single_image_inference(model, image_path, scale_factor=0.25):
    model.eval()
    with torch.no_grad():
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        new_h = int(image.shape[0] * scale_factor)
        new_w = int(image.shape[1] * scale_factor)
        new_h = (new_h + 31) // 32 * 32
        new_w = (new_w + 31) // 32 * 32
        image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)

        image_tensor = TF.to_tensor(image).unsqueeze(0)

        logits = model(image_tensor)
        prob_mask = logits.sigmoid()
        pred_mask = (prob_mask > 0.5).float().squeeze().cpu().numpy()

        pred_mask = cv2.resize(pred_mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)

        return image, pred_mask

In [None]:
@app.post("/process-ecw/")
async def process_ecw(file: UploadFile):
    try:
        file_id = str(uuid.uuid4())
        ecw_file_path = INPUT_DIR / f"{file_id}.ecw"
        async with aiofiles.open(ecw_file_path, 'wb') as f:
            while chunk := await file.read(1024):
                await f.write(chunk)

        tiff_file_path = OUTPUT_DIR / f"{file_id}.tif"
        gdal_translate_command = [
            "gdal_translate", "-of", "GTiff", "-co", "COMPRESS=LZW",
            str(ecw_file_path), str(tiff_file_path)
        ]
        subprocess.run(gdal_translate_command, check=True)

        TILE_WIDTH, TILE_HEIGHT = 3000, 3000
        ROWS, COLS = 5, 5
        for row in range(ROWS):
            for col in range(COLS):
                x_offset = col * TILE_WIDTH
                y_offset = row * TILE_HEIGHT
                grid_file_path = GRID_DIR / f"tile_{row}_{col}.tif"
                gdal_translate_grid_command = [
                    "gdal_translate", "-of", "GTiff", "-srcwin",
                    str(x_offset), str(y_offset), str(TILE_WIDTH), str(TILE_HEIGHT),
                    str(tiff_file_path), str(grid_file_path)
                ]
                subprocess.run(gdal_translate_grid_command, check=True)

        for grid_file in GRID_DIR.iterdir():
            image, pred_mask = single_image_inference(None, str(grid_file))
            mask_file_path = MASK_DIR / f"mask_{grid_file.name}"
            cv2.imwrite(str(mask_file_path), (pred_mask * 255).astype(np.uint8))

        combined_mask_tiff_path = OUTPUT_DIR / "combined_masks.tif"
        gdalwarp_command = [
            "gdalwarp", "-of", "GTiff", str(MASK_DIR / "*.tif"), str(combined_mask_tiff_path)
        ]
        subprocess.run(gdalwarp_command, check=True)

        shapefile_path = OUTPUT_DIR / "output_shapefile.shp"
        gdal_polygonize_command = [
            "gdal_polygonize.py", str(combined_mask_tiff_path), "-f", "ESRI Shapefile",
            str(shapefile_path)
        ]
        subprocess.run(gdal_polygonize_command, check=True)

        return FileResponse(shapefile_path, media_type="application/x-shapefile", filename="output_shapefile.zip")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

In [None]:
ngrok.set_auth_token("2WfNAfRBJZK6oe2x0Sl9QVRl5Zv_5wxr2UnS1CBuJmycrFk2k")

In [None]:
ngrok_tunnel = ngrok.connect(addr = 8000, domain = "helpful-boxer-wrongly.ngrok-free.app")
print("Public URL:", ngrok_tunnel.public_url)
uvicorn.run(app, port=8000)