<a href="https://colab.research.google.com/github/JensH-2157843/AML_Project/blob/main/src/neural_networks/NN1(segmentation).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Library imports

In [1]:
!pip install segmentation-models==1.0.1 albumentations==1.3.1 --quiet
import os
import numpy as np
from PIL import Image
from glob import glob
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision.transforms as T
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import tensorflow as tf
import matplotlib.pyplot as plt
from transformers import SegformerFeatureExtractor

import time
import copy
from torchvision import transforms

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.7/125.7 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/50.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.7/50.7 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h

# Dataset import

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
## DATASET IMPORT ##
deepglobe_dir = "/content/drive/MyDrive/train"
import os

deepglobe_images = sorted(glob(os.path.join(deepglobe_dir, '*_sat.jpg')))
deepglobe_masks = sorted(glob(os.path.join(deepglobe_dir, '*_mask.png')))

for tile in sorted(os.listdir(deepglobe_dir)):
    tile_path = os.path.join(deepglobe_dir, tile)
    if not os.path.isdir(tile_path):
        continue
    img_folder = os.path.join(tile_path, "images")
    mask_folder = os.path.join(tile_path, "masks")
    deepglobe_images.extend(sorted(glob(os.path.join(img_folder, '*.jpg'))))
    deepglobe_masks.extend(sorted(glob(os.path.join(mask_folder, '*.png'))))

all_images = deepglobe_images
all_masks = deepglobe_masks

train_imgs, val_imgs, train_masks, val_masks = train_test_split(
    all_images, all_masks, test_size=0.2, random_state=42
)

In [4]:
IMG_SIZE = (256, 256)

def rgb_to_binary_mask(mask_image, suitable_rgbs):
    mask = np.array(mask_image)
    binary_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int64)
    for rgb in suitable_rgbs:
        matches = np.all(mask == rgb, axis=-1)
        binary_mask[matches] = 1
    return binary_mask

image_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

mask_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE, interpolation=transforms.InterpolationMode.NEAREST)
])

In [5]:
# --- Define Your PyTorch Dataset (This IS your combined loader + preprocessor) ---
class SolarPanelDataset(Dataset):
    def __init__(self, img_paths, mask_paths, suitable_rgbs, img_transform=None, mask_transform=None):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.suitable_rgbs = suitable_rgbs
        self.img_transform = img_transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        # --- Start of Preprocessing Logic (PyTorch version) ---
        img_path = self.img_paths[idx]
        mask_path = self.mask_paths[idx]
        image = Image.open(img_path).convert("RGB")
        mask_rgb = Image.open(mask_path).convert("RGB")
        if self.img_transform:
            image = self.img_transform(image) # Applies resize, ToTensor, Normalize
        if self.mask_transform:
            mask_rgb = self.mask_transform(mask_rgb) # Applies resize (NEAREST)
        mask_binary = rgb_to_binary_mask(mask_rgb, self.suitable_rgbs) # Converts mask
        mask = torch.from_numpy(mask_binary) # To PyTorch Tensor
        # --- End of Preprocessing Logic ---
        return image, mask

In [6]:
# RGB values for classes we consider 'Suitable' (Class 1)
SUITABLE_RGB_VALUES = [
    (255, 255, 0),  # Agriculture land
    (255, 0, 255),  # Rangeland
    (255, 255, 255),# Barren land
    (60, 16, 152),  # Building
    (132, 41, 246)  # Unpaved land
]

In [7]:
val_loader = SolarPanelDataset(train_imgs, train_masks, SUITABLE_RGB_VALUES,  image_transforms, mask_transforms)
train_loader = SolarPanelDataset(val_imgs, val_masks, SUITABLE_RGB_VALUES,  image_transforms, mask_transforms)

val_loader = DataLoader(val_loader, batch_size=20, shuffle=False)
train_loader = DataLoader(train_loader, batch_size=20, shuffle=True)

# Model

In [None]:
## ARCHITECTURE ##
class ConvBlock(nn.Module):
    """
    Convolutional Block: (Conv -> BN -> ReLU) * 2
    """
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        return x

class EncoderBlock(nn.Module):
    """
    Encoder Block: ConvBlock -> MaxPool
    Returns both ConvBlock output (skip) and MaxPool output.
    """
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv_block = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        skip = self.conv_block(x)
        pooled = self.pool(skip)
        return skip, pooled

