In [None]:
# # Install torch and torchvision
# !pip install torch torchvision

# # Install datasets library
# !pip install datasets

# # Install numpy (already pre-installed in Colab but can be reinstalled if needed)
# !pip install numpy


In [None]:
# # Mount Google Drive
# from google.colab import drive
# drive.mount('/content/drive')

# # Install required libraries
# !pip install datasets
# !pip install torchvision

In [None]:
from PIL import Image
from torchvision import transforms
import torch
from datasets import Dataset, Features, Image
import os
from sklearn.model_selection import train_test_split
from torchvision.models.segmentation import deeplabv3_resnet50
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import cv2

# img = Image.open("./sar_images/images/train/100.png")
# print(img.mode)  # 'L' for grayscale, 'RGB' for color

In [None]:
def preprocess(batch):
    # Convert images to RGB (3 channels)
    images = [img.convert('RGB') for img in batch['image']]
    # Convert masks to grayscale (1 channel)
    masks = [mask.convert('L') for mask in batch['segmentation_mask']]

    # Resize images and masks to 512x512
    images = [transforms.Resize((512, 512))(img) for img in images]
    masks = [transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR)(mask) for mask in masks]

    # Convert images to tensors and normalize
    images = [transforms.ToTensor()(img) for img in images]
    #images = [transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) for img in images]

    # Convert masks to Long tensors and map to 0 and 1
    masks = [transforms.functional.pil_to_tensor(mask).long() for mask in batch['segmentation_mask']]
    masks = [mask.squeeze(0) for mask in masks]  # Remove channel dimension
    masks = [(mask > 0).long() for mask in masks]  # Map any non-zero value to 1, keep as Long

    return {"pixel_values": images, "labels": masks}

In [None]:
# Define paths for training data (update based on your folder structure)
train_image_dir = "./sar_images/images/train/"
train_mask_dir = "./sar_images/masks/train/"

# Get sorted lists of training file paths
train_image_files = sorted([os.path.join(train_image_dir, f) for f in os.listdir(train_image_dir) if f.endswith('.png')])
train_mask_files = sorted([os.path.join(train_mask_dir, f) for f in os.listdir(train_mask_dir) if f.endswith('.png')])

# Verify matching number of images and masks
assert len(train_image_files) == len(train_mask_files), "Number of training images and masks must be equal"

# Split into training and validation sets
train_image_files, val_image_files, train_mask_files, val_mask_files = train_test_split(
    train_image_files, train_mask_files, test_size=0.2, random_state=42
)

# Create training dataset dictionary
train_dataset_dict = {
    "image": train_image_files,
    "segmentation_mask": train_mask_files
}

# Define features
features = Features({
    "image": Image(),
    "segmentation_mask": Image()
})

# Create the training dataset
train_dataset = Dataset.from_dict(train_dataset_dict, features=features)

# Set the preprocessing transform
train_dataset.set_transform(preprocess)

# Create validation dataset dictionary
val_dataset_dict = {
    "image": val_image_files,
    "segmentation_mask": val_mask_files
}

# Create the validation dataset
val_dataset = Dataset.from_dict(val_dataset_dict, features=features)

# Set the preprocessing transform
val_dataset.set_transform(preprocess)

In [None]:
# Load pre-trained DeepLabV3+ with ResNet50 backbone
model = deeplabv3_resnet50(pretrained=True)

# Modify the classifier for 2 classes
model.classifier[4] = nn.Conv2d(256, 2, kernel_size=1)

In [None]:
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training loop
num_epochs = 50
best_val_loss = float("inf")  # Initialize with a large value
patience = 8
patience_counter = 0
checkpoint_path = "./deeplabv3/saved_model.pth"  # Path to save the best model

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for i, batch in enumerate(train_loader):
        images = batch['pixel_values'].to(device)
        masks = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(images)['out']  # DeepLabV3+ outputs a dict with 'out'
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Print loss every 10 batches to prevent Colab disconnection
        if i % 10 == 0:
            print(f"Epoch {epoch+1}, Batch {i}, Loss: {loss.item()}")

    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss}")

    # Validation Phase
    model.eval()
    val_loss = 0.0
    val_accuracy = 0.0
    with torch.no_grad():
        for batch in val_loader:
            images = batch['pixel_values'].to(device)
            masks = batch['labels'].to(device)

            outputs = model(images)['out']
            loss = criterion(outputs, masks)
            val_loss += loss.item()
            
            preds = torch.argmax(outputs, dim=1)
            correct = (preds == masks).sum().item()
            total = masks.numel()
            val_accuracy += correct / total

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy /= len(val_loader)
    print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss}, Validation Accuracy: {val_accuracy}")

    # Checkpoint: Save the best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Best model saved with Validation Loss: {avg_val_loss}")
    else:
        patience_counter += 1

    # Early stopping if no improvement
    if patience_counter >= patience:
        print("Early stopping triggered.")
        break

