In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
import random
from torchmetrics import JaccardIndex

from PIL import Image
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler,random_split, ConcatDataset, Subset
import torchvision.transforms as transforms
import torchvision.models.segmentation as segmentation
import torch.optim as optim
import os
import torch.nn as nn
from sklearn.model_selection import train_test_split

from torch.cuda.amp import GradScaler, autocast
from sklearn.model_selection import StratifiedShuffleSplit

from tqdm import tqdm

In [2]:
torch.cuda.empty_cache()


In [3]:
# set the random seed
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(0)  

In [4]:
class WildScenesDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, target_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.target_transform = target_transform
        self.images = sorted(os.listdir(image_dir))
      

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            mask = self.target_transform(mask)
            mask = torch.squeeze(mask, 0).long()

        return image, mask

In [5]:
class ToLabelTensor:
    def __call__(self, pic):
        tensor = torch.from_numpy(np.array(pic, dtype=np.int32)).long()
        max_value = tensor.max().item()
        if max_value >= 19:
            print(f"Label value out of range: {max_value}")
        assert max_value < 19, "Label value out of range"  
        return tensor


In [6]:
import time
from multiprocessing import Pool

base_paths = [
    # '/root/autodl-tmp/K-01/',
    # '/root/autodl-tmp/K-03/',
    '/root/autodl-tmp/V-01/',
    # '/root/autodl-tmp/V-02/',
    # '/root/autodl-tmp/V-03/'
]
image_dirs = [os.path.join(base_path, 'image') for base_path in base_paths]
mask_dirs = [os.path.join(base_path, 'indexLabel') for base_path in base_paths]


transform = transforms.Compose(
    [
        transforms.Resize((512, 512)),  
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

target_transform = transforms.Compose(
    [
        transforms.Resize((512, 512), interpolation=Image.NEAREST),
        # transforms.ToTensor(),
        ToLabelTensor(),
    ]
)
datasets = [
    WildScenesDataset(image_dir=image_dir, mask_dir=mask_dir, transform=transform, target_transform=target_transform)
    for image_dir, mask_dir in zip(image_dirs, mask_dirs)
]
full_dataset = ConcatDataset(datasets)


# split the dataset
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=4, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=4, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=4, drop_last=True)



class_names = [
    "unlabelled",
    "asphalt",
    "dirt",
    "mud",
    "water",
    "gravel",
    "other-terrain",
    "tree-trunk",
    "tree-foliage",
    "bush",
    "fence",
    "structure",
    "pole",
    "vehicle",
    "rock",
    "log",
    "other-object",
    "sky",
    "grass",
]



In [7]:
from collections import Counter
import matplotlib.pyplot as plt


# def check_class_distribution(dataset, num_classes):
#     class_counts = np.zeros(num_classes)
#     for _, mask in dataset:
#         mask = mask.numpy().flatten()
#         for cls in range(num_classes):
#             class_counts[cls] += np.sum(mask == cls)
#     return class_counts

# class_counts = check_class_distribution(train_dataset, 19)
# # print("Class distribution in the dataset:")
# # for i, count in enumerate(class_counts):
# #     print(f'{class_names[i]}: {count}')

# def plot_class_distribution(class_counts, class_names):
#     counts = [class_counts[i] for i in range(len(class_names))]
#     plt.figure(figsize=(10, 5))
#     plt.bar(class_names, counts)
#     plt.xlabel('Class')
#     plt.ylabel('Number of Pixels')
#     plt.title('Class Distribution in Dataset')
#     plt.xticks(rotation=90)
#     plt.show()

# plot_class_distribution(class_counts, class_names)

In [8]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0

    def __call__(self, val_iou, model):
        score = val_iou

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model)
            self.counter = 0

    def save_checkpoint(self, model):
        # save model
        torch.save(model.state_dict(), 'checkpoint.pt')

In [9]:
import torchvision.models.segmentation as models

# apply pretrained deeplabv3 model
num_classes=19
model = models.deeplabv3_resnet101(pretrained=True)
model.classifier[4] = torch.nn.Conv2d(256, num_classes, kernel_size=(1, 1))

# #self-defined classfifier
# original_classifier = model.classifier

# class CustomClassifier(nn.Sequential):
#     def __init__(self, original_classifier, num_classes):
#         super(CustomClassifier, self).__init__()
#         self.original_classifier = original_classifier

#         # new convolution layer
#         self.additional_conv = nn.Conv2d(256, 256, kernel_size=3, padding=1)
#         self.relu = nn.ReLU()

#         self.classifier = nn.Conv2d(256, num_classes, kernel_size=(1, 1))

#     def forward(self, x):
#         x = self.original_classifier[0](x)
#         x = self.original_classifier[1](x)
#         x = self.original_classifier[2](x)
#         x = self.original_classifier[3](x)
#         x = self.relu(self.additional_conv(x))
#         x = self.classifier(x)
#         return x

