In [2]:
!wget https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip # Download dataset for Google Colab
!unzip /content/kagglecatsanddogs_3367a.zip
print("Download Complete")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Download Complete


In [6]:
import os
import cv2
import numpy as np
from tqdm import tqdm

REBUILD_DATA = False # set to true to one once, then back to false unless you want to change something in your training data.

class DogsVSCats():
    IMG_SIZE = 50
    CATS = "PetImages/Cat"
    DOGS = "PetImages/Dog"
    TESTING = "PetImages/Testing"
    LABELS = {CATS: 0, DOGS: 1}
    training_data = []

    catcount = 0
    dogcount = 0

    def make_training_data(self):
        for label in self.LABELS:
            print(label)
            for f in tqdm(os.listdir(label)):
                if "jpg" in f:
                    try:
                        path = os.path.join(label, f)
                        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
                        img = cv2.resize(img, (self.IMG_SIZE, self.IMG_SIZE))
                        self.training_data.append([np.array(img), np.eye(2)[self.LABELS[label]]])  

                        if label == self.CATS:
                            self.catcount += 1
                        elif label == self.DOGS:
                            self.dogcount += 1

                    except Exception as e:
                        pass
                        #print(label, f, str(e))

        np.random.shuffle(self.training_data)
        np.save("training_data.npy", self.training_data)
        print('Cats:',dogsvcats.catcount)
        print('Dogs:',dogsvcats.dogcount)

if REBUILD_DATA:
    dogsvcats = DogsVSCats()
    dogsvcats.make_training_data()


training_data = np.load("training_data.npy", allow_pickle=True)
print(len(training_data))

24946


In [22]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5) #input, output, kernel size ( 5x5 kernel)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 128, 5)  

        x = torch.randn(50, 50).view(-1, 1, 50, 50) # -1 means any input dimension, 1 means input dimension, 50x50 is shape of input. (((batch_size, size(dim or channel), output_shape)))
        self._to_linear = None
        self.convs(x)
        self.fc1 = nn.Linear(self._to_linear, 512) #fully connected layer
        self.fc2 = nn.Linear(512, 2)

    def convs(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) #2x2 max pooling output shape
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))

        print(x[0].shape)
        if self._to_linear is None:
            self._to_linear = x[0].shape[0]* x[0].shape[1]* x[0].shape[2]
        return x

    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self._to_linear)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=1)
net = Net()
print(net)

torch.Size([128, 2, 2])
Net(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=512, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=2, bias=True)
)


In [23]:
import torch.optim as optim

optimizer = optim.Adam(net.parameters(), lr=0.001)
loss_function = nn.MSELoss()

X = torch.Tensor([i[0] for i in training_data]).view(-1, 50, 50)
X = X/255.0 # Set pixel values between 0 and 1
y = torch.Tensor([i[1] for i in training_data])

VAL_PCT = 0.1 # value percent. Test against %10 percent of our dataset.	In our case testset is approximately 2500.
val_size = int(len(X) * VAL_PCT)
print(val_size)

2494


In [24]:
train_X = X[:-val_size]
train_y = y[:-val_size]

test_X = X[-val_size:]
test_y = y[-val_size:]

print(len(train_X))
print(len(test_X))

22452
2494


In [25]:
# TRAINING SECTION 

BATCH_SIZE = 100
EPOCHS = 1

for epoch in range(EPOCHS):
  for i in tqdm(range(0, len(train_X), BATCH_SIZE)): #between 0 - len(train_X) with BATCH_SIZE as a step value.
    batch_X = train_X[i:i+BATCH_SIZE].view(-1, 1, 50, 50)
    batch_y = train_y[i:i+BATCH_SIZE]

    net.zero_grad()
    outputs = net(batch_X)
    loss = loss_function(outputs, batch_y)
    loss.backward()
    optimizer.step()
print("loss :", loss)

  0%|          | 0/225 [00:00<?, ?it/s]

