# Gated Convoluntional Neural Networks

<div style="display: flex; align-items: center;">
    <img src="../imgs/GatedCNN.jpg" alt="Your Image" width="300" style="margin-right: 20px;">
    <div>
        <p>Gated Convolutional Neural Networks (Gated CNNs) are an advanced architectural variant of traditional Convolutional Neural Networks (CNNs) that incorporate gating mechanisms to regulate the flow of information through the network. These gating units act like learnable switches that can open or close, allowing certain features to pass while suppressing others, thus enabling the network to focus on relevant information.</p>
        <p>The core of a Gated CNN is its gating mechanism, which typically uses sigmoid functions to generate gate signals. These signals produce values between 0 and 1, determining how much of the input should be let through.</p>
        <p>Gated CNNs often employ depthwise separable convolutions for the token mixer, which are computationally efficient. This involves first performing a depthwise convolution that acts on each input channel separately, followed by a pointwise convolution that combines the outputs.</p>
        <p>In addition, this chapter will introduce a more challenging dataset: CIFAR10/CIFAR100, and attempt to challenge them.</p>
    </div>
</div>

## Gated CNN
Before building model.Let's see the parameters of the new layer:<br>
**nn.LayerNorm**:<br>
`normalized_shape`: The shape of the normalization, this could be a single integer N or a tuple of integers representing the last N dimensions to normalize over.<br>
`eps`: A small constant added for numerical stability. This prevents the division by zero error during the normalization process.<br>
`elementwise_affine`: A boolean value which, when set to True, applies a learnable scaling and shifting transformation after normalization.

In [1]:
import torch
import torch.nn as nn
from functools import partial

class GatedCNN(nn.Module):
    
    def __init__(self, input_channels, output_size):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(64)
        )
        self.downsample = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128)
        )
        self.blocks = nn.Sequential(
            GatedCNNBlock(128, kernel_size=3),
            GatedCNNBlock(128, kernel_size=3),
            GatedCNNBlock(128, kernel_size=3),
            GatedCNNBlock(128, kernel_size=3),
            GatedCNNBlock(128, kernel_size=3)
        )
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, output_size)

    def forward(self, x):
        x = self.stem(x)
        x = self.downsample(x)
        x = self.blocks(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

class GatedCNNBlock(nn.Module):

    def __init__(self, dim, expension_ratio=8/3, kernel_size=7, conv_ratio=1.0,
                 norm_layer=partial(nn.LayerNorm, eps=1e-6), 
                 act_layer=nn.GELU,
                 drop_path=0.,
                 **kwargs):
        super().__init__()
        hidden = int(expension_ratio * dim)
        conv_channels = int(conv_ratio * dim)
        self.split_indices = (hidden, hidden - conv_channels, conv_channels)
        
        self.norm = norm_layer(dim)
        self.fc1 = nn.Linear(dim, hidden * 2)
        self.act = act_layer()
        self.conv = nn.Conv2d(conv_channels, conv_channels, kernel_size=kernel_size, padding=kernel_size//2, groups=conv_channels)
        self.fc2 = nn.Linear(hidden, dim)

    def forward(self, x):
        shortcut = x  # [B, C, H, W]
        x = x.permute(0, 2, 3, 1)  # [B, C, H, W] -> [B, H, W, C]
        x = self.norm(x)
        g, i, c = torch.split(self.fc1(x), self.split_indices, dim=-1)
        c = c.permute(0, 3, 1, 2)  # [B, H, W, C] -> [B, C, H, W]
        c = self.conv(c)
        c = c.permute(0, 2, 3, 1)  # [B, C, H, W] -> [B, H, W, C]
        x = self.fc2(self.act(g) * torch.cat((i, c), dim=-1))
        x = x.permute(0, 3, 1, 2)  # [B, H, W, C] -> [B, C, H, W]
        return x + shortcut

## CIFAR10/100 dataset
The CIFAR dataset is a widely used benchmark in the field of machine learning and computer vision, particularly for image classification tasks. Here are some of its key characteristics:

### CIFAR Dataset Characteristics
**Two Variants**: CIFAR comes in two versions - CIFAR-10 and CIFAR-100:<br>
`CIFAR-10`: Contains 60,000 32x32 color images in 10 classes, with 6,000 images per class. It is split into 50,000 training images and 10,000 test images.
`CIFAR-100`: Similar to `CIFAR-10` but has 100 classes with 600 images per class. Each class in CIFAR-100 is divided into a "fine" and "coarse" label structure.

- Image Size: All images are 32x32 pixels with 3 color channels (RGB), making the dataset relatively small in terms of image resolution.

- Diversity: The dataset includes a wide variety of images in each class, making it more challenging than simpler datasets like MNIST.

**Comparison with MNIST Complexity**:<br>
`MNIST`: Consists of grayscale images of handwritten digits (0-9), with each image being 28x28 pixels. The dataset is relatively simple with less variation in the data.<br>
`CIFAR`: Contains color images with greater variability in terms of objects, backgrounds, and lighting conditions, making classification more challenging.<br>
**Number of Classes**:<br>
MNIST: 10 classes (digits 0-9).<br>
CIFAR-10: 10 classes (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck).<br>
CIFAR-100: 100 classes, adding significantly more complexity compared to both MNIST and CIFAR-10.<br>

**Image Resolution and Color Channels**:<br>
MNIST: 28x28 pixel grayscale images (1 channel).<br>
CIFAR: 32x32 pixel color images (3 channels), requiring models to handle more data and more complex features.

In [2]:
import sys
sys.path.append('../tools')
from CIFAR10 import CIFAR10Trainer
from CIFAR100 import CIFAR100Trainer

### Using CIFAR10 to train a simple CNN in Chapter 3
After a time, we can see a limited accuracy compared with the performance of MNIST dataset although both the train and valid loss is very low. Because the complex dataset always need deeper network to train and test. Until 5/18/2024, the sota has 99.5% accuracy.

In [1]:
import torch
import torch.nn as nn

class CNN(nn.Module):
    def __init__(self, input_channels, output_size):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        # we can see fc1 and fc2 as a MLP
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, output_size)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
cnn = CNN(input_channels=3, output_size=10)
cnn_trainer_10 = CIFAR10Trainer(cnn, loss='CE', lr=0.01, optimizer='SGD', batch_size=128, epoch=30, model_type='classification')
cnn_trainer_10.train()
cnn_trainer_10.test()

Files already downloaded and verified
Files already downloaded and verified
2024-05-18 22:23:44
Epoch 1 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:10<00:00, 32.36it/s, train_loss=1.61]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 38.55it/s, val_acc=53.4, val_loss=0.0103]


