In [12]:
import numpy as np
import os
import glob
import argparse
import backbone
import torch.nn as nn


import torch
from torch.autograd import Variable
import torch.nn as nn
import math
import numpy as np
import torch.nn.functional as F
from torch.nn.utils.weight_norm import WeightNorm

#Redefine Conv4 here :

def init_layer(L):
    # Initialization using fan-in
    if isinstance(L, nn.Conv2d):
        n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels
        L.weight.data.normal_(0,math.sqrt(2.0/float(n)))
    elif isinstance(L, nn.BatchNorm2d):
        L.weight.data.fill_(1)
        L.bias.data.fill_(0)

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)    
    
class ConvNet(nn.Module):
    maml = False # Default

    def __init__(self, depth, n_way=-1, flatten=True, padding=1, bn=False):
        super(ConvNet, self).__init__()
        layers = []
        for i in range(depth):
            indim = 3 if i == 0 else 64
            outdim = 64
            if self.maml:
                conv_layer = Conv2d_fw(indim, outdim, 3, padding=padding)
                if bn:
                    BN     = BatchNorm2d_fw(outdim)
            else:
                conv_layer = nn.Conv2d(indim, outdim, 3, stride=1, padding=padding, bias=False)
                if bn:
                    BN     = nn.BatchNorm2d(outdim)
            
            relu = nn.ReLU(inplace=True)
            layers.append(conv_layer)
            if bn:
                layers.append(BN)
            layers.append(relu)

            if i < 4:  # Pooling only for the first 4 layers
                pool = nn.MaxPool2d(2)
                layers.append(pool)

            # Initialize the layers
            init_layer(conv_layer)
            if bn:
                init_layer(BN)

        if flatten:
            layers.append(Flatten())
        
        if n_way>0:
            layers.append(nn.Linear(1600,n_way))
            self.final_feat_dim = n_way
        else:
            self.final_feat_dim = 1600
            
        self.trunk = nn.Sequential(*layers)
        

    def forward(self, x):
        out = self.trunk(x)
        return out

def Conv4_mine():
    print("Conv4 No Batch Normalization")
    return ConvNet(4, bn=True)

Conv4_mine = Conv4_mine()

Conv4 No Batch Normalization
[Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)]
[Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)]
[Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), MaxPool2d(kernel_size=2

In [11]:
# Simple Conv Block
class ConvBlock(nn.Module):
    maml = False #Default
    def __init__(self, indim, outdim, pool = True, padding = 1):
        super(ConvBlock, self).__init__()
        self.indim  = indim
        self.outdim = outdim
        if self.maml:
            self.C      = Conv2d_fw(indim, outdim, 3, padding = padding)
            self.BN     = BatchNorm2d_fw(outdim)
        else:
            self.C      = nn.Conv2d(indim, outdim, 3, stride = 1, padding = padding, bias = True) # Default : bias=False
            self.BN     = nn.BatchNorm2d(outdim)
        self.relu   = nn.ReLU(inplace=True)

        self.parametrized_layers = [self.C, self.BN, self.relu]
        if pool:
            self.pool   = nn.MaxPool2d(2)  # Originally : self.pool = nn.MaxPool2d(2)
            self.parametrized_layers.append(self.pool)

        for layer in self.parametrized_layers:
            init_layer(layer)

        self.trunk = nn.Sequential(*self.parametrized_layers)


    def forward(self,x):
        out = self.trunk(x)
        return out
    
    
class ConvNet(nn.Module):
    def __init__(self, depth, flatten = True):
        super(ConvNet,self).__init__()
        trunk = []
        for i in range(depth):
            indim = 3 if i == 0 else 64
            outdim = 64
            B = ConvBlock(indim, outdim, pool = ( i <4 ) ) #only pooling for fist 4 layers
            trunk.append(B)
            print(trunk)

        if flatten:
            trunk.append(Flatten())

        self.trunk = nn.Sequential(*trunk)
        self.final_feat_dim = 1600

    def forward(self,x):
        out = self.trunk(x)
        return out
    
def Conv4():
    print("Conv4")
    return ConvNet(4)

Conv4_original = Conv4()

Conv4
[ConvBlock(
  (C): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (BN): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (trunk): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
)]
[ConvBlock(
  (C): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (BN): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (trunk): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d

In [10]:
print('Conv4 Mine State Dict')
for key, value in Conv4_mine.state_dict().items():
    print(f'Key : {key}   ;   Value : {value.shape}')
print('')
print('============================================')
print('')
print('Conv4 theirs State Dict')
for key, value in Conv4_original.state_dict().items():
    print(f'Key : {key}   ;   Value : {value.shape}')

Conv4 Mine State Dict
Key : trunk.0.weight   ;   Value : torch.Size([64, 3, 3, 3])
Key : trunk.1.weight   ;   Value : torch.Size([64])
Key : trunk.1.bias   ;   Value : torch.Size([64])
Key : trunk.1.running_mean   ;   Value : torch.Size([64])
Key : trunk.1.running_var   ;   Value : torch.Size([64])
Key : trunk.1.num_batches_tracked   ;   Value : torch.Size([])
Key : trunk.4.weight   ;   Value : torch.Size([64, 64, 3, 3])
Key : trunk.5.weight   ;   Value : torch.Size([64])
Key : trunk.5.bias   ;   Value : torch.Size([64])
Key : trunk.5.running_mean   ;   Value : torch.Size([64])
Key : trunk.5.running_var   ;   Value : torch.Size([64])
Key : trunk.5.num_batches_tracked   ;   Value : torch.Size([])
Key : trunk.8.weight   ;   Value : torch.Size([64, 64, 3, 3])
Key : trunk.9.weight   ;   Value : torch.Size([64])
Key : trunk.9.bias   ;   Value : torch.Size([64])
Key : trunk.9.running_mean   ;   Value : torch.Size([64])
Key : trunk.9.running_var   ;   Value : torch.Size([64])
Key : trunk.9.nu