torch.Size([128, 2, 2])


  1%|          | 2/225 [00:00<01:17,  2.89it/s]

torch.Size([128, 2, 2])


  1%|▏         | 3/225 [00:00<01:10,  3.16it/s]

torch.Size([128, 2, 2])


  2%|▏         | 4/225 [00:01<01:05,  3.39it/s]

torch.Size([128, 2, 2])


  2%|▏         | 5/225 [00:01<01:02,  3.52it/s]

torch.Size([128, 2, 2])


  3%|▎         | 6/225 [00:01<00:59,  3.67it/s]

torch.Size([128, 2, 2])


  3%|▎         | 7/225 [00:01<00:57,  3.76it/s]

torch.Size([128, 2, 2])


  4%|▎         | 8/225 [00:02<00:56,  3.82it/s]

torch.Size([128, 2, 2])


  4%|▍         | 9/225 [00:02<00:56,  3.84it/s]

torch.Size([128, 2, 2])


  4%|▍         | 10/225 [00:02<00:55,  3.88it/s]

torch.Size([128, 2, 2])


  5%|▍         | 11/225 [00:02<00:54,  3.92it/s]

torch.Size([128, 2, 2])


  5%|▌         | 12/225 [00:03<00:53,  3.95it/s]

torch.Size([128, 2, 2])


  6%|▌         | 13/225 [00:03<00:53,  3.97it/s]

torch.Size([128, 2, 2])


  6%|▌         | 14/225 [00:03<00:52,  3.98it/s]

torch.Size([128, 2, 2])


  7%|▋         | 15/225 [00:03<00:52,  4.00it/s]

torch.Size([128, 2, 2])


  7%|▋         | 16/225 [00:04<00:52,  4.00it/s]

torch.Size([128, 2, 2])


  8%|▊         | 17/225 [00:04<00:52,  3.98it/s]

torch.Size([128, 2, 2])


  8%|▊         | 18/225 [00:04<00:52,  3.97it/s]

torch.Size([128, 2, 2])


  8%|▊         | 19/225 [00:04<00:51,  3.96it/s]

torch.Size([128, 2, 2])


  9%|▉         | 20/225 [00:05<00:51,  3.97it/s]

torch.Size([128, 2, 2])


  9%|▉         | 21/225 [00:05<00:52,  3.88it/s]

torch.Size([128, 2, 2])


 10%|▉         | 22/225 [00:05<00:52,  3.89it/s]

torch.Size([128, 2, 2])


 10%|█         | 23/225 [00:05<00:51,  3.92it/s]

torch.Size([128, 2, 2])


 11%|█         | 24/225 [00:06<00:50,  3.95it/s]

torch.Size([128, 2, 2])


 11%|█         | 25/225 [00:06<00:50,  3.94it/s]

torch.Size([128, 2, 2])


 12%|█▏        | 26/225 [00:06<00:50,  3.94it/s]

torch.Size([128, 2, 2])


 12%|█▏        | 27/225 [00:06<00:50,  3.94it/s]

torch.Size([128, 2, 2])


 12%|█▏        | 28/225 [00:07<00:49,  3.96it/s]

torch.Size([128, 2, 2])


 13%|█▎        | 29/225 [00:07<00:49,  3.93it/s]

torch.Size([128, 2, 2])


 13%|█▎        | 30/225 [00:07<00:49,  3.92it/s]

torch.Size([128, 2, 2])


 14%|█▍        | 31/225 [00:07<00:49,  3.95it/s]

torch.Size([128, 2, 2])


 14%|█▍        | 32/225 [00:08<00:48,  3.99it/s]

torch.Size([128, 2, 2])


 15%|█▍        | 33/225 [00:08<00:48,  3.96it/s]

torch.Size([128, 2, 2])


 15%|█▌        | 34/225 [00:08<00:48,  3.96it/s]

