In [None]:
from pathlib import Path

from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import random

import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision import models

from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision
from torchmetrics.segmentation import MeanIoU


In [None]:
# Utilities Section
# RGB format
index_color_mapping={0:(0,0,0),         # void
                     1:(108,64,20),     # dirt
                     2:(255,229,204),   # sand
                     3:(0,102,0),       # grass
                     4:(0,255,0),       # tree
                     5:(0,153,153),     # pole
                     6:(0,128,255),     # water
                     7:(0,0,255),       # sky
                     8:(255,255,0),     # vehicle
                     9:(255,0,127),     # container/generic-object
                     10:(64,64,64),     # asphalt
                     11:(255,128,0),    # gravel
                     12:(255,0,0),      # building
                     13:(153,76,0),     # mulch
                     14:(102,102,0),    # rock-bed
                     15:(102,0,0),      # log
                     16:(0,255,128),    # bicycle
                     17:(204,153,255),  # person
                     18:(102,0,204),    # fence
                     19:(255,153,204),  # bush
                     20:(0,102,102),    # sign
                     21:(153,204,255),  # rock
                     22:(102,255,255),  # bridge
                     23:(101,101,11),   # concrete
                     24:(114,85,47)}    # picnic-table

# RGB format
color_index_mapping={(0,0,0):0,         # void
                     (108,64,20):1,     # dirt
                     (255,229,204):2,   # sand
                     (0,102,0):3,       # grass
                     (0,255,0):4,       # tree
                     (0,153,153):5,     # pole
                     (0,128,255):6,     # water
                     (0,0,255):7,       # sky
                     (255,255,0):8,     # vehicle
                     (255,0,127):9,     # container/generic-object
                     (64,64,64):10,     # asphalt
                     (255,128,0):11,    # gravel
                     (255,0,0):12,      # building
                     (153,76,0):13,     # mulch
                     (102,102,0):14,    # rock-bed
                     (102,0,0):15,      # log
                     (0,255,128):16,    # bicycle
                     (204,153,255):17,  # person
                     (102,0,204):18,    # fence
                     (255,153,204):19,  # bush
                     (0,102,102):20,    # sign
                     (153,204,255):21,  # rock
                     (102,255,255):22,  # bridge
                     (101,101,11):23,   # concrete
                     (114,85,47):24}    # picnic-table     

def index_lookup(color:tuple)->int:
    """
    Get index of color from color_index_mapping where the format of color is RGB.\n
    Therefore, you must convert color format to RGB before you pass variable 'color'\n
    The variable 'color' is a tuple.
    """
    return color_index_mapping[color]

def to_color_label(index_label:np.ndarray) -> np.ndarray:
    """
    Convert index label to color label for showing the result of prediction.
    """
    h,w=index_label.shape
    color_label=np.zeros((h,w,3),dtype=np.uint8)
    for i in range(h):
        for j in range(w):
            color=index_color_mapping[index_label[i][j]]
            r,g,b=color
            color_label[i][j]=np.array([b,g,r])
    return color_label

def get_dirs_list(dir_path:Path)->list:
    return [p for p in dir_path.iterdir() if p.is_dir()]

def get_files_list(dir_path:Path,suffix:str)->list:
    return [f for f in dir_path.iterdir() if f.is_file() and f.suffix==suffix]

def process_row(row:np.ndarray)->list:
    index_list=[]
    for element in row:
        bgr=element.tolist()
        rgb=(bgr[2],bgr[1],bgr[0])
        index_list.append(index_lookup(rgb))
    return index_list

def convert_color2index(file_path:str)->np.ndarray:
    color=cv2.imread(file_path,cv2.IMREAD_COLOR)
    rows=[color[i,:] for i in range(color.shape[0])]
    with ThreadPoolExecutor() as executor:
        rst=list(executor.map(process_row,rows))
    return np.array(rst,dtype=np.uint8)

