In [8]:
from UNet import Unet
import utils
import engine
from learning_rate_range_test import LRTest

import os
import numpy as np
import matplotlib.pyplot as plt



import albumentations as A
import gc
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim


  from .autonotebook import tqdm as notebook_tqdm


In [9]:
model=Unet(channels=[3, 64, 128, 256, 512, 1024], no_classes=150)
x   = torch.Tensor(np.random.rand(2, 3, 572, 572))
y=model(x)
print(y.shape)

torch.Size([2, 150, 388, 388])


In [10]:
datadir  = './ADE20K/images/training/'
maskdir = './ADE20K/annotations/training/'
val_datadir  = './ADE20K/images/validation/'
val_maskdir = './ADE20K/annotations/validation/'




In [11]:
import numpy as np
from copy import deepcopy
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import random
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from scipy import ndimage

class BaseDataSet(Dataset):
    def __init__(self, root, split, mean, std, base_size=None, augment=True, val=False,
                crop_size=321, scale=True, flip=True, rotate=False, blur=False, return_id=False):
        self.root = root
        self.split = split
        self.mean = mean
        self.std = std
        self.augment = augment
        self.crop_size = crop_size
        if self.augment:
            self.base_size = base_size
            self.scale = scale
            self.flip = flip
            self.rotate = rotate
            self.blur = blur
        self.val = val
        self.files = []
        self._set_files()
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(mean, std)
        self.return_id = return_id

        cv2.setNumThreads(0)

    def _set_files(self):
        raise NotImplementedError
    
    def _load_data(self, index):
        raise NotImplementedError

    def _val_augmentation(self, image, label):
        if self.crop_size:
            h, w = label.shape
            # Scale the smaller side to crop size
            if h < w:
                h, w = (self.crop_size, int(self.crop_size * w / h))
            else:
                h, w = (int(self.crop_size * h / w), self.crop_size)

            image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
            label = Image.fromarray(label).resize((w, h), resample=Image.NEAREST)
            label = np.asarray(label, dtype=np.int32)

            # Center Crop
            h, w = label.shape
            start_h = (h - self.crop_size )// 2
            start_w = (w - self.crop_size )// 2
            end_h = start_h + self.crop_size
            end_w = start_w + self.crop_size
            image = image[start_h:end_h, start_w:end_w]
            label = label[start_h:end_h, start_w:end_w]
        return image, label

    def _augmentation(self, image, label):
        h, w, _ = image.shape
        # Scaling, we set the bigger to base size, and the smaller 
        # one is rescaled to maintain the same ratio, if we don't have any obj in the image, re-do the processing
        if self.base_size:
            if self.scale:
                longside = random.randint(int(self.base_size*0.5), int(self.base_size*2.0))
            else:
                longside = self.base_size
            h, w = (longside, int(1.0 * longside * w / h + 0.5)) if h > w else (int(1.0 * longside * h / w + 0.5), longside)
            image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
            label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST)
    
        h, w, _ = image.shape
        # Rotate the image with an angle between -10 and 10
        if self.rotate:
            angle = random.randint(-10, 10)
            center = (w / 2, h / 2)
            rot_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
            image = cv2.warpAffine(image, rot_matrix, (w, h), flags=cv2.INTER_LINEAR)#, borderMode=cv2.BORDER_REFLECT)
            label = cv2.warpAffine(label, rot_matrix, (w, h), flags=cv2.INTER_NEAREST)#,  borderMode=cv2.BORDER_REFLECT)

        # Padding to return the correct crop size
        if self.crop_size:
            pad_h = max(self.crop_size - h, 0)
            pad_w = max(self.crop_size - w, 0)
            pad_kwargs = {
                "top": 0,
                "bottom": pad_h,
                "left": 0,
                "right": pad_w,
                "borderType": cv2.BORDER_CONSTANT,}
            if pad_h > 0 or pad_w > 0:
                image = cv2.copyMakeBorder(image, value=0, **pad_kwargs)
                label = cv2.copyMakeBorder(label, value=0, **pad_kwargs)
            
            # Cropping 
            h, w, _ = image.shape
            start_h = random.randint(0, h - self.crop_size)
            start_w = random.randint(0, w - self.crop_size)
            end_h = start_h + self.crop_size
            end_w = start_w + self.crop_size
            image = image[start_h:end_h, start_w:end_w]
            label = label[start_h:end_h, start_w:end_w]

        # Random H flip
        if self.flip:
            if random.random() > 0.5:
                image = np.fliplr(image).copy()
                label = np.fliplr(label).copy()

        # Gaussian Blud (sigma between 0 and 1.5)
        if self.blur:
            sigma = random.random()
            ksize = int(3.3 * sigma)
            ksize = ksize + 1 if ksize % 2 == 0 else ksize
            image = cv2.GaussianBlur(image, (ksize, ksize), sigmaX=sigma, sigmaY=sigma, borderType=cv2.BORDER_REFLECT_101)
        return image, label
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        image, label, image_id = self._load_data(index)
        if self.val:
            image, label = self._val_augmentation(image, label)
        elif self.augment:
            image, label = self._augmentation(image, label)

        label = torch.from_numpy(np.array(label, dtype=np.int32)).long()
        image = Image.fromarray(np.uint8(image))
        if self.return_id:
            return  self.normalize(self.to_tensor(image)), label, image_id
        return self.normalize(self.to_tensor(image)), label

    def __repr__(self):
        fmt_str = "Dataset: " + self.__class__.__name__ + "\n"
        fmt_str += "    # data: {}\n".format(self.__len__())
        fmt_str += "    Split: {}\n".format(self.split)
        fmt_str += "    Root: {}".format(self.root)
        return fmt_str

