In [2]:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
import random

# Data paths
data_path = r"/kaggle/input/monusegdataset/data"

# ============================================================================
# DATASET PROCESSING
# ============================================================================

# Data Augmentation Class
class NucleiAugmentation:
    def __init__(self, image_size=256, augmentation_prob=0.7, sparse_mode=False):
        self.image_size = image_size
        self.augmentation_prob = augmentation_prob
        self.sparse_mode = sparse_mode
    
    def __call__(self, image, mask):
        if torch.is_tensor(image):
            image = TF.to_pil_image(image)
        if torch.is_tensor(mask):
            mask = TF.to_pil_image(mask)
        
        # Random horizontal flip
        if random.random() > self.augmentation_prob:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        
        # Random vertical flip
        if random.random() > self.augmentation_prob:
            image = TF.vflip(image)
            mask = TF.vflip(mask)
        
        # Random rotation (-45 to +45 degrees)
        if random.random() > self.augmentation_prob:
            angle = random.uniform(-45, 45)
            image = TF.rotate(image, angle)
            mask = TF.rotate(mask, angle)
        
        # Random brightness/contrast adjustment
        if random.random() > self.augmentation_prob:
            brightness_factor = random.uniform(0.8, 1.2)
            image = TF.adjust_brightness(image, brightness_factor)
        
        if random.random() > self.augmentation_prob:
            contrast_factor = random.uniform(0.8, 1.2)
            image = TF.adjust_contrast(image, contrast_factor)
        
        # Random Gaussian blur
        if random.random() > self.augmentation_prob:
            image = TF.gaussian_blur(image, kernel_size=3)
        
        # For sparse mode: random crop
        if self.sparse_mode and random.random() > 0.8:
            width, height = image.size
            crop_size = random.randint(128, 256)
            i = random.randint(0, height - crop_size)
            j = random.randint(0, width - crop_size)
            image = TF.crop(image, i, j, crop_size, crop_size)
            mask = TF.crop(mask, i, j, crop_size, crop_size)
        
        # Ensure final size
        image = image.resize((self.image_size, self.image_size))
        mask = mask.resize((self.image_size, self.image_size))
        
        return image, mask

# Dataset Class
class KMMSDataset(Dataset):
    def __init__(self, image_paths, mask_paths, image_size=256, augment=False, sparse_mode=False):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.image_size = image_size
        self.augment = augment
        self.sparse_mode = sparse_mode
        self.augmentor = NucleiAugmentation(image_size, sparse_mode=sparse_mode) if augment else None
        
        self.valid_pairs = list(zip(image_paths, mask_paths))
        
        print(f"Created dataset with {len(self.valid_pairs)} valid image-mask pairs")
        if augment:
            mode = "sparse" if sparse_mode else "normal"
            print(f"✓ Data augmentation ENABLED ({mode} mode)")
        else:
            print("✗ Data augmentation DISABLED")
    
    def __len__(self):
        return len(self.valid_pairs)
    
    def __getitem__(self, idx):
        img_path, mask_path = self.valid_pairs[idx]
        
        try:
            # Load image
            image = Image.open(img_path).convert('RGB')
            mask = Image.open(mask_path).convert('L')
            
            # Apply augmentation
            if self.augment and self.augmentor:
                image, mask = self.augmentor(image, mask)
            else:
                image = image.resize((self.image_size, self.image_size))
                mask = mask.resize((self.image_size, self.image_size))
            
            # Convert to numpy arrays
            image = np.array(image).astype(np.float32) / 255.0
            mask = np.array(mask).astype(np.float32) / 255.0
            
            # Ensure mask is binary
            mask = (mask > 0.5).astype(np.float32)
            
            # Convert to tensors
            image = torch.from_numpy(image).permute(2, 0, 1).float()
            mask = torch.from_numpy(mask).unsqueeze(0).float()
            
            return image, mask, os.path.basename(img_path)
        
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            dummy_image = torch.zeros(3, self.image_size, self.image_size)
            dummy_mask = torch.zeros(1, self.image_size, self.image_size)
            return dummy_image, dummy_mask, "error"

# Data loading and balancing
def find_kmms_data():
    train_path = os.path.join(data_path, "kmms_training")
    test_path = os.path.join(data_path, "kmms_test")
    
    train_images, train_masks, test_images, test_masks = [], [], [], []
    
    # Load training data
    if os.path.exists(train_path):
        train_images_dir = os.path.join(train_path, "images")
        train_masks_dir = os.path.join(train_path, "masks")
        if os.path.exists(train_images_dir):
            train_images = sorted([os.path.join(train_images_dir, f) for f in os.listdir(train_images_dir) 
                                 if f.lower().endswith(('.tiff', '.tif', '.png', '.jpg', '.jpeg'))])
        if os.path.exists(train_masks_dir):
            train_masks = sorted([os.path.join(train_masks_dir, f) for f in os.listdir(train_masks_dir) 
                                if f.lower().endswith(('.tiff', '.tif', '.png', '.jpg', '.jpeg'))])
    
    # Load test data
    if os.path.exists(test_path):
        test_images_dir = os.path.join(test_path, "images")
        test_masks_dir = os.path.join(test_path, "masks")
        if os.path.exists(test_images_dir):
            test_images = sorted([os.path.join(test_images_dir, f) for f in os.listdir(test_images_dir) 
                                if f.lower().endswith(('.tiff', '.tif', '.png', '.jpg', '.jpeg'))])
        if os.path.exists(test_masks_dir):
            test_masks = sorted([os.path.join(test_masks_dir, f) for f in os.listdir(test_masks_dir) 
                               if f.lower().endswith(('.tiff', '.tif', '.png', '.jpg', '.jpeg'))])
    
    print(f"Found: {len(train_images)} training images, {len(test_images)} test images")
    return train_images, train_masks, test_images, test_masks