class CustomDataset(Dataset):
    def __init__(self, workspace_path:str, csv_file_name:str, color_transforms=None, target_size=None, loading_mode="pre"):
        workspace_path=Path(workspace_path)
        csv_file_path=workspace_path/csv_file_name
        files_list=np.genfromtxt(csv_file_path,dtype=None,encoding='utf-8',delimiter=',')
        self.num_files=files_list.shape[0]
        self.color_transforms=color_transforms
        self.target_size=target_size
        self.loading_mode=loading_mode

        self.data_file_full_path_list=[]
        for pair in files_list:
            color_file_path=Path(pair[0])
            index_file_path=Path(pair[1])
            self.data_file_full_path_list.append([color_file_path,index_file_path])

        self.color_index_list=None
        if loading_mode=="pre":
            with ThreadPoolExecutor() as executor:
                self.color_index_list=list(tqdm(executor.map(self.__load_color_and_index__,self.data_file_full_path_list),desc="Load images and labels",total=self.__len__()))

    def __load_color_and_index__(self,path_pair)->list:
        color_image_path=path_pair[0]
        index_label_path=path_pair[1]
        color_image=cv2.imread(color_image_path.as_posix(),cv2.IMREAD_UNCHANGED)
        index_label=cv2.imread(index_label_path.as_posix(),cv2.IMREAD_UNCHANGED)
        if self.target_size is not None:
            color_image=cv2.resize(color_image,self.target_size,interpolation=cv2.INTER_LINEAR)
            index_label=cv2.resize(index_label,self.target_size,interpolation=cv2.INTER_NEAREST)
        return[color_image,index_label]

    def __len__(self)->int:
        return self.num_files

    def __getitem__(self, index):
        color_image=None
        index_label=None
        if self.loading_mode=="pre":
            color_image=self.color_index_list[index][0]
            index_label=self.color_index_list[index][1]
        else:
            pair=self.__load_color_and_index__(self.data_file_full_path_list[index])
            color_image=pair[0]
            index_label=pair[1]
            
        if self.color_transforms is not None:
            color_image=self.color_transforms(color_image)
            index_label=torch.from_numpy(index_label).long()
        return color_image,index_label

In [None]:
# Before everything starting, please create a workspace folder named RUGD_ws 
# Move folders RUGD_frames-with-annotations and RUGD_annotations into RUGD_ws
# Create folder model within RUGD_ws
# Set workspace path here
workspace_path="C:/Users/SenGao/Downloads/RUGD_ws"

frames_folder_name="RUGD_frames-with-annotations"
annotations_folder_name="RUGD_annotations"
annotations_index_folder_name="RUGD_annotations_index"
model_folder_name="model"

random_state=42
random.seed(random_state)

training_folder_list=["park-2","trail","trail-3","trail-4","trail-6","trail-9","trail-10","trail-11","trail-12","trail-14","trail-15","village"]
val_folder_list=["park-8","trail-5"]
testing_folder_list=["creek","park-1","trail-7","trail-13"]

In [None]:
# Index labels generation
annotations_path=Path(workspace_path)/annotations_folder_name
annotation_index_path=Path(workspace_path)/annotations_index_folder_name
annotation_index_path.mkdir()
print("Start generating index labels.")
dirs_list=get_dirs_list(annotations_path)
for p in dirs_list:
    new_p=Path(annotation_index_path/p.name)
    new_p.mkdir()
    files_list=get_files_list(p,".png")
    for f in tqdm(files_list,desc=f"folder '{p.name}' is being processed.",total=len(files_list),leave=False):
        cv2.imwrite((new_p/f.name).as_posix(),convert_color2index(f.as_posix()))
print("Generating index labels finishes.")

In [None]:
# Dataset spliting
frames_path=Path(workspace_path)/frames_folder_name
annotations_index_path=Path(workspace_path)/annotations_index_folder_name
saving_path=Path(workspace_path)

