Skip to content

Commit

Permalink
added model source parameter; fixes #56
Browse files Browse the repository at this point in the history
  • Loading branch information
hahahannes committed Aug 8, 2022
1 parent 32236ee commit 1c174f4
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 75 deletions.
76 changes: 76 additions & 0 deletions tests/model/test_load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import unittest

import tests.helper as helper

import numpy as np
import tensorflow as tf
from torchvision import transforms as T
from thingsvision.model_class import Model

class ModelLoadingTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls):
helper.create_test_images()

def test_mode_and_device(self):
model_name = 'vgg16_bn'
model, dataset, dl = helper.create_model_and_dl(model_name, 'pt')
self.assertTrue(hasattr(model.model, helper.DEVICE))
self.assertFalse(model.model.training)

def test_load_model_without_source(self):
model_name = 'vgg16'
model = Model(model_name, False, 'cpu')
self.assertEqual(model.model.__class__.__name__, 'VGG')

with self.assertRaises():
model_name = 'random_model'
model = Model(model_name, False, 'cpu')

def test_load_custom_user_model(self):
model_name = 'VGG16bn_ecoset'
model = Model(model_name, False, 'cpu')
self.assertEqual(model.model.__class__.__name__, 'VGG', source='custom')

model_name = 'Resnet50_ecoset'
model = Model(model_name, False, 'cpu')
self.assertEqual(model.model.__class__.__name__, 'ResNet', source='custom')

model_name = 'Alexnet_ecoset'
model = Model(model_name, False, 'cpu')
print(model.__class__.__name__)
self.assertEqual(model.model.__class__.__name__, 'AlexNet', source='custom')

with self.assertRaises():
model_name = 'random_model'
model = Model(model_name, False, 'cpu', source='custom')

def test_load_timm_models(self):
model_name = 'mixnet_l'
model = Model(model_name, False, 'cpu')
self.assertEqual(model.model.__class__.__name__, 'EfficientNet', source='timm')

with self.assertRaises():
model_name = 'random_model'
model = Model(model_name, False, 'cpu', source='timm')

def test_load_torchvision_models(self):
model_name = 'vgg16'
model = Model(model_name, False, 'cpu')
self.assertEqual(model.model.__class__.__name__, 'VGG', source='torchvision')

with self.assertRaises():
model_name = 'random_model'
model = Model(model_name, False, 'cpu', source='torchvision')

def test_load_keras_models(self):
model_name = 'VGG16'
model = Model(model_name, False, 'cpu')
self.assertEqual(model.model.__class__.__name__, 'VGG', source='keras')

with self.assertRaises():
model_name = 'random_model'
model = Model(model_name, False, 'cpu', source='keras')


30 changes: 2 additions & 28 deletions tests/model/test_model.py → tests/model/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,12 @@
from torchvision import transforms as T
from thingsvision.model_class import Model

class ModelLoadingTestCase(unittest.TestCase):
class ModelTransformationsTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls):
helper.create_test_images()

def test_mode_and_device(self):
model_name = 'vgg16_bn'
model, dataset, dl = helper.create_model_and_dl(model_name, 'pt')
self.assertTrue(hasattr(model.model, helper.DEVICE))
self.assertFalse(model.model.training)

def test_transformations_clip(self):
model_name = 'clip-RN'
model, dataset, dl = helper.create_model_and_dl(model_name, 'pt')
Expand All @@ -38,24 +32,4 @@ def test_transformations_cnn(self):
model_name = 'VGG16'
model, dataset, dl = helper.create_model_and_dl(model_name, 'tf')
transforms = model.get_transformations()
self.assertTrue(isinstance(transforms, tf.keras.Sequential))

def test_load_custom_user_model(self):
model_name = 'VGG16bn_ecoset'
model = Model(model_name, False, 'cpu')
self.assertEqual(model.model.__class__.__name__, 'VGG')

model_name = 'Resnet50_ecoset'
model = Model(model_name, False, 'cpu')
self.assertEqual(model.model.__class__.__name__, 'ResNet')

model_name = 'Alexnet_ecoset'
model = Model(model_name, False, 'cpu')
print(model.__class__.__name__)
self.assertEqual(model.model.__class__.__name__, 'AlexNet')

def test_load_timm_models(self):
model_name = 'mixnet_l'
model = Model(model_name, False, 'cpu')
self.assertEqual(model.model.__class__.__name__, 'EfficientNet')

self.assertTrue(isinstance(transforms, tf.keras.Sequential))
138 changes: 91 additions & 47 deletions thingsvision/model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(self,
pretrained: bool,
device: str,
model_path: str=None,
backend: str='pt'):
backend: str=None,
source: str=None):
"""
Parameters
----------
Expand All @@ -52,70 +53,113 @@ def __init__(self,
backend : str (optional)
Deep learning framework that should be used.
'pt' for PyTorch and 'tf' for Tensorflow
source: str (optional)
Source of model and weights. If not set, all
models sources are searched for the model name
until the first occurence.
"""