class BaseDataLoader(DataLoader):
    def __init__(self, dataset, batch_size, shuffle, num_workers, val_split = 0.0):
        self.shuffle = shuffle
        self.dataset = dataset
        self.nbr_examples = len(dataset)
        if val_split: self.train_sampler, self.val_sampler = self._split_sampler(val_split)
        else: self.train_sampler, self.val_sampler = None, None

        self.init_kwargs = {
            'dataset': self.dataset,
            'batch_size': batch_size,
            'shuffle': self.shuffle,
            'num_workers': num_workers,
            'pin_memory': True
        }
        super(BaseDataLoader, self).__init__(sampler=self.train_sampler, **self.init_kwargs)

    def _split_sampler(self, split):
        if split == 0.0:
            return None, None
        
        self.shuffle = False

        split_indx = int(self.nbr_examples * split)
        np.random.seed(0)
        
        indxs = np.arange(self.nbr_examples)
        np.random.shuffle(indxs)
        train_indxs = indxs[split_indx:]
        val_indxs = indxs[:split_indx]
        self.nbr_examples = len(train_indxs)

        train_sampler = SubsetRandomSampler(train_indxs)
        val_sampler = SubsetRandomSampler(val_indxs)
        return train_sampler, val_sampler

In [12]:
ADE20K_palette = [0,0,0,120,120,120,180,120,120,6,230,230,80,50,50,4,200,
                    3,120,120,80,140,140,140,204,5,255,230,230,230,4,250,7,224,
                    5,255,235,255,7,150,5,61,120,120,70,8,255,51,255,6,82,143,
                    255,140,204,255,4,255,51,7,204,70,3,0,102,200,61,230,250,255,
                    6,51,11,102,255,255,7,71,255,9,224,9,7,230,220,220,220,255,9,
                    92,112,9,255,8,255,214,7,255,224,255,184,6,10,255,71,255,41,
                    10,7,255,255,224,255,8,102,8,255,255,61,6,255,194,7,255,122,8,
                    0,255,20,255,8,41,255,5,153,6,51,255,235,12,255,160,150,20,0,
                    163,255,140,140,140,250,10,15,20,255,0,31,255,0,255,31,0,255,224
                    ,0,153,255,0,0,0,255,255,71,0,0,235,255,0,173,255,31,0,255,11,200,
                    200,255,82,0,0,255,245,0,61,255,0,255,112,0,255,133,255,0,0,255,
                    163,0,255,102,0,194,255,0,0,143,255,51,255,0,0,82,255,0,255,41,0,
                    255,173,10,0,255,173,255,0,0,255,153,255,92,0,255,0,255,255,0,245,
                    255,0,102,255,173,0,255,0,20,255,184,184,0,31,255,0,255,61,0,71,255,
                    255,0,204,0,255,194,0,255,82,0,10,255,0,112,255,51,0,255,0,194,255,0,
                    122,255,0,255,163,255,153,0,0,255,10,255,112,0,143,255,0,82,0,255,163,
                    255,0,255,235,0,8,184,170,133,0,255,0,255,92,184,0,255,255,0,31,0,184,
                    255,0,214,255,255,0,112,92,255,0,0,224,255,112,224,255,70,184,160,163,
                    0,255,153,0,255,71,255,0,255,0,163,255,204,0,255,0,143,0,255,235,133,255,
                    0,255,0,235,245,0,255,255,0,122,255,245,0,10,190,212,214,255,0,0,204,255,
                    20,0,255,255,255,0,0,153,255,0,41,255,0,255,204,41,0,255,41,255,0,173,0,
                    255,0,245,255,71,0,255,122,0,255,0,255,184,0,92,255,184,255,0,0,133,255,
                    255,214,0,25,194,194,102,255,0,92,0,255]