print("Start spliting dataset.")
training_data_path_str_list=[]
for folder_name in training_folder_list:
    frame_path_list=get_files_list(frames_path/folder_name,".png")
    for frame_path in frame_path_list:
        index_path=annotations_index_path/frame_path.parent.name/frame_path.name
        training_data_path_str_list.append([frame_path.as_posix(),index_path.as_posix()])
np.savetxt((saving_path/"train_data_path.csv").as_posix(),np.array(training_data_path_str_list),delimiter=',',fmt='%s')
val_data_path_str_list=[]
for folder_name in val_folder_list:
    frame_path_list=get_files_list(frames_path/folder_name,".png")
    for frame_path in frame_path_list:
        index_path=annotations_index_path/frame_path.parent.name/frame_path.name
        val_data_path_str_list.append([frame_path.as_posix(),index_path.as_posix()])
np.savetxt((saving_path/"val_data_path.csv").as_posix(),np.array(val_data_path_str_list),delimiter=',',fmt='%s')
testing_data_path_str_list=[]
for folder_name in testing_folder_list:
    frame_path_list=get_files_list(frames_path/folder_name,".png")
    for frame_path in frame_path_list:
        index_path=annotations_index_path/frame_path.parent.name/frame_path.name
        testing_data_path_str_list.append([frame_path.as_posix(),index_path.as_posix()])
np.savetxt((saving_path/"test_data_path.csv").as_posix(),np.array(testing_data_path_str_list),delimiter=',',fmt='%s')
print("Finish.")

In [None]:
# Standardization_para_calculating
class TrainingDataset(Dataset):
    def __init__(self,workspace_path,training_data_csv_file_name="train_data_path.csv") -> None:
        workspace_path=Path(workspace_path)
        csv_file_path=workspace_path/training_data_csv_file_name
        files_list= np.genfromtxt(csv_file_path.as_posix(),dtype=None,encoding='utf-8',delimiter=',')
        self.frame_path_list=[]
        for s in files_list[:,:1].tolist():
            self.frame_path_list.append(Path(s[0]))

    def __len__(self) -> int:
        return len(self.frame_path_list)
    
    def __getitem__(self, index):
        image=cv2.imread(self.frame_path_list[index].as_posix(),cv2.IMREAD_UNCHANGED)
        image=transforms.ToTensor()(image)
        return image

csv_file_name="train_data_path.csv"
dataset=TrainingDataset(workspace_path,csv_file_name)
data_loader=DataLoader(dataset=dataset,batch_size=32)

sum_bgr=torch.zeros(3,dtype=torch.float64)
num_pixel=0
for images in tqdm(data_loader,desc="Calculate mean: "):
    sum_bgr+=torch.sum(images,[0,2,3])
    num_pixel+=images.size(0)*images.size(2)*images.size(3)
mean=(sum_bgr/num_pixel)
print(f"mean: {mean}")

sum_squared_bgr=torch.zeros(3,dtype=torch.float64)
for images in tqdm(data_loader,desc="Calculate std: "):
    resharped_mean=mean[None, :, None, None]
    difference_value=images-resharped_mean
    squared_difference_value=difference_value**2
    sum_squared_bgr+=squared_difference_value.sum([0,2,3])
std=torch.sqrt(sum_squared_bgr/num_pixel)
print(f"std: {std}")

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))