# Improved data splitting function
def create_balanced_split(all_images, all_masks, test_size=0.15, val_size=0.15, random_state=42):
    """
    Create a balanced train/val/test split from all available data
    """
    from sklearn.model_selection import train_test_split
    
    # First split: Train+Val vs Test
    train_val_images, test_images, train_val_masks, test_masks = train_test_split(
        all_images, all_masks, test_size=test_size, random_state=random_state
    )
    
    # Second split: Train vs Val from train_val set
    val_ratio = val_size / (1 - test_size)
    train_images, val_images, train_masks, val_masks = train_test_split(
        train_val_images, train_val_masks, test_size=val_ratio, random_state=random_state
    )
    
    total_samples = len(all_images)
    print(f"✅ Balanced Data Split:")
    print(f"   Training: {len(train_images)} images ({len(train_images)/total_samples*100:.1f}%)")
    print(f"   Validation: {len(val_images)} images ({len(val_images)/total_samples*100:.1f}%)")
    print(f"   Test: {len(test_images)} images ({len(test_images)/total_samples*100:.1f}%)")
    
    return train_images, val_images, test_images, train_masks, val_masks, test_masks

# Main dataset preparation function
def prepare_datasets(image_size=256, batch_size=4, balanced_split=True):
    """
    Main function to prepare datasets with optional balanced splitting
    """
    # Load all data
    train_images, train_masks, test_images, test_masks = find_kmms_data()
    
    if balanced_split:
        # Combine all data for balanced splitting
        all_images = train_images + test_images
        all_masks = train_masks + test_masks
        
        # Create balanced split (70-15-15 recommended)
        train_images, val_images, test_images, train_masks, val_masks, test_masks = create_balanced_split(
            all_images, all_masks, test_size=0.15, val_size=0.15
        )
    else:
        # Use original split (80-20 from training data for validation)
        from sklearn.model_selection import train_test_split
        train_images, val_images, train_masks, val_masks = train_test_split(
            train_images, train_masks, test_size=0.2, random_state=42
        )
        print(f"Using original split: Train={len(train_images)}, Val={len(val_images)}, Test={len(test_images)}")
    
    # Create datasets
    train_dataset = KMMSDataset(train_images, train_masks, image_size=image_size, augment=True, sparse_mode=True)
    val_dataset = KMMSDataset(val_images, val_masks, image_size=image_size, augment=False, sparse_mode=False)
    test_dataset = KMMSDataset(test_images, test_masks, image_size=image_size, augment=False, sparse_mode=False)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset

# Example usage
if __name__ == "__main__":
    # Prepare datasets with balanced splitting (recommended)
    train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = prepare_datasets(
        image_size=256, 
        batch_size=4, 
        balanced_split=True  # Set to False to use original split
    )
    
    print(f"\nFinal dataset sizes:")
    print(f"Training: {len(train_dataset)} samples")
    print(f"Validation: {len(val_dataset)} samples") 
    print(f"Test: {len(test_dataset)} samples")
    
    # Test data loading
    print(f"\nTesting data loader...")
    for images, masks, filenames in train_loader:
        print(f"Batch - Images: {images.shape}, Masks: {masks.shape}")
        break

Found: 24 training images, 58 test images
✅ Balanced Data Split:
   Training: 56 images (68.3%)
   Validation: 13 images (15.9%)
   Test: 13 images (15.9%)
Created dataset with 56 valid image-mask pairs
✓ Data augmentation ENABLED (sparse mode)
Created dataset with 13 valid image-mask pairs
✗ Data augmentation DISABLED
Created dataset with 13 valid image-mask pairs
✗ Data augmentation DISABLED

Final dataset sizes:
Training: 56 samples
Validation: 13 samples
Test: 13 samples

Testing data loader...
Batch - Images: torch.Size([4, 3, 256, 256]), Masks: torch.Size([4, 1, 256, 256])


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ============================================================================
# FIXED VM-UNET MODEL ARCHITECTURE
# ============================================================================