2024-05-18 22:23:56
Epoch 2 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:11<00:00, 31.49it/s, train_loss=1.16]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 36.81it/s, val_acc=60.4, val_loss=0.00895]


2024-05-18 22:24:08
Epoch 3 / 30


[Train]: 100%|████████████████████████| 352/352 [00:11<00:00, 31.69it/s, train_loss=0.967]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 36.05it/s, val_acc=65.1, val_loss=0.00777]


2024-05-18 22:24:20
Epoch 4 / 30


[Train]: 100%|████████████████████████| 352/352 [00:11<00:00, 29.97it/s, train_loss=0.832]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 37.21it/s, val_acc=68.4, val_loss=0.00714]


2024-05-18 22:24:33
Epoch 5 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 28.98it/s, train_loss=0.722]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 31.12it/s, val_acc=70, val_loss=0.00694]


2024-05-18 22:24:47
Epoch 6 / 30


[Train]: 100%|████████████████████████| 352/352 [00:11<00:00, 30.13it/s, train_loss=0.614]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 30.58it/s, val_acc=70.1, val_loss=0.00698]


2024-05-18 22:25:00
Epoch 7 / 30


[Train]: 100%|████████████████████████| 352/352 [00:11<00:00, 29.97it/s, train_loss=0.533]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 30.55it/s, val_acc=71, val_loss=0.00701]


2024-05-18 22:25:13
Epoch 8 / 30


[Train]: 100%|████████████████████████| 352/352 [00:11<00:00, 29.88it/s, train_loss=0.439]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 32.97it/s, val_acc=70.5, val_loss=0.0074]


2024-05-18 22:25:26
Epoch 9 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 28.93it/s, train_loss=0.354]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 34.84it/s, val_acc=70.1, val_loss=0.00769]


