In [1]:
import torch
import numpy as np
import torch.nn as nn
import random
import os
import cv2
import logging
import torch.nn.functional as F
import matplotlib.pyplot as plt
import albumentations as A

from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from res_att_unet import ResAttUnet
from eval_metrics import calc_scores
from utils import plot_graph

from datetime import datetime
from dataset import Datasets_1Step, Datasets_2Step, Datasets_3Step

from res_att_unet import ResAttEncoder, ResAttUnet, MLP
from nt_xent import NTXentLoss
from supcon_loss import BlockConLoss
from utils import EarlyStopping, format_time, plot_graph, plot_loss
from validation import validate_model
from dice_loss import SoftDiceLoss

**Hyperparameters and Data Loading paths**

In [2]:
PRE_LEARNING_RATE = 5e-4
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
PRE_BATCH_SIZE = 2
BATCH_SIZE = 4
PRE_NUM_EPOCHS = 100
NUM_EPOCHS = 150
SEED = 42

# Data paths
# For pre-training in the first step
PRE_TRAIN_IMG_DIR = "./data/train_imgs.npz"       # 100% of training dataset for pre-training

# For training in the second step
LABELED_IMG_DIR = "./data/imgs/x_train_1.npz"    # 10 of training images to treat as labeled
LABELED_GTS_DIR = "./data/gts/y_train_1.npz"   # 10 of training images to treat as labeled
UNLABELED_IMG_DIR = "./data/unlabeled_imgs/x_ul_1.npz"       # 90 of training images to treat as unlabeled
PSEUDO_GTS_DIR = "./data/pseudo_gts/pseudo_label_1.npz"  # 90 of training images to treat as pseudo labels

TRAIN_IMG_DIR = "./data/train_imgs.npz"  # 80% of whole dataset for training
TRAIN_GTS_DIR = "./data/train_gts.npz"   # 80% of whole dataset for training
VAL_IMG_DIR = "./data/val_imgs.npz"    # 20% of whole dataset for validation
VAL_GTS_DIR = "./data/val_gts.npz"   # 20% of whole dataset for validation

**Data preprocessing**

In [None]:
# Data preproprecessing for 1Step pre-training
train_imgs_1step = np.load(PRE_TRAIN_IMG_DIR, allow_pickle=True)['data']
train_data_1step = Datasets_1Step(train_imgs_1step)
print(train_imgs_1step.shape)

# Data preprocessing for 2Step pre-training
labeled_imgs = np.load(LABELED_IMG_DIR, allow_pickle=True)['data']
labeled_gts = np.load(LABELED_GTS_DIR, allow_pickle=True)['data']
unlabeled_imgs = np.load(UNLABELED_IMG_DIR, allow_pickle=True)['data']
pseudo_gts = np.load(PSEUDO_GTS_DIR, allow_pickle=True)['data']
train_imgs_2step = np.concatenate([labeled_imgs, unlabeled_imgs], axis=0)
train_gts_2step = np.concatenate([labeled_gts, pseudo_gts], axis=0)
train_data_2step = Datasets_2Step(train_imgs_2step, train_gts_2step, transform=True)
print(labeled_imgs.shape, labeled_gts.shape, unlabeled_imgs.shape, pseudo_gts.shape)

# Data preprocessing for 3Step pre-training (segmentation training)
train_dataset_3step = Datasets_3Step(labeled_imgs, labeled_gts, transform=True)


# Data preprocessing for 3Step training (segmentation training)
train_imgs = np.load(TRAIN_IMG_DIR, allow_pickle=True)['data']
train_gts = np.load(TRAIN_GTS_DIR, allow_pickle=True)['data']
val_imgs = np.load(VAL_IMG_DIR, allow_pickle=True)['data']
val_gts = np.load(VAL_GTS_DIR, allow_pickle=True)['data']
train_dataset = Datasets_3Step(train_imgs, train_gts, transform=True)
val_dataset = Datasets_3Step(val_imgs, val_gts, transform=False)

**EXTRA FUNCTION FOR**

In [4]:
# Extra functions for visualization
selected_class_indice = [0, 1, 2, 3, 4]
selected_class_rgb = [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0]]

