In [None]:
# Cell 1: Imports, Drive mount, device and seeds
import os
import random
import math
from pathlib import Path

import cv2
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets, transforms

from segment_anything import sam_model_registry, SamPredictor

# Mount drive (Colab) - comment out if running locally
try:
    from google.colab import drive
    drive.mount('/content/drive')
except Exception as e:
    # Not running in Colab or mount unavailable
    pass

# Device and reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Use cuDNN fast autotuner when input sizes are constant (good for training)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True


In [None]:
# Cell 2: Load SAM and create predictor
# Point the checkpoint path to your model checkpoint in Drive
SAM_CHECKPOINT = "/content/drive/MyDrive/plantation_data/models/sam_vit_b.pth"
if not os.path.exists(SAM_CHECKPOINT):
    print(f"Warning: checkpoint not found at {SAM_CHECKPOINT}. Update the path before running training.")

sam = sam_model_registry["vit_b"](checkpoint=SAM_CHECKPOINT)
sam.to(device)
# Put in train mode for fine-tuning
sam.train()

predictor = SamPredictor(sam)
print("SAM model and predictor ready")


In [None]:
# Cell 3: Dataset transform and ImageFolder preview (optional)
class ConvertToRGB(object):
    def __call__(self, img):
        if img.mode != "RGB":
            img = img.convert("RGB")
        return img

