In [1]:
import torch
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

In [2]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 96, kernel_size = 11, stride = 4, padding = 0)
        self.conv2 = nn.Conv2d(96, 256, kernel_size = 5, stride = 1, padding = 2)
        self.conv3 = nn.Conv2d(256, 384, kernel_size = 3, stride = 1, padding = 1)
        self.conv4 = nn.Conv2d(384, 384, kernel_size = 3, stride = 1, padding = 1)
        self.conv5 = nn.Conv2d(384, 256, kernel_size = 3, stride = 1, padding = 1)
        self.fc1 = nn.Linear(9216, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 2)
        
    def convs(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size = (3, 3), stride = 2, padding = 0)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size = (3, 3), stride = 2, padding = 0)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.conv5(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size = (3, 3), stride = 2, padding = 0)
        return x
        
    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, 9216)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.softmax(x, dim = 1)
        return x
    
alexnet = AlexNet()

In [3]:
import torch.optim as optim

optimizer = optim.Adam(alexnet.parameters(), lr = 0.001)
loss_function = nn.MSELoss()

In [4]:
training_data = np.load('training_data.npy', allow_pickle = True)

In [5]:
X = torch.Tensor([i[0] for i in training_data]).view(-1, 227, 227)
X = X/255.0
y = torch.Tensor([i[1] for i in training_data])
del training_data

In [7]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)
del X

In [38]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print("Running on GPU")
else:
    device = torch.device('cpu')
    print("Running on cpu")

Running on GPU


In [40]:
alexnet.to(device)

AlexNet(
  (conv1): Conv2d(1, 96, kernel_size=(11, 11), stride=(4, 4))
  (conv2): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=9216, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=4096, bias=True)
  (fc3): Linear(in_features=4096, out_features=2, bias=True)
)

In [41]:
BATCH_SIZE = 100
EPOCHS = 1

for epoch in range(EPOCHS):
    for i in tqdm(range(0, len(X_train), BATCH_SIZE)):
        X_batch = X_train[i:i+BATCH_SIZE].view(-1,1,227,227).to(device)
        y_batch = y_train[i:i+BATCH_SIZE].to(device)
        alexnet.zero_grad()
        outputs = alexnet(X_batch)
        loss = loss_function(outputs, y_batch)
        loss.backward()
        optimizer.step()
    print(f"Epoch: {epoch}. Loss: {loss}")


  0%|                                                                                          | 0/213 [00:00<?, ?it/s][A
  0%|▍                                                                                 | 1/213 [00:03<13:37,  3.86s/it][A
  1%|▊                                                                                 | 2/213 [00:04<10:02,  2.86s/it][A
  1%|█▏                                                                                | 3/213 [00:05<07:58,  2.28s/it][A
  2%|█▌                                                                                | 4/213 [00:06<06:34,  1.89s/it][A
  2%|█▉                                                                                | 5/213 [00:07<05:35,  1.61s/it][A
  3%|██▎                                                                               | 6/213 [00:08<04:53,  1.42s/it][A
  3%|██▋                                                                               | 7/213 [00:09<04:25,  1.29s/it][A
  4%|███       

 62%|█████████████████████████████████████████████████▌                              | 132/213 [02:27<01:33,  1.16s/it][A
 62%|█████████████████████████████████████████████████▉                              | 133/213 [02:28<01:32,  1.16s/it][A
 63%|██████████████████████████████████████████████████▎                             | 134/213 [02:29<01:31,  1.15s/it][A
 63%|██████████████████████████████████████████████████▋                             | 135/213 [02:30<01:30,  1.16s/it][A
 64%|███████████████████████████████████████████████████                             | 136/213 [02:31<01:29,  1.17s/it][A
 64%|███████████████████████████████████████████████████▍                            | 137/213 [02:33<01:29,  1.17s/it][A
 65%|███████████████████████████████████████████████████▊                            | 138/213 [02:34<01:28,  1.18s/it][A
 65%|████████████████████████████████████████████████████▏                           | 139/213 [02:35<01:25,  1.16s/it][A
 66%|███████████

Epoch: 0. Loss: 0.2498754858970642


In [45]:
def test(alexnet):
    correct = 0
    total = 0
    with torch.no_grad():
        for i in tqdm(range(len(X_test))):
            real_class = torch.argmax(y_test[i]).to(device)
            net_out = alexnet(X_test[i].view(-1, 1, 227, 227).to(device))[0]
            predicted_class = torch.argmax(net_out)

            if predicted_class == real_class:
                correct += 1
            total += 1
    print("Accuracy: ", round(correct/total, 3))

test(alexnet)


  0%|                                                                                         | 0/2365 [00:00<?, ?it/s][A
  0%|▏                                                                                | 5/2365 [00:00<00:52, 45.11it/s][A
  0%|▎                                                                               | 11/2365 [00:00<00:48, 48.68it/s][A
  1%|▌                                                                               | 18/2365 [00:00<00:45, 52.03it/s][A
  1%|▊                                                                               | 24/2365 [00:00<00:43, 53.92it/s][A
  1%|█                                                                               | 30/2365 [00:00<00:42, 55.50it/s][A
  2%|█▎                                                                              | 37/2365 [00:00<00:41, 56.55it/s][A
  2%|█▍                                                                              | 44/2365 [00:00<00:39, 58.05it/s][A
  2%|█▋        

 37%|████████████████████████████▉                                                  | 867/2365 [00:14<00:26, 56.75it/s][A
 37%|█████████████████████████████▏                                                 | 873/2365 [00:14<00:26, 57.07it/s][A
 37%|█████████████████████████████▎                                                 | 879/2365 [00:14<00:26, 56.07it/s][A
 37%|█████████████████████████████▌                                                 | 885/2365 [00:14<00:26, 56.75it/s][A
 38%|█████████████████████████████▊                                                 | 892/2365 [00:15<00:24, 59.09it/s][A
 38%|█████████████████████████████▉                                                 | 898/2365 [00:15<00:25, 56.87it/s][A
 38%|██████████████████████████████▏                                                | 904/2365 [00:15<00:25, 56.43it/s][A
 38%|██████████████████████████████▍                                                | 910/2365 [00:15<00:25, 56.85it/s][A
 39%|███████████

 71%|███████████████████████████████████████████████████████▏                      | 1672/2365 [00:28<00:12, 55.63it/s][A
 71%|███████████████████████████████████████████████████████▎                      | 1678/2365 [00:28<00:12, 55.25it/s][A
 71%|███████████████████████████████████████████████████████▌                      | 1684/2365 [00:29<00:12, 55.75it/s][A
 71%|███████████████████████████████████████████████████████▋                      | 1690/2365 [00:29<00:11, 56.68it/s][A
 72%|███████████████████████████████████████████████████████▉                      | 1696/2365 [00:29<00:11, 56.38it/s][A
 72%|████████████████████████████████████████████████████████▏                     | 1702/2365 [00:29<00:11, 56.32it/s][A
 72%|████████████████████████████████████████████████████████▎                     | 1708/2365 [00:29<00:11, 56.53it/s][A
 72%|████████████████████████████████████████████████████████▌                     | 1714/2365 [00:29<00:11, 55.51it/s][A
 73%|███████████

Accuracy:  0.478



