# import

In [1]:
import torch
from torch import nn
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [24]:
def init_model_ResNet50(trainable_extractor = False):
    """
    Create an ResNet50 model.

    Returns:
        model (torch.nn.Module): ResNet50 model.
    """
    weights = torchvision.models.ResNet50_Weights.DEFAULT
    model = torchvision.models.resnet50(weights=weights).to(device)
    
    # Freeze all layers
    for param in model.parameters():
        param.requires_grad = trainable_extractor
    
    model.fc= nn.Sequential(
    torch.nn.Linear(2048,1000),
    torch.nn.ReLU(),
    torch.nn.Linear(1000,500),
    torch.nn.Dropout(),
    torch.nn.Linear(in_features=500,
                    out_features=2,
                    bias=True)
    ).to(device)

    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    
    # Recreate the classifier layer and seed it to the target device
    
    return model 

In [25]:
model_orinin = init_model_ResNet50()
summary(model_orinin, input_size=(1, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"], depth=3, row_settings=["var_names"])

Layer (type (var_name))                  Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
ResNet (ResNet)                          [1, 3, 224, 224]          [1, 2]                    --                        --                        --
├─Conv2d (conv1)                         [1, 3, 224, 224]          [1, 64, 112, 112]         (9,408)                   [7, 7]                    118,013,952
├─BatchNorm2d (bn1)                      [1, 64, 112, 112]         [1, 64, 112, 112]         (128)                     --                        128
├─ReLU (relu)                            [1, 64, 112, 112]         [1, 64, 112, 112]         --                        --                        --
├─MaxPool2d (maxpool)                    [1, 64, 112, 112]         [1, 64, 56, 56]           --                        3                         --
├─Sequential (layer1)                    [1, 64, 56, 56]           [1, 256, 56, 56]          --

In [None]:
summary(model_orinin.layer1, col_names=["num_params", "kernel_size"], depth=4)

In [114]:
summary(Channel_Attention(64))

Layer (type:depth-idx)                   Param #
Channel_Attention                        --
├─Sequential: 1-1                        --
│    └─Flatten: 2-1                      --
│    └─Linear: 2-2                       260
│    └─ReLU: 2-3                         --
│    └─Linear: 2-4                       320
Total params: 580
Trainable params: 580
Non-trainable params: 0

In [112]:
summary(ChannelAttention(64))

Layer (type:depth-idx)                   Param #
ChannelAttention                         --
├─AdaptiveAvgPool2d: 1-1                 --
├─AdaptiveMaxPool2d: 1-2                 --
├─Sequential: 1-3                        --
│    └─Conv2d: 2-1                       256
│    └─ReLU: 2-2                         --
│    └─Conv2d: 2-3                       256
├─Sigmoid: 1-4                           --
Total params: 512
Trainable params: 512
Non-trainable params: 0

# khoi tao model ResNet50 CBAM

In [162]:
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        return self.sigmoid(out)

class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        out = self.channel_attention(x) * x
        out = self.spatial_attention(out) * out
        return out

class BottleneckWithCBAM(torchvision.models.resnet.Bottleneck):
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BottleneckWithCBAM, self).__init__(inplanes, planes, stride, downsample)
        self.cbam = CBAM(planes * self.expansion)
        self.planes = planes

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out = self.cbam(out)

        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

def init_model_ResNet50_CBAM(trainable_extractor = False, device='cuda'):
    """
    Create an ResNet50 model with CBAM attention.

    Returns:
        model (torch.nn.Module): ResNet50 model with CBAM.
    """
    weights = torchvision.models.ResNet50_Weights.DEFAULT
    model = torchvision.models.resnet50(weights=weights).to(device)

    for name, module in model.named_children():
        if name not in ['conv1', 'bn1', 'relu', 'maxpool', 'fc']:
            for block_name, block in module.named_children():
                if isinstance(block, torchvision.models.resnet.Bottleneck):
                    inplanes = block.conv1.in_channels
                    planes = block.conv2.out_channels # Thử lấy planes từ conv2
                    stride = block.conv2.stride[0]
                    downsample = block.downsample
                    setattr(module, block_name, BottleneckWithCBAM(inplanes, planes, stride, downsample))

    # Freeze layers (tương tự như hàm gốc)
    for param in model.parameters():
        param.requires_grad = trainable_extractor

    # Thay thế lớp fully connected
    model.fc= nn.Sequential(
        torch.nn.Linear(2048,1024),
        torch.nn.ReLU(),
        torch.nn.Linear(1024,512),
        torch.nn.Dropout(),
        torch.nn.Linear(in_features=512,
                        out_features=2,
                        bias=True)
    )
    # Chuyển toàn bộ model lên device sau khi đã thực hiện các thay đổi
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    # model.fc = model.fc.to(device) # Đảm bảo cả lớp fc cũng được chuyển

    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    return model

In [163]:
model = init_model_ResNet50_CBAM(True)
weight = torchvision.models.ResNet50_Weights.DEFAULT
model.load_state_dict(weight.get_state_dict(progress=True), strict=False)

_IncompatibleKeys(missing_keys=['layer1.0.cbam.channel_attention.fc.0.weight', 'layer1.0.cbam.channel_attention.fc.2.weight', 'layer1.0.cbam.spatial_attention.conv.weight', 'layer1.1.cbam.channel_attention.fc.0.weight', 'layer1.1.cbam.channel_attention.fc.2.weight', 'layer1.1.cbam.spatial_attention.conv.weight', 'layer1.2.cbam.channel_attention.fc.0.weight', 'layer1.2.cbam.channel_attention.fc.2.weight', 'layer1.2.cbam.spatial_attention.conv.weight', 'layer2.0.cbam.channel_attention.fc.0.weight', 'layer2.0.cbam.channel_attention.fc.2.weight', 'layer2.0.cbam.spatial_attention.conv.weight', 'layer2.1.cbam.channel_attention.fc.0.weight', 'layer2.1.cbam.channel_attention.fc.2.weight', 'layer2.1.cbam.spatial_attention.conv.weight', 'layer2.2.cbam.channel_attention.fc.0.weight', 'layer2.2.cbam.channel_attention.fc.2.weight', 'layer2.2.cbam.spatial_attention.conv.weight', 'layer2.3.cbam.channel_attention.fc.0.weight', 'layer2.3.cbam.channel_attention.fc.2.weight', 'layer2.3.cbam.spatial_atten

In [187]:
for name, param in model.named_parameters():
    if name in weight.get_state_dict():
        param.requires_grad = False



In [None]:
summary(model, input_size=(16, 3, 224, 224), row_settings=["var_names"], col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds", "trainable"], depth=5)

Layer (type (var_name))                                      Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds                 Trainable
ResNet (ResNet)                                              [16, 3, 224, 224]         [16, 2]                   --                        --                        --                        Partial
├─Conv2d (conv1)                                             [16, 3, 224, 224]         [16, 64, 112, 112]        (9,408)                   [7, 7]                    1,888,223,232             False
├─BatchNorm2d (bn1)                                          [16, 64, 112, 112]        [16, 64, 112, 112]        (128)                     --                        2,048                     False
├─ReLU (relu)                                                [16, 64, 112, 112]        [16, 64, 112, 112]        --                        --                        --                        --
├─MaxPool2d 

: 

In [148]:
summary(model.layer1, col_names=["num_params", "trainable"], depth=8)

Layer (type:depth-idx)                        Param #                   Trainable
Sequential                                    --                        False
├─BottleneckWithCBAM: 1-1                     --                        False
│    └─Conv2d: 2-1                            (4,096)                   False
│    └─BatchNorm2d: 2-2                       (128)                     False
│    └─Conv2d: 2-3                            (36,864)                  False
│    └─BatchNorm2d: 2-4                       (128)                     False
│    └─Conv2d: 2-5                            (16,384)                  False
│    └─BatchNorm2d: 2-6                       (512)                     False
│    └─ReLU: 2-7                              --                        --
│    └─Sequential: 2-8                        --                        False
│    │    └─Conv2d: 3-1                       (16,384)                  False
│    │    └─BatchNorm2d: 3-2                  (512)            

In [None]:
summary(model_orinin.layer1, col_names=["num_params", "trainable"], depth=5)

Layer (type:depth-idx)                   Param #                   Trainable
Sequential                               --                        False
├─Bottleneck: 1-1                        --                        False
│    └─Conv2d: 2-1                       (4,096)                   False
│    └─BatchNorm2d: 2-2                  (128)                     False
│    └─Conv2d: 2-3                       (36,864)                  False
│    └─BatchNorm2d: 2-4                  (128)                     False
│    └─Conv2d: 2-5                       (16,384)                  False
│    └─BatchNorm2d: 2-6                  (512)                     False
│    └─ReLU: 2-7                         --                        --
│    └─Sequential: 2-8                   --                        False
│    │    └─Conv2d: 3-1                  (16,384)                  False
│    │    └─BatchNorm2d: 3-2             (512)                     False
├─Bottleneck: 1-2                        --       

: 

# Khoi tao model v2

In [101]:
class Channel_Attention(nn.Module):
    '''Channel Attention in CBAM.
    '''

    def __init__(self, channel_in, reduction_ratio=16, pool_types=['avg', 'max']):
        '''Param init and architecture building.
        '''

        super(Channel_Attention, self).__init__()
        self.pool_types = pool_types

        self.shared_mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=channel_in, out_features=channel_in//reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=channel_in//reduction_ratio, out_features=channel_in)
        )


    def forward(self, x):
        '''Forward Propagation.
        '''

        channel_attentions = []

        for pool_types in self.pool_types:
            if pool_types == 'avg':
                pool_init = nn.AvgPool2d(kernel_size=(x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                avg_pool = pool_init(x)
                channel_attentions.append(self.shared_mlp(avg_pool))
            elif pool_types == 'max':
                pool_init = nn.MaxPool2d(kernel_size=(x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                max_pool = pool_init(x)
                channel_attentions.append(self.shared_mlp(max_pool))

        pooling_sums = torch.stack(channel_attentions, dim=0).sum(dim=0)
        scaled = nn.Sigmoid()(pooling_sums).unsqueeze(2).unsqueeze(3).expand_as(x)

        return x * scaled #return the element-wise multiplication between the input and the result.


class ChannelPool(nn.Module):
    '''Merge all the channels in a feature map into two separate channels where the first channel is produced by taking the max values from all channels, while the
       second one is produced by taking the mean from every channel.
    '''
    def forward(self, x):
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)


class Spatial_Attention(nn.Module):
    '''Spatial Attention in CBAM.
    '''

    def __init__(self, kernel_size=7):
        '''Spatial Attention Architecture.
        '''

        super(Spatial_Attention, self).__init__()

        self.compress = ChannelPool()
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, stride=1, dilation=1, padding=(kernel_size-1)//2, bias=False),
            nn.BatchNorm2d(num_features=1, eps=1e-5, momentum=0.01, affine=True)
        )


    def forward(self, x):
        '''Forward Propagation.
        '''
        x_compress = self.compress(x)
        x_output = self.spatial_attention(x_compress)
        scaled = nn.Sigmoid()(x_output)
        return x * scaled


class CBAM(nn.Module):
    '''CBAM architecture.
    '''
    def __init__(self, channel_in, reduction_ratio=16, pool_types=['avg', 'max'], spatial=True):
        '''Param init and arch build.
        '''
        super(CBAM, self).__init__()
        self.spatial = spatial

        self.channel_attention = Channel_Attention(channel_in=channel_in, reduction_ratio=reduction_ratio, pool_types=pool_types)

        if self.spatial:
            self.spatial_attention = Spatial_Attention(kernel_size=7)


    def forward(self, x):
        '''Forward Propagation.
        '''
        x_out = self.channel_attention(x)
        if self.spatial:
            x_out = self.spatial_attention(x_out)

        return x_out
    
'''
ResNet-50 Architecture.
'''

class BottleNeck(nn.Module):
    '''Bottleneck modules
    '''

    def __init__(self, in_channels, out_channels, expansion=4, stride=1, use_cbam=True):
        '''Param init.
        '''
        super(BottleNeck, self).__init__()

        self.use_cbam = use_cbam
        #only the first conv will be affected by the given stride parameter. The rest have default stride value (which is 1).
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False, stride=stride)
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)
        self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels*expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(num_features=out_channels*expansion)
        self.relu = nn.ReLU(inplace=True)

        #since the input has to be same size with the output during the identity mapping, whenever the stride or the number of output channels are
        #more than 1 and expansion*out_channels respectively, the input, x, has to be downsampled to the same level as well.
        self.downsample = nn.Sequential()
        if stride != 1 or in_channels != expansion*out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=expansion*out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(num_features=out_channels*expansion)
            )

        if self.use_cbam:
            self.cbam = CBAM(channel_in=out_channels*expansion)


    def forward(self, x):
        '''Forward Propagation.
        '''

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))

        if self.use_cbam:
            out = self.cbam(out)

        out += self.downsample(x) #identity connection/skip connection
        out = self.relu(out)

        return out


