In [None]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import albumentations as A
import os
import cv2
from tqdm import tqdm
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.down(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=False):
        super().__init__()
        if bilinear:
            self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                    nn.Conv2d(in_channels, in_channels // 2, 1))
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, stride=2)
            
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.conv(x))

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x4, x3)
        x = self.up3(x3, x2)
        x = self.up4(x2, x1)
        logits = self.outc(x)
        return logits
device = torch.device('cuda')  # or 'cpu' if you want to load the model onto the CPU
model = UNet(3, 1).to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/River networks/water_bodies_model.pth', map_location=device))

image = cv2.imread('/content/drive/MyDrive/River networks/15fb2909-1c07-4f14-9b1f-d636182020cf.jpeg')
image = cv2.resize(image,(256,256))
image_tensor = torch.from_numpy(image.transpose((2, 0, 1))).float().to(device)
import torch
import cv2

# Load image
image = cv2.imread('/content/drive/MyDrive/River networks/15fb2909-1c07-4f14-9b1f-d636182020cf.jpeg')
image = cv2.resize(image, (256, 256))

# Convert to tensor and add batch and channel dimensions
image_tensor = torch.from_numpy(image.transpose((2, 0, 1))).float().unsqueeze(0).to(device)

# Make prediction
with torch.no_grad():
    pred = model(image_tensor).cpu().detach()
    pred = pred > 0.5

# Display result
# display_batch(image_tensor, None, pred)

def denormalise(image):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    image = torch.from_numpy(image.transpose((2, 0, 1))).float()
    image = image * std + mean
    return image.permute(1, 2, 0).numpy()

def display_batch(image, pred):
    image = denormalise(image)
    # masks = masks.permute(0, 2, 3, 1)
    pred = pred.permute(0, 2, 3, 1)

    # image = image.numpy()
    # masks = masks.numpy()
    pred = pred.numpy()

    # image = np.concatenate(image, axis=1)
    # masks = np.concatenate(masks, axis=1)
    pred = np.concatenate(pred, axis=1)

    fig, ax = plt.subplots(2, 1, figsize=(20, 6))
    fig.tight_layout()
    ax[0].imshow(image)
    ax[0].set_title('Images')
    # ax[1].imshow(masks, cmap= 'gray')
    # ax[1].set_title('Masks')
    ax[1].imshow(pred, cmap= 'gray')
    ax[1].set_title('Predictions')

display_batch(image, pred)