def color_code_segment(image):
    color_code = np.array(selected_class_rgb)
    x = color_code[image.astype(int)]
    return x


def create_overlay(image, mask, alpha=0.5):
    # Ensure the image and mask are in the same shape
    if mask.ndim == 2:  # If mask is grayscale, convert to RGB
        mask = np.stack([mask] * 3, axis=-1)

    if image.ndim == 2:
        image = np.stack([image] * 3, axis=-1)

    # Blend the images
    overlay = alpha * mask + (1 - alpha) * image
    return np.clip(overlay, 0, 1)  # Ensure values stay in [0, 1]

import random

def visualize(images, masks):
    fontsize = 18
    num_images = min(5, len(images))  # Limit to 10 images if there are more

    # Randomly select indices from the whole dataset
    indices = random.sample(range(len(images)), num_images)

    f, ax = plt.subplots(3, num_images, figsize=(num_images * 4, 10))

    for i in range(num_images):
        idx = indices[i]
        image = images[idx] / 255.0 
        ax[0, i].imshow(image, cmap='gray')
        ax[0, i].set_title(f'Image Index: {idx}', fontsize=fontsize)
        ax[0, i].axis('off')

        ax[1, i].imshow(color_code_segment(masks[idx]))
        ax[1, i].set_title(f'Ground Truth Index: {idx}', fontsize=fontsize)
        ax[1, i].axis('off')

        overlay = create_overlay(image, color_code_segment(masks[idx]))
        ax[2, i].imshow(overlay)
        ax[2, i].set_title(f"Overlay Index: {idx}", fontsize=fontsize)
        ax[2, i].axis('off')

    plt.tight_layout()
    plt.show()

**FIRST STEP PRE_TRAINING**

In [None]:
# Model for 1Step pre-training (encoder pre-training)
class Encoder(nn.Module):
    def __init__(self, num_classes=128):
        super(Encoder, self).__init__()
        self.backbone = ResAttEncoder(in_channels=3)
        self.projection_head = MLP(512, num_class=num_classes)
        self.reconstruction_head = nn.ConvTranspose2d(in_channels=512, out_channels=3, kernel_size=8, stride=8)

    def forward(self, x1, x2):
        z_1, _ = self.backbone(x1)
        z_2, _ = self.backbone(x2)
        z_recon_1 = self.reconstruction_head(z_1)
        #z_recon_2 = self.reconstruction_head(z_2)

        z_1 = self.projection_head(z_1)
        z_2 = self.projection_head(z_2)

        return x1, z_recon_1, z_1, z_2

# Checkpoint paths
NUM_EPOCHS_1Step = 200
current_date = str(datetime.now().strftime("%Y_%m_%d"))
checkpoint_name = 'checkpoints/1step_pretraining_checkpoint.pth.tar'

# Dataloader
train_loader = DataLoader(train_data_1step, batch_size=8, shuffle=True, drop_last=True)

# Model for training
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
model = Encoder(num_classes=128).to(DEVICE)
loss_nt = NTXentLoss(DEVICE, temperature=0.1, use_cosine_similarity=True)
loss_mse = nn.MSELoss()
scaler = torch.cuda.amp.GradScaler()
optimizer = torch.optim.SGD(model.parameters(), lr=PRE_LEARNING_RATE, momentum=0.9)

logging.disable(logging.NOTSET)
logging.basicConfig(filename=os.path.join('log_file', 'pre_training.log'), level=logging.DEBUG)
early_stopping = EarlyStopping(patience=25, verbose=True, indicator='loss')

loss_trend = []
loss_1_trend = []
loss_2_trend = []
beta = 0.1