class ResNet50(nn.Module):
    '''ResNet-50 Architecture.
    '''

    def __init__(self, use_cbam=True, image_depth=3, num_classes=6):
        '''Params init and build arch.
        '''
        super(ResNet50, self).__init__()

        self.in_channels = 64
        self.expansion = 4
        self.num_blocks = [3, 3, 3, 2]

        self.conv1  = nn.Conv2d(kernel_size=7, stride=2, in_channels=image_depth, out_channels=self.in_channels, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # self.conv_block1 = nn.Sequential(nn.Conv2d(kernel_size=7, stride=2, in_channels=image_depth, out_channels=self.in_channels, padding=3, bias=False),
        #                                     nn.BatchNorm2d(self.in_channels),
        #                                     nn.ReLU(inplace=True),
        #                                     nn.MaxPool2d(stride=2, kernel_size=3, padding=1))

        self.layer1 = self.make_layer(out_channels=64, num_blocks=self.num_blocks[0], stride=1, use_cbam=use_cbam)
        self.layer2 = self.make_layer(out_channels=128, num_blocks=self.num_blocks[1], stride=2, use_cbam=use_cbam)
        self.layer3 = self.make_layer(out_channels=256, num_blocks=self.num_blocks[2], stride=2, use_cbam=use_cbam)
        self.layer4 = self.make_layer(out_channels=512, num_blocks=self.num_blocks[3], stride=2, use_cbam=use_cbam)
        self.avgpool = nn.AvgPool2d(7)
        # self.linear = nn.Linear(512*self.expansion, num_classes)
        self.linear = torch.nn.Sequential(
            torch.nn.Linear(2048,1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024,512),
            torch.nn.Dropout(),
            torch.nn.Linear(in_features=512,
                            out_features=2,
                            bias=True)
            )

    def make_layer(self, out_channels, num_blocks, stride, use_cbam):
        '''To construct the bottleneck layers.
        '''
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(BottleNeck(in_channels=self.in_channels, out_channels=out_channels, stride=stride, expansion=self.expansion, use_cbam=use_cbam))
            self.in_channels = out_channels * self.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        '''Forward propagation of ResNet-50.
        '''
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x_conv = self.layer4(x)
        x = self.avgpool(x_conv)
        x = nn.Flatten()(x) #flatten the feature maps.
        x = self.linear(x)

        return x_conv, x

In [96]:
# for param in model.parameters():
#     param.requires_grad = trainable_extractor

model = ResNet50(use_cbam=True, image_depth=3, num_classes=2)
# Thay thế lớp fully connected
# Chuyển toàn bộ model lên device sau khi đã thực hiện các thay đổi
model = model.to(device)

In [102]:
import torch
import torchvision.models as models

# Tải mô hình ResNet50 đã được huấn luyện trước từ torchvision
pretrained_model = models.resnet50(pretrained=True)
pretrained_dict = pretrained_model.state_dict()

model_dict = model.state_dict()

# Lọc các keys không khớp giữa pretrained_dict và model_dict
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].shape == v.shape}
print(pretrained_dict.keys())
print(model_dict.keys())
print(pretrained_dict.keys())
# Cập nhật model_dict với các weights từ pretrained_dict
model_dict.update(pretrained_dict)