torch.Size([128, 2, 2])


 16%|█▌        | 35/225 [00:08<00:49,  3.85it/s]

torch.Size([128, 2, 2])


 16%|█▌        | 36/225 [00:09<00:48,  3.91it/s]

torch.Size([128, 2, 2])


 16%|█▋        | 37/225 [00:09<00:47,  3.92it/s]

torch.Size([128, 2, 2])


 17%|█▋        | 38/225 [00:09<00:47,  3.96it/s]

torch.Size([128, 2, 2])


 17%|█▋        | 39/225 [00:09<00:46,  3.98it/s]

torch.Size([128, 2, 2])


 18%|█▊        | 40/225 [00:10<00:45,  4.03it/s]

torch.Size([128, 2, 2])


 18%|█▊        | 41/225 [00:10<00:46,  3.99it/s]

torch.Size([128, 2, 2])


 19%|█▊        | 42/225 [00:10<00:45,  4.02it/s]

torch.Size([128, 2, 2])


 19%|█▉        | 43/225 [00:10<00:45,  4.02it/s]

torch.Size([128, 2, 2])


 20%|█▉        | 44/225 [00:11<00:44,  4.02it/s]

torch.Size([128, 2, 2])


 20%|██        | 45/225 [00:11<00:45,  3.96it/s]

torch.Size([128, 2, 2])


 20%|██        | 46/225 [00:11<00:45,  3.97it/s]

torch.Size([128, 2, 2])


 21%|██        | 47/225 [00:11<00:44,  3.98it/s]

torch.Size([128, 2, 2])


 21%|██▏       | 48/225 [00:12<00:44,  3.99it/s]

torch.Size([128, 2, 2])


 22%|██▏       | 49/225 [00:12<00:44,  3.96it/s]

torch.Size([128, 2, 2])


 22%|██▏       | 50/225 [00:12<00:44,  3.94it/s]

torch.Size([128, 2, 2])


 23%|██▎       | 51/225 [00:12<00:44,  3.92it/s]

torch.Size([128, 2, 2])


 23%|██▎       | 52/225 [00:13<00:43,  3.95it/s]

torch.Size([128, 2, 2])


 24%|██▎       | 53/225 [00:13<00:43,  3.95it/s]

torch.Size([128, 2, 2])


 24%|██▍       | 54/225 [00:13<00:43,  3.95it/s]

torch.Size([128, 2, 2])


 24%|██▍       | 55/225 [00:14<00:42,  3.97it/s]

torch.Size([128, 2, 2])


 25%|██▍       | 56/225 [00:14<00:42,  3.96it/s]

torch.Size([128, 2, 2])


 25%|██▌       | 57/225 [00:14<00:42,  3.93it/s]

torch.Size([128, 2, 2])


 26%|██▌       | 58/225 [00:14<00:42,  3.94it/s]

torch.Size([128, 2, 2])


 26%|██▌       | 59/225 [00:15<00:42,  3.94it/s]

torch.Size([128, 2, 2])


 27%|██▋       | 60/225 [00:15<00:41,  3.97it/s]

torch.Size([128, 2, 2])


 27%|██▋       | 61/225 [00:15<00:41,  3.94it/s]

torch.Size([128, 2, 2])


 28%|██▊       | 62/225 [00:15<00:41,  3.93it/s]

torch.Size([128, 2, 2])


 28%|██▊       | 63/225 [00:16<00:41,  3.92it/s]

torch.Size([128, 2, 2])


 28%|██▊       | 64/225 [00:16<00:40,  3.94it/s]

torch.Size([128, 2, 2])


 29%|██▉       | 65/225 [00:16<00:40,  3.91it/s]

torch.Size([128, 2, 2])


 29%|██▉       | 66/225 [00:16<00:40,  3.94it/s]

torch.Size([128, 2, 2])


 30%|██▉       | 67/225 [00:17<00:40,  3.92it/s]

