# AutoMM for Semantic Segmentation

Semantic Segmentation is a computer vision task where the objective is to create a detailed pixel-wise segmentation map of an image, assigning each pixel to a specific class or object. This technology is crucial in various applications, such as in autonomous vehicles to identify vehicles, pedestrians, traffic signs, pavement, and other road features.

The Segment Anything Model (SAM) is a foundational model pretrained on a vast dataset with 1 billion masks and 11 million images. While SAM performs exceptionally well on generic scenes, it encounters challenges when applied to specialized domains like remote sensing, medical imagery, agriculture, and manufacturing. Fortunately, AutoMM comes to the rescue by facilitating the fine-tuning of SAM on domain-specific data.

In [None]:
from autogluon.multimodal import MultiModalPredictor
import pandas as pd
import os

## Prepare Data

For demonstration purposes, we use the Leaf Disease Segmentation from Kaggle. This dataset is a good example for automating disease detection in plants, especially for speeding up the plant pathology process. Segmenting specific regions on leaves or plants can be quite challenging, particularly when dealing with smaller diseased areas or various types of diseases.

In [None]:
train_data = pd.read_csv('leaf_disease_segmentation/train.csv', index_col=0)
val_data = pd.read_csv('leaf_disease_segmentation/val.csv', index_col=0)
test_data = pd.read_csv('leaf_disease_segmentation/test.csv', index_col=0)
image_col = 'image'
label_col = 'label'

In [None]:
def path_expander(path, base_folder):
    path_l = path.split(';')
    return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])

for per_col in [image_col, label_col]:
    train_data[per_col] = train_data[per_col].apply(lambda ele: path_expander(ele, base_folder='leaf_disease_segmentation'))
    val_data[per_col] = val_data[per_col].apply(lambda ele: path_expander(ele, base_folder='leaf_disease_segmentation'))
    test_data[per_col] = test_data[per_col].apply(lambda ele: path_expander(ele, base_folder='leaf_disease_segmentation'))
    

print(train_data[image_col].iloc[0])
print(train_data[label_col].iloc[0])

In [None]:
from autogluon.multimodal import MultiModalPredictor
predictor_zero_shot = MultiModalPredictor(
    problem_type="semantic_segmentation", 
    label=label_col,
     hyperparameters={
            "model.sam.checkpoint_name": "facebook/sam-vit-base",
        },
    num_classes=1, # forground-background segmentation
)

In [None]:
pred_zero_shot = predictor_zero_shot.predict({'image': [test_data.iloc[0]['image']]})

In [None]:
from autogluon.multimodal.utils import SemanticSegmentationVisualizer
visualizer = SemanticSegmentationVisualizer()
visualizer.plot_mask(pred_zero_shot)

It's worth noting that SAM without prompts outputs a rough leaf mask instead of disease masks due to its lack of context about the domain task. While SAM can perform better with proper click prompts, it might not be an ideal end-to-end solution for some applications that require a standalone model for deployment.

You can also conduct a zero-shot evaluation on the test data.  
As expected, the test score of the zero-shot SAM is relatively low.

In [None]:
scores = predictor_zero_shot.evaluate(test_data, metrics=["iou"])
print(scores)

## Finetune SAM

Next, let's explore how to fine-tune SAM for enhanced performance.  
Initialize a new predictor and fit it with the training and validation data.

In [None]:
from autogluon.multimodal import MultiModalPredictor

predictor = MultiModalPredictor(
    problem_type="semantic_segmentation", 
    label="label",
     hyperparameters={
            "model.sam.checkpoint_name": "facebook/sam-vit-base",
        },
)
predictor.fit(
    train_data=train_data,
    tuning_data=val_data,
)

After fine-tuning, evaluate SAM on the test data.

In [None]:
scores = predictor.evaluate(test_data, metrics=["iou"])
print(scores)

To visualize the impact, let's examine the predicted mask after fine-tuning.

In [None]:
pred = predictor.predict({'image': [test_data.iloc[0]['image']]})
visualizer.plot_mask(pred)