class DecoderBlock(nn.Module):
    """
    Decoder Block: ConvTranspose -> Concat -> ConvBlock
    """
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        # Upsamples by a factor of 2, halving the channels.
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        # ConvBlock takes concatenated input (skip + upconv), so its input channels
        # will be out_channels (from skip) + out_channels (from upconv).
        self.conv_block = ConvBlock(out_channels * 2, out_channels)

    def forward(self, x, skip_connection):
        x = self.upconv(x)

        # Ensure spatial dimensions match before concatenating.
        # If input sizes are powers of 2, they should match.
        # If not, cropping (from skip) or padding (to x) might be needed.
        # Here we assume they match or crop the skip connection if necessary.
        if x.shape != skip_connection.shape:
            # Simple center-cropping (adjust if needed)
            diffY = skip_connection.size()[2] - x.size()[2]
            diffX = skip_connection.size()[3] - x.size()[3]
            skip_connection = skip_connection[:, :, diffY // 2 : skip_connection.size()[2] - diffY // 2 - diffY % 2,
                                                diffX // 2 : skip_connection.size()[3] - diffX // 2 - diffX % 2]

        x = torch.cat([x, skip_connection], dim=1) # Concatenate along channel dimension (dim=1)
        x = self.conv_block(x)
        return x

class DeepUnet(nn.Module):

    def __init__(self, in_channels=3, out_classes=11):
        """
        Initializes the DeepUnet model.

        Args:
            in_channels (int): Number of input channels (e.g., 3 for RGB).
            out_classes (int): Number of output segmentation classes.
        """
        super(DeepUnet, self).__init__()
        self.in_channels = in_channels
        self.out_classes = out_classes

        # Encoder Path
        self.enc1 = EncoderBlock(in_channels, 64)
        self.enc2 = EncoderBlock(64, 128)
        self.enc3 = EncoderBlock(128, 256)
        self.enc4 = EncoderBlock(256, 512)

        # Bottleneck
        self.bottleneck = ConvBlock(512, 1024)

        # Decoder Path
        self.dec1 = DecoderBlock(1024, 512)
        self.dec2 = DecoderBlock(512, 256)
        self.dec3 = DecoderBlock(256, 128)
        self.dec4 = DecoderBlock(128, 64)

        # Output Layer
        self.output_conv = nn.Conv2d(64, out_classes, kernel_size=1)

        # Optional: Softmax layer. Often omitted if using CrossEntropyLoss,
        # which combines LogSoftmax and NLLLoss.
        # self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        """
        Defines the forward pass of the U-Net.

        Args:
            x (Tensor): The input tensor (N, C, H, W).

        Returns:
            Tensor: The output segmentation map (N, out_classes, H, W).
        """
        # Encoder path
        s1, p1 = self.enc1(x)
        s2, p2 = self.enc2(p1)
        s3, p3 = self.enc3(p2)
        s4, p4 = self.enc4(p3)

        # Bottleneck
        b1 = self.bottleneck(p4)

        # Decoder path
        d1 = self.dec1(b1, s4)
        d2 = self.dec2(d1, s3)
        d3 = self.dec3(d2, s2)
        d4 = self.dec4(d3, s1)

        # Output
        outputs = self.output_conv(d4)

        # Optional: Apply softmax
        # outputs = self.softmax(outputs)

        return outputs

# Learning algorithm

In [None]:
# --- Configuration & Constants ---
IMG_SIZE = (256, 256)
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50 # A good starting point, adjust as needed
IN_CHANNELS = 3
OUT_CLASSES = 2 # 0: Not Suitable, 1: Suitable
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# EarlyStopping Configuration
EARLY_STOPPING_PATIENCE = 7 # Number of epochs to wait for improvement before stopping
EARLY_STOPPING_MIN_DELTA = 0.0001 # Minimum change in monitored quantity to qualify as improvement

# --- 4. Model, Loss, Optimizer ---
print("Setting up model, loss, and optimizer...")
model = DeepUnet(in_channels=IN_CHANNELS, out_classes=OUT_CLASSES).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# --- 5. Training Loop with EarlyStopping and History ---
print("Starting training...")
history = {'train_loss': [], 'val_loss': []}
best_val_loss = float('inf')
epochs_no_improve = 0
best_model_weights = copy.deepcopy(model.state_dict()) # Store best model

for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    model.train()
    running_train_loss = 0.0
    for i, (images, masks) in enumerate(train_loader):
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_train_loss += loss.item()
        if (i + 1) % 20 == 0: # Print training progress more frequently
             print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], Batch Loss: {loss.item():.4f}")

    avg_train_loss = running_train_loss / len(train_loader)
    history['train_loss'].append(avg_train_loss)

    # Validation
    model.eval()
    running_val_loss = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, masks)
            running_val_loss += loss.item()
    avg_val_loss = running_val_loss / len(val_loader)
    history['val_loss'].append(avg_val_loss)
    epoch_time = time.time() - start_time

    print(f"--- Epoch {epoch+1}/{NUM_EPOCHS} Finished ---")
    print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
    print(f"Epoch Duration: {epoch_time:.2f} seconds")

    # EarlyStopping Check
    if avg_val_loss < best_val_loss - EARLY_STOPPING_MIN_DELTA:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        best_model_weights = copy.deepcopy(model.state_dict())
        print(f"Validation loss improved. Saving model weights.")
    else:
        epochs_no_improve += 1
        print(f"Validation loss did not improve for {epochs_no_improve} epoch(s).")

    if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
        print(f"Early stopping triggered after {epoch+1} epochs.")
        model.load_state_dict(best_model_weights) # Restore best model weights
        break
    print("-" * 30)

print("Training finished!")
if epoch < NUM_EPOCHS -1 and epochs_no_improve < EARLY_STOPPING_PATIENCE : # If not early stopped
    print("Completed all epochs.")
    model.load_state_dict(best_model_weights) # Ensure best model is loaded if early stopping wasn't triggered but patience was > 0

# --- 6. Print Loss History ---
print("\n--- Training History ---")
for i in range(len(history['train_loss'])):
    print(f"Epoch {i+1}: Train Loss = {history['train_loss'][i]:.4f}, Val Loss = {history['val_loss'][i]:.4f}")

Setting up model, loss, and optimizer...
Starting training...
--- Epoch 1/50 Finished ---
Train Loss: 0.7818 | Val Loss: 0.6749
Epoch Duration: 613.67 seconds
Validation loss improved. Saving model weights.
------------------------------
--- Epoch 2/50 Finished ---
Train Loss: 0.6248 | Val Loss: 0.6331
Epoch Duration: 179.33 seconds
Validation loss improved. Saving model weights.
------------------------------
--- Epoch 3/50 Finished ---
Train Loss: 0.5680 | Val Loss: 0.5611
Epoch Duration: 179.05 seconds
Validation loss improved. Saving model weights.
------------------------------
--- Epoch 4/50 Finished ---
Train Loss: 0.5083 | Val Loss: 0.5130
Epoch Duration: 174.86 seconds
Validation loss improved. Saving model weights.
------------------------------
--- Epoch 5/50 Finished ---
Train Loss: 0.5259 | Val Loss: 0.4906
Epoch Duration: 177.66 seconds
Validation loss improved. Saving model weights.
------------------------------
--- Epoch 6/50 Finished ---
Train Loss: 0.4511 | Val Loss:

In [None]:
# Define the path where you want to save the model
model_save_path = "solar_unet_model.pth"

# Save only the model's state dictionary (recommended for inference/retraining)
torch.save(model.state_dict(), model_save_path)

print(f"Neural network model saved successfully to {model_save_path}")

# Pretained one

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50
import copy
import time

# --- Configuration & Constants ---
IMG_SIZE = (256, 256)
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50
IN_CHANNELS = 3
OUT_CLASSES = 2  # 0: Not Suitable, 1: Suitable
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Transfer Learning Configuration
FREEZE_ENCODER = True  # Whether to freeze encoder weights initially
UNFREEZE_AFTER_EPOCHS = 10  # Unfreeze encoder after this many epochs
PRETRAINED_LR_FACTOR = 0.1  # Learning rate multiplier for pretrained layers

# EarlyStopping Configuration
EARLY_STOPPING_PATIENCE = 7
EARLY_STOPPING_MIN_DELTA = 0.0001

class PretrainedEncoderUNet(nn.Module):
    """
    U-Net with pretrained encoder (ResNet backbone)
    """
    def __init__(self, in_channels=3, out_classes=2, pretrained=True):
        super(PretrainedEncoderUNet, self).__init__()

        # Use ResNet50 as encoder backbone
        resnet = models.resnet50(pretrained=pretrained)

        # Remove the final layers (avgpool and fc)
        self.encoder = nn.Sequential(*list(resnet.children())[:-2])

        # Encoder feature dimensions for ResNet50
        encoder_dims = [64, 256, 512, 1024, 2048]

        # Decoder blocks
        self.decoder5 = self._make_decoder_block(2048, 1024)
        self.decoder4 = self._make_decoder_block(1024 + 1024, 512)  # +1024 from skip connection
        self.decoder3 = self._make_decoder_block(512 + 512, 256)   # +512 from skip connection
        self.decoder2 = self._make_decoder_block(256 + 256, 128)   # +256 from skip connection
        self.decoder1 = self._make_decoder_block(128 + 64, 64)     # +64 from skip connection

        # Final output layer
        self.final_conv = nn.Conv2d(64, out_classes, kernel_size=1)

        # Store encoder layer references for skip connections
        self.encoder_layers = self._get_encoder_layers(resnet)

    def _make_decoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def _get_encoder_layers(self, resnet):
        """Extract individual layers from ResNet for skip connections"""
        layers = {}
        layers['conv1'] = resnet.conv1
        layers['bn1'] = resnet.bn1
        layers['relu'] = resnet.relu
        layers['maxpool'] = resnet.maxpool
        layers['layer1'] = resnet.layer1
        layers['layer2'] = resnet.layer2
        layers['layer3'] = resnet.layer3
        layers['layer4'] = resnet.layer4
        return layers

    def forward(self, x):
        # Encoder path with skip connections
        # Initial conv block
        x1 = self.encoder_layers['conv1'](x)
        x1 = self.encoder_layers['bn1'](x1)
        x1 = self.encoder_layers['relu'](x1)
        skip1 = x1  # 64 channels

        x2 = self.encoder_layers['maxpool'](x1)
        x2 = self.encoder_layers['layer1'](x2)
        skip2 = x2  # 256 channels

        x3 = self.encoder_layers['layer2'](x2)
        skip3 = x3  # 512 channels

        x4 = self.encoder_layers['layer3'](x3)
        skip4 = x4  # 1024 channels

        x5 = self.encoder_layers['layer4'](x4)  # 2048 channels (bottleneck)

        # Decoder path
        d5 = self.decoder5(x5)
        d4 = torch.cat([d5, skip4], dim=1)
        d4 = self.decoder4(d4)

        d3 = torch.cat([d4, skip3], dim=1)
        d3 = self.decoder3(d3)

        d2 = torch.cat([d3, skip2], dim=1)
        d2 = self.decoder2(d2)

        d1 = torch.cat([d2, skip1], dim=1)
        d1 = self.decoder1(d1)

        # Final output
        output = self.final_conv(d1)

        return output

    def freeze_encoder(self):
        """Freeze encoder parameters"""
        for name, param in self.named_parameters():
            if 'encoder' in name:
                param.requires_grad = False
        print("Encoder layers frozen")

    def unfreeze_encoder(self):
        """Unfreeze encoder parameters"""
        for name, param in self.named_parameters():
            if 'encoder' in name:
                param.requires_grad = True
        print("Encoder layers unfrozen")

class DeepLabTransferUNet(nn.Module):
    """
    Alternative: Use DeepLabV3 as base and modify for binary segmentation
    """
    def __init__(self, out_classes=2):
        super(DeepLabTransferUNet, self).__init__()

        # Load pretrained DeepLabV3
        self.backbone = deeplabv3_resnet50(pretrained=True)

        # Modify classifier for our number of classes
        self.backbone.classifier = nn.Sequential(
            nn.Conv2d(2048, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Conv2d(256, out_classes, kernel_size=1)
        )

        # Modify auxiliary classifier if it exists
        if hasattr(self.backbone, 'aux_classifier'):
            self.backbone.aux_classifier = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1),
                nn.Conv2d(256, out_classes, kernel_size=1)
            )

    def forward(self, x):
        return self.backbone(x)['out']

    def freeze_backbone(self):
        """Freeze backbone parameters except classifier"""
        for name, param in self.named_parameters():
            if 'classifier' not in name:
                param.requires_grad = False
        print("Backbone frozen, only classifier trainable")

    def unfreeze_backbone(self):
        """Unfreeze all parameters"""
        for param in self.parameters():
            param.requires_grad = True
        print("All layers unfrozen")

# --- Model Setup with Transfer Learning ---
def setup_transfer_learning_model(model_type="pretrained_unet"):
    """
    Setup model with transfer learning
    """
    if model_type == "pretrained_unet":
        model = PretrainedEncoderUNet(
            in_channels=IN_CHANNELS,
            out_classes=OUT_CLASSES,
            pretrained=True
        ).to(DEVICE)

        if FREEZE_ENCODER:
            model.freeze_encoder()

    elif model_type == "deeplab":
        model = DeepLabTransferUNet(out_classes=OUT_CLASSES).to(DEVICE)
        if FREEZE_ENCODER:
            model.freeze_backbone()

    return model

# --- Custom Optimizer Setup for Transfer Learning ---
def setup_optimizer(model, model_type="pretrained_unet"):
    """
    Setup optimizer with different learning rates for pretrained and new layers
    """
    if model_type == "pretrained_unet":
        # Different learning rates for encoder and decoder
        encoder_params = []
        decoder_params = []

        for name, param in model.named_parameters():
            if param.requires_grad:
                if 'encoder' in name:
                    encoder_params.append(param)
                else:
                    decoder_params.append(param)

        optimizer = optim.Adam([
            {'params': encoder_params, 'lr': LEARNING_RATE * PRETRAINED_LR_FACTOR},
            {'params': decoder_params, 'lr': LEARNING_RATE}
        ])

    else:
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    return optimizer

# --- Enhanced Training Loop with Transfer Learning ---
def train_with_transfer_learning(model, train_loader, val_loader, model_type="pretrained_unet"):
    """
    Training loop optimized for transfer learning
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = setup_optimizer(model, model_type)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    print("Starting transfer learning training...")
    history = {'train_loss': [], 'val_loss': [], 'lr': []}
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_weights = copy.deepcopy(model.state_dict())

    for epoch in range(NUM_EPOCHS):
        start_time = time.time()

        # Unfreeze encoder after specified epochs
        if epoch == UNFREEZE_AFTER_EPOCHS and FREEZE_ENCODER:
            if model_type == "pretrained_unet":
                model.unfreeze_encoder()
            elif model_type == "deeplab":
                model.unfreeze_backbone()

            # Recreate optimizer with unfrozen parameters
            optimizer = setup_optimizer(model, model_type)
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='min', factor=0.5, patience=5, verbose=True
            )

        # Training phase
        model.train()
        running_train_loss = 0.0

        for i, (images, masks) in enumerate(train_loader):
            images, masks = images.to(DEVICE), masks.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(images)

            # Handle different output formats
            if isinstance(outputs, dict):
                outputs = outputs['out']

            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item()

            if (i + 1) % 20 == 0:
                print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], "
                      f"Batch Loss: {loss.item():.4f}")

        avg_train_loss = running_train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)

        # Validation phase
        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(DEVICE), masks.to(DEVICE)
                outputs = model(images)

                if isinstance(outputs, dict):
                    outputs = outputs['out']

                loss = criterion(outputs, masks)
                running_val_loss += loss.item()

        avg_val_loss = running_val_loss / len(val_loader)
        history['val_loss'].append(avg_val_loss)

        # Update learning rate
        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        history['lr'].append(current_lr)

        epoch_time = time.time() - start_time

        print(f"--- Epoch {epoch+1}/{NUM_EPOCHS} Finished ---")
        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        print(f"Learning Rate: {current_lr:.6f}")
        print(f"Epoch Duration: {epoch_time:.2f} seconds")

        # Early stopping check
        if avg_val_loss < best_val_loss - EARLY_STOPPING_MIN_DELTA:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_model_weights = copy.deepcopy(model.state_dict())
            print("Validation loss improved. Saving model weights.")
        else:
            epochs_no_improve += 1
            print(f"Validation loss did not improve for {epochs_no_improve} epoch(s).")

        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            model.load_state_dict(best_model_weights)
            break

        print("-" * 50)

    print("Training finished!")
    model.load_state_dict(best_model_weights)
    return model, history

