In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import albumentations as A
from tqdm.auto import tqdm
import glob
import os
import pandas as pd
from sklearn.model_selection import train_test_split
import time

from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
import cv2

import segmentation_models_pytorch as smp
from torch.profiler import profile, record_function, ProfilerActivity

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [3]:
# a list to collect paths of images
images_path = []
labels_path = []
masks_path = []

# Get the paths of the images and sort them
images_path = sorted(glob.glob('data/original_images/*.jpg'))
labels_path = sorted(glob.glob('data/label_images_semantic/*.png'))
rgb_masks_path = sorted(glob.glob('data/RGB_color_image_masks/*.png'))

paths = np.column_stack((images_path, labels_path))
print(paths.shape)
print(paths[0])

# Apply 80-10-10 split
train_split, valtest_split = train_test_split(paths, test_size=0.2, random_state=69420)
val_split, test_split = train_test_split(valtest_split, test_size=0.5, random_state=69420)

(400, 2)
['data/original_images/000.jpg' 'data/label_images_semantic/000.png']


In [4]:
import os
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from PIL import Image
import torchvision.transforms as transforms

def create_and_save_tiles(img_path, mask_path, tiles_dim=512, final_dim=256, output_dir="data"):
    # Extract the original index from the image path
    original_index = os.path.basename(img_path)[:3]

    # Load image and mask
    img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    # Convert to tensor
    transform = transforms.ToTensor()
    img = transform(img)
    mask = transform(mask)

    # Check if resizing is necessary
    if img.shape[1] % tiles_dim != 0 or img.shape[2] % tiles_dim != 0:
        # Round down to the nearest multiple of tiles_dim
        new_height = img.shape[1] // tiles_dim * tiles_dim
        new_width = img.shape[2] // tiles_dim * tiles_dim
        new_shp = (new_height, new_width)

        # Resize the image and mask
        img = F.interpolate(img.unsqueeze(0), size=new_shp, mode='bilinear', align_corners=False).squeeze(0)
        mask = F.interpolate(mask.unsqueeze(0), size=new_shp, mode='nearest').squeeze(0)
  
    img_tiles = img.unfold(1, tiles_dim, tiles_dim).unfold(2, tiles_dim, tiles_dim)
    img_tiles = img_tiles.contiguous().view(3, -1, tiles_dim, tiles_dim).permute(1, 0, 2, 3)
    
    mask_tiles = mask.unfold(1, tiles_dim, tiles_dim).unfold(2, tiles_dim, tiles_dim)
    mask_tiles = mask_tiles.contiguous().view(-1, tiles_dim, tiles_dim)

    # Resize tiles to 256x256
    resize_dim = final_dim
    img_tiles = F.interpolate(img_tiles, size=(resize_dim, resize_dim), mode='bilinear', align_corners=False)
    mask_tiles = F.interpolate(mask_tiles.unsqueeze(1), size=(resize_dim, resize_dim), mode='nearest').squeeze(1)

    # Create output directories if they don't exist
    img_output_dir = os.path.join(output_dir, f"{tiles_dim}x{tiles_dim}", "images")
    mask_output_dir = os.path.join(output_dir, f"{tiles_dim}x{tiles_dim}", "masks")
    os.makedirs(img_output_dir, exist_ok=True)
    os.makedirs(mask_output_dir, exist_ok=True)

    # Save tiles
    for i, (img_tile, mask_tile) in enumerate(zip(img_tiles, mask_tiles)):
        img_tile_path = os.path.join(img_output_dir, f"{original_index}_{i}.png")
        mask_tile_path = os.path.join(mask_output_dir, f"{original_index}_{i}.png")
        save_image(img_tile, img_tile_path)
        save_image(mask_tile, mask_tile_path)

    return img_tiles, mask_tiles




already_created = True

if not already_created:
    resize_dim = 256
    for path in tqdm(paths):
        img_path, mask_path = path
        create_and_save_tiles(img_path, mask_path, tiles_dim=2000, final_dim=256, output_dir=f'data/tiles_{resize_dim}')

