In [None]:
# -*- coding: utf-8 -*-
"""Enhanced Weakly Supervised Segmentation Pipeline"""

# !pip install torch==2.5.0 torchvision --index-url https://download.pytorch.org/whl/cpu

# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, random_split, Dataset
import numpy as np
import cv2
import os
import random
import matplotlib.pyplot as plt
import torchvision.transforms.v2 as T
from torchvision.models.segmentation import deeplabv3_resnet50
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.datasets import OxfordIIITPet
from torchvision.ops import box_convert, box_iou
import albumentations as A
from albumentations.pytorch import ToTensorV2
from collections import Counter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Constants
NUM_CLASSES = 37  # Oxford-IIIT Pet has 37 breeds
SEGMENTATION_CLASSES = 2  # Foreground (pet) and background
CAM_THRESHOLD = 0.3
PSEUDO_MASK_THRESHOLD = 0.5
IMAGE_SIZE = 224
DEFAULT_LR = 1e-3
DEFAULT_WEIGHT_DECAY = 1e-4
DEFAULT_CLIP_GRAD_NORM = None

# Data Loading and Preparation
def prepare_datasets():
    """Prepare training and validation datasets with transforms."""
    train_transform = T.Compose([
        T.Resize((256, 256)),
        T.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE), scale=(0.8, 1.0)),
        T.RandomHorizontalFlip(p=0.5),
        T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    val_transform = T.Compose([
        T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Load base dataset
    base_dataset = OxfordIIITPet(
        root="./oxford_iiit_data",
        download=True,
        target_types="category",
        split="trainval",
        transform=None,
    )

    # Split into train and validation
    train_size = int(0.85 * len(base_dataset))
    val_size = len(base_dataset) - train_size
    train_subset, val_subset = random_split(base_dataset, [train_size, val_size])

    # Apply transforms
    train_ds = TransformDataset(train_subset, train_transform)
    val_ds = TransformDataset(val_subset, val_transform)

    return train_ds, val_ds

class TransformDataset(Dataset):
    """Wrapper to apply transforms on the fly."""
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        img, target = self.subset[idx]
        if self.transform:
            img = self.transform(img)
        return img, target

# Model Initialization
def get_classifier_model():
    """Initialize and return the classification model."""
    model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, NUM_CLASSES)
    return model.to(device)

def get_segmentation_model():
    """Initialize and return the segmentation model."""
    model = deeplabv3_resnet50(weights="DEFAULT")
    model.classifier[4] = nn.Conv2d(256, SEGMENTATION_CLASSES, kernel_size=1)
    return model.to(device)

# Training Functions
def train_classifier(model, train_loader, val_loader,
                   num_epochs=10, lr=DEFAULT_LR,
                   weight_decay=DEFAULT_WEIGHT_DECAY, 
                   clip_grad_norm=DEFAULT_CLIP_GRAD_NORM,
                   save_path="classifier.pth"):
    """Enhanced classifier training function with gradient clipping."""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()

            if clip_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)

            optimizer.step()
            
            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        scheduler.step()
        
        train_loss = running_loss / len(train_loader)
        train_acc = correct / total
        val_loss, val_acc = evaluate_classifier(model, val_loader, criterion)

        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    torch.save(model.state_dict(), save_path)
    print(f"Classifier model saved to {save_path}")
    return model

