In [2]:
import torch, torchvision
import torchvision.transforms as transforms
import torch.nn as nn

import numpy as np

import copy

print(f"Is CUDA supported by this system? {torch.cuda.is_available()}")
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

Is CUDA supported by this system? True


In [3]:
# https://github.com/kuangliu/pytorch-cifar/issues/19
test_t = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ]
)
transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomCrop(32, padding=4),
        test_t
    ]
)

train_data = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True, 
    download=True,
    transform=transform
)
test_data  = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False, 
    download=True,
    transform=test_t
)

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=48, shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=48, shuffle=False
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
class Block(nn.Module):
    def __init__(self, channel_in, pass_on, channel_out, device):
        super().__init__()
        self.conv_1x1 = nn.Sequential(
            nn.Conv2d(channel_in, channel_out["1x1"], kernel_size=1),
            nn.BatchNorm2d(channel_out["1x1"]),
            nn.PReLU()
        )
    
        # 3x3 branch, we padding 1 in the 3x3 convolution layer to keep same size of image
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(channel_in, pass_on["3x3"], kernel_size=1),
            nn.BatchNorm2d(pass_on["3x3"]),
            nn.PReLU(),
            nn.Conv2d(pass_on["3x3"], channel_out["3x3"], kernel_size=3, padding=1),
            nn.BatchNorm2d(channel_out["3x3"]),
            nn.PReLU()
        )
        
        # 5x5 branch, we padding 2 in the 5x5 convolution layer to keep same size of image
        self.conv_5x5 = nn.Sequential(
            nn.Conv2d(channel_in, pass_on["5x5"], kernel_size=1),
            nn.BatchNorm2d(pass_on["5x5"]),
            nn.PReLU(),
            nn.Conv2d(pass_on["5x5"], channel_out["5x5"], kernel_size=5, padding=2),
            nn.BatchNorm2d(channel_out["5x5"]),
            nn.PReLU()
        ) 
        # Max pooling branch
        self.max_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            nn.Conv2d(channel_in, channel_out["max"], kernel_size=1),
            nn.BatchNorm2d(channel_out["max"]),
            nn.PReLU()
        )

        self._initialize_weights()

    def forward(self, x):
        return torch.cat(
            [
                self.conv_1x1(x), self.conv_3x3(x),
                self.conv_5x5(x), self.max_pool(x)
            ], dim=1 # concatenate along channels
        )

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.1)
                nn.init.constant_(m.bias, 0)

In [11]:
# https://arxiv.org/pdf/1502.01852.pdf
# https://arxiv.org/pdf/1409.4842.pdf