# model.classifier = CustomClassifier(original_classifier, num_classes)



iou_metric = JaccardIndex(task='multiclass', num_classes=num_classes, average=None).to('cuda')


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)




In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.optim.lr_scheduler import ReduceLROnPlateau


# class DiceLoss(nn.Module):
#     def __init__(self, smooth=1.e-5):
#         super(DiceLoss, self).__init__()
#         self.smooth = smooth

#     def forward(self, outputs, targets):
#         outputs = torch.softmax(outputs, dim=1) 
#         targets = F.one_hot(targets, num_classes=outputs.shape[1]).permute(0, 3, 1, 2).float()

#         intersection = torch.sum(outputs * targets, dim=(2, 3))
#         union = torch.sum(outputs, dim=(2, 3)) + torch.sum(targets, dim=(2, 3))

#         dice_score = (2. * intersection + self.smooth) / (union + self.smooth)
#         dice_loss = 1 - dice_score.mean(dim=1)
#         return dice_loss.mean()


# criterion = DiceLoss()
criterion = torch.nn.CrossEntropyLoss()

# # apply adam
optimizer = optim.Adam(model.parameters(), lr=0.001)
# #apply sgd
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)

# apply a learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True)



In [11]:
# def calculate_class_iou(pred, target, num_classes):
#     iou_list = []
#     pred = pred.view(-1)
#     target = target.view(-1)

#     for cls in range(num_classes):
#         pred_inds = pred == cls
#         target_inds = target == cls
#         intersection = (pred_inds & target_inds).sum().float().item()
#         union = (pred_inds | target_inds).sum().float().item()
#         if union == 0:
#             iou_list.append(float('nan')) 
#         else:
#             iou_list.append(intersection / union)

#     return torch.tensor(iou_list, device=pred.device)

In [12]:
# scaler = GradScaler()
import logging

# apply logging to record the training result
logging.basicConfig(filename='training12.log', level=logging.INFO)

early_stopping = EarlyStopping(patience=3, verbose=True) 

num_epochs = 10  

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as pbar:
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()

            outputs = model(images)['out']
            loss = criterion(outputs, masks.long())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(loss=running_loss / len(train_loader))
            pbar.update(1)

    logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

     # evaluate model
    model.eval()
    val_loss = 0.0
    val_class_ious = []
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)['out']
            loss = criterion(outputs, masks)
            val_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            iou_scores = iou_metric(preds, masks)
            val_class_ious.append(iou_scores.cpu().numpy())


    val_loss /= len(val_loader)
    val_class_ious = np.nanmean(val_class_ious, axis=0)
    mean_iou = np.nanmean(val_class_ious)

    logging.info(f'Validation Loss: {val_loss:.4f}')
    for i, iou in enumerate(val_class_ious):
        logging.info(f'{class_names[i]} IoU: {iou:.4f}')
    logging.info(f'Mean IoU: {mean_iou:.4f}')

    
    # mean_iou = np.nanmean(val_class_ious)
    early_stopping(mean_iou, model)
    if early_stopping.early_stop:
        logging.info("Early stopping")
        break
    scheduler.step(mean_iou)
    torch.cuda.empty_cache()

    

# load best model
model.load_state_dict(torch.load('checkpoint.pt'))

Epoch 1/10: 100%|██████████| 52/52 [00:42<00:00,  1.21batch/s, loss=0.888]
Epoch 2/10: 100%|██████████| 52/52 [00:42<00:00,  1.22batch/s, loss=0.46]  
Epoch 3/10: 100%|██████████| 52/52 [00:42<00:00,  1.21batch/s, loss=0.416] 


EarlyStopping counter: 1 out of 3


Epoch 4/10: 100%|██████████| 52/52 [00:42<00:00,  1.21batch/s, loss=0.406] 


EarlyStopping counter: 2 out of 3


Epoch 5/10: 100%|██████████| 52/52 [00:42<00:00,  1.21batch/s, loss=0.386] 


EarlyStopping counter: 3 out of 3


<All keys matched successfully>

In [13]:
# calculate test IoU
model.eval()
test_class_ious = []
with torch.no_grad():
    for images, masks in test_loader:
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)['out']
        preds = torch.argmax(outputs, dim=1)
        iou_scores = iou_metric(preds, masks)
        test_class_ious.append(iou_scores.cpu().numpy())

test_class_ious = np.nanmean(test_class_ious, axis=0)
mean_test_iou = np.nanmean(test_class_ious)

for i, iou in enumerate(test_class_ious):
    logging.info(f'{class_names[i]} Test IoU: {iou:.4f}')
logging.info(f'Mean IoU: {mean_test_iou:.4f}')


In [14]:
torch.save(model.state_dict(), 'trained_model.pth')