In [5]:
# for img in tqdm(paths[:, 0]):
#     shape = cv2.imread(img).shape
#     print(f'shape for img {img} is {shape}')
#     if shape != (4000, 6000, 3):
#         print(f'wrong shape for img {img}')

In [6]:
# Read number of classes
labels_colors = pd.read_csv('data/class_dict_seg.csv')
columns = ['class', 'r', 'g', 'b']
labels_colors.columns = columns
# Extract RGB values
labels_colors['RGB'] = labels_colors[['r', 'g', 'b']].apply(tuple, axis=1)
colors = labels_colors['RGB'].values
labels_colors

Unnamed: 0,class,r,g,b,RGB
0,unlabeled,0,0,0,"(0, 0, 0)"
1,paved-area,128,64,128,"(128, 64, 128)"
2,dirt,130,76,0,"(130, 76, 0)"
3,grass,0,102,0,"(0, 102, 0)"
4,gravel,112,103,87,"(112, 103, 87)"
5,water,28,42,168,"(28, 42, 168)"
6,rocks,48,41,30,"(48, 41, 30)"
7,pool,0,50,89,"(0, 50, 89)"
8,vegetation,107,142,35,"(107, 142, 35)"
9,roof,70,70,70,"(70, 70, 70)"


In [7]:
# # Check if there are any conflicting labels in the masks
# for img in tqdm(paths[:, 1]):
#     mask = cv2.imread(img, cv2.IMREAD_GRAYSCALE)
#     unlabelled = np.sum(mask == 23)
#     if unlabelled > 0:
#         print(f'Unlabelled pixels in img {img}: {unlabelled} ({unlabelled / mask.size * 100:.2f}%)')

In [8]:
# No conflicting labels found
# Therefore there are 23 classes in the dataset
nr_classes = len(labels_colors) - 1
print(f'Number of classes: {nr_classes}')

Number of classes: 23


### Dataset class

In [9]:
class TilesDataset(Dataset):
    def __init__(self, image_paths, transform=None, tiles=True, tiles_dim=512):
        self.image_paths = image_paths
        self.transform = transform
        self.tiles = tiles
        self.tiles_dim = tiles_dim

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path, mask_path = self.image_paths[idx]
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if self.transform is not None:
            augmented = self.transform(image=img, mask=mask)
            img, mask = augmented['image'], augmented['mask']

        if self.tiles:
            img, mask = self.create_tiles(img, mask, self.tiles_dim)

        return img, mask
    
    def create_tiles(self, img, mask, tiles_dim):
        # Check if resizing is necessary
        if img.shape[1] % tiles_dim != 0 or img.shape[2] % tiles_dim != 0:
            # Round down to the nearest multiple of tiles_dim
            new_height = img.shape[1] // tiles_dim * tiles_dim
            new_width = img.shape[2] // tiles_dim * tiles_dim
            new_shp = (new_height, new_width)

            # Resize the image and mask
            img = F.interpolate(img.unsqueeze(0), size=new_shp, mode='bilinear', align_corners=False).squeeze(0)
            mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=new_shp, mode='nearest').squeeze(0).squeeze(0)

            # Print the shape of the image and mask
            print(f"Image shape: {img.shape}")
            print(f"Mask shape: {mask.shape}")
        
        # Create img tiles and mask tiles
        img_tiles = img.unfold(1, tiles_dim, tiles_dim).unfold(2, tiles_dim, tiles_dim)
        img_tiles = img_tiles.contiguous().view(3, -1, tiles_dim, tiles_dim).permute(1, 0, 2, 3)
        
        mask_tiles = mask.unfold(0, tiles_dim, tiles_dim).unfold(1, tiles_dim, tiles_dim)
        mask_tiles = mask_tiles.contiguous().view(-1, tiles_dim, tiles_dim)

        # Resize tiles if necessary
        if tiles_dim > 256:
            resize_dim = 256
            img_tiles = F.interpolate(img_tiles, size=(resize_dim, resize_dim), mode='bilinear', align_corners=False)
            mask_tiles = F.interpolate(mask_tiles.unsqueeze(1), size=(resize_dim, resize_dim), mode='nearest').squeeze(1)

        return img_tiles, mask_tiles

