In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Dataset
import numpy as np
from dataloader import PASTIS_Dataset
from collate import pad_collate
import torch.nn.functional as F
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch.cuda.amp import autocast, GradScaler
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

Порядок комментариев к коду: китайский / русский / английский.

数据集加载 / Загрузка набора данных / loading dataset

In [2]:
class DynamicTimePointDataset(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices
        self.index_mapping = self._create_index_mapping()

    def _create_index_mapping(self):
        mapping = []
        for idx in self.indices:
            (data, dates), target = self.dataset[idx]
            s2_data = data['S2']
            num_time_points = s2_data.shape[0]
            for time_point in range(num_time_points):
                mapping.append((idx, time_point))
        return mapping

    def __len__(self):
        return len(self.index_mapping)

    def __getitem__(self, idx):
        patch_idx, time_point_idx = self.index_mapping[idx]
        (data, dates), target = self.dataset[patch_idx]
        s2_data = data['S2']
        time_point_data = s2_data[time_point_idx].unsqueeze(0) 
        return time_point_data, target


In [None]:
# 获取并处理数据集 / Получение и обработка набора данных / Getting and processing the dataset
path_to_dataset = 'E:/Research/Newdata/PASTIS'
dataset = PASTIS_Dataset(path_to_dataset, norm=True, target='semantic') # 使用语义分割标签 / Использование меток семантической сегментации / Using semantic segmentation labels

subset_indices = torch.randperm(len(dataset))[:1500].tolist()
dynamic_dataset = DynamicTimePointDataset(dataset, subset_indices)
total_samples = len(dynamic_dataset)
print(f"Total number of data samples: {total_samples}")

# 划分训练集和验证集 / Разделение на обучающий и проверочный наборы / Splitting into training and validation sets
train_size = int(0.8 * len(dynamic_dataset))
valid_size = len(dynamic_dataset) - train_size
train_dataset, valid_dataset = random_split(dynamic_dataset, [train_size, valid_size])

# 创建 DataLoader / Создание DataLoader / Creating DataLoader
train_loader = DataLoader(train_dataset, batch_size=16, collate_fn=pad_collate, shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, collate_fn=pad_collate, pin_memory=True)

# 类别数 / Количество классов / Number of classes
num_classes = 20

UNET 3+/ Модель UNET 3+ / UNET 3+ model

In [4]:
# Double Convolution / двойная свертка
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)
# Encoder / кодер
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)
# Decoder / декодер
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels, in_channels // 2 if bilinear else None)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


