In [4]:
import torch
from torchinfo import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SFAModule(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(SFAModule, self).__init__()
        self.in_channels = in_channels
        self.reduction_ratio = reduction_ratio

        # Spatial Attention Branch (using strip pooling)
        self.spatial_x_pool = nn.AdaptiveAvgPool2d((None, 1))
        self.spatial_y_pool = nn.AdaptiveAvgPool2d((1, None))
        self.spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

        # Frequency Attention Branch
        self.freq_conv_real = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
        self.freq_conv_imag = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
        self.freq_conv_out = nn.Conv2d(in_channels // reduction_ratio, 1, kernel_size=1) # Output 1 channel for spatial application

        # Channel Attention Branch (using Squeeze-and-Excitation)
        self.channel_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.channel_fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction_ratio, in_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        batch_size, channels, height, width = x.size()

        # Spatial Attention
        spatial_x = self.spatial_x_pool(x)
        spatial_y = self.spatial_y_pool(x).permute(0, 1, 3, 2)
        spatial_att = self.sigmoid(self.spatial_conv(spatial_x + spatial_y))
        x_spatial = x * spatial_att

        # Frequency Attention
        # Perform FFT
        x_fft = torch.fft.fft2(torch.fft.fftshift(x, dim=(-2, -1)), norm="ortho")
        x_real = x_fft.real
        x_imag = x_fft.imag
        # Process real and imaginary parts
        freq_real = self.freq_conv_real(x_real)
        freq_imag = self.freq_conv_imag(x_imag)
        freq_att = self.sigmoid(self.freq_conv_out(freq_real + freq_imag))
        x_freq = x * freq_att # Apply in spatial domain for simplicity

        # Channel Attention
        channel_avg = self.channel_avg_pool(x).squeeze(-1).squeeze(-1)
        channel_att = self.channel_fc(channel_avg).unsqueeze(-1).unsqueeze(-1)
        x_channel = x * channel_att

        # Fusion (simple element-wise addition)
        out = x_spatial + x_freq + x_channel

        return out

if __name__ == '__main__':
    # Ví dụ sử dụng module SFA
    in_channels = 256
    batch_size = 4
    height = 224
    width = 224

    input_tensor = torch.randn(batch_size, in_channels, height, width)
    sfa_module = SFAModule(in_channels)
    output_tensor = sfa_module(input_tensor)

    print("Input shape:", input_tensor.shape)
    print("Output shape:", output_tensor.shape)

Input shape: torch.Size([4, 256, 224, 224])
Output shape: torch.Size([4, 256, 224, 224])


In [53]:
from torchinfo import summary
summary(sfa_module, depth=3, row_settings=["var_names"], input_size=(32, 256, 64, 64), col_names=["input_size", "output_size", "num_params"])

Layer (type (var_name))                       Input Shape               Output Shape              Param #
SFAModule (SFAModule)                         [32, 256, 64, 64]         [32, 256, 64, 64]         --
├─AdaptiveAvgPool2d (spatial_x_pool)          [32, 256, 64, 64]         [32, 256, 64, 1]          --
├─AdaptiveAvgPool2d (spatial_y_pool)          [32, 256, 64, 64]         [32, 256, 1, 64]          --
├─Conv2d (spatial_conv)                       [32, 256, 64, 1]          [32, 1, 64, 1]            256
├─Sigmoid (sigmoid)                           [32, 1, 64, 1]            [32, 1, 64, 1]            --
├─Conv2d (freq_conv_real)                     [32, 256, 64, 64]         [32, 16, 64, 64]          4,112
├─Conv2d (freq_conv_imag)                     [32, 256, 64, 64]         [32, 16, 64, 64]          4,112
├─Conv2d (freq_conv_out)                      [32, 16, 64, 64]          [32, 1, 64, 64]           17
├─Sigmoid (sigmoid)                           [32, 1, 64, 64]           [32, 1,

In [5]:
import torch
import torch.nn as nn
import torchvision.models as models

class BottleneckWithSFA(nn.Module):
    def __init__(self, bottleneck):
        super(BottleneckWithSFA, self).__init__()
        self.conv1 = bottleneck.conv1
        self.bn1 = bottleneck.bn1
        self.conv2 = bottleneck.conv2
        self.bn2 = bottleneck.bn2
        self.conv3 = bottleneck.conv3
        self.bn3 = bottleneck.bn3
        self.relu = bottleneck.relu
        self.sfa = SFAModule(bottleneck.conv3.out_channels)
        # self.sfa = SFAModule(bottleneck.expansion * bottleneck.conv3.out_channels)
        self.expansion = bottleneck.expansion
        self.out_channels = bottleneck.conv3.out_channels
        self.stride = bottleneck.stride
        self.expansion = bottleneck.expansion
        self.downsample = bottleneck.downsample

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out) 
        out = self.sfa(out)# Apply SFA after the convolutional layers

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNet50WithSFAIntegrated(nn.Module):
    def __init__(self, num_classes=1000, pretrained=True):
        super(ResNet50WithSFAIntegrated, self).__init__()
        weight = models.ResNet50_Weights.DEFAULT
        resnet50 = models.resnet50(weights=weight)
        self.conv1 = resnet50.conv1
        self.bn1 = resnet50.bn1
        self.relu = resnet50.relu
        self.maxpool = resnet50.maxpool

        self.layer1 = self._make_layer(resnet50.layer1)
        self.layer2 = self._make_layer(resnet50.layer2)
        self.layer3 = self._make_layer(resnet50.layer3)
        self.layer4 = self._make_layer(resnet50.layer4)

        self.avgpool = resnet50.avgpool
        self.fc = nn.Sequential(
            torch.nn.Linear(2048,1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024,512),
            torch.nn.Dropout(),
            torch.nn.Linear(in_features=512,
                            out_features=2,
                            bias=True)
        ) 

    def _make_layer(self, layer):
        blocks = []
        for bottleneck in layer:
            blocks.append(BottleneckWithSFA(bottleneck))
        return nn.Sequential(*blocks)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

    # Ví dụ sử dụng ResNet50WithSFAIntegrated
model_sfa = ResNet50WithSFAIntegrated(num_classes=2)
summary(model_sfa, depth=4, row_settings=["var_names"], input_size=(32, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds", "trainable"])

Layer (type (var_name))                                      Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds                 Trainable
ResNet50WithSFAIntegrated (ResNet50WithSFAIntegrated)        [32, 3, 224, 224]         [32, 2]                   --                        --                        --                        True
├─Conv2d (conv1)                                             [32, 3, 224, 224]         [32, 64, 112, 112]        9,408                     [7, 7]                    3,776,446,464             True
├─BatchNorm2d (bn1)                                          [32, 64, 112, 112]        [32, 64, 112, 112]        128                       --                        4,096                     True
├─ReLU (relu)                                                [32, 64, 112, 112]        [32, 64, 112, 112]        --                        --                        --                        --
├─MaxPool2d (maxp

In [6]:

weight = models.ResNet50_Weights.DEFAULT
model_sfa.load_state_dict(weight.get_state_dict(progress=True), strict=False)

for name, param in model_sfa.named_parameters():
    if name in weight.get_state_dict():
        param.requires_grad = False

In [None]:

summary(model_sfa, depth=4, row_settings=["var_names"], input_size=(32, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds", "trainable"])

Layer (type (var_name))                                      Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds                 Trainable
ResNet50WithSFAIntegrated (ResNet50WithSFAIntegrated)        [32, 3, 224, 224]         [32, 2]                   --                        --                        --                        Partial
├─Conv2d (conv1)                                             [32, 3, 224, 224]         [32, 64, 112, 112]        (9,408)                   [7, 7]                    3,776,446,464             False
├─BatchNorm2d (bn1)                                          [32, 64, 112, 112]        [32, 64, 112, 112]        (128)                     --                        4,096                     False
├─ReLU (relu)                                                [32, 64, 112, 112]        [32, 64, 112, 112]        --                        --                        --                        --
├─MaxPool2d 

In [11]:
import torch
from torch import nn
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
def load_checkpoint(checkpoint_path):
    """
    Load model and optimizer state from a checkpoint.

    Args:
        checkpoint_path (str): Path to the checkpoint file.
        model (torch.nn.Module): The model to load the state into.
        optimizer (torch.optim.Optimizer): The optimizer to load the state into.

    Returns:
        model (torch.nn.Module): The model with loaded state.
        optimizer (torch.optim.Optimizer): The optimizer with loaded state.
        epoch (int): The epoch at which the checkpoint was saved.
    """
    checkpoint = torch.load(checkpoint_path)
    model_state_dict = checkpoint['model_state_dict']
    optimizer = checkpoint['optimizer_state_dict']
    epoch = checkpoint['epoch']
    return model_state_dict, optimizer, epoch

def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer):
    # Put model in train mode
    model.train()

    # Setup train loss and train accuracy values
    train_loss, train_acc = 0, 0

    # Loop through data loader data batches
    for (X, y) in tqdm(dataloader, desc="Batch"):
        # Send data to target device
        # print("\rbatch: " + str(batch) + "/" + str(round(int(100000/64))), end = "")
        X, y = X.to(device), y.to(device)

        # 1. Forward pass
        y_pred = model(X)

        # 2. Calculate  and accumulate loss
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()

        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backward
        loss.backward()

        # 5. Optimizer step
        optimizer.step()

        # Calculate and accumulate accuracy metric across all batches
        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        train_acc += (y_pred_class == y).sum().item()/len(y_pred)

    # Adjust metrics to get average loss and accuracy per batch
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc

def test_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: torch.nn.Module):
    # Put model in eval mode
    model.eval()

    # Setup test loss and test accuracy values
    test_loss, test_acc = 0, 0

    # Turn on inference context manager
    with torch.inference_mode():
        # Loop through DataLoader batches
        for batch, (X, y) in enumerate(dataloader):
            # Send data to target device
            X, y = X.to(device), y.to(device)

            # 1. Forward pass
            test_pred_logits = model(X)

            # 2. Calculate and accumulate loss
            loss = loss_fn(test_pred_logits, y)
            test_loss += loss.item()

            # Calculate and accumulate accuracy
            test_pred_labels = test_pred_logits.argmax(dim=1)
            test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))

    # Adjust metrics to get average loss and accuracy per batch
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return test_loss, test_acc

def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),
          checkpoint_model_name: str = "",
          epochs: int = 5,
          pretrained: str = None):
    # 1. Take in various parameters required for training and test steps

    # 2. Create empty results dictionary
    results = {"train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": []
    }
    
    if pretrained:
        model_state_dict, optimizer_state_dict, start_epoch = load_checkpoint(pretrained)
        model.load_state_dict(model_state_dict)
        optimizer.load_state_dict(optimizer_state_dict)
    else:
        start_epoch = 0
    # 3. Loop through training and testing steps for a number of epochs
    for epoch in range(start_epoch+1, start_epoch + epochs):
        print("Epoch:",epoch)
        train_loss, train_acc = train_step(model=model,
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn,
                                           optimizer=optimizer)
        test_loss, test_acc = test_step(model=model,
            dataloader=test_dataloader,
            loss_fn=loss_fn)
        # 4. Print out what's happening
        print(
            f"Epoch: {epoch} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "
            f"test_loss: {test_loss:.4f} | "
            f"test_acc: {test_acc:.4f}"
        )

        # 5. Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

        # 6. Save Checkpoints
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'train_acc': train_acc,
            'test_loss': test_loss,
            'test_acc': test_acc
        }
        torch.save(checkpoint, f"checkpoints/{checkpoint_model_name}_epoch_{epoch:02d}.pth")
        
    # 7. Return the filled results at the end of the epochs
    return results