### Get image and mask paths

In [10]:
final_dim = 256
tiles_dim = 512
tiles_path = f'data/tiles_{final_dim}/{tiles_dim}x{tiles_dim}'

# Get image paths
for folder in os.listdir(tiles_path):
    if folder == 'images':
        img_paths = sorted(glob.glob(f'{tiles_path}/{folder}/*.png'))
    elif folder == 'masks':
        mask_paths = sorted(glob.glob(f'{tiles_path}/{folder}/*.png'))

# Combine image and mask paths
image_paths = np.array(list(zip(img_paths, mask_paths)))
print(image_paths.shape)
print(image_paths[0])

# Apply 80-10-10 split
train_split, valtest_split = train_test_split(image_paths, test_size=0.2, random_state=69420)
val_split, test_split = train_test_split(valtest_split, test_size=0.5, random_state=69420)

(30800, 2)
['data/tiles_256/512x512/images/000_0.png'
 'data/tiles_256/512x512/masks/000_0.png']


### Transformations

In [11]:
# Define Albumentations transformations
train_transform = A.Compose([
    # A.Resize(new_height, new_width, p=1.0),  # Resize the image to the desired shape
    A.HorizontalFlip(p=0.5),  # Apply horizontal flip with 50% probability
    A.VerticalFlip(p=0.5),  # Apply vertical flip with 50% probability
    A.RandomBrightnessContrast(p=0.2),  # Randomly change brightness and contrast
    A.OneOf([
        A.GaussianBlur(p=1.0),  # Apply Gaussian blur
        A.MotionBlur(p=1.0),  # Apply motion blur
    ], p=0.2),  # Apply one of the blur operations with 20% probability
    A.HueSaturationValue(p=0.2),  # Randomly change hue, saturation, and value
    A.RandomGamma(p=0.2),  # Randomly change gamma
    A.CLAHE(p=0.2),  # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # Normalize the image
    ToTensorV2(),  # Convert image and mask to PyTorch tensors
])

valtest_transform = A.Compose([
    A.HorizontalFlip(p=0.5),  # Apply horizontal flip with 50% probability
    A.VerticalFlip(p=0.5),  # Apply vertical flip with 50% probability
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # Normalize the image
    ToTensorV2(),  # Convert image and mask to PyTorch tensors
])

# Initialize your custom dataset
train_ds = TilesDataset(train_split, transform=train_transform, tiles_dim=tiles_dim, tiles=False)
val_ds = TilesDataset(val_split, transform=valtest_transform, tiles_dim=tiles_dim, tiles=False)
test_ds = TilesDataset(test_split, transform=valtest_transform, tiles_dim=tiles_dim, tiles=False)

print(f'Train dataset length: {len(train_ds)}, Val dataset length: {len(val_ds)}, Test dataset length: {len(test_ds)}')

Train dataset length: 24640, Val dataset length: 3080, Test dataset length: 3080


In [12]:
# Create a DataLoader
num_workers = 12
batch_size_train = 90
batch_size_valtest = 75
train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=num_workers, persistent_workers=True)
val_loader = DataLoader(val_ds, batch_size=batch_size_valtest, shuffle=False, pin_memory=True, num_workers=num_workers, persistent_workers=True)
test_loader = DataLoader(test_ds, batch_size=batch_size_valtest, shuffle=False, pin_memory=True, num_workers=num_workers, persistent_workers=True)

### Metrics

