In [2]:
!unzip -q /content/data.zip -d /content/custom_data

In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
import shutil
from tqdm import tqdm
import json

In [32]:
def yolo_to_mask(txt_file, image_shape, multi_class=True):
    """
    Convert YOLO polygon format to segmentation mask

    Args:
        txt_file: Path to YOLO annotation file
        image_shape: (height, width) of the image
        multi_class: If True, different classes get different pixel values

    Returns:
        Binary or multi-class mask
    """
    mask = np.zeros(image_shape[:2], dtype=np.uint8)

    if not os.path.exists(txt_file):
        return mask

    with open(txt_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 6:  # Skip invalid lines
                continue

            class_id = int(parts[0])
            coords = [float(x) for x in parts[1:]]

            # Convert normalized coordinates to pixel coordinates
            points = []
            for i in range(0, len(coords), 2):
                x = int(coords[i] * image_shape[1])
                y = int(coords[i+1] * image_shape[0])
                points.append([x, y])

            if len(points) > 2:  # Need at least 3 points for a polygon
                # Fill polygon - use class_id+1 for multi-class, 255 for binary
                fill_value = (class_id + 1) if multi_class else 255
                cv2.fillPoly(mask, [np.array(points)], fill_value)

    return mask

In [34]:
def process_dataset(source_dir, output_dir, multi_class=True):
    """
    Process entire dataset: convert YOLO annotations to masks

    Expected structure:
    source_dir/
    ├── images/
    │   ├── image1.jpg
    │   └── image2.jpg
    └── labels/
        ├── image1.txt
        └── image2.txt
    """

    # Create output directories
    os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'masks'), exist_ok=True)

    image_dir = os.path.join(source_dir, 'images')
    label_dir = os.path.join(source_dir, 'labels')

    image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

    print(f"Processing {len(image_files)} images...")

    for img_file in tqdm(image_files):
        # Load image
        img_path = os.path.join(image_dir, img_file)
        image = cv2.imread(img_path)

        if image is None:
            print(f"Warning: Could not load {img_file}")
            continue

        height, width = image.shape[:2]

        # Find corresponding annotation file
        base_name = os.path.splitext(img_file)[0]
        txt_file = os.path.join(label_dir, f"{base_name}.txt")

        # Convert to mask
        mask = yolo_to_mask(txt_file, (height, width), multi_class)

        # Save image and mask
        output_img_path = os.path.join(output_dir, 'images', img_file)
        output_mask_path = os.path.join(output_dir, 'masks', f"{base_name}.png")

        # Copy image
        shutil.copy2(img_path, output_img_path)

        # Save mask
        cv2.imwrite(output_mask_path, mask)

    print(f"Dataset processed! Output saved to: {output_dir}")


In [35]:
def split_dataset(processed_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
    """
    Split processed dataset into train/val/test sets
    """

    # Get all image files
    image_dir = os.path.join(processed_dir, 'images')
    image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

    # Split dataset
    train_files, temp_files = train_test_split(image_files, test_size=(1-train_ratio), random_state=42)
    val_files, test_files = train_test_split(temp_files, test_size=(test_ratio/(val_ratio+test_ratio)), random_state=42)

    print(f"Dataset split:")
    print(f"Train: {len(train_files)} images")
    print(f"Val: {len(val_files)} images")
    print(f"Test: {len(test_files)} images")

    # Create split directories
    for split, files in [('train', train_files), ('val', val_files), ('test', test_files)]:
        # Create directories
        split_img_dir = os.path.join(processed_dir, split, 'images')
        split_mask_dir = os.path.join(processed_dir, split, 'masks')
        os.makedirs(split_img_dir, exist_ok=True)
        os.makedirs(split_mask_dir, exist_ok=True)

        # Copy files
        for file in files:
            base_name = os.path.splitext(file)[0]

            # Copy image
            src_img = os.path.join(processed_dir, 'images', file)
            dst_img = os.path.join(split_img_dir, file)
            shutil.copy2(src_img, dst_img)

            # Copy mask
            src_mask = os.path.join(processed_dir, 'masks', f"{base_name}.png")
            dst_mask = os.path.join(split_mask_dir, f"{base_name}.png")
            if os.path.exists(src_mask):
                shutil.copy2(src_mask, dst_mask)


In [36]:
class ToothDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, multi_class=False):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.multi_class = multi_class

        # Get all image files
        self.images = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        # Filter images that have corresponding masks
        valid_images = []
        for img_file in self.images:
            base_name = os.path.splitext(img_file)[0]
            mask_file = os.path.join(mask_dir, f"{base_name}.png")
            if os.path.exists(mask_file):
                valid_images.append(img_file)

        self.images = valid_images
        print(f"Found {len(self.images)} valid image-mask pairs")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)

        # Corresponding mask
        base_name = os.path.splitext(img_name)[0]
        mask_path = os.path.join(self.mask_dir, f"{base_name}.png")

        # Load image and mask
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        # Convert to numpy arrays
        image = np.array(image)
        mask = np.array(mask)

        # Apply transforms
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Normalize mask
        if self.multi_class:
            # Keep class labels as integers
            mask = mask.long()
        else:
            # Binary segmentation: normalize to 0-1
            mask = mask.float() / 255.0

        return image, mask

