<a href="https://colab.research.google.com/github/DevKiroBerty/deep_learning/blob/main/AA_ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [37]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [38]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [39]:
class AugmentedConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, heads=4, dv=0.1):
        super(AugmentedConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.attention_conv = nn.Conv2d(in_channels, out_channels, 1)  # Same number of channels as conv

    def forward(self, x):
        conv_out = self.conv(x)
        attention_out = self.attention_conv(x)
        # Ensure same dimensions by matching the sizes
        if conv_out.size() != attention_out.size():
            attention_out = F.interpolate(attention_out, size=conv_out.size()[2:], mode='bilinear', align_corners=False)
        return conv_out + attention_out



In [40]:
class BasicBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = AugmentedConv(in_planes, planes, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = AugmentedConv(planes, planes, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [41]:
class WideResNet(nn.Module):
    def __init__(self, depth, widen_factor, num_classes):
        super(WideResNet, self).__init__()
        n = (depth - 4) // 6
        k = widen_factor
        self.in_planes = 16

        def wide_layer(block, planes, num_blocks, stride):
            strides = [stride] + [1]*(num_blocks-1)
            layers = []
            for stride in strides:
                layers.append(block(self.in_planes, planes, stride))
                self.in_planes = planes
            return nn.Sequential(*layers)

        self.conv1 = AugmentedConv(3, 16, kernel_size=3, stride=1, padding=1)
        self.layer1 = wide_layer(BasicBlock, 16*k, n, stride=1)
        self.layer2 = wide_layer(BasicBlock, 32*k, n, stride=2)
        self.layer3 = wide_layer(BasicBlock, 64*k, n, stride=2)
        self.bn1 = nn.BatchNorm2d(64*k)
        self.linear = nn.Linear(64*k, num_classes)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [42]:
def load_data(batch_size):
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])

    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    return trainloader, testloader

In [43]:
def train_model(model, trainloader, criterion, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:  # print every 100 mini-batches
                print(f'Epoch [{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0

In [44]:
def test_model(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network on the test images: {100 * correct / total:.2f}%')

In [45]:
def main():
    batch_size = 128
    epochs = 10
    learning_rate = 0.01

    trainloader, testloader = load_data(batch_size)

    model = WideResNet(depth=16, widen_factor=2, num_classes=10).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

    train_model(model, trainloader, criterion, optimizer, epochs)
    test_model(model, testloader)

if __name__ == '__main__':
    main()

Files already downloaded and verified
Files already downloaded and verified
Epoch [1, 100] loss: 1.920
Epoch [1, 200] loss: 1.571
Epoch [1, 300] loss: 1.363
Epoch [2, 100] loss: 1.128
Epoch [2, 200] loss: 1.042
Epoch [2, 300] loss: 0.985
Epoch [3, 100] loss: 0.908
Epoch [3, 200] loss: 0.862
Epoch [3, 300] loss: 0.826
Epoch [4, 100] loss: 0.737
Epoch [4, 200] loss: 0.724
Epoch [4, 300] loss: 0.700
Epoch [5, 100] loss: 0.636
Epoch [5, 200] loss: 0.641
Epoch [5, 300] loss: 0.632
Epoch [6, 100] loss: 0.600
Epoch [6, 200] loss: 0.576
Epoch [6, 300] loss: 0.556
Epoch [7, 100] loss: 0.534
Epoch [7, 200] loss: 0.534
Epoch [7, 300] loss: 0.527
Epoch [8, 100] loss: 0.496
Epoch [8, 200] loss: 0.500
Epoch [8, 300] loss: 0.493
Epoch [9, 100] loss: 0.467
Epoch [9, 200] loss: 0.455
Epoch [9, 300] loss: 0.476
Epoch [10, 100] loss: 0.438
Epoch [10, 200] loss: 0.429
Epoch [10, 300] loss: 0.444
Accuracy of the network on the test images: 79.26%