In [13]:
def compute_metrics_torch(y_true, y_pred, num_classes):
    # Flatten the arrays for metric computation
    y_true_flat = y_true.view(-1)
    y_pred_flat = y_pred.view(-1)

    # Compute overall accuracy
    acc = (y_true_flat == y_pred_flat).float().mean().item()

    # Helper function to compute IoU for a single class
    def compute_iou(cls):
        intersection = ((y_true_flat == cls) & (y_pred_flat == cls)).float().sum().item()
        union = ((y_true_flat == cls) | (y_pred_flat == cls)).float().sum().item()
        return intersection / union if union != 0 else 0

    # Helper function to compute Dice score for a single class
    def compute_dice(cls):
        intersection = 2 * ((y_true_flat == cls) & (y_pred_flat == cls)).float().sum().item()
        total = (y_true_flat == cls).float().sum().item() + (y_pred_flat == cls).float().sum().item()
        return intersection / total if total != 0 else 0

    # Compute IoU
    iou_list = [compute_iou(cls) for cls in range(num_classes)]
    mean_iou = np.mean(iou_list)

    # Compute Dice
    dice_list = [compute_dice(cls) for cls in range(num_classes)]
    mean_dice = np.mean(dice_list)

    # Return the metrics
    return {
        'mean_iou': mean_iou,
        'per_class_iou': iou_list,
        'accuracy': acc,
        'mean_dice': mean_dice,
        'per_class_dice': dice_list
    }

### Create network

In [14]:
config = {
    'arch': 'unet',
    'encoder_name': 'resnet34',
    'encoder_weights': 'imagenet',
    'in_channels': 3,
    'classes': nr_classes
}

model = smp.create_model(**config)
# model

### Tensorboard

In [15]:
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
import os

# Create a TensorBoard callback
logs_dir = f'logs/{config["arch"]}/{config["encoder_name"]}/tiles_{final_dim}/{tiles_dim}x{tiles_dim}/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
os.makedirs(logs_dir, exist_ok=True)  # Ensure the logs directory exists
# Get full path to the logs directory
logs_dir = os.path.abspath(logs_dir)
print(f"TensorBoard logs directory: {logs_dir}")

writer = SummaryWriter(log_dir=logs_dir)

2024-07-30 17:25:14.038262: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-30 17:25:14.054619: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-30 17:25:14.059659: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-30 17:25:14.071642: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


TensorBoard logs directory: /home/andrea/Documents/unimib/BigImaging/Exam/logs/unet/resnet34/tiles_256/512x512/2024-07-30_17-25-15


### Train

In [16]:
from torch.amp.grad_scaler import GradScaler
from torch.amp.autocast_mode import autocast

def reshape_imgs_masks(imgs, masks):
    imgs, masks = imgs.to(device), masks.to(device)

    # Reshape images: [batch_size, num_tiles, channels, height, width] -> [batch_size * num_tiles, channels, height, width]
    imgs = imgs.view(-1, imgs.shape[2], imgs.shape[3], imgs.shape[4])
    # Reshape masks: [batch_size, num_tiles, height, width] -> [batch_size * num_tiles, height, width]
    masks = masks.view(-1, masks.shape[2], masks.shape[3])

    # Convert masks to Long() type
    masks = masks.to(torch.long)

    return imgs,masks

def train(train_loss, imgs, masks, scaler, optimizer, criterion, iteration, accumulation_steps=1, use_amp=True, tiles=False):
    
    # Only reshape images and masks if tiles are being computed by the Dataset class
    # Else the source is the already tiled images and masks
    if tiles:
        imgs, masks = reshape_imgs_masks(imgs, masks)
    else:
        imgs, masks = imgs.to(device), masks.to(device)
        masks = masks.to(torch.long)

    optimizer.zero_grad()

    if use_amp:
        with autocast(device_type='cuda'):
            outputs = model(imgs)
            loss = criterion(outputs, masks) / accumulation_steps
        scaler.scale(loss).backward()
    else:
        outputs = model(imgs)
        loss = criterion(outputs, masks) / accumulation_steps
        loss.backward()

    if (iteration + 1) % accumulation_steps == 0:
        if use_amp:
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()
        optimizer.zero_grad()

    train_loss += loss.detach() * accumulation_steps  # Adjust for scaled loss

    if (iteration + 1) % 64 == 0:
        print(f'Train loss at iteration {iteration + 1}: {train_loss.item() / iteration :.3f}')

    return train_loss