class VMBlock(nn.Module):
    """Vision Mamba Block with residual connection"""
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.LayerNorm(channels)
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.activation = nn.GELU()
        
    def forward(self, x):
        residual = x
        # LayerNorm and channel-first for conv
        x = x.permute(0, 2, 3, 1)  # [B, C, H, W] -> [B, H, W, C]
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)  # [B, H, W, C] -> [B, C, H, W]
        
        x = self.conv1(x)
        x = self.activation(x)
        x = self.conv2(x)
        
        return x + residual

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )
    
    def forward(self, x):
        return self.double_conv(x)

class VMEncoderBlock(nn.Module):
    """Encoder block with VMBlock"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.vm_block = VMBlock(out_channels)
        self.pool = nn.MaxPool2d(2)
    
    def forward(self, x):
        # Feature extraction
        x = self.conv(x)
        # Mamba-style processing
        x = self.vm_block(x)
        # Downsample
        pooled = self.pool(x)
        return pooled, x

class VMDecoderBlock(nn.Module):
    """FIXED Decoder block with proper channel handling"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # Upsample reduces channels by 2
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        # After concatenation: (in_channels//2 + skip_channels) -> out_channels
        self.conv = DoubleConv(in_channels // 2 + out_channels, out_channels)  # FIXED
        self.vm_block = VMBlock(out_channels)
    
    def forward(self, x, skip):
        # Upsample
        x = self.up(x)
        # Skip connection - ensure spatial dimensions match
        if x.shape[2:] != skip.shape[2:]:
            x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=True)
        
        # Concatenate along channel dimension
        x = torch.cat([x, skip], dim=1)
        # Feature extraction
        x = self.conv(x)
        # Mamba-style processing
        x = self.vm_block(x)
        return x

class VMUNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()
        
        # Encoder path
        self.enc1 = VMEncoderBlock(n_channels, 64)    # 64 channels
        self.enc2 = VMEncoderBlock(64, 128)           # 128 channels  
        self.enc3 = VMEncoderBlock(128, 256)          # 256 channels
        self.enc4 = VMEncoderBlock(256, 512)          # 512 channels
        
        # Bridge
        self.bridge = DoubleConv(512, 1024)           # 1024 channels
        
        # Decoder path - FIXED channel dimensions
        self.dec1 = VMDecoderBlock(1024, 512)         # 1024->512 after up, then 512+512=1024->512
        self.dec2 = VMDecoderBlock(512, 256)          # 512->256 after up, then 256+256=512->256
        self.dec3 = VMDecoderBlock(256, 128)          # 256->128 after up, then 128+128=256->128
        self.dec4 = VMDecoderBlock(128, 64)           # 128->64 after up, then 64+64=128->64
        
        # Output
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
        
    def forward(self, x):
        # Encoder
        x1, skip1 = self.enc1(x)    # skip1: 64 channels
        x2, skip2 = self.enc2(x1)   # skip2: 128 channels
        x3, skip3 = self.enc3(x2)   # skip3: 256 channels
        x4, skip4 = self.enc4(x3)   # skip4: 512 channels
        
        # Bridge
        x5 = self.bridge(x4)        # 1024 channels
        
        # Decoder with proper skip connections
        x = self.dec1(x5, skip4)    # 512 channels
        x = self.dec2(x, skip3)     # 256 channels
        x = self.dec3(x, skip2)     # 128 channels
        x = self.dec4(x, skip1)     # 64 channels
        
        # Output
        return torch.sigmoid(self.outc(x))

# ============================================================================
# SIMPLIFIED ENHANCED VERSION (Easier to debug)
# ============================================================================

class SimpleVMUNet(nn.Module):
    """Simplified version that's easier to debug"""
    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()
        
        # Encoder
        self.enc1 = DoubleConv(n_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        
        # Bridge with VMBlock
        self.bridge = DoubleConv(512, 1024)
        self.vm_bridge = VMBlock(1024)
        
        # Decoder
        self.up1 = nn.ConvTranspose2d(1024, 512, 2, 2)
        self.dec1 = DoubleConv(1024, 512)  # 512 (up) + 512 (skip)
        self.vm1 = VMBlock(512)
        
        self.up2 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.dec2 = DoubleConv(512, 256)   # 256 (up) + 256 (skip)
        self.vm2 = VMBlock(256)
        
        self.up3 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec3 = DoubleConv(256, 128)   # 128 (up) + 128 (skip)
        self.vm3 = VMBlock(128)
        
        self.up4 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec4 = DoubleConv(128, 64)    # 64 (up) + 64 (skip)
        self.vm4 = VMBlock(64)
        
        self.outc = nn.Conv2d(64, n_classes, 1)
        
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))
        
        # Bridge
        b = self.pool4(e4)
        b = self.bridge(b)
        b = self.vm_bridge(b)
        
        # Decoder
        d1 = self.up1(b)
        d1 = torch.cat([d1, e4], dim=1)
        d1 = self.dec1(d1)
        d1 = self.vm1(d1)
        
        d2 = self.up2(d1)
        d2 = torch.cat([d2, e3], dim=1)
        d2 = self.dec2(d2)
        d2 = self.vm2(d2)
        
        d3 = self.up3(d2)
        d3 = torch.cat([d3, e2], dim=1)
        d3 = self.dec3(d3)
        d3 = self.vm3(d3)
        
        d4 = self.up4(d3)
        d4 = torch.cat([d4, e1], dim=1)
        d4 = self.dec4(d4)
        d4 = self.vm4(d4)
        
        return torch.sigmoid(self.outc(d4))

