In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, Cropping2D, ZeroPadding2D

import numpy as np
import matplotlib.pyplot as plt

import os
import tifffile
from PIL import Image

import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models.segmentation as models


#The dataset


In [2]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


##Image Preprocessing

In [3]:
from torchvision import transforms
import torch

def preprocess_image(image, label):
    # Normalize image by converting it to the range [0, 1]
    image = image / 255.0  # the images are 12-channel, so keep all channels

    # Resize the image and label
    transform_image = transforms.Compose([
        transforms.ToTensor(),  # Convert to Tensor
        transforms.Resize((128, 128))  # Resize to 128x128 if needed
    ])

    transform_label = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((128, 128))  # Resize to 128x128 if needed
    ])

    # Apply transformations
    image = transform_image(image).float()
    label = transform_label(label).long().squeeze(0)  # Convert label to long tensor

    return image, label



In [4]:
class SegmentationDataset(Dataset):
    def __init__(self, image_files, label_files, transform=None):
        self.image_files = image_files
        self.label_files = label_files
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load the image and label
        image = tifffile.imread(self.image_files[idx])  # Load 12-channel image
        label = Image.open(self.label_files[idx])  # Load label as grayscale
        label = np.array(label)  # Convert label to numpy array

        if self.transform:
            image, label = self.transform(image, label)

        return image, label

##Loading the data and creating dataLoader

In [5]:
# Define paths
images_path = '/content/drive/MyDrive/data/images'
labels_path = '/content/drive/MyDrive/data/labels'

# Load all image and label files
image_files = [os.path.join(images_path, f) for f in os.listdir(images_path) if f.endswith('.tif')]
label_files = [os.path.join(labels_path, f) for f in os.listdir(labels_path) if f.endswith('.png')]

# Sort files to ensure they match
image_files.sort()
label_files.sort()

# Create the dataset and DataLoader
train_dataset = SegmentationDataset(image_files, label_files, transform=preprocess_image)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)




#Loading DeepLabV3

In [6]:
import torch.nn as nn

# Loading the pretrained DeepLabV3+ model
deeplab = models.deeplabv3_resnet101(pretrained=True)

# Modifying the first convolutional layer to accept 12 channels instead of 3
deeplab.backbone.conv1 = nn.Conv2d(12, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

# Set the model to evaluation mode
deeplab.eval()

Downloading: "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth" to /root/.cache/torch/hub/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth
100%|██████████| 233M/233M [00:02<00:00, 117MB/s]


DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(12, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): S

In [21]:
# Assuming image_files and label_files contain validation data as well
val_dataset = SegmentationDataset(image_files, label_files, transform=preprocess_image)

# Create a DataLoader for validation (batch size can be the same as for training or different)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)

#Accuracy

In [22]:
def compute_pixel_accuracy(predictions, labels):
    """
    Compute pixel accuracy (the ratio of correctly predicted pixels).

    Args:
        predictions: Predicted segmentation masks (batch_size, H, W)
        labels: Ground truth segmentation masks (batch_size, H, W)

    Returns:
        accuracy: Pixel accuracy
    """
    correct = (predictions == labels).sum().float()
    total = labels.numel()

    accuracy = correct / total
    return accuracy.item()


In [23]:
num_classes = 2

total_accuracy = 0
num_batches = 0

# Loop through the DataLoader for evaluation
for images, labels in val_loader:  # Replace val_loader with your validation/test DataLoader
    with torch.no_grad():
        images = images.float()  # Ensure images are in float32 format

        # Perform the forward pass to get the predictions
        outputs = deeplab(images)['out']
        predictions = torch.argmax(outputs, dim=1)  # Predicted segmentation map

        # Compute metrics for each batch
        accuracy = compute_pixel_accuracy(predictions, labels)

        # Aggregate results
        total_accuracy += accuracy
        num_batches += 1

# Average results across batches
mean_accuracy = total_accuracy / num_batches

# Display the results
print(f'Total Accuracy: {total_accuracy}')
print(f'Mean Pixel Accuracy: {mean_accuracy}')


Total Accuracy: 52.68486022949219
Mean Pixel Accuracy: 0.6842189640193791


#Implementaion on flask

In [10]:
!pip install flask-ngrok
!pip install pyngrok

Collecting flask-ngrok
  Downloading flask_ngrok-0.0.25-py3-none-any.whl.metadata (1.8 kB)
Downloading flask_ngrok-0.0.25-py3-none-any.whl (3.1 kB)
Installing collected packages: flask-ngrok
Successfully installed flask-ngrok-0.0.25
Collecting pyngrok
  Downloading pyngrok-7.2.0-py3-none-any.whl.metadata (7.4 kB)
Downloading pyngrok-7.2.0-py3-none-any.whl (22 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.2.0


In [11]:
!ngrok authtoken 2mDwKIRaihphgFobGExw70mYhTZ_25rfWjqRgWhLv91UxzoMw # remove the <> from your authtoken

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [12]:
!lsof -i :5002
!kill -9 <PID>


/bin/bash: -c: line 1: syntax error near unexpected token `newline'
/bin/bash: -c: line 1: `kill -9 <PID>'


In [13]:
from flask import Flask, request, jsonify
from pyngrok import ngrok
import tensorflow as tf
from PIL import Image
import numpy as np
import threading

# Initialize Flask app
app = Flask(__name__)

# Prediction route
@app.route('/predict', methods=['POST'])

def predict():
    if 'image' not in request.files:
        return jsonify({"error": "No image found"}), 400

    # Get the uploaded image
    file = request.files['image']
    image = Image.open(file).resize((128, 128))
    image = np.array(image)
    image = np.expand_dims(image, axis=0)

    # Make prediction (replace with your model prediction code)
    prediction = DeepLabV3.predict(image)
    result = np.argmax(prediction, axis=1)

    return jsonify({"prediction": int(result[0])})

# Function to run Flask app
def run_app():
    app.run(port=6060)

# Start Flask app in a separate thread
threading.Thread(target=run_app).start()



public_url = ngrok.connect(6060)
print(f"Public URL: {public_url}")



 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:6060
INFO:werkzeug:[33mPress CTRL+C to quit[0m


Public URL: NgrokTunnel: "https://8fbe-34-32-167-184.ngrok-free.app" -> "http://localhost:6060"