In [13]:
# Check what's in the base directory

import numpy as np
import os
import torch
import cv2
from PIL import Image
from glob import glob
from torch.utils.data import Dataset
from torchvision import transforms


class ADE20KDataset(BaseDataSet):
    """
    ADE20K dataset 
    http://groups.csail.mit.edu/vision/datasets/ADE20K/
    """
    def __init__(self, **kwargs):
        self.num_classes = 150
        self.palette = ADE20K_palette
        super(ADE20KDataset, self).__init__(**kwargs)

    def _set_files(self):
        if self.split in  ["training", "validation"]:
            self.image_dir = os.path.join(self.root, 'images', self.split)
            self.label_dir = os.path.join(self.root, 'annotations', self.split)
            self.files = [os.path.basename(path).split('.')[0] for path in glob(self.image_dir + '/*.jpg')]
        else: raise ValueError(f"Invalid split name {self.split}")
    
    def _load_data(self, index):
        image_id = self.files[index]
        image_path = os.path.join(self.image_dir, image_id + '.jpg')
        label_path = os.path.join(self.label_dir, image_id + '.png')
        image = np.asarray(Image.open(image_path).convert('RGB'), dtype=np.float32)
        label = np.asarray(Image.open(label_path), dtype=np.int32) - 1 # from -1 to 149
        return image, label, image_id

class ADE20K(BaseDataLoader):
    def __init__(self, data_dir, batch_size, split, crop_size=None, base_size=None, scale=True, num_workers=1, val=False,
                    shuffle=False, flip=False, rotate=False, blur= False, augment=False, val_split= None, return_id=False):

        self.MEAN = [0.48897059, 0.46548275, 0.4294]
        self.STD = [0.22861765, 0.22948039, 0.24054667]

        kwargs = {
            'root': data_dir,
            'split': split,
            'mean': self.MEAN,
            'std': self.STD,
            'augment': augment,
            'crop_size': crop_size,
            'base_size': base_size,
            'scale': scale,
            'flip': flip,
            'blur': blur,
            'rotate': rotate,
            'return_id': return_id,
            'val': val
        }

        self.dataset = ADE20KDataset(**kwargs)
        super(ADE20K, self).__init__(self.dataset, batch_size, shuffle, num_workers, val_split)

In [None]:
# Training data loader
train_loader = ADE20K(
    data_dir='/path/to/your/ade20k/dataset',
    batch_size=16,
    split='training',
    crop_size=512,
    base_size=520,
    scale=True,
    num_workers=4,
    shuffle=True,
    flip=True,
    rotate=True,
    blur=True,
    augment=True
)

# Validation data loader
val_loader = ADE20K(
    data_dir='Desktop/UNET_ADEK20K/ade/ADEChallengeData2016/images/training/',
    batch_size=16,
    split='validation',
    crop_size=512,
    base_size=520,
    scale=False,
    num_workers=4,
    shuffle=False,
    augment=False,
    val=True
)

ValueError: num_samples should be a positive integer value, but got num_samples=0

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
from torch.cuda.amp import autocast, GradScaler
from typing import Tuple, List, Optional
import logging

def compute_mIoU(pred: torch.Tensor, label: torch.Tensor, num_classes: int, ignore_index: int = -1) -> float:
    """
    Compute mean Intersection over Union (mIoU) for semantic segmentation.
    
    Args:
        pred: Predicted segmentation masks [N, H, W] or [N*H*W]
        label: Ground truth masks [N, H, W] or [N*H*W]
        num_classes: Number of classes
        ignore_index: Index to ignore in computation
        
    Returns:
        Mean IoU score
    """
    pred = pred.view(-1)
    label = label.view(-1)
    
    # Remove ignore_index pixels
    if ignore_index is not None:
        valid_mask = label != ignore_index
        pred = pred[valid_mask]
        label = label[valid_mask]
    
    ious = []
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = label == cls
        
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        
        if union == 0:
            ious.append(float('nan'))  # Ignore classes not present
        else:
            ious.append(intersection / union)
    
    return np.nanmean(ious)

