In [1]:
import os
from glob import glob
import shutil
from pathlib import Path, PurePath
import json
import time

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast

from PIL import Image
import cv2

# Import the new library
import segmentation_models_pytorch as smp

from torchvision import transforms as T
from tqdm import tqdm
import albumentations as A
from sklearn.model_selection import train_test_split


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%matplotlib inline
torch.manual_seed(42)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")


Using device: mps


In [3]:
IMAGES = './train/images'
MASKS = './train/masks'
classes_csv = './train/_classes.csv'


In [4]:
class porosity_Dataset(Dataset):
    def __init__(self, image_path, mask_path, x, mean, std, transform=None, patch=False,
                 image_ext=".jpg", mask_ext=".png"):
        self.img_path = image_path
        self.mask_path = mask_path
        self.x = x
        self.mean = mean
        self.std = std
        self.transform = transform
        self.patch = patch
        self.image_ext = image_ext
        self.mask_ext = mask_ext
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        image_filename = self.x[idx] + self.image_ext
        mask_filename = self.x[idx] + "_mask" + self.mask_ext
        img_path = os.path.join(self.img_path, image_filename)
        mask_path = os.path.join(self.mask_path, mask_filename)

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if img.shape[:2] != mask.shape[:2]:
            mask = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)

        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']

        # Normalization and tensor conversion
        t = T.Compose([T.ToTensor(), T.Normalize(self.mean, self.std)])
        img = t(img)
        mask = torch.from_numpy(mask).long()

        if self.patch:
            img, mask = self.tiles(img, mask)
            
        return img, mask

    def tiles(self, img, mask, size=256, threshold=0.01, min_classes=2):
        img_patches = img.unfold(1, size, size).unfold(2, size, size)
        img_patches = img_patches.contiguous().view(3, -1, size, size).permute(1, 0, 2, 3)

        mask_patches = mask.unfold(0, size, size).unfold(1, size, size)
        mask_patches = mask_patches.contiguous().view(-1, size, size)

        keep_indices = []
        for i, patch in enumerate(mask_patches):
            unique_classes = torch.unique(patch)
            fg = (patch != 0).sum().item()
            ratio = fg / (size * size)
            if len(unique_classes) >= min_classes or ratio > threshold or np.random.rand() < 0.1:
                keep_indices.append(i)

        if not keep_indices:
            keep_indices.append(np.random.randint(0, len(img_patches)))

        return img_patches[keep_indices], mask_patches[keep_indices]


In [5]:
def custom_collate(batch):
    images = [item[0] for item in batch]
    masks = [item[1] for item in batch]
    return images, masks


In [6]:
class FocalLoss(nn.Module):
    def __init__(self, weight=None, gamma=2, ignore_index=-100):
        super().__init__()
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)

    def forward(self, input, target):
        logpt = -self.ce(input, target)
        pt = torch.exp(logpt)
        loss = ((1 - pt) ** self.gamma) * -logpt
        return loss.mean()

class DiceLoss(nn.Module):
    def __init__(self, n_classes, ignore_index=None):
        super(DiceLoss, self).__init__()
        self.n_classes = n_classes
        self.ignore_index = ignore_index

    def forward(self, input, target, smooth=1e-6):
        input_soft = F.softmax(input, dim=1)
        
        if self.ignore_index is not None:
            mask = target != self.ignore_index
            target = target[mask]
            input_soft = input_soft.permute(0, 2, 3, 1)[mask.unsqueeze(-1).expand_as(input_soft.permute(0, 2, 3, 1))]
            input_soft = input_soft.reshape(-1, self.n_classes)


        target_one_hot = F.one_hot(target, num_classes=self.n_classes).float()
        
        input_flat = input_soft.contiguous().view(-1)
        target_flat = target_one_hot.contiguous().view(-1)
        
        intersection = (input_flat * target_flat).sum()
        dice_score = (2. * intersection + smooth) / (input_flat.sum() + target_flat.sum() + smooth)
        return 1 - dice_score

class DiceFocalLoss(nn.Module):
    def __init__(self, n_classes, weight=None, gamma=2, alpha=0.5, ignore_index=-100):
        super(DiceFocalLoss, self).__init__()
        self.focal_loss = FocalLoss(weight=weight, gamma=gamma, ignore_index=ignore_index)
        self.dice_loss = DiceLoss(n_classes=n_classes, ignore_index=ignore_index)
        self.alpha = alpha
        self.ignore_index = ignore_index

    def forward(self, input, target):
        focal = self.focal_loss(input, target)
        dice = self.dice_loss(input, target)
        return self.alpha * focal + (1 - self.alpha) * dice

