In [None]:
import torch
from torch import nn
import torchvision
from torchvision import models, transforms
from torchvision.models import resnet50, ResNet50_Weights
from torchsummary import summary

print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)

PyTorch Version:  2.0.1+cu118
Torchvision Version:  0.15.2+cu118


#### Model

torchvision models: https://pytorch.org/vision/stable/models.html

In [None]:
FREEZE = False

In [None]:
# build model & load pre-trained weights
# model = models.resnet50(weights=ResNet50_Weights.DEFAULT)

#
from torchvision.models import GoogLeNet_Weights
model = models.googlenet(weights=GoogLeNet_Weights.DEFAULT)
print(model)

Downloading: "https://download.pytorch.org/models/googlenet-1378be20.pth" to /root/.cache/torch/hub/checkpoints/googlenet-1378be20.pth
100%|██████████| 49.7M/49.7M [00:00<00:00, 85.5MB/s]

GoogLeNet(
  (conv1): BasicConv2d(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (conv2): BasicConv2d(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): BasicConv2d(
    (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception3a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track




In [None]:
# replace classifier
num_features = model.fc.in_features # len of feature vectors

# Freeze model
if FREEZE:
    for param in model.parameters():
        param.requires_grad = False

# Replace classifier
model.fc = nn.Linear(num_features, 6)
print(model.fc)


Linear(in_features=1024, out_features=6, bias=True)


#### Dataset

In [None]:
class TransferDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths, transform):
        self.img_paths = img_paths
        # transform from pre-trained model
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        # Read img
        path = self.img_paths[idx]
        img = Image.open(path)
        # transform img
        img = self.transform(img)
        return img, cls_idx

In [None]:
from torchvision.models import resnet50, ResNet50_Weights

# Preprocess Transform
transform =  ResNet50_Weights.DEFAULT.transforms()
# transform = GoogLeNet_Weights.DEFAULT.transforms()
print(transform)

# Build Dataset
ds = TransferDataset([], transform)

ImageClassification(
    crop_size=[224]
    resize_size=[232]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


```python
from torchvision.models import resnet50, ResNet50_Weights

resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
resnet50(weights=ResNet50_Weights.DEFAULT)

# Strings are also supported
resnet50(weights="IMAGENET1K_V2")

# No weights - random initialization
resnet50(weights=None)

# Old version
# pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True)  # deprecated
resnet50(True)  # deprecated

# no weights:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False)  # deprecated
resnet50(False)  # deprecated
```

#### Ref:

[Official: Transfer learning tutorials](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)

[Official: Finetuning torchvision models tutorial](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html#load-data)

[Accessing-and-modifying-different-layers-of-a-pretrained-model-in-pytorch](https://github.com/mortezamg63/Accessing-and-modifying-different-layers-of-a-pretrained-model-in-pytorch)