for epoch in range(NUM_EPOCHS_1Step):
    model.train()
    epoch_loss = 0
    epoch_loss_1 = 0
    epoch_loss_2 = 0
    training_loss = 0
    running_loss_1 = 0
    running_loss_2 = 0
    loop = tqdm(train_loader)
    for i, (x1, x2) in enumerate(loop):
        x1, x2 = x1.float().to(DEVICE), x2.float().to(DEVICE)

        with torch.cuda.amp.autocast():
            z_real, z_recon, z_anchors, z_positive = model(x1, x2)
            loss_1 = loss_nt(z_anchors, z_positive)
            loss_2 = loss_mse(z_real, z_recon)
            loss = loss_1*beta + loss_2*(1 - beta)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())
        training_loss += loss.item()
        epoch_loss = training_loss / len(train_loader)
        running_loss_1 += loss_1.item()
        epoch_loss_1 = running_loss_1 / len(train_loader)
        running_loss_2 += loss_2.item()
        epoch_loss_2 = running_loss_2 / len(train_loader)

    if torch.isnan(loss).any():
        print("Loss is NaN. Stopping training.")
        break

    loss_trend.append(epoch_loss)
    loss_1_trend.append(epoch_loss_1)
    loss_2_trend.append(epoch_loss_2)

    logging.debug(f"Epoch: {epoch + 1}\tLoss: {epoch_loss :.8f}")
    print(f"Epoch: {epoch + 1}\tLOSS: {epoch_loss :.4f}\tLOSS_1: {epoch_loss_1 :.4f}\tLOSS_2: {epoch_loss_2 :.4f}")

    checkpoint = {
        'epoch': epoch + 1,
        'loss': epoch_loss,
        'model_state_dict': model.backbone.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }

    early_stopping(epoch_loss, model.backbone, checkpoint, checkpoint_name)

    if early_stopping.early_stop:
        print("Early stopped")
        logging.info("Training has finished")
        logging.disable(logging.CRITICAL)
        plot_graph(loss_trend, loss_1_trend, loss_2_trend)
        break

    if epoch == (NUM_EPOCHS_1Step - 1):
        logging.info("Training has finished")
        logging.disable(logging.CRITICAL)
        break

**SECOND STEP PRE_TRAINING**

In [None]:
import time

# Model for 2Step Pre-training (model pre-training)
class UNET(nn.Module):
    def __init__(self, num_classes=128):
        super(UNET, self).__init__()
        self.backbone = ResAttUnet(in_channels=3)
        self.prejector = nn.Sequential(nn.Conv2d(64, 256, kernel_size=1),
                                             nn.Conv2d(256, num_classes, kernel_size=1))

    def forward(self, x):
        y = self.backbone(x)
        y = self.prejector(y)
        return y

# Checkpoint paths
current_date = str(datetime.now().strftime("%Y_%m_%d"))
checkpoint_name = 'checkpoints/2step_pretraining_checkpoint.pth.tar'

# Dataloader
dataloader = DataLoader(train_data_2step, batch_size=PRE_BATCH_SIZE, shuffle=True)

# Model for training
model = UNET(num_classes=128).to(DEVICE)
loss_fn = BlockConLoss(temperature=0.1, block_size=16)
optimizer = torch.optim.Adam(model.parameters(), lr=PRE_LEARNING_RATE, weight_decay=1e-5)

# Load Checkpoint
checkpoint = 'checkpoints/1step_pretraining_checkpoint.pth.tar'
stats = torch.load(checkpoint)
model_state = stats['model_state_dict']
model.backbone.encoder.load_state_dict(model_state)

logging.disable(logging.NOTSET)
logging.basicConfig(filename=os.path.join('log_file', 'pre_training_2step.log'), level=logging.DEBUG)
logging.info(f"\n\nStart SupCon training for {PRE_NUM_EPOCHS} epochs {current_date}")
early_stopping = EarlyStopping(patience=100, verbose=True, indicator='loss')

