<a href="https://colab.research.google.com/github/SSrirabai/DeepLearning/blob/main/DeepLab3Plus_Base_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Installs
!pip install segmentation-models-pytorch




### This is the DeepLab3+ base model


<details>
  <summary>Model Architechture</summary>
  
  - Model Definition: DeepLabV3Plus with 3 input channels.
  
  - Traning Setup
        - num_epochs = 50
        - eary stop mechanism kicks in when improvement_tolerance = 0.00001 is not satisfied.  #  Need to add to the code.
        - learning_rate = 0.001
        - patience = 3 # number of epochs with no change before early stop
        - criterion = nn.CrossEntropyLoss()
        - optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        - scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
  
  - Saver the initial state of the model before training
        - When loaded before each experiment will reset the state of the model and reinitialize optimizer and learning rate schedulers
          
  - Evaluation Metrics
        - Dice Coefficient (F1 Score)
        - Jaccard
        - Precision
        - Recall
        - Accuracy
        - Loss
</details>




In [2]:
# imports

import csv
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from PIL import Image
from sklearn.metrics import jaccard_score, precision_score, recall_score, f1_score
from albumentations import Compose as AlbCompose, Resize as AlbResize, Normalize as AlbNormalize
import matplotlib.pyplot as plt

# Import DeepLabV3+ Model
from segmentation_models_pytorch import DeepLabV3Plus

In [3]:
# This is the code that shold work




# Define device and initialize DeepLabV3+ Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DeepLabV3Plus(encoder_name="resnet50", encoder_weights="imagenet", in_channels=3, classes=2)
model.to(device)

# Function to save the initial state of the model
def save_initial_model_state():
    torch.save(model.state_dict(), 'initial_deeplab_model.pth')
    print("Initial model state saved!")

# Save the initial state of the model before training
save_initial_model_state()

# Training setup
#num_epochs = 5  # use for testing the code.
num_epochs = 50  # normally 50 epochs
learning_rate = 0.001
improvement_tolerance = 0.0001  # Minimum improvement required to reset the counter
patience = 3
criterion = nn.CrossEntropyLoss()

# Initialize metric lists to store the results.
train_losses, test_losses = [], []
train_accuracies, test_accuracies = [], []
train_jaccards, test_jaccards = [], []
train_precisions, test_precisions = [], []
train_recalls, test_recalls = [], []
train_dices, test_dices = [], []

# Data transformations
def get_transforms():
    return AlbCompose([
        AlbResize(height=256, width=256, interpolation=1),  # 1 corresponds to 'bilinear'
        AlbNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])

# Dataset Class
class KvasirDataset(Dataset):
    def __init__(self, images_dir, masks_dir, file_list, transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        with open(file_list, 'r') as f:
            self.image_filenames = f.read().splitlines()

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.images_dir, self.image_filenames[idx] + '.jpg')).convert("RGB")
        mask = Image.open(os.path.join(self.masks_dir, self.image_filenames[idx] + '.png')).convert("L")

        # Convert images to numpy arrays
        image_np = np.array(image)
        mask_np = np.array(mask)

        # Apply transformations
        if self.transform:
            augmented = self.transform(image=image_np, mask=mask_np)  # Use the correct keyword arguments
            image, mask = augmented['image'], augmented['mask']

        mask = (mask > 0).astype(np.int64)  # Convert mask to binary (0 or 1)

        # Convert to PyTorch tensors
        image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)  # Change shape to C x H x W
        mask = torch.tensor(mask, dtype=torch.long)  # Keep mask as long for class indices

        return image, mask

# Load datasets and define data loaders
train_dataset = KvasirDataset('kvasir-instrument/images/', 'kvasir-instrument/masks/', 'kvasir-instrument/train.txt', transform=get_transforms())
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataset = KvasirDataset('kvasir-instrument/images/', 'kvasir-instrument/masks/', 'kvasir-instrument/test.txt', transform=get_transforms())
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Metrics functions
def calculate_accuracy(preds, targets):
    correct = (preds == targets).float()
    accuracy = correct.sum() / correct.numel()
    return accuracy

def calculate_metrics(preds, targets):
    """Calculate Jaccard, Precision, Recall, Accuracy, and Dice Coefficient for binary classification."""
    preds_flat = preds.flatten()
    targets_flat = targets.flatten()
    jaccard = jaccard_score(targets_flat, preds_flat, average='binary')
    precision = precision_score(targets_flat, preds_flat, zero_division=0)
    recall = recall_score(targets_flat, preds_flat, zero_division=0)
    accuracy = (preds_flat == targets_flat).sum() / len(targets_flat)

    # Dice Coefficient as F1 Score in binary case
    dice = f1_score(targets_flat, preds_flat, average='binary')

    return jaccard, precision, recall, accuracy, dice