In [None]:
model.load_state_dict(torch.load("./deeplabv3/saved_model.pth"))
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
model.eval()
val_loss = 0.0
val_accuracy = 0.0
with torch.no_grad():
    for batch in val_loader:
        images = batch['pixel_values'].to(device)
        masks = batch['labels'].to(device)
        outputs = model(images)['out']
        loss = criterion(outputs, masks)
        val_loss += loss.item()

        # Compute pixel accuracy
        preds = torch.argmax(outputs, dim=1)
        correct = (preds == masks).sum().item()
        total = masks.numel()
        val_accuracy += correct / total

val_loss /= len(val_loader)
val_accuracy /= len(val_loader)
print(f"Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}")

# Testing

In [None]:
test_image_dir = "./sar_images/images/test/"
test_mask_dir = "./sar_images/masks/test/"

# Get sorted lists of training file paths
test_image_files = sorted([os.path.join(test_image_dir, f) for f in os.listdir(test_image_dir) if f.endswith('.png')])
test_mask_files = sorted([os.path.join(test_mask_dir, f) for f in os.listdir(test_mask_dir) if f.endswith('.png')])

test_dataset_dict = {
    "image": test_image_files,
    "segmentation_mask": test_mask_files
}

# Create the training dataset
test_dataset = Dataset.from_dict(test_dataset_dict, features=features)

# Set the preprocessing transform
test_dataset.set_transform(preprocess)
test_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

In [None]:


# Function to compute IoU for each class
def compute_iou(preds, targets, num_classes):
    """
    Compute IoU for each class.

    Args:
        preds (torch.Tensor): Predicted class labels (shape: [batch_size, height, width])
        targets (torch.Tensor): Ground truth labels (shape: [batch_size, height, width])
        num_classes (int): Number of classes (e.g., 2 for background and object)

    Returns:
        list: IoU for each class
    """
    ious = []
    for cls in range(num_classes):
        # Binary masks for the current class
        pred_cls = (preds == cls)  # Predicted pixels for class
        target_cls = (targets == cls)  # Ground truth pixels for class

        # Compute intersection and union
        intersection = (pred_cls & target_cls).sum().item()  # True positives
        union = (pred_cls | target_cls).sum().item()  # TP + FP + FN

        # Handle division by zero (class not present in batch)
        if union == 0:
            iou = float('nan')  # Exclude from mean calculation
        else:
            iou = intersection / union
        ious.append(iou)
    return ious

# Evaluation on the test set
def evaluate(model, test_loader, criterion, device, num_classes=2):
    """
    Evaluate the model on the test set and compute class-wise IoU and mean IoU.

    Args:
        model: Trained segmentation model
        test_loader: DataLoader for the test set
        criterion: Loss function
        device: Device to run the model on (e.g., 'cuda' or 'cpu')
        num_classes (int): Number of classes in the segmentation task
    """
    model.eval()  # Set model to evaluation mode
    test_loss = 0.0
    test_accuracy = 0.0
    all_ious = []  # Store IoU for each batch

    with torch.no_grad():  # Disable gradient computation
        for batch in test_loader:
            # Assuming batch contains 'pixel_values' (images) and 'labels' (masks)
            images = batch['pixel_values'].to(device)
            masks = batch['labels'].to(device)

            # Forward pass
            outputs = model(images)['out']  # Model output (logits)
            loss = criterion(outputs, masks)
            test_loss += loss.item()

            # Get predicted class labels
            preds = torch.argmax(outputs, dim=1)  # Shape: [batch_size, height, width]

            # Compute pixel accuracy
            correct = (preds == masks).sum().item()
            total = masks.numel()  # Total number of pixels
            test_accuracy += correct / total

            # Compute IoU for this batch
            ious = compute_iou(preds, masks, num_classes)
            all_ious.append(ious)

    # Calculate average metrics
    test_loss /= len(test_loader)
    test_accuracy /= len(test_loader)

    # Convert IoU list to numpy array for easier computation
    all_ious = np.array(all_ious)  # Shape: [num_batches, num_classes]

    # Compute class-wise IoU (mean over batches, ignoring nan)
    class_wise_iou = np.nanmean(all_ious, axis=0)  # IoU per class

    # Compute mean IoU (average across classes)
    mean_iou = np.nanmean(class_wise_iou)

    # Print results
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print(f"Class-wise IoU: {class_wise_iou}")
    print(f"Mean IoU: {mean_iou:.4f}")

    return test_loss, test_accuracy, class_wise_iou, mean_iou

