**Create predictions (pseudo labels) from unsupervised images (no labels) using the trained model obtained after the supervised phase.**

In [None]:
# get local path to project
import json

with open("config.json", "r") as f:
    config = json.load(f)
BASE_PATH = config["PROJECT_DIR"]
DATA_PATH = config["DATA_DIR"]

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import albumentations as albu
import cv2
import numpy as np
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.utils import metrics, losses, base
import random
import matplotlib.pyplot as plt
import os
from copy import deepcopy
from datetime import datetime
import torch.nn.functional as F

%matplotlib inline

## Dataloader

In [None]:
class Dataset(BaseDataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.

    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        augmentation (albumentations.Compose): data transfromation pipeline
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing
            (e.g. noralization, shape manipulation, etc.)

    """

    def __init__(
            self,
            list_IDs,
            images_dir,
            preprocessing=None,
            resize=(False, (256, 256)), # To resize, the first value has to be True
            n_classes:int=4,
    ):
        self.ids = list_IDs
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]

        self.preprocessing = preprocessing
        self.resize = resize
        self.n_classes = n_classes

    def __getitem__(self, i):

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.resize[0]:
            image = cv2.resize(image, self.resize[1], interpolation=cv2.INTER_NEAREST)

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image)
            image = sample['image']

        return image

    def __len__(self):
        return len(self.ids)

## Augmentation

In [None]:
def get_training_augmentation():
    train_transform = [

        albu.OneOf(
            [
                albu.HorizontalFlip(p=0.5),
                albu.VerticalFlip(p=0.5),
            ],
            p=0.8,
        ),

        albu.OneOf(
            [
                albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0, p=0.1, border_mode=0), # scale only
                albu.ShiftScaleRotate(scale_limit=0, rotate_limit=30, shift_limit=0, p=0.1, border_mode=0), # rotate only
                albu.ShiftScaleRotate(scale_limit=0, rotate_limit=0, shift_limit=0.1, p=0.6, border_mode=0), # shift only
                albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=30, shift_limit=0.1, p=0.2, border_mode=0), # affine transform
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.Perspective(p=0.2),
                albu.GaussNoise(p=0.2),
                albu.Sharpen(p=0.2),
                albu.Blur(blur_limit=3, p=0.2),
                albu.MotionBlur(blur_limit=3, p=0.2),
            ],
            p=0.5,
        ),

        albu.OneOf(
            [
                albu.CLAHE(p=0.25),
                albu.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.25),
                albu.RandomGamma(p=0.25),
                albu.HueSaturationValue(p=0.25),
            ],
            p=0.3,
        ),

    ]

    return albu.Compose(train_transform, p=0.9) # 90% augmentation probability


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        # albu.PadIfNeeded(512, 512)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform

    Args:
        preprocessing_fn (callbale): data normalization function
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose

    """

    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

## Parameters

In [None]:
# Parameters
BASE_MODEL = 'MiT+pscse'
ENCODER = 'mit_b3'
ENCODER_WEIGHTS = 'imagenet'
BATCH_SIZE = 16
n_classes = 4
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LR = 0.0001 # learning rate
EPOCHS = 500
WEIGHT_DECAY = 1e-5
SAVE_WEIGHTS_ONLY = True
RESIZE = (False, (256,256)) # if resize needed
TO_CATEGORICAL = True
SAVE_BEST_MODEL = True
SAVE_LAST_MODEL = False

PERIOD = 10 # periodically save checkpoints
RAW_PREDICTION = False # if true, then stores raw predictions (i.e. before applying threshold)
RETRAIN = False

# For early stopping
EARLY_STOP = True # True to activate early stopping
PATIENCE = 50 # for early stopping

import ssl
ssl._create_default_https_context = ssl._create_unverified_context


## Load model

In [None]:
# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    # aux_params=aux_params,
    classes=n_classes,
    activation=ACTIVATION,
    decoder_attention_type='pscse',
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

model.to(DEVICE)

# Optimizer
optimizer = torch.optim.Adam([
    dict(params=model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY),
])

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                              factor=0.1,
                              mode='min',
                              patience=10,
                              min_lr=0.00001,
                              verbose=True,
                              )

# Model name that will be loaded
model_name = 'MiT+pscse_padded_aug_mit_b3_sup_2025-03-17_10-12-44'
print(model_name)

# Checkpoint directory
checkpoint_loc = DATA_PATH + 'checkpoints/' + model_name

# Load model====================================================================
checkpoint = torch.load(os.path.join(checkpoint_loc, 'best_model.pth'))
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

## Read unsupervised image names

In [None]:
x_test_dir = DATA_PATH + 'Unlabeled/'

list_IDs_test = os.listdir(x_test_dir)

print('No. of test images: ', len(list_IDs_test))

## Inference

In [None]:
# Test dataloader ==============================================================
test_dataset = Dataset(
    list_IDs_test,
    x_test_dir,
    preprocessing=get_preprocessing(preprocessing_fn),
    resize=(RESIZE),
    n_classes=n_classes,
)

test_dataloader = DataLoader(test_dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=6)

# Prediction ===================================================================
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import confusion_matrix
import scipy.io as sio

import warnings
warnings.filterwarnings("ignore")

save_pred = True
threshold = 0.5
ep = 1e-6
raw_pred = []

HARD_LINE = True

# Save directory
save_dir_pred = DATA_PATH + 'predictions/' + model_name + '_selfsupervised'
if not os.path.exists(save_dir_pred): os.makedirs(save_dir_pred)

save_dir_pred_palette = DATA_PATH + 'predictions_palettte/' + model_name + '_selfsupervised'
if not os.path.exists(save_dir_pred_palette): os.makedirs(save_dir_pred_palette)

iter_test_dataloader = iter(test_dataloader)

palette = [[128, 128, 128], [255, 0, 0], [0, 255, 0], [0, 0, 255]]

for enu, i in enumerate(range(len(list_IDs_test))):

    name = os.path.splitext(list_IDs_test[i])[0] # remove extension

    print('Processing:', enu, name)

    # Image-wise mean of metrics
    i_mp, i_mr, i_mdice, i_miou = [], [], [], []

    image = next(iter_test_dataloader) # get image and mask as Tensors

    pr_mask = model.predict(image.to(DEVICE)) # Move image tensor to gpu

    # Convert from onehot
    # gt_mask = torch.argmax(gt_mask_, dim=1)
    if TO_CATEGORICAL:
        pr_mask = torch.argmax(pr_mask, dim=1)

    # pr_mask = torch.argmax(pr_mask, dim=1)

    # Move to CPU and convert to numpy
    pred = pr_mask.squeeze().cpu().numpy()

    # Save raw prediction
    if RAW_PREDICTION: raw_pred.append(pred)

    # Save prediction as png
    if save_pred:
        "Uncomment for non-palette"
        cv2.imwrite(os.path.join(save_dir_pred, list_IDs_test[i]), np.squeeze(pred).astype(np.uint8))

        "Uncomment for palette"
        # Palette prediction
        pal_pred = np.squeeze(pred).astype(np.uint8)
        pal_pred = Image.fromarray(pal_pred)
        pal_pred = pal_pred.convert("P")
        pal_pred.putpalette(np.array(palette, dtype=np.uint8))

        # Store
        pal_pred.save(os.path.join(save_dir_pred_palette, list_IDs_test[i]))
