In [None]:
import yaml
from pytorch_lightning.loggers import NeptuneLogger
from Modules.train import DataModule, Model, get_trainer

# Training

In [None]:
%%script false --no-raise-error

with open("Modules/config.yaml", 'r') as stream:
    PARAMS = yaml.safe_load(stream)
    PARAMS = PARAMS['classifier']
    print(PARAMS)

neptune_logger = NeptuneLogger(
        project="kaori/AISeed",
        # with_id="AIS-113",
        log_model_checkpoints=False,
    )

In [None]:
%%script false --no-raise-error

import albumentations as A
from albumentations.pytorch import ToTensorV2

# img_size = PARAMS['dataset_settings']['img_size']
# transforms_img = A.Compose([
#     A.RandomResizedCrop(img_size, img_size),  # Random crop and resize
#     A.HorizontalFlip(),  # Random horizontal flip
#     A.VerticalFlip(),  # Random vertical flip
#     A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
#     A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
#     ToTensorV2(),  # Convert to tensor
# ])

data = DataModule(PARAMS['dataset_settings'], PARAMS['training_settings'], [None, None])

In [None]:
%%script false --no-raise-error
neptune_logger.log_hyperparams(params=PARAMS)
model = Model(PARAMS=PARAMS)
trainer = get_trainer(PARAMS['training_settings'], neptune_logger)
# train
trainer.fit(model, data)
# test
trainer.test(model, data)

In [None]:
# model = Model.load_from_checkpoint(".neptune/Untitled/AIS-113/checkpoints/best_model_001.ckpt")
# trainer = get_trainer(PARAMS['training_settings'], neptune_logger)
# # 
# trainer.test(model, data)

# Testing

In [None]:
# %%script false --no-raise-error
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
import os
from torchvision import transforms

with open("models/class_name.txt", "r", encoding='utf-8') as f:
    class_names = f.read().splitlines()

PARAMS =  {
    "architect_settings" : {
            "task": "None",
            "name": "model-test",
            "backbone": {
                    "name": "selectnet-s",
                    "is_full": False,
                    "is_pretrained": True,
                    "is_freeze": False, 
            },
            "n_cls": 2,
            },
    "dataset_settings": {
            
            },
    "training_settings":{
    
    }
}

model_list = [name.split('.')[0] for name in os.listdir("models") if name.endswith('ckpt')]
model_list += ["fcn-m", "resnet-s", "fasterrcnn-s", "maskrcnn-s"]
inv_normalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
    std=[1/0.229, 1/0.224, 1/0.255]
)

def predict(image, model_choice):
    labels, detection = None, None
    if(model_choice in ["fcn-m", "resnet-s", "fasterrcnn-s", "maskrcnn-s"]):
        PARAMS['architect_settings']['backbone']['is_full'] = True
        PARAMS['architect_settings']['backbone']['name'] = model_choice
        if(model_choice in ["fasterrcnn-s", "maskrcnn-s"]):
            PARAMS['task'] = 'detection'
        elif(model_choice == "fcn-m"):
            PARAMS['task'] = 'segmentation'
        else:
            PARAMS['task'] = 'classification'
        model = Model(PARAMS)
    else:
        model = Model.load_from_checkpoint(f"models/{model_choice}.ckpt").cpu()
    model.eval()
    transforms = model.model.preprocess
    tensor_image = transforms(image)
    with torch.no_grad():
        y_hat = model(tensor_image.unsqueeze(0))
        if(model.task == "classification"):
            preds = torch.softmax(y_hat, dim=-1).tolist()
            labels = {class_names[k]: float(v) for k, v in enumerate(preds[0][:-1])}
        elif(model.task == "segmentation"):
            num_classes = y_hat.shape[1]
            masks = y_hat[0]
            classes_masks = masks.argmax(0) == torch.arange(num_classes)[:, None, None]
            tensor_image = inv_normalize(tensor_image)
            detection = draw_segmentation_masks((tensor_image * 255.).to(torch.uint8), 
                                              masks=classes_masks, alpha=.6)
            detection = detection.numpy().transpose(1, 2, 0) / 255.
        elif(model.task == "detection"):
            if("maskrcnn" in model_choice):
                boolean_masks = [out['masks'][out['scores'] > .75] > 0.5
                                for out in y_hat][0]
                detection = draw_segmentation_masks((tensor_image * 255.).to(torch.uint8),
                                                    boolean_masks.squeeze(1), alpha=0.8)
            else:
                detection = draw_bounding_boxes((tensor_image * 255.).to(torch.uint8), 
                                                    boxes=y_hat[0]["boxes"][:5],
                                                    colors="red",
                                                    width=5)
            detection = detection.numpy().transpose(1, 2, 0) / 255.

    return labels, detection

In [None]:
# %%script false --no-raise-error

import gradio as gr

title = "Application Demo "
description = "# A Demo of Wrapping Pretrained Networks"
example_list = [["examples/" + example] for example in os.listdir("examples")]


with gr.Blocks() as demo:
    demo.title = title
    gr.Markdown(description)
    with gr.Row():
        with gr.Column():
            model = gr.Dropdown(model_list, label="Select Model", interactive=True)
            im = gr.Image(type="pil", label="input image")
            label_conv = gr.Label(label="Predictions", num_top_classes=4)
        with gr.Column():
            im_detection = gr.Image(type="pil", label="Detection")
            btn = gr.Button(value="predict")
    btn.click(predict, inputs=[im, model], outputs=[label_conv, im_detection])
    gr.Examples(examples=example_list, inputs=[im, model], outputs=[label_conv, im_detection])
      

demo.launch(share=True)