In [None]:
from pathlib import Path
import sys
project_root = Path.cwd().resolve().parents[0]
print(project_root)
sys.path.append(str(project_root))

In [None]:
from operator import itemgetter
from pathlib import Path
from src.shared.constants import DATASET_DIR
from src.shared.tiler import Tiler, BatchConfig
import numpy as np
import gc
from fastai.data.block import DataBlock
from fastai.data.transforms import RandomSplitter
from fastai.vision.all import *
from fastai.callback.all import EarlyStoppingCallback
from PIL import Image
import torch

In [None]:


class MemoryEfficientTilerDataset:
    """Memory-optimized dataset - NO CACHING."""
    def __init__(self, image_tiler: Tiler, label_tiler: Tiler = None, min_building_ratio: float = 0.1):
        self.image_tiler = image_tiler
        self.label_tiler = label_tiler

        if label_tiler and min_building_ratio > 0:
            self.valid_indices = self._filter_tiles_by_content_efficient(min_building_ratio)
        else:
            self.valid_indices = list(range(len(image_tiler)))

        print(f"Memory-Efficient Tiler Dataset: {len(self.valid_indices)} valid tiles")

    def _filter_tiles_by_content_efficient(self, min_building_ratio):
        valid_indices = []

        for idx in range(len(self.image_tiler)):  # Check **all** tiles
            try:
                label_tile = self.label_tiler.get_tile_by_id(idx, as_numpy=True)  # shape: C,H,W or H,W
                building_ratio = np.mean(label_tile == 1)

                if building_ratio >= min_building_ratio:
                    valid_indices.append(idx)  # Keep only this tile
            except Exception as e:
                print(f"Warning: Could not check tile {idx}: {e}")
                continue

        return valid_indices  # No set/sorting needed

    def __len__(self): return len(self.valid_indices)

    def __getitem__(self, idx):
        actual_idx = self.valid_indices[idx]

        # ----- IMAGE: C,H,W, uint8 -----
        image_np = self.image_tiler.get_tile_by_id(actual_idx, as_numpy=True)  # (C,H,W)
        
        # Ensure image is at least 3 channels
        if image_np.ndim == 2:
            image_np = np.repeat(image_np[None, ...], 3, axis=0)  # (3,H,W)
        elif image_np.shape[0] >= 4:
            image_np = image_np[:3]  # Trim to RGB

        img_t = torch.from_numpy(image_np.copy())      # uint8, (3,H,W)
        img_t = TensorImage(img_t)                     # fastai wrapper

        # ----- MASK: H,W, long -----
        label_np = self.label_tiler.get_tile_by_id(actual_idx, as_numpy=True)  # (C,H,W) or (H,W)
        if label_np.ndim == 3:
            label_np = label_np[0] if label_np.shape[0] > 1 else np.squeeze(label_np, axis=0)
        mask_np = (label_np == 1).astype(np.uint8)      # binary mask (0 or 1)
        mask_t = torch.from_numpy(mask_np.copy()).long()
        mask_t = TensorMask(mask_t)

        return (img_t, mask_t)
    
def create_memory_efficient_tiler_dls(
    image_path: str,
    label_path: str = None,
    tile_size: int = 256,
    overlap: int = 32,
    batch_size: int = 2,
    valid_pct: float = 0.2,
    min_building_ratio: float = 0.1,
    **kwargs
):
    # MEMORY-OPTIMIZED CONFIG
    config = BatchConfig(
        batch_size=8,
        prefetch_batches=2,
        max_workers=4,
        enable_cache=False,
        memory_limit_mb=2048
    )

    image_tiler = Tiler(image_path, tile_size=tile_size, overlap=overlap, batch_config=config)
    label_tiler = Tiler(label_path, tile_size=tile_size, overlap=overlap, batch_config=config) if label_path else None

    dataset = MemoryEfficientTilerDataset(
        image_tiler=image_tiler,
        label_tiler=label_tiler,
        min_building_ratio=min_building_ratio
    )

    def get_items(_): return list(range(len(dataset)))

    def get_x(i):
        itm = dataset[i]
        return itm[0] if isinstance(itm, tuple) else itm

    tile_size = 1024  # your current tiles

    # xtra transforms:
    xtra = [
        Dihedral(p=0.7),                 # 8-way: flips + 90° rotations (orthos love this)
        Rotate(max_deg=10, p=0.5),       # small free-angle rotation
        Zoom(min_zoom=1.0, max_zoom=1.20, p=0.7),  # small scale jitter
        # optional: light blur/noise for robustness
        RandomErasing(p=0.05, max_count=1),        # tiny occlusions (acts like mild Cutout)
    ]

    batch_tfms = [
        *aug_transforms(                  # uses affine + lighting together
            do_flip=True, flip_vert=True,
            max_rotate=0,                 # (we add Rotate separately to keep control)
            min_zoom=1.0, max_zoom=1.0,   # (zoom handled by Zoom() above)
            max_lighting=0.25, p_lighting=0.9,
            max_warp=0.0,                 # no perspective/elastic twists
        ),
        Normalize.from_stats(*imagenet_stats),
        # Optional: tiny noise/blur—uncomment if needed
        # GAussianBlur(p=0.15, sigma=(0.1, 1.0)),   # requires torchvision>=0.13 transforms
    ]

    def ImgTBlock():
        return TransformBlock()

    def MaskTBlock(codes):
        return TransformBlock(item_tfms=AddMaskCodes(codes=codes))

    if label_path:
        def get_y(i):
            _, m = dataset[i]
            return m

        dblock = DataBlock(
            blocks=(ImgTBlock, MaskTBlock(codes=['background', 'building'])),
            get_items=get_items,
            get_x=get_x,
            get_y=get_y,
            splitter=RandomSplitter(valid_pct=valid_pct, seed=42),
            item_tfms=[],
            batch_tfms=[IntToFloatTensor(), *xtra, *batch_tfms]
        )
    else:
        dblock = DataBlock(
            blocks=(ImgTBlock,),
            get_items=get_items,
            get_x=get_x,
            splitter=RandomSplitter(valid_pct=valid_pct, seed=42),
            batch_tfms=[IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)]
        )

    # IMPORTANT: don't pass path=None; use '.' or omit entirely. Also use num_workers=0.
    return dblock.dataloaders(
        source=None,
        bs=batch_size,
        num_workers=8,
        pin_memory=True,
        # path='.'   # optional; or just omit the path argument
        **kwargs
    )
