In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from sklearn.model_selection import train_test_split
import random
import numpy as np
from tqdm import tqdm
import argparse
import wandb
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from LF_library import *
from LF_deep_utils import *
from dataset import *
from LF_utils import *
from sklearn.metrics import precision_score, recall_score, f1_score

np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [44]:
import os
from PIL import Image
from torch.utils.data import Dataset
import torch

class MultiImageTileDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        """
        Args:
            image_dir (str): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.image_dir = image_dir
        self.transform = transform
        self.tiles = {}

        # Organize images by tile_number
        for filename in os.listdir(image_dir):
            if filename.endswith(".tif") and "_tile_" in filename and "_rgb_" in filename:
                # Extract the tile number and image index
                parts = filename.split("_")
                image_index = int(parts[2])  # Extracts the number after "rgb_"
                tile_number = int(parts[-1].split(".tif")[0])  # Extracts the tile number after "tile_"

                # Initialize a list for each tile if not already done
                if tile_number not in self.tiles:
                    self.tiles[tile_number] = [None] * 6  # Placeholder for 6 images

                # Store the filename in the correct index slot
                self.tiles[tile_number][image_index] = filename

        # Convert the dictionary to a sorted list of (tile_number, image_list)
        self.tiles = sorted(self.tiles.items())

    def __len__(self):
        return len(self.tiles)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the item in the dataset.

        Returns:
            images (list of PIL Images or transformed Tensors): List of 6 images for the tile.
            tile_number (int): The tile number of the images.
        """
        tile_number, image_filenames = self.tiles[idx]
        images = []

        # Load all 6 images for this tile
        for filename in image_filenames:
            if filename is not None:  # Check if the file exists in the list
                image_path = os.path.join(self.image_dir, filename)
                image = Image.open(image_path).convert("RGB")  # Convert to RGB if needed

                if self.transform:
                    image = self.transform(image)

                images.append(image)

        # Stack images into a tensor if they are transformed (Tensor format)
#         if isinstance(images[0], torch.Tensor):
#             images = torch.stack(images)  # Shape: (6, C, H, W) where 6 is the number of images per tile

        return images, tile_number, image_filenames


In [50]:
from torchvision import transforms
from torch.utils.data import DataLoader

# Define image transformations if needed
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize images if necessary
    transforms.ToTensor()           # Convert images to tensors
])

# Create the dataset
image_dir = "/home/macula/SMATousi/Gullies/ground_truth/Labeling_Tool/MO+IA_test_data_numbered/"
dataset = MultiImageTileDataset(image_dir=image_dir, transform=transform)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Iterate through the DataLoader
# for images, tile_numbers in dataloader:
#     print("Batch of images:", images.shape)
#     print("Tile numbers:", tile_numbers)

In [51]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

resnet_extractor = ResNetFeatureExtractor()
mlp_classifier = MLPClassifier(input_size=6*2048, hidden_size=512, output_size=1)

model = Gully_Classifier(input_size=6*2048, hidden_size=512, output_size=1).to(device)

state_dict = torch.load('../weak-supervision/trained_models/model_epoch_100.pth')
# state_dict_new = torch.load('./artifacts/new_loss/model_epoch_600.pth')

new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)

model.eval()

Gully_Classifier(
  (feature_extractor): ResNetFeatureExtractor(
    (feature_extractor): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   

In [76]:
all_preds = {}


for images, tile_number, file_names in tqdm(dataloader):
    
    list_of_images = [image.to(device) for image in images]
    
    deep_learning_output = model(list_of_images)
    preds = torch.round(deep_learning_output.squeeze()).detach().cpu().numpy()
    
    all_preds[str(tile_number.numpy()[0])] = float(str(preds))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 900/900 [00:41<00:00, 21.73it/s]


In [77]:
import json 

with open('./Labeling_Results/test_dataset_300/DL_preds.json', 'w') as json_file:
    json.dump(all_preds, json_file, indent=4)

In [78]:
with open('./Labeling_Results/test_dataset_300/DL_preds.json', 'r') as json_file:
    data_dict = json.load(json_file)

In [96]:
import json
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

# Load ground truth and predictions
with open('./Labeling_Results/test_dataset_300/GT_L6_Tol.json') as f:
    ground_truth_data = json.load(f)
    
with open('./Labeling_Results/test_dataset_300/DL_preds.json') as f:
    predictions_data = json.load(f)

# Prepare ground truth and prediction lists based on the common tile numbers
tile_numbers = ground_truth_data.keys()
y_true = []
y_pred = []

for tile in tile_numbers:
    if tile in predictions_data:
        # Convert the labels to integer for calculation purposes
        y_true.append(int(ground_truth_data[tile]["label"]))
        y_pred.append(int(predictions_data[tile]))
#         print(tile)
    else:
        print(f'This tile is not in the prediction = {tile}')

y_true = np.array(y_true)/4

# Calculate metrics
precision = precision_score(y_true, y_pred, average='binary', pos_label=1)
recall = recall_score(y_true, y_pred, average='binary', pos_label=1)
f1 = f1_score(y_true, y_pred, average='binary', pos_label=1)
accuracy = accuracy_score(y_true, y_pred)

# Print the metrics in a formatted output
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1:.2f}")
print(f"Accuracy: {accuracy:.2f}")


Precision: 0.87
Recall: 0.85
F1 Score: 0.86
Accuracy: 0.84


In [92]:
y_true

array([0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
       1., 1., 0., 0., 1., 0., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0.,
       0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 1., 1., 0., 0.,
       1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 1.,
       1., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0.,
       0., 0., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0.,
       1., 0., 1., 0., 1., 1., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1.,
       0., 1., 0., 0., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 1., 1.,
       1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1.,
       1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0.,
       1., 1., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 1., 0., 1., 1., 0.,
       0., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 1., 1., 1., 1., 1.,
       1., 0., 1., 1., 1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 1., 0.,
       1., 1., 1., 1., 1.

In [93]:
np.array(y_pred)

array([0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0,
       0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1,
       1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1,
       1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0,
       1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1,
       0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0,
       1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,
       1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0,
       0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0,
       1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1,
       1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0,
       1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
       1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1,
       1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1,