# Experimental code used for microplastic segmentation and classification with U-Net, VGG16, and LR.
Work done by: Kristupas (https://github.com/KristupasJon)

Mount drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Initiate imports, file paths, and other configurations

In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.preprocessing.image import load_img, img_to_array, array_to_img
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from tensorflow.keras.applications import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
import cv2
import matplotlib.pyplot as plt
import random
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, jaccard_score, f1_score, classification_report
from sklearn.utils.class_weight import compute_class_weight

IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_CHANNELS = 3
BATCH_SIZE = 8
EPOCHS = 80
NUM_CLASSES = 3

IMAGES_DIR = '/content/drive/MyDrive/bachelors/dataset/images'
MASKS_DIR = '/content/drive/MyDrive/bachelors/dataset/masks'
CSV_LABELS = '/content/drive/MyDrive/bachelors/dataset/manual_labels.csv'
SEG_OUTPUT_DIR = '/content/drive/MyDrive/bachelors/dataset/segmented_masks'
VAL_CSV_LABELS = '/content/drive/MyDrive/bachelors/dataset/val_labels.csv'

def load_labels(csv_path):
    df = pd.read_csv(csv_path)
    df['label_idx'] = df['class_name'].map({'oval':0, 'string':1, 'other':2})
    return df

labels_df = load_labels(CSV_LABELS)


Define U-Net Segmentation Model. Prepare validation and training splits from images.

In [None]:
def conv_block(inputs, filters):
    x = layers.Conv2D(filters, 3, padding='same')(inputs)
    x = layers.BatchNormalization()(x); x = layers.ReLU()(x)
    x = layers.Conv2D(filters, 3, padding='same')(x)
    x = layers.BatchNormalization()(x); x = layers.ReLU()(x)
    return x

def encoder_block(inputs, filters):
    x = conv_block(inputs, filters)
    p = layers.MaxPooling2D()(x)
    return x, p

def decoder_block(inputs, skip, filters):
    x = layers.Conv2DTranspose(filters, 2, strides=2, padding='same')(inputs)
    x = layers.Concatenate()([x, skip])
    return conv_block(x, filters)

def build_unet(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)):
    i = layers.Input(shape)
    s1, p1 = encoder_block(i, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)
    b = conv_block(p4, 1024)
    d1 = decoder_block(b, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)
    o = layers.Conv2D(1, 1, activation='sigmoid')(d4)
    return Model(i, o)

def load_image_mask_pair(image_path, mask_path, target_size=(IMG_HEIGHT, IMG_WIDTH)):
    image = load_img(image_path, target_size=target_size)
    image = img_to_array(image).astype(np.float32) / 255.0

    mask = load_img(mask_path, color_mode='grayscale', target_size=target_size)
    mask = img_to_array(mask).astype(np.float32) / 255.0
    mask = np.where(mask > 0.5, 1.0, 0.0)
    return image, mask

def data_generator(image_paths, mask_paths, batch_size=BATCH_SIZE):
    idxs = np.arange(len(image_paths))
    while True:
        np.random.shuffle(idxs)
        for start in range(0, len(idxs), batch_size):
            batch_idxs = idxs[start:start+batch_size]
            imgs, msks = [], []
            for i in batch_idxs:
                img, msk = load_image_mask_pair(image_paths[i], mask_paths[i])
                imgs.append(img); msks.append(msk)
            yield np.stack(imgs, axis=0), np.stack(msks, axis=0)

def augment(image, mask):
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_left_right(image)
        mask  = tf.image.flip_left_right(mask)

    k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32)
    image = tf.image.rot90(image, k)
    mask  = tf.image.rot90(mask,  k)

    image = tf.image.random_brightness(image, max_delta=0.1)
    return image, mask

img_files = set(os.listdir(IMAGES_DIR))
mask_files = set(os.listdir(MASKS_DIR))
common = sorted(img_files.intersection(mask_files))
all_img_paths = [os.path.join(IMAGES_DIR, f) for f in common]
all_mask_paths = [os.path.join(MASKS_DIR, f) for f in common]