def load_data(train_dir: str, valid_dir: str, batch_size: int = 64):
    # Define transforms
    weights = torchvision.models.ResNet50_Weights.DEFAULT #sai nhung ma ket qua tot nen de lai
    auto_transforms = weights.transforms()

    # Load data
    train_data = datasets.ImageFolder(train_dir, transform=auto_transforms, target_transform = None)
    valid_data = datasets.ImageFolder(valid_dir, transform=auto_transforms)

    # Create data loaders
    train_dataloader = DataLoader(train_data, batch_size=batch_size, num_workers=1, shuffle=True)
    valid_dataloader = DataLoader(valid_data, batch_size=batch_size, num_workers=1, shuffle=False)

    return train_dataloader, valid_dataloader

if __name__ == "__main__":
    device="cuda" if torch.cuda.is_available() else "cpu"
    print(f"device: {device}")
    train_dir="archive/dataset/train"
    valid_dir="archive/dataset/valid"
    test_dir="archive/dataset/test"

    # model = init_model_ResNet50_CBAM(True)

    # weight = torchvision.models.ResNet50_Weights.DEFAULT
    # model.load_state_dict(weight.get_state_dict(progress=True), strict=False)

    # for name, param in model.named_parameters():
    #     if name in weight.get_state_dict():
    #         param.requires_grad = False
            
    # Train
    batch_size = 16
    train_dataloader, test_dataloader = load_data(train_dir, valid_dir, batch_size)

    train(model=model_sfa,
          train_dataloader=train_dataloader,
          test_dataloader=test_dataloader,
          optimizer=torch.optim.Adam(model_sfa.parameters(), lr=0.001),
          checkpoint_model_name="ResNet50_SFA",
          epochs=10,
          pretrained="")


device: cuda
Epoch: 1


Batch:   2%|▏         | 424/22052 [04:34<3:53:18,  1.54it/s]


KeyboardInterrupt: 