# GeoSeg Inference 

In [None]:
import os as os
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader

from models.DCSwin import *
from data_processing.Urban_dataset import *
from data_processing.transform import *

## Load dataset

In [None]:
data_root = "/Users/dean/code/geospatial/sat_stuff/dcswin_sam_sessrs/geoseg/data/Urban/train" # LoveDA test data root
n_classes = 7 

# Define device
gpu_index = 0
device = torch.device(f"cuda:{gpu_index}" if torch.cuda.is_available() else "cpu")
print(device)

# Dataset + loader with pin_memory
infer_dataset = UrbanDataset(data_root=data_root,
                                mode='test',
                                augmentation=train_aug) # Might want to check val_aug

infer_loader = DataLoader(infer_dataset,
                                batch_size=1,
                                shuffle=False,
                                num_workers=4,
                                pin_memory=True) # True when running on cuda

## Load the model

In [None]:
# Model
model = dcswin_tiny(pretrained=True, num_classes=n_classes, weight_path='/Users/dean/code/geospatial/sat_stuff/dcswin_sam_sessrs/geoseg/models/ckpt/stseg_tiny.pth')
model.to(device)
model.eval()

## Create masks! 

In [None]:
# True saves RGB masks for visualization
# False saves indexed PNG for SESSRS
rgb_output = True 

output_path = "data/output_masks_loveda" # base output folder

# Make output dirs
os.makedirs(os.path.join(output_path, 'pre_rgb'), exist_ok=True)
os.makedirs(os.path.join(output_path, 'pre_p'), exist_ok=True)

with torch.no_grad():
    for idx, batch in enumerate(infer_loader):
        img = batch['img'].to(device)
        logits   = model(img)                           # (1, C, H, W)
        probs    = nn.Softmax(dim=1)(logits)
        mask     = probs.argmax(dim=1).squeeze(0)       # (H, W) tensor
        mask_np  = mask.cpu().numpy().astype(np.uint8)
        mask_name = f"mask_{idx}"      # pass name without “.png”
        img_writer((mask_np, output_path, mask_name, rgb_output))