## Step-by-step Training Notebook for SegFormer on Weld Defect Dataset

In [None]:
# Install Required Libraries
!pip install torch torchvision transformers matplotlib pandas

### Import necessary libraries

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor, AdamW
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from PIL import Image
import os
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd

### Set device (use GPU if available)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

### Configuration


In [None]:
model_name = "nvidia/segformer-b0-finetuned-ade-512-512"  # Pretrained SegFormer model
num_classes = 2  # Number of classes for segmentation
data_image_size = (512, 512)  # Resize images to 512x512
batch_size = 16  # Batch size for training and validation
num_epochs = 20  # Total number of training epochs
learning_rate = 5e-5  # Learning rate for the optimizer
checkpoint_path = "best_model_checkpoint.pth"  # Path to save the best model checkpoint

# Paths to dataset directories (Update paths to your dataset)
train_image_dir = "path_to_train_images"
train_mask_dir = "path_to_train_masks"
valid_image_dir = "path_to_valid_images"
valid_mask_dir = "path_to_valid_masks"

### Define a custom dataset class


In [None]:
class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, mask_dir, feature_extractor):
        """
        Initializes the dataset.
        
        Args:
            image_dir (str): Path to the directory containing images.
            mask_dir (str): Path to the directory containing masks.
            feature_extractor: Pretrained feature extractor from Hugging Face.
        """
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.feature_extractor = feature_extractor
        self.images = sorted([f for f in os.listdir(image_dir) if f.endswith(".jpg")])
        self.masks = sorted([f for f in os.listdir(mask_dir) if f.endswith(".png")])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        """Returns a single sample (image and mask)"""
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])

        # Load and process the image
        image = Image.open(img_path).convert("RGB")
        processed_image = self.feature_extractor(image, return_tensors="pt")['pixel_values'].squeeze(0)

        # Load and process the mask
        mask = Image.open(mask_path).resize(data_image_size, Image.NEAREST)
        mask_tensor = torch.tensor(np.array(mask), dtype=torch.long)

        return {"pixel_values": processed_image, "labels": mask_tensor}

### Initialize feature extractor


In [None]:
feature_extractor = SegformerFeatureExtractor.from_pretrained(model_name, reduce_labels=False)

# Load datasets
train_dataset = SegmentationDataset(train_image_dir, train_mask_dir, feature_extractor)
valid_dataset = SegmentationDataset(valid_image_dir, valid_mask_dir, feature_extractor)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

### Load SegFormer model


In [None]:
model = SegformerForSemanticSegmentation.from_pretrained(
    model_name, ignore_mismatched_sizes=True, num_labels=num_classes
)
model.to(device)

# Define optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate)

### Training and validation loop


In [None]:
train_losses = []
val_mious = []
val_accuracies = []
best_miou = 0  # To track the best model

# Define metric computation function
def compute_metrics(preds, labels, num_classes):
    preds = preds.flatten()
    labels = labels.flatten()
    intersection = torch.zeros(num_classes, dtype=torch.float32)
    union = torch.zeros(num_classes, dtype=torch.float32)
    
    for cls in range(num_classes):
        intersection[cls] = ((preds == cls) & (labels == cls)).sum()
        union[cls] = ((preds == cls) | (labels == cls)).sum()

    miou = (intersection / (union + 1e-6)).mean().item()
    pixel_acc = (preds == labels).float().mean().item()
    return miou, pixel_acc

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    # Training step
    for batch in train_loader:
        inputs = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)
    print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {avg_loss:.4f}")

    # Validation step
    model.eval()
    total_miou = 0
    total_accuracy = 0
    
    with torch.no_grad():
        for batch in valid_loader:
            inputs = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(pixel_values=inputs)
            logits = outputs.logits

            # Resize logits to match labels
            logits_upsampled = F.interpolate(logits, size=labels.shape[1:], mode="bilinear", align_corners=False)
            preds = torch.argmax(logits_upsampled, dim=1)

            # Compute metrics
            miou, accuracy = compute_metrics(preds.cpu(), labels.cpu(), num_classes)
            total_miou += miou
            total_accuracy += accuracy

    avg_miou = total_miou / len(valid_loader)
    avg_accuracy = total_accuracy / len(valid_loader)
    val_mious.append(avg_miou)
    val_accuracies.append(avg_accuracy)

    print(f"Validation mIoU: {avg_miou:.4f}, Pixel Accuracy: {avg_accuracy:.4f}")

    # Save best model
    if avg_miou > best_miou:
        best_miou = avg_miou
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Best model saved with mIoU: {best_miou:.4f}")

### Plot metrics


In [None]:
metrics_df = pd.DataFrame({
    "Epoch": list(range(1, num_epochs + 1)),
    "Training Loss": train_losses,
    "Validation mIoU": val_mious,
    "Validation Accuracy": val_accuracies
})
metrics_df.to_csv("training_metrics.csv", index=False)

plt.figure(figsize=(12, 6))
plt.plot(metrics_df["Epoch"], metrics_df["Training Loss"], label="Loss")
plt.plot(metrics_df["Epoch"], metrics_df["Validation mIoU"], label="mIoU")
plt.plot(metrics_df["Epoch"], metrics_df["Validation Accuracy"], label="Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Metric Value")
plt.title("Training Metrics")
plt.legend()
plt.grid()
plt.savefig("training_metrics_plot.png")
plt.show()

### Testing and visualization


In [None]:
def visualize_predictions(model, feature_extractor, images, masks, device):
    model.eval()
    for img_path, mask_path in zip(images, masks):
        image = Image.open(img_path).convert("RGB")
        original_mask = Image.open(mask_path)
        input_image = feature_extractor(image, return_tensors="pt")["pixel_values"].to(device)

        with torch.no_grad():
            outputs = model(pixel_values=input_image)

        logits = outputs.logits
        logits_upsampled = F.interpolate(
            logits,
            size=original_mask.size[::-1],  # (width, height) in PIL format
            mode="bilinear",
            align_corners=False,
        )
        predicted_mask = torch.argmax(logits_upsampled, dim=1).squeeze(0).cpu().numpy()

        plt.figure(figsize=(10, 4))
        plt.subplot(1, 3, 1)
        plt.title("Original Image")
        plt.imshow(image)

        plt.subplot(1, 3, 2)
        plt.title("Original Mask")
        plt.imshow(np.array(original_mask), cmap="gray")

        plt.subplot(1, 3, 3)
        plt.title("Predicted Mask")
        plt.imshow(predicted_mask, cmap="gray")

        plt.show()

# Test visualization
sample_images = ["path_to_test_image_1", "path_to_test_image_2"]
sample_masks = ["path_to_test_mask_1", "path_to_test_mask_2"]
visualize_predictions(model, feature_extractor, sample_images, sample_masks, device)