# Notebook for SDD Hackathon 2024 wt Vortex.io 

## Library imports and dowloads

In [4]:
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
import pickle
import pandas as pd

## Training

### Fragmentation

In [5]:
def deaggregate_fragment(mask):
    return [(mask == i).astype(int) for i in range(1, np.max(mask)+1)]

In [16]:
train_images_path = "dataset/trainset/images/"
train_masks_path = "dataset/trainset/masks/"
test_images_path = "dataset/testset/images/"
files = os.listdir(train_images_path)
all_fragments = {}
for file in ["training_fragments_1.pkl",
                "training_fragments_2.pkl", 
                "training_fragments_3.pkl", 
                "training_fragments_4.pkl",
                "training_fragments_5.pkl",
                "training_fragments_6.pkl"]:
    with open(file, 'rb') as f:
        aggregated_fragments = pickle.load(f)
        for file in aggregated_fragments:
            all_fragments[file] = deaggregate_fragment(aggregated_fragments[file])

### Fragment labelization

In [26]:
def get_label(frag, file, show=False):
    mask_file = train_masks_path + file.replace(
        "jpg", "png"
    )
    mask = cv2.imread(mask_file)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)//255
    test = mask * frag
    if show:
        plt.subplot(1,3,1)
        img = cv2.imread(train_images_path + file)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        plt.imshow(img)
        plt.subplot(1,3,2)
        plt.imshow(mask, vmin=0, vmax=1)
        plt.subplot(1,3,3)
        plt.imshow(frag, vmin=0, vmax=1)
        plt.show()
    if np.sum(test) == np.sum(frag):
        return True
    return np.sum(test) / np.sum(frag) > .7


### Fragment encoding

In [27]:
# Load a pretrained ResNet model
model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
# Remove the classification layer
model = torch.nn.Sequential(*list(model.children())[:-1])
# Set the model to evaluation mode
model.eval()

# Define preprocessing transformations
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def get_features(frag):
    input_tensor = preprocess(Image.fromarray(frag))
    input_batch = input_tensor.unsqueeze(0)
    with torch.no_grad():
        features = model(input_batch)
    return features.view(-1).numpy()

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\lilia/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:16<00:00, 6.28MB/s]


### Fragment aggregation

In [33]:
def add_file_fragments(directory, file, fragments, training_set):
    if "rh�ne" in file: 
        print("Replacing rh�ne with rhone in {file}: ")
        file = file.replace("rh�ne", "rhone")
        print(file)
    origin = cv2.imread(directory+file)
    origin = cv2.cvtColor(origin, cv2.COLOR_BGR2RGB)
    for frag in fragments:
        h, w, _ = origin.shape
        resized_frag = cv2.resize(frag.astype('float32') , (w,h))
        fragment = np.zeros(origin.shape)
        for i in range(3):
            fragment[:,:,i] = origin[:,:,i] * resized_frag.astype(int)
        training_set.append(
            {
                "filename": file,
                "label": get_label(resized_frag, file),
                "features": get_features(fragment.astype("uint8"))
            }
        )
    

training_set = []
for file in tqdm(all_fragments):
    fragments = all_fragments[file]
    add_file_fragments(train_images_path, file, fragments, training_set)
print(f"{len(training_set)} fragments in training set.")
print()

  2%|▏         | 10/554 [01:11<1:04:41,  7.13s/it]

Replacing rh�ne with rhone in {file}: 
port-saint-louis-du-rhone_1-2022-08-16T12_48_10.jpg





error: OpenCV(4.9.0) D:\a\opencv-python\opencv-python\opencv\modules\imgproc\src\color.cpp:196: error: (-215:Assertion failed) !_src.empty() in function 'cv::cvtColor'


## Testing

### Fragmentation

In [21]:
test_fragments = {}
print("Fragmenting pictures...")
with open("share_test_set.pkl", 'rb') as f:
    aggregated_fragments = pickle.load(f)
    for file in aggregated_fragments:
        test_fragments[file] = deaggregate_fragment(aggregated_fragments[file])
