In [1]:
import torch, torchvision
import cv2
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import numpy as np

In [2]:
REBUILD_DATA = True
class CatDogs:
    IMG_SIZE = 50
    Cats = "../PetImages/Cats/"
    Dogs = "../PetImages/Dogs/"
    
    LABELS = {Cats:0, Dogs:1}
    training_data = []
    catcount = 0
    dogcount = 0
    
    def make_training_data(self):
        for label in self.LABELS:
#             print(label)
            for i in tqdm(os.listdir(label)):
                try:
                    PATH = os.path.join(label,i)
                    img = cv2.imread(PATH, cv2.IMREAD_GRAYSCALE)
                    img = cv2.resize(img, (self.IMG_SIZE, self.IMG_SIZE))
                    self.training_data.append([np.array(img), 
                                               np.eye(2)[self.LABELS[label]]])
                except Exception as e:
                    pass
                
                if label == self.Cats:
                            self.catcount += 1
                elif label == self.Dogs:
                    self.dogcount += 1
                
                

        np.random.shuffle(self.training_data)
        np.save("training_data.npy", self.training_data)
        print(len(self.training_data))
        print('Cats:', self.catcount)
        print('Dogs:', self.dogcount)
        
        
if REBUILD_DATA:
    catdogs = CatDogs()
    catdogs.make_training_data()

100%|██████████| 305/305 [00:00<00:00, 706.23it/s]
100%|██████████| 306/306 [00:00<00:00, 661.65it/s]

609
Cats: 305
Dogs: 306





In [3]:
training_data = np.load("training_data.npy", allow_pickle=True)
len(training_data)

609

In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [5]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,32,5)
        self.conv2 = nn.Conv2d(32,64,5)
        self.conv3 = nn.Conv2d(64,128,5)
        
        x = torch.randn(50,50).view(-1,1,50,50)
        self._to_linear = None
        
        self.convs(x)
        self.fc1 = nn.Linear(self._to_linear, 512)
        self.fc2 = nn.Linear(512,2)

    def convs(self,x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2,2))
        
#         print(x[0].shape)
        
        if self._to_linear is None:
            self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
#             print(self._to_linear)
        return x
    
    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self._to_linear)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return F.softmax(x, dim=1)
    
    
net = CNN()

In [10]:
optimizer = optim.Adam(net.parameters(), lr = 0.01)
loss_function = nn.MSELoss()
X = torch.Tensor([i[0] for i in training_data]).reshape(-1, 50, 50)
X = X/255.0
y = torch.Tensor([i[1] for i in training_data])

VAL_PCT = 0.1
val_size = int(len(X)*VAL_PCT)
val_size, len(X),len(y)

(60, 609, 609)

In [11]:
train_X = X[:val_size]
train_y = y[:val_size]

test_X = X[-val_size:]
test_y = y[-val_size:]
len(train_X), len(train_y), len(test_X), len(test_y)

(60, 60, 60, 60)

In [12]:
BATCH = 1000
EPOCHS = 1

for epoch in range(EPOCHS):
    for i in tqdm(range(0, len(train_X), BATCH)):
#         print(i, i+BATCH)
        batch_X = train_X[i:i+BATCH].view(-1,1,50,50)
        batch_y = train_y[i:i+BATCH]

        net.zero_grad()
        output = net(batch_X)
        loss = loss_function(output, batch_y)
        loss.backward()
        optimizer.step()
print(loss)

100%|██████████| 1/1 [00:00<00:00,  2.62it/s]

tensor(0.2498, grad_fn=<MseLossBackward>)





In [13]:
correct = 0
total = 0

with torch.no_grad():
    for i in tqdm(range(len(test_X))):
        real_class = torch.argmax(test_y[i])
        net_out = net(test_X[i].view(-1,1,50,50))[0]
        predicted_class = torch.argmax(net_out)
        if predicted_class == real_class:
            correct += 1
        total +=1
print("Accuracy: ", round(correct/total,3))

100%|██████████| 60/60 [00:00<00:00, 236.48it/s]

Accuracy:  0.55



