In [1]:
import torch
from segment_anything import sam_model_registry

# Load SAM model
sam_checkpoint = "segment-anything/sam_vit_h_4b8939.pth"  # Path to SAM weights
model_type = "vit_h"  # Use "vit_b" for smaller models
device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device)


Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-31): 32 x Block(
        (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1280, out_features=3840, bias=True)
          (proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=1280, out_features=5120, bias=True)
          (lin2): Linear(in_features=5120, out_features=1280, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d

In [11]:
import os
import cv2
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class PlantDiseaseDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.mask_paths = []

        # Load all image and mask paths
        for class_name in os.listdir(root_dir):
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                mask_path = img_path.replace("images", "masks")  # Assuming masks are stored separately
                self.image_paths.append(img_path)
                self.mask_paths.append(mask_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]
    
        # Load image and mask
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  # Load as grayscale
    
        # Apply transformations
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']  # Already a tensor due to ToTensorV2()
            mask = augmented['mask']    # Already a tensor due to ToTensorV2()
    
        # Add channel dimension to mask
        mask = mask.unsqueeze(0)  # Shape: (1, H, W)
    
        return image, mask

In [12]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

transform = A.Compose([
    A.Resize(256, 256),  # Resize to 256x256
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # ImageNet normalization
    ToTensorV2(),
])

In [13]:
from torch.utils.data import DataLoader

# Define dataset paths
train_dataset = PlantDiseaseDataset(root_dir="Plant_Disease_Dataset_Unified/train", transform=transform)
valid_dataset = PlantDiseaseDataset(root_dir="Plant_Disease_Dataset_Unified/valid", transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

In [17]:
dataset = PlantDiseaseDataset(root_dir="Plant_Disease_Dataset_Unified/train", transform=transform)
image, mask = dataset[0]  # Get the first sample
print("Image shape:", image.shape)  # Should be (3, H, W)
print("Mask shape:", mask.shape)   # Should be (1, H, W)

Image shape: torch.Size([3, 256, 256])
Mask shape: torch.Size([1, 256, 256])


In [15]:
for images, masks in train_loader:
    print("Images shape:", images.shape)
    print("Masks shape:", masks.shape)
    break  # Stop after one batch

Images shape: torch.Size([8, 3, 256, 256])
Masks shape: torch.Size([8, 1, 256, 256])


In [20]:
import cv2
import numpy as np
from PIL import Image

def generate_mask(image_path, output_path):
    # Load an image
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"Could not load image at path: {image_path}")

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Initialize SAM predictor
    predictor.set_image(image)

    # Define a point prompt
    input_point = np.array([[100, 100]])  # Point on diseased region
    input_label = np.array([1])  # 1 for foreground

    # Generate mask
    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,  # Enable multi-mask output
    )

    # Save the mask
    mask_image = Image.fromarray((masks[0] * 255).astype(np.uint8))
    mask_image.save(output_path)

# Example usage
try:
    generate_mask("path/to/image.jpg", "path/to/output/mask.png")
except FileNotFoundError as e:
    print(e)

Could not load image at path: path/to/image.jpg


[ WARN:0@871.622] global loadsave.cpp:268 findDecoder imread_('path/to/image.jpg'): can't open/read file: check file path/integrity
