In [None]:
# config.py
"""Configuration parameters for the cattle segmentation project"""
import os
import torch

class Config:
    # Paths and dataset configuration
    BASE_DIR = r'C:\Users\andrey\.cache\kagglehub\datasets\sadhliroomyprime\cattle-weight-detection-model-dataset-12k\versions\3\www.acmeai.tech Dataset - BMGF-LivestockWeight-CV\Pixel\B3'
    SAVED_MODEL_PATH = 'best_cattle_segmentation_model_7.keras'
    MODEL_TESTING_PATH = 'best_cattle_segmentation_model_6.keras'

    # Training parameters
    BATCH_SIZE = 16
    VAL_BATCH_SIZE = 16
    NUM_WORKERS = 0  # Reduced from 8 to avoid potential threading issues
    LEARNING_RATE = 0.002
    NUM_EPOCHS = 25

    # Model parameters
    NUM_CLASSES = 3  # Updated to 3 classes: Sticker, Cattle, Background
    IMAGE_SIZE = 512

    # Device configuration
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    LIMIT = 350  # Set to the desired limit, e.g., 100 images for each



# dataset.py
"""Dataset classes and related utilities"""
import os
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
import albumentations as A

def get_transforms(phase):
    """Get image transformations based on training/validation phase"""
    if phase == 'train':
        return A.Compose([
            A.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0)),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.Normalize()
        ])
    else:
        return A.Compose([
            A.Resize(height=512, width=512),
            A.Normalize()
        ])