# --- Usage Example ---
if __name__ == "__main__":
    # Choose model type: "pretrained_unet" or "deeplab"
    MODEL_TYPE = "pretrained_unet"

    # Setup model
    model = setup_transfer_learning_model(MODEL_TYPE)

    print(f"Model type: {MODEL_TYPE}")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # Train model (assuming train_loader and val_loader are defined)
    trained_model, history = train_with_transfer_learning(model, train_loader, val_loader, MODEL_TYPE)

    # Print training history
    print("\n--- Transfer Learning Training History ---")
    for i in range(len(history['train_loss'])):
         print(f"Epoch {i+1}: Train Loss = {history['train_loss'][i]:.4f}, "
               f"Val Loss = {history['val_loss'][i]:.4f}, LR = {history['lr'][i]:.6f}")

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 161MB/s]


Encoder layers frozen
Model type: pretrained_unet
Total parameters: 50,033,474
Trainable parameters: 26,525,442
Starting transfer learning training...




--- Epoch 1/50 Finished ---
Train Loss: 0.6584 | Val Loss: 0.6819
Learning Rate: 0.000010
Epoch Duration: 372.66 seconds
Validation loss improved. Saving model weights.
--------------------------------------------------
--- Epoch 2/50 Finished ---
Train Loss: 0.4918 | Val Loss: 0.6162
Learning Rate: 0.000010
Epoch Duration: 171.36 seconds
Validation loss improved. Saving model weights.
--------------------------------------------------
--- Epoch 3/50 Finished ---
Train Loss: 0.4321 | Val Loss: 0.4583
Learning Rate: 0.000010
Epoch Duration: 167.91 seconds
Validation loss improved. Saving model weights.
--------------------------------------------------
--- Epoch 4/50 Finished ---
Train Loss: 0.3906 | Val Loss: 0.3840
Learning Rate: 0.000010
Epoch Duration: 168.69 seconds
Validation loss improved. Saving model weights.
--------------------------------------------------
--- Epoch 5/50 Finished ---
Train Loss: 0.3605 | Val Loss: 0.3589
Learning Rate: 0.000010
Epoch Duration: 164.87 seconds