# ============================================================================
# TEST THE MODEL WITH YOUR DATA
# ============================================================================

def test_model_with_data():
    """Test the model with your data loader to verify it works"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Testing on device: {device}")
    
    # Load your data
    train_loader, val_loader, test_loader, _, _, _ = prepare_datasets(
        image_size=256, 
        batch_size=4, 
        balanced_split=True
    )
    
    # Test both models
    models = {
        "Fixed VM-UNet": VMUNet(n_channels=3, n_classes=1),
        "Simple VM-UNet": SimpleVMUNet(n_channels=3, n_classes=1)
    }
    
    for name, model in models.items():
        print(f"\n{'='*50}")
        print(f"Testing {name}")
        print(f"{'='*50}")
        
        model.to(device)
        
        # Test with a batch
        try:
            with torch.no_grad():
                for images, masks, _ in train_loader:
                    images = images.to(device)
                    print(f"Input shape: {images.shape}")
                    
                    outputs = model(images)
                    print(f"Output shape: {outputs.shape}")
                    print(f"✅ {name} works correctly!")
                    break
                    
        except Exception as e:
            print(f"❌ Error in {name}: {e}")
    
    return models

# ============================================================================
# UPDATED TRAINING FUNCTION
# ============================================================================

def train_fixed_vm_unet(model, train_loader, val_loader, num_epochs=50, device='cuda'):
    model.to(device)
    
    criterion = CombinedLoss(alpha=0.7, beta=0.3)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
    
    train_losses = []
    val_losses = []
    train_f1_scores = []
    val_f1_scores = []
    
    best_val_f1 = 0.0
    best_model_state = None
    
    print("🚀 Starting Fixed VM-UNet Training...")
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        epoch_train_loss = 0.0
        train_preds = []
        train_targets = []
        
        for images, masks, _ in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]'):
            images = images.to(device)
            masks = masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            epoch_train_loss += loss.item()
            train_preds.append(outputs.detach())
            train_targets.append(masks.detach())
        
        # Validation
        model.eval()
        epoch_val_loss = 0.0
        val_preds = []
        val_targets = []
        
        with torch.no_grad():
            for images, masks, _ in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]'):
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                epoch_val_loss += loss.item()
                
                val_preds.append(outputs)
                val_targets.append(masks)
        
        # Calculate metrics
        avg_train_loss = epoch_train_loss / len(train_loader)
        avg_val_loss = epoch_val_loss / len(val_loader)
        
        train_preds = torch.cat(train_preds)
        train_targets = torch.cat(train_targets)
        train_precision, train_recall, train_accuracy, train_f1 = calculate_metrics(train_preds, train_targets, threshold=0.3)
        
        val_preds = torch.cat(val_preds)
        val_targets = torch.cat(val_targets)
        val_precision, val_recall, val_accuracy, val_f1 = calculate_metrics(val_preds, val_targets, threshold=0.3)
        
        # Store results
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        train_f1_scores.append(train_f1)
        val_f1_scores.append(val_f1)
        
        scheduler.step(avg_val_loss)
        
        print(f'\n📈 Epoch {epoch+1}/{num_epochs}:')
        print(f'   Train Loss: {avg_train_loss:.4f} | Prec: {train_precision:.4f} | Rec: {train_recall:.4f} | F1: {train_f1:.4f}')
        print(f'   Val Loss: {avg_val_loss:.4f} | Prec: {val_precision:.4f} | Rec: {val_recall:.4f} | F1: {val_f1:.4f}')
        
        # Save best model
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_model_state = model.state_dict().copy()
            torch.save(best_model_state, 'best_fixed_vm_unet_model.pth')
            print(f'💾 New best model saved! Val F1: {best_val_f1:.4f}')
    
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model, train_losses, val_losses, train_f1_scores, val_f1_scores

# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    # First test the models
    models = test_model_with_data()
    
    # Use the simple model for training (more stable)
    model = models["Simple VM-UNet"]
    
    # Load datasets
    train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = prepare_datasets(
        image_size=256, 
        batch_size=4, 
        balanced_split=True
    )
    
    print(f"\n📊 Dataset sizes: Train={len(train_dataset)}, Val={len(val_dataset)}, Test={len(test_dataset)}")
    
    # Train the model
    print("\n🎯 Starting training...")
    model, train_losses, val_losses, train_f1_scores, val_f1_scores = train_fixed_vm_unet(
        model, train_loader, val_loader, num_epochs=50, device='cuda'
    )
    
    # Test the model
    test_metrics = test_vm_unet(model, test_loader, device='cuda')
    
    # Save final model
    torch.save(model.state_dict(), 'final_fixed_vm_unet_model.pth')
    print("\n💾 Final model saved!")
    
    return model, test_metrics

if __name__ == "__main__":
    model, metrics = main()

Testing on device: cuda
Found: 24 training images, 58 test images
✅ Balanced Data Split:
   Training: 56 images (68.3%)
   Validation: 13 images (15.9%)
   Test: 13 images (15.9%)
Created dataset with 56 valid image-mask pairs
✓ Data augmentation ENABLED (sparse mode)
Created dataset with 13 valid image-mask pairs
✗ Data augmentation DISABLED
Created dataset with 13 valid image-mask pairs
✗ Data augmentation DISABLED

Testing Fixed VM-UNet
Input shape: torch.Size([4, 3, 256, 256])
Output shape: torch.Size([4, 1, 256, 256])
✅ Fixed VM-UNet works correctly!

Testing Simple VM-UNet
Input shape: torch.Size([4, 3, 256, 256])
Output shape: torch.Size([4, 1, 256, 256])
✅ Simple VM-UNet works correctly!
Found: 24 training images, 58 test images
✅ Balanced Data Split:
   Training: 56 images (68.3%)
   Validation: 13 images (15.9%)
   Test: 13 images (15.9%)
Created dataset with 56 valid image-mask pairs
✓ Data augmentation ENABLED (sparse mode)
Created dataset with 13 valid image-mask pairs
✗ D

Epoch 1/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.95it/s]
Epoch 1/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.23it/s]



📈 Epoch 1/50:
   Train Loss: 0.4892 | Prec: 0.1542 | Rec: 0.9865 | F1: 0.2667
   Val Loss: 0.5331 | Prec: 0.0911 | Rec: 0.0056 | F1: 0.0105
💾 New best model saved! Val F1: 0.0105


Epoch 2/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.91it/s]
Epoch 2/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.06it/s]



📈 Epoch 2/50:
   Train Loss: 0.3721 | Prec: 0.3575 | Rec: 0.8770 | F1: 0.5079
   Val Loss: 0.6211 | Prec: 0.8502 | Rec: 0.1170 | F1: 0.2058
💾 New best model saved! Val F1: 0.2058


Epoch 3/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.90it/s]
Epoch 3/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  4.18it/s]



📈 Epoch 3/50:
   Train Loss: 0.3209 | Prec: 0.5388 | Rec: 0.7533 | F1: 0.6283
   Val Loss: 0.3588 | Prec: 0.6709 | Rec: 0.5795 | F1: 0.6218
💾 New best model saved! Val F1: 0.6218


Epoch 4/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.04it/s]
Epoch 4/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.37it/s]



📈 Epoch 4/50:
   Train Loss: 0.2892 | Prec: 0.5858 | Rec: 0.7652 | F1: 0.6636
   Val Loss: 0.3137 | Prec: 0.6830 | Rec: 0.6315 | F1: 0.6562
💾 New best model saved! Val F1: 0.6562


Epoch 5/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.02it/s]
Epoch 5/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.36it/s]



📈 Epoch 5/50:
   Train Loss: 0.3177 | Prec: 0.5361 | Rec: 0.7772 | F1: 0.6345
   Val Loss: 0.2893 | Prec: 0.5336 | Rec: 0.8384 | F1: 0.6521


Epoch 6/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.00it/s]
Epoch 6/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.28it/s]



📈 Epoch 6/50:
   Train Loss: 0.2749 | Prec: 0.6149 | Rec: 0.7672 | F1: 0.6827
   Val Loss: 0.2801 | Prec: 0.6690 | Rec: 0.7137 | F1: 0.6906
💾 New best model saved! Val F1: 0.6906


Epoch 7/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.99it/s]
Epoch 7/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.32it/s]



📈 Epoch 7/50:
   Train Loss: 0.2476 | Prec: 0.6698 | Rec: 0.7889 | F1: 0.7245
   Val Loss: 0.2672 | Prec: 0.6770 | Rec: 0.7464 | F1: 0.7100
💾 New best model saved! Val F1: 0.7100


Epoch 8/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.93it/s]
Epoch 8/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.44it/s]



📈 Epoch 8/50:
   Train Loss: 0.2627 | Prec: 0.5652 | Rec: 0.8381 | F1: 0.6751
   Val Loss: 0.3039 | Prec: 0.5126 | Rec: 0.8269 | F1: 0.6329


Epoch 9/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.04it/s]
Epoch 9/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.40it/s]



📈 Epoch 9/50:
   Train Loss: 0.2614 | Prec: 0.6273 | Rec: 0.7838 | F1: 0.6969
   Val Loss: 0.2822 | Prec: 0.6863 | Rec: 0.6899 | F1: 0.6881


Epoch 10/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  2.00it/s]
Epoch 10/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.23it/s]



📈 Epoch 10/50:
   Train Loss: 0.2419 | Prec: 0.6374 | Rec: 0.8168 | F1: 0.7160
   Val Loss: 0.2506 | Prec: 0.6237 | Rec: 0.8287 | F1: 0.7118
💾 New best model saved! Val F1: 0.7118


Epoch 11/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.99it/s]
Epoch 11/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.43it/s]



📈 Epoch 11/50:
   Train Loss: 0.2542 | Prec: 0.6312 | Rec: 0.7939 | F1: 0.7033
   Val Loss: 0.2652 | Prec: 0.6382 | Rec: 0.7626 | F1: 0.6949


Epoch 12/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.01it/s]
Epoch 12/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.26it/s]



📈 Epoch 12/50:
   Train Loss: 0.2614 | Prec: 0.6240 | Rec: 0.7933 | F1: 0.6985
   Val Loss: 0.2608 | Prec: 0.6426 | Rec: 0.7964 | F1: 0.7113


Epoch 13/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.93it/s]
Epoch 13/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.32it/s]



📈 Epoch 13/50:
   Train Loss: 0.2692 | Prec: 0.6214 | Rec: 0.7927 | F1: 0.6967
   Val Loss: 0.3073 | Prec: 0.7287 | Rec: 0.6279 | F1: 0.6745


Epoch 14/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.99it/s]
Epoch 14/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.39it/s]



📈 Epoch 14/50:
   Train Loss: 0.2515 | Prec: 0.6435 | Rec: 0.7957 | F1: 0.7116
   Val Loss: 0.2664 | Prec: 0.6617 | Rec: 0.7511 | F1: 0.7036


Epoch 15/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.01it/s]
Epoch 15/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.14it/s]



📈 Epoch 15/50:
   Train Loss: 0.2726 | Prec: 0.6194 | Rec: 0.7521 | F1: 0.6793
   Val Loss: 0.3136 | Prec: 0.7124 | Rec: 0.6223 | F1: 0.6643


Epoch 16/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.87it/s]
Epoch 16/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.29it/s]



📈 Epoch 16/50:
   Train Loss: 0.2418 | Prec: 0.6412 | Rec: 0.8153 | F1: 0.7178
   Val Loss: 0.2551 | Prec: 0.6706 | Rec: 0.7730 | F1: 0.7182
💾 New best model saved! Val F1: 0.7182


Epoch 17/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.02it/s]
Epoch 17/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.32it/s]



📈 Epoch 17/50:
   Train Loss: 0.2647 | Prec: 0.6197 | Rec: 0.8069 | F1: 0.7010
   Val Loss: 0.2690 | Prec: 0.6800 | Rec: 0.7278 | F1: 0.7031


Epoch 18/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.03it/s]
Epoch 18/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.25it/s]



📈 Epoch 18/50:
   Train Loss: 0.2468 | Prec: 0.6699 | Rec: 0.7898 | F1: 0.7249
   Val Loss: 0.2628 | Prec: 0.6849 | Rec: 0.7331 | F1: 0.7082


Epoch 19/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.99it/s]
Epoch 19/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.28it/s]



📈 Epoch 19/50:
   Train Loss: 0.2315 | Prec: 0.6795 | Rec: 0.8039 | F1: 0.7365
   Val Loss: 0.2550 | Prec: 0.6836 | Rec: 0.7488 | F1: 0.7147


Epoch 20/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.93it/s]
Epoch 20/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.35it/s]



📈 Epoch 20/50:
   Train Loss: 0.2237 | Prec: 0.6878 | Rec: 0.8188 | F1: 0.7476
   Val Loss: 0.2572 | Prec: 0.6614 | Rec: 0.7660 | F1: 0.7099


Epoch 21/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.99it/s]
Epoch 21/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.23it/s]



📈 Epoch 21/50:
   Train Loss: 0.2298 | Prec: 0.6739 | Rec: 0.8198 | F1: 0.7397
   Val Loss: 0.2581 | Prec: 0.6914 | Rec: 0.7374 | F1: 0.7137


Epoch 22/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.95it/s]
Epoch 22/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.32it/s]



📈 Epoch 22/50:
   Train Loss: 0.2097 | Prec: 0.6842 | Rec: 0.8300 | F1: 0.7501
   Val Loss: 0.2610 | Prec: 0.7078 | Rec: 0.7220 | F1: 0.7148


Epoch 23/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  2.00it/s]
Epoch 23/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.33it/s]



📈 Epoch 23/50:
   Train Loss: 0.2145 | Prec: 0.7244 | Rec: 0.8019 | F1: 0.7612
   Val Loss: 0.2501 | Prec: 0.7013 | Rec: 0.7548 | F1: 0.7270
💾 New best model saved! Val F1: 0.7270


Epoch 24/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.95it/s]
Epoch 24/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.32it/s]



📈 Epoch 24/50:
   Train Loss: 0.2122 | Prec: 0.6785 | Rec: 0.8353 | F1: 0.7488
   Val Loss: 0.2517 | Prec: 0.6391 | Rec: 0.8289 | F1: 0.7217


Epoch 25/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.99it/s]
Epoch 25/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.25it/s]



📈 Epoch 25/50:
   Train Loss: 0.2242 | Prec: 0.6824 | Rec: 0.8172 | F1: 0.7437
   Val Loss: 0.2593 | Prec: 0.7000 | Rec: 0.7306 | F1: 0.7150


Epoch 26/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.92it/s]
Epoch 26/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.36it/s]



📈 Epoch 26/50:
   Train Loss: 0.2044 | Prec: 0.7083 | Rec: 0.8376 | F1: 0.7675
   Val Loss: 0.2562 | Prec: 0.7062 | Rec: 0.7337 | F1: 0.7197


Epoch 27/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.04it/s]
Epoch 27/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.35it/s]



📈 Epoch 27/50:
   Train Loss: 0.2135 | Prec: 0.6835 | Rec: 0.8412 | F1: 0.7542
   Val Loss: 0.2663 | Prec: 0.7066 | Rec: 0.7135 | F1: 0.7101


Epoch 28/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.98it/s]
Epoch 28/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.27it/s]



📈 Epoch 28/50:
   Train Loss: 0.2186 | Prec: 0.6994 | Rec: 0.8216 | F1: 0.7556
   Val Loss: 0.2541 | Prec: 0.6801 | Rec: 0.7619 | F1: 0.7187


Epoch 29/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.90it/s]
Epoch 29/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.38it/s]



📈 Epoch 29/50:
   Train Loss: 0.1919 | Prec: 0.7247 | Rec: 0.8489 | F1: 0.7819
   Val Loss: 0.2399 | Prec: 0.6909 | Rec: 0.7968 | F1: 0.7401
💾 New best model saved! Val F1: 0.7401


Epoch 30/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  2.00it/s]
Epoch 30/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.27it/s]



📈 Epoch 30/50:
   Train Loss: 0.2004 | Prec: 0.6935 | Rec: 0.8526 | F1: 0.7649
   Val Loss: 0.2535 | Prec: 0.6959 | Rec: 0.7483 | F1: 0.7212


Epoch 31/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.01it/s]
Epoch 31/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.27it/s]



📈 Epoch 31/50:
   Train Loss: 0.2045 | Prec: 0.7215 | Rec: 0.8302 | F1: 0.7720
   Val Loss: 0.2612 | Prec: 0.6996 | Rec: 0.7304 | F1: 0.7147


Epoch 32/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.97it/s]
Epoch 32/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.31it/s]



📈 Epoch 32/50:
   Train Loss: 0.2128 | Prec: 0.6879 | Rec: 0.8324 | F1: 0.7533
   Val Loss: 0.2534 | Prec: 0.6907 | Rec: 0.7610 | F1: 0.7241


Epoch 33/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.04it/s]
Epoch 33/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.39it/s]



📈 Epoch 33/50:
   Train Loss: 0.2052 | Prec: 0.7079 | Rec: 0.8368 | F1: 0.7670
   Val Loss: 0.2489 | Prec: 0.7106 | Rec: 0.7493 | F1: 0.7294


Epoch 34/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.03it/s]
Epoch 34/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.39it/s]



📈 Epoch 34/50:
   Train Loss: 0.2024 | Prec: 0.7255 | Rec: 0.8335 | F1: 0.7757
   Val Loss: 0.2535 | Prec: 0.7082 | Rec: 0.7375 | F1: 0.7225


Epoch 35/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.00it/s]
Epoch 35/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.28it/s]



📈 Epoch 35/50:
   Train Loss: 0.1928 | Prec: 0.7186 | Rec: 0.8507 | F1: 0.7791
   Val Loss: 0.2352 | Prec: 0.6982 | Rec: 0.8003 | F1: 0.7458
💾 New best model saved! Val F1: 0.7458


Epoch 36/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.99it/s]
Epoch 36/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.33it/s]



📈 Epoch 36/50:
   Train Loss: 0.1999 | Prec: 0.6961 | Rec: 0.8581 | F1: 0.7686
   Val Loss: 0.2437 | Prec: 0.6879 | Rec: 0.7852 | F1: 0.7333


Epoch 37/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.91it/s]
Epoch 37/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.12it/s]



📈 Epoch 37/50:
   Train Loss: 0.1932 | Prec: 0.6915 | Rec: 0.8666 | F1: 0.7692
   Val Loss: 0.2415 | Prec: 0.7033 | Rec: 0.7868 | F1: 0.7427


Epoch 38/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.02it/s]
Epoch 38/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.36it/s]



📈 Epoch 38/50:
   Train Loss: 0.1945 | Prec: 0.7238 | Rec: 0.8353 | F1: 0.7756
   Val Loss: 0.2476 | Prec: 0.7071 | Rec: 0.7558 | F1: 0.7306


Epoch 39/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.01it/s]
Epoch 39/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.32it/s]



📈 Epoch 39/50:
   Train Loss: 0.1930 | Prec: 0.7223 | Rec: 0.8456 | F1: 0.7791
   Val Loss: 0.2558 | Prec: 0.7081 | Rec: 0.7357 | F1: 0.7216


Epoch 40/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.03it/s]
Epoch 40/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.37it/s]



📈 Epoch 40/50:
   Train Loss: 0.1967 | Prec: 0.7105 | Rec: 0.8400 | F1: 0.7698
   Val Loss: 0.2393 | Prec: 0.7133 | Rec: 0.7761 | F1: 0.7434


Epoch 41/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.94it/s]
Epoch 41/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.34it/s]



📈 Epoch 41/50:
   Train Loss: 0.2028 | Prec: 0.7179 | Rec: 0.8435 | F1: 0.7757
   Val Loss: 0.2375 | Prec: 0.7009 | Rec: 0.7800 | F1: 0.7383


Epoch 42/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.00it/s]
Epoch 42/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.14it/s]



📈 Epoch 42/50:
   Train Loss: 0.1918 | Prec: 0.7074 | Rec: 0.8601 | F1: 0.7763
   Val Loss: 0.2499 | Prec: 0.7035 | Rec: 0.7445 | F1: 0.7234


Epoch 43/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.87it/s]
Epoch 43/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.43it/s]



📈 Epoch 43/50:
   Train Loss: 0.1934 | Prec: 0.7158 | Rec: 0.8448 | F1: 0.7749
   Val Loss: 0.2449 | Prec: 0.6805 | Rec: 0.7817 | F1: 0.7276


Epoch 44/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.93it/s]
Epoch 44/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.31it/s]



📈 Epoch 44/50:
   Train Loss: 0.2149 | Prec: 0.7172 | Rec: 0.8420 | F1: 0.7746
   Val Loss: 0.2497 | Prec: 0.6866 | Rec: 0.7676 | F1: 0.7249


Epoch 45/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.92it/s]
Epoch 45/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.36it/s]



📈 Epoch 45/50:
   Train Loss: 0.1995 | Prec: 0.7092 | Rec: 0.8424 | F1: 0.7701
   Val Loss: 0.2436 | Prec: 0.6897 | Rec: 0.7786 | F1: 0.7314


Epoch 46/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.92it/s]
Epoch 46/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.24it/s]



📈 Epoch 46/50:
   Train Loss: 0.1884 | Prec: 0.7170 | Rec: 0.8616 | F1: 0.7827
   Val Loss: 0.2373 | Prec: 0.6938 | Rec: 0.7932 | F1: 0.7401


Epoch 47/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.96it/s]
Epoch 47/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.32it/s]



📈 Epoch 47/50:
   Train Loss: 0.2056 | Prec: 0.7039 | Rec: 0.8449 | F1: 0.7680
   Val Loss: 0.2330 | Prec: 0.6872 | Rec: 0.8100 | F1: 0.7435


Epoch 48/50 [Train]: 100%|██████████| 14/14 [00:06<00:00,  2.03it/s]
Epoch 48/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.30it/s]



📈 Epoch 48/50:
   Train Loss: 0.2126 | Prec: 0.7036 | Rec: 0.8362 | F1: 0.7642
   Val Loss: 0.2378 | Prec: 0.6986 | Rec: 0.7859 | F1: 0.7397


Epoch 49/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.94it/s]
Epoch 49/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.37it/s]



📈 Epoch 49/50:
   Train Loss: 0.2053 | Prec: 0.7157 | Rec: 0.8316 | F1: 0.7693
   Val Loss: 0.2391 | Prec: 0.6957 | Rec: 0.7867 | F1: 0.7384


Epoch 50/50 [Train]: 100%|██████████| 14/14 [00:07<00:00,  1.96it/s]
Epoch 50/50 [Val]: 100%|██████████| 4/4 [00:00<00:00,  5.09it/s]



📈 Epoch 50/50:
   Train Loss: 0.1815 | Prec: 0.7244 | Rec: 0.8650 | F1: 0.7885
   Val Loss: 0.2356 | Prec: 0.6936 | Rec: 0.7974 | F1: 0.7419

🧪 Testing VM-UNet on test set...


Testing: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]



🔍 Finding optimal threshold...
   Threshold 0.3: Precision=0.7251, Recall=0.7229, F1=0.7240
   Threshold 0.4: Precision=0.7513, Recall=0.6867, F1=0.7175
   Threshold 0.5: Precision=0.7751, Recall=0.6499, F1=0.7070
   Threshold 0.6: Precision=0.7977, Recall=0.6106, F1=0.6917
   Threshold 0.7: Precision=0.8210, Recall=0.5648, F1=0.6692

🎯 VM-UNet FINAL TEST RESULTS
Optimal Threshold: 0.3
Test Loss: 0.2495
Precision: 0.7251
Recall:    0.7229
Accuracy:  0.9037
F1-Score:  0.7240

💾 Final model saved!
