# U-Net for image segmentation

In [57]:
import torch
from torch import nn
from torch.nn import functional as F
import cv2
import numpy as np
from torch.utils.data import Dataset
from torch import optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import argparse

## Model

In [58]:
class UNet(nn.Module):
    def __init__(self, n_class, in_channels=3):
        super().__init__()
        
        # Encoder
        self.encoder1 = self._block(in_channels, 64)
        self.encoder2 = self._block(64, 128)
        self.encoder3 = self._block(128, 256)
        self.encoder4 = self._block(256, 512)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Bottleneck
        self.bottleneck = self._block(512, 1024)
        
        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder4 = self._block(1024, 512) 
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = self._block(512, 256)
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = self._block(256, 128)
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = self._block(128, 64)
        
        # Output
        self.conv_out = nn.Conv2d(64, n_class, kernel_size=1)
        
    def _block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(self.pool(enc2))
        enc4 = self.encoder4(self.pool(enc3))
        
        # Bottleneck
        bottleneck = self.bottleneck(self.pool(enc4))
        
        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat([dec4, enc4], dim=1)
        dec4 = self.decoder4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat([dec3, enc3], dim=1)
        dec3 = self.decoder3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat([dec2, enc2], dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat([dec1, enc1], dim=1)
        dec1 = self.decoder1(dec1)
        
        return self.conv_out(dec1)

## Synthetic dataset for simple testing

In [59]:
class ShapesDataset(Dataset):    
    def __init__(self, num_samples=500, img_size=128):
        self.num_samples = num_samples
        self.img_size = img_size
        self.shapes = ['circle', 'square', 'triangle']
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # empy image
        img = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
        mask = np.zeros((self.img_size, self.img_size), dtype=np.uint8)
        
        # random backgrounfd color
        bg_color = np.random.randint(50, 200, 3)
        img[:, :] = bg_color
        
        shape = np.random.choice(self.shapes)
        size = np.random.randint(20, 50)
        center_x = np.random.randint(size, self.img_size - size)
        center_y = np.random.randint(size, self.img_size - size)
        
        shape_color = np.random.randint(0, 255, 3)
        while np.all(np.abs(shape_color - bg_color) < 50):
            shape_color = np.random.randint(0, 255, 3)
        
        if shape == 'circle':
            cv2.circle(img, (center_x, center_y), size, 
                      shape_color.tolist(), -1)
            cv2.circle(mask, (center_x, center_y), size, 1, -1)
            
        elif shape == 'square':
            pt1 = (center_x - size, center_y - size)
            pt2 = (center_x + size, center_y + size)
            cv2.rectangle(img, pt1, pt2, shape_color.tolist(), -1)
            cv2.rectangle(mask, pt1, pt2, 1, -1)
            
        elif shape == 'triangle':
            pts = np.array([
                [center_x, center_y - size],
                [center_x - size, center_y + size],
                [center_x + size, center_y + size]
            ])
            cv2.fillPoly(img, [pts], shape_color.tolist())
            cv2.fillPoly(mask, [pts], 1)
        
        # add some noise
        noise = np.random.randint(-20, 20, img.shape, dtype=np.int32)
        img = np.clip(img + noise, 0, 255).astype(np.uint8)
        
        img_tensor = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
        mask_tensor = torch.from_numpy(mask).unsqueeze(0).float()
        
        return img_tensor, mask_tensor

In [60]:
dataset = ShapesDataset(num_samples=200, img_size=256)

## Training

In [70]:
class Config:
    SAVE_DIR = "./unet_training_results"
    MODEL_SAVE_PATH = os.path.join(SAVE_DIR, "best_unet.pth")
    
    BATCH_SIZE = 4
    NUM_EPOCHS = 10
    LEARNING_RATE = 1e-4
    IMG_SIZE = 256
    NUM_SAMPLES_TRAIN = 1000
    NUM_SAMPLES_VAL = 200
    
    # Dataset params
    SHAPES = ['circle', 'square', 'triangle', 'pentagon', 'star']
    MIN_SHAPE_SIZE = 30
    MAX_SHAPE_SIZE = 70
    
    # Vizsualization
    VISUALIZE_EVERY = 5  # epochs
    SAVE_MODEL_EVERY = 10  # epochs

os.makedirs(Config.SAVE_DIR, exist_ok=True)

In [71]:
class DiceBCELoss(nn.Module):    
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
        
    def forward(self, predictions, targets):
        predictions = torch.sigmoid(predictions)
        predictions = predictions.view(-1)
        targets = targets.view(-1)
        
        intersection = (predictions * targets).sum()
        dice = (2. * intersection + self.smooth) / (
            predictions.sum() + targets.sum() + self.smooth)
        dice_loss = 1 - dice
        
        bce_loss = nn.functional.binary_cross_entropy(
            predictions, targets, reduction='mean')
    
        return dice_loss + bce_loss

In [72]:
class IoULoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
        
    def forward(self, predictions, targets):
        predictions = torch.sigmoid(predictions)
        
        predictions = predictions.view(-1)
        targets = targets.view(-1)
        
        intersection = (predictions * targets).sum()
        union = predictions.sum() + targets.sum() - intersection
        
        iou = (intersection + self.smooth) / (union + self.smooth)
        return 1 - iou

In [73]:
def calculate_metrics(predictions, targets, threshold=0.5):
    """Вычисление метрик качества"""
    with torch.no_grad():
        pred_binary = (torch.sigmoid(predictions) > threshold).float()
        targets = targets.float()
        
        tp = (pred_binary * targets).sum()
        fp = (pred_binary * (1 - targets)).sum()
        fn = ((1 - pred_binary) * targets).sum()
        tn = ((1 - pred_binary) * (1 - targets)).sum()
        
        accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-6)
        precision = tp / (tp + fp + 1e-6)
        recall = tp / (tp + fn + 1e-6)
        f1 = 2 * precision * recall / (precision + recall + 1e-6)
        
        dice = 2 * tp / (2 * tp + fp + fn + 1e-6)
        iou = tp / (tp + fp + fn + 1e-6)
        
    return {
        'accuracy': accuracy.item(),
        'precision': precision.item(),
        'recall': recall.item(),
        'f1': f1.item(),
        'dice': dice.item(),
        'iou': iou.item()
    }

