<a href="https://colab.research.google.com/github/Naveen11205570/Computer-Vision-and-Deep-Learning/blob/main/Segmentation_of_dataset_SAM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import numpy as np
import pandas as pd
import cv2
!pip install segment-anything
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

class MNISTSegmentationDataset(Dataset):
    def __init__(self, transform=None):
        self.dataset = MNIST(root='./data', train=True, download=True)
        self.transform = transform or transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        img = self.transform(img)

        mask = torch.zeros((1, 28, 28))
        mask[0] = (img[0] > 0).float()

        return img, mask

dataset = MNISTSegmentationDataset()
loader = DataLoader(dataset, batch_size=8, shuffle=True)

class SmallUNet(nn.Module):
    def __init__(self):
        super(SmallUNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=2, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

device = "cuda" if torch.cuda.is_available() else "cpu"
model = SmallUNet().to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train_model(model, loader, optimizer, criterion, epochs=5):
    model.train()
    for epoch in range(epochs):
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

train_model(model, loader, optimizer, criterion)

def evaluate(model, loader):
    model.eval()
    precision, recall, f1, accuracy, dice = [], [], [], [], []
    with torch.no_grad():
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            outputs = (outputs > 0.5).float()
            y_true = masks.cpu().numpy().flatten()
            y_pred = outputs.cpu().numpy().flatten()
            precision.append(precision_score(y_true, y_pred, zero_division=1))
            recall.append(recall_score(y_true, y_pred, zero_division=1))
            f1.append(f1_score(y_true, y_pred, zero_division=1))
            accuracy.append(accuracy_score(y_true, y_pred))
            dice.append(2 * np.sum(y_pred * y_true) / (np.sum(y_pred) + np.sum(y_true)))
    return np.mean(precision), np.mean(recall), np.mean(f1), np.mean(accuracy), np.mean(dice)

unet_metrics = evaluate(model, loader)



Collecting segment-anything
  Downloading segment_anything-1.0-py3-none-any.whl.metadata (487 bytes)
Downloading segment_anything-1.0-py3-none-any.whl (36 kB)
Installing collected packages: segment-anything
Successfully installed segment-anything-1.0
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 21.7MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 606kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 5.48MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 2.55MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch 1, Loss: 0.018488317728042603
Epoch 2, Loss: 0.003987718839198351
Epoch 3, Loss: 0.007677287328988314
Epoch 4, Loss: 0.004723136313259602
Epoch 5, Loss: 0.000586644746363163


In [None]:
import requests

sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
checkpoint_file = "sam_vit_h_4b8939.pth"

response = requests.get(sam_checkpoint_url)
if response.status_code == 200:
    with open(checkpoint_file, "wb") as f:
        f.write(response.content)
    print(f"Downloaded {checkpoint_file} successfully.")
else:
    print("Failed to download the file.")

sam = sam_model_registry["vit_h"](checkpoint=checkpoint_file).to(device)
sam_mask_generator = SamAutomaticMaskGenerator(sam)

def sam_evaluate(loader, sample_size=5):
    precision, recall, f1, accuracy, dice = [], [], [], [], []

    with torch.no_grad():
        for i, (image, mask) in enumerate(loader):
            if i >= sample_size:
                break

            image = image[0].cpu().numpy()

            if image.ndim == 3:
                if image.shape[0] == 1:
                    image = image.squeeze(0)
                    image = np.stack((image,) * 3, axis=-1)
                elif image.shape[0] == 3:
                    image = image.transpose(1, 2, 0)

            elif image.ndim == 2:
                image = np.stack((image,) * 3, axis=-1)


            sam_mask = sam_mask_generator.generate(image)

            if len(sam_mask) == 0:
                print("No masks generated for image index:", i)
                continue

            y_true = mask[0].cpu().numpy().flatten()

            y_pred = (sam_mask[0]['segmentation']).astype(np.uint8).flatten()

            y_pred = np.where(y_pred > 0, 1, 0)

            precision.append(precision_score(y_true, y_pred, zero_division=1))
            recall.append(recall_score(y_true, y_pred, zero_division=1))
            f1.append(f1_score(y_true, y_pred, zero_division=1))
            accuracy.append(accuracy_score(y_true, y_pred))
            dice.append(2 * np.sum(y_pred * y_true) / (np.sum(y_pred) + np.sum(y_true) + 1e-6))

    return np.mean(precision), np.mean(recall), np.mean(f1), np.mean(accuracy), np.mean(dice)

sam_metrics = sam_evaluate(loader)

comparison = pd.DataFrame({
    "Metric": ["Precision", "Recall", "F1-Score", "Accuracy", "Dice-Score"],
    "UNet": unet_metrics,
    "SAM": sam_metrics
})
print(comparison)

Downloaded sam_vit_h_4b8939.pth successfully.


  state_dict = torch.load(f)


       Metric      UNet       SAM
0   Precision  0.997513  0.047740
1      Recall  0.996365  0.229709
2    F1-Score  0.996937  0.078679
3    Accuracy  0.998837  0.073214
4  Dice-Score  0.996937  0.078679