2024-05-18 22:25:39
Epoch 10 / 30


[Train]: 100%|████████████████████████| 352/352 [00:11<00:00, 30.04it/s, train_loss=0.276]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 32.54it/s, val_acc=71.8, val_loss=0.00806]


2024-05-18 22:25:52
Epoch 11 / 30


[Train]: 100%|████████████████████████| 352/352 [00:11<00:00, 29.75it/s, train_loss=0.208]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 35.19it/s, val_acc=72.2, val_loss=0.00884]


2024-05-18 22:26:05
Epoch 12 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:12<00:00, 27.96it/s, train_loss=0.15]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 34.57it/s, val_acc=70.8, val_loss=0.0103]


2024-05-18 22:26:19
Epoch 13 / 30


[Train]: 100%|████████████████████████| 352/352 [00:13<00:00, 27.03it/s, train_loss=0.115]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 29.58it/s, val_acc=71.6, val_loss=0.0106]


2024-05-18 22:26:33
Epoch 14 / 30


[Train]: 100%|███████████████████████| 352/352 [00:12<00:00, 27.75it/s, train_loss=0.0834]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.95it/s, val_acc=71.5, val_loss=0.0113]


2024-05-18 22:26:47
Epoch 15 / 30


[Train]: 100%|███████████████████████| 352/352 [00:12<00:00, 27.42it/s, train_loss=0.0597]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 28.88it/s, val_acc=71.2, val_loss=0.012]


2024-05-18 22:27:01
Epoch 16 / 30


[Train]: 100%|███████████████████████| 352/352 [00:12<00:00, 27.73it/s, train_loss=0.0555]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 28.58it/s, val_acc=69.9, val_loss=0.0133]


2024-05-18 22:27:15
Epoch 17 / 30


[Train]: 100%|███████████████████████| 352/352 [00:12<00:00, 27.96it/s, train_loss=0.0489]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.84it/s, val_acc=70.8, val_loss=0.0141]


2024-05-18 22:27:29
Epoch 18 / 30


[Train]: 100%|███████████████████████| 352/352 [00:12<00:00, 28.17it/s, train_loss=0.0364]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 30.89it/s, val_acc=71.5, val_loss=0.0144]


2024-05-18 22:27:43
Epoch 19 / 30


[Train]: 100%|███████████████████████| 352/352 [00:12<00:00, 28.64it/s, train_loss=0.0187]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 31.02it/s, val_acc=71.6, val_loss=0.0152]


2024-05-18 22:27:57
Epoch 20 / 30


[Train]: 100%|███████████████████████| 352/352 [00:12<00:00, 29.26it/s, train_loss=0.0112]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 33.46it/s, val_acc=72.4, val_loss=0.0157]


2024-05-18 22:28:10
Epoch 21 / 30


[Train]: 100%|██████████████████████| 352/352 [00:12<00:00, 27.36it/s, train_loss=0.00441]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 28.52it/s, val_acc=72.6, val_loss=0.0159]


2024-05-18 22:28:24
Epoch 22 / 30


[Train]: 100%|██████████████████████| 352/352 [00:12<00:00, 28.58it/s, train_loss=0.00141]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 30.28it/s, val_acc=72.5, val_loss=0.0164]


2024-05-18 22:28:38
Epoch 23 / 30


[Train]: 100%|█████████████████████| 352/352 [00:12<00:00, 28.60it/s, train_loss=0.000774]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 31.83it/s, val_acc=72.4, val_loss=0.0167]


2024-05-18 22:28:52
Epoch 24 / 30


[Train]: 100%|█████████████████████| 352/352 [00:12<00:00, 28.08it/s, train_loss=0.000682]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 27.94it/s, val_acc=72.6, val_loss=0.017]


2024-05-18 22:29:06
Epoch 25 / 30


[Train]: 100%|█████████████████████| 352/352 [00:12<00:00, 28.25it/s, train_loss=0.000547]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 30.61it/s, val_acc=72.4, val_loss=0.0172]


2024-05-18 22:29:19
Epoch 26 / 30


[Train]: 100%|█████████████████████| 352/352 [00:12<00:00, 28.48it/s, train_loss=0.000463]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 30.07it/s, val_acc=72.5, val_loss=0.0173]