torch.Size([128, 2, 2])


 30%|███       | 68/225 [00:17<00:39,  3.93it/s]

torch.Size([128, 2, 2])


 31%|███       | 69/225 [00:17<00:39,  3.92it/s]

torch.Size([128, 2, 2])


 31%|███       | 70/225 [00:17<00:39,  3.92it/s]

torch.Size([128, 2, 2])


 32%|███▏      | 71/225 [00:18<00:39,  3.92it/s]

torch.Size([128, 2, 2])


 32%|███▏      | 72/225 [00:18<00:39,  3.90it/s]

torch.Size([128, 2, 2])


 32%|███▏      | 73/225 [00:18<00:39,  3.89it/s]

torch.Size([128, 2, 2])


 33%|███▎      | 74/225 [00:18<00:38,  3.93it/s]

torch.Size([128, 2, 2])


 33%|███▎      | 75/225 [00:19<00:38,  3.90it/s]

torch.Size([128, 2, 2])


 34%|███▍      | 76/225 [00:19<00:38,  3.88it/s]

torch.Size([128, 2, 2])


 34%|███▍      | 77/225 [00:19<00:38,  3.87it/s]

torch.Size([128, 2, 2])


 35%|███▍      | 78/225 [00:19<00:37,  3.90it/s]

torch.Size([128, 2, 2])


 35%|███▌      | 79/225 [00:20<00:37,  3.88it/s]

torch.Size([128, 2, 2])


 36%|███▌      | 80/225 [00:20<00:37,  3.89it/s]

torch.Size([128, 2, 2])


 36%|███▌      | 81/225 [00:20<00:37,  3.88it/s]

torch.Size([128, 2, 2])


 36%|███▋      | 82/225 [00:20<00:36,  3.91it/s]

torch.Size([128, 2, 2])


 37%|███▋      | 83/225 [00:21<00:36,  3.92it/s]

torch.Size([128, 2, 2])


 37%|███▋      | 84/225 [00:21<00:35,  3.95it/s]

torch.Size([128, 2, 2])


 38%|███▊      | 85/225 [00:21<00:35,  3.90it/s]

torch.Size([128, 2, 2])


 38%|███▊      | 86/225 [00:21<00:35,  3.91it/s]

torch.Size([128, 2, 2])


 39%|███▊      | 87/225 [00:22<00:36,  3.80it/s]

torch.Size([128, 2, 2])


 39%|███▉      | 88/225 [00:22<00:35,  3.84it/s]

torch.Size([128, 2, 2])


 40%|███▉      | 89/225 [00:22<00:35,  3.86it/s]

torch.Size([128, 2, 2])


 40%|████      | 90/225 [00:22<00:35,  3.86it/s]

torch.Size([128, 2, 2])


 40%|████      | 91/225 [00:23<00:34,  3.89it/s]

torch.Size([128, 2, 2])


 41%|████      | 92/225 [00:23<00:34,  3.90it/s]

torch.Size([128, 2, 2])


 41%|████▏     | 93/225 [00:23<00:34,  3.88it/s]

torch.Size([128, 2, 2])


 42%|████▏     | 94/225 [00:24<00:33,  3.85it/s]

torch.Size([128, 2, 2])


 42%|████▏     | 95/225 [00:24<00:33,  3.89it/s]

torch.Size([128, 2, 2])


 43%|████▎     | 96/225 [00:24<00:33,  3.90it/s]

torch.Size([128, 2, 2])


 43%|████▎     | 97/225 [00:24<00:33,  3.87it/s]

torch.Size([128, 2, 2])


 44%|████▎     | 98/225 [00:25<00:33,  3.84it/s]

torch.Size([128, 2, 2])


 44%|████▍     | 99/225 [00:25<00:32,  3.85it/s]

torch.Size([128, 2, 2])


 44%|████▍     | 100/225 [00:25<00:32,  3.87it/s]