# Load state dict vào model
model.load_state_dict(model_dict, strict=False)



odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.bn3.num_batches_tracked', 'layer1.0.downsample.0.weight', 'layer1.0.downsample.1.weight', 'layer1.0.downsample.1.bias', 'layer1.0.downsample.1.running_mean', 'layer1.0.downsample.1.running_var', 'layer1.0.downsample.1.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.conv2.we

_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias', 'layer2.3.conv1.weight', 'layer2.3.bn1.weight', 'layer2.3.bn1.bias', 'layer2.3.bn1.running_mean', 'layer2.3.bn1.running_var', 'layer2.3.bn1.num_batches_tracked', 'layer2.3.conv2.weight', 'layer2.3.bn2.weight', 'layer2.3.bn2.bias', 'layer2.3.bn2.running_mean', 'layer2.3.bn2.running_var', 'layer2.3.bn2.num_batches_tracked', 'layer2.3.conv3.weight', 'layer2.3.bn3.weight', 'layer2.3.bn3.bias', 'layer2.3.bn3.running_mean', 'layer2.3.bn3.running_var', 'layer2.3.bn3.num_batches_tracked', 'layer3.3.conv1.weight', 'layer3.3.bn1.weight', 'layer3.3.bn1.bias', 'layer3.3.bn1.running_mean', 'layer3.3.bn1.running_var', 'layer3.3.bn1.num_batches_tracked', 'layer3.3.conv2.weight', 'layer3.3.bn2.weight', 'layer3.3.bn2.bias', 'layer3.3.bn2.running_mean', 'layer3.3.bn2.running_var', 'layer3.3.bn2.num_batches_tracked', 'layer3.3.conv3.weight', 'layer3.3.bn3.weight', 'layer3.3.bn3.bias', 'layer3.3.bn3.running_mean', 'layer3.3.bn3.ru

