In [3]:
import torch
from torch import nn
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [11]:
class ResNet50ViT(nn.Module):
    def __init__(self, num_classes=2, img_size=224):  # Thêm img_size để tính số lượng token tối đa
        super(ResNet50ViT, self).__init__()
        # Load pre-trained ResNet50
        resnet50 = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        #freeze all layers
        for param in resnet50.parameters():
            param.requires_grad = False
        
        self.resnet50 = nn.Sequential(*list(resnet50.children())[:-2])  # Remove the fully connected layer

        # Conv2D to reduce ResNet50 output channels to match ViT embedding size
        self.conv2d = nn.Conv2d(in_channels=2048, out_channels=768, kernel_size=1)

        # Class token
        self.class_token = nn.Parameter(torch.randn(1, 1, 768))  # Class token

        # Calculate the number of tokens based on the expected output size of ResNet50
        # For ResNet50 with an input size of 224x224, the output feature map before pooling is usually 7x7.
        num_patches = (img_size // 32) ** 2 # Assuming the total downsampling of ResNet50 is 32 (can vary slightly)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, 768)) # +1 for class token


        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=2048, dropout=0.1, batch_first=True),
            num_layers=6
        )

        # Fully connected layer for classification
        self.fc = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.3), 
            nn.Linear(512, num_classes)
        )


    def forward(self, x):
        # ResNet50 feature extraction
        x = self.resnet50(x)  # Output shape: [batch_size, 2048, H, W]
        x = self.conv2d(x)  # Reduce channels to 768, shape: [batch_size, 768, H, W]

        # Flatten spatial dimensions into tokens
        batch_size, channels, height, width = x.shape
        num_tokens = height * width
        x = x.permute(0, 2, 3, 1).reshape(batch_size, num_tokens, channels)  # Shape: [batch_size, H*W, 768]

        # Add class token
        class_token = self.class_token.expand(batch_size, -1, -1)  # Shape: [batch_size, 1, 768]
        x = torch.cat([class_token, x], dim=1)  # Shape: [batch_size, H*W+1, 768]

        # Add position embeddings
        position_embeddings = self.pos_embedding[:, :num_tokens + 1, :].expand(batch_size, -1, -1)
        x = x + position_embeddings  # Add position embeddings

        # Transformer Encoder
        x = self.transformer_encoder(x)  # Shape: [batch_size, H*W+1, 768]

        # Classification head (use class token output)
        x = self.fc(x[:, 0])  # Use the class token for classification
        return x

In [12]:
from torchinfo import summary

ResNet50ViT = ResNet50ViT()
summary(ResNet50ViT, input_size=(32, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "trainable"])

Layer (type:depth-idx)                             Input Shape               Output Shape              Param #                   Trainable
ResNet50ViT                                        [32, 3, 224, 224]         [32, 2]                   39,168                    Partial
├─Sequential: 1-1                                  [32, 3, 224, 224]         [32, 2048, 7, 7]          --                        False
│    └─Conv2d: 2-1                                 [32, 3, 224, 224]         [32, 64, 112, 112]        (9,408)                   False
│    └─BatchNorm2d: 2-2                            [32, 64, 112, 112]        [32, 64, 112, 112]        (128)                     False
│    └─ReLU: 2-3                                   [32, 64, 112, 112]        [32, 64, 112, 112]        --                        --
│    └─MaxPool2d: 2-4                              [32, 64, 112, 112]        [32, 64, 56, 56]          --                        --
│    └─Sequential: 2-5                             [32,

In [None]:

model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
for param in model.parameters():
    param.requires_grad = False

summary(model, input_size=(1, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "trainable"])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Trainable
ResNet                                   [1, 3, 224, 224]          [1, 1000]                 --                        False
├─Conv2d: 1-1                            [1, 3, 224, 224]          [1, 64, 112, 112]         (9,408)                   False
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         [1, 64, 112, 112]         (128)                     False
├─ReLU: 1-3                              [1, 64, 112, 112]         [1, 64, 112, 112]         --                        --
├─MaxPool2d: 1-4                         [1, 64, 112, 112]         [1, 64, 56, 56]           --                        --
├─Sequential: 1-5                        [1, 64, 56, 56]           [1, 256, 56, 56]          --                        False
│    └─Bottleneck: 2-1                   [1, 64, 56, 56]           [1, 256, 56, 56]          --                        False
│ 