<a href="https://colab.research.google.com/github/LuckerZOfficiaL/A-Contrastive-Learning-Approach-for-Finger-Photo-Identification/blob/main/Architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Architecture Implementation

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torchvision.models as models

In [None]:
class ResnetClassifier(nn.Module):
  def __init__(self, num_classes=100, hidden_size=256, initialize="xavier"):
    super(ResnetClassifier, self).__init__()
    self.conv_1to3 = nn.Conv2d(1, 3, kernel_size=1, stride=1, padding=0)
    self.pretrained_model = models.resnet18(pretrained=True)
    #self.intermediate_norm = nn.BatchNorm1d(1000, affine=False) # batch norm doesn't work when batch size = 1, so let's use layerNorm
    self.intermediate_norm = nn.LayerNorm(1000)
    self.projector = nn.Linear(1000, hidden_size)
    self.final_norm = nn.LayerNorm(hidden_size)
    self.dropout = nn.Dropout(0.2)
    self.activation = nn.GELU()
    self.fc = nn.Linear(hidden_size, num_classes)
    self.softmax = nn.Softmax(dim=1)

    if initialize == "kaiming":
      init.kaiming_uniform_(self.projector.weight, mode='fan_in', nonlinearity='leaky_relu')
      init.kaiming_uniform_(self.fc.weight, mode='fan_out')
      init.constant_(self.fc.bias, 0)
    if initialize == "xavier":
      init.xavier_uniform_(self.projector.weight)
      init.xavier_uniform_(self.fc.weight)
      init.constant_(self.fc.bias, 0)


  def forward(self, x):
      if x.size()[1] == 1:
        x = self.conv_1to3(x)
      x = self.pretrained_model(x)
      x = self.intermediate_norm(x)
      x = self.projector(x)
      x = self.final_norm(x)
      x = self.activation(x)
      x = self.dropout(x)
      x = self.fc(x)
      x = self.softmax(x)
      return x
  def forward_logits(self, x):
      if x.size()[1] == 1:
        x = self.conv_1to3(x)
      x = self.pretrained_model(x)
      x = self.intermediate_norm(x)
      #x = self.activation(x)
      x = self.projector(x)
      x = self.final_norm(x)
      x = self.activation(x)
      x = self.dropout(x)
      x = self.fc(x)
      return x
  def forward_projector(self, x):
      if x.size()[1] == 1:
        x = self.conv_1to3(x)
      x = self.pretrained_model(x)
      x = self.intermediate_norm(x)
      x = self.projector(x)
      return x
  def forward_backbone(self, x):
      if x.size()[1] == 1:
        x = self.conv_1to3(x)
      x = self.pretrained_model(x)
      return x


In [None]:
"""x = torch.rand((8, 1, 123, 331))
model = ResnetClassifier(num_classes=100, hidden_size=256, initialize="xavier")
model.forward_projector(x).shape"""

torch.Size([8, 256])