In [21]:
import torch
import numpy as np
import torch.nn as nn
from torchvision.models.segmentation import deeplabv3_resnet50

In [23]:
def seed_everything(seed_value):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


# For reproducibility
seed = 100
seed_everything(seed)

In [18]:
def prepare_model(backbone_model=deeplabv3_resnet50, num_classes=2):
    weights = 'DEFAULT'    # Initialize model with pre-trained weights.

    model = deeplabv3_resnet50(weights=weights)
 
    # Update the number of output channels for the output layer.
    # This will remove the pre-trained weights for the last layer.
    model.classifier[4]         = nn.LazyConv2d(num_classes, 1)
    model.aux_classifier[4] = nn.LazyConv2d(num_classes, 1)
 
    return model

In [19]:
# Dummy Initialization.
model = prepare_model(num_classes=2)
model.train() 
# In train mode, batch size needs to be at least 2.
out = model(torch.randn((2, 3, 384, 384)))
print(out['out'].shape) # torch.Size([2, 2, 384, 384])

torch.Size([2, 5, 384, 384])