losses = []
for epoch in range(PRE_NUM_EPOCHS):
    epoch_start_time = time.time()  # Start time for the epoch
    model.train()
    training_loss = 0
    print(f"Epoch: {(epoch + 1)} / {PRE_NUM_EPOCHS}")
    loop = tqdm(dataloader)
    for i, (img_1, img_2, gt_1, gt_2) in enumerate(loop):
        img_1, img_2 = img_1.float(), img_2.float()
        gt_1, gt_2 = gt_1.long(), gt_2.long()

        imgs = torch.cat([img_1, img_2], dim=0)
        labels = torch.cat([gt_1, gt_2], dim=0).squeeze(1)

        with torch.cuda.amp.autocast():
            imgs = imgs.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)

        bsz = imgs.shape[0] // 2
        features = model(imgs)
        features = F.normalize(features, p=2, dim=1)
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
        l1, l2 = torch.split(labels, [bsz, bsz], dim=0)
        labels = torch.cat([l1.unsqueeze(1), l2.unsqueeze(1)], dim=1)
        loss = loss_fn(features, labels)

        if loss.mean() == 0:
            continue
        mask = (loss != 0)
        mask = mask.int().cuda()
        loss = (loss * mask).sum() / mask.sum()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loop.set_postfix(loss=loss.item())
        training_loss += loss.item()

    epoch_loss = training_loss / len(dataloader)

    # End time for the epoch
    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time  # Time taken for the epoch
    epoch_duration_formatted = format_time(epoch_duration)  # Format the epoch duration into HH:MM:SS

    if torch.isnan(loss).any():
        print("Loss in NaN, training stoped")
        break

    losses.append(epoch_loss)
    logging.debug(f"Epoch: {epoch + 1}\t\tLoss: {epoch_loss}")
    print(f"Epoch: {epoch + 1}\tLoss: {epoch_loss:.4f}\tEpoch time: {epoch_duration_formatted}")

    checkpoint = {
        'epoch': epoch + 1,
        'loss': epoch_loss,
        'model_state_dict': model.backbone.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    early_stopping(epoch_loss, model, checkpoint, checkpoint_name)
    if early_stopping.early_stop:
        print("Early stopped")
        logging.disable(logging.CRITICAL)
        break

    if epoch == PRE_NUM_EPOCHS - 1:
        #torch.save(checkpoint, checkpoint_name)
        logging.disable(logging.CRITICAL)
        break

**THIRD STEP PRE_TRAINING (SEGMENTATION TRAININ)**

In [None]:
# Model for downstream training
class FULL_UNET(nn.Module):
    def __init__(self, num_class=5):
        super(FULL_UNET, self).__init__()
        self.backbone = ResAttUnet(in_channels=3)
        self.projection_head = nn.Conv2d(64, num_class, kernel_size=1)

    def forward(self, x):
        y = self.backbone(x)
        output = self.projection_head(y)

        return output
    
# Checkpoint paths
current_date = str(datetime.now().strftime("%Y_%m_%d"))
checkpoint_name = 'checkpoints/3step_training_checkpoint_.pth.tar'

# Dataloader
train_loader = DataLoader(train_dataset_3step, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)

# Model for training
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
model = FULL_UNET(num_class=5).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
ce_loss = nn.CrossEntropyLoss()
dc_loss = SoftDiceLoss(batch_dice=True, do_bg=False, smooth=1.0, apply_nonlin=torch.nn.Softmax(dim=1))
scaler = torch.cuda.amp.GradScaler()

# Load Checkpoint
checkpoint = 'checkpoints/2step_pretraining_checkpoint.pth.tar'
stats = torch.load(checkpoint)
model_state = stats['model_state_dict']
model.backbone.load_state_dict(model_state)

logging.disable(logging.NOTSET)
logging.basicConfig(filename=os.path.join('log_file', 'training.log'), level=logging.DEBUG)
early_stopping = EarlyStopping(patience=30, verbose=True, indicator='dice')

loss_trend = []
val_loss_trend = []
accuracy_trend = []
dice_trend = []
iou_trend = []
for epoch in range(NUM_EPOCHS):
    model.train()
    training_loss = 0
    loop = tqdm(train_loader)
    print(f"============================ Epoch: {epoch + 1}/{NUM_EPOCHS} ============================")
    for idx, (x, y) in enumerate(loop):
        x = x.float().to(DEVICE)
        y = y.long().to(DEVICE)
        y = y.squeeze(1)

        # forward
        with torch.cuda.amp.autocast():
            pred = model(x)
            loss = ce_loss(pred, y) + dc_loss(pred, y)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())
        training_loss += loss.item()
    epoch_loss = training_loss / len(train_loader)
    if torch.isnan(loss).any():
        print("Loss is nan, training stopped")
        break

        # checkpoints
    checkpoint = {
        'epoch': epoch + 1,
        'loss': loss,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }

    # check accuracy
    stats = validate_model(model, val_loader, DEVICE, 5)
    if stats is None:
        print("Validation failed, stats is None")
        break
    acc = stats.get('acc', 0)
    iou = stats.get('iou', 0)
    dice = stats.get('dice', 0)
    val_loss = stats.get('val_loss', 0)
    val_loss_trend.append(val_loss)

    accuracy_trend.append(acc)
    iou_trend.append(iou)
    dice_trend.append(dice)
    loss_trend.append(epoch_loss)
    logging.debug(f"Epoch: {epoch + 1}\tLoss: {epoch_loss}\tIoU: {iou}\tDice: {dice}")
    print(f"train_loss: {epoch_loss:.4f}\tVal_loss: {val_loss:.4f}\tIoU Score: {iou:.2f}\tDice Score: {dice:.2f}")

    early_stopping(dice, model, checkpoint, checkpoint_name)

    if early_stopping.early_stop:
        logging.disable(logging.CRITICAL)
        plot_graph(iou_trend, dice_trend)
        plot_loss(loss_trend, val_loss_trend)
        print("Early stopped")
        break

    if epoch == NUM_EPOCHS - 1:
        logging.disable(logging.CRITICAL)
        plot_graph(iou_trend, dice_trend)
        plot_loss(loss_trend, val_loss_trend)
        break

