In [None]:
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss
from torch.utils.data.dataset import Dataset 
from torch.utils.data import DataLoader
from torchvision import transforms,models
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
from torchvision.transforms.functional import to_pil_image

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device object

In [None]:
test_dir = '../Data/test/flooded/'

In [None]:
transforms_test = transforms.Compose([
    transforms.Resize((224, 224)),   #must same as here
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
class TestDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_files = os.listdir(folder_path)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        image_file = self.image_files[index]
     
        image_path =os.path.join(self.folder_path, image_file)
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

In [None]:
test_dataset = TestDataset(test_dir, transforms_test)

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8)

In [None]:
os.path.join(test_dataset.folder_path, os.listdir(test_dataset.folder_path)[0])

In [None]:
classification_model =models.regnet_x_1_6gf()
num_features = classification_model.fc.in_features   
classification_model.fc = nn.Linear(num_features, 2) 
classification_model.load_state_dict(torch.load("classification_model.pt"))
classification_model.to(device)

In [None]:
classification_model.eval()
results = []
with torch.no_grad():
        for inputs in tqdm(test_dataloader, total=len(test_dataloader)):
                inputs = inputs.to(device) 
                outputs = classification_model(inputs)
                _, preds = torch.max(outputs, 1)
                results.append(preds)
        #0 flooded , 1 non flooded 
        results = torch.cat(results, dim=0)
        results = results.to('cpu').numpy().flatten()

In [None]:
with open('preds.txt', 'w') as f:
    # Loop over the values
    for value in results:
        # Check if the value is equal to 0
        if value == 0:
            # Write "flooded" to the text file
            f.write('1\n')
        else:
            # Write "non flooded" to the text file
            f.write('0\n')

In [87]:
# Make another pass through the data this time check if the results[index] == 0
# This means that the image is flooded --> Go ahead and segment it and produce a new 
# Image with the two segments colored differently

ENCODER='resnet18'
WEIGHTS='imagenet'

segmentation_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

test_dataset.transform = segmentation_transforms

output_dir = 'segmented_images'
os.makedirs(output_dir, exist_ok=True)

class SegmentationModel(nn.Module):
    def __init__(self):
        super(SegmentationModel,self).__init__()

        self.arc=smp.Unet(
            encoder_name=ENCODER,
            encoder_weights=WEIGHTS,
            in_channels=3,
            classes=1,
            activation=None
        )
    def forward(self,images,masks=None):
        logits=self.arc(images)

        if masks!=None:
            loss1=DiceLoss(mode='binary')(logits,masks)
            loss2=nn.BCEWithLogitsLoss()(logits,masks)
            return logits,loss1,loss2
        return logits
    
def apply_mask(image, mask, segment_1_color=(0, 0, 255), segment_2_color=(1, 1, 1)):
    # Convert the image and mask to numpy arrays
    image_np = np.array(image)
    mask_np = np.array(mask)

    # Create a new RGB image with the specified color where the mask is 1
    color_image = np.zeros(image_np.shape, dtype=np.uint8)
    color_image[mask_np == 1] = segment_1_color
    color_image[mask_np != 1] = segment_2_color
    
    # Combine the original image with the color image
    masked_image = Image.fromarray(color_image)

    return masked_image
    
segmentation_model = SegmentationModel()
segmentation_model.load_state_dict(torch.load("segmentation_model.pt"))
segmentation_model.to(device)

segmentation_model.eval()
with torch.no_grad():
    for index, inputs in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        if results[index] == 1:
            continue
        inputs = inputs.to(device)
        outputs = segmentation_model(inputs)
        predicted = outputs > 0.5
        mask = predicted.squeeze().cpu().numpy()

        # Convert the inputs tensor to a PIL Image
        image = to_pil_image(inputs.squeeze().cpu())

        masked_image = apply_mask(image, mask)
        
        original_image = Image.open(os.path.join(test_dataset.folder_path, os.listdir(test_dataset.folder_path)[index]))
        masked_image = masked_image.resize(original_image.size)
        
        masked_image.save(os.path.join(output_dir, os.listdir(test_dataset.folder_path)[index]))

100%|██████████| 70/70 [00:10<00:00,  6.89it/s]