class SpatialPath(nn.Module):
    def __init__(self):
        super(SpatialPath, self).__init__()
        self.conv1 = ConvBlock(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
        self.conv2 = ConvBlock(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.conv3 = ConvBlock(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.dropout1=nn.Dropout2d(0.1)
        self.dropout2=nn.Dropout2d(0.2)
        self.dropout3=nn.Dropout2d(0.3)
    def forward(self, x):
        x=self.conv1(x)
        x=self.dropout1(x)
        x=self.conv2(x)
        x=self.dropout2(x)
        x=self.conv3(x)
        x=self.dropout3(x)
        return x


class AttentionRefinementModule(nn.Module):
    def __init__(self, in_channels,out_channels):
        super(AttentionRefinementModule, self).__init__()
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg=self.global_pool(x)
        avg=self.conv(avg)
        avg=self.bn(avg)
        attention = self.sigmoid(avg)
        x=self.conv(x)
        x=self.bn(x)
        x=self.relu(x)
        return x * attention


class ContextPath(nn.Module):
    def __init__(self):
        super(ContextPath, self).__init__()
        self.resnet18 = models.resnet18(pretrained=True)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.arm8x = AttentionRefinementModule(128,512)
        self.arm16x = AttentionRefinementModule(256,256)
        self.conv1=ConvBlock(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.conv2=ConvBlock(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.dropout1=nn.Dropout2d(0.1)
        self.dropout2=nn.Dropout2d(0.2)
        self.dropout3=nn.Dropout2d(0.3)
        
    def forward(self, x):
        x2 = self.resnet18.conv1(x)
        x2 = self.resnet18.bn1(x2)
        x2 = self.resnet18.relu(x2)
        x4 = self.resnet18.maxpool(x2)
        x4 = self.resnet18.layer1(x4)
        x8 = self.resnet18.layer2(x4)  
        x16 = self.resnet18.layer3(x8)
        
        avg=self.avg_pool(x16)
        avg_up=F.interpolate(avg, size=x16.size()[2:], mode='bilinear', align_corners=True)
        
        x16_arm=self.arm16x(x16)
        x16_arm=x16_arm+avg_up
        x16_arm_up=F.interpolate(x16_arm, size=x8.size()[2:], mode='bilinear', align_corners=True)
        x16_arm_up=self.conv1(x16_arm_up)
        x16_arm_up=self.dropout2(x16_arm_up)
        
        x8_arm=self.arm8x(x8)
        x8_arm=x8_arm+x16_arm_up
        x8_arm=self.conv2(x8_arm)
        x8_arm=self.dropout3(x8_arm)
        
        return x8_arm,x


class FeatureFusionModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FeatureFusionModule, self).__init__()
        self.conv = ConvBlock(in_channels, out_channels, 1, 1, 0)
        self.dropout=nn.Dropout2d(0.4)
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(out_channels, out_channels // 4, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // 4, out_channels, 1, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, sp, cp):
        fusion = torch.cat([sp, cp], dim=1)
        fusion = self.conv(fusion)
        fusion=self.dropout(fusion)
        attention = self.attention(fusion)
        return fusion + fusion * attention


class BiSeNet(nn.Module):
    def __init__(self, num_classes):
        super(BiSeNet, self).__init__()
        self.spatial_path = SpatialPath()
        self.context_path = ContextPath()
        self.feature_fusion = FeatureFusionModule(512,128)
        self.final_conv=nn.Conv2d(128, num_classes, 1,1,0)
        
    def forward(self, x):
        sp = self.spatial_path(x)
        cp,x = self.context_path(x)
        out = self.feature_fusion(sp,cp)
        out = F.interpolate(out, size=x.size()[2:], mode='bilinear', align_corners=True)
        out = self.final_conv(out)
        return out

In [None]:
# Parameters of training
color_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4044, 0.4067, 0.4069],std=[0.2750, 0.2738, 0.2710])
])
loading_mode="pre"
num_classes=25
target_size=(512,512)
epochs=3
batch_size=8
learning_rate=0.001

model=BiSeNet(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)
#scheduler = optim.StepLR(optimizer, step_size=10, gamma=0.1)

metric_val_acc = MulticlassAccuracy(num_classes=num_classes, ignore_index=0)
metric_val_precision = MulticlassPrecision(num_classes=num_classes, average="macro", ignore_index=0)
metric_pixel_wise_acc= MulticlassAccuracy(num_classes=num_classes, ignore_index=0)
metric_mIoU=MeanIoU(num_classes=num_classes,include_background=False,input_format="index")

In [None]:
# Training
training_dataset=CustomDataset(workspace_path,"train_data_path.csv",color_transforms,target_size,loading_mode)
val_dataset=CustomDataset(workspace_path,"val_data_path.csv",color_transforms,target_size,loading_mode)
training_loader=DataLoader(training_dataset,batch_size,shuffle=True,drop_last=True)
val_loader=DataLoader(val_dataset,batch_size,shuffle=True,drop_last=True)

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Current used device is {device}")
if torch.cuda.is_available():
    torch.cuda.empty_cache()

model=model.to(device)

metric_val_acc=metric_val_acc.to(device)
metric_val_precision=metric_val_precision.to(device)
metric_pixel_wise_acc=metric_pixel_wise_acc.to(device)
metric_mIoU=metric_mIoU.to(device)

y_training_loss=np.zeros(epochs,dtype=np.float32)
y_val_loss=np.zeros(epochs,dtype=np.float32)
y_val_accuracy=np.zeros(epochs,dtype=np.float32)
y_val_precision=np.zeros(epochs,dtype=np.float32)
y_val_pixel_wise_acc=np.zeros(epochs,dtype=np.float32)
y_val_mIoU=np.zeros(epochs,dtype=np.float32)

for epoch in range(epochs):
    message=f"Current epoch:{epoch+1}/{epochs} "
    model.train()
    training_loss_sum = 0.0
    for images, labels in tqdm(training_loader, desc=message+"Training: ", leave=False):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        training_loss_sum += loss.item()
    y_training_loss[epoch]=training_loss_sum / len(training_loader)

    model.eval()
    val_loss_sum = 0.0
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc=message+"validating: ", leave=False):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss_sum += loss.item()
            pred_labels = torch.argmax(outputs, dim=1)
            pred_labels_flat = pred_labels.view(pred_labels.size(0), -1)
            labels_flat = labels.view(labels.size(0), -1)
            metric_val_acc.update(pred_labels,labels)
            metric_val_precision.update(pred_labels,labels)
            metric_pixel_wise_acc.update(pred_labels_flat,labels_flat)
            metric_mIoU.update(pred_labels,labels)

    y_val_loss[epoch]=val_loss_sum/len(val_loader)
    y_val_accuracy[epoch]=metric_val_acc.compute()
    y_val_precision[epoch]=metric_val_precision.compute()
    y_val_pixel_wise_acc[epoch]=metric_pixel_wise_acc.compute()
    y_val_mIoU[epoch]=metric_mIoU.compute()
    metric_val_acc.reset()
    metric_val_precision.reset()
    metric_pixel_wise_acc.reset()
    metric_mIoU.reset()
    print(f"Epoch: {epoch+1}/{epochs}")
    print(f"Training loss: {y_training_loss[epoch]:.4f} Validation loss: {y_val_loss[epoch]:.4f}")
    print(f"accuracy: {y_val_accuracy[epoch]:.4f} Precision: {y_val_precision[epoch]:.4f}")
    print(f"Pixel-wise Acc: {y_val_pixel_wise_acc[epoch]:.4f} MeanIoU: {y_val_mIoU[epoch]:.4f}")
