In [None]:
import numpy as np
import os
from PIL import Image
from datasets import Dataset
from torch.utils.data import Dataset as TorchDataset
from transformers import SamProcessor
import matplotlib.pyplot as plt
import random
from torch.utils.data import DataLoader
from transformers import SamModel
from torch.optim import Adam
import monai
from tqdm import tqdm
from statistics import mean
import torchvision.transforms as transforms  # Import for resizing

In [None]:
# ------------------ Hyperparameters ------------------ #
image_dir = r"E:/Random Python Scripts/Tata HaxS/SAM/Dataset/train/images"  # Image directory path
mask_dir = r"E:/Random Python Scripts/Tata HaxS/SAM/Dataset/train/masks"    # Mask directory path

batch_size = 2            # Batch size for DataLoader
learning_rate = 1e-5      # Learning rate for optimizer
weight_decay = 0          # Weight decay for optimizer
num_epochs = 10           # Number of epochs for training
loss_fn_type = 'DiceCELoss'  # Loss function: Choose between DiceFocalLoss, FocalLoss, DiceCELoss
model_save_path = r"E:/Random Python Scripts/Tata HaxS/Models/Models/SAM/Lmao/lmao2.pth"  # Path to save model

# ------------------------------------------------------ #

In [None]:
image_files = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
mask_files = [f for f in os.listdir(mask_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

all_images = []
all_masks = []

# Transform for resizing
resize_transform = transforms.Resize((256, 256))  # Adjust as necessary

# Process images
for img_file in image_files:
    try:
        large_image = Image.open(os.path.join(image_dir, img_file)).convert('RGB')
        all_images.append(np.array(large_image))
    except Exception as e:
        print(f"Error loading image: {img_file} - {e}")
        continue

# Process masks
for mask_file in mask_files:
    try:
        large_mask = Image.open(os.path.join(mask_dir, mask_file)).convert('L')  # Load as grayscale
        # Resize mask to match model's output shape
        large_mask = resize_transform(large_mask)  # Resize the mask
        all_masks.append(np.array(large_mask))
    except Exception as e:
        print(f"Error loading mask: {mask_file} - {e}")
        continue

# Convert lists to NumPy arrays
images = np.array(all_images)
masks = np.array(all_masks)

In [None]:
# Create a list to store the indices of non-empty masks
valid_indices = [i for i, mask in enumerate(masks) if mask.max() != 0]
filtered_images = images[valid_indices]
filtered_masks = masks[valid_indices]

In [None]:
# Convert the NumPy arrays to Pillow images and store them in a dictionary
dataset_dict = {
    "image": [Image.fromarray(img) for img in filtered_images],
    "label": [Image.fromarray(mask) for mask in filtered_masks],
}

# Create the dataset using the datasets.Dataset class
dataset = Dataset.from_dict(dataset_dict)

In [None]:
# Visualization
img_num = random.randint(0, filtered_images.shape[0] - 1)
example_image = dataset[img_num]["image"]
example_mask = dataset[img_num]["label"]

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(np.array(example_image), cmap='gray')  # Assuming the first image is grayscale
axes[0].set_title("Image")
axes[1].imshow(example_mask, cmap='gray')  # Assuming the second image is grayscale
axes[1].set_title("Mask")

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

plt.show()

In [None]:
# Get bounding boxes from mask
def get_bounding_box(ground_truth_map):
    y_indices, x_indices = np.where(ground_truth_map > 0)
    if len(x_indices) == 0 or len(y_indices) == 0:  # If no object found
        return [0, 0, 0, 0]  # Return a dummy bounding box
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    H, W = ground_truth_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))
    bbox = [x_min, y_min, x_max, y_max]

    return bbox

In [None]:
# Dataset class for SAM model
class SAMDataset(TorchDataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"].convert('RGB')
        ground_truth_mask = np.array(item["label"])

        prompt = get_bounding_box(ground_truth_mask)

        inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs["ground_truth_mask"] = ground_truth_mask

        return inputs

In [None]:
# Initialize the processor
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
train_dataset = SAMDataset(dataset=dataset, processor=processor)

example = train_dataset[0]
for k,v in example.items():
    print(k,v.shape)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)

batch = next(iter(train_dataloader))
for k,v in batch.items():
    print(k,v.shape)

In [None]:
# Load the model
model = SamModel.from_pretrained("facebook/sam-vit-base")

In [None]:
# Ensure only the mask decoder gradients are computed
for name, param in model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)

# Initialize optimizer and loss function based on hyperparameter choice
optimizer = Adam(model.mask_decoder.parameters(), lr=learning_rate, weight_decay=weight_decay)

if loss_fn_type == 'DiceCELoss':
    seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
elif loss_fn_type == 'FocalLoss':
    seg_loss = monai.losses.FocalLoss(sigmoid=True)
elif loss_fn_type == 'DiceFocalLoss':
    seg_loss = monai.losses.DiceFocalLoss(sigmoid=True, squared_pred=True)

In [None]:
# Training loop
model.to("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
        outputs = model(pixel_values=batch["pixel_values"].to(device),
                        input_boxes=batch["input_boxes"].to(device),
                        multimask_output=False)

        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

In [None]:
# Save the model's state dictionary to a file
torch.save(model.state_dict(), model_save_path)