In [1]:
import torch
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, datasets

In [2]:
train_data = datasets.CIFAR10('CIFAR10', train = True, 
                         transform = transforms.Compose([transforms.ToTensor()]),
                         download = True)

test_data = datasets.CIFAR10('CIFAR10', train = False,
                             transform = transforms.Compose([transforms.ToTensor()]),
                             download = True)

train_batch = torch.utils.data.DataLoader(train_data, batch_size = 100, shuffle = True)
test_batch = torch.utils.data.DataLoader(test_data, batch_size = 1, shuffle = True)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 96, kernel_size = (3, 3), stride = 1, padding = 1) #  32x32
        '''In the original paper the authors have drastically reduced the resolution of the image but since the image size here 
        is already very small therefore preserved the dimension in the first convolution operation'''
        self.conv2 = nn.Conv2d(96, 256, kernel_size = (3, 3), stride = 1, padding = 1) # 16x16
        self.conv3 = nn.Conv2d(256, 384, kernel_size = (3, 3), stride = 1, padding = 1) #8x8
        self.conv4 = nn.Conv2d(384, 384, kernel_size = (3, 3), stride = 1, padding = 1) #8x8
        self.conv5 = nn.Conv2d(384, 256, kernel_size = (3, 3), stride = 1, padding = 1) #8x8
        self.fc1 = nn.Linear(4096, 2048)
        self.fc2 = nn.Linear(2048, 2048)
        self.fc3 = nn.Linear(2048, 10)
        
    def convs(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size = (2, 2), stride = 2, padding = 0) # 16x16
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size = (2, 2), stride = 2, padding = 0) # 8x8
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.conv5(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size = (2, 2), stride = 2, padding = 0) #4x4
        return x
        
    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, 4096)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.softmax(x, dim = 1)
        return x
    
alexnet = AlexNet()

In [4]:
optimizer = optim.Adam(alexnet.parameters(), lr = 0.001)
loss_function = nn.BCELoss()

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print("Running on GPU")
else:
    device = torch.device('cpu')
    print("Running on cpu")

Running on GPU


In [6]:
alexnet.to(device)

AlexNet(
  (conv1): Conv2d(3, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(96, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=4096, out_features=2048, bias=True)
  (fc2): Linear(in_features=2048, out_features=2048, bias=True)
  (fc3): Linear(in_features=2048, out_features=10, bias=True)
)

In [7]:
def one_hot_encoder(labels, num_labels):
    batch_size = len(labels)
    one_hot_labels = np.zeros([batch_size, num_labels])
    for i in range(batch_size):
        one_hot_labels[i][labels[i]] = 1
    return one_hot_labels

In [11]:
EPOCHS = 10

def train(train_batch):
    for epoch in range(EPOCHS):
        for images, labels in tqdm(train_batch):
            one_hot_labels = torch.Tensor(one_hot_encoder(labels, 10))
            alexnet.zero_grad()
            outputs = alexnet.forward(images.to(device))
            loss = loss_function(outputs, one_hot_labels.to(device))
            loss.backward()
            optimizer.step()
        print(f"Epoch: {epoch}. Loss: {loss}")

In [12]:
def test(test_batch):
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_batch):
            labels = labels.to(device)
            net_out = alexnet(images.to(device))
            predicted_class = torch.argmax(net_out).to(device)
            if (predicted_class == labels):
                correct += 1
            total += 1
        print("Accuracy: ", round(correct/total, 3))

In [13]:
train(train_batch)

100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [02:14<00:00,  3.72it/s]
  0%|                                                                                          | 0/500 [00:00<?, ?it/s]

Epoch: 0. Loss: 0.22507333755493164


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [02:38<00:00,  3.15it/s]
  0%|                                                                                          | 0/500 [00:00<?, ?it/s]

Epoch: 1. Loss: 0.1813281923532486


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [03:11<00:00,  2.61it/s]
  0%|                                                                                          | 0/500 [00:00<?, ?it/s]

Epoch: 2. Loss: 0.13120971620082855


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [02:58<00:00,  2.80it/s]
  0%|                                                                                          | 0/500 [00:00<?, ?it/s]

Epoch: 3. Loss: 0.13807496428489685


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [03:05<00:00,  2.70it/s]
  0%|                                                                                          | 0/500 [00:00<?, ?it/s]

Epoch: 4. Loss: 0.10920222103595734


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [02:53<00:00,  2.88it/s]
  0%|                                                                                          | 0/500 [00:00<?, ?it/s]

Epoch: 5. Loss: 0.08597782254219055


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [02:59<00:00,  2.79it/s]
  0%|                                                                                          | 0/500 [00:00<?, ?it/s]

Epoch: 6. Loss: 0.09013327211141586


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [02:54<00:00,  2.86it/s]
  0%|                                                                                          | 0/500 [00:00<?, ?it/s]

Epoch: 7. Loss: 0.10360493510961533


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [02:54<00:00,  2.87it/s]
  0%|                                                                                          | 0/500 [00:00<?, ?it/s]

Epoch: 8. Loss: 0.0747457966208458


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [03:03<00:00,  2.73it/s]

Epoch: 9. Loss: 0.05097617208957672





In [14]:
test(test_batch)

100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [01:10<00:00, 142.04it/s]

Accuracy:  0.715