torch.save(model.state_dict(), Path(workspace_path)/model_folder_name/f"model.pth")

np.savetxt((Path(workspace_path)/model_folder_name/"training_loss.csv").as_posix(),y_training_loss,delimiter=',',fmt='%f')
np.savetxt((Path(workspace_path)/model_folder_name/"val_loss.csv").as_posix(),y_val_loss,delimiter=',',fmt='%f')
np.savetxt((Path(workspace_path)/model_folder_name/"accuracy.csv").as_posix(),y_val_accuracy,delimiter=',',fmt='%f')
np.savetxt((Path(workspace_path)/model_folder_name/"precision.csv").as_posix(),y_val_precision,delimiter=',',fmt='%f')
np.savetxt((Path(workspace_path)/model_folder_name/"pixel_wise_acc.csv").as_posix(),y_val_pixel_wise_acc,delimiter=',',fmt='%f')
np.savetxt((Path(workspace_path)/model_folder_name/"mIoU.csv").as_posix(),y_val_mIoU,delimiter=',',fmt='%f')
print("Finish!")

In [None]:
# Show results of training
training_loss=np.genfromtxt((Path(workspace_path)/model_folder_name/"training_loss.csv").as_posix(),dtype=None,encoding='utf-8',delimiter=',')
val_loss=np.genfromtxt((Path(workspace_path)/model_folder_name/"val_loss.csv").as_posix(),dtype=None,encoding='utf-8',delimiter=',')
val_accuracy=np.genfromtxt((Path(workspace_path)/model_folder_name/"accuracy.csv").as_posix(),dtype=None,encoding='utf-8',delimiter=',')
val_precision=np.genfromtxt((Path(workspace_path)/model_folder_name/"precision.csv").as_posix(),dtype=None,encoding='utf-8',delimiter=',')
val_pixel_wise_accuracy=np.genfromtxt((Path(workspace_path)/model_folder_name/"pixel_wise_acc.csv").as_posix(),dtype=None,encoding='utf-8',delimiter=',')
val_mIoU=np.genfromtxt((Path(workspace_path)/model_folder_name/"mIoU.csv").as_posix(),dtype=None,encoding='utf-8',delimiter=',')

