In [None]:
from torchvision import models
import torch.nn as nn

def get_model(cfg_model, device):
    model_type = cfg_model.type

    num_classes = 2

    if model_type == 'resnet50':
        model = models.resnet50(weights='DEFAULT')

        num_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(128, num_classes),
        )

    elif model_type == 'efficientnet_b0':
        model = models.efficientnet_b0(weights=cfg_model.pretrained)
        model.classifier = nn.Linear(model.classifier.in_features, cfg_model.num_classes)

    if cfg_model.freeze_backbone:
        for name, param in model.named_parameters():
            if not('fc' in name or 'classifier' in name):
                param.requires_grad = False

    return model.to(device)