2024-05-18 22:29:33
Epoch 27 / 30


[Train]: 100%|█████████████████████| 352/352 [00:12<00:00, 28.33it/s, train_loss=0.000416]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.24it/s, val_acc=72.6, val_loss=0.0175]


2024-05-18 22:29:47
Epoch 28 / 30


[Train]: 100%|█████████████████████| 352/352 [00:12<00:00, 28.67it/s, train_loss=0.000381]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 28.85it/s, val_acc=72.5, val_loss=0.0177]


2024-05-18 22:30:01
Epoch 29 / 30


[Train]: 100%|█████████████████████| 352/352 [00:12<00:00, 28.73it/s, train_loss=0.000351]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 30.13it/s, val_acc=72.5, val_loss=0.0178]


2024-05-18 22:30:14
Epoch 30 / 30


[Train]: 100%|█████████████████████| 352/352 [00:12<00:00, 28.42it/s, train_loss=0.000324]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 29.10it/s, val_acc=72.5, val_loss=0.0179]


### Using CIFAR100 to train a simple CNN in Chapter 3
On CIFAR100, CNN is very limited.Until 5/18/2024, the sota has 96.08% accuracy.

In [None]:
cnn_100 = CNN(input_channels=3, output_size=100)
cnn_trainer_100 = CIFAR100Trainer(cnn_100, loss='CE', lr=0.01, optimizer='SGD', batch_size=128, epoch=30, model_type='classification')
cnn_trainer_100.train()
cnn_trainer_100.test()

Files already downloaded and verified
Files already downloaded and verified
2024-05-18 22:48:23
Epoch 1 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:11<00:00, 31.13it/s, train_loss=4.04]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 35.84it/s, val_acc=16.3, val_loss=0.0286]


2024-05-18 22:48:36
Epoch 2 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:11<00:00, 30.03it/s, train_loss=3.33]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 32.76it/s, val_acc=23.5, val_loss=0.025]


2024-05-18 22:48:48
Epoch 3 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:11<00:00, 30.19it/s, train_loss=2.92]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 34.50it/s, val_acc=29.7, val_loss=0.0224]


2024-05-18 22:49:01
Epoch 4 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:11<00:00, 29.79it/s, train_loss=2.63]
[Valid]: 100%|█████████████████████████| 40/40 [00:01<00:00, 35.99it/s, val_acc=32, val_loss=0.0213]


2024-05-18 22:49:14
Epoch 5 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:12<00:00, 29.31it/s, train_loss=2.39]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 36.96it/s, val_acc=35.5, val_loss=0.0204]


2024-05-18 22:49:27
Epoch 6 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:12<00:00, 27.67it/s, train_loss=2.18]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 35.31it/s, val_acc=36.8, val_loss=0.0196]


2024-05-18 22:49:41
Epoch 7 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:12<00:00, 29.19it/s, train_loss=1.98]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 34.66it/s, val_acc=39.8, val_loss=0.019]


2024-05-18 22:49:54
Epoch 8 / 30


[Train]: 100%|██████████████████████████| 352/352 [00:12<00:00, 28.51it/s, train_loss=1.8]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 30.85it/s, val_acc=38.9, val_loss=0.0197]


2024-05-18 22:50:08
Epoch 9 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:11<00:00, 30.51it/s, train_loss=1.62]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 36.07it/s, val_acc=38.5, val_loss=0.0197]


2024-05-18 22:50:21
Epoch 10 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:12<00:00, 27.23it/s, train_loss=1.45]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 31.78it/s, val_acc=38.6, val_loss=0.0205]


2024-05-18 22:50:35
Epoch 11 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:12<00:00, 29.20it/s, train_loss=1.27]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 35.66it/s, val_acc=38.1, val_loss=0.0214]


2024-05-18 22:50:48
Epoch 12 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:13<00:00, 26.89it/s, train_loss=1.09]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 31.62it/s, val_acc=38.9, val_loss=0.0228]


2024-05-18 22:51:02
Epoch 13 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 28.67it/s, train_loss=0.925]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 34.85it/s, val_acc=37.7, val_loss=0.0251]