# Example usage
# Assuming model, test_loader, criterion, and device are defined
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_loss, test_accuracy, class_wise_iou, mean_iou = evaluate(model, test_loader, criterion, device, num_classes=2)

In [None]:
output_dir = "./GEE_Output/DeepLabOutputs/Outputs"
# Inverse normalization (if applicable)
#inv_normalize = transforms.Normalize(
 #   mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
#    std=[1/0.229, 1/0.224, 1/0.225]
#)

# Select indices
# indices = [0, 1, 2, 3, 4]# Prediction function
def get_prediction(model, image):
    with torch.no_grad():
        output = model(image.unsqueeze(0).to(device))['out']
        pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
    return pred

# Visualization function
def visualize(image, mask, pred, idx, name):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle(f'Sample {idx}', fontsize=16)

    #if image.shape[0] == 3:
       # image = inv_normalize(image).permute(1, 2, 0).numpy()
    #else:
    image = image.permute(1, 2, 0).numpy()
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')

    axes[2].imshow(pred, cmap='gray')
    axes[2].set_title('Prediction')
    axes[2].axis('off')

    plt.savefig(f"./GEE_Output/DeepLabOutputs/Visualizations/{name}")
    plt.show()

# Visualize 4-5 samples (replace 'test_dataset' with your dataset)
for idx in range(len(test_dataset)):
    sample = test_dataset[idx]
    name = test_image_files[idx]
    image = sample['pixel_values']
    mask = sample['labels'].numpy()
    pred = get_prediction(model, image)
    
    if isinstance(pred, torch.Tensor):
        pred = pred.cpu().numpy()

    # Normalize prediction to [0, 255] (for visualization)
    pred_norm = pred * 255
    pred_norm = pred_norm.astype(np.uint8)

    # Save predicted mask
    name = os.path.basename(name)
    pred_path = os.path.join(output_dir, f"pred_{name}")
    cv2.imwrite(pred_path, pred_norm)
    
    visualize(image, mask, pred, idx, name)

In [None]:
# import matplotlib.pyplot as plt
# import torch
# import numpy as np
# from torchvision import transforms

# # Define inverse normalization (adjust if your normalization differs)
# #inv_normalize = transforms.Normalize(
#  #   mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
# #    std=[1/0.229, 1/0.224, 1/0.225]
# #)

# # Set up the plot
# num_samples = 3  # Number of samples to visualize
# fig, axes = plt.subplots(num_samples, 2, figsize=(10, 5 * num_samples))

# for i in range(num_samples):
#     # Access the dictionary and extract tensors
#     sample = train_dataset[i]
#     image = sample["pixel_values"]  # Image tensor (C, H, W)
#     mask = sample["labels"]        # Mask tensor (H, W)

#     # Unnormalize the image and convert to NumPy array
#     #image = inv_normalize(image)
#     image = np.clip(image.permute(1, 2, 0).numpy(), 0, 1)  # (H, W, C)

#     # Convert mask to NumPy array
#     mask = mask.numpy()  # (H, W)

#     # Plot original image
#     axes[i, 0].imshow(image)
#     axes[i, 0].set_title('Original Image')
#     axes[i, 0].axis('off')

#     # Plot ground truth mask
#     axes[i, 1].imshow(mask, cmap='gray')
#     axes[i, 1].set_title('Ground Truth Mask')
#     axes[i, 1].axis('off')

# plt.tight_layout()
# plt.show()