torch.Size([128, 2, 2])


 45%|████▍     | 101/225 [00:25<00:31,  3.90it/s]

torch.Size([128, 2, 2])


 45%|████▌     | 102/225 [00:26<00:31,  3.93it/s]

torch.Size([128, 2, 2])


 46%|████▌     | 103/225 [00:26<00:30,  3.96it/s]

torch.Size([128, 2, 2])


 46%|████▌     | 104/225 [00:26<00:30,  3.96it/s]

torch.Size([128, 2, 2])


 47%|████▋     | 105/225 [00:26<00:30,  3.95it/s]

torch.Size([128, 2, 2])


 47%|████▋     | 106/225 [00:27<00:30,  3.94it/s]

torch.Size([128, 2, 2])


 48%|████▊     | 107/225 [00:27<00:29,  3.95it/s]

torch.Size([128, 2, 2])


 48%|████▊     | 108/225 [00:27<00:29,  3.95it/s]

torch.Size([128, 2, 2])


 48%|████▊     | 109/225 [00:27<00:29,  3.87it/s]

torch.Size([128, 2, 2])


 49%|████▉     | 110/225 [00:28<00:29,  3.88it/s]

torch.Size([128, 2, 2])


 49%|████▉     | 111/225 [00:28<00:29,  3.89it/s]

torch.Size([128, 2, 2])


 50%|████▉     | 112/225 [00:28<00:28,  3.90it/s]

torch.Size([128, 2, 2])


 50%|█████     | 113/225 [00:28<00:28,  3.91it/s]

torch.Size([128, 2, 2])


 51%|█████     | 114/225 [00:29<00:28,  3.92it/s]

torch.Size([128, 2, 2])


 51%|█████     | 115/225 [00:29<00:27,  3.94it/s]

torch.Size([128, 2, 2])


 52%|█████▏    | 116/225 [00:29<00:28,  3.88it/s]

torch.Size([128, 2, 2])


 52%|█████▏    | 117/225 [00:29<00:27,  3.91it/s]

torch.Size([128, 2, 2])


 52%|█████▏    | 118/225 [00:30<00:27,  3.94it/s]

torch.Size([128, 2, 2])


 53%|█████▎    | 119/225 [00:30<00:26,  3.94it/s]

torch.Size([128, 2, 2])


 53%|█████▎    | 120/225 [00:30<00:26,  3.93it/s]

torch.Size([128, 2, 2])


 54%|█████▍    | 121/225 [00:30<00:26,  3.96it/s]

torch.Size([128, 2, 2])


 54%|█████▍    | 122/225 [00:31<00:26,  3.96it/s]

torch.Size([128, 2, 2])


 55%|█████▍    | 123/225 [00:31<00:25,  3.96it/s]

torch.Size([128, 2, 2])


 55%|█████▌    | 124/225 [00:31<00:25,  3.94it/s]

torch.Size([128, 2, 2])


 56%|█████▌    | 125/225 [00:31<00:25,  3.94it/s]

torch.Size([128, 2, 2])


 56%|█████▌    | 126/225 [00:32<00:24,  3.97it/s]

torch.Size([128, 2, 2])


 56%|█████▋    | 127/225 [00:32<00:24,  3.97it/s]

torch.Size([128, 2, 2])


 57%|█████▋    | 128/225 [00:32<00:24,  3.95it/s]

torch.Size([128, 2, 2])


 57%|█████▋    | 129/225 [00:32<00:24,  3.91it/s]

torch.Size([128, 2, 2])


 58%|█████▊    | 130/225 [00:33<00:24,  3.92it/s]

torch.Size([128, 2, 2])


 58%|█████▊    | 131/225 [00:33<00:23,  3.98it/s]

torch.Size([128, 2, 2])


 59%|█████▊    | 132/225 [00:33<00:23,  3.97it/s]

torch.Size([128, 2, 2])


 59%|█████▉    | 133/225 [00:33<00:23,  3.97it/s]