2024-05-18 22:51:16
Epoch 14 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 27.85it/s, train_loss=0.762]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 32.51it/s, val_acc=36.9, val_loss=0.0269]


2024-05-18 22:51:30
Epoch 15 / 30


[Train]: 100%|████████████████████████| 352/352 [00:11<00:00, 30.57it/s, train_loss=0.623]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 36.18it/s, val_acc=36.5, val_loss=0.0291]


2024-05-18 22:51:42
Epoch 16 / 30


[Train]: 100%|████████████████████████| 352/352 [00:13<00:00, 26.25it/s, train_loss=0.512]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 33.27it/s, val_acc=36.4, val_loss=0.0314]


2024-05-18 22:51:57
Epoch 17 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 28.44it/s, train_loss=0.403]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 34.72it/s, val_acc=35.9, val_loss=0.0346]


2024-05-18 22:52:11
Epoch 18 / 30


[Train]: 100%|████████████████████████| 352/352 [00:11<00:00, 29.73it/s, train_loss=0.358]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 33.46it/s, val_acc=35.7, val_loss=0.0361]


2024-05-18 22:52:24
Epoch 19 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 28.52it/s, train_loss=0.302]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 33.89it/s, val_acc=34.9, val_loss=0.0405]


2024-05-18 22:52:37
Epoch 20 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 27.34it/s, train_loss=0.245]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 35.80it/s, val_acc=36.7, val_loss=0.0417]


2024-05-18 22:52:51
Epoch 21 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 28.15it/s, train_loss=0.209]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 35.38it/s, val_acc=36.1, val_loss=0.0442]


2024-05-18 22:53:05
Epoch 22 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 28.02it/s, train_loss=0.173]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 35.57it/s, val_acc=35.7, val_loss=0.0467]


2024-05-18 22:53:18
Epoch 23 / 30


[Train]: 100%|████████████████████████| 352/352 [00:13<00:00, 27.06it/s, train_loss=0.166]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 37.44it/s, val_acc=36.3, val_loss=0.047]


2024-05-18 22:53:33
Epoch 24 / 30


[Train]: 100%|████████████████████████| 352/352 [00:11<00:00, 29.52it/s, train_loss=0.131]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 34.00it/s, val_acc=35.8, val_loss=0.0474]


2024-05-18 22:53:46
Epoch 25 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 28.20it/s, train_loss=0.123]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 34.75it/s, val_acc=36.1, val_loss=0.0505]


2024-05-18 22:53:59
Epoch 26 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:13<00:00, 26.41it/s, train_loss=0.12]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 33.01it/s, val_acc=35.9, val_loss=0.0528]


2024-05-18 22:54:14
Epoch 27 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 28.45it/s, train_loss=0.108]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 32.97it/s, val_acc=35.7, val_loss=0.0537]


2024-05-18 22:54:27
Epoch 28 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 28.16it/s, train_loss=0.119]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 34.31it/s, val_acc=36.1, val_loss=0.0553]


2024-05-18 22:54:41
Epoch 29 / 30


[Train]: 100%|████████████████████████| 352/352 [00:13<00:00, 26.33it/s, train_loss=0.132]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 34.12it/s, val_acc=35.8, val_loss=0.0568]


2024-05-18 22:54:56
Epoch 30 / 30


[Train]: 100%|████████████████████████| 352/352 [00:12<00:00, 28.14it/s, train_loss=0.102]
[Valid]: 100%|██████████████████████████| 40/40 [00:01<00:00, 32.88it/s, val_acc=37, val_loss=0.058]


# Train GatedCNN on CIFAR10/100
GatedCNN has deeper network, let's what result we will get on CIFAR10/100. We use the same settings(lr, optimizer, bs and epoch) from the CNN training.

In [2]:
import sys
import torch.nn as nn
sys.path.append('../tools')
from CIFAR10 import CIFAR10Trainer
from CIFAR100 import CIFAR100Trainer

In [None]:
model = GatedCNN(input_channels=3, output_size=10)
trainer = CIFAR10Trainer(model, loss='CE', lr=0.01, optimizer='SGD', batch_size=128, epoch=30, model_type='classification')
trainer.train()
trainer.test()