In [9]:
import torch
import torch.nn as nn
import os
import json
import pickle
from datetime import datetime
import shutil
import warnings

class ModelSaver:
    """
    Comprehensive model saving and loading utility
    """

    def __init__(self, save_dir="saved_models"):
        self.save_dir = save_dir
        self.create_save_directory()

    def create_save_directory(self):
        """Create directory structure for saving models"""
        os.makedirs(self.save_dir, exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, "checkpoints"), exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, "best_models"), exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, "final_models"), exist_ok=True)
        print(f"Model save directory created: {self.save_dir}")

    def save_model_complete(self, model, optimizer, scheduler, history,
                           epoch, val_loss, model_name="unet_model",
                           save_type="final", additional_info=None):
        """
        Save complete model state with all training information

        Args:
            model: The neural network model
            optimizer: The optimizer used for training
            scheduler: Learning rate scheduler (can be None)
            history: Training history dictionary
            epoch: Current epoch number
            val_loss: Current validation loss
            model_name: Base name for the model
            save_type: Type of save ("final", "best", "checkpoint")
            additional_info: Dictionary with additional information to save
        """
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        # Determine save path based on type
        if save_type == "checkpoint":
            save_path = os.path.join(self.save_dir, "checkpoints", f"{model_name}_epoch_{epoch}_{timestamp}")
        elif save_type == "best":
            save_path = os.path.join(self.save_dir, "best_models", f"{model_name}_best_{timestamp}")
        else:  # final
            save_path = os.path.join(self.save_dir, "final_models", f"{model_name}_final_{timestamp}")

        # Create directory for this specific save
        os.makedirs(save_path, exist_ok=True)

        # 1. Save model state dict (most common format)
        model_state_path = os.path.join(save_path, f"{model_name}_state_dict.pth")
        torch.save(model.state_dict(), model_state_path)

        # 2. Save complete model (entire model architecture + weights)
        complete_model_path = os.path.join(save_path, f"{model_name}_complete.pth")
        torch.save(model, complete_model_path)

        # 3. Save training checkpoint (everything needed to resume training)
        checkpoint_data = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'epoch': epoch,
            'val_loss': val_loss,
            'history': history,
            'model_class': model.__class__.__name__,
            'model_params': self._get_model_params(model),
            'timestamp': timestamp,
            'save_type': save_type
        }

        if additional_info:
            checkpoint_data.update(additional_info)

        checkpoint_path = os.path.join(save_path, f"{model_name}_checkpoint.pth")
        torch.save(checkpoint_data, checkpoint_path)

        # 4. Save model configuration as JSON
        config = {
            'model_name': model_name,
            'model_class': model.__class__.__name__,
            'model_parameters': self._get_model_params(model),
            'total_parameters': sum(p.numel() for p in model.parameters()),
            'trainable_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad),
            'epoch': epoch,
            'validation_loss': val_loss,
            'timestamp': timestamp,
            'save_type': save_type,
            'pytorch_version': torch.__version__,
            'model_size_mb': os.path.getsize(complete_model_path) / (1024 * 1024)
        }

        if additional_info:
            config['additional_info'] = additional_info

        config_path = os.path.join(save_path, f"{model_name}_config.json")
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2, default=str)

        # 5. Save training history as JSON
        if history:
            history_path = os.path.join(save_path, f"{model_name}_history.json")
            with open(history_path, 'w') as f:
                json.dump(history, f, indent=2)

        # 6. Save model summary as text
        summary_path = os.path.join(save_path, f"{model_name}_summary.txt")
        self._save_model_summary(model, summary_path, config)

        print(f"Model saved successfully!")
        print(f"Save location: {save_path}")
        print(f"Files saved:")
        print(f"  - State dict: {model_name}_state_dict.pth")
        print(f"  - Complete model: {model_name}_complete.pth")
        print(f"  - Checkpoint: {model_name}_checkpoint.pth")
        print(f"  - Configuration: {model_name}_config.json")
        print(f"  - History: {model_name}_history.json")
        print(f"  - Summary: {model_name}_summary.txt")

        return save_path

    def save_model_state_only(self, model, model_name="unet_model", save_dir=None):
        """
        Save only the model state dict (lightweight option)
        """
        if save_dir is None:
            save_dir = self.save_dir

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{model_name}_state_{timestamp}.pth"
        filepath = os.path.join(save_dir, filename)

        torch.save(model.state_dict(), filepath)
        print(f"Model state dict saved: {filepath}")
        return filepath

    def save_onnx_model(self, model, input_shape, model_name="unet_model", save_dir=None):
        """
        Save model in ONNX format for deployment

        Args:
            model: PyTorch model
            input_shape: Tuple of input dimensions (batch_size, channels, height, width)
            model_name: Name for the model
            save_dir: Directory to save (optional)
        """
        try:
            if save_dir is None:
                save_dir = self.save_dir

            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"{model_name}_onnx_{timestamp}.onnx"
            filepath = os.path.join(save_dir, filename)

            # Create dummy input for tracing
            dummy_input = torch.randn(input_shape)

            # Export to ONNX
            torch.onnx.export(
                model,
                dummy_input,
                filepath,
                export_params=True,
                opset_version=11,
                do_constant_folding=True,
                input_names=['input'],
                output_names=['output'],
                dynamic_axes={
                    'input': {0: 'batch_size'},
                    'output': {0: 'batch_size'}
                }
            )

            print(f"ONNX model saved: {filepath}")
            return filepath

        except ImportError:
            print("ONNX export requires 'onnx' package. Install with: pip install onnx")
            return None
        except Exception as e:
            print(f"Error saving ONNX model: {e}")
            return None

    def load_model_complete(self, load_path, model_class=None, device="cpu"):
        """
        Load complete model with all training information

        Args:
            load_path: Path to the saved model directory or checkpoint file
            model_class: Model class (if loading state dict)
            device: Device to load model on

        Returns:
            Dictionary containing model, optimizer, scheduler, history, etc.
        """
        if os.path.isdir(load_path):
            # Find checkpoint file in directory
            checkpoint_files = [f for f in os.listdir(load_path) if f.endswith('_checkpoint.pth')]
            if not checkpoint_files:
                raise FileNotFoundError(f"No checkpoint file found in {load_path}")
            checkpoint_path = os.path.join(load_path, checkpoint_files[0])
        else:
            checkpoint_path = load_path

        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=device)

        print(f"Loading model from: {checkpoint_path}")
        print(f"Model class: {checkpoint.get('model_class', 'Unknown')}")
        print(f"Saved at epoch: {checkpoint.get('epoch', 'Unknown')}")
        print(f"Validation loss: {checkpoint.get('val_loss', 'Unknown')}")

        result = {
            'checkpoint_data': checkpoint,
            'epoch': checkpoint.get('epoch', 0),
            'val_loss': checkpoint.get('val_loss', float('inf')),
            'history': checkpoint.get('history', {}),
            'model_params': checkpoint.get('model_params', {}),
            'timestamp': checkpoint.get('timestamp', 'Unknown')
        }

        # Try to load the complete model first
        if os.path.isdir(load_path):
            complete_model_files = [f for f in os.listdir(load_path) if f.endswith('_complete.pth')]
            if complete_model_files:
                complete_model_path = os.path.join(load_path, complete_model_files[0])
                try:
                    model = torch.load(complete_model_path, map_location=device)
                    result['model'] = model
                    print("Complete model loaded successfully")
                except Exception as e:
                    print(f"Could not load complete model: {e}")

        # If complete model loading failed, try to reconstruct from state dict
        if 'model' not in result and model_class is not None:
            try:
                model_params = checkpoint.get('model_params', {})
                model = model_class(**model_params)
                model.load_state_dict(checkpoint['model_state_dict'])
                model.to(device)
                result['model'] = model
                print("Model reconstructed from state dict")
            except Exception as e:
                print(f"Could not reconstruct model: {e}")

        return result

    def load_model_state_only(self, filepath, model_class, model_params=None, device="cpu"):
        """
        Load only model state dict

        Args:
            filepath: Path to state dict file
            model_class: Class of the model
            model_params: Parameters to initialize the model
            device: Device to load on
        """
        if model_params is None:
            model_params = {}

        model = model_class(**model_params)
        state_dict = torch.load(filepath, map_location=device)
        model.load_state_dict(state_dict)
        model.to(device)

        print(f"Model state loaded from: {filepath}")
        return model

    def _get_model_params(self, model):
        """Extract model parameters for reconstruction"""
        params = {}

        # Try to get common parameters
        if hasattr(model, 'in_channels'):
            params['in_channels'] = model.in_channels
        if hasattr(model, 'out_classes'):
            params['out_classes'] = model.out_classes
        if hasattr(model, 'num_classes'):
            params['num_classes'] = model.num_classes

        return params

    def _save_model_summary(self, model, filepath, config):
        """Save model summary as text file"""
        with open(filepath, 'w') as f:
            f.write("=" * 50 + "\n")
            f.write("MODEL SUMMARY\n")
            f.write("=" * 50 + "\n\n")

            f.write(f"Model Name: {config['model_name']}\n")
            f.write(f"Model Class: {config['model_class']}\n")
            f.write(f"Total Parameters: {config['total_parameters']:,}\n")
            f.write(f"Trainable Parameters: {config['trainable_parameters']:,}\n")
            f.write(f"Model Size: {config['model_size_mb']:.2f} MB\n")
            f.write(f"PyTorch Version: {config['pytorch_version']}\n")
            f.write(f"Save Timestamp: {config['timestamp']}\n")
            f.write(f"Epoch: {config['epoch']}\n")
            f.write(f"Validation Loss: {config['validation_loss']}\n\n")

            f.write("MODEL ARCHITECTURE:\n")
            f.write("-" * 30 + "\n")
            f.write(str(model) + "\n\n")

            if config.get('model_parameters'):
                f.write("MODEL PARAMETERS:\n")
                f.write("-" * 30 + "\n")
                for key, value in config['model_parameters'].items():
                    f.write(f"{key}: {value}\n")

