In [None]:
import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.xception import preprocess_input as preprocess_xception
from tensorflow.keras.applications import Xception
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input
from tensorflow.keras.callbacks import EarlyStopping
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights, maskrcnn_resnet50_fpn
from torchvision.transforms import Compose, ToTensor, Normalize
import torch
import pandas as pd

In [None]:
def apply_gaussian_blur_and_edge_detection(image_np):
    # Apply modifications directly on NumPy array for efficiency
    # print(image_np.shape, image_np.dtype, type(image_np))
    image_np = cv2.convertScaleAbs(image_np, alpha=1.1, beta=-100)
    blurred = cv2.GaussianBlur(image_np, (5, 5), sigmaX=10)
    gray = cv2.cvtColor(blurred, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(gray, threshold1=10, threshold2=95, L2gradient=True)
    closing = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
    edges_3channel = np.stack((closing,) * 3, axis=-1)
    dilated_edges = cv2.dilate(edges_3channel, None, iterations=1)
    return cv2.addWeighted(image_np, 1, dilated_edges, 0.5, 0)

class ImageAugmentation(object):
    def __call__(self, x):
        return apply_gaussian_blur_and_edge_detection(x)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
def get_model_instance_segmentation(num_classes):
    
    model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
    # Get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # Replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    # Now get the number of input features for the mask predictor and replace the mask head with a new one
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    return model

In [None]:
mask_model = get_model_instance_segmentation(5+1) # 5 classes + background
mask_model = mask_model.to(device)
# Loading a pre-existing trained model parameters
model_state_dict = torch.load('model_weights.pth', map_location=device)
mask_model.load_state_dict(model_state_dict)
mask_model.eval()

In [None]:
def get_transform():
    return Compose([
        ImageAugmentation(),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

In [None]:
def apply_mask(image_np, model, threshold = 0.9):
    transformations = get_transform()
    image_tensor = transformations(image_np).unsqueeze(0).to(device)
    with torch.no_grad():
        prediction = model(image_tensor)
    if prediction[0]['masks'].size(0) > 0:
        mask = (prediction[0]['masks'][0, 0] > threshold).cpu().numpy()
        image_np[mask] = 255
        return image_np
    else:
        return image_np
    

In [None]:
filename_train = "./bttai-nybg-2024/BTTAIxNYBG-train.csv"
df_train = pd.read_csv(filename_train)
filename_test = "./bttai-nybg-2024/BTTAIxNYBG-test.csv"
df_test = pd.read_csv(filename_test)
filename_val = "./bttai-nybg-2024/BTTAIxNYBG-validation.csv"
df_val = pd.read_csv(filename_val)
train_image_directory = "./bttai-nybg-2024/BTTAIxNYBG-train/BTTAIxNYBG-train"
validation_image_directory = "./bttai-nybg-2024/BTTAIxNYBG-validation/BTTAIxNYBG-validation"

In [None]:
def preprocess_and_mask(img):
    image = apply_mask((255*img).astype(np.uint8), mask_model)
    image = preprocess_xception(image)
    return image

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
input_tensor = Input(shape = (1000, 1000, 3))
xception_base = Xception(include_top=False, weights='imagenet', input_tensor = input_tensor)
classes = list(df_train["classLabel"].unique())

In [None]:
xception_output = GlobalAveragePooling2D()(xception_base.output)
xception_output = Dense(1024, activation = 'relu')(xception_output)
predictions = Dense(len(classes), activation = 'sigmoid')(xception_output)

In [None]:
model = Model(inputs = xception_base.input, outputs = predictions)

In [None]:
for layer in xception_base.layers:
    layer.trainable = False

In [None]:
model.compile(optimizer = 'adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
datagen = ImageDataGenerator(preprocessing_function = preprocess_and_mask)

In [None]:
validation_generator = datagen.flow_from_dataframe(
    df_val,
    directory=validation_image_directory,
    x_col='imageFile',
    y_col='classLabel',
    target_size=(1000, 1000),
    batch_size=32,
    class_mode='categorical',
    shuffle = False)

In [None]:
# import matplotlib.pyplot as plt
# def show_image(image, index):
#     # If preprocessing includes normalization, adjust the image to display correctly
#     image = image[index]
#     if np.min(image) < 0:
#         # Rescale to 0-1 if preprocessing involves standardization
#         image = (image - image.min()) / (image.max() - image.min())

#     plt.imshow(image)
#     plt.title("Sample Image")
#     plt.axis('off')
#     plt.show()

# show_image(x, 14)

In [None]:
train_generator = datagen.flow_from_dataframe(
    df_train,
    directory=train_image_directory,
    x_col='imageFile',
    y_col='classLabel',
    target_size=(1000, 1000),
    batch_size=32,
    class_mode='categorical',
    shuffle = True)

In [None]:
early_stopping = EarlyStopping(
    monitor='val_loss',  # Monitor validation loss
    patience=3,  # number of epochs with no improvement after which training will be stopped
    verbose=1,
    mode='min',  # the training will stop when the quantity monitored has stopped decreasing
    restore_best_weights=True  # restore model weights from the epoch with the best value of the monitored quantity
)

In [None]:
history = model.fit(
    train_generator,
    steps_per_epoch=len(train_generator),
    epochs=10,
    validation_data=validation_generator,
    validation_steps=50,
    callbacks=[early_stopping])