torch.Size([128, 2, 2])


 60%|█████▉    | 134/225 [00:34<00:22,  3.98it/s]

torch.Size([128, 2, 2])


 60%|██████    | 135/225 [00:34<00:22,  4.02it/s]

torch.Size([128, 2, 2])


 60%|██████    | 136/225 [00:34<00:22,  4.03it/s]

torch.Size([128, 2, 2])


 61%|██████    | 137/225 [00:34<00:22,  3.96it/s]

torch.Size([128, 2, 2])


 61%|██████▏   | 138/225 [00:35<00:21,  3.97it/s]

torch.Size([128, 2, 2])


 62%|██████▏   | 139/225 [00:35<00:21,  3.97it/s]

torch.Size([128, 2, 2])


 62%|██████▏   | 140/225 [00:35<00:21,  3.96it/s]

torch.Size([128, 2, 2])


 63%|██████▎   | 141/225 [00:35<00:21,  3.97it/s]

torch.Size([128, 2, 2])


 63%|██████▎   | 142/225 [00:36<00:20,  3.97it/s]

torch.Size([128, 2, 2])


 64%|██████▎   | 143/225 [00:36<00:20,  3.98it/s]

torch.Size([128, 2, 2])


 64%|██████▍   | 144/225 [00:36<00:20,  4.01it/s]

torch.Size([128, 2, 2])


 64%|██████▍   | 145/225 [00:36<00:20,  3.99it/s]

torch.Size([128, 2, 2])


 65%|██████▍   | 146/225 [00:37<00:19,  3.98it/s]

torch.Size([128, 2, 2])


 65%|██████▌   | 147/225 [00:37<00:19,  3.98it/s]

torch.Size([128, 2, 2])


 66%|██████▌   | 148/225 [00:37<00:19,  3.98it/s]

torch.Size([128, 2, 2])


 66%|██████▌   | 149/225 [00:37<00:19,  3.98it/s]

torch.Size([128, 2, 2])


 67%|██████▋   | 150/225 [00:38<00:18,  3.98it/s]

torch.Size([128, 2, 2])


 67%|██████▋   | 151/225 [00:38<00:18,  4.00it/s]

torch.Size([128, 2, 2])


 68%|██████▊   | 152/225 [00:38<00:18,  3.97it/s]

torch.Size([128, 2, 2])


 68%|██████▊   | 153/225 [00:38<00:18,  3.91it/s]

torch.Size([128, 2, 2])


 68%|██████▊   | 154/225 [00:39<00:17,  3.95it/s]

torch.Size([128, 2, 2])


 69%|██████▉   | 155/225 [00:39<00:17,  3.97it/s]

torch.Size([128, 2, 2])


 69%|██████▉   | 156/225 [00:39<00:17,  3.91it/s]

torch.Size([128, 2, 2])


 70%|██████▉   | 157/225 [00:39<00:17,  3.89it/s]

torch.Size([128, 2, 2])


 70%|███████   | 158/225 [00:40<00:17,  3.88it/s]

torch.Size([128, 2, 2])


 71%|███████   | 159/225 [00:40<00:16,  3.90it/s]

torch.Size([128, 2, 2])


 71%|███████   | 160/225 [00:40<00:16,  3.95it/s]

torch.Size([128, 2, 2])


 72%|███████▏  | 161/225 [00:40<00:16,  3.96it/s]

torch.Size([128, 2, 2])


 72%|███████▏  | 162/225 [00:41<00:15,  3.96it/s]

torch.Size([128, 2, 2])


 72%|███████▏  | 163/225 [00:41<00:15,  3.97it/s]

torch.Size([128, 2, 2])


 73%|███████▎  | 164/225 [00:41<00:15,  3.97it/s]

torch.Size([128, 2, 2])


 73%|███████▎  | 165/225 [00:41<00:15,  3.97it/s]