Files already downloaded and verified
Files already downloaded and verified
2024-05-18 22:57:18
Epoch 1 / 30


[Train]: 100%|██████████████████████████| 352/352 [00:14<00:00, 23.73it/s, train_loss=1.7]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 29.74it/s, val_acc=43.5, val_loss=0.0125]


2024-05-18 22:57:34
Epoch 2 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:14<00:00, 24.39it/s, train_loss=1.42]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 30.63it/s, val_acc=52.6, val_loss=0.0107]


2024-05-18 22:57:50
Epoch 3 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:14<00:00, 24.13it/s, train_loss=1.27]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 30.79it/s, val_acc=54.6, val_loss=0.0102]


2024-05-18 22:58:06
Epoch 4 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:14<00:00, 24.09it/s, train_loss=1.17]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 29.84it/s, val_acc=59.3, val_loss=0.00925]


2024-05-18 22:58:22
Epoch 5 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:14<00:00, 23.99it/s, train_loss=1.07]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 30.36it/s, val_acc=60.3, val_loss=0.00885]


2024-05-18 22:58:38
Epoch 6 / 30


[Train]: 100%|████████████████████████████| 352/352 [00:14<00:00, 23.52it/s, train_loss=1]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 29.66it/s, val_acc=63.9, val_loss=0.00811]


2024-05-18 22:58:54
Epoch 7 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:14<00:00, 23.62it/s, train_loss=0.94]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 29.80it/s, val_acc=65.8, val_loss=0.00785]


2024-05-18 22:59:10
Epoch 8 / 30


[Train]: 100%|████████████████████████| 352/352 [00:15<00:00, 23.33it/s, train_loss=0.877]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 28.45it/s, val_acc=65.6, val_loss=0.00776]


2024-05-18 22:59:27
Epoch 9 / 30


[Train]: 100%|████████████████████████| 352/352 [00:15<00:00, 22.73it/s, train_loss=0.826]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 28.61it/s, val_acc=64.6, val_loss=0.00802]


2024-05-18 22:59:44
Epoch 10 / 30


[Train]: 100%|████████████████████████| 352/352 [00:15<00:00, 22.75it/s, train_loss=0.782]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 27.75it/s, val_acc=68.7, val_loss=0.00715]


2024-05-18 23:00:01
Epoch 11 / 30


[Train]: 100%|████████████████████████| 352/352 [00:15<00:00, 22.55it/s, train_loss=0.724]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 28.47it/s, val_acc=66.4, val_loss=0.00779]


2024-05-18 23:00:18
Epoch 12 / 30


[Train]: 100%|████████████████████████| 352/352 [00:15<00:00, 22.16it/s, train_loss=0.677]
[Valid]: 100%|█████████████████████████| 40/40 [00:01<00:00, 28.07it/s, val_acc=69, val_loss=0.0072]


2024-05-18 23:00:35
Epoch 13 / 30


[Train]: 100%|████████████████████████| 352/352 [00:15<00:00, 22.35it/s, train_loss=0.633]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 28.00it/s, val_acc=69.3, val_loss=0.00736]


2024-05-18 23:00:52
Epoch 14 / 30


[Train]: 100%|████████████████████████| 352/352 [00:15<00:00, 22.31it/s, train_loss=0.579]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 28.00it/s, val_acc=69.2, val_loss=0.00751]


2024-05-18 23:01:09
Epoch 15 / 30


[Train]: 100%|████████████████████████| 352/352 [00:15<00:00, 22.08it/s, train_loss=0.529]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 27.17it/s, val_acc=69.8, val_loss=0.00751]


2024-05-18 23:01:27
Epoch 16 / 30


[Train]: 100%|████████████████████████| 352/352 [00:15<00:00, 22.08it/s, train_loss=0.482]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 27.57it/s, val_acc=71.3, val_loss=0.00727]


2024-05-18 23:01:44
Epoch 17 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:15<00:00, 22.25it/s, train_loss=0.43]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 27.85it/s, val_acc=69.9, val_loss=0.00801]


2024-05-18 23:02:02
Epoch 18 / 30


[Train]: 100%|████████████████████████| 352/352 [00:15<00:00, 22.05it/s, train_loss=0.391]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 27.66it/s, val_acc=70, val_loss=0.00768]