# Memory monitoring utility
def monitor_memory():
    import psutil
    process = psutil.Process()
    memory_mb = process.memory_info().rss / 1024 / 1024
    print(f"Current memory usage: {memory_mb:.1f} MB")
    return memory_mb

def train_with_memory_monitoring(image_path: str, label_path: str, arch=resnet34, **kwargs):
    print("🧠 Initial memory usage:")
    monitor_memory()
    dls = create_memory_efficient_tiler_dls(
        image_path=image_path,
        label_path=label_path,
        **kwargs
    )
    print("🧠 After creating DataLoaders:")
    monitor_memory()
    print(f"🎯 Created DataLoaders with {len(dls.train)} batches")

    learn = unet_learner(
        dls,
        arch,
        metrics=[DiceMulti()],
        loss_func=CrossEntropyLossFlat(axis=1),
        cbs=[EarlyStoppingCallback(patience=3)],
        pretrained=True
    )
    dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    learn.model.to(dev)

    print("🧠 After creating learner:")
    monitor_memory()
    return learn, dls

if __name__ == "__main__":
    print("🧠 Starting memory usage:")
    monitor_memory()

    learn, dls = train_with_memory_monitoring(
        image_path=Path(DATASET_DIR / "ortho_cog_cropped.tif"),
        label_path=Path(DATASET_DIR / "building_mask.tif"),
        tile_size=1024,
        overlap=256,
        batch_size=4,
    )

    print("🧠 Final memory usage:")
    monitor_memory()

        


In [None]:
learn.fine_tune(5)



In [None]:
# xb, yb = dls.one_batch()
dls.show_batch(max_n=10, figsize=(10,10))
learn.model.eval()
# with torch.inference_mode():
#   logits = learn.model(xb)            # [B, C, H, W]
#   pred = logits.argmax(dim=1)         # [B, H, W]  (0=background, 1=building)

interp = SegmentationInterpretation.from_learner(learn)
interp.show_results(k=3)
# print("xb:", xb.shape, "yb:", yb.shape, "pred:", pred.shape)

# # visualize a couple of samples
# import matplotlib.pyplot as plt
# for i in range(min(3, xb.shape[0])):
#   img = xb[i].cpu().permute(1,2,0).clamp(0,1).numpy()
#   gt  = yb[i].cpu().numpy()
#   pr  = pred[i].cpu().numpy()
#   fig, axs = plt.subplots(1,3, figsize=(10,3))
#   axs[0].imshow(img); axs[0].set_title("Image"); axs[0].axis('off')
#   axs[1].imshow(gt, vmin=0, vmax=1); axs[1].set_title("GT"); axs[1].axis('off')
#   axs[2].imshow(pr, vmin=0, vmax=1); axs[2].set_title("Pred"); axs[2].axis('off')
#   plt.show()


In [None]:
# from matplotlib.colors import ListedColormap


# cmap = ListedColormap([
#     (0.0, 0.0, 0.0, 1.0),   # black background
#     (0.0, 0.8, 0.2, 0.9),   # bright green buildings
# ])
# interp.show_results(idxs=[1,2,3], cmap=cmap, alpha=0.6, vmin=0, vmax=1)
learn.lr_find()