In [1]:
import os
import zipfile
from tqdm import tqdm  # For progress bar

In the following we define the LoveDA dataset class

In [2]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset

class LoveDADataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, target_transform=None):
        """
        Args:
            image_dir (str): Path to the directory with input images.
            mask_dir (str): Path to the directory with masks.
            transform (callable, optional): Transformations for the input images.
            target_transform (callable, optional): Transformations for the masks.
        """
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted(os.listdir(image_dir))
        self.mask_filenames = sorted(os.listdir(mask_dir))
        self.transform = transform
        self.target_transform = target_transform
        self.class_weights = None

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        # Load image and mask
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])

        # Use PIL to load images
        image = Image.open(image_path).convert("RGB")  # Convert to 3-channel RGB
        mask = Image.open(mask_path)  # Grayscale mask (single channel)

        # Apply transforms
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)

        return image, mask


In [3]:
import os
import numpy as np
from PIL import Image
import cv2
import torch
from datasets.base_dataset import BaseDataset

class LoveDA(BaseDataset):
    def __init__(self, 
                 image_dir, 
                 mask_dir,
                 num_classes=7,
                 multi_scale=True, 
                 flip=True, 
                 ignore_label=-1, 
                 base_size=1024, 
                 crop_size=(128, 128),
                 scale_factor=16,
                 mean=[0.485, 0.456, 0.406], 
                 std=[0.229, 0.224, 0.225],
                 bd_dilate_size=4):

        super(LoveDA, self).__init__(ignore_label, base_size,
                crop_size, scale_factor, mean, std,)

        self.mask_dir = mask_dir
        self.image_dir = image_dir
        self.num_classes = num_classes

        self.multi_scale = multi_scale
        self.flip = flip
        self.image_filenames = sorted(os.listdir(image_dir))
        self.mask_filenames = sorted(os.listdir(mask_dir))
        
        #self.img_list = [line.strip().split() for line in open(root + list_path)]

        # Define label mapping for LoveDA
        self.label_mapping = {
            0: ignore_label,  # Background
            1: 0,  # Urban land
            2: 1,  # Agriculture
            3: 2,  # Rangeland
            4: 3,  # Forest
            5: 4,  # Water
            6: 5,  # Barren land
            7: 6   # Unknown
        }

        # Class weights for LoveDA (example values; adjust as needed)
        self.class_weights = torch.FloatTensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).cuda()

        self.bd_dilate_size = bd_dilate_size
    
        
    def convert_label(self, label, inverse=False):
        temp = label.copy()
        if inverse:
            for v, k in self.label_mapping.items():
                label[temp == k] = v
        else:
            for k, v in self.label_mapping.items():
                label[temp == k] = v
        return label

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.image_filenames[index])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[index])
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        size = image.shape

        '''if 'test' in self.list_path:
            image = self.input_transform(image)
            image = image.transpose((2, 0, 1))

            return image.copy(), np.array(size), self.image_filenames[index]'''

        label = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        label = self.convert_label(label)

        image, label, edge = self.gen_sample(image, label, 
                                self.multi_scale, self.flip, edge_size=self.bd_dilate_size)

        return image.copy(), label.copy(), edge.copy(), np.array(size), self.image_filenames[index]

    def single_scale_inference(self, config, model, image):
        pred = self.inference(config, model, image)
        return pred

    def save_pred(self, preds, sv_path, name):
        preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
        for i in range(preds.shape[0]):
            pred = self.convert_label(preds[i], inverse=True)
            save_img = Image.fromarray(pred)
            save_img.save(os.path.join(sv_path, name[i] + '.png'))


Verifica codice

In [4]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset

image_path = 'data\\LoveDA\\train\\Urban\\images_png'
mask_path = 'data\\LoveDA\\train\\Urban\\masks_png'

assert os.path.exists(image_path), f"Image path does not exist: {image_path}"
assert os.path.exists(mask_path), f"Mask path does not exist: {mask_path}"