def compute_mIoU_batch_efficient(pred: torch.Tensor, label: torch.Tensor, num_classes: int, 
                                ignore_index: int = -1) -> float:
    """
    More efficient batch-wise mIoU computation using confusion matrix.
    """
    pred = pred.view(-1)
    label = label.view(-1)
    
    # Remove ignore_index pixels
    if ignore_index is not None:
        valid_mask = label != ignore_index
        pred = pred[valid_mask]
        label = label[valid_mask]
    
    # Create confusion matrix
    mask = (label >= 0) & (label < num_classes)
    hist = torch.bincount(
        num_classes * label[mask] + pred[mask],
        minlength=num_classes ** 2
    ).reshape(num_classes, num_classes).float()
    
    # Compute IoU for each class
    diag = torch.diag(hist)
    union = hist.sum(dim=1) + hist.sum(dim=0) - diag
    
    # Avoid division by zero
    ious = diag / torch.clamp(union, min=1e-8)
    
    # Return mean IoU, ignoring classes not present
    valid_ious = ious[union > 0]
    return valid_ious.mean().item() if len(valid_ious) > 0 else 0.0

class SegmentationTrainer:
    def __init__(self, model, train_loader, val_loader, num_classes=150, 
                 lr=1e-4, ignore_index=-1, use_amp=True, log_interval=10):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.use_amp = use_amp
        self.log_interval = log_interval
        
        # Setup optimizer and scheduler
        self.optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
        self.scheduler = optim.lr_scheduler.PolynomialLR(
            self.optimizer, total_iters=100, power=0.9
        )
        
        # Setup loss function with class weights if needed
        self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
        
        # Setup mixed precision training
        self.scaler = GradScaler() if use_amp else None
        
        # Tracking variables
        self.best_miou = 0.0
        self.train_losses = []
        self.val_mious = []
        
    def train_epoch(self, epoch: int) -> Tuple[float, float]:
        """Train for one epoch."""
        self.model.train()
        torch.cuda.reset_peak_memory_stats()
        start_time = time.time()
        
        running_loss = 0.0
        num_batches = 0
        
        for batch_idx, (images, labels) in enumerate(self.train_loader):
            images, labels = images.cuda(non_blocking=True), labels.cuda(non_blocking=True)
            
            self.optimizer.zero_grad()
            
            if self.use_amp:
                with autocast():
                    outputs = self.model(images)
                    loss = self.criterion(outputs, labels)
                
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
            
            running_loss += loss.item()
            num_batches += 1
            
            # Log progress
            if batch_idx % self.log_interval == 0:
                print(f'Epoch {epoch+1}, Batch {batch_idx}/{len(self.train_loader)}, '
                      f'Loss: {loss.item():.4f}')
        
        epoch_time = time.time() - start_time
        peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 2)
        avg_loss = running_loss / num_batches
        
        return avg_loss, epoch_time, peak_memory
    
    def validate(self) -> float:
        """Validate the model and compute mIoU."""
        self.model.eval()
        all_ious = []
        
        with torch.no_grad():
            for val_images, val_labels in self.val_loader:
                val_images = val_images.cuda(non_blocking=True)
                val_labels = val_labels.cuda(non_blocking=True)
                
                if self.use_amp:
                    with autocast():
                        preds = self.model(val_images)
                else:
                    preds = self.model(val_images)
                
                preds = torch.argmax(preds, dim=1)
                
                # Use the more efficient mIoU computation
                iou = compute_mIoU_batch_efficient(
                    preds, val_labels, self.num_classes, self.ignore_index
                )
                all_ious.append(iou)
        
        return np.mean(all_ious)
    
    def save_checkpoint(self, epoch: int, miou: float, filepath: str):
        """Save model checkpoint."""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_miou': self.best_miou,
            'miou': miou,
            'train_losses': self.train_losses,
            'val_mious': self.val_mious
        }
        
        if self.scaler:
            checkpoint['scaler_state_dict'] = self.scaler.state_dict()
        
        torch.save(checkpoint, filepath)
        print(f'Checkpoint saved to {filepath}')
    
    def train(self, num_epochs: int, save_dir: str = './checkpoints'):
        """Main training loop."""
        import os
        os.makedirs(save_dir, exist_ok=True)
        
        for epoch in range(num_epochs):
            # Training
            avg_loss, epoch_time, peak_memory = self.train_epoch(epoch)
            self.train_losses.append(avg_loss)
            
            # Validation
            avg_miou = self.validate()
            self.val_mious.append(avg_miou)
            
            # Learning rate scheduling
            self.scheduler.step()
            current_lr = self.optimizer.param_groups[0]['lr']
            
            # Print epoch summary
            print(f"Epoch {epoch+1}/{num_epochs}:")
            print(f"  Train Loss: {avg_loss:.4f}")
            print(f"  Val mIoU: {avg_miou:.4f}")
            print(f"  Time: {epoch_time:.2f}s")
            print(f"  Peak Memory: {peak_memory:.2f} MB")
            print(f"  Learning Rate: {current_lr:.2e}")
            print("-" * 50)
            
            # Save best model
            if avg_miou > self.best_miou:
                self.best_miou = avg_miou
                self.save_checkpoint(
                    epoch, avg_miou, 
                    os.path.join(save_dir, 'best_model.pth')
                )
                print(f"New best mIoU: {self.best_miou:.4f}")
            
            # Save regular checkpoint every 10 epochs
            if (epoch + 1) % 10 == 0:
                self.save_checkpoint(
                    epoch, avg_miou,
                    os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth')
                )
        
        print(f"Training completed! Best mIoU: {self.best_miou:.4f}")
        return self.best_miou

