In [1]:
from fastai.vision.all import *

In [2]:
path = untar_data(URLs.PETS)/'images'
fnames = get_image_files(path)
pat = r'/([^/]+)_\d+.*'
batch_tfms = [*aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]
item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.))
bs=64

pets = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
     get_items=get_image_files,
     splitter=RandomSplitter(),
     get_y=RegexLabeller(pat = r'/([^/]+)_\d+.*'),
     item_tfms=item_tfms,
     batch_tfms=batch_tfms
)
dls = pets.dataloaders(path, bs=bs)

In [4]:
!pip install timm

Collecting timm
  Downloading timm-0.9.12-py3-none-any.whl.metadata (60 kB)
     ---------------------------------------- 0.0/60.6 kB ? eta -:--:--
     ------ --------------------------------- 10.2/60.6 kB ? eta -:--:--
     ------------ ------------------------- 20.5/60.6 kB 217.9 kB/s eta 0:00:01
     ------------------- ------------------ 30.7/60.6 kB 186.2 kB/s eta 0:00:01
     -------------------------------- ----- 51.2/60.6 kB 260.9 kB/s eta 0:00:01
     -------------------------------------- 60.6/60.6 kB 268.0 kB/s eta 0:00:00
Collecting huggingface-hub (from timm)
  Downloading huggingface_hub-0.20.3-py3-none-any.whl.metadata (12 kB)
Collecting safetensors (from timm)
  Downloading safetensors-0.4.2-cp311-none-win_amd64.whl.metadata (3.9 kB)
Downloading timm-0.9.12-py3-none-any.whl (2.2 MB)
   ---------------------------------------- 0.0/2.2 MB ? eta -:--:--
    --------------------------------------- 0.0/2.2 MB 2.0 MB/s eta 0:00:02
   -- ------------------------------------- 

In [5]:
from timm import create_model
net = create_model("vit_tiny_patch16_224", pretrained=True)

  from .autonotebook import tqdm as notebook_tqdm
model.safetensors: 100%|██████████| 22.9M/22.9M [00:01<00:00, 17.8MB/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [6]:
learn = vision_learner(dls, models.resnet18)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\kevol/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:01<00:00, 25.7MB/s]


In [7]:
learn.model[-1]

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): fastai.layers.Flatten(full=False)
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25, inplace=False)
  (4): Linear(in_features=1024, out_features=512, bias=False)
  (5): ReLU(inplace=True)
  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): Dropout(p=0.5, inplace=False)
  (8): Linear(in_features=512, out_features=37, bias=False)
)

In [8]:
net[-1]

TypeError: 'VisionTransformer' object is not subscriptable

In [9]:
class MyModel(nn.Module):
    def __init__(self):
        self.l1 = nn.Linear(1,1)
        self.l2 = nn.linear(1,1)
    def forward(self, x):
        return self.l2(self.l1(x))

In [10]:
class MyModel(nn.Sequential):
    def __init__(self):
        layers = [
            nn.Linear(1,1),
            nn.Linear(1,1),
        ]
        super().__init__(*layers)

In [11]:
net = MyModel()
net[0], net[1]

(Linear(in_features=1, out_features=1, bias=True),
 Linear(in_features=1, out_features=1, bias=True))

In [12]:
def custom_cut_model(model:nn.Module, cut:typing.Union[int, typing.Callable]):
    """
    Cuts `model` into an `nn.Sequential` based on `cut`. 
    """
    if isinstance(cut, int):
        return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut):
        return cut(model)
    else:
        raise NameError("`cut` must either be an integer or a function")

In [13]:
class CustomTimmBody(nn.Module):
    """
    A small submodule to work with `timm` models more easily
    """
    def __init__(
        self, 
        model, 
        pretrained:bool=True, 
        cut=None, 
        n_in:int=3
    ):
        super().__init__()
        self.needs_pooling = model.default_cfg.get('pool_size', None)
        if cut is None:
            self.model = model
        else:
            self.model = custom_cut_model(model, cut)
    
    def forward(self, x): 
        if self.needs_pooling:
            return self.model.forward_features(x)
        else:
            return self.model(x)

In [14]:
body = CustomTimmBody(
    create_model("vit_tiny_patch16_224", pretrained=True, num_classes=0, in_chans=3)
).train()

In [15]:
head = create_head(body.model.num_features, dls.c, pool=None)

In [16]:
head

Sequential(
  (0): BatchNorm1d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Dropout(p=0.25, inplace=False)
  (2): Linear(in_features=192, out_features=512, bias=False)
  (3): ReLU(inplace=True)
  (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=512, out_features=37, bias=False)
)

In [17]:
x = torch.randn(2,3,224,224)

In [18]:
out = head(body(x))
out, out.shape

(tensor([[ 0.2664, -1.4011,  0.9820,  0.2703, -0.1990, -0.6010, -0.3267,  1.3051,
          -0.4930, -1.2725, -0.3302,  0.4171,  0.4434, -0.0224, -0.3191, -0.9780,
          -0.3840, -0.3713, -0.2089,  1.1003,  0.3625, -0.2849,  0.2323,  0.8902,
           0.8680, -1.2252, -0.8373, -0.7704,  1.5352,  0.6110, -0.6006, -0.0841,
           0.9687, -0.8723, -0.3574, -1.4446, -0.5222],
         [ 0.8661,  1.3563, -0.1964,  0.2642, -0.0186,  0.3100, -0.5519,  0.1121,
           0.5343,  1.9768, -1.6937, -0.4661,  0.2464,  0.1269, -0.0559,  0.6972,
           0.2942,  1.0785,  0.5170, -0.8845, -0.1946,  0.3110,  0.5366, -0.6239,
           0.5172,  1.3059,  0.6528,  0.9186, -1.0411, -0.1129,  0.0933,  0.7153,
          -0.3450,  0.6841, -0.6316,  0.8465,  1.3900]], grad_fn=<MmBackward0>),
 torch.Size([2, 37]))

In [19]:
apply_init(head)

In [20]:
def my_split_func(model:nn.Module):
    "A function that splits layers by their parameters"
    return L(model[0], model[1:]).map(params)

In [21]:
def splitter(model):
    "Splits a model by head and body"
    return L(model[0], model[1]).map(params)

In [22]:
learn = Learner(
    dls,
    nn.Sequential(body, head),
    splitter=splitter
)

In [23]:
print(learn.summary()[-250:])

l trainable params: 5,605,056
Total non-trainable params: 0

Optimizer used: <function Adam at 0x000002095644B920>
Loss function: FlattenedLoss of CrossEntropyLoss()

Callbacks:
  - TrainEvalCallback
  - CastToTensor
  - Recorder
  - ProgressCallback


In [24]:
learn.freeze()

In [25]:
print(learn.summary()[-295:])

l trainable params: 128,256
Total non-trainable params: 5,476,800

Optimizer used: <function Adam at 0x000002095644B920>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #1

Callbacks:
  - TrainEvalCallback
  - CastToTensor
  - Recorder
  - ProgressCallback