# Training function with early stopping using improvement tolerance
def train_and_evaluate(model, train_loader, test_loader):
    best_test_loss = float('inf')
    early_stop_counter = 0

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    for epoch in range(num_epochs):
        model.train()
        running_loss, running_acc = 0.0, 0.0
        all_train_jaccard, all_train_precision, all_train_recall, all_train_dice = [], [], [], []

        for images, masks in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            preds = torch.argmax(outputs, dim=1)
            running_loss += loss.item()
            running_acc += calculate_accuracy(preds, masks).item()

            # Calculate training metrics
            jaccard, precision, recall, _, dice = calculate_metrics(preds.cpu().numpy(), masks.cpu().numpy())
            all_train_jaccard.append(jaccard)
            all_train_precision.append(precision)
            all_train_recall.append(recall)
            all_train_dice.append(dice)

        # Log training metrics
        avg_train_loss = running_loss / len(train_loader)
        avg_train_acc = running_acc / len(train_loader)
        avg_train_jaccard = sum(all_train_jaccard) / len(all_train_jaccard)
        avg_train_precision = sum(all_train_precision) / len(all_train_precision)
        avg_train_recall = sum(all_train_recall) / len(all_train_recall)
        avg_train_dice = sum(all_train_dice) / len(all_train_dice)

        train_losses.append(avg_train_loss)
        train_accuracies.append(avg_train_acc)
        train_jaccards.append(avg_train_jaccard)
        train_precisions.append(avg_train_precision)
        train_recalls.append(avg_train_recall)
        train_dices.append(avg_train_dice)

        # Validation/Test Step
        model.eval()
        with torch.no_grad():
            running_test_loss, running_test_acc = 0.0, 0.0
            all_jaccard, all_precision, all_recall, all_accuracy, all_dice = [], [], [], [], []

            for images, masks in test_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                test_loss = criterion(outputs, masks)
                running_test_loss += test_loss.item()

                preds = torch.argmax(outputs, dim=1)
                jaccard, precision, recall, accuracy, dice = calculate_metrics(preds.cpu().numpy(), masks.cpu().numpy())
                all_jaccard.append(jaccard)
                all_precision.append(precision)
                all_recall.append(recall)
                all_accuracy.append(accuracy)
                all_dice.append(dice)

            # Average metrics
            avg_test_loss = running_test_loss / len(test_loader)
            avg_test_acc = sum(all_accuracy) / len(all_accuracy)
            avg_jaccard = sum(all_jaccard) / len(all_jaccard)
            avg_precision = sum(all_precision) / len(all_precision)
            avg_recall = sum(all_recall) / len(all_recall)
            avg_dice = sum(all_dice) / len(all_dice)

            # Log test metrics
            test_losses.append(avg_test_loss)
            test_accuracies.append(avg_test_acc)
            test_jaccards.append(avg_jaccard)
            test_precisions.append(avg_precision)
            test_recalls.append(avg_recall)
            test_dices.append(avg_dice)

            # Print metrics after each epoch
            print(f"Epoch [{epoch + 1}/{num_epochs}]:")
            print(f"  Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_acc:.4f}, Train Jaccard: {avg_train_jaccard:.4f}, Train Precision: {avg_train_precision:.4f}, Train Recall: {avg_train_recall:.4f}, Train Dice: {avg_train_dice:.4f}")
            print(f"  Test Loss: {avg_test_loss:.4f}, Test Accuracy: {avg_test_acc:.4f}, Jaccard: {avg_jaccard:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f}, Dice: {avg_dice:.4f}")

            # Check for early stopping with improvement tolerance
            if avg_test_loss < best_test_loss - improvement_tolerance:
                best_test_loss = avg_test_loss
                early_stop_counter = 0
                torch.save({'model_state_dict': model.state_dict(),
                             'optimizer_state_dict': optimizer.state_dict(),
                             'epoch': epoch + 1,
                             'test_loss': avg_test_loss,
                             'test_accuracy': avg_test_acc,
                             'jaccard': avg_jaccard,
                             'precision': avg_precision,
                             'recall': avg_recall,
                             'dice': avg_dice}, 'deeplab_best_model.pth')
                print("Best model saved!")
            else:
                early_stop_counter += 1

            if early_stop_counter >= patience:
                print(f"Early stopping at epoch {epoch + 1} due to lack of improvement.")
                break
        scheduler.step()

    # Return all metrics for plotting
    return (train_losses, test_losses, train_accuracies, test_accuracies,
            train_jaccards, test_jaccards, train_precisions, test_precisions,
            train_recalls, test_recalls, train_dices, test_dices)