torch.Size([128, 2, 2])


 74%|███████▍  | 166/225 [00:42<00:14,  3.97it/s]

torch.Size([128, 2, 2])


 74%|███████▍  | 167/225 [00:42<00:14,  3.97it/s]

torch.Size([128, 2, 2])


 75%|███████▍  | 168/225 [00:42<00:14,  3.95it/s]

torch.Size([128, 2, 2])


 75%|███████▌  | 169/225 [00:43<00:14,  3.93it/s]

torch.Size([128, 2, 2])


 76%|███████▌  | 170/225 [00:43<00:14,  3.92it/s]

torch.Size([128, 2, 2])


 76%|███████▌  | 171/225 [00:43<00:13,  3.93it/s]

torch.Size([128, 2, 2])


 76%|███████▋  | 172/225 [00:43<00:13,  3.96it/s]

torch.Size([128, 2, 2])


 77%|███████▋  | 173/225 [00:44<00:13,  3.94it/s]

torch.Size([128, 2, 2])


 77%|███████▋  | 174/225 [00:44<00:12,  3.94it/s]

torch.Size([128, 2, 2])


 78%|███████▊  | 175/225 [00:44<00:12,  3.95it/s]

torch.Size([128, 2, 2])


 78%|███████▊  | 176/225 [00:44<00:12,  3.93it/s]

torch.Size([128, 2, 2])


 79%|███████▊  | 177/225 [00:45<00:12,  3.92it/s]

torch.Size([128, 2, 2])


 79%|███████▉  | 178/225 [00:45<00:11,  3.92it/s]

torch.Size([128, 2, 2])


 80%|███████▉  | 179/225 [00:45<00:11,  3.92it/s]

torch.Size([128, 2, 2])


 80%|████████  | 180/225 [00:45<00:11,  3.93it/s]

torch.Size([128, 2, 2])


 80%|████████  | 181/225 [00:46<00:11,  3.95it/s]

torch.Size([128, 2, 2])


 81%|████████  | 182/225 [00:46<00:10,  3.94it/s]

torch.Size([128, 2, 2])


 81%|████████▏ | 183/225 [00:46<00:10,  3.97it/s]

torch.Size([128, 2, 2])


 82%|████████▏ | 184/225 [00:46<00:10,  4.00it/s]

torch.Size([128, 2, 2])


 82%|████████▏ | 185/225 [00:47<00:10,  3.99it/s]

torch.Size([128, 2, 2])


 83%|████████▎ | 186/225 [00:47<00:09,  4.01it/s]

torch.Size([128, 2, 2])


 83%|████████▎ | 187/225 [00:47<00:09,  4.01it/s]

torch.Size([128, 2, 2])


 84%|████████▎ | 188/225 [00:47<00:09,  3.98it/s]

torch.Size([128, 2, 2])


 84%|████████▍ | 189/225 [00:48<00:09,  3.96it/s]

torch.Size([128, 2, 2])


 84%|████████▍ | 190/225 [00:48<00:08,  3.94it/s]

torch.Size([128, 2, 2])


 85%|████████▍ | 191/225 [00:48<00:08,  3.98it/s]

torch.Size([128, 2, 2])


 85%|████████▌ | 192/225 [00:48<00:08,  3.98it/s]

torch.Size([128, 2, 2])


 86%|████████▌ | 193/225 [00:49<00:08,  3.97it/s]

torch.Size([128, 2, 2])


 86%|████████▌ | 194/225 [00:49<00:07,  3.94it/s]

torch.Size([128, 2, 2])


 87%|████████▋ | 195/225 [00:49<00:07,  3.94it/s]

torch.Size([128, 2, 2])


 87%|████████▋ | 196/225 [00:49<00:07,  3.99it/s]

torch.Size([128, 2, 2])


 88%|████████▊ | 197/225 [00:50<00:07,  3.90it/s]

torch.Size([128, 2, 2])


 88%|████████▊ | 198/225 [00:50<00:06,  3.90it/s]