train_imgs, val_imgs, train_msks, val_msks = train_test_split(all_img_paths, all_mask_paths, test_size=0.2, random_state=42)

Initiate and compile the model. Prepare the model for training.

In [None]:
accuracy = tf.keras.metrics.BinaryAccuracy(name='Tikslumas')
precision = tf.keras.metrics.Precision(name='Preciziškumas')
recall    = tf.keras.metrics.Recall(name='Atkūrimas')


In [None]:
unet = build_unet()

unet.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=[accuracy, precision, recall]
)

train_gen = data_generator(train_imgs, train_msks)
val_gen = data_generator(val_imgs, val_msks)
steps = len(train_imgs)//BATCH_SIZE
vsteps = len(val_imgs)//BATCH_SIZE
callbacks = [ModelCheckpoint(f'unet{EPOCHS}.h5', save_best_only=True), EarlyStopping(patience=20, restore_best_weights=True)]

Train the model.

In [None]:
history = unet.fit(train_gen, steps_per_epoch=steps, validation_data=val_gen, validation_steps=vsteps, epochs=EPOCHS, callbacks=callbacks)

unet.save(f'unet_model_{EPOCHS}.h5')


Plot metrics during segmentation model training

In [None]:


def plot_train_metrics_separate(history, metrics_other, metric_acc='accuracy'):
    epochs = range(1, len(history.history[metric_acc]) + 1)
    font_kwargs = dict(fontsize=12, fontweight='light')
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

    plt.figure(figsize=(16, 6))
    plt.plot(
        epochs,
        history.history[metric_acc],
        label=metric_acc,
        linewidth=1.0,
        color=colors[0]
    )
    plt.title("Tikslumas kas treniravimo metu")
    plt.xlabel("Epocha")
    plt.ylabel("Tikslumas")
    plt.xticks(epochs)
    plt.grid(True)
    plt.legend(loc="lower right", frameon=False)
    plt.tight_layout()
    plt.show()


    plt.figure(figsize=(16, 6))
    for idx, m in enumerate(metrics_other):
        plt.plot(
            epochs,
            history.history[m],
            label=m,
            linewidth=1.0,
            color=colors[idx + 1]
        )
    plt.title("Preciziškumas ir Atkūrimas treniravimo metu")
    plt.xlabel("Epocha")
    plt.ylabel("Vertė", **font_kwargs)
    plt.xticks(epochs, **font_kwargs)
    plt.grid(True)
    plt.legend(loc="best", frameon=False)
    plt.tight_layout()
    plt.show()


metrics_to_plot = ['Preciziškumas', 'Atkūrimas']
plot_train_metrics_separate(history, metrics_to_plot, metric_acc='Tikslumas')


Load model weights to skip training if needed.

In [None]:
#unet.load_weights('/content/drive/MyDrive/bachelors/unet.h5')
unet.load_weights('/content/drive/MyDrive/bachelors/unet_model_60.h5')

U-net model segmentation evaluation

