# Model Select

In [1]:
from data import create_dataset

import os
import torch
import yaml
from timm import create_model
from models import MemSeg

import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt

In [2]:
cfg = yaml.load(open('./configs/bottle.yaml','r'), Loader=yaml.FullLoader)
cfg

{'EXP_NAME': 'MemSeg',
 'SEED': 42,
 'DATASET': {'datadir': './datasets/MVTec',
  'texture_source_dir': './datasets/dtd/images',
  'target': 'bottle',
  'resize': (256, 256),
  'structure_grid_size': 8,
  'transparency_range': [0.15, 1.0],
  'perlin_scale': 6,
  'min_perlin_scale': 0,
  'perlin_noise_threshold': 0.5},
 'DATALOADER': {'batch_size': 8, 'num_workers': 0},
 'MEMORYBANK': {'nb_memory_sample': 30},
 'MODEL': {'feature_extractor_name': 'resnet18'},
 'TRAIN': {'batch_size': 8,
  'num_training_steps': 5000,
  'l1_weight': 0.6,
  'focal_weight': 0.4,
  'focal_alpha': None,
  'focal_gamma': 4,
  'use_wandb': False},
 'OPTIMIZER': {'lr': 0.003, 'weight_decay': 0.0005},
 'SCHEDULER': {'min_lr': 0.0001, 'warmup_ratio': 0.1, 'use_scheduler': True},
 'LOG': {'log_interval': 1, 'eval_interval': 100},
 'RESULT': {'savedir': './saved_model'}}

In [3]:
# ====================================
# Select Model
# ====================================
def load_model(model_name):
    global model
    global testset

    testset = create_dataset(
        datadir                = cfg['DATASET']['datadir'],
        target                 = model_name.split('-')[1],
        train                  = False,
        resize                 = cfg['DATASET']['resize'],
        texture_source_dir     = cfg['DATASET']['texture_source_dir'],
        structure_grid_size    = cfg['DATASET']['structure_grid_size'],
        transparency_range     = cfg['DATASET']['transparency_range'],
        perlin_scale           = cfg['DATASET']['perlin_scale'],
        min_perlin_scale       = cfg['DATASET']['min_perlin_scale'],
        perlin_noise_threshold = cfg['DATASET']['perlin_noise_threshold']
    )

    memory_bank = torch.load(f'./saved_model/{model_name}/memory_bank.pt')
    memory_bank.device = 'cpu'
    for k in memory_bank.memory_information.keys():
        memory_bank.memory_information[k] = memory_bank.memory_information[k].cpu()

    feature_extractor = create_model(
        cfg['MODEL']['feature_extractor_name'],
        pretrained    = True,
        features_only = True
    )
    model = MemSeg(
        memory_bank       = memory_bank,
        feature_extractor = feature_extractor
    )

    model.load_state_dict(torch.load(f'./saved_model/{model_name}/best_model.pt'))

In [4]:
# ====================================
# Visualization
# ====================================
def result_plot(idx):
    input_i, mask_i, target_i = testset[idx]

    output_i = model(input_i.unsqueeze(0)).detach()
    output_i = torch.nn.functional.softmax(output_i, dim=1)

    def minmax_scaling(img):
        return (((img - img.min()) / (img.max() - img.min())) * 255).to(torch.uint8)

    fig, ax = plt.subplots(1,4, figsize=(15,10))

    ax[0].imshow(minmax_scaling(input_i.permute(1,2,0)))
    ax[0].set_title('Input: {}'.format('Normal' if target_i == 0 else 'Abnormal'))
    ax[1].imshow(mask_i, cmap='gray')
    ax[1].set_title('Ground Truth')
    ax[2].imshow(output_i[0][1], cmap='gray')
    ax[2].set_title('Predicted Mask')
    ax[3].imshow(minmax_scaling(input_i.permute(1,2,0)), alpha=1)
    ax[3].imshow(output_i[0][1], cmap='gray', alpha=0.5)
    ax[3].set_title(f'Input X Predicted Mask')

    plt.show()

In [5]:
# ====================================
# widgets
# ====================================
model_list = widgets.Dropdown(
    options=os.listdir('./saved_model'),
    value='MemSeg-bottle',
    description='Model:',
    disabled=False,
)
button = widgets.Button(description="Model Change")
output = widgets.Output()


@output.capture()
def on_button_clicked(b):
    clear_output(wait=True)
    load_model(model_name=model_list.value)

    # vizualization
    file_list = widgets.Dropdown(
        options=[(file_path, i) for i, file_path in enumerate(testset.file_list)],
        value=0,
        description='image:',
    )

    widgets.interact(result_plot, idx=file_list)

button.on_click(on_button_clicked)

display(widgets.HBox([model_list, button]), output)

HBox(children=(Dropdown(description='Model:', options=('MemSeg-bottle', 'MemSeg-custom-256'), value='MemSeg-bo…

Output()

No such comm: 9c58511231af48709aedd2fc87f668f3
No such comm: 2c1891d8eaf0460da183f406e0b71f98
No such comm: 2c1891d8eaf0460da183f406e0b71f98
