In [1]:
from PIL import Image
#from cityscapesscripts.helpers import labels
from cityscapesscripts.helpers import labels as cs_labels
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import time
import gc
from tqdm.auto import tqdm
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
from backbones_unet.model.unet import Unet

  from .autonotebook import tqdm as notebook_tqdm


# Setup

In [2]:
# Set seeds for reproducibility
def set_seeds(seed=42):
    """Set seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.mps.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed
SEED = 42
set_seeds(SEED)
print(f"Seeds set for reproducibility is {SEED}")

DEVICE = torch.device("cuda" if torch.cuda.is_available() 
                      else "mps" if torch.mps.is_available() 
                      else "cpu"
                    )
print(DEVICE)

from contextlib import contextmanager
import ssl
# Disable SSL verification for urllib requests on MacOS
# This is a workaround for the "SSL: CERTIFICATE_VERIFY_FAILED" error on MacOS
@contextmanager
def no_ssl_verification():
    """Temporarily disable SSL verification"""
    old_context = ssl._create_default_https_context
    ssl._create_default_https_context = ssl._create_unverified_context
    try:
        yield
    finally:
        ssl._create_default_https_context = old_context

Seeds set for reproducibility is 42
mps


In [3]:
class SegmentationDataset(Dataset):
    def __init__(self, images_dir, targets_dir, image_transform=None, target_transform=None):
        """
        Args:
            images_dir (string): Directory with all the images.
            targets_dir (string): Directory with all the target masks.
            image_transform (callable, optional): Optional transform to be applied on images.
            target_transform (callable, optional): Optional transform to be applied on targets.
        """
        self.images_dir = images_dir
        self.targets_dir = targets_dir
        self.image_transform = image_transform
        self.target_transform = target_transform
        
        # Get all image filenames
        self.image_filenames = [f for f in os.listdir(images_dir) 
                               if f.lower().endswith(('.png'))]
        self.image_filenames.sort()
        
    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        # Load image
        img_name = self.image_filenames[idx]
        img_path = os.path.join(self.images_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        # Load target mask
        target_name = img_name.replace('.png', '_trainId.png')
        target_path = os.path.join(self.targets_dir, target_name)
        target = Image.open(target_path)
        
        # Apply transforms
        if self.image_transform:
            image = self.image_transform(image)
        
        if self.target_transform:
            target = self.target_transform(target)
       # else:
       #     # Default: convert to tensor
       #     target = torch.from_numpy(target)
        
        return image, target

In [4]:
NUM_CLASSES = 19

In [5]:
def compute_iou(preds, labels, num_classes, ignore_index=255):
    
    preds = torch.argmax(preds, dim=1).detach().cpu()  # [B, H, W]
    
    labels = labels.detach().cpu() 
    
    ious = []
    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (labels == cls)
        
        # Escludi pixel ignorati
        mask = (labels != ignore_index)
        pred_inds = pred_inds & mask
        target_inds = target_inds & mask

        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        
        if union == 0:
            continue  # salta classe non presente
        ious.append(intersection / union)
    
    if len(ious) == 0:
        return float('nan')  # o 0.0 se preferisci
    return sum(ious) / len(ious)

In [6]:
def clear_memory():
    """Clear memory and cache for all device types"""
    # Delete all local variables in the caller's frame
    for obj in list(locals().values()):
        del obj
        
    # Run garbage collector
    gc.collect()
    
    if DEVICE.type == 'cuda':
        torch.cuda.empty_cache()

    if DEVICE.type == 'mps':
        torch.mps.empty_cache()

    # Second GC run to make sure everything is cleaned up
    gc.collect()

In [7]:
# Test Function
def run_test(model, test_dataloader):
    model.eval()
    test_total_loss = 0
    test_total_iou = 0
    
    with torch.no_grad():
        for i, (images, masks) in enumerate(tqdm(test_dataloader, desc="Testing")):
            images = images.to(DEVICE)
            masks = masks.to(DEVICE)
            
            outputs = model(images)['out']
            
            # Collect metrics
            test_total_iou += compute_iou(outputs.detach().cpu(), masks, NUM_CLASSES)
            
            # Explicit cleanup
            del images, masks, outputs
            
            # Periodic memory clear
            if (i + 1) % 5 == 0:
                clear_memory()
    
    clear_memory()
    
    test_avg_iou = test_total_iou / len(test_dataloader)
    return test_avg_iou

# DeeplabV3


In [8]:
path_images = "syn_resized_images"
path_target = "syn_resized_gt"

image_transform = transforms.Compose([
    transforms.Resize((256, 466)), # We maintain the og aspect ratio
    transforms.ToTensor(),  # Converts PIL Image to Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize with ImageNet parameters
])

target_transform = transforms.Compose([
    transforms.Resize((256, 466), interpolation=Image.NEAREST), # This interpolation ensure that all pixels have a correct value of their class
    transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
])
syn_dataset = SegmentationDataset(images_dir=path_images, targets_dir=path_target, image_transform=image_transform, target_transform=target_transform)

In [9]:
# Get total dataset size
total_size = len(syn_dataset)

# Calculate split sizes (60% train, 10% val, 30% test)
train_size = int(0.6 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

# Create random splits
syn_train_dataset, syn_val_dataset, syn_test_dataset = random_split(
    syn_dataset, 
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(SEED)  # For reproducibility
)

# Create DataLoaders
batch_size = 8  


syn_train_dataloader = DataLoader(
    syn_train_dataset,
    batch_size=batch_size,
    shuffle=True,
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED) 
)


syn_val_dataloader = DataLoader(
    syn_val_dataset,
    batch_size=batch_size,
    shuffle=True,  
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

syn_test_dataloader = DataLoader(
    syn_test_dataset,
    batch_size=batch_size,
    shuffle=True,  
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

## DeeplabV3 Coco

In [10]:
if DEVICE.type == 'mps': 
    print("mps detected, using no_ssl_verification")
    with no_ssl_verification():
        model = deeplabv3_resnet50(
            weights='COCO_WITH_VOC_LABELS_V1', 
        )
else:
    model = deeplabv3_resnet50(
        weights='COCO_WITH_VOC_LABELS_V1', 
    )

model.classifier[4] = nn.Conv2d(256, NUM_CLASSES, kernel_size=1)
model = model.to(DEVICE)

# Load the best model checkpoint
best_model_path = "models/deeplabv3_resnet50_best_model.pth"
print(f"Loading best model from: {best_model_path}")

checkpoint = torch.load(best_model_path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])

print(f"Best model loaded successfully!")
print(f"Best validation IoU: {checkpoint['val_iou']:.4f}")
print(f"From epoch: {checkpoint['epoch']}")

mps detected, using no_ssl_verification
Loading best model from: models/deeplabv3_resnet50_best_model.pth
Best model loaded successfully!
Best validation IoU: 0.5830
From epoch: 25


In [11]:
test_avg_iou = run_test(model=model,test_dataloader=syn_test_dataloader)
print(f"Test Average IoU: {test_avg_iou:.4f}")

Testing: 100%|██████████| 469/469 [07:14<00:00,  1.08it/s]

Test Average IoU: 0.5900





## DeeplabV3 Imagenet1k

In [14]:
if DEVICE.type == 'mps': 
    print("mps detected, using no_ssl_verification")
    with no_ssl_verification():
        model = deeplabv3_resnet50(
            weights='COCO_WITH_VOC_LABELS_V1', 
            weights_backbone='IMAGENET1K_V1'
        )
else:
    model = deeplabv3_resnet50(
        weights='COCO_WITH_VOC_LABELS_V1', 
        weights_backbone='IMAGENET1K_V1'
    )

model.classifier[4] = nn.Conv2d(256, NUM_CLASSES, kernel_size=1)
model = model.to(DEVICE)

# Load the best model checkpoint
best_model_path = "models/deeplabv3_imagenet1k_best_model.pth"
print(f"Loading best model from: {best_model_path}")

checkpoint = torch.load(best_model_path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])

print(f"Best model loaded successfully!")
print(f"Best validation IoU: {checkpoint['val_iou']:.4f}")
print(f"From epoch: {checkpoint['epoch']}")

mps detected, using no_ssl_verification
Loading best model from: models/deeplabv3_imagenet1k_best_model.pth
Best model loaded successfully!
Best validation IoU: 0.5857
From epoch: 24


In [15]:
test_avg_iou = run_test(model=model,test_dataloader=syn_test_dataloader)
print(f"Test Average IoU: {test_avg_iou:.4f}")

Testing: 100%|██████████| 469/469 [07:11<00:00,  1.09it/s]

Test Average IoU: 0.5892





## DeeplabV3 Imagent1k 480x256

In [17]:
image_transform = transforms.Compose([
    transforms.Resize((256, 480)), # We augment from 466x256 to 480x256 in order to avoid the padding artifacts
    transforms.ToTensor(),  # Converts PIL Image to Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize with ImageNet parameters
])

target_transform = transforms.Compose([
    transforms.Resize((256, 480), interpolation=Image.NEAREST), # This interpolation ensure that all pixels have a correct value of their class
    transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
])
syn_dataset = SegmentationDataset(images_dir=path_images, targets_dir=path_target, image_transform=image_transform, target_transform=target_transform)

# Get total dataset size
total_size = len(syn_dataset)
print(f"Total dataset size: {total_size}")

# Calculate split sizes (60% train, 10% val, 30% test)
train_size = int(0.6 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

print(f"Train size: {train_size} ({train_size/total_size*100:.1f}%)")
print(f"Validation size: {val_size} ({val_size/total_size*100:.1f}%)")
print(f"Test size: {test_size} ({test_size/total_size*100:.1f}%)")

# Create random splits
syn_train_dataset, syn_val_dataset, syn_test_dataset = random_split(
    syn_dataset, 
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(SEED)  # For reproducibility
)

# Create DataLoaders
batch_size = 8  


syn_train_dataloader = DataLoader(
    syn_train_dataset,
    batch_size=batch_size,
    shuffle=True,
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED) 
)


syn_val_dataloader = DataLoader(
    syn_val_dataset,
    batch_size=batch_size,
    shuffle=True,  
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

syn_test_dataloader = DataLoader(
    syn_test_dataset,
    batch_size=batch_size,
    shuffle=True,  
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

print(f"\nDataLoaders created:")
print(f"Train batches: {len(syn_train_dataloader)}")
print(f"Validation batches: {len(syn_val_dataloader)}")
print(f"Test batches: {len(syn_test_dataloader)}")

if DEVICE.type == 'mps': 
    print("mps detected, using no_ssl_verification")
    with no_ssl_verification():
        model = deeplabv3_resnet50(
            weights='COCO_WITH_VOC_LABELS_V1', 
            weights_backbone='IMAGENET1K_V1'
        )
else:
    model = deeplabv3_resnet50(
        weights='COCO_WITH_VOC_LABELS_V1', 
        weights_backbone='IMAGENET1K_V1'
    )

model.classifier[4] = nn.Conv2d(256, NUM_CLASSES, kernel_size=1)
model = model.to(DEVICE)

Total dataset size: 12500
Train size: 7500 (60.0%)
Validation size: 1250 (10.0%)
Test size: 3750 (30.0%)

DataLoaders created:
Train batches: 938
Validation batches: 157
Test batches: 469
mps detected, using no_ssl_verification


In [18]:
# Load the best model checkpoint
best_model_path = "models/deeplabv3_imagenet1k_480x256_best_model.pth"
print(f"Loading best model from: {best_model_path}")

checkpoint = torch.load(best_model_path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])

print(f"Best model loaded successfully!")
print(f"Best validation IoU: {checkpoint['val_iou']:.4f}")
print(f"From epoch: {checkpoint['epoch']}")

Loading best model from: models/deeplabv3_imagenet1k_480x256_best_model.pth
Best model loaded successfully!
Best validation IoU: 0.5857
From epoch: 25


In [19]:
test_avg_iou = run_test(model=model,test_dataloader=syn_test_dataloader)
print(f"Test Average IoU: {test_avg_iou:.4f}")

Testing: 100%|██████████| 469/469 [07:21<00:00,  1.06it/s]

Test Average IoU: 0.5448





# Unet

## Unet Imagenet1k

In [21]:
path_images = "syn_resized_images"
path_target = "syn_resized_gt"

image_transform = transforms.Compose([
    transforms.Resize((256, 466)), # We maintain the og aspect ratio
    transforms.ToTensor(),  # Converts PIL Image to Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize with ImageNet parameters
])

target_transform = transforms.Compose([
    transforms.Resize((256, 466), interpolation=Image.NEAREST), # This interpolation ensure that all pixels have a correct value of their class
    transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
])
syn_dataset = SegmentationDataset(images_dir=path_images, targets_dir=path_target, image_transform=image_transform, target_transform=target_transform)

# Get total dataset size
total_size = len(syn_dataset)
print(f"Total dataset size: {total_size}")

# Calculate split sizes (60% train, 10% val, 30% test)
train_size = int(0.6 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

print(f"Train size: {train_size} ({train_size/total_size*100:.1f}%)")
print(f"Validation size: {val_size} ({val_size/total_size*100:.1f}%)")
print(f"Test size: {test_size} ({test_size/total_size*100:.1f}%)")

# Create random splits
syn_train_dataset, syn_val_dataset, syn_test_dataset = random_split(
    syn_dataset, 
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(SEED)  # For reproducibility
)

# Create DataLoaders
batch_size = 8  


syn_train_dataloader = DataLoader(
    syn_train_dataset,
    batch_size=batch_size,
    shuffle=True,
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED) 
)


syn_val_dataloader = DataLoader(
    syn_val_dataset,
    batch_size=batch_size,
    shuffle=True,  
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

syn_test_dataloader = DataLoader(
    syn_test_dataset,
    batch_size=batch_size,
    shuffle=True,  
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

print(f"\nDataLoaders created:")
print(f"Train batches: {len(syn_train_dataloader)}")
print(f"Validation batches: {len(syn_val_dataloader)}")
print(f"Test batches: {len(syn_test_dataloader)}")

Total dataset size: 12500
Train size: 7500 (60.0%)
Validation size: 1250 (10.0%)
Test size: 3750 (30.0%)

DataLoaders created:
Train batches: 938
Validation batches: 157
Test batches: 469


In [22]:
class PaddedUnet(nn.Module):
    def __init__(self, backbone='resnet50', pretrained=True, in_channels=3, num_classes=19):
        super(PaddedUnet, self).__init__()
        # Initialize the original UNet model
        self.unet = Unet(
            backbone=backbone,
            pretrained=pretrained,
            in_channels=in_channels,
            num_classes=num_classes,
        )
        
    def forward(self, x):
        # Save original dimensions for reference
        orig_dim = x.shape[2:]
        
        # Encoder gives us a list of embeddings, one for each level
        embeddings = self.unet.encoder(x)
        embeddings.reverse()  # Reverse order to mach decoding 
        
        # Manual decoding with padding fixes
        x = embeddings[0]
        
        for i, block in enumerate(self.unet.decoder.blocks):
            # Get skip connection if available
            skip = embeddings[i+1] if i+1 < len(embeddings) else None
            
            # Handle upscaling if needed
            if block.scale_factor > 1:
                x = F.interpolate(x, scale_factor=block.scale_factor, mode='nearest')
            
            # Apply padding to match dimensions instead of resizing
            # Needed since input shape is not divisible by 32
            if skip is not None:
                # Check if shapes need adjustment
                if x.shape[2:] != skip.shape[2:]:
                    # Calculate padding needed
                    h_diff = x.shape[2] - skip.shape[2]
                    w_diff = x.shape[3] - skip.shape[3]
                    
                    # Apply padding to match dimensions
                    if h_diff > 0 or w_diff > 0:
                        # Calculate padding values
                        pad_h = (h_diff // 2, h_diff - h_diff // 2)
                        pad_w = (w_diff // 2, w_diff - w_diff // 2)
                        skip = F.pad(skip, [pad_w[0], pad_w[1], pad_h[0], pad_h[1]])
                    elif h_diff < 0 or w_diff < 0:
                        # Need to pad the x tensor instead
                        h_diff = -h_diff
                        w_diff = -w_diff
                        pad_h = (h_diff // 2, h_diff - h_diff // 2)
                        pad_w = (w_diff // 2, w_diff - w_diff // 2)
                        x = F.pad(x, [pad_w[0], pad_w[1], pad_h[0], pad_h[1]])
                
                # Concatenate skip connection with current features
                x = torch.cat([x, skip], dim=1)
                
            # Apply convolutions
            x = block.conv1(x)
            x = block.conv2(x)
        
        # Apply final convolution
        x = self.unet.decoder.final_conv(x)
        
        # Ensure output matches input dimensions if needed
        if x.shape[2:] != orig_dim:
            x = F.interpolate(x, size=orig_dim, mode='bilinear', align_corners=False)
        
        return {'out': x}

if DEVICE.type == 'mps':
    print("mps detected, using no_ssl_verification")
    with no_ssl_verification():
        model = PaddedUnet(
            backbone='resnet50',
            pretrained=True,
            in_channels=3,
            num_classes=NUM_CLASSES,
        )
else:
    model = PaddedUnet(
        backbone='resnet50',
        pretrained=True,
        in_channels=3,
        num_classes=NUM_CLASSES,
    )

model = model.to(DEVICE)

# Load the best model checkpoint
best_model_path = "models/unet_imagenet1k_best_model.pth"
print(f"Loading best model from: {best_model_path}")

checkpoint = torch.load(best_model_path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])

print(f"Best model loaded successfully!")
print(f"Best validation IoU: {checkpoint['val_iou']:.4f}")
print(f"From epoch: {checkpoint['epoch']}")

mps detected, using no_ssl_verification
Loading best model from: models/unet_imagenet1k_best_model.pth
Best model loaded successfully!
Best validation IoU: 0.5690
From epoch: 25


In [23]:
test_avg_iou = run_test(model=model,test_dataloader=syn_test_dataloader)
print(f"Test Average IoU: {test_avg_iou:.4f}")

Testing: 100%|██████████| 469/469 [02:56<00:00,  2.66it/s]

Test Average IoU: 0.5760





## Unet Imagenet1k 480x256

In [25]:
image_transform = transforms.Compose([
    transforms.Resize((256, 480)), # We augment from 466x256 to 480x256 in order to avoid the padding artifacts
    transforms.ToTensor(),  # Converts PIL Image to Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize with ImageNet parameters
])

target_transform = transforms.Compose([
    transforms.Resize((256, 480), interpolation=Image.NEAREST), # This interpolation ensure that all pixels have a correct value of their class
    transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
])
syn_dataset = SegmentationDataset(images_dir=path_images, targets_dir=path_target, image_transform=image_transform, target_transform=target_transform)

# Get total dataset size
total_size = len(syn_dataset)
print(f"Total dataset size: {total_size}")

# Calculate split sizes (60% train, 10% val, 30% test)
train_size = int(0.6 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

print(f"Train size: {train_size} ({train_size/total_size*100:.1f}%)")
print(f"Validation size: {val_size} ({val_size/total_size*100:.1f}%)")
print(f"Test size: {test_size} ({test_size/total_size*100:.1f}%)")

# Create random splits
syn_train_dataset, syn_val_dataset, syn_test_dataset = random_split(
    syn_dataset, 
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(SEED)  # For reproducibility
)

# Create DataLoaders
batch_size = 8  


syn_train_dataloader = DataLoader(
    syn_train_dataset,
    batch_size=batch_size,
    shuffle=True,
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED) 
)


syn_val_dataloader = DataLoader(
    syn_val_dataset,
    batch_size=batch_size,
    shuffle=True,  
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

syn_test_dataloader = DataLoader(
    syn_test_dataset,
    batch_size=batch_size,
    shuffle=True,  
    #num_workers=2,
    generator=torch.Generator().manual_seed(SEED)
)

print(f"\nDataLoaders created:")
print(f"Train batches: {len(syn_train_dataloader)}")
print(f"Validation batches: {len(syn_val_dataloader)}")
print(f"Test batches: {len(syn_test_dataloader)}")

Total dataset size: 12500
Train size: 7500 (60.0%)
Validation size: 1250 (10.0%)
Test size: 3750 (30.0%)

DataLoaders created:
Train batches: 938
Validation batches: 157
Test batches: 469


In [None]:
# Test Function for classic unet
def run_test(model, test_dataloader):
    model.eval()
    test_total_loss = 0
    test_total_iou = 0
    
    with torch.no_grad():
        for i, (images, masks) in enumerate(tqdm(test_dataloader, desc="Testing")):
            images = images.to(DEVICE)
            masks = masks.to(DEVICE)
            
            outputs = model(images)
            
            # Collect metrics
            test_total_iou += compute_iou(outputs.detach().cpu(), masks, NUM_CLASSES)
            
            # Explicit cleanup
            del images, masks, outputs
            
            # Periodic memory clear
            if (i + 1) % 5 == 0:
                clear_memory()
    
    clear_memory()
    
    test_avg_iou = test_total_iou / len(test_dataloader)
    return test_avg_iou

In [26]:
if DEVICE.type == 'mps':
    print("mps detected, using no_ssl_verification")
    with no_ssl_verification():
        model = Unet(
            backbone='resnet50',
            pretrained=True,
            in_channels=3,
            num_classes=NUM_CLASSES,
        )
else:
    model = Unet(
        backbone='resnet50',
        pretrained=True,
        in_channels=3,
        num_classes=NUM_CLASSES,
    )

model = model.to(DEVICE)

# Load the best model checkpoint
best_model_path = "models/unet_imagenet1k_480x256_best_model.pth"
print(f"Loading best model from: {best_model_path}")

checkpoint = torch.load(best_model_path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])

print(f"Best model loaded successfully!")
print(f"Best validation IoU: {checkpoint['val_iou']:.4f}")
print(f"From epoch: {checkpoint['epoch']}")

mps detected, using no_ssl_verification
Loading best model from: models/unet_imagenet1k_480x256_best_model.pth
Best model loaded successfully!
Best validation IoU: 0.5983
From epoch: 24


In [29]:
test_avg_iou = run_test(model=model,test_dataloader=syn_test_dataloader)
print(f"Test Average IoU: {test_avg_iou:.4f}")

Testing: 100%|██████████| 469/469 [02:56<00:00,  2.66it/s]

Test Average IoU: 0.6040