In [74]:
def train_unet_on_shapes():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    
    print("\nCreate dataset...")
    train_dataset = ShapesDataset(
        num_samples=Config.NUM_SAMPLES_TRAIN,
        img_size=Config.IMG_SIZE,
    )
    
    val_dataset = ShapesDataset(
        num_samples=Config.NUM_SAMPLES_VAL,
        img_size=Config.IMG_SIZE,
    )
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=Config.BATCH_SIZE, 
        shuffle=True,
        num_workers=0,
        pin_memory=True if device.type == 'cuda' else False
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=Config.BATCH_SIZE, 
        shuffle=False,
        num_workers=0,
        pin_memory=True if device.type == 'cuda' else False
    )
    
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Val dataset size: {len(val_dataset)}")
    print(f"Batch size: {Config.BATCH_SIZE}")
    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")
    
    print("\nU-Net initializing...")
    model = UNet(n_class=1, in_channels=3)  # binary segmentation
    model.to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    criterion = DiceBCELoss()  # or IoULoss() or nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5, 
        patience=5
    )
    
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    metrics_history = {
        'train': {'accuracy': [], 'f1': [], 'dice': [], 'iou': []},
        'val': {'accuracy': [], 'f1': [], 'dice': [], 'iou': []}
    }
    
    print("\n" + "="*50)
    print("Start training!")
    print("="*50)
    
    for epoch in range(Config.NUM_EPOCHS):
        model.train()
        epoch_train_loss = 0
        train_metrics = {'accuracy': 0, 'f1': 0, 'dice': 0, 'iou': 0}
        
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{Config.NUM_EPOCHS} [Train]')
        
        for batch_idx, (images, masks) in enumerate(train_pbar):
            images = images.to(device)
            masks = masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            
            epoch_train_loss += loss.item()
            
            batch_metrics = calculate_metrics(outputs, masks)
            for key in train_metrics.keys():
                train_metrics[key] += batch_metrics[key]
            
            train_pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'
            })
        
        avg_train_loss = epoch_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        for key in train_metrics.keys():
            train_metrics[key] /= len(train_loader)
            metrics_history['train'][key].append(train_metrics[key])
        
        model.eval()
        epoch_val_loss = 0
        val_metrics = {'accuracy': 0, 'f1': 0, 'dice': 0, 'iou': 0}
        
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{Config.NUM_EPOCHS} [Val]')
            
            for images, masks in val_pbar:
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                epoch_val_loss += loss.item()
                
                batch_metrics = calculate_metrics(outputs, masks)
                for key in val_metrics.keys():
                    val_metrics[key] += batch_metrics[key]
                
                val_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_val_loss = epoch_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        for key in val_metrics.keys():
            val_metrics[key] /= len(val_loader)
            metrics_history['val'][key].append(val_metrics[key])
        
        scheduler.step(avg_val_loss)
        

        print(f"\Epoch {epoch+1}/{Config.NUM_EPOCHS}:")
        print(f"  Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        print(f"  Train Metrics: Acc={train_metrics['accuracy']:.4f}, "
              f"F1={train_metrics['f1']:.4f}, Dice={train_metrics['dice']:.4f}")
        print(f"  Val Metrics:   Acc={val_metrics['accuracy']:.4f}, "
              f"F1={val_metrics['f1']:.4f}, Dice={val_metrics['dice']:.4f}")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'metrics': val_metrics,
            }, Config.MODEL_SAVE_PATH)
            print(f"  Best model saved (Val Loss: {avg_val_loss:.4f})")

        if (epoch + 1) % Config.SAVE_MODEL_EVERY == 0:
            checkpoint_path = os.path.join(
                Config.SAVE_DIR, f"checkpoint_epoch_{epoch+1}.pth"
            )
            torch.save(model.state_dict(), checkpoint_path)
            print(f"  Checkpoint saved: {checkpoint_path}")
        
        if (epoch + 1) % Config.VISUALIZE_EVERY == 0 or epoch == 0:
            visualize_results(
                model, val_dataset, device, epoch+1,
                save_dir=Config.SAVE_DIR,
                num_samples=4
            )
    
    history = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'metrics_history': metrics_history,
        'config': vars(Config)
    }
    
    history_path = os.path.join(Config.SAVE_DIR, "training_history.npy")
    np.save(history_path, history, allow_pickle=True)
    
    plot_training_history(train_losses, val_losses, metrics_history, Config.SAVE_DIR)
    
    print("\n" + "="*50)
    print("Train finished!")
    print(f"Best model saved: {Config.MODEL_SAVE_PATH}")
    print(f"Train history: {history_path}")
    print("="*50)
    
    return model, history



