In [3]:
!pip install scikit-image


Collecting scikit-image
  Downloading scikit_image-0.25.2-cp310-cp310-win_amd64.whl.metadata (14 kB)
Collecting imageio!=2.35.0,>=2.33 (from scikit-image)
  Downloading imageio-2.37.0-py3-none-any.whl.metadata (5.2 kB)
Collecting tifffile>=2022.8.12 (from scikit-image)
  Downloading tifffile-2025.5.10-py3-none-any.whl.metadata (31 kB)
Collecting lazy-loader>=0.4 (from scikit-image)
  Downloading lazy_loader-0.4-py3-none-any.whl.metadata (7.6 kB)
Downloading scikit_image-0.25.2-cp310-cp310-win_amd64.whl (12.8 MB)
   ---------------------------------------- 0.0/12.8 MB ? eta -:--:--
   - -------------------------------------- 0.5/12.8 MB 5.6 MB/s eta 0:00:03
   ---- ----------------------------------- 1.6/12.8 MB 5.6 MB/s eta 0:00:03
   -------- ------------------------------- 2.6/12.8 MB 6.0 MB/s eta 0:00:02
   ------------ --------------------------- 3.9/12.8 MB 5.9 MB/s eta 0:00:02
   ---------------- ----------------------- 5.2/12.8 MB 5.7 MB/s eta 0:00:02
   ------------------- ----

In [None]:
import rasterio
from torchvision.transforms.functional import to_tensor
from skimage.util import view_as_windows
import numpy as np
import torch
import sys
sys.path.append("scripts")  # Make scripts/ importable
from model import UNet
from tqdm import tqdm
import matplotlib.pyplot as plt



# Load full image
full_img_path =r"C:\Users\khuza\OneDrive\Desktop\Data science course\DATA SCIENCE COURSE\PROJECTS\AI-DRIVEN SATELLITE ANALYSIS\data\raw\bangalore_rgb.tif"
with rasterio.open(full_img_path) as src:
    full_image = src.read().astype(np.float32)  # Shape: (C, H, W)

# Normalize
full_image = (full_image - full_image.min()) / (full_image.max() - full_image.min() + 1e-8)

# Create patches (same as training size)
PATCH_SIZE = 64
STRIDE = 64
C, H, W = full_image.shape

windows = view_as_windows(full_image, (C, PATCH_SIZE, PATCH_SIZE), step=STRIDE)
windows = windows.reshape(-1, C, PATCH_SIZE, PATCH_SIZE)

print(f"[INFO] Extracted {len(windows)} patches from full image.")

# Load the model
model = UNet(in_channels=3, out_channels=1)
model.load_state_dict(torch.load(r"C:\Users\khuza\OneDrive\Desktop\Data science course\DATA SCIENCE COURSE\PROJECTS\AI-DRIVEN SATELLITE ANALYSIS\models\unet_trained.pth", map_location="cuda" if torch.cuda.is_available() else "cpu"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Inference on all patches
preds = []
model.eval()
with torch.no_grad():
    for patch in tqdm(windows, desc="Predicting patches"):
        patch_tensor = torch.from_numpy(patch).unsqueeze(0).to(device)
        output = model(patch_tensor)
        output = torch.sigmoid(output).squeeze().cpu().numpy()
        preds.append(output)

# Reconstruct full prediction map
preds = np.array(preds).reshape((H // STRIDE, W // STRIDE, PATCH_SIZE, PATCH_SIZE))
reconstructed = np.block([[preds[i, j] for j in range(preds.shape[1])] for i in range(preds.shape[0])])

# Save or visualize
plt.figure(figsize=(10, 8))
plt.imshow(reconstructed, cmap="gray")
plt.title("Full Image Prediction")
plt.axis("off")
os.makedirs("reports/figures", exist_ok=True)
plt.savefig("reports/figures/full_image_prediction.png", dpi=300)
print("[INFO] Full-size prediction saved to reports/figures/full_image_prediction.png")
plt.show()


[INFO] Extracted 4692 patches from full image.


Predicting patches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 4692/4692 [07:02<00:00, 11.10it/s]


NameError: name 'os' is not defined