2024-05-18 23:02:19
Epoch 19 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.78it/s, train_loss=0.335]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 27.65it/s, val_acc=68.5, val_loss=0.00855]


2024-05-18 23:02:37
Epoch 20 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.79it/s, train_loss=0.296]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 26.93it/s, val_acc=68.8, val_loss=0.00903]


2024-05-18 23:02:54
Epoch 21 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.67it/s, train_loss=0.261]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 27.41it/s, val_acc=69.7, val_loss=0.00963]


2024-05-18 23:03:12
Epoch 22 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.75it/s, train_loss=0.225]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 27.70it/s, val_acc=68.9, val_loss=0.00961]


2024-05-18 23:03:30
Epoch 23 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.83it/s, train_loss=0.19]
[Valid]: 100%|██████████████████████| 40/40 [00:01<00:00, 27.95it/s, val_acc=71.2, val_loss=0.00984]


2024-05-18 23:03:47
Epoch 24 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.80it/s, train_loss=0.165]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 27.07it/s, val_acc=69.7, val_loss=0.011]


2024-05-18 23:04:05
Epoch 25 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.72it/s, train_loss=0.15]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 26.92it/s, val_acc=70.9, val_loss=0.0106]


2024-05-18 23:04:22
Epoch 26 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.37it/s, train_loss=0.115]
[Valid]: 100%|█████████████████████████| 40/40 [00:01<00:00, 27.13it/s, val_acc=70, val_loss=0.0126]


2024-05-18 23:04:40
Epoch 27 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.65it/s, train_loss=0.117]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.35it/s, val_acc=70.3, val_loss=0.0123]


2024-05-18 23:04:58
Epoch 28 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.53it/s, train_loss=0.106]
[Valid]: 100%|█████████████████████████| 40/40 [00:01<00:00, 26.66it/s, val_acc=70, val_loss=0.0118]


2024-05-18 23:05:16
Epoch 29 / 30


[Train]: 100%|███████████████████████| 352/352 [00:16<00:00, 21.63it/s, train_loss=0.0881]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.00it/s, val_acc=68.9, val_loss=0.0137]


2024-05-18 23:05:34
Epoch 30 / 30


[Train]: 100%|███████████████████████| 352/352 [00:16<00:00, 21.42it/s, train_loss=0.0845]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 26.48it/s, val_acc=70.9, val_loss=0.0131]


In [None]:
model_100 = GatedCNN(input_channels=3, output_size=100)
trainer_100 = CIFAR100Trainer(model_100, loss='CE', lr=0.01, optimizer='SGD', batch_size=128, epoch=30, model_type='classification')
trainer_100.train()
trainer_100.test()

Files already downloaded and verified
Files already downloaded and verified
2024-05-18 23:06:43
Epoch 1 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:15<00:00, 23.13it/s, train_loss=4.03]
[Valid]: 100%|█████████████████████████| 40/40 [00:01<00:00, 27.86it/s, val_acc=12, val_loss=0.0302]


2024-05-18 23:07:00
Epoch 2 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:15<00:00, 23.00it/s, train_loss=3.58]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 28.56it/s, val_acc=16.9, val_loss=0.0275]


2024-05-18 23:07:16
Epoch 3 / 30


[Train]: 100%|██████████████████████████| 352/352 [00:15<00:00, 22.72it/s, train_loss=3.3]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.75it/s, val_acc=20.5, val_loss=0.0264]


2024-05-18 23:07:33
Epoch 4 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:15<00:00, 22.58it/s, train_loss=3.09]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 28.08it/s, val_acc=24.7, val_loss=0.0245]


2024-05-18 23:07:50
Epoch 5 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:15<00:00, 22.11it/s, train_loss=2.93]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.77it/s, val_acc=26.1, val_loss=0.0235]


2024-05-18 23:08:08
Epoch 6 / 30


[Train]: 100%|██████████████████████████| 352/352 [00:15<00:00, 22.12it/s, train_loss=2.8]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 26.54it/s, val_acc=28.3, val_loss=0.0227]


2024-05-18 23:08:25
Epoch 7 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.82it/s, train_loss=2.66]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.07it/s, val_acc=30.6, val_loss=0.0217]


