In [None]:
from flask import Flask, render_template, request, send_from_directory
from flask_ngrok import run_with_ngrok
import os
from PIL import Image
import tifffile as tiff
from torchvision import transforms
import numpy as np
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp

# Define the model architecture
class WaterSegmentationModel(torch.nn.Module):
    def __init__(self):
        super(WaterSegmentationModel, self).__init__()
        self.pre_conv = nn.Conv2d(12, 3, kernel_size=1)  # Reduce channels from 12 to 3
        self.unet = smp.Unet(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=3,
            classes=1,
            activation=None
        )

    def forward(self, x):
        x = self.pre_conv(x)  # Apply 1x1 convolution to reduce input channels
        return self.unet(x)  # Pass through U-Net

# Initialize Flask app
app = Flask(__name__)
run_with_ngrok(app)  # Start ngrok when app is run

# Load the trained model
model_path = 'best_model.pth'
model = WaterSegmentationModel()
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
model.pre_conv.load_state_dict(checkpoint['pre_conv_state_dict'])
model.unet.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Define the preprocessing transformation
preprocess = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

@app.route('/', methods=['GET'])
def home():
    return render_template('index.html')

@app.route('/images/<path:filename>')
def serve_image(filename):
    return send_from_directory(os.path.join(app.root_path, 'images'), filename)

@app.route('/', methods=['POST'])
def predict():
    imagefile = request.files['imagefile']
    image_path = "./images/TifImage/" + imagefile.filename
    imagefile.save(image_path)

    # Load the .tif image
    tif_image = tiff.imread(image_path)
    print("Original image shape:", tif_image.shape)

    if tif_image.shape[-1] != 12:
        return "The image doesn't have 12 channels!"

    img_tensor = torch.tensor(tif_image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)

    # Perform inference
    with torch.no_grad():
        output = model(img_tensor)
        predictions = (torch.sigmoid(output) > 0.5).float()

    # Post-process output
    output_mask = predictions.squeeze(0).cpu().numpy()
    highlighted_image = np.zeros((output_mask.shape[1], output_mask.shape[2], 3), dtype=np.uint8)
    highlighted_image[output_mask[0] == 1] = [255, 255, 255]
    highlighted_image[output_mask[0] == 0] = [0, 0, 0]

    output_image_path = "./images/output_/" + imagefile.filename.replace('.tif', '_highlighted.png')
    output_mask_image = Image.fromarray(highlighted_image)
    output_mask_image.save(output_image_path)

    result_image = 'output_/' + imagefile.filename.replace('.tif', '_highlighted.png')
    return render_template('index.html', result_image=result_image)

if __name__ == '__main__':
    app.run()