In [22]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

In [23]:
class TextSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_filenames = sorted(os.listdir(image_dir))  # Ensure order consistency

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.image_filenames[idx])  # Same name for masks

        # Open Image & Mask
        image = Image.open(image_path).convert("RGB")  
        mask = Image.open(mask_path).convert("L")  # Convert mask to grayscale

        # Apply transformations (if any)
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask


In [24]:
# Define correct dataset path
DATASET_PATH = "../artifacts/output_dataset_dir"  # Change if needed

# Correct paths for train, val, and test sets
train_images = os.path.join(DATASET_PATH, "train/images")
train_masks = os.path.join(DATASET_PATH, "train/masks")
test_images = os.path.join(DATASET_PATH, "test/images")
test_masks = os.path.join(DATASET_PATH, "test/masks")
val_images = os.path.join(DATASET_PATH, "val/images")
val_masks = os.path.join(DATASET_PATH, "val/masks")


# Define transformations (normalize, resize)
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize to model input size
    transforms.ToTensor(),          # Convert to tensor
])

# Create Dataset Instances
train_dataset = TextSegmentationDataset(train_images, train_masks, transform=transform)
val_dataset = TextSegmentationDataset(val_images, val_masks, transform=transform)
test_dataset = TextSegmentationDataset(test_images, test_masks, transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)


Define U-Net Model

In [25]:
import torch
import torch.nn as nn
import torchvision.models as models

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        # Encoder: Use a pre-trained ResNet34 as feature extractor
        resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        self.encoder = nn.Sequential(*list(resnet.children())[:-2])  # Remove FC layer

        # Decoder: Upsampling layers
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, out_channels, kernel_size=1),  # Output layer
            nn.Upsample(size=(256, 256), mode="bilinear", align_corners=False)  # Ensure output shape
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# ✅ Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)

Define Training Pipeline: 
Loss Function
We'll use Dice Loss, which is great for segmentation tasks.

In [26]:
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, pred, target, smooth=1):
        pred = torch.sigmoid(pred)  # Convert logits to probabilities
        target = target.float()  # Ensure target is float

        # ✅ Ensure shapes match
        if pred.shape != target.shape:
            target = torch.nn.functional.interpolate(target, size=pred.shape[2:], mode="nearest")

        intersection = (pred * target).sum()
        return 1 - ((2. * intersection + smooth) / (pred.sum() + target.sum() + smooth))

criterion = DiceLoss()


 What this does:
✔️ Uses Dice Loss to handle segmentation better than Binary Cross Entropy
✔️ Ensures pred and target have the same shape

Optimizer & Learning Rate Scheduler
We'll use Adam optimizer with a learning rate scheduler to adjust learning over time.

In [27]:
import torch.optim as optim

# Define optimizer and learning rate scheduler
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=3, verbose=True)


What this does:
✔️ Uses Adam optimizer (good for deep networks)
✔️ Lowers LR when validation loss stops improving

Data Loader (Augmentations & Batch Processing)

In [30]:
import os
import cv2
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

# ✅ Define Dataset Class
class CustomDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_list = os.listdir(image_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_list[idx])
        mask_path = os.path.join(self.mask_dir, self.image_list[idx])

        # ✅ Load Image
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB

        # ✅ Load Mask (Ensure single channel)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  # Grayscale

        # ✅ Resize to match model input size (256x256)
        image = cv2.resize(image, (256, 256))
        mask = cv2.resize(mask, (256, 256))

        # ✅ Convert to Tensor
        image = transforms.ToTensor()(image)
        mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)  # Add channel dimension

        return image, mask

# ✅ Set Paths
image_dir = "../artifacts/output_dataset_dir/train/images"
mask_dir = "../artifacts/output_dataset_dir/train/masks"

# ✅ Check Dataset Size
print("Number of images:", len(os.listdir(image_dir)))
print("Number of masks:", len(os.listdir(mask_dir)))

# ✅ Create DataLoader
train_dataset = CustomDataset(image_dir, mask_dir)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)


Number of images: 7
Number of masks: 7


Train the Model: Define Training Loop

In [31]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0

    for images, masks in dataloader:
        images, masks = images.to(device), masks.to(device)

        # ✅ Forward Pass
        outputs = model(images)
        loss = criterion(outputs, masks)

        # ✅ Backward Pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(dataloader)


What this does:
✔️ Runs a full training epoch
✔️ Uses forward + backward pass
✔️ Updates weights using optimizer

Train for Multiple Epochs

In [32]:
num_epochs = 10
best_loss = float('inf')

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)

    # ✅ Adjust learning rate
    scheduler.step(train_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}")

    # ✅ Save best model
    if train_loss < best_loss:
        best_loss = train_loss
        torch.save(model.state_dict(), "best_model.pth")
        print("✅ Model Saved!")


Epoch [1/10], Loss: 0.1689
✅ Model Saved!
Epoch [2/10], Loss: 0.1433
✅ Model Saved!
Epoch [3/10], Loss: -0.0644
✅ Model Saved!
Epoch [4/10], Loss: -0.3976
✅ Model Saved!
Epoch [5/10], Loss: -0.6002
✅ Model Saved!
Epoch [6/10], Loss: -0.6499
✅ Model Saved!
Epoch [7/10], Loss: -0.6744
✅ Model Saved!
Epoch [8/10], Loss: -0.6794
✅ Model Saved!
Epoch [9/10], Loss: -0.6845
✅ Model Saved!
Epoch [10/10], Loss: -0.6890
✅ Model Saved!


Evaluate the Model

In [35]:
def evaluate(model, dataloader, device):
    model.eval()
    total_iou = 0
    num_samples = 0

    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)

            # Forward pass
            outputs = model(images)
            preds = torch.sigmoid(outputs) > 0.5  # Convert to binary mask
            preds = preds.float()  # Convert boolean to float (0s and 1s)
            masks = masks.float()

            # ✅ IoU Calculation (Fix bitwise operation issue)
            intersection = (preds * masks).sum()
            union = (preds + masks).clamp(0, 1).sum()  # Ensures values remain 0 or 1
            iou = intersection / (union + 1e-6)  # Add small value to prevent division by zero

            total_iou += iou.item()
            num_samples += 1

    return total_iou / num_samples  # Average IoU

# ✅ Load the best model correctly
model.load_state_dict(torch.load("best_model.pth", weights_only=True))  # Fix warning

# ✅ Run Evaluation
test_iou = evaluate(model, train_loader, device)
print(f"Test IoU: {test_iou:.4f}")


Test IoU: 3.8702