# Example usage:
def main():
    # Assuming you have model, train_loader, val_loader defined
    # model = YourSegmentationModel(num_classes=150)
    # train_loader = create_ade20k_dataloader(...)
    # val_loader = create_ade20k_dataloader(...)
    
    trainer = SegmentationTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_classes=150,
        lr=1e-4,
        ignore_index=-1,  # or 255 for some datasets
        use_amp=True,
        log_interval=10
    )
    
    best_miou = trainer.train(num_epochs=100, save_dir='./checkpoints')
    print(f"Final best mIoU: {best_miou:.4f}")

# For backward compatibility, here's the improved standalone version
def improved_training_loop(model, train_loader, val_loader, num_epochs=100, num_classes=150):
    """Improved version of your original training loop."""
    
    # Use AdamW with weight decay
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    
    # Add learning rate scheduler
    scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=num_epochs, power=0.9)
    
    # Use ignore_index=255 for ADE20K (common convention)
    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
    
    # Mixed precision training
    scaler = GradScaler()
    
    best_miou = 0.0
    
    for epoch in range(num_epochs):
        model.train()
        torch.cuda.reset_peak_memory_stats()
        start_time = time.time()
        
        running_loss = 0.0
        num_batches = 0
        
        for images, labels in train_loader:
            images, labels = images.cuda(non_blocking=True), labels.cuda(non_blocking=True)
            
            optimizer.zero_grad()
            
            # Mixed precision forward pass
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            # Mixed precision backward pass
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            num_batches += 1
        
        # Update learning rate
        scheduler.step()
        
        end_time = time.time()
        peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 2)
        avg_loss = running_loss / num_batches
        
        # Validation
        model.eval()
        iou_scores = []
        with torch.no_grad():
            for val_images, val_labels in val_loader:
                val_images = val_images.cuda(non_blocking=True)
                val_labels = val_labels.cuda(non_blocking=True)
                
                with autocast():
                    preds = model(val_images)
                preds = torch.argmax(preds, dim=1)
                
                # Use the more efficient mIoU computation
                iou = compute_mIoU_batch_efficient(preds, val_labels, num_classes, ignore_index=255)
                iou_scores.append(iou)
        
        avg_mIoU = np.mean(iou_scores)
        current_lr = optimizer.param_groups[0]['lr']
        
        # Save best model
        if avg_mIoU > best_miou:
            best_miou = avg_mIoU
            torch.save(model.state_dict(), 'best_model.pth')
        
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f} | mIoU = {avg_mIoU:.4f} | "
              f"Time = {end_time - start_time:.2f}s | Peak Mem = {peak_memory:.2f} MB | "
              f"LR = {current_lr:.2e}")

# if __name__ == "__main__":
#     main()