In [107]:
summary(model, input_size=(16, 3, 224, 224), row_settings=["var_names"],col_names=["input_size", "output_size", "num_params", "trainable"], depth=7)

Layer (type (var_name))                                      Input Shape               Output Shape              Param #                   Trainable
ResNet50 (ResNet50)                                          [16, 3, 224, 224]         [16, 2048, 7, 7]          --                        True
├─Conv2d (conv1)                                             [16, 3, 224, 224]         [16, 64, 112, 112]        9,408                     True
├─BatchNorm2d (bn1)                                          [16, 64, 112, 112]        [16, 64, 112, 112]        128                       True
├─ReLU (relu)                                                [16, 64, 112, 112]        [16, 64, 112, 112]        --                        --
├─MaxPool2d (maxpool)                                        [16, 64, 112, 112]        [16, 64, 56, 56]          --                        --
├─Sequential (layer1)                                        [16, 64, 56, 56]          [16, 256, 56, 56]         --                    

# Train

In [None]:
def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer):
    # Put model in train mode
    model.train()

    # Setup train loss and train accuracy values
    train_loss, train_acc = 0, 0

    # Loop through data loader data batches
    for (X, y) in tqdm(dataloader, desc="Batch"):
        # Send data to target device
        # print("\rbatch: " + str(batch) + "/" + str(round(int(100000/64))), end = "")
        X, y = X.to(device), y.to(device)

        # 1. Forward pass
        y_pred = model(X)

        # 2. Calculate  and accumulate loss
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()

        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backward
        loss.backward()

        # 5. Optimizer step
        optimizer.step()

        # Calculate and accumulate accuracy metric across all batches
        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        train_acc += (y_pred_class == y).sum().item()/len(y_pred)

    # Adjust metrics to get average loss and accuracy per batch
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc

