In [1]:
import timm
import torch
from argparse import Namespace
from ml import learner

from pipe import constants

In [2]:
hparams = Namespace(arch='resnet18')
in_channels = 1
CKPT_PATH = '../models/arch=resnet18_sz=128_fold=0.ckpt'

Let's instantiate an ImageClassifier without passing anything. As it stands, it will download pre-trained weights from the internet.

In [3]:
model = learner.ImageClassifier(in_channels=in_channels, num_classes=len(constants.target_cols), target_cols=constants.target_cols, **vars(hparams))

In [4]:
list(model.model.parameters())[0][0]

tensor([[[-0.0238, -0.0419, -0.0152,  0.2014,  0.1783,  0.0514, -0.0585],
         [ 0.0722,  0.0245, -0.3404, -0.8849, -0.8372, -0.4194, -0.0255],
         [ 0.0021,  0.2066,  0.9146,  1.7876,  1.5763,  0.7295,  0.2065],
         [ 0.0511, -0.2216, -0.9161, -1.2994, -0.7589,  0.0356,  0.2399],
         [-0.1118,  0.0292,  0.1983, -0.1444, -0.9120, -1.2490, -0.7778],
         [ 0.0842,  0.0939,  0.1247,  0.6146,  1.2016,  1.1198,  0.4827],
         [-0.0077,  0.0038, -0.0514, -0.1908, -0.4476, -0.2770, -0.0434]]],
       grad_fn=<SelectBackward>)

For reference, the one below are the weights our model learned during training.

In [5]:
checkpoint = torch.load(CKPT_PATH)
checkpoint['state_dict']['model.conv1.weight'][0]

tensor([[[-0.0079, -0.0354,  0.0197,  0.2234,  0.1945,  0.0575, -0.0791],
         [ 0.0843,  0.0221, -0.3155, -0.8817, -0.8399, -0.4252, -0.0661],
         [ 0.0102,  0.2022,  0.9309,  1.8050,  1.5994,  0.7278,  0.1703],
         [ 0.0692, -0.2240, -0.9167, -1.3040, -0.7457,  0.0461,  0.2296],
         [-0.1043,  0.0242,  0.2088, -0.1371, -0.8993, -1.2391, -0.7755],
         [ 0.0819,  0.0707,  0.1195,  0.6113,  1.2126,  1.1297,  0.4838],
         [ 0.0047, -0.0082, -0.0353, -0.1824, -0.4406, -0.2664, -0.0408]]],
       device='cuda:0')

Let's try to load the checkpoint from an already initialized ImageClassifier.

In [6]:
model.load_from_checkpoint(checkpoint_path=CKPT_PATH)
list(model.model.parameters())[0][0]

tensor([[[-0.0238, -0.0419, -0.0152,  0.2014,  0.1783,  0.0514, -0.0585],
         [ 0.0722,  0.0245, -0.3404, -0.8849, -0.8372, -0.4194, -0.0255],
         [ 0.0021,  0.2066,  0.9146,  1.7876,  1.5763,  0.7295,  0.2065],
         [ 0.0511, -0.2216, -0.9161, -1.2994, -0.7589,  0.0356,  0.2399],
         [-0.1118,  0.0292,  0.1983, -0.1444, -0.9120, -1.2490, -0.7778],
         [ 0.0842,  0.0939,  0.1247,  0.6146,  1.2016,  1.1198,  0.4827],
         [-0.0077,  0.0038, -0.0514, -0.1908, -0.4476, -0.2770, -0.0434]]],
       grad_fn=<SelectBackward>)

Interesting! Loading the weights from an already instantiate ImageClassifier doesn't work as you would expect. We are still using the same pre-trained weights shipped with timm, and not those we learned during training.

In [7]:
model2 = learner.ImageClassifier(in_channels=in_channels, num_classes=len(constants.target_cols), target_cols=constants.target_cols, **vars(hparams))
model2.load_from_checkpoint(checkpoint_path=CKPT_PATH)
list(model2.model.parameters())[0][0]

tensor([[[-0.0238, -0.0419, -0.0152,  0.2014,  0.1783,  0.0514, -0.0585],
         [ 0.0722,  0.0245, -0.3404, -0.8849, -0.8372, -0.4194, -0.0255],
         [ 0.0021,  0.2066,  0.9146,  1.7876,  1.5763,  0.7295,  0.2065],
         [ 0.0511, -0.2216, -0.9161, -1.2994, -0.7589,  0.0356,  0.2399],
         [-0.1118,  0.0292,  0.1983, -0.1444, -0.9120, -1.2490, -0.7778],
         [ 0.0842,  0.0939,  0.1247,  0.6146,  1.2016,  1.1198,  0.4827],
         [-0.0077,  0.0038, -0.0514, -0.1908, -0.4476, -0.2770, -0.0434]]],
       grad_fn=<SelectBackward>)

Below is the recommended way to load a checkpoint in PyTorch Lightning (PL).

In [8]:
model3 = learner.ImageClassifier.load_from_checkpoint(
    checkpoint_path='../models/arch=resnet18_sz=128_fold=0.ckpt',
)
list(model3.model.parameters())[0][0]

tensor([[[-0.0079, -0.0354,  0.0197,  0.2234,  0.1945,  0.0575, -0.0791],
         [ 0.0843,  0.0221, -0.3155, -0.8817, -0.8399, -0.4252, -0.0661],
         [ 0.0102,  0.2022,  0.9309,  1.8050,  1.5994,  0.7278,  0.1703],
         [ 0.0692, -0.2240, -0.9167, -1.3040, -0.7457,  0.0461,  0.2296],
         [-0.1043,  0.0242,  0.2088, -0.1371, -0.8993, -1.2391, -0.7755],
         [ 0.0819,  0.0707,  0.1195,  0.6113,  1.2126,  1.1297,  0.4838],
         [ 0.0047, -0.0082, -0.0353, -0.1824, -0.4406, -0.2664, -0.0408]]],
       grad_fn=<SelectBackward>)