In [7]:
def pixel_accuracy(output, mask):
    with torch.no_grad():
        output = torch.argmax(F.softmax(output, dim=1), dim=1)
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy

In [8]:
def mIoU(pred_mask, mask, smooth=1e-10, n_classes=12):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)
        
        iou_per_class = []
        for c in range(n_classes):
            true_class = pred_mask == c
            true_label = mask == c
            
            if true_label.long().sum().item() == 0:
                iou_per_class.append(np.nan)
            else:
                intersect = torch.logical_and(true_class, true_label).sum().float().item()
                union = torch.logical_or(true_class, true_label).sum().float().item()
                iou = (intersect + smooth) / (union + smooth)
                iou_per_class.append(iou)
        return np.nanmean(iou_per_class)

In [9]:
def compute_per_class_iou(preds, labels, num_classes=12):
    ious = []
    preds = preds.view(-1)
    labels = labels.view(-1)
    for cls in range(num_classes):
        pred_inds = preds == cls
        label_inds = labels == cls
        intersection = (pred_inds & label_inds).sum().item()
        union = (pred_inds | label_inds).sum().item()
        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append(intersection / union)
    return ious

In [10]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [11]:
def fit(epochs, model, train_loader, val_loader, criterion, optimizer, scheduler, patch=False, num_classes=12):
    train_losses, val_losses = [], []
    train_iou, val_iou = [], []
    train_acc, val_acc = [], []
    lrs = []
    min_loss = np.inf
    no_improve = 0

    model.to(device)
    start_time = time.time()

    for epoch in range(epochs):
        epoch_start = time.time()
        model.train()
        running_loss, running_iou, running_acc, total_batches = 0, 0, 0, 0

        for image_tiles, mask_tiles in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
            images = torch.cat(image_tiles, dim=0).to(device)
            masks = torch.cat(mask_tiles, dim=0).to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()
            running_iou += mIoU(outputs, masks)
            running_acc += pixel_accuracy(outputs, masks)
            total_batches += 1
            lrs.append(get_lr(optimizer))

        model.eval()
        val_loss, val_iou_score, val_accuracy, val_batches = 0, 0, 0, 0
        all_class_ious = []
        with torch.no_grad():
            for image_tiles, mask_tiles in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
                images = torch.cat(image_tiles, dim=0).to(device)
                masks = torch.cat(mask_tiles, dim=0).to(device)
                
                outputs = model(images)
                val_loss += criterion(outputs, masks).item()
                val_iou_score += mIoU(outputs, masks)
                val_accuracy += pixel_accuracy(outputs, masks)

                preds = torch.argmax(outputs, dim=1)
                for pred, true_mask in zip(preds, masks):
                    all_class_ious.append(compute_per_class_iou(pred, true_mask, num_classes))
                val_batches += 1
        
        avg_train_loss = running_loss / total_batches
        avg_val_loss = val_loss / val_batches
        avg_train_iou = running_iou / total_batches
        avg_val_iou = val_iou_score / val_batches
        avg_train_acc = running_acc / total_batches
        avg_val_acc = val_accuracy / val_batches

        train_losses.append(avg_train_loss); val_losses.append(avg_val_loss)
        train_iou.append(avg_train_iou); val_iou.append(avg_val_iou)
        train_acc.append(avg_train_acc); val_acc.append(avg_val_acc)
        
        mean_class_ious = np.nanmean(np.array(all_class_ious), axis=0)

        print(f"Epoch [{epoch + 1}/{epochs}] - "
              f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, "
              f"Train mIoU: {avg_train_iou:.4f}, Val mIoU: {avg_val_iou:.4f}, "
              f"Train Acc: {avg_train_acc:.4f}, Val Acc: {avg_val_acc:.4f}, "
              f"LR: {lrs[-1]:.6f}, "
              f"Time: {(time.time() - epoch_start)/60:.2f} mins")

        print("\n📊 Per-class IoU (Validation):")
        for i, iou in enumerate(mean_class_ious):
            print(f"  Class {i}: IoU = {iou:.4f}")

        if avg_val_loss < min_loss:
            print(f"Validation loss decreased ({min_loss:.4f} -> {avg_val_loss:.4f}). Saving model.")
            min_loss = avg_val_loss
            no_improve = 0
            torch.save(model.state_dict(), 'model_transfer_learning.pt')
        else:
            no_improve += 1
            print(f"No improvement in validation loss for {no_improve} epochs.")
            if no_improve >= 10:
                print("Early stopping triggered.")
                break

    total_time = (time.time() - start_time) / 60
    print(f"\nTraining completed in {total_time:.2f} minutes.")
    history = {'train_loss': train_losses, 'val_loss': val_losses, 'train_miou': train_iou, 'val_miou': val_iou, 'train_acc': train_acc, 'val_acc': val_acc, 'lrs': lrs}
    return history

