In [3]:
# ==========================================
# 1. SETUP & IMPORTS
# ==========================================
!pip install -q segmentation_models_pytorch gdown

import os
import cv2
import glob
import torch
import gdown
import numpy as np
import pandas as pd
import torch.nn.functional as F
import segmentation_models_pytorch as smp

from PIL import Image
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from dataclasses import dataclass

# ==========================================
# 2. CONFIGURATION
# ==========================================
@dataclass
class Cfg:
    # Paths
    TEST_DIR: str = "/kaggle/input/terra-seg-rugged-terrain-segmentation/offroad-seg-kaggle/test_images_padded"
    WORK_DIR: str = "/kaggle/working"
    MODEL_DIR: str = os.path.join(WORK_DIR, "weights")
    MASK_OUT_DIR: str = os.path.join(WORK_DIR, "binary_masks")
    SUBMISSION_PATH: str = os.path.join(WORK_DIR, "submission.csv")
    
    # Drive Link
    DRIVE_URL: str = "https://drive.google.com/drive/folders/1kApmkblFvkT1kRtbe23zMnnxl02pATmB?usp=sharing"
    
    # Model Configs
    CKPT_DEEPLAB: str = "deeplab.pth"
    CKPT_HRNET: str = "best_offseg_hrnet_unet.pth"
    
    # Inference Parameters
    ORIG_SIZE: tuple = (540, 960)  # (H, W)
    MODEL_SIZE: tuple = (544, 960) # (H, W)
    THRESHOLD: float = 0.48
    DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
    BATCH_SIZE: int = 1
    NUM_WORKERS: int = 2

# Create directories
os.makedirs(Cfg.MODEL_DIR, exist_ok=True)
os.makedirs(Cfg.MASK_OUT_DIR, exist_ok=True)

# ==========================================
# 3. UTILITIES
# ==========================================
def download_weights(cfg: Cfg):
    """Downloads model weights from Google Drive using gdown."""
    print(f" Downloading weights to {cfg.MODEL_DIR}...")
    gdown.download_folder(url=cfg.DRIVE_URL, output=cfg.MODEL_DIR, quiet=False, use_cookies=False)
    
    # Verify files
    required = [cfg.CKPT_DEEPLAB, cfg.CKPT_HRNET]
    files = os.listdir(cfg.MODEL_DIR)
    for f in required:
        if f not in files:
            raise FileNotFoundError(f" Missing file: {f}")
    print(" Download complete.")

def rle_encode(mask: np.ndarray) -> str:
    """Run-Length Encode a binary mask."""
    pixels = mask.flatten(order="F")
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return " ".join(map(str, runs))

def load_model(arch: str, path: str, device: str):
    """Factory function to load and freeze models."""
    path = os.path.join(Cfg.MODEL_DIR, path)
    
    if arch == "deeplab":
        model = smp.DeepLabV3Plus("resnet34", encoder_weights=None, classes=1)
        state_dict = torch.load(path, map_location=device)["model_state_dict"]
    elif arch == "hrnet":
        model = smp.Unet(
            encoder_name="tu-hrnet_w32", 
            encoder_weights=None, 
            classes=1, 
            decoder_attention_type="scse"
        )
        state_dict = torch.load(path, map_location=device)
    else:
        raise ValueError(f"Unknown architecture: {arch}")

    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    # Freeze parameters
    for param in model.parameters():
        param.requires_grad = False
        
    return model

# ==========================================
# 4. DATASET PIPELINE
# ==========================================
class TerrainDataset(Dataset):
    def __init__(self, root_dir, model_size):
        self.image_paths = sorted(glob.glob(os.path.join(root_dir, "*.png")))
        self.model_h, self.model_w = model_size
        
        # Define transforms
        self.norm = transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        img_id = os.path.splitext(os.path.basename(path))[0]
        
        # Load Image
        img_pil = Image.open(path).convert("RGB")
        
        # Branch 1: Padding for DeepLab
        # Matches original logic: pad (left, top, right, bottom) -> (0, 2, 0, 2)
        img_pad_pil = transforms.functional.pad(img_pil, (0, 2, 0, 2))
        img_pad_tensor = self.norm(self.to_tensor(img_pad_pil))
        
        # Branch 2: Resize for HRNet
        img_resize_pil = img_pil.resize((self.model_w, self.model_h), Image.BILINEAR)
        img_resize_tensor = self.norm(self.to_tensor(img_resize_pil))
        
        return {
            "image_id": img_id,
            "path": path,
            "input_pad": img_pad_tensor,     # For DeepLab
            "input_resize": img_resize_tensor # For HRNet
        }

