# Model

> All things related to models (loading, utils,...)

In [None]:
#| default_exp model

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os

import torch
import torch.nn as nn
from torchvision import transforms, models



In [None]:
#| export
def build_model(backbone="resnet18", num_classes=3, pretrained=True):

    if backbone == "resnet18":
        model = models.resnet18(pretrained=pretrained)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif backbone == "efficientnet":
        model = models.efficientnet_b0(pretrained=pretrained)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    else:
        raise ValueError("Unsupported backbone")
    return model

In [None]:
#| export
def load_ckpt(model, ckpt_path):
    if os.path.exists(ckpt_path):
        state_dict = torch.load(ckpt_path, map_location='cpu')
        model.load_state_dict(state_dict)
        print(f"Loaded pretrained weights from {ckpt_path}")
    else:
        print(f"No checkpoint found at {ckpt_path}. Using random initialized weights.")
    return model

In [None]:
#| export
def init_model(cfg): 
    
    model = build_model(backbone=cfg.model.backbone, num_classes=len(cfg.data.label_names), pretrained=cfg.model.pretrained)
    model = load_ckpt(model, cfg.model.ckpt)
    for param in model.parameters():
        param.requires_grad = False

    if cfg.task == "eval":
        return model

    if cfg.task == 'probing':
        key = "fc" if cfg.model.backbone == "resnet18" else "classifier.1"
        for param in getattr(model, key).parameters():
            param.requires_grad = True

    elif cfg.task == 'fine-tuning':
        for param in model.parameters():
            param.requires_grad = True

    else:
        raise ValueError("Unsupported type. Choose either 'eval', 'probing', or 'fine-tuning'.")

    return model

In [None]:
#| hide
from omegaconf import OmegaConf
cfg = OmegaConf.load("../cfgs/task_1/resnet/eval.yaml")
cfg.model.ckpt = "../pretrained_backbone/ckpt_resnet18_ep50.pt"

model = init_model(cfg)


Loaded pretrained weights from ../pretrained_backbone/ckpt_resnet18_ep50.pt


  state_dict = torch.load(ckpt_path, map_location='cpu')


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()