## AlexNet on CIFAR 10

In [1]:
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
import tqdm
import torchinfo


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
device

device(type='cuda', index=0)

In [3]:
'''
Step 1:
'''

transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

train_dataset = datasets.CIFAR10(root='./cifar_10data/',
                                 train=True, 
                                 transform=transform,
                                 download=True)

test_dataset = datasets.CIFAR10(root='./cifar_10data/',
                                train=False, 
                                transform=transforms.ToTensor())
    

Files already downloaded and verified


In [4]:
'''
Step 2
'''

class AlexNet(nn.Module) :
    
    def __init__(self, num_class=10) :
        super(AlexNet, self).__init__()
        
        self.conv_layer1 = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=4),
                nn.ReLU(),
                nn.Conv2d(96, 96, kernel_size=3),
                nn.ReLU()
                )
        self.conv_layer2 = nn.Sequential(
                nn.Conv2d(96, 256, kernel_size=5, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2)
                )
        self.conv_layer3 = nn.Sequential(
                nn.Conv2d(256, 384, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(384, 384, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(384, 256, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2)
                )
        
        self.fc_layer1 = nn.Sequential(
                nn.Dropout(),
                nn.Linear(9216, 4096),
                nn.ReLU(),
                nn.Dropout(),  #p=0.5 by default
                nn.Linear(4096, 4096),
                nn.ReLU(),     #p=0.5 by default
                nn.Linear(4096, 10)
                )
    
    def forward(self, x) :
        output = self.conv_layer1(x)
        output = self.conv_layer2(output)
        output = self.conv_layer3(output)
        output = output.view(-1, 9216)
        output = self.fc_layer1(output)
        return output

    

In [5]:
'''
Step 3
'''
model = AlexNet().to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1, weight_decay=0.00005)

In [6]:
torchinfo.summary(model,(200,3,32,32))

Layer (type:depth-idx)                   Output Shape              Param #
AlexNet                                  [200, 10]                 --
├─Sequential: 1-1                        [200, 96, 27, 27]         --
│    └─Conv2d: 2-1                       [200, 96, 29, 29]         4,704
│    └─ReLU: 2-2                         [200, 96, 29, 29]         --
│    └─Conv2d: 2-3                       [200, 96, 27, 27]         83,040
│    └─ReLU: 2-4                         [200, 96, 27, 27]         --
├─Sequential: 1-2                        [200, 256, 13, 13]        --
│    └─Conv2d: 2-5                       [200, 256, 27, 27]        614,656
│    └─ReLU: 2-6                         [200, 256, 27, 27]        --
│    └─MaxPool2d: 2-7                    [200, 256, 13, 13]        --
├─Sequential: 1-3                        [200, 256, 6, 6]          --
│    └─Conv2d: 2-8                       [200, 384, 13, 13]        885,120
│    └─ReLU: 2-9                         [200, 384, 13, 13]        -

In [7]:
torch.cuda.empty_cache()

In [8]:
'''
Step 4
'''
model.train()
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=200, shuffle=True)

import time
start = time.time()
for epoch in tqdm.tqdm(range(100)):
    print("{}th epoch starting.".format(epoch))
    for i, (images, labels) in enumerate(train_loader) :
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        train_loss = loss_function(model(images), labels)
        train_loss.backward()

        optimizer.step()

    print ("Epoch [{}] Loss: {:.4f}".format(epoch+1, train_loss.item()))

end = time.time()
print("Time ellapsed in training is: {}".format(end - start))

  0%|          | 0/100 [00:00<?, ?it/s]

0th epoch starting.


  1%|          | 1/100 [00:14<24:01, 14.56s/it]

Epoch [1] Loss: 2.3025
1th epoch starting.


  2%|▏         | 2/100 [00:28<23:21, 14.30s/it]

Epoch [2] Loss: 2.2971
2th epoch starting.


  3%|▎         | 3/100 [00:42<22:59, 14.22s/it]

Epoch [3] Loss: 2.2222
3th epoch starting.


  4%|▍         | 4/100 [00:56<22:41, 14.19s/it]

Epoch [4] Loss: 2.3039
4th epoch starting.


  5%|▌         | 5/100 [01:11<22:25, 14.17s/it]

Epoch [5] Loss: 2.1935
5th epoch starting.


  6%|▌         | 6/100 [01:25<22:11, 14.17s/it]

Epoch [6] Loss: 2.0747
6th epoch starting.


  7%|▋         | 7/100 [01:39<21:57, 14.17s/it]

Epoch [7] Loss: 2.0630
7th epoch starting.


  8%|▊         | 8/100 [01:53<21:43, 14.17s/it]