In [12]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]


In [13]:
t_train = A.Compose([A.VerticalFlip(), A.HorizontalFlip(), A.GridDistortion(p=0.2),
                     A.GaussNoise(), A.RandomBrightnessContrast((0, 0.5), (0, 0.5))])
t_val = A.Compose([A.HorizontalFlip(), A.GridDistortion(p=0.2)])


In [14]:
x_all = sorted([os.path.splitext(f)[0] for f in os.listdir(IMAGES)])
x_temp, x_test = train_test_split(x_all, test_size=0.10, random_state=42)
x_train, x_val = train_test_split(x_temp, test_size=0.10, random_state=42)


In [15]:
train_dataset = porosity_Dataset(IMAGES, MASKS, x_train, mean, std, t_train, patch=True)
val_dataset = porosity_Dataset(IMAGES, MASKS, x_val, mean, std, t_val, patch=True)
test_dataset = porosity_Dataset(IMAGES, MASKS, x_test, mean, std, t_val, patch=True)


In [16]:
batch_size = 3
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

In [17]:
def get_class_distribution(dataset, num_classes=12):
    class_counts = np.zeros(num_classes, dtype=int)
    for _, mask in DataLoader(dataset, batch_size=1):
        for m in mask:
            for cls in range(num_classes):
                class_counts[cls] += torch.sum(m == cls).item()
    return class_counts

train_class_counts = get_class_distribution(train_dataset)
val_class_counts = get_class_distribution(val_dataset)

print("Train class distribution:", train_class_counts)
print("Val class distribution:", val_class_counts)

Train class distribution: [149394631    494843    479811  30599976     72652    165127         0
         0         0         0         0         0]
Val class distribution: [16748549     4748     7403  3139462     8449    14333        0        0
        0        0        0        0]


In [18]:
model = smp.Unet(
    encoder_name="resnet34",        # A powerful and well-tested encoder
    encoder_weights="imagenet",     # Use weights pre-trained on ImageNet
    in_channels=3,                  # Number of input channels (3 for RGB)
    classes=12,                     # Number of output classes
)
model.to(device)

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [19]:
def get_class_weights_from_masks(
    masks_dir, 
    num_classes=12, 
    background_class_id=0, 
    background_weight_multiplier=0.1
):
    class_counts = np.zeros(num_classes, dtype=np.int64)
    total_pixels = 0

    for mask_name in os.listdir(masks_dir):
        if mask_name.endswith(".png"):
            mask_path = os.path.join(masks_dir, mask_name)
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            
            if mask is not None:
                total_pixels += mask.size
                for cls in range(num_classes):
                    class_counts[cls] += np.sum(mask == cls)
    class_frequencies = class_counts / total_pixels
    class_frequencies[class_frequencies == 0] = 1e-6
    class_weights = 1.0 / class_frequencies
    class_weights[background_class_id] *= background_weight_multiplier
    class_weights = class_weights * (num_classes / np.sum(class_weights))
    return torch.tensor(class_weights, dtype=torch.float32)

In [20]:
class_weights = get_class_weights_from_masks(MASKS, num_classes=12)

In [21]:
class_weights = class_weights.to(device)


In [22]:
max_lr = 1e-3
epoch = 30
weight_decay = 1e-4


In [23]:
criterion = DiceFocalLoss(n_classes=12, weight=class_weights, gamma=2, alpha=0.5, ignore_index=0)
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epoch, steps_per_epoch=len(train_loader))


In [24]:
history = fit(epoch, model, train_loader, val_loader, criterion, optimizer, scheduler, patch=True)


Epoch 1 Training:   0%|          | 0/236 [00:00<?, ?it/s]

Epoch 1 Training:   6%|▌         | 14/236 [02:00<31:54,  8.62s/it]


KeyboardInterrupt: 

In [None]:
def plot_loss(history):
    plt.plot(history['val_loss'], label='val', marker='o')
    plt.plot(history['train_loss'], label='train', marker='o')
    plt.title('Loss per epoch'); plt.ylabel('loss');
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()
    
def plot_score(history):
    plt.plot(history['train_miou'], label='train_mIoU', marker='*')
    plt.plot(history['val_miou'], label='val_mIoU',  marker='*')
    plt.title('Score per epoch'); plt.ylabel('mean IoU')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()
    
def plot_acc(history):
    plt.plot(history['train_acc'], label='train_accuracy', marker='*')
    plt.plot(history['val_acc'], label='val_accuracy',  marker='*')
    plt.title('Accuracy per epoch'); plt.ylabel('Accuracy')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()

plot_loss(history)
plot_score(history)
plot_acc(history)