def evaluate_classifier(model, loader, criterion=None):
    """Enhanced evaluation with both loss and accuracy."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    if criterion is None:
        criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    return running_loss / len(loader), correct / total

# CAM Generation and Processing
class GradCAM:
    """Enhanced Grad-CAM implementation with bounding box support."""
    def __init__(self, model, target_layer_name="layer4"):
        self.model = model
        self.model.eval()
        self.target_layer = None
        
        # Find target layer
        for name, module in self.model.named_children():
            if name == target_layer_name:
                self.target_layer = module
                break
        
        if self.target_layer is None:
            raise ValueError(f"Layer {target_layer_name} not found")
        
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self.forward_hook = self.target_layer.register_forward_hook(self._forward_hook)
        self.backward_hook = self.target_layer.register_backward_hook(self._backward_hook)
    
    def _forward_hook(self, module, input, output):
        self.activations = output
    
    def _backward_hook(self, module, grad_in, grad_out):
        self.gradients = grad_out[0]
    
    def __call__(self, x, class_idx=None):
        logits = self.model(x)
        if class_idx is None:
            class_idx = torch.argmax(logits, dim=1)
        
        # Create one-hot encoding for backprop
        one_hot = torch.zeros_like(logits)
        for i in range(logits.size(0)):
            one_hot[i, class_idx[i]] = 1.0
        
        self.model.zero_grad()
        logits.backward(gradient=one_hot, retain_graph=True)
        
        # Compute CAM
        weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = F.relu(cam)
        
        # Normalize
        cam = cam - cam.view(cam.size(0), -1).min(dim=1)[0].view(cam.size(0),1,1,1)
        cam = cam / (cam.view(cam.size(0), -1).max(dim=1)[0].view(cam.size(0),1,1,1) + 1e-8)
        
        return cam

def generate_and_refine_cams(model, data_loader, output_dir="cams"):
    """Generate and refine CAMs for the dataset."""
    gradcam = GradCAM(model)
    os.makedirs(output_dir, exist_ok=True)
    
    model.eval()
    with torch.no_grad():
        for i, (images, _) in enumerate(data_loader):
            images = images.to(device)
            
            with torch.enable_grad():
                cams = gradcam(images)
            
            # Upsample and save
            cams = F.interpolate(cams, size=(IMAGE_SIZE, IMAGE_SIZE), 
                               mode='bilinear', align_corners=False)
            cams = cams.squeeze(1).cpu().numpy()
            
            for b in range(cams.shape[0]):
                cam = cams[b]
                # Apply ReCAM refinement
                cam = recam_refinement(cam)
                np.save(os.path.join(output_dir, f"cam_{i*data_loader.batch_size+b}.npy"), cam)
    
    print("CAM generation and refinement complete!")
    return output_dir

def recam_refinement(cam, expansion_factor=1.2, threshold=CAM_THRESHOLD):
    """Refine CAMs using ReCAM approach."""
    mask = (cam >= threshold).astype(np.uint8)
    coverage = mask.sum() / (cam.shape[0]*cam.shape[1])
    
    if coverage < 0.1:
        cam = cam * expansion_factor
        cam = np.clip(cam, 0, 1)
    
    return cam

# Pseudo Mask Generation
def generate_pseudo_masks(cam_dir, output_dir="pseudo_masks"):
    """Generate binary pseudo masks from refined CAMs."""
    os.makedirs(output_dir, exist_ok=True)
    
    cam_files = [f for f in os.listdir(cam_dir) if f.endswith('.npy')]
    for cam_file in cam_files:
        cam = np.load(os.path.join(cam_dir, cam_file))
        mask = (cam >= PSEUDO_MASK_THRESHOLD).astype(np.uint8)
        
        # Apply morphological refinement
        mask = morphological_refinement(mask)
        cv2.imwrite(os.path.join(output_dir, cam_file.replace('.npy', '.png')), mask*255)
    
    print("Pseudo mask generation complete!")
    return output_dir

def morphological_refinement(mask, kernel_size=3):
    """Apply morphological operations to refine masks."""
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    return cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)

# Segmentation Training
class PseudoSegDataset(Dataset):
    """Dataset for segmentation training with pseudo masks."""
    def __init__(self, image_dataset, mask_dir):
        self.image_dataset = image_dataset
        self.mask_dir = mask_dir
        self.mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.png')])
        
    def __len__(self):
        return len(self.image_dataset)
    
    def __getitem__(self, idx):
        image, _ = self.image_dataset[idx]  # Original image and label
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
        
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 127).astype(np.uint8)  # Ensure binary
        
        # Convert to tensor and remove channel dimension if present
        mask = torch.from_numpy(mask).long()  # Should be (H,W)
        
        return image, mask

def train_segmentation(model, train_loader, val_loader, num_epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=DEFAULT_LR, weight_decay=DEFAULT_WEIGHT_DECAY)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)
            
            # Debug shapes
            print(f"Input shape: {images.shape}, Mask shape: {masks.shape}")
            
            optimizer.zero_grad()
            outputs = model(images)['out']
            
            # Ensure mask is long type and correct shape
            if len(masks.shape) == 4:  # If (B,C,H,W)
                masks = masks.squeeze(1)  # Remove channel dim
            masks = masks.long()  # Ensure correct dtype
            
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        train_loss = running_loss / len(train_loader)
        val_loss = evaluate_segmentation(model, val_loader, criterion)
        
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"  Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    return model

def evaluate_segmentation(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)
            if len(masks.shape) == 4:
                masks = masks.squeeze(1)
                
            outputs = model(images)['out']
            loss = criterion(outputs, masks)
            running_loss += loss.item()
    
    return running_loss / len(loader)

# Visualization
def visualize_results(model, dataset, num_samples=3):
    """Visualize segmentation results."""
    model.eval()
    indices = random.sample(range(len(dataset)), num_samples)
    
    plt.figure(figsize=(15, 5*num_samples))
    for i, idx in enumerate(indices):
        image, mask = dataset[idx]
        image = image.unsqueeze(0).to(device)
        
        with torch.no_grad():
            pred = model(image)['out']
            pred_mask = torch.argmax(pred, dim=1).squeeze(0).cpu().numpy()
        
        image_np = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
        image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
        
        plt.subplot(num_samples, 3, i*3+1)
        plt.imshow(image_np)
        plt.title("Input Image")
        plt.axis('off')
        
        plt.subplot(num_samples, 3, i*3+2)
        plt.imshow(mask, cmap='jet')
        plt.title("Pseudo Mask")
        plt.axis('off')
        
        plt.subplot(num_samples, 3, i*3+3)
        plt.imshow(pred_mask, cmap='jet')
        plt.title("Prediction")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Main Workflow
def main():
    # 1. Prepare data
    train_ds, val_ds = prepare_datasets()
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2)
    
    # 2. Train classifier
    print("Training classifier...")
    classifier = get_classifier_model()
    classifier = train_classifier(
        classifier, train_loader, val_loader,
        num_epochs=10, lr=DEFAULT_LR,
        weight_decay=DEFAULT_WEIGHT_DECAY,
        clip_grad_norm=DEFAULT_CLIP_GRAD_NORM,
        save_path="classifier.pth"
    )
    
    # 3. Generate and refine CAMs
    print("\nGenerating CAMs...")
    cam_dir = generate_and_refine_cams(classifier, train_loader)
    
    # 4. Generate pseudo masks
    print("\nGenerating pseudo masks...")
    mask_dir = generate_pseudo_masks(cam_dir)
    
    # 5. Prepare segmentation dataset
    seg_train_ds = PseudoSegDataset(train_ds, mask_dir)
    seg_train_loader = DataLoader(seg_train_ds, batch_size=8, shuffle=True, num_workers=2)
    
    # 6. Train segmentation model
    print("\nTraining segmentation model...")
    seg_model = get_segmentation_model()
    seg_model = train_segmentation(seg_model, seg_train_loader, val_loader)
    torch.save(seg_model.state_dict(), "seg_model.pth")
    
    # 7. Visualize results
    print("\nVisualizing results...")
    visualize_results(seg_model, seg_train_ds)
    
    print("\nWorkflow complete!")

if __name__ == "__main__":
    main()