self.model_name = model_name
self.backend = backend
self.pretrained = pretrained
self.device = device
self.model_path = model_path
self.source = source
self.backend = backend
self.load_model()


def load_model_from_torchvision(self):
if hasattr(torchvision_models, self.model_name):
self.backend = 'pt'
model = getattr(torchvision_models, self.model_name)
return model(pretrained=self.pretrained)

def load_model_from_timm(self):
if self.model_name in timm.list_models():
self.backend = 'pt'
return timm.create_model(self.model_name, self.pretrained)

def load_model_from_custom_models(self):
if re.search(r'^clip', self.model_name):
clip_model_name = "RN50"
if re.search(r'ViT$', self.model_name):
clip_model_name = "ViT-B/32"
model, self.clip_n_px = clip.load(
clip_model_name,
device=self.device,
model_path=self.model_path,
pretrained=self.pretrained,
jit=False,
)
return model

if re.search(r'^cornet', self.model_name):
try:
model = getattr(cornet, f'cornet_{self.model_name[-1]}')
except:
model = getattr(cornet, f'cornet_{self.model_name[-2:]}')
model = model(pretrained=self.pretrained, map_location=torch.device(self.device))
model = model.module # remove DataParallel
return model

if hasattr(custom_models, self.model_name):
model = getattr(custom_models, self.model_name)
return model(self.device, self.backend).create_model()

def load_model_from_keras(self):
if hasattr(tensorflow_models, self.model_name):
self.backend = 'tf'
model = getattr(tensorflow_models, self.model_name)
if self.pretrained:
weights = 'imagenet'
elif self.model_path:
weights = self.model_path
else:
weights = None

return model(weights=weights)


def load_model(self) -> Tuple[Any, Any]:
"""Load a pretrained *torchvision* or CLIP model into memory."""
if self.backend == 'pt':
if re.search(r'^clip', self.model_name):
clip_model_name = "RN50"
if re.search(r'ViT$', self.model_name):
clip_model_name = "ViT-B/32"
self.model, self.clip_n_px = clip.load(
clip_model_name,
device=self.device,
model_path=self.model_path,
pretrained=self.pretrained,
jit=False,
)
"""Load a pretrained model into memory."""

if self.source:
model = None
if self.source == 'timm':
model = self.get_model_from_timm()
elif self.source == 'keras':
model = self.get_model_from_keras()
elif self.source == 'torchvision':
model = self.get_model_from_torchvision()
elif self.source == 'custom':
model = self.get_model_from_custom_models()

if not model:
raise Exception(f'Model {self.model_name} not found in {self.source}')
elif not self.source:
model = None
for found_model in [self.load_model_from_torchvision(),
self.load_model_from_timm(),
self.load_model_from_keras(),
self.load_model_from_custom_models()]:
if found_model:
model = found_model
break
if found_model:
self.model = found_model
else:
device = torch.device(self.device)
if re.search(r'^cornet', self.model_name):
try:
self.model = getattr(cornet, f'cornet_{self.model_name[-1]}')
except:
self.model = getattr(cornet, f'cornet_{self.model_name[-2:]}')
self.model = self.model(pretrained=self.pretrained, map_location=device)
self.model = self.model.module # remove DataParallel
elif hasattr(custom_models, self.model_name):
self.model = getattr(custom_models, self.model_name)(self.device, self.backend).create_model()
elif hasattr(torchvision_models, self.model_name):
self.model = getattr(torchvision_models, self.model_name)
self.model = self.model(pretrained=self.pretrained)
elif self.model_name in timm.list_models():
self.model = timm.create_model(self.model_name, self.pretrained)

self.model = self.model.to(device)

raise Exception(f'Model {self.model_name} not found in all sources')

if self.backend == 'pt':
device = torch.device(self.device)
if self.model_path:
try:
state_dict = torch.load(self.model_path, map_location=device)
except FileNotFoundError:
state_dict = torch.hub.load_state_dict_from_url(self.model_path, map_location=device)
self.model.load_state_dict(state_dict)
self.model.eval()
elif self.backend == 'tf':
if hasattr(custom_models, self.model_name):
self.model = getattr(custom_models, self.model_name)(self.device, self.backend).create_model()
else:
model = getattr(tensorflow_models, self.model_name)
if self.pretrained:
weights = 'imagenet'
elif self.model_path:
weights = self.model_path
else:
weights = None

self.model = model(weights=weights)

self.model = self.model.to(device)

def show(self) -> str:
"""Show architecture of model to select a layer."""
Expand Down

0 comments on commit 1c174f4

Please sign in to comment.