In [75]:
def visualize_results(model, dataset, device, epoch, save_dir, num_samples=4):
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, num_samples*4))
    
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    for i, idx in enumerate(indices):
        image, true_mask = dataset[idx]
        
        with torch.no_grad():
            input_tensor = image.unsqueeze(0).to(device)
            prediction = model(input_tensor)
            prediction_prob = torch.sigmoid(prediction).cpu().squeeze()
            prediction_binary = (prediction_prob > 0.5).float()
        
        image_np = image.permute(1, 2, 0).numpy()
        true_mask_np = true_mask.squeeze().numpy()
        pred_mask_np = prediction_binary.numpy()
        
        axes[i, 0].imshow(image_np)
        axes[i, 0].set_title("Image")
        axes[i, 0].axis('off')

        axes[i, 1].imshow(true_mask_np, cmap='gray')
        axes[i, 1].set_title("True mask")
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(pred_mask_np, cmap='gray')
        axes[i, 2].set_title("Predicted mask")
        axes[i, 2].axis('off')
        
        im = axes[i, 3].imshow(prediction_prob.numpy(), cmap='hot', vmin=0, vmax=1)
        axes[i, 3].set_title("Probability map")
        axes[i, 3].axis('off')
        plt.colorbar(im, ax=axes[i, 3], fraction=0.046, pad=0.04)
    
    plt.suptitle(f"Эпоха {epoch}", fontsize=16, y=1.02)
    plt.tight_layout()
    
    save_path = os.path.join(save_dir, f"predictions_epoch_{epoch}.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"  Image saved: {save_path}")



In [76]:
def plot_training_history(train_losses, val_losses, metrics_history, save_dir):
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    epochs = range(1, len(train_losses) + 1)
    
    # Loss
    axes[0, 0].plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2)
    axes[0, 0].plot(epochs, val_losses, 'r-', label='Val Loss', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[0, 1].plot(epochs, metrics_history['train']['accuracy'], 'b-', 
                    label='Train Accuracy', linewidth=2)
    axes[0, 1].plot(epochs, metrics_history['val']['accuracy'], 'r-', 
                    label='Val Accuracy', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # F1 Score
    axes[0, 2].plot(epochs, metrics_history['train']['f1'], 'b-', 
                    label='Train F1', linewidth=2)
    axes[0, 2].plot(epochs, metrics_history['val']['f1'], 'r-', 
                    label='Val F1', linewidth=2)
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('F1 Score')
    axes[0, 2].set_title('F1 Score')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    axes[1, 0].plot(epochs, metrics_history['train']['dice'], 'b-', 
                    label='Train Dice', linewidth=2)
    axes[1, 0].plot(epochs, metrics_history['val']['dice'], 'r-', 
                    label='Val Dice', linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Dice Coefficient')
    axes[1, 0].set_title('Dice Coefficient')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # IoU
    axes[1, 1].plot(epochs, metrics_history['train']['iou'], 'b-', 
                    label='Train IoU', linewidth=2)
    axes[1, 1].plot(epochs, metrics_history['val']['iou'], 'r-', 
                    label='Val IoU', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('IoU')
    axes[1, 1].set_title('Intersection over Union (IoU)')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # Learning Rate 
    axes[1, 2].axis('off')  
    
    plt.suptitle('U-Net Training history', fontsize=16, y=1.02)
    plt.tight_layout()
    
    history_plot_path = os.path.join(save_dir, "training_history.png")
    plt.savefig(history_plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"Plots saved: {history_plot_path}")


In [77]:
def test_trained_model(model_path, num_test_samples=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = UNet(n_class=1, in_channels=3)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"\nLoaded model from: {model_path}")
    print(f"Epoch: {checkpoint['epoch']}")
    print(f"Val Loss: {checkpoint['val_loss']:.4f}")
    
    test_dataset = ShapesDataset(
        num_samples=num_test_samples,
        img_size=Config.IMG_SIZE,
        shapes=Config.SHAPES,
        min_size=Config.MIN_SHAPE_SIZE,
        max_size=Config.MAX_SHAPE_SIZE
    )
    
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)
    
    total_metrics = {'accuracy': 0, 'f1': 0, 'dice': 0, 'iou': 0}
    
    with torch.no_grad():
        for images, masks in test_loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            batch_metrics = calculate_metrics(outputs, masks)
            
            for key in total_metrics.keys():
                total_metrics[key] += batch_metrics[key] * images.size(0)
    
    for key in total_metrics.keys():
        total_metrics[key] /= len(test_dataset)
    
    print("\nTest results:")
    print(f"Accuracy: {total_metrics['accuracy']:.4f}")
    print(f"F1 Score: {total_metrics['f1']:.4f}")
    print(f"Dice Coefficient: {total_metrics['dice']:.4f}")
    print(f"IoU: {total_metrics['iou']:.4f}")
    
    print("\nVisualizing...")
    visualize_results(
        model, test_dataset, device, 
        epoch='test', 
        save_dir=Config.SAVE_DIR,
        num_samples=min(6, num_test_samples)
    )
    
    return total_metrics

In [78]:
trained_model, history = train_unet_on_shapes()

Device: cpu

Create dataset...
Train dataset size: 1000
Val dataset size: 200
Batch size: 4
Train batches: 250
Val batches: 50

U-Net initializing...
Total parameters: 31,043,521
Trainable parameters: 31,043,521

Start training!


Epoch 1/10 [Train]: 100%|██████████| 250/250 [13:53<00:00,  3.34s/it, loss=0.7698, lr=1.00e-04]
Epoch 1/10 [Val]: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, loss=2.0700]


\Epoch 1/10:
  Train Loss: 1.0573 | Val Loss: 2.0273
  Train Metrics: Acc=0.9585, F1=0.8320, Dice=0.8320
  Val Metrics:   Acc=0.6965, F1=0.3503, Dice=0.3503
  Best model saved (Val Loss: 2.0273)
  Image saved: ./unet_training_results\predictions_epoch_1.png


Epoch 2/10 [Train]: 100%|██████████| 250/250 [13:08<00:00,  3.16s/it, loss=0.8251, lr=1.00e-04]
Epoch 2/10 [Val]: 100%|██████████| 50/50 [00:53<00:00,  1.07s/it, loss=3.9556]


\Epoch 2/10:
  Train Loss: 0.7931 | Val Loss: 2.0797
  Train Metrics: Acc=0.9904, F1=0.9168, Dice=0.9168
  Val Metrics:   Acc=0.7429, F1=0.3918, Dice=0.3918


Epoch 3/10 [Train]: 100%|██████████| 250/250 [13:23<00:00,  3.22s/it, loss=0.4984, lr=1.00e-04]
Epoch 3/10 [Val]: 100%|██████████| 50/50 [00:53<00:00,  1.07s/it, loss=2.7692]


\Epoch 3/10:
  Train Loss: 0.5864 | Val Loss: 2.2671
  Train Metrics: Acc=0.9932, F1=0.9424, Dice=0.9424
  Val Metrics:   Acc=0.7399, F1=0.3922, Dice=0.3922


Epoch 4/10 [Train]: 100%|██████████| 250/250 [13:24<00:00,  3.22s/it, loss=0.2945, lr=1.00e-04]
Epoch 4/10 [Val]: 100%|██████████| 50/50 [00:53<00:00,  1.06s/it, loss=0.6472]


\Epoch 4/10:
  Train Loss: 0.4370 | Val Loss: 1.7231
  Train Metrics: Acc=0.9936, F1=0.9393, Dice=0.9393
  Val Metrics:   Acc=0.8045, F1=0.4904, Dice=0.4904
  Best model saved (Val Loss: 1.7231)


Epoch 5/10 [Train]: 100%|██████████| 250/250 [13:25<00:00,  3.22s/it, loss=0.1431, lr=1.00e-04]
Epoch 5/10 [Val]: 100%|██████████| 50/50 [00:53<00:00,  1.06s/it, loss=4.2105]


\Epoch 5/10:
  Train Loss: 0.2944 | Val Loss: 2.7383
  Train Metrics: Acc=0.9950, F1=0.9554, Dice=0.9554
  Val Metrics:   Acc=0.6798, F1=0.3749, Dice=0.3749
  Image saved: ./unet_training_results\predictions_epoch_5.png


Epoch 6/10 [Train]: 100%|██████████| 250/250 [13:08<00:00,  3.15s/it, loss=0.1514, lr=1.00e-04]
Epoch 6/10 [Val]: 100%|██████████| 50/50 [00:53<00:00,  1.07s/it, loss=3.3415]


\Epoch 6/10:
  Train Loss: 0.2127 | Val Loss: 3.0233
  Train Metrics: Acc=0.9949, F1=0.9585, Dice=0.9585
  Val Metrics:   Acc=0.6819, F1=0.3655, Dice=0.3655


Epoch 7/10 [Train]: 100%|██████████| 250/250 [13:26<00:00,  3.23s/it, loss=0.1332, lr=1.00e-04]
Epoch 7/10 [Val]: 100%|██████████| 50/50 [00:53<00:00,  1.07s/it, loss=3.5729]


\Epoch 7/10:
  Train Loss: 0.1459 | Val Loss: 3.1840
  Train Metrics: Acc=0.9963, F1=0.9691, Dice=0.9691
  Val Metrics:   Acc=0.7229, F1=0.3595, Dice=0.3595


Epoch 8/10 [Train]: 100%|██████████| 250/250 [13:24<00:00,  3.22s/it, loss=0.4370, lr=1.00e-04]
Epoch 8/10 [Val]: 100%|██████████| 50/50 [00:53<00:00,  1.07s/it, loss=0.0968]


\Epoch 8/10:
  Train Loss: 0.1290 | Val Loss: 3.6408
  Train Metrics: Acc=0.9957, F1=0.9632, Dice=0.9632
  Val Metrics:   Acc=0.5879, F1=0.2671, Dice=0.2671


Epoch 9/10 [Train]: 100%|██████████| 250/250 [14:41<00:00,  3.53s/it, loss=0.0365, lr=1.00e-04]
Epoch 9/10 [Val]: 100%|██████████| 50/50 [00:45<00:00,  1.09it/s, loss=3.1565]


\Epoch 9/10:
  Train Loss: 0.1020 | Val Loss: 3.3449
  Train Metrics: Acc=0.9966, F1=0.9691, Dice=0.9691
  Val Metrics:   Acc=0.6782, F1=0.3657, Dice=0.3657


Epoch 10/10 [Train]: 100%|██████████| 250/250 [21:19<00:00,  5.12s/it, loss=0.0336, lr=1.00e-04]
Epoch 10/10 [Val]: 100%|██████████| 50/50 [01:52<00:00,  2.25s/it, loss=2.7353]


\Epoch 10/10:
  Train Loss: 0.0905 | Val Loss: 4.2428
  Train Metrics: Acc=0.9960, F1=0.9671, Dice=0.9671
  Val Metrics:   Acc=0.5906, F1=0.2747, Dice=0.2747
  Checkpoint saved: ./unet_training_results\checkpoint_epoch_10.pth
  Image saved: ./unet_training_results\predictions_epoch_10.png


TypeError: cannot pickle 'mappingproxy' object