In [6]:
from unlabeled_extrapolation.models import bit_resnet, vit_model, timm_model
from unlabeled_extrapolation.utils import utils
import importlib
import timm
import torch
importlib.reload(bit_resnet)
importlib.reload(vit_model)
importlib.reload(utils)
importlib.reload(timm_model)

<module 'unlabeled_extrapolation.models.timm_model' from '/juice/scr/ananya/cifar_experiments/transfer_learning/unlabeled_extrapolation/models/timm_model.py'>

In [2]:
def get_layers_freeze_test(model):
    print('num params before freezing: ', utils.count_parameters(model, trainable=True))
    print(model.get_layers()[1])
    print(len(model.get_layers()))
    for k in [1, 2, len(model.get_layers())]:
        model.freeze_bottom_k(k=k)
        print(f'num params after freezing {k}: {utils.count_parameters(model, trainable=True)}')


In [50]:
# Bit-resnet (get layers) for ResNet-50 and ResNet-101

resnet50_checkpoint_path = "/u/scr/ananya/simclr_weights/BiT-M-R50x1.npz"
resnet50 = bit_resnet.BitResNet(model_name='BiT-M-R50x1', checkpoint_path=resnet50_checkpoint_path)
get_layers_freeze_test(resnet50)

resnet101_checkpoint_path = "/u/scr/ananya/simclr_weights/BiT-M-R101x1.npz"
resnet101 = bit_resnet.BitResNet(model_name='BiT-M-R101x1', checkpoint_path=resnet101_checkpoint_path)
get_layers_freeze_test(resnet101)

num params before freezing:  68256659
('conv-block-0', Sequential(
  (unit01): PreActBottleneck(
    (gn1): GroupNorm(32, 64, eps=1e-05, affine=True)
    (conv1): StdConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (gn2): GroupNorm(32, 64, eps=1e-05, affine=True)
    (conv2): StdConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn3): GroupNorm(32, 64, eps=1e-05, affine=True)
    (conv3): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (relu): ReLU(inplace=True)
    (downsample): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (unit02): PreActBottleneck(
    (gn1): GroupNorm(32, 256, eps=1e-05, affine=True)
    (conv1): StdConv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (gn2): GroupNorm(32, 64, eps=1e-05, affine=True)
    (conv2): StdConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (gn3): GroupNorm(32, 64, eps=1e-05, affine=True)
    (conv3)

In [57]:
# DINO
model = vit_model.VitModel(model_name='dino_vitb16')
model.new_last_layer(10)
get_layers_freeze_test(model)


Downloading: "https://github.com/facebookresearch/dino/archive/main.zip" to /sailhome/ananya/.cache/torch/hub/main.zip


num params before freezing:  85806346
('empty_ln_pre', Module())
54
1
num params after freezing 1: 85215754
2
num params after freezing 2: 85215754
54
num params after freezing 54: 0


In [48]:
# Timm ViT-S
model = vit_model.VitModel(model_name='timm.vit_small_patch16_224')
get_layers_freeze_test(model)


num params before freezing:  22050664
('empty_ln_pre', Module())
54
num params after freezing 1: 21755368
num params after freezing 2: 21755368
num params after freezing 54: 0


In [3]:
# Conv-next
model = timm_model.TimmModel('convnext_base_in22k')
get_layers_freeze_test(model)

num params before freezing:  109953489
('stage0', ConvNeXtStage(
  (downsample): Identity()
  (blocks): Sequential(
    (0): ConvNeXtBlock(
      (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
      (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=128, out_features=512, bias=True)
        (act): GELU()
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=512, out_features=128, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
    )
    (1): ConvNeXtBlock(
      (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
      (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=128, out_features=512, bias=True)
        (act): GELU()
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=512

In [9]:
x = torch.zeros((8,3,224,224))
x = x.cuda()
model(x)

tensor([[-2.1558,  0.1969, -2.2230,  ..., -2.2479,  0.7122,  0.2625],
        [-2.1558,  0.1969, -2.2230,  ..., -2.2479,  0.7122,  0.2625],
        [-2.1558,  0.1969, -2.2230,  ..., -2.2479,  0.7122,  0.2625],
        ...,
        [-2.1558,  0.1969, -2.2230,  ..., -2.2479,  0.7122,  0.2625],
        [-2.1558,  0.1969, -2.2230,  ..., -2.2479,  0.7122,  0.2625],
        [-2.1558,  0.1969, -2.2230,  ..., -2.2479,  0.7122,  0.2625]],
       device='cuda:0')

In [10]:
print(model._model.head)

Sequential(
  (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Identity())
  (norm): LayerNorm2d((1024,), eps=1e-06, elementwise_affine=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (drop): Dropout(p=0.0, inplace=False)
  (fc): Linear(in_features=1024, out_features=21841, bias=True)
)