Epoch [8] Loss: 2.0028
8th epoch starting.


  9%|▉         | 9/100 [02:07<21:29, 14.17s/it]

Epoch [9] Loss: 1.8296
9th epoch starting.


 10%|█         | 10/100 [02:21<21:15, 14.17s/it]

Epoch [10] Loss: 1.8083
10th epoch starting.


 11%|█         | 11/100 [02:36<21:01, 14.17s/it]

Epoch [11] Loss: 2.0656
11th epoch starting.


 12%|█▏        | 12/100 [02:50<20:47, 14.17s/it]

Epoch [12] Loss: 1.5913
12th epoch starting.


 13%|█▎        | 13/100 [03:04<20:33, 14.18s/it]

Epoch [13] Loss: 1.4546
13th epoch starting.


 14%|█▍        | 14/100 [03:18<20:19, 14.18s/it]

Epoch [14] Loss: 1.4203
14th epoch starting.


 15%|█▌        | 15/100 [03:32<20:05, 14.18s/it]

Epoch [15] Loss: 1.5039
15th epoch starting.


 16%|█▌        | 16/100 [03:46<19:50, 14.17s/it]

Epoch [16] Loss: 1.3566
16th epoch starting.


 17%|█▋        | 17/100 [04:01<19:35, 14.17s/it]

Epoch [17] Loss: 1.4507
17th epoch starting.


 18%|█▊        | 18/100 [04:15<19:21, 14.16s/it]

Epoch [18] Loss: 1.2539
18th epoch starting.


 19%|█▉        | 19/100 [04:29<19:06, 14.16s/it]

Epoch [19] Loss: 1.2235
19th epoch starting.


 20%|██        | 20/100 [04:43<18:53, 14.17s/it]

Epoch [20] Loss: 1.0334
20th epoch starting.


 21%|██        | 21/100 [04:57<18:39, 14.17s/it]

Epoch [21] Loss: 0.9469
21th epoch starting.


 22%|██▏       | 22/100 [05:11<18:25, 14.17s/it]

Epoch [22] Loss: 1.0939
22th epoch starting.


 23%|██▎       | 23/100 [05:26<18:11, 14.17s/it]

Epoch [23] Loss: 0.9552
23th epoch starting.


 24%|██▍       | 24/100 [05:40<17:56, 14.17s/it]

Epoch [24] Loss: 0.8655
24th epoch starting.


 25%|██▌       | 25/100 [05:54<17:42, 14.16s/it]

Epoch [25] Loss: 0.7302
25th epoch starting.


 26%|██▌       | 26/100 [06:08<17:27, 14.16s/it]

Epoch [26] Loss: 0.7275
26th epoch starting.


 27%|██▋       | 27/100 [06:22<17:13, 14.15s/it]

Epoch [27] Loss: 0.8921
27th epoch starting.


 28%|██▊       | 28/100 [06:36<16:58, 14.15s/it]

Epoch [28] Loss: 0.7175
28th epoch starting.


 29%|██▉       | 29/100 [06:51<16:43, 14.14s/it]

Epoch [29] Loss: 0.7119
29th epoch starting.


 30%|███       | 30/100 [07:05<16:29, 14.14s/it]

Epoch [30] Loss: 0.6328
30th epoch starting.


 31%|███       | 31/100 [07:19<16:16, 14.15s/it]

Epoch [31] Loss: 0.5676
31th epoch starting.


 32%|███▏      | 32/100 [07:33<16:02, 14.16s/it]

Epoch [32] Loss: 0.7417
32th epoch starting.


 33%|███▎      | 33/100 [07:47<15:48, 14.15s/it]

Epoch [33] Loss: 0.7248
33th epoch starting.


 34%|███▍      | 34/100 [08:01<15:33, 14.15s/it]

Epoch [34] Loss: 0.5872
34th epoch starting.


 35%|███▌      | 35/100 [08:15<15:19, 14.14s/it]

Epoch [35] Loss: 0.5630
35th epoch starting.


 36%|███▌      | 36/100 [08:30<15:04, 14.14s/it]

Epoch [36] Loss: 0.6076
36th epoch starting.


 37%|███▋      | 37/100 [08:44<14:50, 14.14s/it]

Epoch [37] Loss: 0.4562
37th epoch starting.


 38%|███▊      | 38/100 [08:58<14:36, 14.14s/it]

Epoch [38] Loss: 0.4403
38th epoch starting.


 39%|███▉      | 39/100 [09:12<14:22, 14.13s/it]

Epoch [39] Loss: 0.5114
39th epoch starting.


 40%|████      | 40/100 [09:26<14:08, 14.13s/it]

Epoch [40] Loss: 0.4704
40th epoch starting.


 41%|████      | 41/100 [09:40<13:53, 14.13s/it]

