In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)

    def forward(self, x):
        b, c, _, _ = x.size()
        z = F.adaptive_avg_pool2d(x, 1).view(b, c)
        s = torch.sigmoid(self.fc2(F.relu(self.fc1(z))))
        s = s.view(b, c, 1, 1)
        return x * s


In [None]:
from torchvision.models.resnet import BasicBlock, ResNet

class SEBasicBlock(BasicBlock):
    def __init__(self, *args, reduction=16, **kwargs):
        super().__init__(*args, **kwargs)
        self.se = SEBlock(self.conv2.out_channels, reduction)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.se(out)  # Apply SE attention

        out += identity
        out = self.relu(out)
        return out


class SEResNet18(ResNet):
    def __init__(self, num_classes=10):
        super().__init__(SEBasicBlock, [2, 2, 2, 2])
        self.fc = nn.Linear(512, num_classes)


In [None]:
from torchvision.models.resnet import BasicBlock, ResNet

class SEBasicBlock(BasicBlock):
    def __init__(self, *args, reduction=16, **kwargs):
        super().__init__(*args, **kwargs)
        self.se = SEBlock(self.conv2.out_channels, reduction)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.se(out)  # Apply SE attention

        out += identity
        out = self.relu(out)
        return out


class SEResNet18(ResNet):
    def __init__(self, num_classes=10):
        super().__init__(SEBasicBlock, [2, 2, 2, 2])
        self.fc = nn.Linear(512, num_classes)


In [None]:
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(
    testset, batch_size=128, shuffle=False)

classes = trainset.classes


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SEResNet18().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train(epoch):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch}, Loss: {running_loss / len(trainloader):.4f}")

def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Test Accuracy: {100 * correct / total:.2f}%")


In [None]:
from torchvision.models.resnet import BasicBlock, conv1x1
import torch.nn as nn
import torch.nn.functional as F

class SEBasicBlock(BasicBlock):
    expansion = 1
    
    def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
        super().__init__(inplanes, planes, stride, downsample)
        self.se = SEBlock(self.conv2.out_channels, reduction)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.se(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


In [None]:
class SEResNet18(ResNet):
    def __init__(self, num_classes=10):
        super().__init__(
            block=lambda *args, **kwargs: SEBasicBlock(*args, **kwargs, reduction=16),
            layers=[2, 2, 2, 2]
        )
        self.fc = nn.Linear(512, num_classes)


In [None]:
for epoch in range(10):
    train(epoch)
    test()
