# Basic Imports

In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

# Load traning  data

In [2]:
training_data = np.load('traning_data.npy',allow_pickle=True)

# Structure of network

In [31]:
class Net(nn.Module):
    def __init__(self):
        super().__init__() # just run the init of parent class (nn.Module)
        self.conv1 = nn.Conv2d(1, 32, 5) # input is 1 image, 32 output channels, 5x5 kernel / window
        self.conv2 = nn.Conv2d(32, 64, 5) # input is 32, bc the first layer output 32. Then we say the output will be 64 channels, 5x5 kernel / window
        self.conv3 = nn.Conv2d(64, 128, 5)

        x = torch.randn(50,50).view(-1,1,50,50)
        self._to_linear = None
        self.convs(x)

        self.fc1 = nn.Linear(self._to_linear, 512) #flattening.
        self.fc2 = nn.Linear(512, 2) # 512 in, 2 out bc we're doing 2 classes (dog vs cat).

    def convs(self, x):
        # max pooling over 2x2
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))

        if self._to_linear is None:
            self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
        return x

    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self._to_linear)  # .view is reshape ... this flattens X before 
        x = F.relu(self.fc1(x))
        x = self.fc2(x) # bc this is our output layer. No activation here.
        return F.softmax(x, dim=1)

In [32]:
net = Net()
print(net)

Net(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=512, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=2, bias=True)
)


# Traning

In [33]:
import torch.optim as optim

optimizer = optim.Adam(net.parameters(), lr=0.001)
loss_function = nn.MSELoss()

X = torch.Tensor([i[0] for i in training_data]).view(-1,50,50)
X = X/255.0
y = torch.Tensor([i[1] for i in training_data])

VAL_PCT = 0.1  #  reserve 10% of our data for validation
val_size = int(len(X)*VAL_PCT)

train_X = X[:-val_size]
train_y = y[:-val_size]

test_X = X[-val_size:]
test_y = y[-val_size:]



In [34]:
print('Validation Size:',val_size)
print('Train Size:',len(train_X))
print('Test Size:',len(test_X))

Validation Size: 2494
Train Size: 22452
Test Size: 2494


In [37]:
BATCH_SIZE = 100
EPOCHS = 1

