<a href="https://colab.research.google.com/github/Swapneel642/U-net-Architecture-for-waste/blob/main/Untitled1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# Directories for training and validation
TRAIN_IMG_DIR = '/content/drive/MyDrive/U-Net_Dataset/train_image'
TRAIN_MASK_DIR = '/content/drive/MyDrive/U-Net_Dataset/train_mask'
VAL_IMG_DIR = '/content/drive/MyDrive/U-Net_Dataset/val_image'
VAL_MASK_DIR = '/content/drive/MyDrive/U-Net_Dataset/val_mask'

#Library

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision

#Unet Model

In [4]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Downsampling path
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Upsampling path
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature * 2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)  # ConvTranspose2d
            skip_connection = skip_connections[idx // 2]

            # Resize x to match skip_connection if needed
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:], mode="bilinear", align_corners=True)

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](concat_skip)  # DoubleConv

        return self.final_conv(x)

# Custom Dataset Class

In [5]:
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(img_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])

        # Ensure mask filename matches the pattern 'frame_XXXXXXXX_Waste.png'
        mask_filename = self.images[idx].replace(".png", "_Waste.png")
        mask_path = os.path.join(self.mask_dir, mask_filename)

        # Load the image and mask
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)

        # Normalize mask (convert 255 → 1)
        mask[mask == 255.0] = 1.0

        # Apply transformations if provided
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image, mask = augmented["image"], augmented["mask"]

        return image, mask

# Transformations

In [6]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Define transformations
train_transform = A.Compose([
    A.Resize(height=256, width=256),
    A.Rotate(limit=35, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.1),
    A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=256, width=256),
    A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
    ToTensorV2(),
])



# Create Data Loaders

In [7]:
from torch.utils.data import DataLoader

def get_loaders(train_img_dir, train_mask_dir, val_img_dir, val_mask_dir, batch_size, train_transform, val_transform, num_workers=2, pin_memory=True):
    train_ds = SegmentationDataset(img_dir=train_img_dir, mask_dir=train_mask_dir, transform=train_transform)
    val_ds = SegmentationDataset(img_dir=val_img_dir, mask_dir=val_mask_dir, transform=val_transform)

    train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False)

    return train_loader, val_loader


In [8]:
# Define DEVICE before initializing the model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize model
model = UNET(in_channels=3, out_channels=1).to(DEVICE)

# Implement the Training Function

In [9]:
from tqdm import tqdm

def train_fn(loader, model, optimizer, loss_fn, scaler):
    model.train()  # Set model to training mode
    loop = tqdm(loader, leave=True)  # Progress bar

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)  # Add channel dimension

        # Forward pass with mixed precision
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # Backward pass
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update progress bar
        loop.set_postfix(loss=loss.item())

# Implement Model Evaluation

In [10]:
def check_accuracy(loader, model, device="cuda"):
    """Checks model accuracy and Dice score on the dataset."""

    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()  # Set model to evaluation mode

    with torch.no_grad():  # Disable gradient calculation
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)  # Ensure proper shape

            preds = torch.sigmoid(model(x))  # Apply sigmoid activation
            preds = (preds > 0.5).float()  # Convert to binary mask

            num_correct += (preds == y).sum().item()
            num_pixels += preds.numel()
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)

    acc = (num_correct / num_pixels) * 100
    print(f" Validation Accuracy: {acc:.2f}%")
    print(f" Dice Score: {dice_score / len(loader):.4f}")

    model.train()  # Switch back to training mode
    return acc, dice_score / len(loader)

# Saving Model Checkpoints

In [11]:
def save_checkpoint(state, filename="checkpoint.pth.tar"):
    """Saves model checkpoint."""
    print(" Saving checkpoint...")
    torch.save(state, filename)
    print(" Checkpoint saved!")

# Loading Model Checkpoints

In [12]:
def load_checkpoint(checkpoint, model):
    """Loads model checkpoint."""
    print("Loading checkpoint...")
    model.load_state_dict(checkpoint["state_dict"])
    print("Checkpoint loaded!")

# Training Function

In [13]:
from tqdm import tqdm
import torch.nn.functional as F

def train_fn(loader, model, optimizer, loss_fn, scaler, device="cuda"):
    """Trains the model for one epoch."""

    loop = tqdm(loader, leave=True)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device)
        targets = targets.float().unsqueeze(1).to(device)

        # Forward pass
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # Backward pass
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update tqdm progress bar
        loop.set_postfix(loss=loss.item())

#  Accuracy & Dice Score Calculation

In [14]:
def check_accuracy(loader, model, device="cuda"):
    """Evaluates model accuracy and Dice score on a dataset."""

    num_correct = 0
    num_pixels = 0
    dice_score = 0

    model.eval()  # Set model to evaluation mode

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)  # Ensure proper shape

            preds = torch.sigmoid(model(x))  # Apply sigmoid activation
            preds = (preds > 0.5).float()  # Convert to binary mask

            num_correct += (preds == y).sum().item()
            num_pixels += preds.numel()
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)

    acc = (num_correct / num_pixels) * 100
    dice = dice_score / len(loader)

    print(f"Accuracy: {acc:.2f}%")
    print(f"Dice Score: {dice:.4f}")

    model.train()  # Switch back to training mode

    return acc, dice

# Save Model Predictions as Images

In [15]:
def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
    """Saves model predictions as images for visualization."""

    if not os.path.exists(folder):
        os.makedirs(folder)

    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device)

        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()  # Convert to binary mask

        # Save predicted mask
        torchvision.utils.save_image(preds, f"{folder}/pred_{idx}.png")
        # Save ground truth mask for comparison
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}/gt_{idx}.png")

    print(f"Saved predictions in '{folder}'")
    model.train()  # Switch back to training mode

# Training Function

In [16]:
def train_fn(loader, model, optimizer, loss_fn, scaler, device="cuda"):
    """Trains the model for one epoch on the given data loader."""

    loop = tqdm(loader, leave=True)  # Progress bar

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device)
        targets = targets.float().unsqueeze(1).to(device)

        # Forward pass
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # Backward pass
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update tqdm loop with loss value
        loop.set_postfix(loss=loss.item())

# Main Training Loop

In [None]:
# Define DEVICE before initializing the model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize model
model = UNET(in_channels=3, out_channels=1).to(DEVICE)

# Hyperparameters
LEARNING_RATE = 1e-4
BATCH_SIZE = 16
NUM_EPOCHS = 15
NUM_WORKERS = 2
IMAGE_WIDTH = 256
IMAGE_HEIGHT = 256
PIN_MEMORY = True
LOAD_MODEL = False

# Create Data Loaders
train_loader, val_loader = get_loaders(
    TRAIN_IMG_DIR, TRAIN_MASK_DIR, VAL_IMG_DIR, VAL_MASK_DIR,
    BATCH_SIZE, NUM_WORKERS, PIN_MEMORY
)

# Define Loss Function & Optimizer
loss_fn = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Define GradScaler for mixed precision training
scaler = torch.cuda.amp.GradScaler()


def main():
    """Main training loop to train U-Net model for multiple epochs."""
    for epoch in range(NUM_EPOCHS):
        print(f"\n🔹 Epoch {epoch + 1}/{NUM_EPOCHS}")

        # Train for one epoch
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # Check accuracy every 5 epochs
        if (epoch + 1) % 5 == 0:
            check_accuracy(val_loader, model, DEVICE)

        # Save model checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            save_checkpoint({"state_dict": model.state_dict()}, filename=f"checkpoint_epoch_{epoch+1}.pth.tar")

# Run training
if __name__ == "__main__":
    main()