**VALIDATE AND TEST THE TRAINED MODEL**

In [7]:
def calculate_dice_score(output, target, class_id):
    pred_mask = F.softmax(output, dim=1)
    pred_mask = torch.argmax(pred_mask, dim=1)
    pred_mask = pred_mask == class_id
    true_mask = target == class_id

    confusion_vector = pred_mask / true_mask
    tp = torch.sum(confusion_vector == 1).item()
    fp = torch.sum(confusion_vector == float('inf')).item()
    tn = torch.sum(torch.isnan(confusion_vector)).item()
    fn = torch.sum(confusion_vector == 0).item()

    # Calculate Dice Score
    dice_score = (2 * tp) / (2 * tp + fp + fn + 1e-4)
    # Calculate IoU
    iou = tp / (tp + fp + fn + 1e-4)
    # Calculate Acc
    acc = (tn + tp) / (tn + fp + fn + tp)

    return dice_score, iou, acc

def validate_model(model, dataloader, device, num_classes):
    model.eval()
    total_dice_scores = {class_id: 0.0 for class_id in range(num_classes)}
    total_iou_scores = {class_id: 0.0 for class_id in range(num_classes)}
    total_acc_scores = {class_id: 0.0 for class_id in range(num_classes)}

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader):
            inputs = inputs.float().to(device)
            targets = targets.long().to(device)
            targets = targets.squeeze(1)

            outputs = model(inputs)
            
            for class_id in range(1, num_classes):
                dice_score, iou, acc = calculate_dice_score(outputs, targets, class_id)
                total_dice_scores[class_id] += dice_score
                total_iou_scores[class_id] += iou
                total_acc_scores[class_id] += acc

    average_dice_scores = {class_id: total_dice_scores[class_id] / len(dataloader) * 100 for class_id in range(num_classes)}
    average_iou_scores = {class_id: total_iou_scores[class_id] / len(dataloader) * 100 for class_id in range(num_classes)}
    average_acc_scores = {class_id: total_acc_scores[class_id] / len(dataloader) * 100 for class_id in range(num_classes)}

    mean_average_dice = sum(average_dice_scores.values()) / (len(average_dice_scores) - 1) 
    mean_average_iou = sum(average_iou_scores.values()) / (len(average_iou_scores) - 1) 
    mean_average_acc = sum(average_acc_scores.values()) / (len(average_acc_scores) - 1) 

    print(f"Acc: {mean_average_acc:.2f}\tIoU: {mean_average_iou:.2f}\tDice: {mean_average_dice:.2f}")

    for class_id in range(1, num_classes):
        print(f"Acc {class_id}: {average_acc_scores[class_id]:.2f}\tIoU {class_id}: {average_iou_scores[class_id]:.2f}\tDice {class_id}: {average_dice_scores[class_id]:.2f}")