In [37]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # Handle size mismatch
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

In [38]:
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        # Encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)

        # Decoder
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)

        return logits


In [39]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = torch.sigmoid(pred)

        # Flatten
        pred = pred.view(-1)
        target = target.view(-1)

        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)

        return 1 - dice

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, pred, target):
        ce_loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class CombinedLoss(nn.Module):
    def __init__(self, dice_weight=0.5, focal_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.dice_loss = DiceLoss()
        self.focal_loss = FocalLoss()
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight

    def forward(self, pred, target):
        dice = self.dice_loss(pred, target)
        focal = self.focal_loss(pred, target)
        return self.dice_weight * dice + self.focal_weight * focal



In [40]:
def get_transforms(image_size=512):
    train_transform = A.Compose([
        A.Resize(image_size, image_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.2),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
        A.GaussianBlur(blur_limit=3, p=0.1),
        A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
        A.OneOf([
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.5),
            A.GridDistortion(p=0.5),
            A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=0.5),
        ], p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

    val_transform = A.Compose([
        A.Resize(image_size, image_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

    return train_transform, val_transform


In [41]:
def train_model(model, train_loader, val_loader, num_epochs=100, learning_rate=0.001, save_path='best_model.pth'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    model.to(device)

    # Loss and optimizer
    criterion = CombinedLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, verbose=True)

    train_losses = []
    val_losses = []
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0

        train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for images, masks in train_bar:
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_bar.set_postfix(loss=loss.item())

        # Validation phase
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            val_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
            for images, masks in val_bar:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                val_bar.set_postfix(loss=loss.item())

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        scheduler.step(val_loss)

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
            }, save_path)
            print(f'New best model saved! Val Loss: {val_loss:.4f}')

    return train_losses, val_losses


In [42]:
def calculate_iou(pred, target, threshold=0.5):
    """Calculate Intersection over Union (IoU)"""
    pred = (torch.sigmoid(pred) > threshold).float()
    target = target.float()

    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection

    if union == 0:
        return 1.0
    return (intersection / union).item()

def calculate_dice(pred, target, threshold=0.5):
    """Calculate Dice coefficient"""
    pred = (torch.sigmoid(pred) > threshold).float()
    target = target.float()

    intersection = (pred * target).sum()
    dice = (2. * intersection) / (pred.sum() + target.sum())

    return dice.item()

def evaluate_model(model, test_loader, device):
    model.eval()
    ious = []
    dices = []

    with torch.no_grad():
        for images, masks in tqdm(test_loader, desc='Evaluating'):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)

            for i in range(outputs.size(0)):
                iou = calculate_iou(outputs[i], masks[i])
                dice = calculate_dice(outputs[i], masks[i])
                ious.append(iou)
                dices.append(dice)

    return np.mean(ious), np.mean(dices)


In [4]:
# Memory-Optimized Tooth Segmentation Pipeline
# Designed to run on limited RAM systems

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
import shutil
from tqdm import tqdm
import gc
import warnings
warnings.filterwarnings('ignore')

# ================================
# MEMORY-OPTIMIZED DATASET PROCESSING
# ================================

def yolo_to_mask(txt_file, image_shape, multi_class=False):
    """
    Convert YOLO polygon format to segmentation mask
    Memory optimized version
    """
    mask = np.zeros(image_shape[:2], dtype=np.uint8)

    if not os.path.exists(txt_file):
        return mask

    try:
        with open(txt_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) < 6:
                    continue

                class_id = int(parts[0])
                coords = [float(x) for x in parts[1:]]

                # Convert normalized coordinates to pixel coordinates
                points = []
                for i in range(0, len(coords), 2):
                    x = int(coords[i] * image_shape[1])
                    y = int(coords[i+1] * image_shape[0])
                    points.append([x, y])

                if len(points) > 2:
                    fill_value = (class_id + 1) if multi_class else 255
                    cv2.fillPoly(mask, [np.array(points)], fill_value)
    except Exception as e:
        print(f"Error processing {txt_file}: {e}")

    return mask

def process_dataset_memory_efficient(source_dir, output_dir, multi_class=False, max_size=512):
    """
    Process dataset with memory optimization
    """
    # Create output directories
    os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'masks'), exist_ok=True)

    image_dir = os.path.join(source_dir, 'images')
    label_dir = os.path.join(source_dir, 'labels')

    if not os.path.exists(image_dir):
        print(f"Error: {image_dir} does not exist")
        return

    image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

    print(f"Processing {len(image_files)} images...")
    processed_count = 0

    for i, img_file in enumerate(tqdm(image_files, desc="Converting YOLO to masks")):
        try:
            # Load image
            img_path = os.path.join(image_dir, img_file)
            image = cv2.imread(img_path)

            if image is None:
                continue

            height, width = image.shape[:2]

            # Resize if too large to save memory
            if max(height, width) > max_size:
                scale = max_size / max(height, width)
                new_width = int(width * scale)
                new_height = int(height * scale)
                image = cv2.resize(image, (new_width, new_height))
                height, width = new_height, new_width

            # Find corresponding annotation file
            base_name = os.path.splitext(img_file)[0]
            # Handle different annotation file naming conventions
            possible_txt_files = [
                f"{base_name}.txt",
                f"{base_name}_png.rf.{base_name.split('.')[-1] if '.' in base_name else 'txt'}.txt"
            ]

            txt_file = None
            for txt_name in possible_txt_files:
                txt_path = os.path.join(label_dir, txt_name)
                if os.path.exists(txt_path):
                    txt_file = txt_path
                    break

            if txt_file is None:
                # Try finding any txt file with similar name
                for f in os.listdir(label_dir):
                    if base_name in f and f.endswith('.txt'):
                        txt_file = os.path.join(label_dir, f)
                        break

            # Convert to mask
            mask = yolo_to_mask(txt_file, (height, width), multi_class)

            # Save image and mask
            output_img_path = os.path.join(output_dir, 'images', img_file)
            output_mask_path = os.path.join(output_dir, 'masks', f"{base_name}.png")

            # Save resized image
            cv2.imwrite(output_img_path, image)
            cv2.imwrite(output_mask_path, mask)

            processed_count += 1

            # Clear memory every 10 images
            if i % 10 == 0:
                gc.collect()

        except Exception as e:
            print(f"Error processing {img_file}: {e}")
            continue

    print(f"Successfully processed {processed_count} images")