In [None]:
def evaluate_segmentation_model(model, image_paths, mask_paths, batch_size=8, steps=None, threshold=0.5):
    n = len(image_paths)
    if steps is None:
        steps = int(np.ceil(n / batch_size))


    def _gen(paths_img, paths_msk, batch_size):
        idxs = np.arange(len(paths_img))
        for start in range(0, len(idxs), batch_size):
            batch = idxs[start:start+batch_size]
            imgs, msks = [], []
            for i in batch:
                img, msk = load_image_mask_pair(paths_img[i], paths_msk[i])
                imgs.append(img); msks.append(msk)
            yield np.stack(imgs,0), np.stack(msks,0)

    y_true_all = []
    y_score_all = []

    gen = _gen(image_paths, mask_paths, batch_size)
    for _ in range(steps):
        imgs, msks = next(gen)
        preds = model.predict(imgs)

        y_true_all.append(msks.reshape(-1))
        y_score_all.append(preds.reshape(-1))

    y_true = np.concatenate(y_true_all)
    y_score = np.concatenate(y_score_all)

    y_pred = (y_score >= threshold).astype(np.uint8)

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    jaccard = jaccard_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)

    metrics = {
        'accuracy':  accuracy,
        'precision': precision,
        'recall':    recall,
        'jaccard':   jaccard,
        'f1_score':  f1
    }
    counts = {
        'TP': tp,
        'TN': tn,
        'FP': fp,
        'FN': fn
    }
    raw = {
        'y_true':  y_true,
        'y_score': y_score,
        'y_pred':  y_pred
    }

    print(f"Accuracy : {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall   : {recall:.4f}")
    print(f"F1 Score : {f1:.4f}")
    print(f"Jaccard  : {jaccard:.4f}")
    print(f"TP={tp}, TN={tn}, FP={fp}, FN={fn}")

    return metrics, counts, raw

metrics, counts, raw = evaluate_segmentation_model(
    model=unet,
    image_paths=val_imgs,
    mask_paths=val_msks,
    batch_size=8
)

Segment images at random for verification

In [None]:
def show_predictions(model, image_paths, mask_paths, num_samples=3):
    indices = random.sample(range(len(image_paths)), num_samples)

    plt.figure(figsize=(15, num_samples * 3))
    for i, idx in enumerate(indices):
        image, mask = load_image_mask_pair(image_paths[idx], mask_paths[idx])
        pred_mask = model.predict(np.expand_dims(image, axis=0))[0]

        pred_mask_bin = (pred_mask > 0.5).astype(np.float32)

        plt.subplot(num_samples, 3, i * 3 + 1)
        plt.imshow(image)
        plt.title("Vaizdas")
        plt.axis("off")

        plt.subplot(num_samples, 3, i * 3 + 2)
        plt.imshow(mask.squeeze(), cmap='gray')
        plt.title("Tikroji kaukė")
        plt.axis("off")

        plt.subplot(num_samples, 3, i * 3 + 3)
        plt.imshow(pred_mask_bin.squeeze(), cmap='gray')
        plt.title("Spėjama kaukė")
        plt.axis("off")

    plt.tight_layout()
    plt.show()

show_predictions(unet, val_imgs, val_msks)

Process and save all segmented masks (Optional)

In [None]:
def load_and_preprocess_image(image_path, target_size=(IMG_HEIGHT, IMG_WIDTH)):
    img = load_img(image_path, target_size=target_size)
    img = img_to_array(img).astype(np.float32) / 255.0
    return img

def segment_and_save(model, image_paths, output_dir, target_size=(IMG_HEIGHT, IMG_WIDTH), threshold=0.5):
    os.makedirs(output_dir, exist_ok=True)

    for img_path in image_paths:
        img = load_and_preprocess_image(img_path, target_size=target_size)

        pred = model.predict(np.expand_dims(img, axis=0))[0]

        bin_mask = (pred > threshold).astype(np.uint8) * 255

        pil_mask = array_to_img(bin_mask, scale=False)

        base = os.path.basename(img_path)
        name, _ = os.path.splitext(base)
        out_path = os.path.join(output_dir, f"{name}.jpg")

        pil_mask.save(out_path)

image_paths = [
    os.path.join(IMAGES_DIR, fname)
    for fname in os.listdir(IMAGES_DIR)
    if fname.lower().endswith(('.png', '.jpg', '.jpeg'))
]

segment_and_save(unet, image_paths, SEG_OUTPUT_DIR)

Feature extraction with VGG16 from segmented masks

