In [1]:
import numpy as np
import cv2
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
from torchvision import models

In [3]:
REBUILD_DATA = False
class DogsVCats:
    IMG_SIZE = 50
    CATS = "PetImages/Cat"
    DOGS = "PetImages/Dog"
    LABELS = {CATS: 0, DOGS: 1}
    training_data = []
    catcount = 0
    dogcount = 0
    def train(self):
        for label in self.LABELS:
            for img in tqdm(os.listdir(label)):
                try:
                    path = os.path.join(label, img)
                    image = cv2.imread(path, 0)
                    image = cv2.resize(image, (self.IMG_SIZE, self.IMG_SIZE))
                    self.training_data.append([np.array(image), np.eye(2)[self.LABELS[label]]])

                    if label == self.CATS:
                        self.catcount+=1
                    if label == self.DOGS:
                        self.dogcount+=1
                except Exception as e:
                    pass
        np.random.shuffle(self.training_data)
        np.save("training_data.npy", self.training_data)


In [4]:
if REBUILD_DATA:
    dogsvcats = DogsVCats()
    dogsvcats.train()
    
training_data = np.load("training_data.npy", allow_pickle = True)
print(len(training_data))

24946


In [16]:
import torch
x, y = next(iter(training_data))
y.shape

t = torch.tensor([[0.8587, 0.1413]])
t.shape

torch.Size([1, 2])

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 128, 5)
        self.fc1 = nn.Linear(128*2*2, 512)
        self.fc2 = nn.Linear(512, 2)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))
        
        x = x.flatten(start_dim=1)
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return F.softmax(x, dim=1)
        
net = Net()

In [6]:
import torch.optim as optim

optimizer = optim.Adam(net.parameters(), lr=0.001)
loss_function = nn.MSELoss()

X = torch.Tensor([i[0] for i in training_data]).view(-1,50,50)
X = X/255.0
y = torch.Tensor([i[1] for i in training_data])

VAL_PCT = 0.1  # lets reserve 10% of our data for validation
val_size = int(len(X)*VAL_PCT)
print(val_size)

train_X = X[:-val_size]
train_y = y[:-val_size]

test_X = X[-val_size:]
test_y = y[-val_size:]

# print(len(train_X), len(test_X))


2494


In [7]:
BATCH_SIZE = 100
EPOCHS = 8

for epoch in range(EPOCHS):
    for i in tqdm(range(0, len(train_X), BATCH_SIZE)): # from 0, to the len of x, stepping BATCH_SIZE at a time. [:50] ..for now just to dev
        #print(f"{i}:{i+BATCH_SIZE}")
        batch_X = train_X[i:i+BATCH_SIZE].view(-1, 1, 50, 50)
        batch_y = train_y[i:i+BATCH_SIZE]

        net.zero_grad()

        outputs = net(batch_X)
        loss = loss_function(outputs, batch_y)
        loss.backward()
        optimizer.step()    # Does the update

    print(f"Epoch: {epoch}. Loss: {loss}")

100%|████████████████████████████████████████████████████████████████████████████████| 225/225 [02:16<00:00,  1.65it/s]
  0%|                                                                                          | 0/225 [00:00<?, ?it/s]

Epoch: 0. Loss: 0.23219765722751617


100%|████████████████████████████████████████████████████████████████████████████████| 225/225 [02:22<00:00,  1.58it/s]
  0%|                                                                                          | 0/225 [00:00<?, ?it/s]

Epoch: 1. Loss: 0.20397509634494781


100%|████████████████████████████████████████████████████████████████████████████████| 225/225 [02:23<00:00,  1.57it/s]
  0%|                                                                                          | 0/225 [00:00<?, ?it/s]

Epoch: 2. Loss: 0.16770240664482117


100%|████████████████████████████████████████████████████████████████████████████████| 225/225 [02:07<00:00,  1.76it/s]
  0%|                                                                                          | 0/225 [00:00<?, ?it/s]

Epoch: 3. Loss: 0.1571255326271057


100%|████████████████████████████████████████████████████████████████████████████████| 225/225 [02:12<00:00,  1.70it/s]
  0%|                                                                                          | 0/225 [00:00<?, ?it/s]

Epoch: 4. Loss: 0.13298757374286652


100%|████████████████████████████████████████████████████████████████████████████████| 225/225 [02:14<00:00,  1.68it/s]
  0%|                                                                                          | 0/225 [00:00<?, ?it/s]

Epoch: 5. Loss: 0.1117539331316948


100%|████████████████████████████████████████████████████████████████████████████████| 225/225 [02:11<00:00,  1.71it/s]
  0%|                                                                                          | 0/225 [00:00<?, ?it/s]

Epoch: 6. Loss: 0.1000947579741478


100%|████████████████████████████████████████████████████████████████████████████████| 225/225 [02:16<00:00,  1.64it/s]

Epoch: 7. Loss: 0.0811450406908989





In [8]:
correct = 0
total = 0

with torch.no_grad():
    for i in tqdm(range(len(test_X))):
        real_class = torch.argmax(test_y[i])
        net_output = net(test_X[i].view(-1, 1, 50, 50))
        predicted_class = torch.argmax(net_output)
        if predicted_class == real_class:
            correct+=1
        total+=1

print("Acc ", correct/total)

100%|█████████████████████████████████████████████████████████████████████████████| 2494/2494 [00:09<00:00, 269.16it/s]

Acc  0.7582197273456295





In [11]:
torch.save(net, "DogieVCatie.pt")



In [12]:
img = cv2.imread("dog.jpg", 0)
img = cv2.resize(img, (50, 50))
img = np.array(img)
img = torch.tensor(img).view(-1, 1, 50, 50)
img = img/255.0
op = torch.argmax(net(img))

if op.item() == 0:
    print("Its a cat")
else:
    
    print("Its a dog")

Its a dog


In [7]:
LABELS = {"CATS": 0, "DOGS": 1}
len(LABELS)
for key in range(len(LABELS)):
    print(key)

0
1