image_path = os.path.join(image_path, '1366.png')
print(image_path)
mask_path = os.path.join(mask_path, '1366.png')
print(mask_path)
try:
    image = Image.open(image_path).convert("RGB")
    mask = Image.open(mask_path)
except Exception as e:
    raise RuntimeError(f"Error loading image or mask at index : {e}")


data\LoveDA\train\Urban\images_png\1366.png
data\LoveDA\train\Urban\masks_png\1366.png


Load pretrained PIDNet-S on ImageNet

In [5]:
import torch
import sys
import logging
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm  # Import tqdm for progress bars
from torchmetrics.classification import JaccardIndex
from torch.utils.data import DataLoader
from torchvision import transforms
from models.pidnet import PIDNet
from configs import config
from configs import update_config
from utils.criterion import CrossEntropy, OhemCrossEntropy, BondaryLoss
from utils.function import train, validate
from utils.utils import create_logger, FullModel
import timeit
import torch.backends.cudnn as cudnn
import pprint
import argparse
import numpy as np
from tensorboardX import SummaryWriter


def squeeze_channel(tensor):
    return tensor.squeeze(0).long()

def rescale_labels(tensor):
    """
    Rescales labels:
      - Class 0 -> -1 (ignore)
      - Classes 1–7 -> 0–6
    """
    tensor = tensor.squeeze(0).long()
    tensor = tensor - 1  # Shift labels down by 1
    tensor[tensor == -1] = -1  # Ensure 0 becomes -1
    return tensor

def parse_args():
    parser = argparse.ArgumentParser(description='Train segmentation network')
    
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        default="configs/LoveDA/pidnet_small_loveda_pretrained.yaml",
                        type=str)
    parser.add_argument('--seed', type=int, default=0)    
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    args, _ = parser.parse_known_args()  # This ignores unrecognized arguments
    update_config(config, args)

    return args

