## overview

In [1]:
# default package
import logging
import sys 
import os
import pathlib
import IPython
import random
from urllib.request import urlretrieve
import dataclasses as dc
import tempfile

In [2]:
# third party package
import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt
import seaborn
from tqdm import tqdm
import seaborn as sns
import yaml
from mlflow.tracking import MlflowClient
from matplotlib.font_manager import FontProperties
import matplotlib
# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

In [5]:
# my package
sys.path.append(os.path.join(pathlib.Path().resolve(),"../"))
import src.model.model01.image_classifier_models as icm
import src.model.model01.image_network as image_net

In [6]:
# reload settings
%load_ext autoreload
%autoreload 2

In [7]:
# logger
logger=logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

In [8]:
# graph setting
sns.set()
font_path = "/usr/share/fonts/truetype/migmix/migmix-1p-regular.ttf"
font_prop = FontProperties(fname=font_path)
matplotlib.rcParams["font.family"] = font_prop.get_name()

In [9]:
# gpu
torch.cuda.is_available()
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [10]:
# chdir
current_dir=pathlib.Path().resolve()
if current_dir.stem=="notebooks":
    os.chdir(current_dir.parent)
logger.info(pathlib.Path().resolve())

INFO:__main__:/workspaces/load_to_goal/GitHub/pytorch-implementation/pytorch_mlflow_hydra_optuna


## load model

In [13]:
# mlflow global parameter
experiment_name="model01"
run_id="25a9cb491fa34b428026ba567e9f9e81"
tracking_uri="logs/mlruns"

In [14]:
def load_model():
    client = MlflowClient(tracking_uri=tracking_uri)
    with tempfile.TemporaryDirectory() as dname:
        config_path=client.download_artifacts(run_id,"config.yaml",dname)
        with open(config_path) as f:
            config=yaml.load(f,Loader=yaml.SafeLoader)
            
        model_path=[artifact.path for artifact 
                    in client.list_artifacts(run_id)
                    if "pth" in artifact.path][0]
        model_path=client.download_artifacts(run_id,model_path,dname)
        model = icm.LitClassifier(
            model=image_net.CNN(),
            **config,
        )
        model.load_state_dict(torch.load(model_path))
        model=model.to(device)
        model.eval()
        
    return model

In [15]:
model=load_model()

In [16]:
model

LitClassifier(
  (model): CNN(
    (layer1): Sequential(
      (0): Conv2d(1, 28, kernel_size=(5, 5), stride=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (layer2): Sequential(
      (0): Conv2d(28, 10, kernel_size=(2, 2), stride=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (dropout1): Dropout(p=0.25, inplace=False)
    (fc1): Linear(in_features=250, out_features=18, bias=True)
    (dropout2): Dropout(p=0.08, inplace=False)
    (fc2): Linear(in_features=18, out_features=10, bias=True)
  )
  (accuracy): Accuracy()
)