In [2]:
import torch
import torch.nn as nn
import torchvision.models as models

In [3]:
base_model = models.resnet18(pretrained=False)

In [4]:
print(list(base_model.children())[-1])


Linear(in_features=512, out_features=1000, bias=True)


In [5]:
num_cls = 30

In [6]:
class PretextModel(nn.Module):
    def __init__(self, num_cls):
        super(PretextModel, self).__init__()
        base_model = models.resnet50(pretrained=False)
        self.base_model = nn.Sequential(*list(base_model.children())[:-1])
        self.classifier = nn.Linear(2048, num_cls)
    
    def forward(self, x):
        x = self.base_model(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
    

In [7]:
pretext_model = PretextModel(30)


In [8]:
inp = torch.randn(5, 3, 245,245)
out = pretext_model(inp)
print(out.shape)

torch.Size([5, 30])


## Simple net

In [37]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
    
        self.conv_layers = nn.Sequential(nn.Conv2d(3, 6, 7),
                                         nn.MaxPool2d(4, 4),
                                         nn.BatchNorm2d(6),
                                         nn.Conv2d(6, 16, 5),
                                         nn.BatchNorm2d(16),
                                         nn.MaxPool2d(4, 4),) # B x 16 x 13 x 13)
        self.linear = nn.Linear(16 * 13 * 13, 120)
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

In [39]:
inp = torch.randn(5, 3, 245,245)
simple_model = SimpleModel()
out = simple_model(inp)
print(out.shape)

torch.Size([5, 120])
