In [None]:
from google.colab import drive
drive.mount('/content/drive')

Checking, Modelling, Traning, Predict

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from glob import glob
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("‚ö†Ô∏è WARNING: CUDA not available! Will use CPU (very slow)")

# ==================== STEP 2: Dataset Analysis ====================
def analyze_dataset(mask_dir, image_dir):
    """Analyze dataset to detect number of classes and mask format"""
    print("=" * 60)
    print("ANALYZING DATASET...")
    print("=" * 60)

    mask_files = sorted(glob(os.path.join(mask_dir, "*.png")))
    image_files = sorted(glob(os.path.join(image_dir, "*.png")))

    print(f"Found {len(mask_files)} mask files")
    print(f"Found {len(image_files)} image files")

    if len(mask_files) == 0 or len(image_files) == 0:
        raise ValueError("No images found! Check your paths.")

    # Analyze first mask
    sample_mask = np.array(Image.open(mask_files[0]))
    print(f"\nMask shape: {sample_mask.shape}")
    print(f"Mask dtype: {sample_mask.dtype}")

    # Check if RGB or Grayscale
    if len(sample_mask.shape) == 3:
        print(f"Mask format: RGB (channels: {sample_mask.shape[2]})")
        # Convert RGB to class indices if needed
        unique_colors = np.unique(sample_mask.reshape(-1, sample_mask.shape[2]), axis=0)
        num_classes = len(unique_colors)
        print(f"Unique colors found: {num_classes}")
        print(f"Colors (first 10): \n{unique_colors[:10]}")
        is_rgb = True
        unique_values = None
    else:
        print(f"Mask format: Grayscale")
        unique_values = np.unique(sample_mask)
        num_classes = int(unique_values.max()) + 1  # Max value + 1 untuk handle semua kelas
        print(f"Unique values: {unique_values}")
        print(f"Max class value: {unique_values.max()}")
        print(f"Number of classes (max+1): {num_classes}")
        is_rgb = False

    # Sample image
    sample_image = np.array(Image.open(image_files[0]))
    print(f"\nImage shape: {sample_image.shape}")
    print(f"Image dtype: {sample_image.dtype}")

    print("=" * 60)

    return {
        'num_classes': num_classes,
        'is_rgb_mask': is_rgb,
        'mask_files': mask_files,
        'image_files': image_files,
        'image_shape': sample_image.shape,
        'mask_shape': sample_mask.shape,
        'unique_values': unique_values
    }

# ==================== STEP 3: Custom Dataset ====================
class SegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, num_classes, is_rgb_mask=False,
                 img_size=(256, 256), augment=False, value_map=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.num_classes = num_classes
        self.is_rgb_mask = is_rgb_mask
        self.img_size = img_size
        self.augment = augment
        self.value_map = value_map  # Mapping untuk remap nilai mask

        # Build color to class mapping if RGB masks
        if is_rgb_mask:
            self.color_map = self._build_color_map()

    def _build_color_map(self):
        """Build mapping from RGB colors to class indices"""
        color_map = {}
        sample_mask = np.array(Image.open(self.mask_paths[0]))
        unique_colors = np.unique(sample_mask.reshape(-1, sample_mask.shape[2]), axis=0)

        for idx, color in enumerate(unique_colors):
            color_map[tuple(color)] = idx

        return color_map

    def _rgb_to_class(self, mask):
        """Convert RGB mask to class indices"""
        h, w = mask.shape[:2]
        class_mask = np.zeros((h, w), dtype=np.int64)

        for color, class_idx in self.color_map.items():
            match = np.all(mask == color, axis=-1)
            class_mask[match] = class_idx

        return class_mask

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        image = Image.open(self.image_paths[idx]).convert('RGB')
        image = image.resize(self.img_size, Image.BILINEAR)
        image = np.array(image) / 255.0

        # Load mask
        mask = Image.open(self.mask_paths[idx])
        mask = mask.resize(self.img_size, Image.NEAREST)
        mask = np.array(mask)

        # Convert RGB mask to class indices if needed
        if self.is_rgb_mask:
            mask = self._rgb_to_class(mask)

        # Remap mask values if value_map is provided
        if self.value_map is not None:
            mask_remapped = np.zeros_like(mask)
            for old_val, new_val in self.value_map.items():
                mask_remapped[mask == old_val] = new_val
            mask = mask_remapped

        # Convert to tensors
        image = torch.FloatTensor(image).permute(2, 0, 1)  # HWC to CHW
        mask = torch.LongTensor(mask)

        return image, mask

