In [4]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.amp import GradScaler

# GPU 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Mixed Precision을 위한 스케일러 초기화
scaler_resnet_50 = GradScaler('cuda')
scaler_se_resnet_50 = GradScaler('cuda')

batch_size = 512 # 배치 크기

transfrom = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 데이터셋 다운로드
full_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transfrom)
import torch
import torchvision
import torchvision.transforms as transforms

# 데이터 증강
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

# 검증/테스트 데이터
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 데이터셋 다운로드
full_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Using device: cuda


In [5]:
from torch.utils.data import random_split

# Train/Validation 분할 (80:20)
train_size = int(0.8 * len(full_trainset))
val_size = len(full_trainset) - train_size

datasets = random_split(full_trainset, [train_size, val_size])
trainset = datasets[0]  # 학습용
valset = datasets[1]     # 검증용


# 데이터셋 불러오기
# DataLoader : 데이터셋을 배치 단위로 관리하는 역할
    # batch_size : 배치 크기
    # shuffle : 데이터를 섞을지 여부
    # num_workers : 데이터 로드에 사용할 쓰레드 수
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Channel Attention Module
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()

        self.mlp = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False),
        )

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        b, c, _, _ = x.size()

        # average, max pooling
        avg_pool = self.avg_pool(x).view(b, c)
        max_pool = self.max_pool(x).view(b, c)

        # MLP, 동일 MLP 공유
        avg_out = self.mlp(avg_pool).view(b, c, 1, 1)
        max_out = self.mlp(max_pool).view(b, c, 1, 1)
        out = avg_out + max_out

        # scale
        scale = self.sigmoid(out).view(b, c, 1, 1)

        return x * scale

In [7]:
# Spatial Attention Module
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
         # 입력 채널 2(avg, max concat 결과) -> 출력 채널 1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 채널 방향 평균/최대 풀링 -> [B, 1, H, W] 두 장
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        max_pool, _ = torch.max(x, dim=1, keepdim=True)

        # concat -> [B, 2, H, W]
        pool_cat = torch.cat([avg_pool, max_pool], dim=1)

        # conv 7x7 -> sigmoid -> scale
        scale = self.sigmoid(self.conv(pool_cat))

        return x * scale  

In [8]:
class CBAM(nn.Module):
    def __init__(self, in_channels, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attn = ChannelAttention(in_channels, reduction)
        self.spatial_attn = SpatialAttention(kernel_size)
        
    def forward(self, x):
        # channel attention 먼저 적용 후 spatial attention 적용
        x = self.channel_attn(x)
        x = self.spatial_attn(x)
        return x