2024-05-18 23:08:43
Epoch 8 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.98it/s, train_loss=2.54]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 28.15it/s, val_acc=32.3, val_loss=0.0211]


2024-05-18 23:09:00
Epoch 9 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:15<00:00, 22.07it/s, train_loss=2.43]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 28.14it/s, val_acc=33.7, val_loss=0.0202]


2024-05-18 23:09:18
Epoch 10 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.78it/s, train_loss=2.32]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.76it/s, val_acc=35.1, val_loss=0.0197]


2024-05-18 23:09:35
Epoch 11 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.75it/s, train_loss=2.24]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.50it/s, val_acc=35.2, val_loss=0.0199]


2024-05-18 23:09:53
Epoch 12 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.81it/s, train_loss=2.14]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.46it/s, val_acc=37.3, val_loss=0.0189]


2024-05-18 23:10:11
Epoch 13 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.64it/s, train_loss=2.05]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.14it/s, val_acc=38.3, val_loss=0.0183]


2024-05-18 23:10:28
Epoch 14 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.96it/s, train_loss=1.96]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.11it/s, val_acc=39.7, val_loss=0.0182]


2024-05-18 23:10:46
Epoch 15 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.40it/s, train_loss=1.87]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.42it/s, val_acc=40.7, val_loss=0.0176]


2024-05-18 23:11:04
Epoch 16 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.61it/s, train_loss=1.78]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.07it/s, val_acc=41.6, val_loss=0.0177]


2024-05-18 23:11:21
Epoch 17 / 30


[Train]: 100%|██████████████████████████| 352/352 [00:16<00:00, 21.63it/s, train_loss=1.7]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 26.60it/s, val_acc=41.7, val_loss=0.0174]


2024-05-18 23:11:39
Epoch 18 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.42it/s, train_loss=1.62]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 26.98it/s, val_acc=41.5, val_loss=0.0176]


2024-05-18 23:11:57
Epoch 19 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.63it/s, train_loss=1.54]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 26.66it/s, val_acc=43.1, val_loss=0.0172]


2024-05-18 23:12:15
Epoch 20 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.61it/s, train_loss=1.44]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.72it/s, val_acc=42.3, val_loss=0.0175]


2024-05-18 23:12:33
Epoch 21 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.78it/s, train_loss=1.35]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 27.04it/s, val_acc=42.4, val_loss=0.018]


2024-05-18 23:12:50
Epoch 22 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.64it/s, train_loss=1.26]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 27.30it/s, val_acc=43.5, val_loss=0.018]


2024-05-18 23:13:08
Epoch 23 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.49it/s, train_loss=1.16]
[Valid]: 100%|█████████████████████████| 40/40 [00:01<00:00, 25.97it/s, val_acc=43, val_loss=0.0185]


2024-05-18 23:13:26
Epoch 24 / 30


[Train]: 100%|█████████████████████████| 352/352 [00:16<00:00, 21.70it/s, train_loss=1.07]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.26it/s, val_acc=41.5, val_loss=0.0192]


2024-05-18 23:13:44
Epoch 25 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.80it/s, train_loss=0.993]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.26it/s, val_acc=42.7, val_loss=0.0193]


2024-05-18 23:14:01
Epoch 26 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.59it/s, train_loss=0.899]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 27.62it/s, val_acc=42.6, val_loss=0.0202]


2024-05-18 23:14:19
Epoch 27 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.50it/s, train_loss=0.823]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 27.59it/s, val_acc=42.2, val_loss=0.021]


2024-05-18 23:14:37
Epoch 28 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.84it/s, train_loss=0.739]
[Valid]: 100%|███████████████████████| 40/40 [00:01<00:00, 26.94it/s, val_acc=42.1, val_loss=0.0214]


2024-05-18 23:14:55
Epoch 29 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.57it/s, train_loss=0.653]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 27.49it/s, val_acc=42.1, val_loss=0.023]


2024-05-18 23:15:12
Epoch 30 / 30


[Train]: 100%|████████████████████████| 352/352 [00:16<00:00, 21.77it/s, train_loss=0.595]
[Valid]: 100%|████████████████████████| 40/40 [00:01<00:00, 26.94it/s, val_acc=41.8, val_loss=0.023]