class CattleSegmentationDataset(Dataset):
    """Dataset for cattle segmentation with 3 classes"""
    def __init__(self, img_paths, mask_paths, transform=None):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.transform = transform
        
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        try:
            img_path = self.img_paths[idx]
            mask_path = self.mask_paths[idx]
            
            # Read image and mask
            image = cv2.imread(img_path)
            if image is None:
                raise ValueError(f"Could not read image at {img_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            mask = cv2.imread(mask_path)
            if mask is None:
                raise ValueError(f"Could not read mask at {mask_path}")
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
            
            # Convert RGB mask to class indices
            sticker_mask = np.all(mask == [0, 117, 255], axis=-1).astype(np.uint8)  # Sticker (0, 117, 255)
            cattle_mask = np.all(mask == [255, 30, 249], axis=-1).astype(np.uint8)  # Cattle (255, 30, 249)
            background_mask = np.all(mask == [0, 255, 193], axis=-1).astype(np.uint8)  # Background (0, 255, 193)
            
            # Combine into a single mask where each pixel has a class index
            final_mask = np.zeros(cattle_mask.shape, dtype=np.uint8)
            final_mask[cattle_mask == 1] = 1  # Cattle class
            final_mask[sticker_mask == 1] = 2  # Sticker class
            final_mask[background_mask == 1] = 0  # Background class
            
            if self.transform:
                augmented = self.transform(image=image, mask=final_mask)
                image = augmented['image']
                final_mask = augmented['mask']
            
            # Convert to tensors
            image = torch.from_numpy(image.transpose(2, 0, 1)).float()  # Normalized in the transform
            final_mask = torch.from_numpy(final_mask).long()
            
            return image, final_mask
        except Exception as e:
            print(f"Error loading sample {idx} - {img_path}: {str(e)}")
            # Return a dummy sample to avoid crashing
            dummy_image = torch.zeros((3, 512, 512), dtype=torch.float32)
            dummy_mask = torch.zeros((512, 512), dtype=torch.long)
            return dummy_image, dummy_mask


def get_data_paths(base_dir):
    """Get all image and corresponding mask paths"""
    img_paths = []
    mask_paths = []
    img_dir = os.path.join(base_dir, 'images')
    mask_dir = os.path.join(base_dir, 'annotations')
        
    if os.path.exists(img_dir) and os.path.exists(mask_dir):
        for img_name in os.listdir(img_dir):
            if img_name.endswith(('.jpg', '.jpeg', '.png')):
                img_path = os.path.join(img_dir, img_name)
                    
                    # Construct mask name
                mask_name = img_name + '___fuse.png'
                mask_path = os.path.join(mask_dir, mask_name)
                    
                if os.path.exists(mask_path):
                    img_paths.append(img_path)
                    mask_paths.append(mask_path)
                        
    return img_paths, mask_paths

                    
# model.py
"""Model definition and related functions"""
import torch.nn as nn
import segmentation_models_pytorch as smp

# model.py

def create_model(num_classes=3):
    """Create a UNet model with a ResNet18 backbone"""
    try:
        model = smp.Unet(
            encoder_name="resnet18",
            encoder_weights="imagenet",
            in_channels=3,
            classes=num_classes,
        )
        print("Model created successfully")
        print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
        return model
    except Exception as e:
        print(f"Error creating model: {e}")
        raise


# trainer.py
"""Training functionality for the segmentation model"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

# trainer.py
class Trainer:
    def __init__(self, model, train_dataset, val_dataset, config):
        self.model = model
        self.config = config
        self.device = config.DEVICE
        
        # Move model to the configured device
        self.model.to(self.device)
        
        # Create dataloaders
        self.train_loader = DataLoader(
            train_dataset, 
            batch_size=config.BATCH_SIZE,
            shuffle=True,
            num_workers=config.NUM_WORKERS,
            pin_memory=True  # Set to True with CUDA
        )
        
        self.val_loader = DataLoader(
            val_dataset,
            batch_size=config.VAL_BATCH_SIZE,
            shuffle=False,
            num_workers=config.NUM_WORKERS,
            pin_memory=True  # Set to True with CUDA
        )
        
        # Loss function and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
        
        # Training history
        self.history = {'train_loss': [], 'val_loss': [], 'val_iou': []}

    def _validate_batch(self, images, masks):
        """Validate a single batch of data"""
        # Check for NaNs in images and masks
        if torch.isnan(images).any() or torch.isnan(masks).any():
            print("Error: NaN values found in batch.")
            return False
        
        # Check if batch is empty
        if images.shape[0] == 0:
            print("Error: Empty batch.")
            return False
        
        return True
    
    def train_one_epoch(self, epoch):
        """Train the model for one epoch"""
        self.model.train()
        total_loss = 0
        batch_count = 0
        
        print(f"Starting epoch {epoch+1}/{self.config.NUM_EPOCHS}")
        train_pbar = tqdm(self.train_loader, desc=f"Training Epoch {epoch+1}")
        
        for i, (images, masks) in enumerate(train_pbar):
            # Validate batch
            if not self._validate_batch(images, masks):
                print(f"Skipping invalid batch {i}")
                continue
            
            # Transfer to device
            images = images.to(self.device)
            masks = masks.to(self.device)
            
            # Forward pass
            outputs = self.model(images)
            loss = self.criterion(outputs, masks)
            
            # Backward and optimize
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            # Update metrics
            current_loss = loss.item()
            total_loss += current_loss
            batch_count += 1
            
            # Update progress bar
            train_pbar.set_postfix(loss=f"{current_loss:.4f}")
        
        # Calculate average loss
        avg_loss = total_loss / max(batch_count, 1)
        self.history['train_loss'].append(avg_loss)
        
        return avg_loss


    
    def validate(self, epoch):
        """Validate the model on the validation set"""
        self.model.eval()
        total_loss = 0
        total_iou = 0
        batch_count = 0
        
        val_pbar = tqdm(self.val_loader, desc=f"Validation Epoch {epoch+1}")
        
        with torch.no_grad():
            for images, masks in val_pbar:
                try:
                    # Transfer to device
                    images = images.to(self.device)
                    masks = masks.to(self.device)
                    
                    # Forward pass
                    outputs = self.model(images)
                    loss = self.criterion(outputs, masks)
                    
                    # Calculate IoU for cattle class (class index 1)
                    pred_masks = torch.argmax(outputs, dim=1)
                    intersection = torch.logical_and(pred_masks == 1, masks == 1).sum()
                    union = torch.logical_or(pred_masks == 1, masks == 1).sum()
                    iou = intersection / (union + 1e-10)
                    
                    # Update metrics
                    total_loss += loss.item()
                    total_iou += iou.item()
                    batch_count += 1
                    
                    # Update progress bar
                    val_pbar.set_postfix(loss=f"{loss.item():.4f}", iou=f"{iou.item():.4f}")
                    
                except Exception as e:
                    print(f"Error in validation batch: {str(e)}")
                    continue
        
        # Calculate average metrics
        avg_loss = total_loss / max(batch_count, 1)
        avg_iou = total_iou / max(batch_count, 1)
        
        self.history['val_loss'].append(avg_loss)
        self.history['val_iou'].append(avg_iou)
        
        return avg_loss, avg_iou
    
    def train(self):
        """Train the model for the configured number of epochs"""
        print(f"Starting training for {self.config.NUM_EPOCHS} epochs")
        best_val_loss = float('inf')
        
        for epoch in range(self.config.NUM_EPOCHS):
            try:
                # Train for one epoch
                train_loss = self.train_one_epoch(epoch)
                
                # Validate
                val_loss, val_iou = self.validate(epoch)
                
                # Print epoch summary
                print(f"Epoch {epoch+1}/{self.config.NUM_EPOCHS} - "
                      f"Train Loss: {train_loss:.4f}, "
                      f"Val Loss: {val_loss:.4f}, "
                      f"Cattle IoU: {val_iou:.4f}")
                
                # Save best model
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    torch.save(self.model.state_dict(), self.config.SAVED_MODEL_PATH)
                    print(f"Saved best model with validation loss: {val_loss:.4f}")
                
            except Exception as e:
                print(f"Error in epoch {epoch+1}: {str(e)}")
                import traceback
                traceback.print_exc()
                continue
        
        return self.history
    
    def plot_history(self):
        """Plot training history"""
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(self.history['train_loss'], label='Train Loss')
        plt.plot(self.history['val_loss'], label='Val Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(self.history['val_iou'], label='Cattle IoU')
        plt.xlabel('Epochs')
        plt.ylabel('IoU')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig('training_history.png')
        plt.show()

# prediction.py
"""Functionality for making predictions with the trained model"""
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
import albumentations as A
import os

def predict_segmentation(image_path, model, device):
    """Predict segmentation mask for an image"""
    try:
        # Load and preprocess image
        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Could not read image at {image_path}")
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Store original dimensions for reference
        original_height, original_width = image.shape[:2]

        # Preprocess for model
        transform = A.Compose([
            A.Resize(height=512, width=512),
            A.Normalize()
        ])
        
        augmented = transform(image=image)
        processed_image = augmented['image']
        
        # Convert to torch tensor and reorder the dimensions (channels, height, width)
        image_tensor = torch.from_numpy(np.transpose(processed_image, (2, 0, 1))).float().unsqueeze(0)
        
        # Generate prediction
        model.to(device)
        model.eval()
        
        with torch.no_grad():
            image_tensor = image_tensor.to(device)
            output = model(image_tensor)
            pred_mask = torch.argmax(output, dim=1).cpu().numpy()[0]
        
        return pred_mask, image
    
    except Exception as e:
        print(f"Error in prediction: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None

# prediction.py
def visualize_segmentation(image_path, model, device):
    """Visualize segmentation results for a single image"""
    pred_mask, original_img = predict_segmentation(image_path, model, device)
    
    if pred_mask is None:
        print(f"Failed to predict for {image_path}")
        return
        
    output_folder = './predictions'
    os.makedirs(output_folder, exist_ok=True)
    prediction_file_path = os.path.join(output_folder, f'segmentation_{os.path.basename(image_path)}.png')
    
    # Get dimensions of the original image
    original_height, original_width = original_img.shape[:2]
    
    # Resize the prediction mask to match the original image dimensions
    resized_mask = cv2.resize(pred_mask, (original_width, original_height), 
                             interpolation=cv2.INTER_NEAREST)
    
    # Create an overlay of the mask on the image
    overlay = original_img.copy()
    # Highlight different segments with different colors
    overlay[resized_mask == 1] = [255, 30, 249]   # Cattle in pink
    overlay[resized_mask == 2] = [0, 117, 255]    # Sticker in blue
    overlay[resized_mask == 0] = [0, 255, 193]    # Background in green
    
    # Display the overlay image
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(original_img)
    plt.title('Original Image')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(resized_mask, cmap='viridis')
    plt.title('Segmentation Mask')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(overlay)
    plt.title('Segmentation Overlay')
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(prediction_file_path)
    print(f"Saved segmentation to {prediction_file_path}")
    plt.show()


# visualization.py
"""Visualization utilities for the dataset and results"""
import matplotlib.pyplot as plt
import numpy as np

def visualize_sample_masks(dataset, num_samples=3):
    """Visualize a few sample images and their masks from the dataset"""
    fig, axs = plt.subplots(num_samples, 2, figsize=(12, 4*num_samples))
    
    for i in range(num_samples):
        try:
            img, mask = dataset[i]
            
            # Convert tensors back to numpy for visualization
            img = img.numpy().transpose(1, 2, 0)
            img = (img - img.min()) / (img.max() - img.min())  # Normalize for display
            mask = mask.numpy()
            
            # Display
            axs[i, 0].imshow(img)
            axs[i, 0].set_title('Original Image')
            axs[i, 0].axis('off')
            
            axs[i, 1].imshow(mask, cmap='viridis')
            axs[i, 1].set_title('Segmentation Mask')
            axs[i, 1].axis('off')
            
        except Exception as e:
            print(f"Error visualizing sample {i}: {str(e)}")
            # Leave this spot empty in the visualization
            axs[i, 0].text(0.5, 0.5, f"Error loading sample {i}", 
                           ha='center', va='center')
            axs[i, 0].axis('off')
            axs[i, 1].axis('off')
    
    plt.tight_layout()
    plt.savefig('sample_masks.png')
    plt.show()

# main.py
"""Main execution script for the cattle segmentation project"""
import os
import torch
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# Import our modules

def main():
    """Main execution function"""
    try:
        # Initialize configuration
        config = Config()
        print(f"Using device: {config.DEVICE}")

        # Get data paths
        img_paths, mask_paths = get_data_paths(config.BASE_DIR)
        print(f"Found {len(img_paths)} image-mask pairs")

        if len(img_paths) == 0:
            raise ValueError("No images found. Check the dataset path.")

        # Limit the number of images if `LIMIT` is set
        if config.LIMIT:
            img_paths = img_paths[:config.LIMIT]
            mask_paths = mask_paths[:config.LIMIT]
            print(f"Limiting to {config.LIMIT} images for training and validation")

        # Split data into training and validation sets
        train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = train_test_split(
            img_paths, mask_paths, test_size=0.2, random_state=42
        )

        print(f"Training samples: {len(train_img_paths)}")
        print(f"Validation samples: {len(val_img_paths)}")

        # Create datasets
        train_dataset = CattleSegmentationDataset(
            train_img_paths, train_mask_paths, transform=get_transforms('train')
        )

        val_dataset = CattleSegmentationDataset(
            val_img_paths, val_mask_paths, transform=get_transforms('val')
        )

        # Visualize samples to verify dataset
        print("Visualizing sample masks to verify dataset loading...")
        visualize_sample_masks(val_dataset)

        # Create model
        model = create_model(num_classes=config.NUM_CLASSES)

        # Create trainer and train
        trainer = Trainer(model, train_dataset, val_dataset, config)
        trainer.train()

        # Plot training history
        trainer.plot_history()

        # Load best model for prediction
        best_model = create_model(num_classes=config.NUM_CLASSES)
        best_model.load_state_dict(torch.load(config.MODEL_TESTING_PATH))
        print(f"Loaded model for segmentation prediction")

        # Test segmentation on validation samples
        test_samples = val_img_paths[:10]  # Test on 10 samples
        for img_path in test_samples:
            print(f"Segmenting image: {os.path.basename(img_path)}")
            visualize_segmentation(img_path, best_model, config.DEVICE)

    except Exception as e:
        print(f"Error in main execution: {str(e)}")
        import traceback
        traceback.print_exc()

# Execute the main function if this is the main script
if __name__ == "__main__":
    main()