In [None]:
# Validation on data
trained_model = model.to(DEVICE)
# Load Checkpoint
checkpoint = 'checkpoints/3step_training_checkpoint.pth.tar'
stats = torch.load(checkpoint)
model_state = stats['model_state_dict']
model.load_state_dict(model_state)

validate_model(trained_model, val_loader, DEVICE, num_classes=5)

**PLOTTING SOME EXAMPLE FOR VISUALIZATION**

In [None]:
idx = 158 #L3
trained_model = model.to(DEVICE)

# Convert test images to torch tensor
image_tensor = torch.from_numpy(val_imgs)

# Select a specific image and move it to GPU
image = val_imgs[idx] / 255.0
image = A.Resize(256, 256, interpolation=cv2.INTER_LINEAR)(image=image.transpose(1, 2, 0))['image']
image =  torch.tensor(image, dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2) # Add batch dimension
#print(image.shape)
image_gpu = image.to(DEVICE)

# Convert test labels to torch tensor and select the same label
label = val_gts[idx] 
label_cpu = A.Resize(256, 256, interpolation=cv2.INTER_NEAREST)(image=label)['image']  # Convert to HWC for Albumentations, resize, then back to numpy

# Model prediction
preds = trained_model(image_gpu)
preds = F.softmax(preds, dim=1)
pred_labels = torch.argmax(preds, dim=1)
print(pred_labels.shape)

# Move image, prediction, and label back to CPU for visualization
image_cpu = image_gpu.squeeze().to('cpu').numpy()
pred_cpu = pred_labels.squeeze().to('cpu').numpy()  

# Print the shapes
print(f"image shape: {image_cpu.shape}")
print(f"pred shape: {pred_cpu.shape}")
print(f"label shape: {label_cpu.shape}")

class_names = ['MUS', 'IMAT', 'SAT', 'VAT']  
class_ids = [1, 2, 3, 4] 
num_classes = len(class_ids)

# Visualization
fig, axs = plt.subplots(2, num_classes+2, figsize=(5*(num_classes+2), 11), dpi=300)
title_font = 20

# Column 0: Input Image
axs[0, 0].imshow(image_cpu.transpose(1, 2, 0), cmap='gray')
axs[0, 0].set_title('Input', fontsize=title_font)
axs[0, 0].axis('off')

axs[1, 0].imshow(image_cpu.transpose(1, 2, 0), cmap='gray')
axs[1, 0].set_title('Input', fontsize=title_font)
axs[1, 0].axis('off')

# Column 1: All Classes
axs[0, 1].imshow(color_code_segment(label_cpu))
axs[0, 1].set_title('Ground Truth', fontsize=title_font)
axs[0, 1].axis('off')

axs[1, 1].imshow(color_code_segment(pred_cpu))
axs[1, 1].set_title('Predicted', fontsize=title_font)
axs[1, 1].axis('off')

# Function to plot masks for each class
for i, (class_id, class_name) in enumerate(zip(class_ids, class_names)):
    label_mask = (label_cpu == class_id)
    pred_mask = (pred_cpu == class_id)

    if not label_mask.any() and not pred_mask.any():
        continue

    label_class = label_cpu * label_mask
    pred_class = pred_cpu * pred_mask

    # Display predicted labels
    axs[0, i+2].imshow(color_code_segment(label_class))
    axs[0, i+2].set_title(f" GT: {class_name}", fontsize=title_font)
    axs[0, i+2].axis('off')

    # Display ground truth labels
    axs[1, i+2].imshow(color_code_segment(pred_class))
    axs[1, i+2].set_title(f" Pred: {class_name}", fontsize=title_font)
    axs[1, i+2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
import torch
import gc
import cv2
import torchvision

torch.cuda.empty_cache()

torch.cuda.empty_cache()

gc.collect()
torch.cuda.empty_cache()

print(torch.version.cuda) 
print(torch.cuda.is_available()) 
print(cv2.__version__)

print("torch version:", torch.__version__)
print("torchvision version:", torchvision.__version__)