def validate(val_loss, imgs, masks, criterion, use_amp=True, tiles=False):
    with torch.no_grad():
        if tiles:
            imgs, masks = reshape_imgs_masks(imgs, masks)
        else:
            imgs, masks = imgs.to(device), masks.to(device)
            masks = masks.to(torch.long)

        if use_amp:
            with autocast(device_type='cuda'):
                outputs = model(imgs)
                loss = criterion(outputs, masks)
                val_loss += loss.detach()
        else:
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            val_loss += loss.detach()

        preds = torch.argmax(outputs, dim=1)

    return val_loss, preds

In [19]:
print(all_y_true.device)
print(all_y_pred.device)

cpu
cuda:0


In [17]:
# Set up mixed precision training
scaler = GradScaler()

# Accumulation steps (adjust based on GPU memory)
accumulation_steps = 1

model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 100
num_classes = nr_classes

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    if epoch == 0:
        with torch.profiler.profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(logs_dir),
            record_shapes=True, profile_memory=True, with_stack=True
        ) as prof:
            for i, (imgs, masks) in enumerate(tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}')):
                train_loss = train(train_loss, imgs, masks, scaler, optimizer, criterion, i, accumulation_steps, use_amp=True)
                # Profile each 2 batches
                if i % 2 == 0:
                    prof.step()
    else:
        for i, (imgs, masks) in enumerate(tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}')):
            train_loss = train(train_loss, imgs, masks, scaler, optimizer, criterion, i, accumulation_steps, use_amp=True)

    writer.close()

    model.eval()
    val_loss = 0.0
    all_y_true = []
    all_y_pred = []

    with torch.no_grad():
        for imgs, masks in tqdm(val_loader, desc=f'Validation Epoch {epoch+1}/{num_epochs}'):
            val_loss, preds = validate(val_loss, imgs, masks, criterion, use_amp=True)
            all_y_true.append(masks.to(device))
            all_y_pred.append(preds)

    all_y_true = torch.cat(all_y_true, dim=0)
    all_y_pred = torch.cat(all_y_pred, dim=0)

    metrics = compute_metrics_torch(all_y_true, all_y_pred, num_classes)

    print(f'Validation Loss: {val_loss.item()/len(val_loader):.3f}, Mean IoU: {metrics["mean_iou"]:.3f}, '
      f'Accuracy: {metrics["accuracy"]:.3f}, Dice Score: {metrics["mean_dice"]:.3f}, '
      f'per-class IoU: {[f"Class {i}: {iou:.3f}" for i, iou in enumerate(metrics["per_class_iou"])]}')

Training Epoch 1/100:   0%|          | 0/274 [00:00<?, ?it/s]

Train loss at iteration 64: 1.696
Train loss at iteration 128: 1.380
Train loss at iteration 192: 1.254
Train loss at iteration 256: 1.173


Validation Epoch 1/100:   0%|          | 0/42 [00:01<?, ?it/s]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

### Predict masks

In [None]:
import torchvision
import matplotlib.pyplot as plt

def test(testloader, model, criterion):
    model.eval()
    test_loss = 0.0
    all_y_true = []
    all_y_pred = []
    
    with torch.no_grad():
        for imgs, masks in tqdm(testloader, desc=f'Test'):
            imgs, masks = reshape_imgs_masks(imgs, masks)

            with autocast(device_type='cuda'):
                outputs = model(imgs)
                loss = criterion(outputs, masks)
                test_loss += loss.detach()

            preds = torch.argmax(outputs, dim=1)
            all_y_true.append(masks)
            all_y_pred.append(preds)

    all_y_true_flattened = torch.cat(all_y_true, dim=0)
    all_y_pred_flattened = torch.cat(all_y_pred, dim=0)

    print(f'All y true shape: {all_y_true_flattened.shape}, All y pred shape: {all_y_pred_flattened.shape}')

    metrics = compute_metrics_torch(all_y_true_flattened, all_y_pred_flattened, num_classes)

    print(f'Test Loss: {test_loss.item()/len(testloader):.3f}, Mean IoU: {metrics["mean_iou"]:.3f}, '
        f'Accuracy: {metrics["accuracy"]:.3f}, Dice Score: {metrics["mean_dice"]:.3f}, '
        f'per-class IoU: {[f"Class {i}: {iou:.3f}" for i, iou in enumerate(metrics["per_class_iou"])]}')
    
    return all_y_true_flattened, all_y_pred_flattened