# Plot Loss and Accuracy.
def plot_loss_and_accuracy(train_losses, test_losses, train_accuracies, test_accuracies):
    epochs = range(1, len(train_losses) + 1)


    # Plot Losses
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Training Loss')
    plt.plot(epochs, test_losses, label='Testing Loss')
    plt.title('Training and Testing Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Plot Accuracies
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label='Training Accuracy')
    plt.plot(epochs, test_accuracies, label='Testing Accuracy')
    plt.title('Training and Testing Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()


# Plotting function update to include Dice Coefficient
def plot_combined_metrics(
    train_jaccard, test_jaccard,
    train_precision, test_precision,
    train_recall, test_recall,
    train_dice, test_dice,
    title='Combined Metrics Plot'
):
    plt.figure(figsize=(12, 8))

    # Jaccard
    plt.subplot(2, 2, 1)
    plt.plot(train_jaccard, label='Train Jaccard', color='blue')
    plt.plot(test_jaccard, label='Test Jaccard', color='orange')
    plt.title('Jaccard Index')
    plt.xlabel('Epochs')
    plt.ylabel('Jaccard Score')
    plt.legend()

    # Precision
    plt.subplot(2, 2, 2)
    plt.plot(train_precision, label='Train Precision', color='green')
    plt.plot(test_precision, label='Test Precision', color='red')
    plt.title('Precision')
    plt.xlabel('Epochs')
    plt.ylabel('Precision Score')
    plt.legend()

    # Recall
    plt.subplot(2, 2, 3)
    plt.plot(train_recall, label='Train Recall', color='purple')
    plt.plot(test_recall, label='Test Recall', color='brown')
    plt.title('Recall')
    plt.xlabel('Epochs')
    plt.ylabel('Recall Score')
    plt.legend()

    # Dice
    plt.subplot(2, 2, 4)
    plt.plot(train_dice, label='Train Dice', color='pink')
    plt.plot(test_dice, label='Test Dice', color='gray')
    plt.title('Dice Coefficient')
    plt.xlabel('Epochs')
    plt.ylabel('Dice Score')
    plt.legend()

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()


Initial model state saved!


### Call the functions train, evaluate and plot results.


In [None]:
# Run training, evaluation and safe results
results = train_and_evaluate(model, train_loader, test_loader)


Epoch 1/50: 100%|██████████| 59/59 [02:03<00:00,  2.10s/it]


Epoch [1/50]:
  Train Loss: 0.1658, Train Accuracy: 0.9347, Train Jaccard: 0.6198, Train Precision: 0.7683, Train Recall: 0.7616, Train Dice: 0.7486
  Test Loss: 0.0890, Test Accuracy: 0.9666, Jaccard: 0.7098, Precision: 0.7810, Recall: 0.8843, Dice: 0.8266
Best model saved!


Epoch 2/50:  92%|█████████▏| 54/59 [01:37<00:09,  1.86s/it]

### Saving the results to CSV File and Plot results.

In [None]:
# Save the metrics to a CSV file
with open('training_results.csv', mode='w', newline='') as csvfile:
    fieldnames = ["Epoch", "Train Loss", "Test Loss", "Train Accuracy",
                  "Test Accuracy", "Train Jaccard", "Test Jaccard",
                  "Train Precision", "Test Precision", "Train Recall",
                  "Test Recall", "Train Dice", "Test Dice"]

    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

    # Write the header
    writer.writeheader()

    # Write the rows
    for epoch in range(len(train_losses)):
        writer.writerow({
            "Epoch": epoch + 1,
            "Train Loss": train_losses[epoch],
            "Test Loss": test_losses[epoch],
            "Train Accuracy": train_accuracies[epoch],
            "Test Accuracy": test_accuracies[epoch],
            "Train Jaccard": train_jaccards[epoch],
            "Test Jaccard": test_jaccards[epoch],
            "Train Precision": train_precisions[epoch],
            "Test Precision": test_precisions[epoch],
            "Train Recall": train_recalls[epoch],
            "Test Recall": test_recalls[epoch],
            "Train Dice": train_dices[epoch],
            "Test Dice": test_dices[epoch],
        })

# Unpack results for plotting
train_losses, test_losses, train_accuracies, test_accuracies, train_jaccards, test_jaccards, train_precisions, test_precisions, train_recalls, test_recalls, train_dices, test_dices = results

# Plotting
plot_loss_and_accuracy(train_losses, test_losses, train_accuracies, test_accuracies)
plot_combined_metrics(train_jaccards, test_jaccards,
                      train_precisions, test_precisions,
                      train_recalls, test_recalls,
                      train_dices, test_dices,
                      title='Combined Metrics Plot')