transform = transforms.Compose([
    ConvertToRGB(),
    transforms.Resize(1024),
    transforms.CenterCrop(1024),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

images_root = r"/content/drive/MyDrive/plantation_data/train_dat"
if os.path.exists(images_root):
    folder = datasets.ImageFolder(images_root, transform=transform)
    print("Found classes:", folder.classes)
else:
    print("images_root not found; skipping ImageFolder preview")


In [None]:
# Cell 4: Build image-mask pairs (robustly across subfolders)
DATA_ROOT = r"/content/drive/MyDrive/plantation_data"
IMAGES_DIR = os.path.join(DATA_ROOT, "train_dat")
MASKS_DIR = os.path.join(DATA_ROOT, "masks")

IMAGE_EXTS = ('.png', '.jpg', '.jpeg', '.tiff', '.bmp')

pairs = []
if os.path.isdir(IMAGES_DIR) and os.path.isdir(MASKS_DIR):
    img_subdirs = [d for d in sorted(os.listdir(IMAGES_DIR)) if os.path.isdir(os.path.join(IMAGES_DIR, d))]
    mask_subdirs = [d for d in sorted(os.listdir(MASKS_DIR)) if os.path.isdir(os.path.join(MASKS_DIR, d))]
    matched = sorted(set(img_subdirs).intersection(mask_subdirs))
    print(f"Matched subdirectories: {matched}")

    for sub in matched:
        img_dir = os.path.join(IMAGES_DIR, sub)
        mask_dir = os.path.join(MASKS_DIR, sub)
        for img_file in sorted(os.listdir(img_dir)):
            if not img_file.lower().endswith(IMAGE_EXTS):
                continue
            img_path = os.path.join(img_dir, img_file)
            base = os.path.splitext(img_file)[0]
            # find masks that either start with base or contain base
            matched_masks = []
            for root, _, files in os.walk(mask_dir):
                for f in files:
                    if not f.lower().endswith(IMAGE_EXTS):
                        continue
                    if f.startswith(base) or (base in f):
                        matched_masks.append(os.path.join(root, f))
            for m in sorted(matched_masks):
                pairs.append({"image": img_path, "mask": m})

print(f"Total pairs found: {len(pairs)}")

# Sanity: split into train/test
from sklearn.model_selection import train_test_split
if len(pairs) > 1:
    train_pairs, val_pairs = train_test_split(pairs, test_size=0.2, random_state=42)
else:
    train_pairs, val_pairs = pairs, []
print(f"Train: {len(train_pairs)}, Val: {len(val_pairs)}")


In [None]:
# Cell 5: Utility functions - padding and batch reader (produces pixel coords, int64 labels)
def pad_to_square(image_tensor, size=1024):
    # image_tensor shape: (B, C, H, W)
    _, _, h, w = image_tensor.shape
    pad_bottom = max(0, size - h)
    pad_right = max(0, size - w)
    if pad_bottom == 0 and pad_right == 0:
        return image_tensor
    padded = F.pad(image_tensor, (0, pad_right, 0, pad_bottom))
    return padded

def read_batch_random(data_pairs, predictor, visualize=False, target_size=1024, max_pos_points=10, neg_ratio=0.5):
    # pick a random entry
    ent = data_pairs[random.randint(0, len(data_pairs)-1)]
    img_bgr = cv2.imread(ent['image'])
    ann = cv2.imread(ent['mask'], cv2.IMREAD_GRAYSCALE)
    if img_bgr is None or ann is None:
        print('Could not read', ent)
        return None, None, None
    img = img_bgr[..., ::-1].copy()  # BGR->RGB

    # apply predictor transform (this includes resizing to model expected input)
    input_image = predictor.transform.apply_image(img)
    input_tensor = torch.as_tensor(input_image).permute(2,0,1).unsqueeze(0).float()
    # pad to square target_size if needed
    input_tensor = pad_to_square(input_tensor, size=target_size)

    # -- prepare GT mask: resize the annotation using nearest so labels remain integer
    h_orig, w_orig = img.shape[:2]
    scale = target_size / max(h_orig, w_orig)
    new_w = int(round(w_orig * scale))
    new_h = int(round(h_orig * scale))
    mask_resized = cv2.resize(ann, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
    pad_h = target_size - mask_resized.shape[0]
    pad_w = target_size - mask_resized.shape[1]
    mask_padded = np.pad(mask_resized, ((0,pad_h),(0,pad_w)), mode='constant', constant_values=0)

    # Binarize mask (foreground > 0)
    binary_mask = (mask_padded > 0).astype(np.uint8)

    # Erode slightly to get interior points (optional)
    eroded = cv2.erode(binary_mask, np.ones((5,5), np.uint8), iterations=1)
    pos_coords = np.argwhere(eroded > 0)  # (y,x)
    neg_coords = np.argwhere(binary_mask == 0)

    points = []
    labels = []
    # positive points
    if len(pos_coords) > 0:
        n_pos = min(len(pos_coords), max_pos_points)
        idxs = np.random.choice(len(pos_coords), n_pos, replace=False)
        for i in idxs:
            y,x = pos_coords[i]
            points.append([int(x), int(y)])  # (x,y)
            labels.append(1)
    # negative points (random background) to teach the model boundaries
    if len(neg_coords) > 0 and len(points) > 0:
        n_neg = int(len(points) * neg_ratio)
        n_neg = min(n_neg, len(neg_coords))
        if n_neg > 0:
            idxs = np.random.choice(len(neg_coords), n_neg, replace=False)
            for i in idxs:
                y,x = neg_coords[i]
                points.append([int(x), int(y)])
                labels.append(0)

    if len(points) == 0:
        return input_tensor, binary_mask, None, None

    if visualize:
        plt.figure(figsize=(12,4))
        plt.subplot(1,3,1); plt.imshow(img); plt.axis('off'); plt.title('Original')
        plt.subplot(1,3,2); plt.imshow(binary_mask, cmap='gray'); plt.axis('off'); plt.title('GT binary')
        plt.subplot(1,3,3); plt.imshow(binary_mask, cmap='gray');
        for p,l in zip(points, labels):
            c = 'r' if l==1 else 'b'
            plt.scatter(p[0], p[1], s=80, c=c)
        plt.axis('off'); plt.title('Points (red=pos, blue=neg)')
        plt.show()

    return input_tensor, binary_mask, np.array(points, dtype=np.float32), np.array(labels, dtype=np.int64)


In [None]:
# Cell 6: Freeze/unfreeze params and optimizer setup
for p in sam.parameters():
    p.requires_grad = False

# Unfreeze prompt_encoder and mask_decoder (and optionally image_encoder.neck if present)
if hasattr(sam, 'prompt_encoder'):
    for p in sam.prompt_encoder.parameters():
        p.requires_grad = True

if hasattr(sam, 'mask_decoder'):
    for p in sam.mask_decoder.parameters():
        p.requires_grad = True

if hasattr(sam.image_encoder, 'neck'):
    for p in sam.image_encoder.neck.parameters():
        p.requires_grad = True

trainable = [p for p in sam.parameters() if p.requires_grad]
print(f"Trainable parameter groups: {len(trainable)} tensors")

optimizer = optim.AdamW(trainable, lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.2)
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

# Losses
bce_loss = nn.BCEWithLogitsLoss()

def dice_loss(pred_probs: torch.Tensor, target: torch.Tensor, eps=1e-6):
    # pred_probs and target are (B,H,W) with values in [0,1]
    intersection = (pred_probs * target).sum(dim=[1,2])
    union = pred_probs.sum(dim=[1,2]) + target.sum(dim=[1,2])
    dice = (2.0 * intersection + eps) / (union + eps)
    return 1.0 - dice.mean()


In [None]:
# Cell 7: Training loop (with autocast, combined BCE+Dice, and negative sampling)
sam.to(device)
max_steps = 1000
accumulation_steps = 4
save_every = 500
use_bce_dice = True
dice_weight = 1.0

for step in range(1, max_steps+1):
    sample = read_batch_random(train_pairs, predictor, visualize=False, target_size=1024, max_pos_points=8, neg_ratio=0.5)
    if sample is None:
        continue
    input_tensor, gt_mask_np, points, labels = sample
    if input_tensor is None or gt_mask_np is None:
        continue
    if points is None or points.shape[0] == 0:
        # skip samples without points
        continue

    input_tensor = input_tensor.to(device)
    gt_mask = torch.tensor((gt_mask_np > 0).astype(np.float32), device=device).unsqueeze(0)  # (1,H,W)

    # Prepare prompt points and labels: SAM accepts pixel coords and per-point labels
    input_points_t = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0)  # (1,N,2)
    input_labels_t = torch.tensor(labels, dtype=torch.int64, device=device).unsqueeze(0)    # (1,N)

    # forward with autocast if cuda
    use_autocast = torch.cuda.is_available()
    autocast = torch.cuda.amp.autocast if use_autocast else torch.cpu.amp.autocast

    with autocast():
        image_embeddings = sam.image_encoder(input_tensor)
        if len(image_embeddings.shape) == 3:
            image_embeddings = image_embeddings.unsqueeze(0)

        sparse_embeddings, dense_embeddings = sam.prompt_encoder(
            points=(input_points_t, input_labels_t),
            boxes=None,
            masks=None,
        )

        image_pe = sam.prompt_encoder.get_dense_pe()

        low_res_masks, iou_preds = sam.mask_decoder(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
        )

        # pick best mask per batch element using iou_preds
        best_idx = iou_preds.argmax(dim=1)  # (B,)
        B = low_res_masks.shape[0]
        chosen = low_res_masks[torch.arange(B, device=device), best_idx]  # (B,H',W')

        # upsample logits to original input size
        prd_logits_resized = torch.nn.functional.interpolate(
            chosen.unsqueeze(1), size=input_tensor.shape[2:], mode='bilinear', align_corners=False
        ).squeeze(1)  # (B,H,W)

        # BCE on logits
        bce = bce_loss(prd_logits_resized, gt_mask)
        if use_bce_dice:
            prd_probs = torch.sigmoid(prd_logits_resized)
            dloss = dice_loss(prd_probs, gt_mask)
            seg_loss = bce + dice_weight * dloss
        else:
            seg_loss = bce

    loss = seg_loss / accumulation_steps

    if scaler is not None:
        scaler.scale(loss).backward()
    else:
        loss.backward()

    torch.nn.utils.clip_grad_norm_(trainable, max_norm=1.0)

    if step % accumulation_steps == 0:
        if scaler is not None:
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

    # compute IoU for logging (no grad)
    with torch.no_grad():
        prd = (torch.sigmoid(prd_logits_resized) > 0.5).float()
        inter = (gt_mask * prd).sum()
        union = gt_mask.sum() + prd.sum() - inter
        iou = inter / (union + 1e-6)

    if step % 50 == 0:
        print(f"Step {step}: loss={seg_loss.item():.6f} iou={iou.item():.6f}")

    if step % save_every == 0:
        out_path = "/content/drive/MyDrive/plantation_data/models/fine_tuned_sam_vit_b_step_{}.pth".format(step)
        torch.save(sam.state_dict(), out_path)
        print("Saved checkpoint to", out_path)


In [None]:
# Cell 8: Quick evaluation helper (visualize prediction on a val sample)
sam.eval()
with torch.no_grad():
    if len(val_pairs) > 0:
        sample = read_batch_random(val_pairs, predictor, visualize=False, target_size=1024, max_pos_points=8, neg_ratio=0.5)
        if sample is None:
            print('Val sample read failed')
        else:
            input_tensor, gt_mask_np, points, labels = sample
            input_tensor = input_tensor.to(device)
            image_embeddings = sam.image_encoder(input_tensor)
            if len(image_embeddings.shape) == 3:
                image_embeddings = image_embeddings.unsqueeze(0)
            if points is None or points.shape[0] == 0:
                print('No points available in this val sample')
            else:
                pts = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0)
                labs = torch.tensor(labels, dtype=torch.int64, device=device).unsqueeze(0)
                sparse_embeddings, dense_embeddings = sam.prompt_encoder(points=(pts, labs), boxes=None, masks=None)
                image_pe = sam.prompt_encoder.get_dense_pe()
                low_res_masks, iou_preds = sam.mask_decoder(
                    image_embeddings=image_embeddings,
                    image_pe=image_pe,
                    sparse_prompt_embeddings=sparse_embeddings,
                    dense_prompt_embeddings=dense_embeddings,
                    multimask_output=True,
                )
                best_idx = iou_preds.argmax(dim=1)
                chosen = low_res_masks[torch.arange(low_res_masks.shape[0], device=device), best_idx]
                prd_logits_resized = torch.nn.functional.interpolate(chosen.unsqueeze(1), size=input_tensor.shape[2:], mode='bilinear', align_corners=False).squeeze(1)
                prd = (torch.sigmoid(prd_logits_resized) > 0.5).cpu().numpy()[0]
                plt.figure(figsize=(10,5))
                plt.subplot(1,2,1); plt.imshow(gt_mask_np, cmap='gray'); plt.title('GT'); plt.axis('off')
                plt.subplot(1,2,2); plt.imshow(prd, cmap='gray'); plt.title('Pred'); plt.axis('off')
                plt.show()
    else:
        print('No val samples to visualize')


In [None]:
# Cell 9: Save final model
final_out = "/content/drive/MyDrive/plantation_data/models/fine_tuned_sam_vit_b_final.pth"
try:
    torch.save(sam.state_dict(), final_out)
    print('Saved final model to', final_out)
except Exception as e:
    print('Saving final model failed:', e)
