In [19]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from sklearn.model_selection import train_test_split
import random

In [20]:
IMAGE_SIZE = 512
NUM_CLASSES = 11  
BATCH_SIZE = 4
EPOCHS = 10
LEARNING_RATE = 0.0001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
DATASET_PATH = "/kaggle/input/helen-face-segmentation-dataset/helenstar_release"
TRAIN_PATH = os.path.join(DATASET_PATH, "train")
TEST_PATH = os.path.join(DATASET_PATH, "test")

In [22]:
CLASS_NAMES = ["bg", "face", "lb", "rb", "le", "re", "nose", "ulip", "imouth", "llip", "hair"]

In [23]:
class HelenDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.images_path = []
        self.masks_path = []
        
        # Get all image files
        image_files = glob(os.path.join(data_path, "*_image.jpg"))
        
        for img_path in image_files:
            base_name = os.path.basename(img_path).replace("_image.jpg", "")
            mask_path = os.path.join(data_path, f"{base_name}_label.png")
            
            if os.path.exists(mask_path):
                self.images_path.append(img_path)
                self.masks_path.append(mask_path)
    
    def __len__(self):
        return len(self.images_path)
    
    def __getitem__(self, idx):
        img_path = self.images_path[idx]
        mask_path = self.masks_path[idx]
        
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        image = np.stack((image,)*3, axis=-1)  # Convert to 3 channels
        
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]
        
        image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
        mask = torch.from_numpy(mask).long()
        
        return image, mask

In [24]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

In [25]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=NUM_CLASSES, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Downsampling (encoder)
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        
        # Upsampling (decoder)
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)
            )
            self.ups.append(DoubleConv(feature*2, feature))
        
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
    
    def forward(self, x):
        skip_connections = []
        
        
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]  # Reverse for decoder
        
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)  # ConvTranspose2d
            skip_connection = skip_connections[idx//2]
            
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])
                
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)  # DoubleConv
        
        return self.final_conv(x)