In [5]:
class UNet3Plus(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet3Plus, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)

        self.up1 = Up(768, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(320, 64, bilinear)
        self.up4 = Up(192, 64, bilinear)

        self.outc = OutConv(64, n_classes)

        self.full_scale1 = nn.Conv2d(64, 64, kernel_size=1)
        self.full_scale2 = nn.Conv2d(128, 64, kernel_size=1)
        self.full_scale3 = nn.Conv2d(256, 64, kernel_size=1)
        self.full_scale4 = nn.Conv2d(512, 64, kernel_size=1)
        self.full_scale5 = nn.Conv2d(512, 64, kernel_size=1)

        self.deep_supervision1 = nn.Sequential(
            nn.Conv2d(512, n_classes, kernel_size=1),
            nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False) 
        )
        self.deep_supervision2 = nn.Sequential(
            nn.Conv2d(256, n_classes, kernel_size=1),
            nn.Upsample(scale_factor=8, mode='bilinear', align_corners=False)  
        )
        self.deep_supervision3 = nn.Sequential(
            nn.Conv2d(128, n_classes, kernel_size=1),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) 
        )
        self.deep_supervision4 = nn.Sequential(
            nn.Conv2d(64, n_classes, kernel_size=1),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 
        )
        self.deep_supervision5 = nn.Sequential(
            nn.Conv2d(64, n_classes, kernel_size=1)
        )

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        fs1 = self.full_scale1(x1)
        fs2 = self.full_scale2(x2)
        fs3 = self.full_scale3(x3)
        fs4 = self.full_scale4(x4)
        fs5 = self.full_scale5(x5)

        ds1 = self.deep_supervision1(x5)

        target_height, target_width = x4.size(2), x4.size(3)
        fs1 = F.interpolate(fs1, size=(target_height, target_width), mode='bilinear', align_corners=False)
        fs2 = F.interpolate(fs2, size=(target_height, target_width), mode='bilinear', align_corners=False)
        fs3 = F.interpolate(fs3, size=(target_height, target_width), mode='bilinear', align_corners=False)
        fs4 = F.interpolate(fs4, size=(target_height, target_width), mode='bilinear', align_corners=False)
        x_up1 = torch.cat([fs1,fs2,fs3,fs4], dim=1)
        x = self.up1(x5, x_up1)
        ds2 = self.deep_supervision2(x)

        target_height, target_width = x3.size(2), x3.size(3)
        fs1 = F.interpolate(fs1, size=(target_height, target_width), mode='bilinear', align_corners=False)
        fs2 = F.interpolate(fs2, size=(target_height, target_width), mode='bilinear', align_corners=False)
        fs3 = F.interpolate(fs3, size=(target_height, target_width), mode='bilinear', align_corners=False)
        fs5 = F.interpolate(fs5, size=(target_height, target_width), mode='bilinear', align_corners=False)
        x_up2 = torch.cat([fs1,fs2,fs3, fs5], dim=1)
        x = self.up2(x, x_up2)
        ds3 = self.deep_supervision3(x)
       
        target_height, target_width = x2.size(2), x2.size(3)
        fs1 = F.interpolate(fs1, size=(target_height, target_width), mode='bilinear', align_corners=False)
        fs2 = F.interpolate(fs2, size=(target_height, target_width), mode='bilinear', align_corners=False)
        fs5 = F.interpolate(fs5, size=(target_height, target_width), mode='bilinear', align_corners=False)
        x_up3 = torch.cat([fs1,fs2,fs5], dim=1)
        x = self.up3(x, x_up3)
        ds4 = self.deep_supervision4(x)

        target_height, target_width = x1.size(2), x1.size(3)
        fs1 = F.interpolate(fs1, size=(target_height, target_width), mode='bilinear', align_corners=False)
        fs5 = F.interpolate(fs5, size=(target_height, target_width), mode='bilinear', align_corners=False)
        x_up4 = torch.cat([fs1, fs5], dim=1)
        x = self.up4(x, x_up4)
        ds5 = self.deep_supervision5(x)

        logits = self.outc(x)

        return logits, ds1, ds2, ds3, ds4, ds5


GPU / Определите, можно ли использовать Cuda / To see if Cuda can be used

In [6]:
if torch.cuda.is_available():
    print("CUDA is available. GPU is used now.")
    device = torch.device("cuda")
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")

CUDA is available. GPU support enabled.


早停 / Ранняя остановка / Early stopping

In [7]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

训练模型 / обучение модели / model training

In [None]:
# 初始化模型和优化器 / Инициализация модели и оптимизатора / Initializing the model and optimizer
model = UNet3Plus(n_channels=10, n_classes=20)
model.load_state_dict(torch.load('best_unet3+_original_1792.pth'))
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-7)
scaler = GradScaler()
criterion = nn.CrossEntropyLoss()

early_stopping = EarlyStopping(patience=5, min_delta=0.0001)

# 初始化学习率调度器 / Инициализация планировщика скорости обучения / Initializing the learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1, verbose=True, min_lr=1e-6)

def save_model(model, path):
    torch.save(model.state_dict(), path)

writer = SummaryWriter()
# 记录训练过程 / Запись процесса обучения / Recording the training process
train_losses = []
val_losses = []
overall_accuracies = []
precision_scores = []
f1_scores = []
recall_scores = []
best_val_loss = float('inf')  # 初始化最佳验证损失