def test_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: torch.nn.Module):
    # Put model in eval mode
    model.eval()

    # Setup test loss and test accuracy values
    test_loss, test_acc = 0, 0

    # Turn on inference context manager
    with torch.inference_mode():
        # Loop through DataLoader batches
        for batch, (X, y) in enumerate(dataloader):
            # Send data to target device
            X, y = X.to(device), y.to(device)

            # 1. Forward pass
            test_pred_logits = model(X)

            # 2. Calculate and accumulate loss
            loss = loss_fn(test_pred_logits, y)
            test_loss += loss.item()

            # Calculate and accumulate accuracy
            test_pred_labels = test_pred_logits.argmax(dim=1)
            test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))

    # Adjust metrics to get average loss and accuracy per batch
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return test_loss, test_acc

def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),
          checkpoint_model_name: str = "",
          epochs: int = 5,
          pretrained: str = None):
    # 1. Take in various parameters required for training and test steps

    # 2. Create empty results dictionary
    results = {"train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": []
    }
    
    if pretrained:
        model_state_dict, optimizer_state_dict, start_epoch = load_checkpoint(pretrained)
        model.load_state_dict(model_state_dict)
        optimizer.load_state_dict(optimizer_state_dict)
    else:
        start_epoch = 0
    # 3. Loop through training and testing steps for a number of epochs
    for epoch in range(start_epoch+1, start_epoch + epochs):
        print("Epoch:",epoch)
        train_loss, train_acc = train_step(model=model,
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn,
                                           optimizer=optimizer)
        test_loss, test_acc = test_step(model=model,
            dataloader=test_dataloader,
            loss_fn=loss_fn)
        # 4. Print out what's happening
        print(
            f"Epoch: {epoch} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "
            f"test_loss: {test_loss:.4f} | "
            f"test_acc: {test_acc:.4f}"
        )

        # 5. Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

        # 6. Save Checkpoints
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'train_acc': train_acc,
            'test_loss': test_loss,
            'test_acc': test_acc
        }
        torch.save(checkpoint, f"checkpoints/{checkpoint_model_name}_epoch_{epoch:02d}.pth")
        
    # 7. Return the filled results at the end of the epochs
    return results