# ==================== STEP 4: U-Net Architecture ====================
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=2):
        super().__init__()

        # Encoder
        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)

        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(1024, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(128, 64)

        # Output
        self.out = nn.Conv2d(64, num_classes, kernel_size=1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool(enc4))

        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat([dec4, enc4], dim=1)
        dec4 = self.dec4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat([dec3, enc3], dim=1)
        dec3 = self.dec3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat([dec2, enc2], dim=1)
        dec2 = self.dec2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat([dec1, enc1], dim=1)
        dec1 = self.dec1(dec1)

        return self.out(dec1)

# ==================== STEP 5: Training Functions ====================
def train_epoch(model, loader, criterion, optimizer, device, epoch, num_epochs):
    model.train()
    total_loss = 0
    correct_pixels = 0
    total_pixels = 0

    from tqdm import tqdm
    pbar = tqdm(loader, desc=f"Epoch {epoch}/{num_epochs} [TRAIN]")

    for batch_idx, (images, masks) in enumerate(pbar):
        images = images.to(device)
        masks = masks.to(device)

        # Forward
        outputs = model(images)
        loss = criterion(outputs, masks)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Calculate accuracy
        preds = torch.argmax(outputs, dim=1)
        correct_pixels += (preds == masks).sum().item()
        total_pixels += masks.numel()

        # Update progress bar
        avg_loss = total_loss / (batch_idx + 1)
        accuracy = 100.0 * correct_pixels / total_pixels
        pbar.set_postfix({
            'loss': f'{avg_loss:.4f}',
            'acc': f'{accuracy:.2f}%'
        })

    return total_loss / len(loader), accuracy

def validate(model, loader, criterion, device, epoch, num_epochs):
    model.eval()
    total_loss = 0
    correct_pixels = 0
    total_pixels = 0

    from tqdm import tqdm
    pbar = tqdm(loader, desc=f"Epoch {epoch}/{num_epochs} [VAL]  ")

    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(pbar):
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)
            total_loss += loss.item()

            # Calculate accuracy
            preds = torch.argmax(outputs, dim=1)
            correct_pixels += (preds == masks).sum().item()
            total_pixels += masks.numel()

            # Update progress bar
            avg_loss = total_loss / (batch_idx + 1)
            accuracy = 100.0 * correct_pixels / total_pixels
            pbar.set_postfix({
                'loss': f'{avg_loss:.4f}',
                'acc': f'{accuracy:.2f}%'
            })

    return total_loss / len(loader), accuracy

def calculate_iou(pred, target, num_classes):
    """Calculate Intersection over Union (IoU)"""
    ious = []
    pred = pred.view(-1)
    target = target.view(-1)

    for cls in range(num_classes):
        pred_cls = (pred == cls)
        target_cls = (target == cls)

        intersection = (pred_cls & target_cls).sum().float()
        union = (pred_cls | target_cls).sum().float()

        if union == 0:
            iou = float('nan')
        else:
            iou = intersection / union

        ious.append(iou.item())

    return np.nanmean(ious)

# ==================== STEP 6: Visualization ====================
def visualize_predictions(model, dataset, device, num_samples=4):
    """Visualize model predictions"""
    model.eval()

    fig, axes = plt.subplots(num_samples, 3, figsize=(15, num_samples * 4))

    indices = np.random.choice(len(dataset), num_samples, replace=False)

    with torch.no_grad():
        for i, idx in enumerate(indices):
            image, mask = dataset[idx]

            # Predict
            image_input = image.unsqueeze(0).to(device)
            output = model(image_input)
            pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()

            # Denormalize image for display
            image_np = image.permute(1, 2, 0).cpu().numpy()
            mask_np = mask.cpu().numpy()

            # Plot
            axes[i, 0].imshow(image_np)
            axes[i, 0].set_title('Original Image')
            axes[i, 0].axis('off')

            axes[i, 1].imshow(mask_np, cmap='tab20')
            axes[i, 1].set_title('Ground Truth')
            axes[i, 1].axis('off')

            axes[i, 2].imshow(pred, cmap='tab20')
            axes[i, 2].set_title('Prediction')
            axes[i, 2].axis('off')

    plt.tight_layout()
    plt.savefig('predictions.png', dpi=150, bbox_inches='tight')
    plt.show()