def convert_to_rgb(masks, colors):
    """
    Convert a 4D tensor of masks to a 5D tensor of RGB masks.

    Args:
        masks (Tensor): A 4D tensor of masks with shape [batch_size, num_tiles, height, width].
        colors (Tensor): A tensor of RGB color values for each class.

    Returns:
        Tensor: A 5D tensor of RGB masks with shape [batch_size, num_tiles, 3, height, width].
    """
    batch_size, num_tiles, height, width = masks.shape
    masks_rgb = torch.zeros((batch_size, num_tiles, 3, height, width), dtype=torch.uint8).to(device)

    for i, color in enumerate(colors):
        color_tensor = torch.tensor(color, dtype=torch.uint8).view(1, 1, 3, 1, 1).to(device)
        masks_rgb += (masks == i).unsqueeze(2) * color_tensor

    return masks_rgb
    
def visualize_predictions(true_masks, pred_masks, dims=(256, 256), images_to_visualize=3, batch_size=1):
    '''
    Visualize the predictions of a model.

    Args:
        true_masks (torch.Tensor): Ground truth masks of shape [batch_size * num_tiles, 256, 256].
        pred_masks (torch.Tensor): Predicted masks of shape [batch_size * num_tiles, 256, 256].
        dims (tuple): Dimensions of the images (default: (256, 256)).
        batch_size (int): Batch size of the test loader (default: 1).
    '''
    # Compute nr of images per row
    img_per_col = 4000 // dims[0]
    imgs_per_row = 6000 // dims[0]
    num_tiles = img_per_col * imgs_per_row

    # Clip the number of images to visualize
    num_images = images_to_visualize * num_tiles * batch_size
    true_masks = true_masks[:num_images]
    pred_masks = pred_masks[:num_images]
    
    # Resizing the masks to dims
    true_masks = F.interpolate(true_masks.unsqueeze(1).float(), size=dims, mode='nearest').squeeze(1).to(torch.uint8)
    pred_masks = F.interpolate(pred_masks.unsqueeze(1).float(), size=dims, mode='nearest').squeeze(1).to(torch.uint8)

    # Reshape masks for batching
    # [batch_size * num_tiles, 256, 256] -> [batch_size, num_tiles, dims[0], dims[1]]
    true_masks = true_masks.view(batch_size, -1, *true_masks.shape[1:])
    pred_masks = pred_masks.view(batch_size, -1, *pred_masks.shape[1:])

    # Convert masks to RGB
    true_masks_rgb = convert_to_rgb(true_masks, colors)
    pred_masks_rgb = convert_to_rgb(pred_masks, colors)

    for i in range(batch_size):
        # Create a grid of predictions
        pred_grid = torchvision.utils.make_grid(pred_masks_rgb[i], nrow=imgs_per_row, normalize=False, pad_value=1)
        # Create grid of true masks
        true_grid = torchvision.utils.make_grid(true_masks_rgb[i], nrow=imgs_per_row, normalize=False, pad_value=1)
        
        # Display the grids side by side
        fig, axes = plt.subplots(1, 2, figsize=(15, 7))
        
        # Display the true masks grid
        axes[0].imshow(true_grid.permute(1, 2, 0).cpu().numpy())
        axes[0].axis('off')
        axes[0].set_title('True Masks')
        
        # Display the predicted masks grid
        axes[1].imshow(pred_grid.permute(1, 2, 0).cpu().numpy())
        axes[1].axis('off')
        axes[1].set_title('Predicted Masks')
        
        plt.show()

all_y_true_flattened, all_y_pred_flattened = test(test_loader, model, criterion)
visualize_predictions(all_y_true_flattened, all_y_pred_flattened, dims=(tiles_dim, tiles_dim), batch_size=batch_size_valtest)