Epoch [41] Loss: 0.4095
41th epoch starting.


 42%|████▏     | 42/100 [09:54<13:39, 14.14s/it]

Epoch [42] Loss: 0.3939
42th epoch starting.


 43%|████▎     | 43/100 [10:08<13:25, 14.14s/it]

Epoch [43] Loss: 0.5130
43th epoch starting.


 44%|████▍     | 44/100 [10:23<13:12, 14.15s/it]

Epoch [44] Loss: 0.4038
44th epoch starting.


 45%|████▌     | 45/100 [10:37<12:57, 14.14s/it]

Epoch [45] Loss: 0.3806
45th epoch starting.


 46%|████▌     | 46/100 [10:51<12:43, 14.14s/it]

Epoch [46] Loss: 0.3828
46th epoch starting.


 47%|████▋     | 47/100 [11:05<12:29, 14.15s/it]

Epoch [47] Loss: 0.3658
47th epoch starting.


 48%|████▊     | 48/100 [11:19<12:16, 14.16s/it]

Epoch [48] Loss: 0.3664
48th epoch starting.


 49%|████▉     | 49/100 [11:33<12:01, 14.14s/it]

Epoch [49] Loss: 0.2810
49th epoch starting.


 50%|█████     | 50/100 [11:47<11:46, 14.13s/it]

Epoch [50] Loss: 0.2696
50th epoch starting.


 51%|█████     | 51/100 [12:02<11:37, 14.22s/it]

Epoch [51] Loss: 0.3245
51th epoch starting.


 52%|█████▏    | 52/100 [12:16<11:21, 14.20s/it]

Epoch [52] Loss: 0.3586
52th epoch starting.


 53%|█████▎    | 53/100 [12:30<11:06, 14.17s/it]

Epoch [53] Loss: 0.2931
53th epoch starting.


 54%|█████▍    | 54/100 [12:44<10:51, 14.16s/it]

Epoch [54] Loss: 0.2656
54th epoch starting.


 55%|█████▌    | 55/100 [12:58<10:36, 14.15s/it]

Epoch [55] Loss: 0.2927
55th epoch starting.


 56%|█████▌    | 56/100 [13:13<10:22, 14.14s/it]

Epoch [56] Loss: 0.3051
56th epoch starting.


 57%|█████▋    | 57/100 [13:27<10:08, 14.14s/it]

Epoch [57] Loss: 0.3889
57th epoch starting.


 58%|█████▊    | 58/100 [13:41<09:53, 14.14s/it]

Epoch [58] Loss: 0.2850
58th epoch starting.


 59%|█████▉    | 59/100 [13:55<09:39, 14.15s/it]

Epoch [59] Loss: 0.3220
59th epoch starting.


 60%|██████    | 60/100 [14:09<09:25, 14.14s/it]

Epoch [60] Loss: 0.3015
60th epoch starting.


 61%|██████    | 61/100 [14:23<09:11, 14.13s/it]

Epoch [61] Loss: 0.2159
61th epoch starting.


 62%|██████▏   | 62/100 [14:37<08:57, 14.13s/it]

Epoch [62] Loss: 0.2165
62th epoch starting.


 63%|██████▎   | 63/100 [14:52<08:43, 14.14s/it]

Epoch [63] Loss: 0.2059
63th epoch starting.


 64%|██████▍   | 64/100 [15:06<08:29, 14.14s/it]

Epoch [64] Loss: 0.2267
64th epoch starting.


 65%|██████▌   | 65/100 [15:20<08:14, 14.14s/it]

Epoch [65] Loss: 0.1788
65th epoch starting.


 66%|██████▌   | 66/100 [15:34<08:00, 14.13s/it]

Epoch [66] Loss: 0.1584
66th epoch starting.


 67%|██████▋   | 67/100 [15:48<07:46, 14.14s/it]

Epoch [67] Loss: 0.1419
67th epoch starting.


 68%|██████▊   | 68/100 [16:02<07:32, 14.14s/it]

Epoch [68] Loss: 0.1792
68th epoch starting.


 69%|██████▉   | 69/100 [16:16<07:18, 14.14s/it]

Epoch [69] Loss: 0.1770
69th epoch starting.


 70%|███████   | 70/100 [16:31<07:04, 14.15s/it]

Epoch [70] Loss: 0.1288
70th epoch starting.


 71%|███████   | 71/100 [16:45<06:50, 14.15s/it]

Epoch [71] Loss: 0.1292
71th epoch starting.


 72%|███████▏  | 72/100 [16:59<06:36, 14.16s/it]

Epoch [72] Loss: 0.3029
72th epoch starting.


 73%|███████▎  | 73/100 [17:13<06:22, 14.16s/it]