# ==================== STEP 7: MAIN EXECUTION ====================
def main():
    global device
    # Enable CUDA debugging
    import os
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

    # Clear CUDA cache to recover from previous errors
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("üîÑ Cleared CUDA cache\n")

    # CONFIGURE THESE PATHS
    MASK_DIR = "/content/drive/MyDrive/datasets/masks"  # Path to mask folder
    IMAGE_DIR = "/content/drive/MyDrive/datasets/images"  # Path to image folder

    # Hyperparameters
    IMG_SIZE = (256, 256)
    BATCH_SIZE = 8
    NUM_EPOCHS = 50
    LEARNING_RATE = 1e-4

    # Step 1: Analyze dataset
    info = analyze_dataset(MASK_DIR, IMAGE_DIR)
    num_classes = info['num_classes']
    is_rgb_mask = info['is_rgb_mask']

    print(f"\nDetected {num_classes} classes")
    print(f"Mask format: {'RGB' if is_rgb_mask else 'Grayscale'}")

    # Create value mapping for grayscale masks
    value_map = None
    if info['unique_values'] is not None:
        unique_vals = info['unique_values']
        print(f"Unique mask values: {unique_vals}")

        # Remap to continuous values 0, 1, 2, ..., n-1
        value_map = {int(old_val): new_val for new_val, old_val in enumerate(unique_vals)}
        num_classes = len(unique_vals)  # Use actual number of unique values

        print(f"\nRemapping mask values:")
        for old_val, new_val in value_map.items():
            print(f"  {old_val} -> {new_val}")
        print(f"\nFinal number of classes: {num_classes}")

    # Step 2: Split dataset
    train_imgs, val_imgs, train_masks, val_masks = train_test_split(
        info['image_files'],
        info['mask_files'],
        test_size=0.2,
        random_state=42
    )

    print(f"\nTrain samples: {len(train_imgs)}")
    print(f"Validation samples: {len(val_imgs)}")

    # Step 3: Create datasets
    train_dataset = SegmentationDataset(
        train_imgs, train_masks, num_classes, is_rgb_mask, IMG_SIZE, augment=True, value_map=value_map
    )
    val_dataset = SegmentationDataset(
        val_imgs, val_masks, num_classes, is_rgb_mask, IMG_SIZE, augment=False, value_map=value_map
    )

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    # VERIFY: Check remapped mask values
    print("\n" + "=" * 60)
    print("VERIFYING REMAPPED MASKS...")
    print("=" * 60)

    # Check first sample
    sample_img, sample_mask = train_dataset[0]
    unique_remapped = torch.unique(sample_mask)
    print(f"Sample mask shape: {sample_mask.shape}")
    print(f"Unique values after remapping: {unique_remapped.numpy()}")
    print(f"Min value: {sample_mask.min().item()}, Max value: {sample_mask.max().item()}")

    # Check ALL masks in dataset
    print("\nChecking all masks in training set...")
    all_valid = True
    for i in range(len(train_dataset)):
        _, mask = train_dataset[i]
        if mask.max().item() >= num_classes or mask.min().item() < 0:
            print(f"‚ö†Ô∏è  Invalid mask at index {i}: min={mask.min().item()}, max={mask.max().item()}")
            all_valid = False
            if i > 10:  # Stop after 10 errors
                break

    if not all_valid:
        print("\n‚ö†Ô∏è  ERROR: Some masks have invalid values!")
        return
    else:
        print(f"‚úì All {len(train_dataset)} training masks are valid (0 to {num_classes-1})")

    print("=" * 60)

    # Step 4: Initialize model
    print("\n" + "=" * 60)
    print("INITIALIZING MODEL...")
    print("=" * 60)

    # Try CUDA first, fallback to CPU if error
    try:
        if torch.cuda.is_available():
            torch.cuda.synchronize()  # Wait for previous operations
            torch.cuda.empty_cache()  # Clear cache

        model = UNet(in_channels=3, num_classes=num_classes)

        # Move to device with error handling
        try:
            model = model.to(device)
            print(f"‚úì Model successfully moved to {device}")
        except RuntimeError as e:
            print(f"‚ö†Ô∏è  CUDA Error: {e}")
            print("üîÑ Falling back to CPU...")
            device_fallback = torch.device('cpu')
            model = model.to(device_fallback)
            print(f"‚úì Model running on CPU")
            # Update device for rest of training
            device = device_fallback

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

        print(f"Model initialized with {num_classes} classes")
        print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

    except Exception as e:
        print(f"‚ö†Ô∏è  Fatal error during model initialization: {e}")
        print("\nüí° SOLUTION: Go to Runtime ‚Üí Restart Runtime and run again")
        return

    # Test forward pass with dummy data
    print("\nTesting forward pass...")
    try:
        dummy_input = torch.randn(1, 3, 256, 256).to(device)
        dummy_output = model(dummy_input)
        print(f"‚úì Forward pass successful! Output shape: {dummy_output.shape}")

        # Test loss calculation
        dummy_target = torch.randint(0, num_classes, (1, 256, 256)).to(device)
        dummy_loss = criterion(dummy_output, dummy_target)
        print(f"‚úì Loss calculation successful! Loss: {dummy_loss.item():.4f}")
    except Exception as e:
        print(f"‚ö†Ô∏è  Error during forward pass: {e}")
        return

    print("=" * 60)

    # Step 5: Training loop
    print("\n" + "=" * 60)
    print("STARTING TRAINING...")
    print("=" * 60)
    print(f"Device: {device}")
    print(f"Epochs: {NUM_EPOCHS}")
    print(f"Batch Size: {BATCH_SIZE}")
    print(f"Learning Rate: {LEARNING_RATE}")
    print(f"Image Size: {IMG_SIZE}")
    print("=" * 60 + "\n")

    best_val_loss = float('inf')
    best_val_acc = 0.0
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []

    import time
    start_time = time.time()

    for epoch in range(NUM_EPOCHS):
        epoch_start = time.time()

        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device, epoch+1, NUM_EPOCHS
        )

        # Validate
        val_loss, val_acc = validate(
            model, val_loader, criterion, device, epoch+1, NUM_EPOCHS
        )

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)

        epoch_time = time.time() - epoch_start
        elapsed_time = time.time() - start_time
        eta = (elapsed_time / (epoch + 1)) * (NUM_EPOCHS - epoch - 1)

        # Print summary
        print(f"\n{'='*60}")
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Summary:")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
        print(f"  Time: {epoch_time:.1f}s | Elapsed: {elapsed_time/60:.1f}m | ETA: {eta/60:.1f}m")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_unet_model.pth')
            print(f"  ‚úì Best model saved! (Val Loss: {best_val_loss:.4f}, Val Acc: {best_val_acc:.2f}%)‡§µ‡§∞‡•Å‡§®)")

        # GPU Memory usage
        if torch.cuda.is_available():
            print(f"  GPU Memory: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

        print(f"{'='*60}\n")

    total_time = time.time() - start_time
    print(f"\nüéâ Training completed in {total_time/60:.1f} minutes!")
    print(f"Best Val Loss: {best_val_loss:.4f}")
    print(f"Best Val Acc: {best_val_acc:.2f}%")

    # Step 6: Plot training history
    plt.figure(figsize=(15, 5))

    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss', marker='o', markersize=3)
    plt.plot(val_losses, label='Val Loss', marker='s', markersize=3)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training & Validation Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Accuracy plot
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Acc', marker='o', markersize=3)
    plt.plot(val_accs, label='Val Acc', marker='s', markersize=3)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training & Validation Accuracy')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Step 7: Visualize predictions
    print("\n" + "=" * 60)
    print("GENERATING PREDICTIONS...")
    print("=" * 60)

    model.load_state_dict(torch.load('best_unet_model.pth'))
    visualize_predictions(model, val_dataset, device, num_samples=4)

    # Step 8: Calculate IoU
    model.eval()
    total_iou = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            total_iou += calculate_iou(preds, masks, num_classes)

    mean_iou = total_iou / len(val_loader)
    print(f"\nMean IoU: {mean_iou:.4f}")

    print("\n" + "=" * 60)
    print("TRAINING COMPLETED!")
    print("=" * 60)
    print(f"Best model saved as: best_unet_model.pth")
    print(f"Training history saved as: training_history.png")
    print(f"Predictions saved as: predictions.png")

# ==================== STEP 8: PREDICTION FUNCTION ====================
def predict_single_image(model_path, image_path, num_classes, device):
    """Predict on a single image"""
    # Load model
    model = UNet(in_channels=3, num_classes=num_classes).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    original_size = image.size
    image = image.resize((256, 256), Image.BILINEAR)
    image_np = np.array(image) / 255.0
    image_tensor = torch.FloatTensor(image_np).permute(2, 0, 1).unsqueeze(0).to(device)

    # Predict
    with torch.no_grad():
        output = model(image_tensor)
        pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()

    # Resize prediction back to original size
    pred_resized = cv2.resize(pred.astype(np.uint8), original_size, interpolation=cv2.INTER_NEAREST)

    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    axes[0].imshow(Image.open(image_path))
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    axes[1].imshow(pred_resized, cmap='tab20')
    axes[1].set_title('Prediction')
    axes[1].axis('off')

    plt.tight_layout()
    plt.show()

    return pred_resized

# RUN TRAINING
if __name__ == "__main__":
    main()

    # EXAMPLE: Predict on new image
    # pred = predict_single_image(
    #     model_path='best_unet_model.pth',
    #     image_path='path/to/test/image.png',
    #     num_classes=YOUR_NUM_CLASSES,
    #     device=device
    # )

CEK CLASSES

In [None]:
#Untuk mendapatkan berapa total class
import torch

def detect_num_classes(model_path):
    """Auto-detect number of classes from saved model"""

    print(f"{'='*60}")
    print(f"üîç DETECTING NUMBER OF CLASSES")
    print(f"{'='*60}")
    print(f"Model: {model_path}\n")

    # Load state dict
    state_dict = torch.load(model_path, map_location='cpu')

    # Get output layer shape
    out_weight_shape = state_dict['out.weight'].shape
    num_classes = out_weight_shape[0]  # First dimension = number of classes

    print(f"‚úÖ Detected: {num_classes} classes")
    print(f"   Output layer shape: {out_weight_shape}")
    print(f"{'='*60}\n")

    return num_classes

# ========== USAGE ==========
MODEL_PATH = '/content/best_unet_model.pth'

# Auto-detect
NUM_CLASSES = detect_num_classes(MODEL_PATH)

print(f"‚ú® Use this value for prediction:")
print(f"   NUM_CLASSES = {NUM_CLASSES}")

Predict


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import torch
import torch.nn as nn

# ==================== U-Net Model ====================
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=2):
        super().__init__()
        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)
        self.bottleneck = DoubleConv(512, 1024)
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(128, 64)
        self.out = nn.Conv2d(64, num_classes, kernel_size=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        bottleneck = self.bottleneck(self.pool(enc4))
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat([dec4, enc4], dim=1)
        dec4 = self.dec4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat([dec3, enc3], dim=1)
        dec3 = self.dec3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat([dec2, enc2], dim=1)
        dec2 = self.dec2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat([dec1, enc1], dim=1)
        dec1 = self.dec1(dec1)
        return self.out(dec1)

# ==================== AUTO-DETECT NUM_CLASSES ====================
def detect_num_classes(model_path):
    """Auto-detect number of classes from saved model"""
    state_dict = torch.load(model_path, map_location='cpu')
    num_classes = state_dict['out.weight'].shape[0]
    return num_classes

# ==================== PREDICT FUNCTION ====================
def predict_image(model_path, image_path, num_classes=None):
    """Prediksi segmentasi pada gambar"""

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Auto-detect num_classes if not provided
    if num_classes is None:
        print("üîç Auto-detecting number of classes...")
        num_classes = detect_num_classes(model_path)
        print(f"‚úÖ Detected: {num_classes} classes\n")

    print(f"{'='*60}")
    print(f"üîÆ PREDICTION MODE")
    print(f"{'='*60}")
    print(f"Device: {device}")
    print(f"Model: {model_path}")
    print(f"Image: {image_path}")
    print(f"Classes: {num_classes}")
    print(f"{'='*60}\n")

    # Load model
    print("üì¶ Loading model...")
    model = UNet(in_channels=3, num_classes=num_classes).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print("‚úÖ Model loaded!\n")

    # Load image
    print("üñºÔ∏è  Loading image...")
    image = Image.open(image_path).convert('RGB')
    original_size = image.size
    print(f"Original size: {original_size[0]}x{original_size[1]}")

    # Preprocess
    image_resized = image.resize((256, 256), Image.BILINEAR)
    image_np = np.array(image_resized) / 255.0
    image_tensor = torch.FloatTensor(image_np).permute(2, 0, 1).unsqueeze(0).to(device)

    # Predict
    print("üöÄ Running inference...")
    with torch.no_grad():
        output = model(image_tensor)
        pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()

    # Resize back to original
    pred_resized = cv2.resize(pred.astype(np.uint8), original_size, interpolation=cv2.INTER_NEAREST)

    unique_classes = np.unique(pred_resized)
    print(f"‚úÖ Prediction completed!")
    print(f"Detected classes in image: {unique_classes}")
    print(f"Output shape: {pred_resized.shape}\n")

    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Original
    axes[0].imshow(image)
    axes[0].set_title('Original Image', fontsize=16, fontweight='bold')
    axes[0].axis('off')

    # Prediction
    im = axes[1].imshow(pred_resized, cmap='tab20', vmin=0, vmax=num_classes-1)
    axes[1].set_title('Prediction Mask', fontsize=16, fontweight='bold')
    axes[1].axis('off')

    # Add colorbar for mask
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(axes[1])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = plt.colorbar(im, cax=cax)
    cbar.set_label('Class ID', fontsize=12)

    # Overlay
    overlay = np.array(image).copy()
    mask_colored = plt.cm.tab20(pred_resized / max(num_classes-1, 1))[:, :, :3]
    overlay = (overlay * 0.5 + mask_colored * 255 * 0.5).astype(np.uint8)
    axes[2].imshow(overlay)
    axes[2].set_title('Overlay (Original + Mask)', fontsize=16, fontweight='bold')
    axes[2].axis('off')

    plt.tight_layout()
    plt.savefig('prediction_result.png', dpi=150, bbox_inches='tight')
    print(f"üíæ Result saved to: prediction_result.png")
    plt.show()

    # Statistics
    print(f"\n{'='*60}")
    print(f"üìä PREDICTION STATISTICS")
    print(f"{'='*60}")
    for class_id in range(num_classes):
        pixels = np.sum(pred_resized == class_id)
        percentage = (pixels / pred_resized.size) * 100
        if pixels > 0:  # Only show classes that appear
            print(f"Class {class_id:2d}: {pixels:8,} pixels ({percentage:6.2f}%)")

    print(f"{'='*60}")
    print(f"‚ú® DONE!")
    print(f"{'='*60}")

    return pred_resized


# ==================== KONFIGURASI ====================
# üëá GANTI 2 HAL INI SAJA (NUM_CLASSES otomatis terdeteksi):

MODEL_PATH = '/content/drive/MyDrive/best_unet_model.pth'                                 # Path ke model .pth
IMAGE_PATH = '/content/drive/MyDrive/download (1).jpg'   # Path ke gambar

# ==================== RUN PREDICTION ====================
# NUM_CLASSES akan otomatis terdeteksi dari model!
prediction = predict_image(MODEL_PATH, IMAGE_PATH)