# ==========================================
# 5. MAIN INFERENCE LOOP
# ==========================================
def run_inference():
    # 1. Download Weights
    download_weights(Cfg)
    
    # 2. Load Models
    print(" Loading models...")
    net_deeplab = load_model("deeplab", Cfg.CKPT_DEEPLAB, Cfg.DEVICE)
    net_hrnet = load_model("hrnet", Cfg.CKPT_HRNET, Cfg.DEVICE)
    
    # 3. Prepare Data
    if not os.path.exists(Cfg.TEST_DIR):
        print(f" Test directory not found: {Cfg.TEST_DIR}")
        return

    dataset = TerrainDataset(Cfg.TEST_DIR, Cfg.MODEL_SIZE)
    loader = DataLoader(dataset, batch_size=Cfg.BATCH_SIZE, shuffle=False, num_workers=Cfg.NUM_WORKERS)
    print(f" Found {len(dataset)} images.")

    results = []
    
    print(" Starting inference...")
    with torch.no_grad():
        for batch in tqdm(loader):
            img_ids = batch["image_id"]
            paths = batch["path"]
            
            # Move inputs to device
            x_pad = batch["input_pad"].to(Cfg.DEVICE)
            x_resize = batch["input_resize"].to(Cfg.DEVICE)
            
            # --- Forward Pass ---
            out_deeplab = net_deeplab(x_pad)
            out_hrnet = net_hrnet(x_resize)
            
            # --- Post-Process ---
            prob_deeplab = torch.sigmoid(out_deeplab)
            prob_hrnet = torch.sigmoid(out_hrnet)
            
            # Iterate through batch (even if size is 1)
            for i in range(len(img_ids)):
                # Interpolate to original size
                # Note: Unsqueeze adds batch/channel dims required for interpolate
                p_dl = F.interpolate(
                    prob_deeplab[i].unsqueeze(0), size=Cfg.ORIG_SIZE, mode='bilinear', align_corners=False
                ).squeeze().cpu().numpy()
                
                p_hr = F.interpolate(
                    prob_hrnet[i].unsqueeze(0), size=Cfg.ORIG_SIZE, mode='bilinear', align_corners=False
                ).squeeze().cpu().numpy()
                
                # Weighted Ensemble
                ensemble_prob = (0.52 * p_hr) + (0.48 * p_dl)
                mask = (ensemble_prob > Cfg.THRESHOLD).astype(np.uint8)
                
                # Save RLE
                results.append({
                    "image_id": img_ids[i],
                    "encoded_pixels": rle_encode(mask)
                })
                
                # Save PNG
                cv2.imwrite(
                    os.path.join(Cfg.MASK_OUT_DIR, os.path.basename(paths[i])),
                    mask * 255
                )

    # 4. Save Submission
    if results:
        df = pd.DataFrame(results)
        df.to_csv(Cfg.SUBMISSION_PATH, index=False)
        print(f"\n Success! Submission saved to: {Cfg.SUBMISSION_PATH}")
        print(f" Masks saved to: {Cfg.MASK_OUT_DIR}")
    else:
        print(" No predictions generated.")

if __name__ == "__main__":
    run_inference()

 Downloading weights to /kaggle/working/weights...


Retrieving folder contents


Processing file 1PlPTGIruNOI6U6jSVzCMMo5ARNJvhGS4 best_offseg_hrnet_unet.pth
Processing file 1aAnR3fwK8e6Nc7AK7R6LPHi2buHnNN__ deeplab.pth


Retrieving folder contents completed
Building directory structure
Building directory structure completed
Downloading...
From (original): https://drive.google.com/uc?id=1PlPTGIruNOI6U6jSVzCMMo5ARNJvhGS4
From (redirected): https://drive.google.com/uc?id=1PlPTGIruNOI6U6jSVzCMMo5ARNJvhGS4&confirm=t&uuid=b032c2f3-88a4-4969-8568-5ef7abf49b81
To: /kaggle/working/weights/best_offseg_hrnet_unet.pth
100%|██████████| 146M/146M [00:00<00:00, 161MB/s]  
Downloading...
From (original): https://drive.google.com/uc?id=1aAnR3fwK8e6Nc7AK7R6LPHi2buHnNN__
From (redirected): https://drive.google.com/uc?id=1aAnR3fwK8e6Nc7AK7R6LPHi2buHnNN__&confirm=t&uuid=28a49eb7-3131-410c-b76f-743690d44018
To: /kaggle/working/weights/deeplab.pth
100%|██████████| 270M/270M [00:03<00:00, 77.7MB/s] 
Download completed


 Download complete.
 Loading models...
 Found 1002 images.
 Starting inference...


  0%|          | 0/1002 [00:00<?, ?it/s]


 Success! Submission saved to: /kaggle/working/submission.csv
 Masks saved to: /kaggle/working/binary_masks