# --- Training Integration Functions ---
def save_during_training(model, optimizer, scheduler, history, epoch, val_loss,
                        best_val_loss, saver, model_name="unet_model"):
    """
    Function to integrate with training loop for automatic saving
    """
    # Save checkpoint every 10 epochs
    if epoch % 10 == 0:
        saver.save_model_complete(
            model, optimizer, scheduler, history, epoch, val_loss,
            model_name, save_type="checkpoint"
        )

    # Save best model when validation improves
    if val_loss < best_val_loss:
        saver.save_model_complete(
            model, optimizer, scheduler, history, epoch, val_loss,
            model_name, save_type="best"
        )
        print("New best model saved!")

    return val_loss < best_val_loss

def save_final_model(model, optimizer, scheduler, history, epoch, val_loss,
                    saver, model_name="unet_model", additional_info=None):
    """
    Save final model after training completion
    """
    print("Saving final model...")

    # Add training completion info
    final_info = {
        'training_completed': True,
        'final_epoch': epoch,
        'final_val_loss': val_loss
    }

    if additional_info:
        final_info.update(additional_info)

    save_path = saver.save_model_complete(
        model, optimizer, scheduler, history, epoch, val_loss,
        model_name, save_type="final", additional_info=final_info
    )

    # Also save lightweight state dict
    saver.save_model_state_only(model, model_name)

    # Save ONNX if possible (assuming standard input size)
    try:
        saver.save_onnx_model(model, (1, 3, 256, 256), model_name)
    except:
        print("ONNX export skipped (requires onnx package)")

    return save_path

# --- Quick Save Functions ---
def quick_save_model(model, filename="model.pth"):
    """Quick save for model state dict"""
    torch.save(model.state_dict(), filename)
    print(f"Model saved: {filename}")

def quick_save_checkpoint(model, optimizer, epoch, loss, filename="checkpoint.pth"):
    """Quick save for training checkpoint"""
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss
    }
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved: {filename}")

def quick_load_model(model, filename="model.pth", device="cpu"):
    """Quick load for model state dict"""
    model.load_state_dict(torch.load(filename, map_location=device))
    model.to(device)
    print(f"Model loaded: {filename}")
    return model

if __name__ == "__main__":
    quick_save_model(model, filename="model.pth")

Model saved: model.pth