In [None]:
def extract_patches_and_features(unet_model, images_dir, labels_df, fe):
    feature_extractor = fe
    features, labels, patch_meta = [], [], []

    for fname in os.listdir(images_dir):
        img_path = os.path.join(images_dir, fname)
        mask = unet_model.predict(np.expand_dims(load_img(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH)), axis=0)/255.0)[0,...,0]
        mask_bin = (mask>0.5).astype(np.uint8)
        contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        orig = img_to_array(load_img(img_path))

        for cnt in contours:
            x,y,w,h = cv2.boundingRect(cnt)
            crop = orig[y:y+h, x:x+w]
            crop_resized = tf.image.resize(crop, [IMG_HEIGHT, IMG_WIDTH]).numpy()
            proc = preprocess_input(np.expand_dims(crop_resized, axis=0))
            feat = feature_extractor.predict(proc).flatten()
            features.append(feat)
            row = labels_df[labels_df['filename']==fname]
            labels.append(row['label_idx'].values[0] if not row.empty else -1)
            patch_meta.append((fname, x, y, w, h))

    if not features:
        raise ValueError("No patches detected. Check your segmentation masks or input images.")
    return np.vstack(features), np.array(labels), patch_meta

feature_extractor = VGG16(weights='imagenet', include_top=False, pooling='avg', input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
features, labels_arr, patch_meta = extract_patches_and_features(unet, IMAGES_DIR, labels_df, feature_extractor)

def assign_pseudo_labels(features, labels_arr, n_clusters=3):
    if features.shape[0] < n_clusters:
        raise ValueError(f"Insufficient samples ({features.shape[0]}) for {n_clusters}-cluster KMeans.")
    km = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = km.fit_predict(features)
    mapping = {}
    for c in range(n_clusters):
        idx = np.where((clusters==c)&(labels_arr>=0))[0]
        mapping[c] = np.bincount(labels_arr[idx]).argmax() if len(idx)>0 else 0
    return np.array([mapping[c] if labels_arr[i]<0 else labels_arr[i]
                     for i, c in enumerate(clusters)])


pseudo_labels = assign_pseudo_labels(features, labels_arr)

Save feature data

In [None]:
def save_features(features, labels, patch_meta, filename):
    dtype = [('filename', 'U255'), ('x', int), ('y', int), ('w', int), ('h', int)]
    patch_meta_array = np.array(patch_meta, dtype=dtype)

    np.savez_compressed(
        filename,
        features=features,
        labels=labels,
        patch_meta=patch_meta_array
    )

save_features(features, labels_arr, patch_meta, 'extracted_features.npz')

Load feature data to avoid extraction

In [None]:
def load_features(filename):
    data = np.load(filename)
    patch_meta = [
        (row['filename'], row['x'], row['y'], row['w'], row['h'])
        for row in data['patch_meta']
    ]
    return data['features'], data['labels'], patch_meta

def assign_pseudo_labels(features, labels_arr, n_clusters=3):
    if features.shape[0] < n_clusters:
        raise ValueError(f"Insufficient samples ({features.shape[0]}) for {n_clusters}-cluster KMeans.")
    km = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = km.fit_predict(features)
    mapping = {}
    for c in range(n_clusters):
        idx = np.where((clusters==c)&(labels_arr>=0))[0]
        mapping[c] = np.bincount(labels_arr[idx]).argmax() if len(idx)>0 else 0
    return np.array([mapping[c] if labels_arr[i]<0 else labels_arr[i]
                     for i, c in enumerate(clusters)])


features, labels, patch_meta = [], [], []
features, labels_arr, patch_meta = load_features('extracted_features.npz')
pseudo_labels = assign_pseudo_labels(features, labels_arr)


Find the best hyperparameters for the classifier

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV

pagrindinis_clf = LogisticRegression(penalty='l2',multi_class='multinomial',class_weight='balanced',max_iter=10000,random_state=42)


param_grid = {
    'solver': ['lbfgs', 'sag', 'saga', 'newton-cg'],
    'C':      [0.01, 0.1, 1, 10],
}

tinklelis = GridSearchCV(
    estimator=pagrindinis_clf,
    param_grid=param_grid,
    cv=3,
    scoring='accuracy',
    n_jobs=-1,
    verbose=1
)

tinklelis.fit(features, pseudo_labels)

geriausias = tinklelis.best_estimator_
print(f"Geriausias CV tikslumas: {tinklelis.best_score_}")
print(f"Geriausi parametrai: {tinklelis.best_params_}")

iteracijos = None
if hasattr(geriausias, 'n_iter_'):
    it = geriausias.n_iter_
    iteracijos = it if isinstance(it, int) else it[0]
print(f"Iteracijų skaičius: {iteracijos}")


Train the classifier

In [None]:
labels = labels_arr[labels_arr >= 0]
classes = np.unique(labels)

weights = compute_class_weight(class_weight='balanced',classes=classes,y=labels)

class_weight = dict(zip(classes, weights))

#clf = LogisticRegression(class_weight=class_weight, max_iter=100)

clf = LogisticRegression(
    penalty='l2',
    C= 1,
    solver='saga',
    multi_class='multinomial',
    class_weight='balanced',
    max_iter=1000,
    random_state=42
)


clf.fit(features, pseudo_labels)


Evaluate classifier

In [None]:
def evaluate_vgg_classifier_with_csv(clf, features, patch_meta, val_csv_path, class_map=None):
    df = pd.read_csv(val_csv_path)
    if class_map is None:
        class_map = {'oval':0, 'string':1, 'other':2}
    df['label_idx'] = df['class_name'].map(class_map)
    true_label_lookup = dict(zip(df['filename'], df['label_idx']))

    y_true = []
    X_feats = []
    for idx, (fname, x, y, w, h) in enumerate(patch_meta):
        if fname in true_label_lookup:
            y_true.append(true_label_lookup[fname])
            X_feats.append(features[idx])

    if len(y_true) == 0:
        raise ValueError(f"No patches found in CSV {val_csv_path}")

    X = np.vstack(X_feats)
    y_true = np.array(y_true)

    y_pred = clf.predict(X)

    acc  = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
    rec  = recall_score(y_true, y_pred, average='macro', zero_division=0)
    cm   = confusion_matrix(y_true, y_pred)
    report = classification_report(
        y_true, y_pred,
        target_names=list(class_map.keys()),
        zero_division=0
    )

    total = cm.sum()
    class_accuracies = {}
    labels = list(class_map.keys())
    for i, label in enumerate(labels):
        TP = cm[i, i]
        FP = cm[:, i].sum() - TP
        FN = cm[i, :].sum() - TP
        TN = total - TP - FP - FN
        class_accuracies[label] = (TP + TN) / total

    print(f"Accuracy : {acc:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall   : {rec:.4f}")
    for label, a in class_accuracies.items():
        print(f"Accuracy of {label:6s}: {a:.4f}")
    print(f"\nConfusion Matrix:\n{cm}")
    print(f"\nClassification Report:\n{report}")

evaluate_vgg_classifier_with_csv(
    clf=clf,
    features=features,
    patch_meta=patch_meta,
    val_csv_path= VAL_CSV_LABELS
)



Compile N samples of segmentations and classifications

In [None]:
def visualize_with_saved_features(unet_model, classifier, image_paths, mask_paths,features, patch_meta, num_samples=5):
    class_names = ['ovalas', 'siulas', 'kita']

    meta_to_idx = { meta: idx for idx, meta in enumerate(patch_meta) }

    indices = random.sample(range(len(image_paths)), num_samples)
    plt.figure(figsize=(18, num_samples * 4))

    for i, img_idx in enumerate(indices):
        img_path = image_paths[img_idx]
        msk_path = mask_paths[img_idx]

        image, true_mask = load_image_mask_pair(img_path, msk_path)
        disp = (image * 255).astype(np.uint8).copy()

        pred_mask = unet_model.predict(np.expand_dims(image, 0))[0, ..., 0]
        bin_mask = (pred_mask > 0.5).astype(np.uint8)

        contours, _ = cv2.findContours(bin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for cnt in contours:
            x, y, w, h = cv2.boundingRect(cnt)
            key = (os.path.basename(img_path), x, y, w, h)
            if key not in meta_to_idx:
                continue

            feat_idx = meta_to_idx[key]
            feat_vec = features[feat_idx]
            pred_class = classifier.predict([feat_vec])[0]

            cv2.rectangle(disp, (x, y), (x + w, y + h), (0, 255, 0), 2)
            cv2.putText(disp, class_names[pred_class], (x, y - 5),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

        plt.subplot(num_samples, 4, i * 4 + 1)
        plt.imshow(image)
        plt.title("Vaizdas")
        plt.axis("off")
        plt.subplot(num_samples, 4, i * 4 + 2)
        plt.imshow(true_mask.squeeze(), cmap='gray')
        plt.title("Tikroji kaukė")
        plt.axis("off")
        plt.subplot(num_samples, 4, i * 4 + 3)
        plt.imshow(bin_mask, cmap='gray');
        plt.title("Spėjama kaukė")
        plt.axis("off")
        plt.subplot(num_samples, 4, i * 4 + 4)
        plt.imshow(disp)
        plt.title("Klasifikacija");plt.axis("off")

    plt.tight_layout()
    plt.show()



feature_extractor = VGG16(weights='imagenet', include_top=False, pooling='avg', input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))

visualize_with_saved_features(unet, clf, val_imgs, val_msks, features, patch_meta, num_samples=81)


Use single image for segmentation and classification

In [None]:
def visualize_single_image_with_saved_features(unet_model, classifier, images_dir, masks_dir, features, patch_meta, image_name):

    class_names = ['oval', 'string', 'other']

    meta_to_idx = { meta: idx for idx, meta in enumerate(patch_meta) }

    img_path = os.path.join(images_dir, image_name)
    msk_path = os.path.join(masks_dir,  image_name)
    if not os.path.exists(img_path):
        raise FileNotFoundError(f"Image not found: {img_path}")
    if not os.path.exists(msk_path):
        raise FileNotFoundError(f"Mask not found: {msk_path}")

    image, true_mask = load_image_mask_pair(img_path, msk_path)
    disp = (image * 255).astype(np.uint8).copy()

    pred_mask = unet_model.predict(np.expand_dims(image, 0))[0, ..., 0]
    bin_mask  = (pred_mask > 0.5).astype(np.uint8)

    contours, _ = cv2.findContours(bin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for cnt in contours:
        x, y, w, h = cv2.boundingRect(cnt)
        key = (image_name, x, y, w, h)
        if key not in meta_to_idx:
            continue

        feat_idx = meta_to_idx[key]
        feat_vec  = features[feat_idx]
        pred_cl   = classifier.predict([feat_vec])[0]

        cv2.rectangle(disp, (x, y), (x + w, y + h), (0, 255, 0), 2)
        cv2.putText(disp, class_names[pred_cl], (x, y - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 4, 1)
    plt.imshow(image)
    plt.title("Vaizdas")
    plt.axis("off")

    plt.subplot(1, 4, 2)
    plt.imshow(true_mask.squeeze(), cmap='gray')
    plt.title("Tikroji kaukė")
    plt.axis("off")

    plt.subplot(1, 4, 3)
    plt.imshow(bin_mask, cmap='gray')
    plt.title("Segmentuota kaukė")
    plt.axis("off")

    plt.subplot(1, 4, 4)
    plt.imshow(disp)
    plt.title("Klasifikacija")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

visualize_single_image_with_saved_features(
    unet_model=unet,
    classifier=clf,
    images_dir=IMAGES_DIR,
    masks_dir=MASKS_DIR,
    features=features,
    patch_meta=patch_meta,
    image_name='20211223110429.jpg'
)