In [26]:
class Augmentation:
    def __init__(self):
        pass
    
    def __call__(self, image, mask):
        image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
        mask = cv2.resize(mask, (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_NEAREST)
        
        if random.random() > 0.5:
            image = np.fliplr(image).copy()
            mask = np.fliplr(mask).copy()
        
        if random.random() > 0.5:
            angle = random.randint(-15, 15)
            M = cv2.getRotationMatrix2D((IMAGE_SIZE/2, IMAGE_SIZE/2), angle, 1)
            image = cv2.warpAffine(image, M, (IMAGE_SIZE, IMAGE_SIZE))
            mask = cv2.warpAffine(mask, M, (IMAGE_SIZE, IMAGE_SIZE), flags=cv2.INTER_NEAREST)
        
        return {"image": image, "mask": mask}


In [27]:
def train_fn(model, train_loader, optimizer, criterion, device):
    model.train()
    epoch_loss = 0
    
    for batch_idx, (data, targets) in enumerate(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        
        predictions = model(data)
        loss = criterion(predictions, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f"Batch {batch_idx}/{len(train_loader)}: Loss = {loss.item():.4f}")
    
    return epoch_loss / len(train_loader)

In [28]:
def eval_fn(model, val_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for data, targets in val_loader:
            data = data.to(device)
            targets = targets.to(device)
            
            predictions = model(data)
            loss = criterion(predictions, targets)
            
            epoch_loss += loss.item()
    
    return epoch_loss / len(val_loader)

In [29]:
def calculate_metrics(model, data_loader, device, num_classes=NUM_CLASSES):
    model.eval()
    
    confusion_matrix = np.zeros((num_classes, num_classes))
    
    with torch.no_grad():
        for data, targets in data_loader:
            data = data.to(device)
            targets = targets.to(device)
            
            outputs = model(data)
            _, predictions = torch.max(outputs, dim=1)
            
            for i in range(targets.shape[0]):
                for t, p in zip(targets[i].flatten(), predictions[i].flatten()):
                    confusion_matrix[t.item(), p.item()] += 1
    
    ious = []
    f1_scores = []
    
    for i in range(num_classes):
        tp = confusion_matrix[i, i]
        fp = confusion_matrix[:, i].sum() - tp
        fn = confusion_matrix[i, :].sum() - tp
        
        iou = tp / (tp + fp + fn + 1e-10)
        ious.append(iou)
        
        # Calculate F1 score
        precision = tp / (tp + fp + 1e-10)
        recall = tp / (tp + fn + 1e-10)
        f1 = 2 * precision * recall / (precision + recall + 1e-10)
        f1_scores.append(f1)
    
    return ious, f1_scores

In [30]:
def visualize_predictions(model, data_loader, device, num_samples=3):
    model.eval()
    
    samples = []
    counter = 0
    
    with torch.no_grad():
        for data, targets in data_loader:
            if counter >= num_samples:
                break
            
            data = data.to(device)
            
            outputs = model(data)
            _, predictions = torch.max(outputs, dim=1)
            
            data = data.cpu()
            targets = targets.cpu()
            predictions = predictions.cpu()
            
            for i in range(data.shape[0]):
                if counter >= num_samples:
                    break
                
                samples.append((data[i], targets[i], predictions[i]))
                counter += 1
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    
    for i, (image, target, prediction) in enumerate(samples):
        img = image.permute(1, 2, 0).numpy()
        axes[i, 0].imshow(img[:,:,0], cmap='gray')
        axes[i, 0].set_title("Original Image")
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(target, cmap='nipy_spectral', vmin=0, vmax=NUM_CLASSES-1)
        axes[i, 1].set_title("Ground Truth")
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(prediction, cmap='nipy_spectral', vmin=0, vmax=NUM_CLASSES-1)
        axes[i, 2].set_title("Prediction")
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig("prediction_samples.png")
    plt.close()

In [31]:
def create_visualization_mask(pred_mask, output_path):
    colors = [
        [0, 0, 0],       # bg - black
        [255, 0, 0],     # face - red
        [0, 255, 0],     # lb - green
        [0, 0, 255],     # rb - blue
        [255, 255, 0],   # le - yellow
        [255, 0, 255],   # re - magenta
        [0, 255, 255],   # nose - cyan
        [128, 0, 0],     # ulip - maroon
        [0, 128, 0],     # imouth - dark green
        [0, 0, 128],     # llip - navy
        [128, 128, 0]    # hair - olive
    ]
    
    h, w = pred_mask.shape
    viz_mask = np.zeros((h, w, 3), dtype=np.uint8)
    
    for i in range(NUM_CLASSES):
        viz_mask[pred_mask == i] = colors[i]
    
    cv2.imwrite(output_path, cv2.cvtColor(viz_mask, cv2.COLOR_RGB2BGR))
    
    return viz_mask

In [32]:
def predict_on_image(model, image_path, device):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    img = np.stack((img,)*3, axis=-1)  # Convert to 3 channels
    img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
    
    img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0
    img_tensor = img_tensor.unsqueeze(0).to(device)  # Add batch dimension
    
    model.eval()
    with torch.no_grad():
        output = model(img_tensor)
        mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()
    
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.imshow(img[:,:,0], cmap='gray')
    plt.title("Input Image")
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(mask, cmap='nipy_spectral', vmin=0, vmax=NUM_CLASSES-1)
    plt.title("Predicted Segmentation")
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig("new_prediction.png")
    plt.close()
    
    viz_mask = create_visualization_mask(mask, "colorized_prediction.png")
    
    return mask


In [33]:
torch.cuda.empty_cache()

In [34]:
def main():
    train_transform = Augmentation()
    val_transform = Augmentation()  # No augmentation for validation
    
    train_dataset = HelenDataset(TRAIN_PATH, transform=train_transform)
    
    train_indices, val_indices = train_test_split(
        range(len(train_dataset)),
        test_size=0.2,
        random_state=42
    )
    
    train_subset = torch.utils.data.Subset(train_dataset, train_indices)
    val_subset = torch.utils.data.Subset(train_dataset, val_indices)
    
    test_dataset = HelenDataset(TEST_PATH, transform=val_transform)
    
    train_loader = DataLoader(
        train_subset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_subset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    print(f"Training samples: {len(train_subset)}")
    print(f"Validation samples: {len(val_subset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    model = UNet().to(DEVICE)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=5,
        verbose=True
    )
    
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    print(f"Starting training on device: {DEVICE}")
    
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch+1}/{EPOCHS}")
        
        train_loss = train_fn(model, train_loader, optimizer, criterion, DEVICE)
        
        val_loss = eval_fn(model, val_loader, criterion, DEVICE)
        
        scheduler.step(val_loss)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_model.pth")
            print("Saved best model!")
    
    model.load_state_dict(torch.load("best_model.pth"))
    
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig("training_history.png")
    plt.close()
    
    print("Calculating metrics on test set...")
    ious, f1_scores = calculate_metrics(model, test_loader, DEVICE)
    
    print("\nPer-class metrics:")
    print("------------------")
    
    for i in range(NUM_CLASSES):
        print(f"Class {i} ({CLASS_NAMES[i]}):")
        print(f"  IoU: {ious[i]:.4f}")
        print(f"  F1 Score: {f1_scores[i]:.4f}")
    
    mean_iou = np.mean(ious)
    mean_f1 = np.mean(f1_scores)
    
    print("\nOverall metrics:")
    print("----------------")
    print(f"Mean IoU: {mean_iou:.4f}")
    print(f"Mean F1 Score: {mean_f1:.4f}")
    
    print("Generating prediction visualizations...")
    visualize_predictions(model, test_loader, DEVICE)
    
    print("Training and evaluation complete!")
    print("Trained model saved as: best_model.pth")
    print("Training history plot saved as: training_history.png")
    print("Prediction samples saved as: prediction_samples.png")

if __name__ == "__main__":
    main()

Training samples: 1599
Validation samples: 400
Test samples: 100
Starting training on device: cuda
Epoch 1/10
Batch 0/400: Loss = 2.6170
Batch 10/400: Loss = 2.4566
Batch 20/400: Loss = 2.2720
Batch 30/400: Loss = 2.1051
Batch 40/400: Loss = 2.0392
Batch 50/400: Loss = 1.8862
Batch 60/400: Loss = 1.8740
Batch 70/400: Loss = 1.8732
Batch 80/400: Loss = 1.8154
Batch 90/400: Loss = 1.7794
Batch 100/400: Loss = 1.7641
Batch 110/400: Loss = 1.7637
Batch 120/400: Loss = 1.6827
Batch 130/400: Loss = 1.7167
Batch 140/400: Loss = 1.5849
Batch 150/400: Loss = 1.6691
Batch 160/400: Loss = 1.5847
Batch 170/400: Loss = 1.4836
Batch 180/400: Loss = 1.5329
Batch 190/400: Loss = 1.4410
Batch 200/400: Loss = 1.4458
Batch 210/400: Loss = 1.3357
Batch 220/400: Loss = 1.3806
Batch 230/400: Loss = 1.5001
Batch 240/400: Loss = 1.3326
Batch 250/400: Loss = 1.3940
Batch 260/400: Loss = 1.3125
Batch 270/400: Loss = 1.2669
Batch 280/400: Loss = 1.2646
Batch 290/400: Loss = 1.1549
Batch 300/400: Loss = 1.0789
Ba

  model.load_state_dict(torch.load("best_model.pth"))


Calculating metrics on test set...

Per-class metrics:
------------------
Class 0 (bg):
  IoU: 0.8903
  F1 Score: 0.9420
Class 1 (face):
  IoU: 0.7718
  F1 Score: 0.8712
Class 2 (lb):
  IoU: 0.1979
  F1 Score: 0.3305
Class 3 (rb):
  IoU: 0.2126
  F1 Score: 0.3507
Class 4 (le):
  IoU: 0.2570
  F1 Score: 0.4090
Class 5 (re):
  IoU: 0.2208
  F1 Score: 0.3618
Class 6 (nose):
  IoU: 0.7769
  F1 Score: 0.8744
Class 7 (ulip):
  IoU: 0.4533
  F1 Score: 0.6238
Class 8 (imouth):
  IoU: 0.5382
  F1 Score: 0.6998
Class 9 (llip):
  IoU: 0.4511
  F1 Score: 0.6217
Class 10 (hair):
  IoU: 0.6310
  F1 Score: 0.7737

Overall metrics:
----------------
Mean IoU: 0.4910
Mean F1 Score: 0.6235
Generating prediction visualizations...
Training and evaluation complete!
Trained model saved as: best_model.pth
Training history plot saved as: training_history.png
Prediction samples saved as: prediction_samples.png