# 训练循环 / Цикл обучения / Training loop
epochs = 50 # 训练周期 / Эпохи обучения / Training epochs
for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    for batch_idx, batch_data in tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch+1}/{epochs}', leave=False):
        (inputs, targets) = batch_data
        targets = targets.to(device).long()
        optimizer.zero_grad()

        with autocast():
            inputs = torch.squeeze(inputs, dim=1).to(device) 
            outputs = model(inputs)
            logits, ds1, ds2, ds3, ds4, ds5= outputs
            loss_main = criterion(logits, targets)
            loss_ds1 = criterion(ds1, targets)
            loss_ds2 = criterion(ds2, targets)
            loss_ds3 = criterion(ds3, targets)
            loss_ds4 = criterion(ds4, targets)
            loss_ds5 = criterion(ds5, targets)
            loss = loss_main + 0.5 * (loss_ds1  + loss_ds2 + loss_ds3 + loss_ds4 + loss_ds5)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()  # 累加训练损失 / Накопление потерь обучения / Accumulating training loss

    train_loss /= len(train_loader)  # 计算平均训练损失 / Вычисление средних потерь обучения / Calculating average training loss

    if (epoch +1) % 2 == 0:
        # 验证阶段 / Валидация / Validation phase
        model.eval()  # 设置模型为评估模式 / Установка модели в режим оценки / Setting the model to evaluation mode
        val_loss = 0.0
        correct_pixels = 0
        total_pixels = 0
        all_predictions = []
        all_targets = []
        with torch.no_grad():  # 在这个阶段不计算梯度 / На этом этапе градиенты не вычисляются / Gradients are not calculated at this stage
            for batch_data in valid_loader:
                (inputs, targets) = batch_data
                targets = targets.to(device).long()
                inputs = torch.squeeze(inputs, dim=1).to(device)  # 现在 inputs 的形状是 [batch_size, C, H, W]

                outputs = model(inputs)

                logits, ds1, ds2, ds3, ds4, ds5= outputs
                loss_main = criterion(logits, targets)
                loss_ds1 = criterion(ds1, targets)
                loss_ds2 = criterion(ds2, targets)
                loss_ds3 = criterion(ds3, targets)
                loss_ds4 = criterion(ds4, targets)
                loss_ds5 = criterion(ds5, targets)
                loss = loss_main + 0.5 * (loss_ds1 + loss_ds2+ loss_ds3 + loss_ds4 + loss_ds5)

                val_loss += loss.item()  # 累加验证损失 / Накопление потерь валидации / Accumulating validation loss
                # 计算准确率 /  Вычисление точности / Calculating accuracy
                _, predicted = torch.max(logits, 1)  # 获取最大概率的预测结果 / Получение предсказанных результатов с максимальной вероятностью / Getting predicted results with maximum probability
                correct_pixels += (predicted == targets).sum().item()  # 累加正确预测的像素数 / Накопление количества правильно предсказанных пикселей / Accumulating the number of correctly predicted pixels
                total_pixels += targets.nelement()  # 累加总像素数 / Накопление общего количества пикселей / Accumulating the total number of pixels
                all_predictions.append(predicted.cpu().numpy())
                all_targets.append(targets.cpu().numpy())

        all_predictions_flattened = np.concatenate(all_predictions).reshape(-1)
        all_targets_flattened = np.concatenate(all_targets).reshape(-1)

        val_loss /= len(valid_loader)  # 计算平均验证损失 / Вычисление средних потерь валидации / Calculating average validation loss
        overall_accuracy = correct_pixels / total_pixels  # 计算总体准确率 / Вычисление общей точности / Calculating overall accuracy
        precision = precision_score(all_targets_flattened, all_predictions_flattened, average='macro', zero_division=0)  # 计算精确率 / Вычисление точности / Calculating precision
        recall = recall_score(all_targets_flattened, all_predictions_flattened, average='macro', zero_division=0)  # 计算召回率 / Вычисление полноты / Calculating recall
        f1 = f1_score(all_targets_flattened, all_predictions_flattened, average='macro', zero_division=0)  # 计算F1 / Вычисление F1 / Calculating F1

        print(f"Epoch {epoch+1}/{epochs}, Training Loss: {train_loss}, Validation Loss: {val_loss}, Overall Accuracy: {overall_accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

        train_losses.append(train_loss)  # 记录训练损失 / Запись потерь обучения / Recording training loss
        val_losses.append(val_loss)  # 记录验证损失 / Запись потерь валидации / Recording validation loss
        overall_accuracies.append(overall_accuracy)  # 记录总体准确率 / Запись общей точности / Recording overall accuracy
        precision_scores.append(precision)  # 记录精确率 / Запись точности / Recording precision
        recall_scores.append(recall)  # 记录召回率 / Запись полноты / Recording recall
        f1_scores.append(f1)  # 记录F1 / Запись F1 / Recording F1

        # 记录到TensorBoard / Запись в TensorBoard / Recording to TensorBoard
        for name, param in model.named_parameters():
            writer.add_histogram(f'Weights/{name}', param, epoch)
            if param.grad is not None:
                writer.add_histogram(f'Gradients/{name}', param.grad, epoch)
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/overall', overall_accuracy, epoch)
        writer.add_scalar('Precision', precision, epoch)
        writer.add_scalar('Recall', recall, epoch)
        writer.add_scalar('F1', f1, epoch)
        writer.add_scalar('Learning rate', optimizer.param_groups[0]['lr'], epoch)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_model(model, 'best_unet3+_original.pth')
            print(f"Model saved at Epoch {epoch+1}: Improved validation loss to {best_val_loss:.4f}")

        # 在这里调用学习率调度器，基于验证损失 / Вызов планировщика скорости обучения на основе потерь валидации / Calling the learning rate scheduler here, based on validation loss
        scheduler.step(val_loss)

        # 检查是否需要早停 / Проверка на необходимость досрочной остановки / Checking if early stopping is needed
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break
writer.close()

计算mIoU / Рассчитать mIoU / Calculate mIoU

In [None]:
def calculate_iou(predicted, target, num_classes):
    iou_list = []
    for cls in range(num_classes):
        pred_inds = predicted == cls
        target_inds = target == cls
        intersection = (pred_inds & target_inds).sum().item()
        union = pred_inds.sum().item() + target_inds.sum().item() - intersection
        if union == 0:
            # 避免除以0 / Избегание деления на 0 / Avoiding division by zero
            iou_list.append(float('nan'))  # 该类别未出现在预测和目标中 / Этот класс не появляется в прогнозе и цели / This class does not appear in the prediction and target
        else:
            iou_list.append(intersection / union)
    # 忽略nan值计算平均IoU / Игнорирование значений nan при вычислении среднего IoU / Ignoring nan values when calculating mean IoU
    iou_list = [x for x in iou_list if not np.isnan(x)]
    mean_iou = sum(iou_list) / len(iou_list) if iou_list else float('nan')
    return mean_iou

# 模型验证和计算Mean IoU / Проверка модели и вычисление среднего IoU / Model validation and calculating Mean IoU
def validate_and_calculate_iou(model, loader, device, num_classes):
    model.eval()
    total_iou = 0.0
    correct_pixels = 0
    total_pixels = 0
    all_targets = []
    all_predictions = []
    with torch.no_grad():
        for (inputs, targets) in loader:
            targets = targets.to(device).long()
            inputs = torch.squeeze(inputs, dim=1).to(device) 

            outputs = model(inputs)
            logits, ds1, ds2, ds3, ds4, ds5 = outputs
            _, predicted = torch.max(logits, 1)
            total_iou += calculate_iou(predicted, targets, num_classes)
            # 计算准确率 / Вычисление точности / Calculating accuracy
            _, predicted = torch.max(outputs, 1)  # 获取最大概率的预测结果 / Получение предсказанных результатов с максимальной вероятностью / Getting predicted results with maximum probability
            correct_pixels += (predicted == targets).sum().item()  # 累加正确预测的像素数 / Накопление количества правильно предсказанных пикселей / Accumulating the number of correctly predicted pixels
            total_pixels += targets.nelement()  # 累加总像素数 / Накопление общего количества пикселей / Accumulating the total number of pixels
            all_predictions.append(predicted.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    # 扁平化预测和目标张量 / Плоскость тензоров предсказаний и целей / Flattening the prediction and target tensors
    all_predictions_flattened = np.concatenate(all_predictions).reshape(-1)
    all_targets_flattened = np.concatenate(all_targets).reshape(-1)

    mean_iou = total_iou / len(loader)
    overall_accuracy = correct_pixels / total_pixels  # 计算总体准确率 / Вычисление общей точности / Calculating overall accuracy
    precision = precision_score(all_targets_flattened, all_predictions_flattened, average='macro', zero_division=0)  # 计算精确率 / Вычисление точности / Calculating precision
    recall = recall_score(all_targets_flattened, all_predictions_flattened, average='macro', zero_division=0)  # 计算召回率 / Вычисление полноты / Calculating recall
    f1 = f1_score(all_targets_flattened, all_predictions_flattened, average='macro', zero_division=0)  # 计算Fs1 / Вычисление F1 / Calculating F1
    print(f"Mean IoU on validation set: {mean_iou}, Overall Accuracy: {overall_accuracy:.4f}", f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

验证 / Проверить модель / Validate model

In [None]:
# 调用验证函数 / Вызов функции валидации / Calling the validation function
validate_and_calculate_iou(model, valid_loader, device, num_classes)

# 计算模型参数数量 / Вычисление количества параметров модели / Calculating the number of model parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total trainable parameters: {total_params}")


可视化 / Визуализация / Visualization

In [None]:
plt.figure(figsize=(12, 5))
plt.subplot(1, 3, 1)
plt.plot(range(1, len(train_losses)+1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses)+1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Over Epochs')

plt.subplot(1, 3, 2)
plt.plot(range(1, len(overall_accuracies)+1), overall_accuracies, label='Overall Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Overall Accuracy')
plt.legend()
plt.title('Overall Over Epochs')

plt.subplot(1, 3, 3)
plt.plot(range(1, len(precision_scores)+1), precision_scores, label='Precision')
plt.xlabel('Epochs')
plt.ylabel('Precision')
plt.legend()
plt.title('Precision Over Epochs')

plt.subplot(2, 2, 1)
plt.plot(range(1, len(recall_scores)+1), recall_scores, label='Recall')
plt.xlabel('Epochs')
plt.ylabel('Recall')
plt.legend()
plt.title('Recall Over Epochs')

plt.subplot(2, 2, 2)
plt.plot(range(1, len(f1_scores)+1), f1_scores, label='F1')
plt.xlabel('Epochs')
plt.ylabel('F1')
plt.legend()
plt.title('F1 Over Epochs')

plt.tight_layout()
plt.show()

conf_mat = confusion_matrix(all_targets_flattened, all_predictions_flattened)
plt.figure(figsize=(10, 8))
sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix')
plt.show()

print(classification_report(all_targets_flattened, all_predictions_flattened))

In [None]:
valid_data = torch.load('valid_data.pth')

class SimpleDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

valid_dataset = SimpleDataset(valid_data)
valida_loader = DataLoader(valid_dataset, batch_size=16, collate_fn=pad_collate, pin_memory=True)


In [None]:
model = UNet3Plus(n_channels=10, n_classes=20)
model.load_state_dict(torch.load('best_unet3_204.pth'))
model.to(device)

def visualize_overlay(images, labels, predictions, alpha=0.5, num_images=3):
    colors = [
    '#FFFFFF',  # white for background class 0
    '#E6194B',  # red for class 1
    '#3CB44B',  # green for class 2
    '#FFE119',  # yellow for class 3
    '#4363D8',  # blue for class 4
    '#F58231',  # orange for class 5
    '#911EB4',  # purple for class 6
    '#46F0F0',  # cyan-blue for class 7
    '#F032E6',  # pink for class 8
    '#BCF60C',  # lime Green for class 9
    '#FABEBE',  # light pink for class 10
    '#008080',  # light cyan-blue for class 11
    '#E6BEFF',  # mauve for class 12
    '#9A6324',  # brown for class 13
    '#FFFAC8',  # cream for class 14
    '#800000',  # maroon for class 15
    '#AAFFC3',  # Mint Green for class 16
    '#808000',  # Olive Green for class 17
    '#FFD8B1',  # coral for class 18
    '#000075',  # Dark Blue for class 19
    ]

    cmap_custom = ListedColormap(colors)
    norm = BoundaryNorm(np.arange(len(colors) + 1), cmap_custom.N)  

    fig, axs = plt.subplots(num_images, 3, figsize=(15, 5 * num_images))
    for i in range(num_images):
        if num_images == 1:
            ax1, ax2, ax3 = axs
        else:
            ax1, ax2, ax3 = axs[i]

        img_display = images[i][[1, 2, 3]].permute(1, 2, 0).cpu().numpy()
        img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min())

        ax1.imshow(img_display)
        ax1.set_title("Original Image - RGB")
        ax1.axis('off')

        ax2.imshow(img_display) 
        ax2.imshow(labels[i].cpu().numpy(), cmap=cmap_custom, norm=norm, alpha=alpha) 
        ax2.set_title("True Label Overlay")
        ax2.axis('off')

        ax3.imshow(img_display) 
        ax3.imshow(predictions[i].cpu().numpy(), cmap=cmap_custom, norm=norm, alpha=alpha) 
        ax3.set_title("Prediction Overlay")
        ax3.axis('off')

    plt.show()


model.eval()
with torch.no_grad():
    for (inputs, targets) in valida_loader:
        print(inputs.shape)
        targets = torch.squeeze(targets, dim=1)
        targets = targets.to(device).long()
        inputs = torch.squeeze(inputs, dim=1).to(device)  
        inputs = torch.squeeze(inputs, dim=1).to(device)

        print(inputs.shape, targets.shape)
        outputs = model(inputs)
        logits, ds1, ds2, ds3 = outputs
        _, predicted = torch.max(logits, 1)

        visualize_overlay(inputs, targets, predicted, num_images=10)
        break  