In [105]:
import torch
import torch.nn as nn
import torchvision.models as models

In [106]:
base_model = models.resnet18(pretrained=False)

In [107]:
print(list(base_model.children())[-1])


Linear(in_features=512, out_features=1000, bias=True)


In [108]:
num_cls = 30

In [86]:
class PretextModel(nn.Module):
    def __init__(self, num_cls):
        super(PretextModel, self).__init__()
        base_model = models.resnet50(pretrained=False)
        self.base_model = nn.Sequential(*list(base_model.children())[:-1])
        self.classifier = nn.Linear(2048, num_cls)
    
    def forward(self, x):
        x = self.base_model(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
    

In [87]:
pretext_model = PretextModel(30)


In [88]:
inp = torch.randn(5, 3, 245,245)
out = pretext_model(inp)
print(out.shape)

torch.Size([5, 30])


## Simple net

In [103]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
    
        self.conv_layers = nn.Sequential(nn.Conv2d(3, 64, 5),
                                         nn.BatchNorm2d(64),
                                         nn.MaxPool2d(4, 4),
                                         nn.Conv2d(64, 128, 3),
                                         nn.BatchNorm2d(128),
                                         nn.MaxPool2d(2, 2),
                                         nn.Conv2d(128, 256, 3),
                                         nn.BatchNorm2d(256),
                                         nn.MaxPool2d(2, 2),
                                         nn.Conv2d(256, 512, 3),
                                         nn.BatchNorm2d(512),
                                         nn.MaxPool2d(2, 2),) # B x 256 x 5 x 5)
        self.linear = nn.Linear(512 * 2 * 2, 128)
    
    def forward(self, x):
        x = self.conv_layers(x)
        print(x.shape)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

In [104]:
inp = torch.randn(5, 3, 128,128)
simple_model = SimpleModel()
out = simple_model(inp)
print(out.shape)

torch.Size([5, 512, 2, 2])
torch.Size([5, 128])
