In [None]:
%pip install torchinfo
%pip install segment-anything
%pip install torchvision
%pip install torch



In [None]:
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import sklearn.model_selection
import torch
import torch.nn as nn
import torch.optim as optim
import torchinfo
import torchvision
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision import datasets, transforms
from tqdm import tqdm
from segment_anything import SamPredictor, sam_model_registry

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
sam = sam_model_registry["vit_b"](checkpoint="/content/drive/MyDrive/plantation_data/models/sam_vit_b.pth")
sam.to(device)
predictor = SamPredictor(sam)

In [None]:
sam.image_encoder

ImageEncoderViT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (lin1): Linear(in_features=768, out_features=3072, bias=True)
        (lin2): Linear(in_features=3072, out_features=768, bias=True)
        (act): GELU(approximate='none')
      )
    )
  )
  (neck): Sequential(
    (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): LayerNorm2d()
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (3): LayerNorm2d()
  )
)

In [None]:
data_dir = r"/content/drive/MyDrive/plantation_data/train_dat"
print("Data directory:", data_dir)

Data directory: /content/drive/MyDrive/plantation_data/train_dat


In [None]:
class ConvertToRGB(object):
    def __call__(self, img):
        if img.mode != "RGB":
            img = img.convert("RGB")
        return img

In [None]:
transform = transforms.Compose([
    ConvertToRGB(),
    transforms.Resize(1024),
    transforms.CenterCrop(1024),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
dataset = datasets.ImageFolder(data_dir,transform)
dataset

Dataset ImageFolder
    Number of datapoints: 619
    Root location: /content/drive/MyDrive/plantation_data/train_dat
    StandardTransform
Transform: Compose(
               <__main__.ConvertToRGB object at 0x7d693457ae40>
               Resize(size=1024, interpolation=bilinear, max_size=None, antialias=True)
               CenterCrop(size=(1024, 1024))
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [None]:
classes = dataset.classes
classes

['brown_',
 'dark green_',
 'farm',
 'farm2',
 'green_',
 'grid_with_lines',
 'useless_']

In [None]:
batch_size = 1
dataset_loader = DataLoader(dataset,batch_size=batch_size)

print(f"Batch shape: {next(iter(dataset_loader))[0].shape}")

Batch shape: torch.Size([1, 3, 1024, 1024])


In [None]:
test_batch = next(iter(dataset_loader))[0].to(device)
batch_shape = test_batch.shape

# Create the model summary
summary(sam.image_encoder, input_size=batch_shape)

Layer (type:depth-idx)                   Output Shape              Param #
ImageEncoderViT                          [1, 256, 64, 64]          3,145,728
├─PatchEmbed: 1-1                        [1, 64, 64, 768]          --
│    └─Conv2d: 2-1                       [1, 768, 64, 64]          590,592
├─ModuleList: 1-2                        --                        --
│    └─Block: 2-2                        [1, 64, 64, 768]          --
│    │    └─LayerNorm: 3-1               [1, 64, 64, 768]          1,536
│    │    └─Attention: 3-2               [25, 14, 14, 768]         2,365,824
│    │    └─LayerNorm: 3-3               [1, 64, 64, 768]          1,536
│    │    └─MLPBlock: 3-4                [1, 64, 64, 768]          4,722,432
│    └─Block: 2-3                        [1, 64, 64, 768]          --
│    │    └─LayerNorm: 3-5               [1, 64, 64, 768]          1,536
│    │    └─Attention: 3-6               [25, 14, 14, 768]         2,365,824
│    │    └─LayerNorm: 3-7               [1

In [None]:
print(sam.image_encoder(test_batch).shape)

torch.Size([1, 256, 64, 64])


In [None]:
from segment_anything import SamAutomaticMaskGenerator
import cv2
mask_generator = SamAutomaticMaskGenerator(sam)

def generate_and_save_masks(image_folder, output_folder):
    os.makedirs(output_folder, exist_ok=True)

    # Walk through all subdirectories
    for root, dirs, files in os.walk(image_folder):
        # Get relative path to maintain folder structure in output
        rel_path = os.path.relpath(root, image_folder)

        for fname in files:
            # Skip non-image files
            if not fname.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp')):
                continue

            try:
                # Construct full input path
                img_path = os.path.join(root, fname)

                # Create corresponding output subdirectory
                out_subdir = os.path.join(output_folder, rel_path)
                os.makedirs(out_subdir, exist_ok=True)

                # Read and process image
                image_bgr = cv2.imread(img_path)
                if image_bgr is None:
                    print(f"Failed to load image: {img_path}")
                    continue

                image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
                masks = mask_generator.generate(image_rgb)

                # Save masks in corresponding subdirectory
                for i, mask_dict in enumerate(masks):
                    mask = mask_dict['segmentation'].astype(np.uint8) * 255
                    out_path = os.path.join(out_subdir, f"{fname[:-4]}_mask_{i}.png")
                    cv2.imwrite(out_path, mask)

                print(f"Processed: {img_path} - Generated {len(masks)} masks")

            except Exception as e:
                print(f"Error processing {fname}: {str(e)}")
                continue

In [None]:
generate_and_save_masks(data_dir,r"/content/drive/MyDrive/plantation_data/masks" )

Processed: /content/drive/MyDrive/plantation_data/train_dat/brown_/subimage_1_15.png - Generated 22 masks
Processed: /content/drive/MyDrive/plantation_data/train_dat/brown_/subimage_1_14.png - Generated 13 masks
Processed: /content/drive/MyDrive/plantation_data/train_dat/brown_/subimage_1_13.png - Generated 24 masks
Processed: /content/drive/MyDrive/plantation_data/train_dat/brown_/subimage_1_10.png - Generated 19 masks
Processed: /content/drive/MyDrive/plantation_data/train_dat/brown_/subimage_1_9.png - Generated 8 masks
Processed: /content/drive/MyDrive/plantation_data/train_dat/brown_/subimage_1_8.png - Generated 17 masks
Processed: /content/drive/MyDrive/plantation_data/train_dat/brown_/subimage_1_7.png - Generated 13 masks
Processed: /content/drive/MyDrive/plantation_data/train_dat/brown_/subimage_0_15.png - Generated 9 masks
Processed: /content/drive/MyDrive/plantation_data/train_dat/brown_/subimage_0_14.png - Generated 6 masks
Processed: /content/drive/MyDrive/plantation_data/tr

In [None]:
for params in sam.image_encoder.parameters():
    params.requires_grad = False
for param in sam.image_encoder.neck.parameters():
    param.requires_grad = True
for param in sam.mask_decoder.parameters():
    param.requires_grad = True