class CifarGoogle(nn.Module):
    def __init__(self, 
                 in_channels=3, 
                 n_classes=10,
                 device='cpu', 
                 lr=0.01):
        super(CifarGoogle, self).__init__()
        
        self.seq = nn.Sequential(
            # input layer
            nn.Conv2d(3, 64, kernel_size=7, padding=3),
            nn.PReLU(),
            nn.LocalResponseNorm(128),
            nn.Conv2d(64, 112, kernel_size=1),
            nn.PReLU(),
            nn.Conv2d(112, 196, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.LocalResponseNorm(128),
            # pass through blocks
            Block(
                196, 
                pass_on={"3x3": 96, "5x5": 16}, 
                channel_out={"1x1": 64, "3x3": 128, "5x5": 32, "max": 32},
                device=device
            ),
            Block(
                256, 
                pass_on={"3x3": 128, "5x5": 32}, 
                channel_out={"1x1": 128, "3x3": 192, "5x5": 96, "max": 64},
                device=device
            ),
            # reduce dimensions
            nn.MaxPool2d(3, stride=2, padding=1),  
            # pass through blocks
            Block(
                480, 
                pass_on={"3x3": 96, "5x5": 16}, 
                channel_out={"1x1": 192, "3x3": 208, "5x5": 48, "max": 64}, 
                device=device
            ),
            Block(
                512, 
                pass_on={"3x3": 112, "5x5": 24}, 
                channel_out={"1x1": 160, "3x3": 224, "5x5": 64, "max": 64}, 
                device=device
            ),
            Block(
                512, 
                pass_on={"3x3": 128, "5x5": 24}, 
                channel_out={"1x1": 128, "3x3": 256, "5x5": 64, "max": 64}, 
                device=device
            ),
            Block(
                512, 
                pass_on={"3x3": 112, "5x5": 32}, 
                channel_out={"1x1": 112, "3x3": 288, "5x5": 64, "max": 64}, 
                device=device
            ),
            # reduce dimensions
            nn.MaxPool2d(3, stride=2, padding=1),
            # pass through last blocks
            Block(
                528, 
                pass_on={"3x3": 160, "5x5": 32}, 
                channel_out={"1x1": 256, "3x3": 320, "5x5": 128, "max": 128},
                device=device
            ),
            Block(
                832, 
                pass_on={"3x3": 150, "5x5": 42}, 
                channel_out={"1x1": 266, "3x3": 330, "5x5": 108, "max": 128},
                device=device
            ),
            # pool
            nn.AdaptiveAvgPool2d((1, 1)),
            # classification head
            nn.Dropout(0.4),
            nn.Flatten(),
            nn.Linear(832, 10),
            nn.Softmax(1)
        )

        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        self.criterion = nn.CrossEntropyLoss().to(device)

        self.device = device
        self.n_classes = n_classes

        self.total_epochs_trained = 0

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return self.seq(x)

    def train(self, train_loader, test_loader, train_epochs):
        break_count = 0
        max_break_count = 15
        max_test_acc = -9e15
        epoch_train_loss = epoch_test_loss = 0
        epoch_train_losses = []
        epoch_test_losses = []

        state_dict = None

        for epoch in range(train_epochs):
            total_train_loss = total_test_acc = 0
            for image,label in train_loader:
                image = image.to(self.device)
                label = label.to(self.device)

                self.optimizer.zero_grad()

                preds = self.forward(image)

                loss = self.criterion(preds, label)
                total_train_loss += loss.item()
                loss.backward()

                self.optimizer.step()
            
            with torch.no_grad():
                self.training = False
                for module in self.children():
                    module.train(False)
                
                for image,label in test_loader:
                    image = image.to(self.device)
                    label = label.to(self.device)

                    preds = self.forward(image)

                    total_test_acc += (
                        torch.sum(torch.argmax(preds, dim=1) == label).item()
                    )
                    
                self.training = True
                for module in self.children():
                    module.train(True)

            total_test_acc /= test_loader.dataset.data.shape[0]

            self.total_epochs_trained += 1

            print(f'''[Epoch {epoch} / Global Epoch {self.total_epochs_trained}]''')
            print(f'''    Train Loss:     {total_train_loss}''')
            print(f'''    Test Accuracy:  {total_test_acc}''')

            if (total_test_acc > max_test_acc):
                max_test_acc = total_test_acc
                state_dict = copy.deepcopy(self.state_dict())
                break_count = 0
            else:
                break_count += 1
                if break_count >= max_break_count:
                    print("Stopping Early.")
                    self.load_state_dict(state_dict)
                    break

            print(f'''    Best Test Acc:  {max_test_acc}\n''')

    def predict(self, X):
        pass

In [12]:
mybing = CifarGoogle(
    in_channels=3, 
    n_classes=10,
    device=DEVICE,
    lr=0.0001
).to(DEVICE)

In [5]:
f = open(f"model_large_{mybing.total_epochs_trained}.pt", "w").close()
torch.save({
    'model_state': mybing.state_dict(),
    'optimizer_state': mybing.optimizer.state_dict(),
    'epoch': mybing.total_epochs_trained
    }, f"model_large_{mybing.total_epochs_trained}.pt"
)

mybing.train(
    train_loader,
    test_loader,
    5
)

f = open(f"model_large_{mybing.total_epochs_trained}.pt", "w").close()
torch.save({
    'model_state': mybing.state_dict(),
    'optimizer_state': mybing.optimizer.state_dict(),
    'epoch': mybing.total_epochs_trained
    }, f"model_large_{mybing.total_epochs_trained}.pt"
)

[Epoch 0 / Global Epoch 1]
    Train Loss:     2126.458132624626
    Test Accuracy:  0.4889
    Best Test Acc:  0.4889

[Epoch 1 / Global Epoch 2]
    Train Loss:     1978.6962685585022
    Test Accuracy:  0.591
    Best Test Acc:  0.591

[Epoch 2 / Global Epoch 3]
    Train Loss:     1909.254210114479
    Test Accuracy:  0.638
    Best Test Acc:  0.638

[Epoch 3 / Global Epoch 4]
    Train Loss:     1854.7839585542679
    Test Accuracy:  0.6953
    Best Test Acc:  0.6953

[Epoch 4 / Global Epoch 5]
    Train Loss:     1816.6994429826736
    Test Accuracy:  0.7044
    Best Test Acc:  0.7044



In [26]:
checkpoint = torch.load("model_large_5.pt")

In [27]:
mybing.load_state_dict(checkpoint['model_state'])
mybing.optimizer.load_state_dict(checkpoint['optimizer_state'])
mybing.total_epochs_trained = checkpoint['epoch']

In [28]:
mybing.train(
    train_loader,
    test_loader,
    5
)

f = open(f"model_large_{mybing.total_epochs_trained}.pt", "w").close()
torch.save({
    'model_state': mybing.state_dict(),
    'optimizer_state': mybing.optimizer.state_dict(),
    'epoch': mybing.total_epochs_trained
    }, f"model_large_{mybing.total_epochs_trained}.pt"
)

[Epoch 0 / Global Epoch 6]
    Train Loss:     1790.1647598743439
    Test Accuracy:  0.7354
    Best Test Acc:  0.7354

[Epoch 1 / Global Epoch 7]
    Train Loss:     1769.217351436615
    Test Accuracy:  0.7563
    Best Test Acc:  0.7563

[Epoch 2 / Global Epoch 8]
    Train Loss:     1751.1587373018265
    Test Accuracy:  0.7698
    Best Test Acc:  0.7698

[Epoch 3 / Global Epoch 9]
    Train Loss:     1737.7428050041199
    Test Accuracy:  0.7795
    Best Test Acc:  0.7795

[Epoch 4 / Global Epoch 10]
    Train Loss:     1724.7332781553268
    Test Accuracy:  0.786
    Best Test Acc:  0.786



In [29]:
checkpoint2 = torch.load("model_large_10.pt")

In [30]:
mybing.load_state_dict(checkpoint2['model_state'])
mybing.optimizer.load_state_dict(checkpoint2['optimizer_state'])
mybing.total_epochs_trained = checkpoint2['epoch']

In [31]:
print(mybing.total_epochs_trained)

10


In [32]:
mybing.train(
    train_loader,
    test_loader,
    5
)

f = open(f"model_large_{mybing.total_epochs_trained}.pt", "w").close()
torch.save({
    'model_state': mybing.state_dict(),
    'optimizer_state': mybing.optimizer.state_dict(),
    'epoch': mybing.total_epochs_trained
    }, f"model_large_{mybing.total_epochs_trained}.pt"
)

[Epoch 0 / Global Epoch 11]
    Train Loss:     1716.1553529500961
    Test Accuracy:  0.7942
    Best Test Acc:  0.7942

[Epoch 1 / Global Epoch 12]
    Train Loss:     1702.9403219223022
    Test Accuracy:  0.7812
    Best Test Acc:  0.7942

[Epoch 2 / Global Epoch 13]
    Train Loss:     1694.8358651399612
    Test Accuracy:  0.8068
    Best Test Acc:  0.8068

[Epoch 3 / Global Epoch 14]
    Train Loss:     1687.8354079723358
    Test Accuracy:  0.8054
    Best Test Acc:  0.8068

[Epoch 4 / Global Epoch 15]
    Train Loss:     1678.9988315105438
    Test Accuracy:  0.8165
    Best Test Acc:  0.8165



In [7]:
checkpoint3 = torch.load("model_large_15.pt")

In [8]:
mybing.load_state_dict(checkpoint3['model_state'])
mybing.optimizer.load_state_dict(checkpoint3['optimizer_state'])
mybing.total_epochs_trained = checkpoint3['epoch']

print(mybing.total_epochs_trained)

15


In [9]:
mybing.train(
    train_loader,
    test_loader,
    5
)

f = open(f"model_large_{mybing.total_epochs_trained}.pt", "w").close()
torch.save({
    'model_state': mybing.state_dict(),
    'optimizer_state': mybing.optimizer.state_dict(),
    'epoch': mybing.total_epochs_trained
    }, f"model_large_{mybing.total_epochs_trained}.pt"
)

[Epoch 0 / Global Epoch 16]
    Train Loss:     1672.5715950727463
    Test Accuracy:  0.8227
    Best Test Acc:  0.8227

[Epoch 1 / Global Epoch 17]
    Train Loss:     1665.5884673595428
    Test Accuracy:  0.8175
    Best Test Acc:  0.8227

[Epoch 2 / Global Epoch 18]
    Train Loss:     1660.4399118423462
    Test Accuracy:  0.8319
    Best Test Acc:  0.8319

[Epoch 3 / Global Epoch 19]
    Train Loss:     1653.895160317421
    Test Accuracy:  0.8321
    Best Test Acc:  0.8321

[Epoch 4 / Global Epoch 20]
    Train Loss:     1649.6739035844803
    Test Accuracy:  0.8337
    Best Test Acc:  0.8337



In [13]:
checkpoint4 = torch.load("model_large_20.pt")

In [14]:
mybing.load_state_dict(checkpoint4['model_state'])
mybing.optimizer.load_state_dict(checkpoint4['optimizer_state'])
mybing.total_epochs_trained = checkpoint4['epoch']

print(mybing.total_epochs_trained)

20


In [15]:
mybing.train(
    train_loader,
    test_loader,
    5
)

f = open(f"model_large_{mybing.total_epochs_trained}.pt", "w").close()
torch.save({
    'model_state': mybing.state_dict(),
    'optimizer_state': mybing.optimizer.state_dict(),
    'epoch': mybing.total_epochs_trained
    }, f"model_large_{mybing.total_epochs_trained}.pt"
)

[Epoch 0 / Global Epoch 21]
    Train Loss:     1647.5573140382767
    Test Accuracy:  0.8551
    Best Test Acc:  0.8551

[Epoch 1 / Global Epoch 22]
    Train Loss:     1641.3319432735443
    Test Accuracy:  0.8532
    Best Test Acc:  0.8551

[Epoch 2 / Global Epoch 23]
    Train Loss:     1638.7560341358185
    Test Accuracy:  0.8612
    Best Test Acc:  0.8612

[Epoch 3 / Global Epoch 24]
    Train Loss:     1632.1314253807068
    Test Accuracy:  0.8586
    Best Test Acc:  0.8612

[Epoch 4 / Global Epoch 25]
    Train Loss:     1628.480310678482
    Test Accuracy:  0.8668
    Best Test Acc:  0.8668



In [16]:
checkpoint5 = torch.load("model_large_25.pt")

In [18]:
mybing.load_state_dict(checkpoint5['model_state'])
mybing.optimizer.load_state_dict(checkpoint5['optimizer_state'])
mybing.total_epochs_trained = checkpoint5['epoch']

print(mybing.total_epochs_trained)

25


In [19]:
mybing.train(
    train_loader,
    test_loader,
    5
)

f = open(f"model_large_{mybing.total_epochs_trained}.pt", "w").close()
torch.save({
    'model_state': mybing.state_dict(),
    'optimizer_state': mybing.optimizer.state_dict(),
    'epoch': mybing.total_epochs_trained
    }, f"model_large_{mybing.total_epochs_trained}.pt"
)

[Epoch 0 / Global Epoch 26]
    Train Loss:     1624.3030341863632
    Test Accuracy:  0.8676
    Best Test Acc:  0.8676

[Epoch 1 / Global Epoch 27]
    Train Loss:     1621.163398385048
    Test Accuracy:  0.8626
    Best Test Acc:  0.8676

[Epoch 2 / Global Epoch 28]
    Train Loss:     1619.2955071926117
    Test Accuracy:  0.8658
    Best Test Acc:  0.8676

[Epoch 3 / Global Epoch 29]
    Train Loss:     1616.646876692772
    Test Accuracy:  0.8756
    Best Test Acc:  0.8756

[Epoch 4 / Global Epoch 30]
    Train Loss:     1613.7162551879883
    Test Accuracy:  0.875
    Best Test Acc:  0.8756



In [20]:
checkpoint6 = torch.load("model_large_30.pt")

In [21]:
mybing.load_state_dict(checkpoint6['model_state'])
mybing.optimizer.load_state_dict(checkpoint6['optimizer_state'])
mybing.total_epochs_trained = checkpoint6['epoch']

print(mybing.total_epochs_trained)

30


In [22]:
mybing.train(
    train_loader,
    test_loader,
    15
)

f = open(f"model_large_{mybing.total_epochs_trained}.pt", "w").close()
torch.save({
    'model_state': mybing.state_dict(),
    'optimizer_state': mybing.optimizer.state_dict(),
    'epoch': mybing.total_epochs_trained
    }, f"model_large_{mybing.total_epochs_trained}.pt"
)

[Epoch 0 / Global Epoch 31]
    Train Loss:     1611.0862920284271
    Test Accuracy:  0.8699
    Best Test Acc:  0.8699

[Epoch 1 / Global Epoch 32]
    Train Loss:     1609.408675789833
    Test Accuracy:  0.8734
    Best Test Acc:  0.8734

[Epoch 2 / Global Epoch 33]
    Train Loss:     1604.6497066020966
    Test Accuracy:  0.8783
    Best Test Acc:  0.8783

[Epoch 3 / Global Epoch 34]
    Train Loss:     1602.7326176166534
    Test Accuracy:  0.8683
    Best Test Acc:  0.8783

[Epoch 4 / Global Epoch 35]
    Train Loss:     1601.099912762642
    Test Accuracy:  0.8801
    Best Test Acc:  0.8801

[Epoch 5 / Global Epoch 36]
    Train Loss:     1598.07659471035
    Test Accuracy:  0.8773
    Best Test Acc:  0.8801

[Epoch 6 / Global Epoch 37]
    Train Loss:     1596.1671664714813
    Test Accuracy:  0.8786
    Best Test Acc:  0.8801

[Epoch 7 / Global Epoch 38]
    Train Loss:     1595.1360617876053
    Test Accuracy:  0.8793
    Best Test Acc:  0.8801

[Epoch 8 / Global Epoch 39]


In [23]:
checkpoint7 = torch.load("model_large_45.pt")

In [24]:
mybing.load_state_dict(checkpoint7['model_state'])
mybing.optimizer.load_state_dict(checkpoint7['optimizer_state'])
mybing.total_epochs_trained = checkpoint7['epoch']

print(mybing.total_epochs_trained)

45


In [26]:
mybing.train(
    train_loader,
    test_loader,
    7
)

f = open(f"model_large_{mybing.total_epochs_trained}.pt", "w").close()
torch.save({
    'model_state': mybing.state_dict(),
    'optimizer_state': mybing.optimizer.state_dict(),
    'epoch': mybing.total_epochs_trained
    }, f"model_large_{mybing.total_epochs_trained}.pt"
)

[Epoch 0 / Global Epoch 46]
    Train Loss:     1581.3161803483963
    Test Accuracy:  0.8834
    Best Test Acc:  0.8834

[Epoch 1 / Global Epoch 47]
    Train Loss:     1579.2013140916824
    Test Accuracy:  0.8895
    Best Test Acc:  0.8895

[Epoch 2 / Global Epoch 48]
    Train Loss:     1578.717133283615
    Test Accuracy:  0.8855
    Best Test Acc:  0.8895

[Epoch 3 / Global Epoch 49]
    Train Loss:     1578.2780284881592
    Test Accuracy:  0.8866
    Best Test Acc:  0.8895

[Epoch 4 / Global Epoch 50]
    Train Loss:     1576.4174890518188
    Test Accuracy:  0.8902
    Best Test Acc:  0.8902

[Epoch 5 / Global Epoch 51]
    Train Loss:     1574.875913977623
    Test Accuracy:  0.8913
    Best Test Acc:  0.8913

[Epoch 6 / Global Epoch 52]
    Train Loss:     1573.2555277347565
    Test Accuracy:  0.8913
    Best Test Acc:  0.8913



In [27]:
checkpoint8 = torch.load("model_large_52.pt")

In [28]:
mybing.load_state_dict(checkpoint8['model_state'])
mybing.optimizer.load_state_dict(checkpoint8['optimizer_state'])
mybing.total_epochs_trained = checkpoint8['epoch']

print(mybing.total_epochs_trained)

52


In [29]:
mybing.train(
    train_loader,
    test_loader,
    8
)

f = open(f"model_large_{mybing.total_epochs_trained}.pt", "w").close()
torch.save({
    'model_state': mybing.state_dict(),
    'optimizer_state': mybing.optimizer.state_dict(),
    'epoch': mybing.total_epochs_trained
    }, f"model_large_{mybing.total_epochs_trained}.pt"
)

[Epoch 0 / Global Epoch 53]
    Train Loss:     1573.055364727974
    Test Accuracy:  0.8891
    Best Test Acc:  0.8891

[Epoch 1 / Global Epoch 54]
    Train Loss:     1572.4587986469269
    Test Accuracy:  0.8907
    Best Test Acc:  0.8907

[Epoch 2 / Global Epoch 55]
    Train Loss:     1571.5949985980988
    Test Accuracy:  0.8909
    Best Test Acc:  0.8909

[Epoch 3 / Global Epoch 56]
    Train Loss:     1570.350787639618
    Test Accuracy:  0.8883
    Best Test Acc:  0.8909

[Epoch 4 / Global Epoch 57]
    Train Loss:     1568.5161921977997
    Test Accuracy:  0.8923
    Best Test Acc:  0.8923

[Epoch 5 / Global Epoch 58]
    Train Loss:     1567.1864584684372
    Test Accuracy:  0.8931
    Best Test Acc:  0.8931

[Epoch 6 / Global Epoch 59]
    Train Loss:     1568.2514601945877
    Test Accuracy:  0.8895
    Best Test Acc:  0.8931

[Epoch 7 / Global Epoch 60]
    Train Loss:     1566.8065421581268
    Test Accuracy:  0.8887
    Best Test Acc:  0.8931

