In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
import torchmetrics
from torchinfo import summary
from pathlib import Path
import os
import pandas as pd
import matplotlib.pyplot as plt

# Internal libs
os.chdir("../scripts")  # WARNING: changing dir every time cell is called
import data, engine, utils
from constants import *
os.chdir("../notebooks")

In [2]:
if torch.cuda.is_available():
    torch.cuda.init()
    torch.cuda.empty_cache()
    device = "cuda"
else:
    device = "cpu"

# For M1 Mac
# if torch.backends.mps.is_available() and torch.backends.mps.is_built():
#     device = "mps"
# else:
#     device = "cpu"

### Data

In [3]:
# ------------------ Data ------------------
image_transform = transforms.Compose(
    [
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.ToTensor(),
    ]
)

mask_transform = transforms.Compose(
    [
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.ToTensor(),
    ]
)

(
    train_dataloader,
    dev_dataloader,
    test_dataloader,
    class_names,
) = data.create_dataloaders(
    train_dir=Path(f"../{TRAIN_DIR}"),
    dev_dir=Path(f"../{DEV_DIR}"),
    test_dir=Path(f"../{TEST_DIR}"),
    batch_size=BATCH_SIZE,
    device=device,
    image_transform=image_transform,
    mask_transform=mask_transform,
)

## Model

In [4]:
from typing import OrderedDict

def my_load_state_dict(model: nn.Module, input_state_dict: OrderedDict):
    """loads input_state_dict into given model in place

    Args:
        model (nn.Module): model who's state_dict to update
        input_state_dict (OrderedDict): will update model's curret state_dict
    """
    own_state = model.state_dict()
    for name, param in input_state_dict.items():
        if name not in own_state:
            continue
        if isinstance(param, nn.Parameter):
            # backwards compatibility for serialized parameters
            param = param.data
        own_state[name].copy_(param)


def load_model(model_name: str, device: torch.device) -> nn.Module:
    """loads a model with state_dict from ./models/{model_name}/{model_name}.pth"""
    # ------------------ Model ------------------
    # instantiate DeepLabV3 model
    model = torchvision.models.segmentation.deeplabv3_resnet50().to(device)

    # modify classifier layer for desired number of classes
    model.classifier = DeepLabHead(in_channels=2048, num_classes=NUM_CLASSES)

    new_state_dict_path = Path(f"../models/{model_name}/{model_name}.pth")
    new_state_dict = torch.load(new_state_dict_path, map_location=device)
    my_load_state_dict(model, new_state_dict)

    return model

In [5]:
MODEL_NAME = "Universal Resnet50 23_06_04"

model = load_model(MODEL_NAME, device)

In [6]:
summary(
    model=model, 
    input_size=(2, 3, 1024, 1024), 
    col_names=["input_size", "output_size", "num_params", "trainable"], 
    col_width=20, 
    row_settings=["var_names"]
)

Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
DeepLabV3 (DeepLabV3)                              [2, 3, 1024, 1024]   [2, 2, 1024, 1024]   --                   True
├─IntermediateLayerGetter (backbone)               [2, 3, 1024, 1024]   [2, 2048, 128, 128]  --                   True
│    └─Conv2d (conv1)                              [2, 3, 1024, 1024]   [2, 64, 512, 512]    9,408                True
│    └─BatchNorm2d (bn1)                           [2, 64, 512, 512]    [2, 64, 512, 512]    128                  True
│    └─ReLU (relu)                                 [2, 64, 512, 512]    [2, 64, 512, 512]    --                   --
│    └─MaxPool2d (maxpool)                         [2, 64, 512, 512]    [2, 64, 256, 256]    --                   --
│    └─Sequential (layer1)                         [2, 64, 256, 256]    [2, 256, 256, 256]   --                   True
│    │    └─Bottleneck (0)                     