Skip to content

Commit

Permalink
added model source parameter; fixes #56
Browse files Browse the repository at this point in the history
  • Loading branch information
hahahannes committed Aug 10, 2022
1 parent 32236ee commit 6c20312
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 80 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

## Model collection

Features can be extracted for all models in [torchvision](https://pytorch.org/vision/0.8/models.html), all models in [Keras](https://www.tensorflow.org/api_docs/python/tf/keras/applications), all models in [timm](https://github.com/rwightman/pytorch-image-models), custom models trained on Ecoset, each of the [CORnet](https://github.com/dicarlolab/CORnet) versions and both [CLIP](https://github.com/openai/CLIP) variants (`clip-ViT` and `clip-RN`). Note, that the respective model name must be used. For example, if you want to use the VGG16 model from torchvision, you will have to use `vgg16` and if you want to use the VGG16 model from Keras, you will have to use the model name `VGG16`.<br>
Features can be extracted for all models in [torchvision](https://pytorch.org/vision/0.8/models.html), all models in [Keras](https://www.tensorflow.org/api_docs/python/tf/keras/applications), all models in [timm](https://github.com/rwightman/pytorch-image-models), custom models trained on Ecoset, each of the [CORnet](https://github.com/dicarlolab/CORnet) versions and both [CLIP](https://github.com/openai/CLIP) variants (`clip-ViT` and `clip-RN`). Note, that the respective model name must be used. For example, if you want to use the VGG16 model from torchvision, you will have to use `vgg16` and if you want to use the VGG16 model from Keras, you will have to use the model name `VGG16`. You can further specify the model source by setting the `source` parameter.<br>
For the correct abbreviations of [torchvision](https://pytorch.org/vision/0.8/models.html) models have a look [here](https://github.com/pytorch/vision/tree/master/torchvision/models). For the correct abbreviations of [CORnet](https://github.com/dicarlolab/CORnet) models look [here](https://github.com/dicarlolab/CORnet/tree/master/cornet). To separate the string `cornet` from its variant (e.g., `s`, `z`) use a hyphen instead of an underscore (e.g., `cornet-s`, `cornet-z`).<br>

Examples: `alexnet`, `resnet18`, `resnet50`, `resnet101`, `vit_b_16`, `vit_b_32`, `vgg13`, `vgg13_bn`, `vgg16`, `vgg16_bn`, `vgg19`, `vgg19_bn`, `cornet-s`, `clip-ViT`
Expand Down Expand Up @@ -64,7 +64,7 @@ You can find the jupyter notebook using `PyTorch` [here](https://colab.research.

6. If you happen to extract hidden unit activations for many images, it is possible to run into `MemoryErrors`. To circumvent such problems, a helper function called `split_activations` will split the activation matrix into several batches, and stores them in separate files. For now, the split parameter is set to `10`. Hence, the function will split the activation matrix into `10` files. This parameter can, however, easily be modified in case you need more (or fewer) splits. To merge the separate activation batches back into a single activation matrix, just call `merge_activations` when loading the activations (e.g., `activations = merge_activations(PATH)`).

## Extract features at specific layer of a state-of-the-art `torchvision`, `TensorFlow`, `CORnet`, or `CLIP` model
## Extract features at specific layer of a state-of-the-art `torchvision`, `TensorFlow`, `CORnet`, or `CLIP`, `Timm` model

The following examples demonstrate how to load a model with PyTorch or TensorFlow into memory, and how to subsequently extract features.
Please keep in mind, that the model names as well as the layer names depend on the backend. If you use PyTorch, you will need to use these [model names](https://pytorch.org/vision/stable/models.html). If you use Tensorflow, you will need to use these [model names](https://keras.io/api/applications/). You can find the layer names by using `model.show()`.
Expand All @@ -79,11 +79,12 @@ import thingsvision.vision as vision
from thingsvision.model_class import Model

model_name = 'alexnet'
source = 'torchvision'
backend = 'pt'
batch_size = 64

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Model(model_name, pretrained=True, model_path=None, device=device, backend=backend)
model = Model(model_name, pretrained=True, model_path=None, device=device, backend=backend, source=source)
module_name = model.show()

AlexNet(
Expand Down
4 changes: 4 additions & 0 deletions doc/pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,16 @@
"source": [
"## load model\n",
"model_name = 'vgg16_bn' \n",
"# specify the model source as VGG16 can be used from torchvision, timm, ... In this case torchvision is used (https://pytorch.org/vision/stable/models.html)\n",
"source = 'torchvision' \n",
"\n",
"model = Model(\n",
" model_name,\n",
" pretrained=pretrained,\n",
" model_path=model_path,\n",
" device=device,\n",
" backend=backend,\n",
" source=source\n",
")"
]
},
Expand Down
4 changes: 4 additions & 0 deletions doc/tensorflow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,16 @@
"source": [
"## load model\n",
"model_name = 'VGG16' \n",
"# specify the model source, in this case use Keras applications (https://keras.io/api/applications/)\n",
"source = 'keras' \n",
"\n",
"model = Model(\n",
" model_name,\n",
" pretrained=pretrained,\n",
" model_path=model_path,\n",
" device=device,\n",
" backend=backend,\n",
" source=source\n",
")"
]
},
Expand Down
2 changes: 1 addition & 1 deletion tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

PT_MODEL_AND_MODULE_NAMES = {
# Torchvision models
'vgg16_bn': {
'vgg16': {
'modules': ['features.23', 'classifier.3'],
'pretrained': True
},
Expand Down
81 changes: 81 additions & 0 deletions tests/model/test_load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import unittest

import tests.helper as helper

import numpy as np
import tensorflow as tf
from torchvision import transforms as T
from thingsvision.model_class import Model

class ModelLoadingTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls):
helper.create_test_images()

def check_model_loading(self, model_name, expected_class_name, source=None, backend='pt'):
model = Model(model_name, False, 'cpu', source=source, backend=backend)
self.assertEqual(model.model.__class__.__name__, expected_class_name)

def check_unknown_model_loading(self, model_name, expected_exception, source=None, backend='pt'):
with self.assertRaises(Exception) as e:
model = Model(model_name, False, 'cpu', source=source, backend=backend)
self.assertEqual(e.exception, expected_exception)

def test_mode_and_device(self):
model_name = 'vgg16_bn'
model, dataset, dl = helper.create_model_and_dl(model_name, 'pt')
self.assertTrue(hasattr(model.model, helper.DEVICE))
self.assertFalse(model.model.training)

def test_load_model_without_source(self):
model_name = 'vgg16' # PyTorch
self.check_model_loading(model_name, 'VGG')

model_name = 'VGG16' # Tensorflow
self.check_model_loading(model_name, 'Functional', backend='tf')

model_name = 'random'
self.check_unknown_model_loading(model_name, f'Model {model_name} not found in all sources')

def test_load_custom_user_model(self):
source = 'custom'

model_name = 'VGG16bn_ecoset'
self.check_model_loading(model_name, 'VGG', source)

model_name = 'Resnet50_ecoset'
self.check_model_loading(model_name, 'ResNet', source)

model_name = 'Alexnet_ecoset'
self.check_model_loading(model_name, 'AlexNet', source)

model_name = 'random'
self.check_unknown_model_loading(model_name, f'Model {model_name} not found in {source}')

def test_load_timm_models(self):
model_name = 'mixnet_l'
source='timm'
self.check_model_loading(model_name, 'EfficientNet', source)

model_name = 'random'
self.check_unknown_model_loading(model_name, f'Model {model_name} not found in {source}')

def test_load_torchvision_models(self):
model_name = 'vgg16'
source='torchvision'
self.check_model_loading(model_name, 'VGG', source)

model_name = 'random'
self.check_unknown_model_loading(model_name, f'Model {model_name} not found in {source}')

def test_load_keras_models(self):
source = 'keras'
model_name = 'VGG16'
backend = 'tf'
self.check_model_loading(model_name, 'Functional', source, backend)

model_name = 'random'
self.check_unknown_model_loading(model_name, f'Model {model_name} not found in {source}', backend)


30 changes: 2 additions & 28 deletions tests/model/test_model.py → tests/model/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,12 @@
from torchvision import transforms as T
from thingsvision.model_class import Model

class ModelLoadingTestCase(unittest.TestCase):
class ModelTransformationsTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls):
helper.create_test_images()

def test_mode_and_device(self):
model_name = 'vgg16_bn'
model, dataset, dl = helper.create_model_and_dl(model_name, 'pt')
self.assertTrue(hasattr(model.model, helper.DEVICE))
self.assertFalse(model.model.training)

def test_transformations_clip(self):
model_name = 'clip-RN'
model, dataset, dl = helper.create_model_and_dl(model_name, 'pt')
Expand All @@ -38,24 +32,4 @@ def test_transformations_cnn(self):
model_name = 'VGG16'
model, dataset, dl = helper.create_model_and_dl(model_name, 'tf')
transforms = model.get_transformations()
self.assertTrue(isinstance(transforms, tf.keras.Sequential))

def test_load_custom_user_model(self):
model_name = 'VGG16bn_ecoset'
model = Model(model_name, False, 'cpu')
self.assertEqual(model.model.__class__.__name__, 'VGG')

model_name = 'Resnet50_ecoset'
model = Model(model_name, False, 'cpu')
self.assertEqual(model.model.__class__.__name__, 'ResNet')

model_name = 'Alexnet_ecoset'
model = Model(model_name, False, 'cpu')
print(model.__class__.__name__)
self.assertEqual(model.model.__class__.__name__, 'AlexNet')

def test_load_timm_models(self):
model_name = 'mixnet_l'
model = Model(model_name, False, 'cpu')
self.assertEqual(model.model.__class__.__name__, 'EfficientNet')

self.assertTrue(isinstance(transforms, tf.keras.Sequential))
2 changes: 1 addition & 1 deletion tests/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class RDMTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
helper.create_test_images()
model_name = 'vgg16_bn'
model_name = 'vgg16'
model, _, dl = helper.create_model_and_dl(model_name, 'pt')
module_name = helper.PT_MODEL_AND_MODULE_NAMES[model_name]['modules'][0]
features, _ = model.extract_features(
Expand Down
137 changes: 90 additions & 47 deletions thingsvision/model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(self,
pretrained: bool,
device: str,
model_path: str=None,
backend: str='pt'):
backend: str=None,
source: str=None):
"""
Parameters
----------
Expand All @@ -52,70 +53,112 @@ def __init__(self,
backend : str (optional)
Deep learning framework that should be used.
'pt' for PyTorch and 'tf' for Tensorflow
source: str (optional)
Source of model and weights. If not set, all
models sources are searched for the model name
until the first occurence.
"""

self.model_name = model_name
self.backend = backend
self.pretrained = pretrained
self.device = device
self.model_path = model_path
self.source = source
self.backend = backend
self.load_model()


def load_model(self) -> Tuple[Any, Any]:
"""Load a pretrained *torchvision* or CLIP model into memory."""
if self.backend == 'pt':
if re.search(r'^clip', self.model_name):
clip_model_name = "RN50"
if re.search(r'ViT$', self.model_name):
clip_model_name = "ViT-B/32"
self.model, self.clip_n_px = clip.load(
clip_model_name,
device=self.device,
model_path=self.model_path,
pretrained=self.pretrained,
jit=False,
)
def get_model_from_torchvision(self):
if hasattr(torchvision_models, self.model_name):
model = getattr(torchvision_models, self.model_name)
return model(pretrained=self.pretrained)

def get_model_from_timm(self):
if self.model_name in timm.list_models():
return timm.create_model(self.model_name, self.pretrained)

def get_model_from_custom_models(self):
if re.search(r'^clip', self.model_name):
clip_model_name = "RN50"
if re.search(r'ViT$', self.model_name):
clip_model_name = "ViT-B/32"
model, self.clip_n_px = clip.load(
clip_model_name,
device=self.device,
model_path=self.model_path,
pretrained=self.pretrained,
jit=False,
)
return model

if re.search(r'^cornet', self.model_name):
try:
model = getattr(cornet, f'cornet_{self.model_name[-1]}')
except:
model = getattr(cornet, f'cornet_{self.model_name[-2:]}')
model = model(pretrained=self.pretrained, map_location=torch.device(self.device))
model = model.module # remove DataParallel
return model

if hasattr(custom_models, self.model_name):
model = getattr(custom_models, self.model_name)
model = model(self.device, self.backend).create_model()
return model

def get_model_from_keras(self):
if hasattr(tensorflow_models, self.model_name):
model = getattr(tensorflow_models, self.model_name)
if self.pretrained:
weights = 'imagenet'
elif self.model_path:
weights = self.model_path
else:
device = torch.device(self.device)
if re.search(r'^cornet', self.model_name):
try:
self.model = getattr(cornet, f'cornet_{self.model_name[-1]}')
except:
self.model = getattr(cornet, f'cornet_{self.model_name[-2:]}')
self.model = self.model(pretrained=self.pretrained, map_location=device)
self.model = self.model.module # remove DataParallel
elif hasattr(custom_models, self.model_name):
self.model = getattr(custom_models, self.model_name)(self.device, self.backend).create_model()
elif hasattr(torchvision_models, self.model_name):
self.model = getattr(torchvision_models, self.model_name)
self.model = self.model(pretrained=self.pretrained)
elif self.model_name in timm.list_models():
self.model = timm.create_model(self.model_name, self.pretrained)

self.model = self.model.to(device)
weights = None
return model(weights=weights)

def load_model_from_source(self):
model = None
if self.source == 'timm':
model = self.get_model_from_timm()
elif self.source == 'keras':
model = self.get_model_from_keras()
elif self.source == 'torchvision':
model = self.get_model_from_torchvision()
elif self.source == 'custom':
model = self.get_model_from_custom_models()

if not model:
raise Exception(f'Model {self.model_name} not found in {self.source}')
else:
self.model = model

def load_model(self) -> Tuple[Any, Any]:
"""Load a pretrained model into memory."""
if self.source:
self.load_model_from_source()
elif not self.source:
for model_loader in [self.get_model_from_torchvision,
self.get_model_from_timm,
self.get_model_from_keras,
self.get_model_from_custom_models]:
found_model = model_loader()
if found_model:
self.model = found_model
break

if not self.model:
raise Exception(f'Model {self.model_name} not found in all sources')

if self.backend == 'pt':
device = torch.device(self.device)
if self.model_path:
try:
state_dict = torch.load(self.model_path, map_location=device)
except FileNotFoundError:
state_dict = torch.hub.load_state_dict_from_url(self.model_path, map_location=device)
self.model.load_state_dict(state_dict)
self.model.eval()
elif self.backend == 'tf':
if hasattr(custom_models, self.model_name):
self.model = getattr(custom_models, self.model_name)(self.device, self.backend).create_model()
else:
model = getattr(tensorflow_models, self.model_name)
if self.pretrained:
weights = 'imagenet'
elif self.model_path:
weights = self.model_path
else:
weights = None

self.model = model(weights=weights)

self.model = self.model.to(device)

def show(self) -> str:
"""Show architecture of model to select a layer."""
Expand Down

0 comments on commit 6c20312

Please sign in to comment.