# ================================
# MEMORY-EFFICIENT DATASET CLASS
# ================================

class ToothDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, multi_class=False):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.multi_class = multi_class

        # Get all image files
        self.images = []
        if os.path.exists(image_dir):
            all_images = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

            # Filter images that have corresponding masks
            for img_file in all_images:
                base_name = os.path.splitext(img_file)[0]
                mask_file = os.path.join(mask_dir, f"{base_name}.png")
                if os.path.exists(mask_file):
                    self.images.append(img_file)

        print(f"Found {len(self.images)} valid image-mask pairs")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)

        # Corresponding mask
        base_name = os.path.splitext(img_name)[0]
        mask_path = os.path.join(self.mask_dir, f"{base_name}.png")

        # Load image and mask
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        # Apply transforms
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Normalize mask
        if self.multi_class:
            mask = mask.long()
        else:
            mask = mask.float() / 255.0

        return image, mask

# ================================
# LIGHTWEIGHT U-NET ARCHITECTURE
# ================================

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class LightweightUNet(nn.Module):
    """Lightweight U-Net for memory-constrained environments"""
    def __init__(self, n_channels=3, n_classes=1):
        super(LightweightUNet, self).__init__()

        # Reduced channel sizes for memory efficiency
        self.inc = DoubleConv(n_channels, 32)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(32, 64))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(64, 128))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(128, 256))

        # Bottleneck
        self.bottleneck = nn.Sequential(nn.MaxPool2d(2), DoubleConv(256, 256))

        # Decoder
        self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv1 = DoubleConv(256 + 128, 128)  # 256 (from x4) + 128 (from up1) = 384

        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv2 = DoubleConv(128 + 64, 64)    # 128 (from x3) + 64 (from up2) = 192

        self.up3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.up_conv3 = DoubleConv(64 + 32, 32)     # 64 (from x2) + 32 (from up3) = 96

        self.up4 = nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2)
        self.up_conv4 = DoubleConv(32 + 32, 32)

        self.outc = nn.Conv2d(32, n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.bottleneck(x4)

        # Decoder
        x = self.up1(x5)
        x = torch.cat([x, x4], dim=1)
        x = self.up_conv1(x)

        x = self.up2(x)
        x = torch.cat([x, x3], dim=1)
        x = self.up_conv2(x)

        x = self.up3(x)
        x = torch.cat([x, x2], dim=1)
        x = self.up_conv3(x)

        x = self.up4(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up_conv4(x)

        return self.outc(x)

# ================================
# MEMORY-EFFICIENT TRAINING
# ================================

def train_model_memory_efficient(model, train_loader, val_loader, num_epochs=50, learning_rate=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    model.to(device)

    # Use mixed precision training if available
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    # Simple BCE loss for memory efficiency
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0

        for images, masks in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            images, masks = images.to(device), masks.to(device)
            if len(masks.shape) == 3:  # [batch, height, width]
             masks = masks.unsqueeze(1)

            optimizer.zero_grad()

            if scaler:
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = criterion(outputs, masks)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(images)
                loss = criterion(outputs, masks)
                loss.backward()
                optimizer.step()

            train_loss += loss.item()

            # Clear cache
            del outputs, loss
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

        # Validation
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)

                if len(masks.shape) == 3:  # [batch, height, width]
                   masks = masks.unsqueeze(1)

                if scaler:
                    with torch.cuda.amp.autocast():
                        outputs = model(images)
                        loss = criterion(outputs, masks)
                else:
                    outputs = model(images)
                    loss = criterion(outputs, masks)

                val_loss += loss.item()

                # Clear cache
                del outputs, loss
                torch.cuda.empty_cache() if torch.cuda.is_available() else None

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        scheduler.step()

        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_tooth_model.pth')
            print(f'New best model saved!')

        # Force garbage collection
        gc.collect()

# ================================
# SIMPLE TRANSFORMS
# ================================

def get_simple_transforms(image_size=256):
    """Simple transforms to reduce memory usage"""
    train_transform = A.Compose([
        A.Resize(image_size, image_size),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

    val_transform = A.Compose([
        A.Resize(image_size, image_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

    return train_transform, val_transform

# ================================
# MAIN EXECUTION (MEMORY OPTIMIZED)
# ================================

def main():
    # Conservative configuration for low memory
    CONFIG = {
        'source_dir': '/content/custom_data/train',
        'processed_dir': '/content/processed_data',
        'image_size': 256,  # Reduced from 512
        'batch_size': 2,    # Reduced from 8
        'num_epochs': 10,   # Reduced from 100
        'learning_rate': 0.001,
        'multi_class': False,
        'max_size': 512,    # Max image size during processing
    }

    print("=== MEMORY-OPTIMIZED TOOTH SEGMENTATION ===")

    # Step 1: Process dataset
    print("Step 1: Converting YOLO annotations to masks...")
    process_dataset_memory_efficient(
        CONFIG['source_dir'],
        CONFIG['processed_dir'],
        CONFIG['multi_class'],
        CONFIG['max_size']
    )

    # Step 2: Split dataset
    print("\nStep 2: Splitting dataset...")
    image_dir = os.path.join(CONFIG['processed_dir'], 'images')
    if os.path.exists(image_dir):
        image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        train_files, temp_files = train_test_split(image_files, test_size=0.3, random_state=42)
        val_files, test_files = train_test_split(temp_files, test_size=0.5, random_state=42)

        # Create split directories
        for split, files in [('train', train_files), ('val', val_files), ('test', test_files)]:
            split_img_dir = os.path.join(CONFIG['processed_dir'], split, 'images')
            split_mask_dir = os.path.join(CONFIG['processed_dir'], split, 'masks')
            os.makedirs(split_img_dir, exist_ok=True)
            os.makedirs(split_mask_dir, exist_ok=True)

            for file in files:
                base_name = os.path.splitext(file)[0]

                # Copy image
                src_img = os.path.join(CONFIG['processed_dir'], 'images', file)
                dst_img = os.path.join(split_img_dir, file)
                if os.path.exists(src_img):
                    shutil.copy2(src_img, dst_img)

                # Copy mask
                src_mask = os.path.join(CONFIG['processed_dir'], 'masks', f"{base_name}.png")
                dst_mask = os.path.join(split_mask_dir, f"{base_name}.png")
                if os.path.exists(src_mask):
                    shutil.copy2(src_mask, dst_mask)

    # Step 3: Create datasets
    print("\nStep 3: Creating datasets...")
    train_transform, val_transform = get_simple_transforms(CONFIG['image_size'])

    train_dataset = ToothDataset(
        os.path.join(CONFIG['processed_dir'], 'train', 'images'),
        os.path.join(CONFIG['processed_dir'], 'train', 'masks'),
        train_transform
    )

    val_dataset = ToothDataset(
        os.path.join(CONFIG['processed_dir'], 'val', 'images'),
        os.path.join(CONFIG['processed_dir'], 'val', 'masks'),
        val_transform
    )

    # Use num_workers=0 to avoid multiprocessing issues
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0)

    # Step 4: Train model
    print("\nStep 4: Training model...")
    model = LightweightUNet(n_channels=3, n_classes=1)

    train_model_memory_efficient(
        model, train_loader, val_loader,
        CONFIG['num_epochs'], CONFIG['learning_rate']
    )

    print("\n=== TRAINING COMPLETED ===")
    print("Model saved as 'best_tooth_model.pth'")

if __name__ == "__main__":
    main()

=== MEMORY-OPTIMIZED TOOTH SEGMENTATION ===
Step 1: Converting YOLO annotations to masks...
Processing 252 images...


Converting YOLO to masks: 100%|██████████| 252/252 [00:09<00:00, 25.22it/s]


Successfully processed 252 images

Step 2: Splitting dataset...

Step 3: Creating datasets...
Found 176 valid image-mask pairs
Found 50 valid image-mask pairs

Step 4: Training model...
Using device: cpu


Epoch 1/10: 100%|██████████| 88/88 [03:06<00:00,  2.11s/it]


Epoch 1: Train Loss: 0.3944, Val Loss: 0.2831
New best model saved!


Epoch 2/10: 100%|██████████| 88/88 [03:07<00:00,  2.13s/it]


Epoch 2: Train Loss: 0.2122, Val Loss: 0.1805
New best model saved!


Epoch 3/10: 100%|██████████| 88/88 [03:05<00:00,  2.10s/it]


Epoch 3: Train Loss: 0.1712, Val Loss: 0.1707
New best model saved!


Epoch 4/10: 100%|██████████| 88/88 [03:09<00:00,  2.15s/it]


Epoch 4: Train Loss: 0.1593, Val Loss: 0.2462


Epoch 5/10: 100%|██████████| 88/88 [03:03<00:00,  2.09s/it]


Epoch 5: Train Loss: 0.1514, Val Loss: 0.4193


Epoch 6/10: 100%|██████████| 88/88 [03:03<00:00,  2.09s/it]


Epoch 6: Train Loss: 0.1449, Val Loss: 0.1435
New best model saved!


Epoch 7/10: 100%|██████████| 88/88 [03:02<00:00,  2.07s/it]


Epoch 7: Train Loss: 0.1318, Val Loss: 0.1280
New best model saved!


Epoch 8/10: 100%|██████████| 88/88 [03:04<00:00,  2.09s/it]


Epoch 8: Train Loss: 0.1247, Val Loss: 0.1260
New best model saved!


Epoch 9/10: 100%|██████████| 88/88 [03:03<00:00,  2.08s/it]


Epoch 9: Train Loss: 0.1223, Val Loss: 0.1228
New best model saved!


Epoch 10/10: 100%|██████████| 88/88 [03:03<00:00,  2.08s/it]


Epoch 10: Train Loss: 0.1208, Val Loss: 0.1214
New best model saved!

=== TRAINING COMPLETED ===
Model saved as 'best_tooth_model.pth'


In [5]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
import cv2
from PIL import Image

# ================================
# VISUALIZATION FUNCTIONS
# ================================

def visualize_predictions(model, dataloader, device, num_samples=4):
    """Visualize model predictions vs ground truth"""
    model.eval()

    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    with torch.no_grad():
        for i, (images, masks) in enumerate(dataloader):
            if i >= num_samples:
                break

            images, masks = images.to(device), masks.to(device)

            # Fix mask shape if needed
            if len(masks.shape) == 3:
                masks = masks.unsqueeze(1)

            # Get predictions
            outputs = model(images)
            predictions = torch.sigmoid(outputs)  # Convert to probabilities

            # Take first image in batch
            image = images[0].cpu().numpy().transpose(1, 2, 0)
            mask = masks[0, 0].cpu().numpy()
            pred = predictions[0, 0].cpu().numpy()

            # Denormalize image for display
            image = denormalize_image(image)

            # Plot results
            axes[i, 0].imshow(image)
            axes[i, 0].set_title('Original Image')
            axes[i, 0].axis('off')

            axes[i, 1].imshow(mask, cmap='gray')
            axes[i, 1].set_title('Ground Truth')
            axes[i, 1].axis('off')

            axes[i, 2].imshow(pred, cmap='gray')
            axes[i, 2].set_title('Prediction (Prob)')
            axes[i, 2].axis('off')

            # Binary prediction (threshold at 0.5)
            binary_pred = (pred > 0.5).astype(np.uint8)
            axes[i, 3].imshow(binary_pred, cmap='gray')
            axes[i, 3].set_title('Binary Prediction')
            axes[i, 3].axis('off')

    plt.tight_layout()
    plt.savefig('segmentation_results.png', dpi=150, bbox_inches='tight')
    plt.show()

def denormalize_image(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """Denormalize image for visualization"""
    image = image.copy()
    for i in range(3):
        image[:, :, i] = image[:, :, i] * std[i] + mean[i]
    return np.clip(image, 0, 1)

def overlay_prediction_on_image(image, prediction, alpha=0.5):
    """Overlay prediction mask on original image"""
    # Convert to numpy if needed
    if torch.is_tensor(image):
        image = image.cpu().numpy().transpose(1, 2, 0)
    if torch.is_tensor(prediction):
        prediction = prediction.cpu().numpy()

    # Denormalize image
    image = denormalize_image(image)

    # Create colored overlay (red for predictions)
    overlay = np.zeros_like(image)
    overlay[:, :, 0] = prediction  # Red channel

    # Blend with original image
    result = image * (1 - alpha) + overlay * alpha
    return np.clip(result, 0, 1)

# ================================
# EVALUATION METRICS
# ================================

def calculate_iou(pred, target, threshold=0.5):
    """Calculate Intersection over Union (IoU)"""
    pred_binary = (pred > threshold).float()
    target_binary = target.float()

    intersection = (pred_binary * target_binary).sum()
    union = pred_binary.sum() + target_binary.sum() - intersection

    iou = intersection / (union + 1e-8)
    return iou.item()

def calculate_dice(pred, target, threshold=0.5):
    """Calculate Dice coefficient"""
    pred_binary = (pred > threshold).float()
    target_binary = target.float()

    intersection = (pred_binary * target_binary).sum()
    dice = (2 * intersection) / (pred_binary.sum() + target_binary.sum() + 1e-8)

    return dice.item()

def evaluate_model(model, dataloader, device):
    """Evaluate model performance"""
    model.eval()
    total_iou = 0
    total_dice = 0
    total_samples = 0

    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)

            # Fix mask shape if needed
            if len(masks.shape) == 3:
                masks = masks.unsqueeze(1)

            outputs = model(images)
            predictions = torch.sigmoid(outputs)

            # Calculate metrics for each image in batch
            for i in range(images.size(0)):
                iou = calculate_iou(predictions[i], masks[i])
                dice = calculate_dice(predictions[i], masks[i])

                total_iou += iou
                total_dice += dice
                total_samples += 1

    avg_iou = total_iou / total_samples
    avg_dice = total_dice / total_samples

    print(f"Average IoU: {avg_iou:.4f}")
    print(f"Average Dice: {avg_dice:.4f}")

    return avg_iou, avg_dice

# ================================
# PREDICTION ON NEW IMAGES
# ================================

def predict_on_image(model, image_path, device, transform=None):
    """Predict segmentation mask for a single image"""
    model.eval()

    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')

    if transform is None:
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    input_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)
        prediction = torch.sigmoid(output)

    # Convert to numpy
    pred_mask = prediction[0, 0].cpu().numpy()

    return pred_mask

def save_prediction_results(model, image_path, output_dir, device):
    """Save prediction results as images"""
    import os
    os.makedirs(output_dir, exist_ok=True)

    # Get prediction
    pred_mask = predict_on_image(model, image_path, device)

    # Load original image
    original_image = Image.open(image_path).convert('RGB')
    original_image = original_image.resize((256, 256))

    # Save results
    base_name = os.path.splitext(os.path.basename(image_path))[0]

    # Save original
    original_image.save(os.path.join(output_dir, f"{base_name}_original.png"))

    # Save prediction mask
    pred_img = Image.fromarray((pred_mask * 255).astype(np.uint8))
    pred_img.save(os.path.join(output_dir, f"{base_name}_prediction.png"))

    # Save binary mask
    binary_mask = (pred_mask > 0.5).astype(np.uint8) * 255
    binary_img = Image.fromarray(binary_mask)
    binary_img.save(os.path.join(output_dir, f"{base_name}_binary.png"))

    # Save overlay
    overlay = overlay_prediction_on_image(
        np.array(original_image) / 255.0,
        pred_mask,
        alpha=0.3
    )
    overlay_img = Image.fromarray((overlay * 255).astype(np.uint8))
    overlay_img.save(os.path.join(output_dir, f"{base_name}_overlay.png"))

    print(f"Results saved to {output_dir}")

# ================================
# USAGE EXAMPLES
# ================================

def main_evaluation():
    """Example usage of evaluation functions"""

    # Load your model
    model = LightweightUNet(n_channels=3, n_classes=1)
    model.load_state_dict(torch.load('best_tooth_model.pth'))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Assume you have val_loader ready
    # val_loader = your_validation_dataloader

    # 1. Visualize predictions
    print("Visualizing predictions...")
    # visualize_predictions(model, val_loader, device, num_samples=4)

    # 2. Evaluate model performance
    print("Evaluating model...")
    # avg_iou, avg_dice = evaluate_model(model, val_loader, device)

    # 3. Predict on a single image
    print("Predicting on single image...")
    # pred_mask = predict_on_image(model, 'path/to/your/image.jpg', device)

    # 4. Save results for multiple images
    print("Saving prediction results...")
    # save_prediction_results(model, 'path/to/image.jpg', 'output_results/', device)

if __name__ == "__main__":
    main_evaluation()

Visualizing predictions...
Evaluating model...
Predicting on single image...
Saving prediction results...