Epoch [73] Loss: 0.2477
73th epoch starting.


 74%|███████▍  | 74/100 [17:27<06:07, 14.15s/it]

Epoch [74] Loss: 0.1683
74th epoch starting.


 75%|███████▌  | 75/100 [17:41<05:53, 14.14s/it]

Epoch [75] Loss: 0.1962
75th epoch starting.


 76%|███████▌  | 76/100 [17:55<05:39, 14.14s/it]

Epoch [76] Loss: 0.1905
76th epoch starting.


 77%|███████▋  | 77/100 [18:10<05:25, 14.15s/it]

Epoch [77] Loss: 0.1337
77th epoch starting.


 78%|███████▊  | 78/100 [18:24<05:11, 14.14s/it]

Epoch [78] Loss: 0.1377
78th epoch starting.


 79%|███████▉  | 79/100 [18:38<04:57, 14.14s/it]

Epoch [79] Loss: 0.2116
79th epoch starting.


 80%|████████  | 80/100 [18:52<04:42, 14.14s/it]

Epoch [80] Loss: 0.1637
80th epoch starting.


 81%|████████  | 81/100 [19:06<04:28, 14.14s/it]

Epoch [81] Loss: 0.1031
81th epoch starting.


 82%|████████▏ | 82/100 [19:20<04:14, 14.14s/it]

Epoch [82] Loss: 0.1333
82th epoch starting.


 83%|████████▎ | 83/100 [19:34<04:00, 14.14s/it]

Epoch [83] Loss: 0.1243
83th epoch starting.


 84%|████████▍ | 84/100 [19:49<03:46, 14.14s/it]

Epoch [84] Loss: 0.1744
84th epoch starting.


 85%|████████▌ | 85/100 [20:03<03:32, 14.14s/it]

Epoch [85] Loss: 0.1850
85th epoch starting.


 86%|████████▌ | 86/100 [20:17<03:18, 14.14s/it]

Epoch [86] Loss: 0.1085
86th epoch starting.


 87%|████████▋ | 87/100 [20:31<03:03, 14.14s/it]

Epoch [87] Loss: 0.0918
87th epoch starting.


 88%|████████▊ | 88/100 [20:45<02:49, 14.14s/it]

Epoch [88] Loss: 0.1926
88th epoch starting.


 89%|████████▉ | 89/100 [20:59<02:35, 14.14s/it]

Epoch [89] Loss: 0.0535
89th epoch starting.


 90%|█████████ | 90/100 [21:13<02:21, 14.14s/it]

Epoch [90] Loss: 0.1748
90th epoch starting.


 91%|█████████ | 91/100 [21:28<02:07, 14.14s/it]

Epoch [91] Loss: 0.1104
91th epoch starting.


 92%|█████████▏| 92/100 [21:42<01:53, 14.13s/it]

Epoch [92] Loss: 0.0982
92th epoch starting.


 93%|█████████▎| 93/100 [21:56<01:38, 14.13s/it]

Epoch [93] Loss: 0.1666
93th epoch starting.


 94%|█████████▍| 94/100 [22:10<01:24, 14.13s/it]

Epoch [94] Loss: 0.1224
94th epoch starting.


 95%|█████████▌| 95/100 [22:24<01:10, 14.14s/it]

Epoch [95] Loss: 0.1674
95th epoch starting.


 96%|█████████▌| 96/100 [22:38<00:56, 14.14s/it]

Epoch [96] Loss: 0.0808
96th epoch starting.


 97%|█████████▋| 97/100 [22:52<00:42, 14.13s/it]

Epoch [97] Loss: 0.0650
97th epoch starting.


 98%|█████████▊| 98/100 [23:06<00:28, 14.13s/it]

Epoch [98] Loss: 0.1547
98th epoch starting.


 99%|█████████▉| 99/100 [23:21<00:14, 14.13s/it]

Epoch [99] Loss: 0.0925
99th epoch starting.


100%|██████████| 100/100 [23:35<00:00, 14.15s/it]

Epoch [100] Loss: 0.0703
Time ellapsed in training is: 1415.2418777942657





In [9]:
'''
Step 5
'''
model.eval()
test_loss, correct, total = 0, 0, 0

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)
with torch.no_grad():  #using context manager
    for images, labels in test_loader :
        images, labels = images.to(device), labels.to(device)

        output = model(images)
        test_loss += loss_function(output, labels).item()

        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(labels.view_as(pred)).sum().item()

        total += labels.size(0)

print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss /total, correct, total,
        100. * correct / total))

[Test set] Average loss: 0.0039, Accuracy: 8953/10000 (89.53%)