for epoch in range(EPOCHS):
    for i in tqdm(range(0, len(train_X), BATCH_SIZE)): # from 0, to the len of x, stepping BATCH_SIZE at a time. [:50] ..for now just to dev
        batch_X = train_X[i:i+BATCH_SIZE].view(-1, 1, 50, 50)
        batch_y = train_y[i:i+BATCH_SIZE]
        net.zero_grad()
        outputs = net(batch_X)
        loss = loss_function(outputs, batch_y)
        loss.backward()
        optimizer.step()    # Does the update

    print(f"Epoch: {epoch}. Loss: {loss}")



  0%|          | 0/225 [00:00<?, ?it/s][A
  0%|          | 1/225 [00:00<00:35,  6.33it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



  1%|          | 2/225 [00:00<00:36,  6.14it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



  1%|▏         | 3/225 [00:00<00:35,  6.18it/s][A
  2%|▏         | 4/225 [00:00<00:37,  5.92it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



  2%|▏         | 5/225 [00:00<00:36,  6.10it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



  3%|▎         | 6/225 [00:01<00:37,  5.77it/s][A
  3%|▎         | 7/225 [00:01<00:39,  5.49it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



  4%|▎         | 8/225 [00:01<00:40,  5.36it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



  4%|▍         | 9/225 [00:01<00:44,  4.85it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



  4%|▍         | 10/225 [00:01<00:42,  5.10it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



  5%|▍         | 11/225 [00:02<00:44,  4.86it/s][A
  5%|▌         | 12/225 [00:02<00:41,  5.12it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



  6%|▌         | 13/225 [00:02<00:39,  5.34it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



  6%|▌         | 14/225 [00:02<00:38,  5.51it/s][A
  7%|▋         | 15/225 [00:02<00:38,  5.46it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



  7%|▋         | 16/225 [00:02<00:37,  5.55it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



  8%|▊         | 17/225 [00:03<00:37,  5.53it/s][A
  8%|▊         | 18/225 [00:03<00:34,  5.93it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



  8%|▊         | 19/225 [00:03<00:33,  6.13it/s][A
  9%|▉         | 20/225 [00:03<00:32,  6.33it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



  9%|▉         | 21/225 [00:03<00:31,  6.39it/s][A
 10%|▉         | 22/225 [00:03<00:31,  6.45it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 10%|█         | 23/225 [00:04<00:35,  5.66it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 11%|█         | 24/225 [00:04<00:35,  5.66it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 11%|█         | 25/225 [00:04<00:39,  5.07it/s][A
 12%|█▏        | 26/225 [00:04<00:38,  5.19it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 12%|█▏        | 27/225 [00:04<00:36,  5.44it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 12%|█▏        | 28/225 [00:05<00:35,  5.62it/s][A
 13%|█▎        | 29/225 [00:05<00:34,  5.65it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 13%|█▎        | 30/225 [00:05<00:33,  5.83it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 14%|█▍        | 31/225 [00:05<00:33,  5.85it/s][A
 14%|█▍        | 32/225 [00:05<00:34,  5.67it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 15%|█▍        | 33/225 [00:05<00:31,  6.01it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 15%|█▌        | 34/225 [00:06<00:32,  5.83it/s][A
 16%|█▌        | 35/225 [00:06<00:30,  6.18it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 16%|█▌        | 36/225 [00:06<00:29,  6.31it/s][A
 16%|█▋        | 37/225 [00:06<00:30,  6.16it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 17%|█▋        | 38/225 [00:06<00:35,  5.20it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 17%|█▋        | 39/225 [00:06<00:35,  5.30it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 18%|█▊        | 40/225 [00:07<00:38,  4.83it/s][A
 18%|█▊        | 41/225 [00:07<00:34,  5.40it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 19%|█▊        | 42/225 [00:07<00:35,  5.15it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 19%|█▉        | 43/225 [00:07<00:35,  5.10it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 20%|█▉        | 44/225 [00:07<00:38,  4.75it/s][A
 20%|██        | 45/225 [00:08<00:36,  4.87it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 20%|██        | 46/225 [00:08<00:38,  4.60it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 21%|██        | 47/225 [00:08<00:37,  4.72it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 21%|██▏       | 48/225 [00:08<00:37,  4.73it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 22%|██▏       | 49/225 [00:09<00:37,  4.66it/s][A
 22%|██▏       | 50/225 [00:09<00:36,  4.79it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 23%|██▎       | 51/225 [00:09<00:32,  5.34it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 23%|██▎       | 52/225 [00:09<00:31,  5.44it/s][A
 24%|██▎       | 53/225 [00:09<00:32,  5.34it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 24%|██▍       | 54/225 [00:10<00:34,  4.97it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 24%|██▍       | 55/225 [00:10<00:33,  5.13it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 25%|██▍       | 56/225 [00:10<00:31,  5.36it/s][A
 25%|██▌       | 57/225 [00:10<00:29,  5.65it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 26%|██▌       | 58/225 [00:10<00:28,  5.82it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 26%|██▌       | 59/225 [00:10<00:28,  5.90it/s][A
 27%|██▋       | 60/225 [00:10<00:27,  6.03it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 27%|██▋       | 61/225 [00:11<00:29,  5.53it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 28%|██▊       | 62/225 [00:11<00:28,  5.64it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 28%|██▊       | 63/225 [00:11<00:30,  5.36it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 28%|██▊       | 64/225 [00:11<00:31,  5.13it/s][A
 29%|██▉       | 65/225 [00:11<00:30,  5.27it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 29%|██▉       | 66/225 [00:12<00:29,  5.43it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 30%|██▉       | 67/225 [00:12<00:28,  5.59it/s][A
 30%|███       | 68/225 [00:12<00:29,  5.30it/s]

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])


[A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 31%|███       | 69/225 [00:12<00:31,  4.93it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 31%|███       | 70/225 [00:12<00:33,  4.70it/s][A
 32%|███▏      | 71/225 [00:13<00:31,  4.86it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 32%|███▏      | 72/225 [00:13<00:31,  4.93it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 32%|███▏      | 73/225 [00:13<00:27,  5.45it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 33%|███▎      | 74/225 [00:13<00:28,  5.22it/s][A
 33%|███▎      | 75/225 [00:13<00:26,  5.63it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 34%|███▍      | 76/225 [00:14<00:25,  5.83it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 34%|███▍      | 77/225 [00:14<00:23,  6.29it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 35%|███▍      | 78/225 [00:14<00:26,  5.63it/s][A
 35%|███▌      | 79/225 [00:14<00:25,  5.73it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 36%|███▌      | 80/225 [00:14<00:27,  5.29it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 36%|███▌      | 81/225 [00:14<00:26,  5.51it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 36%|███▋      | 82/225 [00:15<00:27,  5.15it/s][A
 37%|███▋      | 83/225 [00:15<00:25,  5.58it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 37%|███▋      | 84/225 [00:15<00:24,  5.80it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 38%|███▊      | 85/225 [00:15<00:23,  5.98it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 38%|███▊      | 86/225 [00:15<00:25,  5.48it/s][A
 39%|███▊      | 87/225 [00:15<00:24,  5.64it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 39%|███▉      | 88/225 [00:16<00:22,  6.15it/s][A
 40%|███▉      | 89/225 [00:16<00:20,  6.56it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 40%|████      | 90/225 [00:16<00:20,  6.59it/s][A
 40%|████      | 91/225 [00:16<00:21,  6.31it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 41%|████      | 92/225 [00:16<00:20,  6.44it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 41%|████▏     | 93/225 [00:16<00:19,  6.80it/s][A
 42%|████▏     | 94/225 [00:16<00:18,  7.05it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 42%|████▏     | 95/225 [00:17<00:19,  6.70it/s][A
 43%|████▎     | 96/225 [00:17<00:18,  6.99it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 43%|████▎     | 97/225 [00:17<00:18,  6.78it/s][A
 44%|████▎     | 98/225 [00:17<00:19,  6.52it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 44%|████▍     | 99/225 [00:17<00:21,  5.73it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 44%|████▍     | 100/225 [00:18<00:23,  5.41it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 45%|████▍     | 101/225 [00:18<00:23,  5.36it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 45%|████▌     | 102/225 [00:18<00:21,  5.77it/s][A
 46%|████▌     | 103/225 [00:18<00:19,  6.20it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 46%|████▌     | 104/225 [00:18<00:18,  6.56it/s][A
 47%|████▋     | 105/225 [00:18<00:17,  6.78it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 47%|████▋     | 106/225 [00:18<00:18,  6.40it/s][A
 48%|████▊     | 107/225 [00:19<00:17,  6.73it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 48%|████▊     | 108/225 [00:19<00:16,  6.94it/s][A
 48%|████▊     | 109/225 [00:19<00:16,  7.16it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 49%|████▉     | 110/225 [00:19<00:15,  7.33it/s][A
 49%|████▉     | 111/225 [00:19<00:15,  7.48it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 50%|████▉     | 112/225 [00:19<00:15,  7.51it/s][A
 50%|█████     | 113/225 [00:19<00:14,  7.51it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 51%|█████     | 114/225 [00:19<00:14,  7.46it/s][A
 51%|█████     | 115/225 [00:20<00:16,  6.53it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 52%|█████▏    | 116/225 [00:20<00:17,  6.08it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 52%|█████▏    | 117/225 [00:20<00:19,  5.42it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 52%|█████▏    | 118/225 [00:20<00:19,  5.52it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 53%|█████▎    | 119/225 [00:20<00:17,  6.01it/s][A
 53%|█████▎    | 120/225 [00:21<00:16,  6.39it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 54%|█████▍    | 121/225 [00:21<00:15,  6.66it/s][A
 54%|█████▍    | 122/225 [00:21<00:14,  6.92it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 55%|█████▍    | 123/225 [00:21<00:14,  6.89it/s][A
 55%|█████▌    | 124/225 [00:21<00:14,  7.02it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 56%|█████▌    | 125/225 [00:21<00:14,  6.73it/s][A
 56%|█████▌    | 126/225 [00:21<00:14,  6.81it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 56%|█████▋    | 127/225 [00:22<00:15,  6.53it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 57%|█████▋    | 128/225 [00:22<00:14,  6.82it/s][A
 57%|█████▋    | 129/225 [00:22<00:13,  7.03it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 58%|█████▊    | 130/225 [00:22<00:13,  7.16it/s][A
 58%|█████▊    | 131/225 [00:22<00:12,  7.26it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 59%|█████▊    | 132/225 [00:22<00:13,  6.82it/s][A
 59%|█████▉    | 133/225 [00:22<00:13,  6.70it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 60%|█████▉    | 134/225 [00:23<00:15,  5.83it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 60%|██████    | 135/225 [00:23<00:17,  5.26it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 60%|██████    | 136/225 [00:23<00:17,  4.98it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 61%|██████    | 137/225 [00:23<00:17,  4.94it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 61%|██████▏   | 138/225 [00:23<00:16,  5.13it/s][A
 62%|██████▏   | 139/225 [00:24<00:15,  5.40it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 62%|██████▏   | 140/225 [00:24<00:15,  5.64it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 63%|██████▎   | 141/225 [00:24<00:14,  5.78it/s][A
 63%|██████▎   | 142/225 [00:24<00:14,  5.93it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 64%|██████▎   | 143/225 [00:24<00:14,  5.67it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 64%|██████▍   | 144/225 [00:24<00:13,  5.98it/s][A
 64%|██████▍   | 145/225 [00:25<00:13,  5.90it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 65%|██████▍   | 146/225 [00:25<00:12,  6.13it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 65%|██████▌   | 147/225 [00:25<00:12,  6.05it/s][A
 66%|██████▌   | 148/225 [00:25<00:11,  6.42it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 66%|██████▌   | 149/225 [00:25<00:12,  5.97it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 67%|██████▋   | 150/225 [00:25<00:12,  5.85it/s][A
 67%|██████▋   | 151/225 [00:26<00:13,  5.63it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 68%|██████▊   | 152/225 [00:26<00:12,  5.79it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 68%|██████▊   | 153/225 [00:26<00:11,  6.09it/s][A
 68%|██████▊   | 154/225 [00:26<00:11,  6.24it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 69%|██████▉   | 155/225 [00:26<00:11,  5.84it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 69%|██████▉   | 156/225 [00:27<00:11,  5.79it/s][A
 70%|██████▉   | 157/225 [00:27<00:11,  6.06it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 70%|███████   | 158/225 [00:27<00:10,  6.23it/s][A
 71%|███████   | 159/225 [00:27<00:10,  6.12it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 71%|███████   | 160/225 [00:27<00:11,  5.90it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 72%|███████▏  | 161/225 [00:27<00:11,  5.58it/s][A
 72%|███████▏  | 162/225 [00:27<00:10,  5.98it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 72%|███████▏  | 163/225 [00:28<00:10,  5.79it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 73%|███████▎  | 164/225 [00:28<00:10,  5.98it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 73%|███████▎  | 165/225 [00:28<00:10,  5.47it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 74%|███████▍  | 166/225 [00:28<00:11,  5.14it/s][A
 74%|███████▍  | 167/225 [00:28<00:11,  5.22it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 75%|███████▍  | 168/225 [00:29<00:10,  5.54it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 75%|███████▌  | 169/225 [00:29<00:10,  5.42it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 76%|███████▌  | 170/225 [00:29<00:11,  4.73it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 76%|███████▌  | 171/225 [00:29<00:12,  4.40it/s][A
 76%|███████▋  | 172/225 [00:30<00:11,  4.75it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 77%|███████▋  | 173/225 [00:30<00:09,  5.26it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 77%|███████▋  | 174/225 [00:30<00:10,  4.91it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 78%|███████▊  | 175/225 [00:30<00:10,  4.81it/s][A
 78%|███████▊  | 176/225 [00:30<00:10,  4.88it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 79%|███████▊  | 177/225 [00:31<00:09,  4.96it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 79%|███████▉  | 178/225 [00:31<00:08,  5.34it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 80%|███████▉  | 179/225 [00:31<00:09,  5.01it/s][A
 80%|████████  | 180/225 [00:31<00:08,  5.05it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 80%|████████  | 181/225 [00:31<00:08,  4.90it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 81%|████████  | 182/225 [00:32<00:09,  4.53it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 81%|████████▏ | 183/225 [00:32<00:09,  4.32it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 82%|████████▏ | 184/225 [00:32<00:09,  4.17it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 82%|████████▏ | 185/225 [00:32<00:08,  4.67it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 83%|████████▎ | 186/225 [00:32<00:08,  4.60it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 83%|████████▎ | 187/225 [00:33<00:08,  4.63it/s][A
 84%|████████▎ | 188/225 [00:33<00:07,  4.81it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 84%|████████▍ | 189/225 [00:33<00:06,  5.16it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 84%|████████▍ | 190/225 [00:33<00:06,  5.26it/s][A
 85%|████████▍ | 191/225 [00:33<00:06,  5.30it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 85%|████████▌ | 192/225 [00:34<00:06,  5.34it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 86%|████████▌ | 193/225 [00:34<00:05,  5.42it/s][A
 86%|████████▌ | 194/225 [00:34<00:05,  5.63it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 87%|████████▋ | 195/225 [00:34<00:05,  5.66it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 87%|████████▋ | 196/225 [00:34<00:05,  4.98it/s][A
 88%|████████▊ | 197/225 [00:35<00:05,  5.22it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 88%|████████▊ | 198/225 [00:35<00:05,  4.70it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 88%|████████▊ | 199/225 [00:35<00:06,  4.11it/s][A
 89%|████████▉ | 200/225 [00:35<00:05,  4.44it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 89%|████████▉ | 201/225 [00:35<00:05,  4.68it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 90%|████████▉ | 202/225 [00:36<00:04,  5.00it/s][A
 90%|█████████ | 203/225 [00:36<00:04,  5.38it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 91%|█████████ | 204/225 [00:36<00:03,  5.27it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 91%|█████████ | 205/225 [00:36<00:03,  5.30it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 92%|█████████▏| 206/225 [00:36<00:03,  5.61it/s][A
 92%|█████████▏| 207/225 [00:36<00:03,  5.83it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 92%|█████████▏| 208/225 [00:37<00:02,  5.92it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 93%|█████████▎| 209/225 [00:37<00:02,  5.89it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 93%|█████████▎| 210/225 [00:37<00:02,  5.38it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 94%|█████████▍| 211/225 [00:37<00:02,  4.68it/s][A
 94%|█████████▍| 212/225 [00:37<00:02,  5.15it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 95%|█████████▍| 213/225 [00:38<00:02,  5.27it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 95%|█████████▌| 214/225 [00:38<00:02,  5.19it/s][A
 96%|█████████▌| 215/225 [00:38<00:01,  5.28it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 96%|█████████▌| 216/225 [00:38<00:01,  5.52it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 96%|█████████▋| 217/225 [00:38<00:01,  5.86it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 97%|█████████▋| 218/225 [00:39<00:01,  5.04it/s][A
 97%|█████████▋| 219/225 [00:39<00:01,  5.29it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 98%|█████████▊| 220/225 [00:39<00:00,  5.50it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



 98%|█████████▊| 221/225 [00:39<00:00,  5.54it/s][A
 99%|█████████▊| 222/225 [00:39<00:00,  5.68it/s][A

torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])



 99%|█████████▉| 223/225 [00:39<00:00,  5.43it/s][A

torch.Size([100, 2])
torch.Size([100, 1, 50, 50]) torch.Size([100, 2])
torch.Size([100, 2])



100%|█████████▉| 224/225 [00:40<00:00,  5.37it/s][A
100%|██████████| 225/225 [00:40<00:00,  5.58it/s][A

torch.Size([52, 1, 50, 50]) torch.Size([52, 2])
torch.Size([52, 2])
Epoch: 0. Loss: 0.19713585078716278





# Test accuracy

In [38]:
correct = 0
total = 0
with torch.no_grad():
    for i in tqdm(range(len(test_X))):
        real_class = torch.argmax(test_y[i])
        net_out = net(test_X[i].view(-1, 1, 50, 50))[0]  # returns a list, 
        predicted_class = torch.argmax(net_out)

        if predicted_class == real_class:
            correct += 1
        total += 1
print("Accuracy: ", round(correct/total, 3))


  0%|          | 0/2494 [00:00<?, ?it/s][A
  1%|▏         | 32/2494 [00:00<00:07, 317.97it/s][A
  4%|▍         | 95/2494 [00:00<00:06, 373.27it/s][A
  7%|▋         | 169/2494 [00:00<00:05, 438.19it/s][A
  9%|▉         | 231/2494 [00:00<00:04, 479.95it/s][A
 12%|█▏        | 304/2494 [00:00<00:04, 534.05it/s][A
 15%|█▍        | 367/2494 [00:00<00:03, 559.00it/s][A
 18%|█▊        | 442/2494 [00:00<00:03, 603.74it/s][A
 21%|██        | 517/2494 [00:00<00:03, 640.79it/s][A
 24%|██▎       | 588/2494 [00:00<00:02, 659.47it/s][A
 27%|██▋       | 671/2494 [00:01<00:02, 702.14it/s][A
 30%|███       | 753/2494 [00:01<00:02, 731.78it/s][A
 33%|███▎      | 828/2494 [00:01<00:02, 722.65it/s][A
 36%|███▌      | 902/2494 [00:01<00:02, 716.52it/s][A
 39%|███▉      | 975/2494 [00:01<00:02, 719.59it/s][A
 42%|████▏     | 1054/2494 [00:01<00:01, 738.30it/s][A
 45%|████▌     | 1129/2494 [00:01<00:01, 737.16it/s][A
 48%|████▊     | 1204/2494 [00:01<00:01, 735.43it/s][A
 51%|█████     | 12

Accuracy:  0.642





# Train on GPU

In [42]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")  # you can continue going on here, like cuda:1 cuda:2....etc. 
    print("Running on the GPU:", device)
else:
    device = torch.device("cpu")
    print("Running on the CPU", device)

Running on the GPU: cuda:0


In [44]:
net = Net().to(device)

In [49]:
def train(net,BATCH_SIZE = 100,    EPOCHS = 3):
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(EPOCHS):
        for i in range(0, len(train_X), BATCH_SIZE): # from 0, to the len of x, stepping BATCH_SIZE at a time. [:50] ..for now just to dev
            batch_X = train_X[i:i+BATCH_SIZE].view(-1, 1, 50, 50)
            batch_y = train_y[i:i+BATCH_SIZE]

            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            net.zero_grad()

            optimizer.zero_grad()   # zero the gradient buffers
            outputs = net(batch_X)
            loss = loss_function(outputs, batch_y)
            loss.backward()
            optimizer.step()    # Does the update

        print(f"Epoch: {epoch}. Loss: {loss}")


In [None]:
train(net, EPOCHS=10)

Epoch: 0. Loss: 0.09853548556566238
Epoch: 1. Loss: 0.10401438921689987
Epoch: 2. Loss: 0.09633658826351166
Epoch: 3. Loss: 0.1109447181224823
Epoch: 4. Loss: 0.0674598291516304


# Test Accuracy

In [46]:
correct = 0
total = 0
for i in tqdm(range(0, len(test_X), BATCH_SIZE)):

    batch_X = test_X[i:i+BATCH_SIZE].view(-1, 1, 50, 50).to(device)
    batch_y = test_y[i:i+BATCH_SIZE].to(device)
    batch_out = net(batch_X)

    out_maxes = [torch.argmax(i) for i in batch_out]
    target_maxes = [torch.argmax(i) for i in batch_y]
    for i,j in zip(out_maxes, target_maxes):
        if i == j:
            correct += 1
        total += 1
print("Accuracy: ", round(correct/total, 3))


  0%|          | 0/25 [00:00<?, ?it/s][A
 16%|█▌        | 4/25 [00:00<00:00, 36.00it/s][A
 40%|████      | 10/25 [00:00<00:00, 40.28it/s][A
 64%|██████▍   | 16/25 [00:00<00:00, 44.21it/s][A
100%|██████████| 25/25 [00:00<00:00, 53.13it/s][A

Accuracy:  0.755



