Skip to content

Commit

Permalink
Enable pretrained w. load #11
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Jun 12, 2023
1 parent f9c4cc3 commit 6dc0f68
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 25 deletions.
70 changes: 50 additions & 20 deletions torch_uncertainty/models/resnet/packed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# fmt: off
from typing import List, Type, Union
from typing import Any, Dict, List, Type, Union

import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -20,14 +20,14 @@

weight_ids = {
"10": {
"18": None,
"18": "pe_resnet18_c10",
"32": None,
"50": "pe_resnet50_c10",
"101": None,
"152": None,
},
"100": {
"18": None,
"18": "pe_resnet18_c100",
"32": None,
"50": "pe_resnet50_c100",
"101": None,
Expand All @@ -40,13 +40,6 @@
"101": None,
"152": None,
},
"1000_wider": {
"18": None,
"32": None,
"50": "pex4_resnet50",
"101": None,
"152": None,
},
}


Expand Down Expand Up @@ -211,6 +204,9 @@ def __init__(
super().__init__()

self.in_channels = in_channels
self.alpha = alpha
self.gamma = gamma
self.groups = groups
self.num_estimators = num_estimators
self.in_planes = 64
block_planes = self.in_planes
Expand Down Expand Up @@ -350,6 +346,15 @@ def forward(self, x: Tensor) -> Tensor:
out = self.linear(out)
return out

def check_config(self, config: Dict[str, Any]) -> bool:
"""Check if the pretrained configuration matches the current model."""
return (
(config["alpha"] == self.alpha)
* (config["gamma"] == self.gamma)
* (config["groups"] == self.groups)
* (config["num_estimators"] == self.num_estimators)
)


def packed_resnet18(
in_channels: int,
Expand Down Expand Up @@ -386,10 +391,15 @@ def packed_resnet18(
style=style,
)
if pretrained: # coverage: ignore
weights = weight_ids[str(num_classes)][18]
weights = weight_ids[str(num_classes)]["18"]
if weights is None:
raise ValueError("No pretrained weights for this configuration")
net.load_state_dict(load_hf(weights))
state_dict, config = load_hf(weights)
if not net.check_config(config):
raise ValueError(
"Pretrained weights do not match current configuration."
)
net.load_state_dict(state_dict)
return net


Expand Down Expand Up @@ -428,10 +438,15 @@ def packed_resnet34(
style=style,
)
if pretrained: # coverage: ignore
weights = weight_ids[str(num_classes)][34]
weights = weight_ids[str(num_classes)]["34"]
if weights is None:
raise ValueError("No pretrained weights for this configuration")
net.load_state_dict(load_hf(weights))
state_dict, config = load_hf(weights)
if not net.check_config(config):
raise ValueError(
"Pretrained weights do not match current configuration."
)
net.load_state_dict(state_dict)
return net


Expand Down Expand Up @@ -470,10 +485,15 @@ def packed_resnet50(
style=style,
)
if pretrained: # coverage: ignore
weights = weight_ids[str(num_classes)][50]
weights = weight_ids[str(num_classes)]["50"]
if weights is None:
raise ValueError("No pretrained weights for this configuration")
net.load_state_dict(load_hf(weights))
state_dict, config = load_hf(weights)
if not net.check_config(config):
raise ValueError(
"Pretrained weights do not match current configuration."
)
net.load_state_dict(state_dict)
return net


Expand Down Expand Up @@ -512,10 +532,15 @@ def packed_resnet101(
style=style,
)
if pretrained: # coverage: ignore
weights = weight_ids[str(num_classes)][101]
weights = weight_ids[str(num_classes)]["101"]
if weights is None:
raise ValueError("No pretrained weights for this configuration")
net.load_state_dict(load_hf(weights))
state_dict, config = load_hf(weights)
if not net.check_config(config):
raise ValueError(
"Pretrained weights do not match current configuration."
)
net.load_state_dict(state_dict)
return net


Expand Down Expand Up @@ -556,8 +581,13 @@ def packed_resnet152(
style=style,
)
if pretrained: # coverage: ignore
weights = weight_ids[str(num_classes)][152]
weights = weight_ids[str(num_classes)]["152"]
if weights is None:
raise ValueError("No pretrained weights for this configuration")
net.load_state_dict(load_hf(weights))
state_dict, config = load_hf(weights)
if not net.check_config(config):
raise ValueError(
"Pretrained weights do not match current configuration."
)
net.load_state_dict(state_dict)
return net
34 changes: 29 additions & 5 deletions torch_uncertainty/utils/hub.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,32 @@
# fmt: off
from pathlib import Path
from typing import Dict, Tuple

import torch
import yaml
from huggingface_hub import hf_hub_download


def load_hf(weight_id: str):
weights = hf_hub_download(
repo_id=f"torch-uncertainty/{weight_id}", filename=f"{weight_id}.ckpt"
)
return weights
# fmt: on
def load_hf(weight_id: str) -> Tuple[torch.Tensor, Dict]:
"""Load a model from the huggingface hub.
Args:
weight_id (str): The id of the model to load.
Returns:
Tuple[torch.Tensor, Dict]: The model weights and config.
"""
repo_id = f"torch-uncertainty/{weight_id}"

# Load the weights
weight_path = hf_hub_download(repo_id=repo_id, filename=f"{weight_id}.ckpt")
weight = torch.load(weight_path)
if "state_dict" in weight:
weight = weight["state_dict"]

# Load the config
config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml")
config = yaml.safe_load(Path(config_path).read_text())

return weight, config

0 comments on commit 6dc0f68

Please sign in to comment.