In [1]:
import torch
import torchvision

import torch.nn as nn

import torchvision.transforms as transforms

from torch.optim.lr_scheduler import ReduceLROnPlateau

from torchinfo import summary

import matplotlib.pyplot as plt

In [2]:
if torch.backends.cuda.is_built():
    device = torch.device("cuda")
    x = torch.ones(1, device=device)
    print (x)
    
else:
    print ("Cuda device not found.")

tensor([1.], device='cuda:0')


## First define base model C from table 1 and train on Cifar10

In [12]:
class base_c(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.dropout1 = nn.Dropout2d(p=0.2)
        
        self.conv1 = nn.Conv2d(
            in_channels=3, out_channels=96,
            kernel_size=(3,3),
            stride=1, padding=2,
        )
        self.relu1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(
            in_channels=96, out_channels=96,
            kernel_size=(3,3),
            stride=1, padding=2,
        )
        self.relu2 = nn.ReLU()
        
        self.mp1 = nn.MaxPool2d(
            stride=2, kernel_size=(3,3), padding=0,
        )
        self.dropout2 = nn.Dropout2d(p=0.5)
        
        self.conv3 = nn.Conv2d(
            in_channels=96, out_channels=192,
            kernel_size=(3,3),
            stride=1, padding=1,
        )
        self.relu3 = nn.ReLU()
        
        self.conv4 = nn.Conv2d(
            in_channels=192, out_channels=192,
            kernel_size=(3,3),
            stride=1, padding=1,
        )
        self.relu4 = nn.ReLU()
        
        self.mp2 = nn.MaxPool2d(
            stride=2, kernel_size=(3,3), padding=0,
        )
        self.dropout3 = nn.Dropout2d(p=0.5)
        
        self.conv5 = nn.Conv2d(
            in_channels=192, out_channels=192,
            kernel_size=(3,3),
            stride=1, padding=0,
        )
        self.relu5 = nn.ReLU()
        
        self.conv6 = nn.Conv2d(
            in_channels=192, out_channels=192,
            kernel_size=(1,1),
            stride=1, padding=0,
        )
        self.relu6 = nn.ReLU()
        
        self.conv7 = nn.Conv2d(
            in_channels=192, out_channels=10,
            kernel_size=(1,1),
            stride=1, padding=0,
        )
        self.relu7 = nn.ReLU()
        
        self.avg_pool = nn.AvgPool2d(
            kernel_size=(6,6)
        )
        
        
    def forward(self, x):
        x = self.dropout1(x)
        #print(f'Shape after dropout1: {x.shape}')
        x = self.conv1(x)
        x = self.relu1(x)
        #print(f'Shape after conv1: {x.shape}')
        x = self.conv2(x)
        x = self.relu2(x)
        #print(f'Shape after conv2: {x.shape}')
        x = self.mp1(x)
        #print(f'Shape after mp1: {x.shape}')
        x = self.dropout2(x)
        #print(f'Shape after dropout2: {x.shape}')
        x = self.conv3(x)
        x = self.relu3(x)
        #print(f'Shape after conv3: {x.shape}')
        x = self.conv4(x)
        x = self.relu4(x)
        #print(f'Shape after conv4: {x.shape}')
        x = self.mp2(x)
        #print(f'Shape after mp2: {x.shape}')
        x = self.dropout3(x)
        #print(f'Shape after dropout3: {x.shape}')
        x = self.conv5(x)
        x = self.relu5(x)
        #print(f'Shape after conv5: {x.shape}')
        x = self.conv6(x)
        x = self.relu6(x)
        #print(f'Shape after conv6: {x.shape}')
        x = self.conv7(x)
        x = self.relu7(x)
        #print(f'Shape after conv7: {x.shape}')
        x = self.avg_pool(x)
        #print(f'Shape after global pool layer: {x.shape}')
        x = torch.flatten(x, start_dim=1)
        #print(f'out shape: {x.shape}')
        
        return x

def glorot_uniform_init(net):
    if isinstance(net, nn.Conv2d):
        nn.init.xavier_normal_(net.weight)
        if net.bias is not None:
            nn.init.constant_(net.bias, 0)

base_model_c = base_c()
base_model_c.apply(glorot_uniform_init)
base_model_c = base_model_c.to(device)
summary(base_model_c, input_size=(1, 3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
base_c                                   [1, 10]                   --
├─Dropout2d: 1-1                         [1, 3, 32, 32]            --
├─Conv2d: 1-2                            [1, 96, 34, 34]           2,688
├─ReLU: 1-3                              [1, 96, 34, 34]           --
├─Conv2d: 1-4                            [1, 96, 36, 36]           83,040
├─ReLU: 1-5                              [1, 96, 36, 36]           --
├─MaxPool2d: 1-6                         [1, 96, 17, 17]           --
├─Dropout2d: 1-7                         [1, 96, 17, 17]           --
├─Conv2d: 1-8                            [1, 192, 17, 17]          166,080
├─ReLU: 1-9                              [1, 192, 17, 17]          --
├─Conv2d: 1-10                           [1, 192, 17, 17]          331,968
├─ReLU: 1-11                             [1, 192, 17, 17]          --
├─MaxPool2d: 1-12                        [1, 192, 8, 8]            -

In [None]:
dummy_in = torch.randn(1, 3, 32, 32)
dummy_in = dummy_in.to(device)
out = base_model_c(dummy_in)


### Testing Base model on Cifar10

In [13]:
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

transform = transforms.Compose([
    #transforms.Resize(224), # Not needed for All CNN paper
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.48227 ,0.4465], std=[0.2470, 0.2435, 0.2616]), 
])

# Load datasets
train_dataset = CIFAR10(root='./CIFAR', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./CIFAR', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4) # 256 batch size for imagenet
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(base_model_c.parameters(), lr=0.01, weight_decay=0.001, momentum=0.9)

# Define learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)

In [15]:
train_loss_mem = []
val_loss_mem = []
val_accuracy_mem = []

for epoch in range(100):
    base_model_c.train()  # Set model to training mode
    running_loss = 0.0
    
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        outputs = base_model_c(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    train_loss = running_loss / len(train_loader)
    train_loss_mem.append(train_loss)
    
    print(f'Epoch [{epoch + 1}] training loss: {train_loss:.3f}')
    
    # Validation phase
    base_model_c.eval()  # Set model to evaluation mode
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for data in test_loader:  # Assuming test_loader is used as a validation loader
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = base_model_c(inputs)
            loss = criterion(outputs, labels)
            
            val_running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    val_loss = val_running_loss / len(test_loader)
    val_accuracy = 100 * val_correct / val_total
    
    print(f'Epoch [{epoch + 1}] validation loss: {val_loss:.3f}, accuracy: {val_accuracy:.2f}%')
    val_loss_mem.append(val_loss)
    val_accuracy_mem.append(val_accuracy)
    # Update the LR scheduler with validation loss
    scheduler.step(val_loss)
    print(f'LR: {scheduler.get_last_lr()}')


Epoch [1] training loss: 2.040
Epoch [1] validation loss: 1.806, accuracy: 30.75%
LR: [0.01]
Epoch [2] training loss: 1.771
Epoch [2] validation loss: 1.702, accuracy: 37.38%
LR: [0.01]
Epoch [3] training loss: 1.647
Epoch [3] validation loss: 1.475, accuracy: 45.32%
LR: [0.01]
Epoch [4] training loss: 1.550
Epoch [4] validation loss: 1.356, accuracy: 52.38%
LR: [0.01]
Epoch [5] training loss: 1.445
Epoch [5] validation loss: 1.211, accuracy: 57.62%
LR: [0.01]
Epoch [6] training loss: 1.361
Epoch [6] validation loss: 1.199, accuracy: 57.69%
LR: [0.01]
Epoch [7] training loss: 1.290
Epoch [7] validation loss: 1.118, accuracy: 60.94%
LR: [0.01]
Epoch [8] training loss: 1.215
Epoch [8] validation loss: 1.013, accuracy: 64.95%
LR: [0.01]
Epoch [9] training loss: 1.158
Epoch [9] validation loss: 0.943, accuracy: 66.74%
LR: [0.01]
Epoch [10] training loss: 1.118
Epoch [10] validation loss: 1.031, accuracy: 62.83%
LR: [0.01]
Epoch [11] training loss: 1.070
Epoch [11] validation loss: 0.891, a

Epoch [84] validation loss: 0.444, accuracy: 85.16%
LR: [1.0000000000000004e-08]
Epoch [85] training loss: 0.459
Epoch [85] validation loss: 0.444, accuracy: 85.16%
LR: [1.0000000000000004e-08]
Epoch [86] training loss: 0.456
Epoch [86] validation loss: 0.444, accuracy: 85.16%
LR: [1.0000000000000004e-08]
Epoch [87] training loss: 0.453
Epoch [87] validation loss: 0.444, accuracy: 85.16%
LR: [1.0000000000000004e-08]
Epoch [88] training loss: 0.459
Epoch [88] validation loss: 0.444, accuracy: 85.16%
LR: [1.0000000000000004e-08]
Epoch [89] training loss: 0.457
Epoch [89] validation loss: 0.444, accuracy: 85.16%
LR: [1.0000000000000004e-08]
Epoch [90] training loss: 0.459
Epoch [90] validation loss: 0.444, accuracy: 85.16%
LR: [1.0000000000000004e-08]
Epoch [91] training loss: 0.458
Epoch [91] validation loss: 0.444, accuracy: 85.16%
LR: [1.0000000000000004e-08]
Epoch [92] training loss: 0.454
Epoch [92] validation loss: 0.444, accuracy: 85.16%
LR: [1.0000000000000004e-08]
Epoch [93] trai

## Define All_CNN_C

In [8]:
class all_cnn_c(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.dropout1 = nn.Dropout2d(p=0.2)
        
        self.conv1 = nn.Conv2d(
            in_channels=3, out_channels=96,
            kernel_size=(3,3),
            stride=1, padding=1,
        )
        self.relu1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(
            in_channels=96, out_channels=96,
            kernel_size=(3,3),
            stride=1, padding=1,
        )
        self.relu2 = nn.ReLU()
        
        self.conv3 = nn.Conv2d(
            in_channels=96, out_channels=96,
            kernel_size=(3,3),
            stride=2, padding=1,
        ) # Replace MP with this conv
        self.relu3 = nn.ReLU()
        self.dropout2 = nn.Dropout2d(p=0.5)
        
        self.conv4 = nn.Conv2d(
            in_channels=96, out_channels=192,
            kernel_size=(3,3),
            stride=1, padding=1,
        )
        self.relu4 = nn.ReLU()
        
        self.conv5 = nn.Conv2d(
            in_channels=192, out_channels=192,
            kernel_size=(3,3),
            stride=1, padding=1,
        )
        self.relu5 = nn.ReLU()
        
        self.conv6 = nn.Conv2d(
            in_channels=192, out_channels=192,
            kernel_size=(3,3),
            stride=2, padding=1,
        ) # Replace MP with this conv
        self.relu6 = nn.ReLU()
        self.dropout3 = nn.Dropout2d(p=0.5)
        
        self.conv7 = nn.Conv2d(
            in_channels=192, out_channels=192,
            kernel_size=(3,3),
            stride=1, padding=0,
        )
        self.relu7 = nn.ReLU()
        
        self.conv8 = nn.Conv2d(
            in_channels=192, out_channels=192,
            kernel_size=(1,1),
            stride=1, padding=0,
        )
        self.relu8 = nn.ReLU()
        
        self.conv9 = nn.Conv2d(
            in_channels=192, out_channels=10,
            kernel_size=(1,1),
            stride=1, padding=0,
        )
        self.relu9 = nn.ReLU()
        
        self.avg_pool = nn.AvgPool2d(
            kernel_size=(6,6)
        )
        
        
    def forward(self, x):
        x = self.dropout1(x)
        #print(f'Shape after dropout1: {x.shape}')
        x = self.conv1(x)
        x = self.relu1(x)
        #print(f'Shape after conv1: {x.shape}')
        x = self.conv2(x)
        x = self.relu2(x)
        #print(f'Shape after conv2: {x.shape}')
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.dropout2(x)
        #print(f'Shape after conv3: {x.shape}')
        x = self.conv4(x)
        x = self.relu4(x)
        #print(f'Shape after conv4: {x.shape}')
        x = self.conv5(x)
        x = self.relu5(x)
        #print(f'Shape after conv5: {x.shape}')
        x = self.conv6(x)
        x = self.relu6(x)
        x = self.dropout3(x)
        #print(f'Shape after conv1: {x.shape}')
        x = self.conv7(x)
        x = self.relu7(x)
        #print(f'Shape after conv1: {x.shape}')
        x = self.conv8(x)
        x = self.relu8(x)
        #print(f'Shape after conv1: {x.shape}')
        x = self.conv9(x)
        x = self.relu9(x)
        #print(f'Shape after conv1: {x.shape}')
        x = self.avg_pool(x)
        #print(f'Shape after global pool layer: {x.shape}')
        x = torch.flatten(x, start_dim=1)
        #print(f'out shape: {x.shape}')
        return x

def glorot_uniform_init(net):
    if isinstance(net, nn.Conv2d):
        nn.init.xavier_normal_(net.weight)
        if net.bias is not None:
            nn.init.constant_(net.bias, 0)

all_cnn_c_model = all_cnn_c()
all_cnn_c_model.apply(glorot_uniform_init)
all_cnn_c_model = all_cnn_c_model.to(device)
summary(all_cnn_c_model, input_size=(1, 3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
all_cnn_c                                [1, 10]                   --
├─Dropout2d: 1-1                         [1, 3, 32, 32]            --
├─Conv2d: 1-2                            [1, 96, 32, 32]           2,688
├─ReLU: 1-3                              [1, 96, 32, 32]           --
├─Conv2d: 1-4                            [1, 96, 32, 32]           83,040
├─ReLU: 1-5                              [1, 96, 32, 32]           --
├─Conv2d: 1-6                            [1, 96, 16, 16]           83,040
├─ReLU: 1-7                              [1, 96, 16, 16]           --
├─Dropout2d: 1-8                         [1, 96, 16, 16]           --
├─Conv2d: 1-9                            [1, 192, 16, 16]          166,080
├─ReLU: 1-10                             [1, 192, 16, 16]          --
├─Conv2d: 1-11                           [1, 192, 16, 16]          331,968
├─ReLU: 1-12                             [1, 192, 16, 16]       

In [None]:
dummy_in = torch.randn(1, 3, 32, 32)
dummy_in = dummy_in.to(device)
out = all_cnn_c_model(dummy_in)

### ALL_CNN_C on Cifar10

In [9]:
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

transform = transforms.Compose([
    #transforms.Resize(224), # Not needed for All CNN paper
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.48227 ,0.4465], std=[0.2470, 0.2435, 0.2616]), 
])

# Load datasets
train_dataset = CIFAR10(root='./CIFAR', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./CIFAR', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(all_cnn_c_model.parameters(), lr=0.01, weight_decay=0.001, momentum=0.9)

# Define learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(all_cnn_c_model.parameters(), lr=0.01, weight_decay=0.001, momentum=0.9)

# Define learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)

train_loss_mem = []
val_loss_mem = []
val_accuracy_mem = []

for epoch in range(100):
    all_cnn_c_model.train()  # Set model to training mode
    running_loss = 0.0
    
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        outputs = all_cnn_c_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    train_loss = running_loss / len(train_loader)
    train_loss_mem.append(train_loss)
    
    print(f'Epoch [{epoch + 1}] training loss: {train_loss:.3f}')
    
    # Validation phase
    all_cnn_c_model.eval()  # Set model to evaluation mode
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for data in test_loader:  # Assuming test_loader is used as a validation loader
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = all_cnn_c_model(inputs)
            loss = criterion(outputs, labels)
            
            val_running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    val_loss = val_running_loss / len(test_loader)
    val_accuracy = 100 * val_correct / val_total
    
    print(f'Epoch [{epoch + 1}] validation loss: {val_loss:.3f}, accuracy: {val_accuracy:.2f}%')
    val_loss_mem.append(val_loss)
    val_accuracy_mem.append(val_accuracy)
    # Update the LR scheduler with validation loss
    scheduler.step(val_loss)
    print(f'LR: {scheduler.get_last_lr()}')


Epoch [1] training loss: 2.151
Epoch [1] validation loss: 1.883, accuracy: 27.41%
LR: [0.01]
Epoch [2] training loss: 1.870
Epoch [2] validation loss: 1.708, accuracy: 37.24%
LR: [0.01]
Epoch [3] training loss: 1.725
Epoch [3] validation loss: 1.777, accuracy: 36.21%
LR: [0.01]
Epoch [4] training loss: 1.633
Epoch [4] validation loss: 1.512, accuracy: 44.47%
LR: [0.01]
Epoch [5] training loss: 1.550
Epoch [5] validation loss: 1.428, accuracy: 46.92%
LR: [0.01]
Epoch [6] training loss: 1.466
Epoch [6] validation loss: 1.387, accuracy: 49.12%
LR: [0.01]
Epoch [7] training loss: 1.401
Epoch [7] validation loss: 1.268, accuracy: 54.74%
LR: [0.01]
Epoch [8] training loss: 1.340
Epoch [8] validation loss: 1.269, accuracy: 55.94%
LR: [0.01]
Epoch [9] training loss: 1.270
Epoch [9] validation loss: 1.123, accuracy: 59.79%
LR: [0.01]
Epoch [10] training loss: 1.225
Epoch [10] validation loss: 1.129, accuracy: 59.80%
LR: [0.01]
Epoch [11] training loss: 1.178
Epoch [11] validation loss: 1.036, a

Epoch [87] training loss: 0.427
Epoch [87] validation loss: 0.442, accuracy: 84.88%
LR: [1.0000000000000002e-06]
Epoch [88] training loss: 0.426
Epoch [88] validation loss: 0.442, accuracy: 84.85%
LR: [1.0000000000000002e-07]
Epoch [89] training loss: 0.425
Epoch [89] validation loss: 0.442, accuracy: 84.85%
LR: [1.0000000000000002e-07]
Epoch [90] training loss: 0.422
Epoch [90] validation loss: 0.442, accuracy: 84.85%
LR: [1.0000000000000002e-07]
Epoch [91] training loss: 0.428
Epoch [91] validation loss: 0.442, accuracy: 84.85%
LR: [1.0000000000000002e-07]
Epoch [92] training loss: 0.426
Epoch [92] validation loss: 0.442, accuracy: 84.85%
LR: [1.0000000000000004e-08]
Epoch [93] training loss: 0.426
Epoch [93] validation loss: 0.442, accuracy: 84.85%
LR: [1.0000000000000004e-08]
Epoch [94] training loss: 0.423
Epoch [94] validation loss: 0.442, accuracy: 84.85%
LR: [1.0000000000000004e-08]
Epoch [95] training loss: 0.429
Epoch [95] validation loss: 0.442, accuracy: 84.85%
LR: [1.00000

## Deconv from the paper (the fun stuff!)