print("Done!")

Fragmenting pictures...
Done!


### Encoding

In [None]:
print("Getting fragment's features...")
testing_set = []
for file in tqdm(test_fragments):
    fragments = test_fragments[file]
    origin = cv2.imread(test_images_path+file)
    origin = cv2.cvtColor(origin, cv2.COLOR_BGR2RGB)
    h, w, _ = origin.shape
    for frag in fragments:
        fragment = np.zeros(origin.shape)
        for i in range(3):
            fragment[:,:,i] = origin[:,:,i] * cv2.resize(frag.astype("float32"), (w, h)).astype(int)
        testing_set.append(
            {
                "filename": file,
                "mask": frag.astype(int),
                "features": get_features(fragment.astype("uint8"))
            }
        )
print("Done!")
print()

### Labelization

In [None]:
def is_frag_water(frag, show=False):
    scores = [cosine_similarity(frag["features"].reshape(1,-1), 
                               o["features"].reshape(1,-1)
                              )[0][0] 
              for o in training_set]
    best = np.argmax(scores)
    if show:
        plt.imshow(training_set[best]["fragment"])
        plt.show()
    return training_set[best]["label"]


print("Labeling test fragments...")
test_features = [o['features'] for o in testing_set]
train_features = [o['features'] for o in training_set]
similarity = cosine_similarity(test_features, train_features)
for i, frag in tqdm(enumerate(testing_set)):
    frag["label"] = training_set[np.argmax(similarity[i])]["label"]
print("Done!")

### Mask creation 

In [None]:
print("Merging water masks...")
testing_masks = {}
for frag in tqdm(testing_set):
    file = frag["filename"]
    if frag["label"]:
        if file in testing_masks:
            mask = frag["mask"]
            testing_masks[file]["mask"] = testing_masks[file]["mask"] + cv2.resize(mask.astype("float32"), testing_masks[file]["shape"])
        else:
            mask = {}
            h, w, _ = cv2.imread(test_images_path+file).shape
            mask["shape"] = (w, h)
            mask["mask"] = cv2.resize(frag["mask"].astype("float32"), mask["shape"])
            testing_masks[file] = mask
print("Done!")

### Submission file generation

In [None]:
def boolean_array_to_rle(boolean_array):
    boolean_vector = boolean_array.copy().reshape(1,-1)[0]
    rle = []
    current_idx = 0
    current_value = -1
    for i in range(len(boolean_vector)):
        if boolean_vector[i] >= 1:
            if current_value != 1:
                current_value = 1
                current_idx = i+1
        elif current_value >= 1:
            current_value = 0
            rle.append((current_idx, i-current_idx+1))
            current_idx = i+1
    if current_value >= 1:
        rle.append((current_idx, i-current_idx+1))
    return rle

def rle_to_str(rle):
    return " ".join([f"{e[0]} {e[1]}" for e in rle])

print("Creating submission file...")
files = []
rles = []
for file in tqdm(testing_masks):
    mask = testing_masks[file]["mask"]
    name = file.replace("jpg", "png")
    rle = rle_to_str(boolean_array_to_rle(mask))
    files.append(name)
    rles.append(rle)
df = pd.DataFrame({'img_key': files, 'rle_mask': rles})
df.to_csv("../submission.csv", index=False)
print("Done!")

### Visualisation

In [None]:
visu_files = np.random.choice(files, 10)
for i, file in enumerate(visu_files):
    plt.figure(figsize=(12,6))
    plt.subplot(1,2,1)
    origin = cv2.imread(test_images_path+file.replace("png","jpg"))
    origin = cv2.cvtColor(origin, cv2.COLOR_BGR2RGB)
    plt.imshow(origin)
    plt.imshow(testing_masks[file.replace("png","jpg")]['mask'], alpha=.2, cmap='jet', vmax=1, vmin=0)
    plt.subplot(1,2,2)
    plt.imshow(testing_masks[file.replace("png","jpg")]['mask'], cmap='jet', vmin=0, vmax=1)
    plt.show()