def load_data(train_dir: str, valid_dir: str, batch_size: int = 32):
    # Define transforms
    weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
    auto_transforms = weights.transforms()

    # Load data
    train_data = datasets.ImageFolder(train_dir, transform=auto_transforms, target_transform = None)
    valid_data = datasets.ImageFolder(valid_dir, transform=auto_transforms)

    # Create data loaders
    train_dataloader = DataLoader(train_data, batch_size=batch_size, num_workers=1, shuffle=True)
    valid_dataloader = DataLoader(valid_data, batch_size=batch_size, num_workers=1, shuffle=False)

    return train_dataloader, valid_dataloader

device="cuda" if torch.cuda.is_available() else "cpu"
device= "cpu"
print(f"device: {device}")
train_dir="archive/dataset/train"
valid_dir="archive/dataset/valid"
test_dir="archive/dataset/test"

model = init_model_ResNet50_CBAM()
# Train
batch_size = 16
train_dataloader, test_dataloader = load_data(train_dir, valid_dir, batch_size)

train(model=model,
        train_dataloader=train_dataloader,
        test_dataloader=test_dataloader,
        optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
        checkpoint_model_name="ResNet50_CBAM",
        epochs=10,
        pretrained="")


device: cpu
Epoch: 1


Batch:   0%|          | 2/22052 [00:30<93:42:48, 15.30s/it]


KeyboardInterrupt: 