torch.Size([128, 2, 2])


 88%|████████▊ | 199/225 [00:50<00:06,  3.85it/s]

torch.Size([128, 2, 2])


 89%|████████▉ | 200/225 [00:50<00:06,  3.91it/s]

torch.Size([128, 2, 2])


 89%|████████▉ | 201/225 [00:51<00:06,  3.91it/s]

torch.Size([128, 2, 2])


 90%|████████▉ | 202/225 [00:51<00:05,  3.92it/s]

torch.Size([128, 2, 2])


 90%|█████████ | 203/225 [00:51<00:05,  3.89it/s]

torch.Size([128, 2, 2])


 91%|█████████ | 204/225 [00:51<00:05,  3.88it/s]

torch.Size([128, 2, 2])


 91%|█████████ | 205/225 [00:52<00:05,  3.90it/s]

torch.Size([128, 2, 2])


 92%|█████████▏| 206/225 [00:52<00:04,  3.92it/s]

torch.Size([128, 2, 2])


 92%|█████████▏| 207/225 [00:52<00:04,  3.96it/s]

torch.Size([128, 2, 2])


 92%|█████████▏| 208/225 [00:52<00:04,  3.99it/s]

torch.Size([128, 2, 2])


 93%|█████████▎| 209/225 [00:53<00:04,  3.98it/s]

torch.Size([128, 2, 2])


 93%|█████████▎| 210/225 [00:53<00:03,  3.95it/s]

torch.Size([128, 2, 2])


 94%|█████████▍| 211/225 [00:53<00:03,  3.89it/s]

torch.Size([128, 2, 2])


 94%|█████████▍| 212/225 [00:53<00:03,  3.93it/s]

torch.Size([128, 2, 2])


 95%|█████████▍| 213/225 [00:54<00:03,  3.90it/s]

torch.Size([128, 2, 2])


 95%|█████████▌| 214/225 [00:54<00:02,  3.90it/s]

torch.Size([128, 2, 2])


 96%|█████████▌| 215/225 [00:54<00:02,  3.90it/s]

torch.Size([128, 2, 2])


 96%|█████████▌| 216/225 [00:54<00:02,  3.90it/s]

torch.Size([128, 2, 2])


 96%|█████████▋| 217/225 [00:55<00:02,  3.88it/s]

torch.Size([128, 2, 2])


 97%|█████████▋| 218/225 [00:55<00:01,  3.85it/s]

torch.Size([128, 2, 2])


 97%|█████████▋| 219/225 [00:55<00:01,  3.87it/s]

torch.Size([128, 2, 2])


 98%|█████████▊| 220/225 [00:55<00:01,  3.87it/s]

torch.Size([128, 2, 2])


 98%|█████████▊| 221/225 [00:56<00:01,  3.87it/s]

torch.Size([128, 2, 2])


 99%|█████████▊| 222/225 [00:56<00:00,  3.87it/s]

torch.Size([128, 2, 2])


 99%|█████████▉| 223/225 [00:56<00:00,  3.87it/s]

torch.Size([128, 2, 2])


100%|█████████▉| 224/225 [00:57<00:00,  3.89it/s]

torch.Size([128, 2, 2])


100%|██████████| 225/225 [00:57<00:00,  3.94it/s]

torch.Size([128, 2, 2])
loss : tensor(0.2526, grad_fn=<MseLossBackward>)





In [28]:
correct = 0 
total = 0
with torch.no_grad():
    for i in tqdm(range(len(test_X))):
        real_class = torch.argmax(test_y[i])
        net_out = net(test_X[i].view(-1, 1, 50, 50))[0]
        predicted_class = torch.argmax(net_out)
        if predicted_class == real_class:
            correct += 1
        total += 1
print("Accuracy:", round(correct/total, 3))

  4%|▍         | 99/2494 [00:00<00:04, 485.10it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
Accuracy: 0.646