plt.figure(figsize=(12, 5))
plt.subplot(1, 3, 1)
plt.plot(training_loss, color='blue', label='Traning loss')
plt.plot(val_loss, color='orange', label='Validation loss')
plt.title("Loss")
plt.legend(loc='best')
plt.subplot(1, 3, 2)
plt.plot(val_accuracy, color='black', label='Accuracy')
plt.plot(val_precision, color='red', label='Precision')
plt.legend(loc='best')
plt.title("Accuracy & Precision")
plt.subplot(1, 3, 3)
plt.plot(val_pixel_wise_accuracy, color='yellow', label='pixel_acc')
plt.plot(val_mIoU, color='green', label='mIoU')
plt.legend(loc='best')
plt.title("Pixel_acc & Mean IoU")
plt.show()

In [None]:
# # Testing parameters
# model_file_name="model_params_1_15.pth"
# # Load testing data
# testing_dataset=CustomDataset(workspace_path,"test_data_path.csv",color_transforms,target_size,loading_mode)
# testing_loader=DataLoader(testing_dataset,batch_size,shuffle=False,drop_last=True)
# # Testing
# torch.cuda.empty_cache()
# model=BiSeNet(num_classes)
# model.load_state_dict(torch.load((Path(workspace_path)/model_folder_name/model_file_name)))
# model.to(device)
# model.eval()

# correct_pixel_sum=0
# pixel_sum=0
# total_intersection=np.zeros(num_classes,dtype=np.int64)
# total_union=np.zeros(num_classes,dtype=np.int64)
# with torch.no_grad():
#     for images, labels in tqdm(testing_loader, desc="testing: ", leave=False):
#         images = images.to(device)
#         labels = labels.to(device)
#         outputs = model(images)
#         probs = torch.softmax(outputs, dim=1)
#         pred_labels = torch.argmax(probs, dim=1)
#         correct_pixel_sum+=(pred_labels==labels).sum().item()
#         pixel_sum+=labels.numel()
#         for pred, label in zip(pred_labels, labels):
#             for index in range(num_classes):
#                 pred_inds = (pred == index)
#                 label_inds = (label == index)
#                 intersection = (pred_inds & label_inds).sum().item()
#                 union = (pred_inds | label_inds).sum().item()
#                 total_intersection[index] += intersection
#                 total_union[index] += union
# pixel_accuracy=correct_pixel_sum/pixel_sum
# iou=total_intersection[1:]/(total_union+1e-6)[1:]
# mIoU=iou.mean().item()
# print(f"pixel-wise accuracy: {pixel_accuracy:.4f} mIoU: {mIoU:.4f}")