def main():

    args = parse_args()

    if args.seed > 0:
        import random
        print('Seeding with', args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)        

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))

    writer_dict = {
        'writer': SummaryWriter(tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    pretrained_state = torch.load('pretrained_models/imagenet/PIDNet_S_ImageNet.pth.tar')['state_dict']

    model = PIDNet(m=2, n=3, num_classes=7, planes=32, ppm_planes=96, head_planes=128, augment=False)
    model_dict = model.state_dict()
    pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)}
    model_dict.update(pretrained_state)
    msg = 'Loaded {} parameters!'.format(len(pretrained_state))
    #logging.info('Attention!!!')
    #logging.info(msg)
    #logging.info('Over!!!')
    model.load_state_dict(model_dict, strict = False)



    # Define transformations for images and masks
    img_transforms = transforms.Compose([
        transforms.Resize((128, 128)),  # Resize to 512x512
        transforms.ToTensor(),          # Convert image to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
    ])

    mask_transforms = transforms.Compose([
        transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.NEAREST),  # Resize mask
        transforms.PILToTensor(),  # Convert mask to tensor
        transforms.Lambda(rescale_labels)  # Remove channel dimension
    ])


    # Dataset paths
    train_dataset = LoveDA(
        image_dir="data\\LoveDA\\train\\Urban\\images_png",
        mask_dir="data\\LoveDA\\train\\Urban\\masks_png",
        #transform=img_transforms,
        #target_transform=mask_transforms
    )

    val_dataset = LoveDA(
        image_dir="data\\LoveDA\\val\\Urban\\images_png",
        mask_dir="data\\LoveDA\\val\\Urban\\masks_png",
        #transform=img_transforms,
        #target_transform=mask_transforms
    )

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)

    num_classes = 7  # Update based on LoveDA
    
    model = PIDNet(m=2, n=3, num_classes=num_classes, planes=32, ppm_planes=96, head_planes=128, augment=True)
    model_dict = model.state_dict()
    pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)}
    model_dict.update(pretrained_state)
    msg = 'Loaded {} parameters!'.format(len(pretrained_state))
    #logging.info('Attention!!!')
    #logging.info(msg)
    #logging.info('Over!!!')
    model.load_state_dict(model_dict, strict = False)
    

    # criterion
    if config.LOSS.USE_OHEM:
        sem_criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                        thres=config.LOSS.OHEMTHRES,
                                        min_kept=config.LOSS.OHEMKEEP,
                                        weight=train_dataset.class_weights)
    else:
        sem_criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                    weight=train_dataset.class_weights)

    bd_criterion = BondaryLoss()
    
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)
    if torch.cuda.device_count() != len(gpus):
        print("The gpu numbers do not match!")
        return 0
    epoch_iters = int(train_dataset.__len__() / config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
        
    model = FullModel(model, sem_criterion, bd_criterion)
    model = nn.DataParallel(model, device_ids=gpus).cuda()

    # optimizer
    if config.TRAIN.OPTIMIZER == 'sgd':
        params_dict = dict(model.named_parameters())
        params = [{'params': list(params_dict.values()), 'lr': config.TRAIN.LR}]

        optimizer = torch.optim.SGD(params,
                                lr=config.TRAIN.LR,
                                momentum=config.TRAIN.MOMENTUM,
                                weight_decay=config.TRAIN.WD,
                                nesterov=config.TRAIN.NESTEROV,
                                )
    else:
        raise ValueError('Only Support SGD optimizer')

    # Training loop
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    num_epochs = 20  # Number of epochs
    best_mIoU = 0
    last_epoch = 0
    start = timeit.default_timer()
    end_epoch = config.TRAIN.END_EPOCH
    num_iters = config.TRAIN.END_EPOCH * epoch_iters
    real_end = 120+1 if 'camvid' in config.DATASET.TRAIN_SET else end_epoch
    

    for epoch in range(num_epochs):

        current_trainloader = train_loader
        if current_trainloader.sampler is not None and hasattr(current_trainloader.sampler, 'set_epoch'):
            current_trainloader.sampler.set_epoch(epoch)

        train(config, epoch, config.TRAIN.END_EPOCH, 
                  epoch_iters, config.TRAIN.LR, num_iters,
                  train_loader, optimizer, model, writer_dict)

        if flag_rm == 1 or (epoch % 5 == 0 and epoch < num_epochs - 5) or (epoch >= num_epochs - 5):
            valid_loss, mean_IoU, IoU_array = validate(config, 
                        val_loader, model, writer_dict)
        if flag_rm == 1:
            flag_rm = 0

        logger.info('=> saving checkpoint to {}'.format(
            final_output_dir + 'checkpoint.pth.tar'))
        torch.save({
            'epoch': epoch+1,
            'best_mIoU': best_mIoU,
            'state_dict': model.module.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, os.path.join(final_output_dir,'checkpoint.pth.tar'))
        if mean_IoU > best_mIoU:
            best_mIoU = mean_IoU
            torch.save(model.module.state_dict(),
                    os.path.join(final_output_dir, 'best.pt'))
        msg = 'Loss: {:.3f}, MeanIU: {: 4.4f}, Best_mIoU: {: 4.4f}'.format(
                    valid_loss, mean_IoU, best_mIoU)
        #logging.info(msg)
        #logging.info(IoU_array)



    torch.save(model.module.state_dict(),
            os.path.join(final_output_dir, 'final_state.pt'))

    writer_dict['writer'].close()
    end = timeit.default_timer()
    logger.info('Hours: %d' % int((end-start)/3600))
    logger.info('Done')

if __name__ == '__main__':
    main()

Namespace(cfg='configs/LoveDA/pidnet_small_loveda_pretrained.yaml', seed=0, opts=[])


=> creating output\loveda\pidnet_small_loveda_pretrained
=> creating log\loveda\pidnet_small\pidnet_small_loveda_pretrained_2024-12-23-18-10
-1
-1


  pretrained_state = torch.load('pretrained_models/imagenet/PIDNet_S_ImageNet.pth.tar')['state_dict']


AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA


RuntimeError: DataLoader worker (pid(s) 26772, 30044) exited unexpectedly