-
Notifications
You must be signed in to change notification settings - Fork 1
/
models.py
30 lines (23 loc) · 843 Bytes
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#used Resnet50
import torchvision.models as models
from torch.nn import Parameter
import torch
import torch.nn as nn
import sys
class Resnet(nn.Module):
def __init__(self, model, num_classes):
super(Resnet, self).__init__()
modules = list(model.children())[:-1]
self.features = nn.Sequential(*modules)
self.linear = nn.Linear(2048, num_classes)
def forward(self, feature, get_feature=False):
feature = self.features(feature)
gf = self.linear(feature.view(feature.size(0), -1))
gf = gf.view(gf.size(0), -1)
if(get_feature):
return feature
else:
return gf
def get_resnet50(num_classes, pretrained=True):
model = models.resnet50(pretrained=pretrained)
return Resnet(model, num_classes)